A Tensorfflow implementation of Attend, Infer, Repeat

Overview

Attend, Infer, Repeat: Fast Scene Understanding with Generative Models

This is an unofficial Tensorflow implementation of Attend, Infear, Repeat (AIR), as presented in the following paper: S. M. Ali Eslami et. al., Attend, Infer, Repeat: Fast Scene Understanding with Generative Models.

  • Author (of the implementation): Adam Kosiorek, Oxford Robotics Institue, University of Oxford
  • Email: adamk(at)robots.ox.ac.uk
  • Webpage: http://akosiorek.github.io/

I describe the implementation and the issues I run into while working on it in this blog post.

Installation

Install Tensorflow v1.1.0rc1, Sonnet v1.1 and the following dependencies (using pip install -r requirements.txt (preferred) or pip install [package]):

  • matplotlib==1.5.3
  • numpy==1.12.1
  • attrdict==2.0.0
  • scipy==0.18.1

Sample Results

AIR learns to reconstruct objects by painting them one by one in a blank canvas. The below figure comes from a model trained for 175k iterations; the maximum number of steps is set to 3, but there are never more than 2 objects. The first row shows the input images, rows 2-4 are reconstructions at steps 1, 2 and 3 (with marked location of the attention glimpse in red, if it exists). Rows 4-7 are the reconstructed image crops, and above each crop is the probability of executing 1, 2 or 3 steps. If the reconstructed crop is black and there is "0 with ..." written above it, it means that this step was not used.

AIR results

Data

Run ./scripts/create_dataset.sh The script creates train and validation datasets of multi-digit MNIST.

Training

Run ./scripts/train_multi_mnist.sh The training script will run for 300k iteratios and will save model checkpoints and training progress figures every 10k iterations in results/multi_mnist. Tensorflow summaries are also stored in the same folder and Tensorboard can be used for monitoring.

The model seems to be very sensitive to initialisation. It might be necessary to run training multiple times before achieving count step accuracy close to the one reported in the paper.

Experimentation

The jupyter notebook available at attend_infer_repeat/experiment.ipynb can be used for experimentation.

Citation

If you find this repo useful in your research, please consider citing the original paper:

@incollection{Eslami2016,
    title = {Attend, Infer, Repeat: Fast Scene Understanding with Generative Models},
    author = {Eslami, S. M. Ali and Heess, Nicolas and Weber, Theophane and Tassa, Yuval and Szepesvari, David and kavukcuoglu, koray and Hinton, Geoffrey E},
    booktitle = {Advances in Neural Information Processing Systems 29},
    editor = {D. D. Lee and M. Sugiyama and U. V. Luxburg and I. Guyon and R. Garnett},
    pages = {3225--3233},
    year = {2016},
    publisher = {Curran Associates, Inc.},
    url = {http://papers.nips.cc/paper/6230-attend-infer-repeat-fast-scene-understanding-with-generative-models.pdf}
}

License

This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program. If not, see http://www.gnu.org/licenses/.

Release Notes

Version 1.0

  • Original unofficial implementation; contains the multi-digit MNIST experiment.
Owner
Adam Kosiorek
I'm a PhD student at the Oxford Robotics Institute. I work on Machine Learning for perception - I'm looking into external memory and attention for RNNs.
Adam Kosiorek
A framework for joint super-resolution and image synthesis, without requiring real training data

SynthSR This repository contains code to train a Convolutional Neural Network (CNN) for Super-resolution (SR), or joint SR and data synthesis. The met

83 Jan 01, 2023
LeViT a Vision Transformer in ConvNet's Clothing for Faster Inference

LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference This repository contains PyTorch evaluation code, training code and pretrained

Facebook Research 504 Jan 02, 2023
Source code for Acorn, the precision farming rover by Twisted Fields

Acorn precision farming rover This is the software repository for Acorn, the precision farming rover by Twisted Fields. For more information see twist

Twisted Fields 198 Jan 02, 2023
September-Assistant - Open-source Windows Voice Assistant

September - Windows Assistant September is an open-source Windows personal assis

The Nithin Balaji 9 Nov 22, 2022
Activating More Pixels in Image Super-Resolution Transformer

HAT [Paper Link] Activating More Pixels in Image Super-Resolution Transformer Xiangyu Chen, Xintao Wang, Jiantao Zhou and Chao Dong BibTeX @article{ch

XyChen 270 Dec 27, 2022
ICRA 2021 "Towards Precise and Efficient Image Guided Depth Completion"

PENet: Precise and Efficient Depth Completion This repo is the PyTorch implementation of our paper to appear in ICRA2021 on "Towards Precise and Effic

232 Dec 25, 2022
Nicholas Lee 3 Jan 09, 2022
An example project demonstrating how the Autonomous Learning Library can be used to build new reinforcement learning agents.

About This repository shows how Autonomous Learning Library can be used to build new reinforcement learning agents. In particular, it contains a model

Chris Nota 5 Aug 30, 2022
Code for Domain Adaptive Video Segmentation via Temporal Consistency Regularization in ICCV 2021

Domain Adaptive Video Segmentation via Temporal Consistency Regularization Updates 08/2021: check out our domain adaptation for sematic segmentation p

36 Dec 12, 2022
Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021

Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021 The code for training mCOLT/mRASP2, a multilingua

104 Jan 01, 2023
Implements the training, testing and editing tools for "Pluralistic Image Completion"

Pluralistic Image Completion ArXiv | Project Page | Online Demo | Video(demo) This repository implements the training, testing and editing tools for "

Chuanxia Zheng 615 Dec 08, 2022
Implementation of UNET architecture for Image Segmentation.

Semantic Segmentation using UNET This is the implementation of UNET on Carvana Image Masking Kaggle Challenge About the Dataset This dataset contains

Anushka agarwal 4 Dec 21, 2021
Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)

This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support mnist, svhn cifar10, cifar100 st

Aaron Chen 2.4k Dec 28, 2022
PyTorch implementation of the ACL, 2021 paper Parameter-efficient Multi-task Fine-tuning for Transformers via Shared Hypernetworks.

Parameter-efficient Multi-task Fine-tuning for Transformers via Shared Hypernetworks This repo contains the PyTorch implementation of the ACL, 2021 pa

Rabeeh Karimi Mahabadi 98 Dec 28, 2022
《Unsupervised 3D Human Pose Representation with Viewpoint and Pose Disentanglement》(ECCV 2020) GitHub: [fig9]

Unsupervised 3D Human Pose Representation [Paper] The implementation of our paper Unsupervised 3D Human Pose Representation with Viewpoint and Pose Di

42 Nov 24, 2022
Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams

Adversarial Robustness Toolbox (ART) is a Python library for Machine Learning Security. ART provides tools that enable developers and researchers to defend and evaluate Machine Learning models and ap

3.4k Jan 04, 2023
Implementation of Heterogeneous Graph Attention Network

HetGAN Implementation of Heterogeneous Graph Attention Network This is the code repository of paper "Prediction of Metro Ridership During the COVID-19

5 Dec 28, 2021
Multi-layer convolutional LSTM with Pytorch

Convolution_LSTM_pytorch Thanks for your attention. I haven't got time to maintain this repo for a long time. I recommend this repo which provides an

Zijie Zhuang 734 Jan 03, 2023
ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral)

ILVR + ADM This is the implementation of ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral). This repository is h

Jooyoung Choi 225 Dec 28, 2022
The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"

Swin-Unet The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validatio

869 Jan 07, 2023