This is the official source code for SLATE. We provide the code for the model, the training code, and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

Related tags

Deep Learningslate
Overview

SLATE

This is the official source code for SLATE. We provide the code for the model, the training code and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

Arxiv: https://arxiv.org/pdf/2110.11405.pdf
Project Page: https://sites.google.com/view/slate-autoencoder

Dataset

The current release provides a boilerplate code to train the model on the 3D Shapes dataset. The dataset class is provided in shapes_3d.py. You can edit or replace this class if you need to run the code on a different dataset. The 3D Shapes dataset can be downloaded from the official URL https://console.cloud.google.com/storage/browser/3d-shapes. This should produce a dataset file 3dshapes.h5. During training, the path to this dataset file needs to be provided using the argument --data_path.

Training

To train the model, simply execute:

python train.py

Check train.py to see the full list of training arguments.

Outputs

The training code produces Tensorboard logs. To see these logs, run Tensorboard on the logging directory that was provided in the training argument --log_path. These logs contain the training loss curves and visualizations of reconstructions and object attention maps.

Hyperparameters of Interest

  • Learning Rate can be tuned using the training argument --lr_main and different choices can affect the characteristics of the object attention maps.
  • Number of Slots can be tuned using the training argument --num_slots. Number of slots should be set higher than the number of objects you expect to see in the images.
  • Number of Slot Attention Iterations can be tuned using the training argument --num_iterations. In general, keep the number of iterations as small as possible because too many iterations can prevent slots from learning to diversify and attach to different objects.

Code Files

This repository provides the following files.

  • train.py contains the main code for running the training.
  • slate.py provides the model class for SLATE.
  • shapes_3d.py contains the dataset class for 3D Shapes dataset.
  • dvae.py provides the encoder and the decoder for Discrete VAE.
  • slot_attn.py provides the model class for Slot Attention encoder.
  • transformer.py provides the model classes for Transformer.
  • utils.py provides helper classes and functions for the implementation.
Owner
Gautam Singh
PhD student at Rutgers CS
Gautam Singh
Preparation material for Dropbox interviews

Dropbox-Onsite-Interviews A guide for the Dropbox onsite interview! The Dropbox interview question bank is very small. The bank has been in a Chinese

386 Dec 31, 2022
A convolutional recurrent neural network for classifying A/B phases in EEG signals recorded for sleep analysis.

CAP-Classification-CRNN A deep learning model based on Inception modules paired with gated recurrent units (GRU) for the classification of CAP phases

Apurva R. Umredkar 2 Nov 25, 2022
Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”

Official implementation for TransDA Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”. Overview: Result: Prerequisites:

stanley 54 Dec 22, 2022
Code for "Graph-Evolving Meta-Learning for Low-Resource Medical Dialogue Generation". [AAAI 2021]

Graph Evolving Meta-Learning for Low-resource Medical Dialogue Generation Code to be further cleaned... This repo contains the code of the following p

Shuai Lin 29 Nov 01, 2022
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

Joshua Ji 3 Aug 20, 2022
Hypernetwork-Ensemble Learning of Segmentation Probability for Medical Image Segmentation with Ambiguous Labels

Hypernet-Ensemble Learning of Segmentation Probability for Medical Image Segmentation with Ambiguous Labels The implementation of Hypernet-Ensemble Le

Sungmin Hong 6 Jul 18, 2022
Implementation for the EMNLP 2021 paper "Interactive Machine Comprehension with Dynamic Knowledge Graphs".

Interactive Machine Comprehension with Dynamic Knowledge Graphs Implementation for the EMNLP 2021 paper. Dependencies apt-get -y update apt-get instal

Xingdi (Eric) Yuan 19 Aug 23, 2022
[Arxiv preprint] Causality-inspired Single-source Domain Generalization for Medical Image Segmentation (code&data-processing pipeline)

Causality-inspired Single-source Domain Generalization for Medical Image Segmentation Arxiv preprint Repository under construction. Might still be bug

Cheng 31 Dec 27, 2022
The CLRS Algorithmic Reasoning Benchmark

Learning representations of algorithms is an emerging area of machine learning, seeking to bridge concepts from neural networks with classical algorithms.

DeepMind 251 Jan 05, 2023
Causal Influence Detection for Improving Efficiency in Reinforcement Learning

Causal Influence Detection for Improving Efficiency in Reinforcement Learning This repository contains the code release for the paper "Causal Influenc

Autonomous Learning Group 21 Nov 29, 2022
Toontown: Galaxy, a new Toontown game based on Disney's Toontown Online

Toontown: Galaxy The official archive repo for Toontown: Galaxy, a new Toontown

1 Feb 15, 2022
A3C LSTM Atari with Pytorch plus A3G design

NEWLY ADDED A3G A NEW GPU/CPU ARCHITECTURE OF A3C FOR SUBSTANTIALLY ACCELERATED TRAINING!! RL A3C Pytorch NEWLY ADDED A3G!! New implementation of A3C

David Griffis 532 Jan 02, 2023
PyTorch Implementation of ECCV 2020 Spotlight TuiGAN: Learning Versatile Image-to-Image Translation with Two Unpaired Images

TuiGAN-PyTorch Official PyTorch Implementation of "TuiGAN: Learning Versatile Image-to-Image Translation with Two Unpaired Images" (ECCV 2020 Spotligh

181 Dec 09, 2022
This is the official pytorch implementation of Student Helping Teacher: Teacher Evolution via Self-Knowledge Distillation(TESKD)

Student Helping Teacher: Teacher Evolution via Self-Knowledge Distillation (TESKD) By Zheng Li[1,4], Xiang Li[2], Lingfeng Yang[2,4], Jian Yang[2], Zh

Zheng Li 9 Sep 26, 2022
Vector Quantization, in Pytorch

Vector Quantization - Pytorch A vector quantization library originally transcribed from Deepmind's tensorflow implementation, made conveniently into a

Phil Wang 665 Jan 08, 2023
Using image super resolution models with vapoursynth and speeding them up with TensorRT

vs-RealEsrganAnime-tensorrt-docker Using image super resolution models with vapoursynth and speeding them up with TensorRT. Also a docker image since

4 Aug 23, 2022
Benchmarks for Model-Based Optimization

Design-Bench Design-Bench is a benchmarking framework for solving automatic design problems that involve choosing an input that maximizes a black-box

Brandon Trabucco 43 Dec 20, 2022
Housing Price Prediction

This project aim was to predict the price of houses in the Boston area during the great financial crisis through regression, as well as classify houses into different quality categories according to

Florian Klement 1 Jan 27, 2022
A framework for GPU based high-performance medical image processing and visualization

FAST is an open-source cross-platform framework with the main goal of making it easier to do high-performance processing and visualization of medical images on heterogeneous systems utilizing both mu

Erik Smistad 315 Dec 30, 2022
This repository introduces a short project about Transfer Learning for Classification of MRI Images.

Transfer Learning for MRI Images Classification This repository introduces a short project made during my stay at Neuromatch Summer School 2021. This

Oscar Guarnizo 3 Nov 15, 2022