A Tensorfflow implementation of Attend, Infer, Repeat

Overview

Attend, Infer, Repeat: Fast Scene Understanding with Generative Models

This is an unofficial Tensorflow implementation of Attend, Infear, Repeat (AIR), as presented in the following paper: S. M. Ali Eslami et. al., Attend, Infer, Repeat: Fast Scene Understanding with Generative Models.

  • Author (of the implementation): Adam Kosiorek, Oxford Robotics Institue, University of Oxford
  • Email: adamk(at)robots.ox.ac.uk
  • Webpage: http://akosiorek.github.io/

I describe the implementation and the issues I run into while working on it in this blog post.

Installation

Install Tensorflow v1.1.0rc1, Sonnet v1.1 and the following dependencies (using pip install -r requirements.txt (preferred) or pip install [package]):

  • matplotlib==1.5.3
  • numpy==1.12.1
  • attrdict==2.0.0
  • scipy==0.18.1

Sample Results

AIR learns to reconstruct objects by painting them one by one in a blank canvas. The below figure comes from a model trained for 175k iterations; the maximum number of steps is set to 3, but there are never more than 2 objects. The first row shows the input images, rows 2-4 are reconstructions at steps 1, 2 and 3 (with marked location of the attention glimpse in red, if it exists). Rows 4-7 are the reconstructed image crops, and above each crop is the probability of executing 1, 2 or 3 steps. If the reconstructed crop is black and there is "0 with ..." written above it, it means that this step was not used.

AIR results

Data

Run ./scripts/create_dataset.sh The script creates train and validation datasets of multi-digit MNIST.

Training

Run ./scripts/train_multi_mnist.sh The training script will run for 300k iteratios and will save model checkpoints and training progress figures every 10k iterations in results/multi_mnist. Tensorflow summaries are also stored in the same folder and Tensorboard can be used for monitoring.

The model seems to be very sensitive to initialisation. It might be necessary to run training multiple times before achieving count step accuracy close to the one reported in the paper.

Experimentation

The jupyter notebook available at attend_infer_repeat/experiment.ipynb can be used for experimentation.

Citation

If you find this repo useful in your research, please consider citing the original paper:

@incollection{Eslami2016,
    title = {Attend, Infer, Repeat: Fast Scene Understanding with Generative Models},
    author = {Eslami, S. M. Ali and Heess, Nicolas and Weber, Theophane and Tassa, Yuval and Szepesvari, David and kavukcuoglu, koray and Hinton, Geoffrey E},
    booktitle = {Advances in Neural Information Processing Systems 29},
    editor = {D. D. Lee and M. Sugiyama and U. V. Luxburg and I. Guyon and R. Garnett},
    pages = {3225--3233},
    year = {2016},
    publisher = {Curran Associates, Inc.},
    url = {http://papers.nips.cc/paper/6230-attend-infer-repeat-fast-scene-understanding-with-generative-models.pdf}
}

License

This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program. If not, see http://www.gnu.org/licenses/.

Release Notes

Version 1.0

  • Original unofficial implementation; contains the multi-digit MNIST experiment.
Owner
Adam Kosiorek
I'm a PhD student at the Oxford Robotics Institute. I work on Machine Learning for perception - I'm looking into external memory and attention for RNNs.
Adam Kosiorek
“英特尔创新大师杯”深度学习挑战赛 赛道3:CCKS2021中文NLP地址相关性任务

ccks2021-track3 CCKS2021中文NLP地址相关性任务-赛道三-冠军方案 团队:我的加菲鱼- wodejiafeiyu 初赛第二/复赛第一/决赛第一 前言 19年开始,陆陆续续参加了一些比赛,拿到过一些top,比较懒一直都没分享过,这次比较幸运又拿了top1,打算分享下 分类的任务

shaochenjie 131 Dec 31, 2022
ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information

ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information This repository contains code, model, dataset for ChineseBERT at ACL2021. Ch

413 Dec 01, 2022
Pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks."

alpha-GAN Unofficial pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks." arXi

Victor Shepardson 78 Dec 08, 2022
Apollo optimizer in tensorflow

Apollo Optimizer in Tensorflow 2.x Notes: Warmup is important with Apollo optimizer, so be sure to pass in a learning rate schedule vs. a constant lea

Evan Walters 1 Nov 09, 2021
A Comprehensive Study on Learning-Based PE Malware Family Classification Methods

A Comprehensive Study on Learning-Based PE Malware Family Classification Methods Datasets Because of copyright issues, both the MalwareBazaar dataset

8 Oct 21, 2022
FADNet++: Real-Time and Accurate Disparity Estimation with Configurable Networks

FADNet++: Real-Time and Accurate Disparity Estimation with Configurable Networks

HKBU High Performance Machine Learning Lab 6 Nov 18, 2022
DAN: Unfolding the Alternating Optimization for Blind Super Resolution

DAN-Basd-on-Openmmlab DAN: Unfolding the Alternating Optimization for Blind Super Resolution We reproduce DAN via mmediting based on open-sourced code

AlexZou 72 Dec 13, 2022
Lung Pattern Classification for Interstitial Lung Diseases Using a Deep Convolutional Neural Network

ild-cnn This is supplementary material for the manuscript: "Lung Pattern Classification for Interstitial Lung Diseases Using a Deep Convolutional Neur

22 Nov 05, 2022
Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition

Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition Official implementation of the Efficient Conforme

Maxime Burchi 145 Dec 30, 2022
Flexible-Modal Face Anti-Spoofing: A Benchmark

Flexible-Modal FAS This is the official repository of "Flexible-Modal Face Anti-

Zitong Yu 22 Nov 10, 2022
a curated list of docker-compose files prepared for testing data engineering tools, databases and open source libraries.

data-services A repository for storing various Data Engineering docker-compose files in one place. How to use it ? Set the required settings in .env f

BigData.IR 525 Dec 03, 2022
Implementation of the paper titled "Using Sampling to Estimate and Improve Performance of Automated Scoring Systems with Guarantees"

Using Sampling to Estimate and Improve Performance of Automated Scoring Systems with Guarantees Implementation of the paper titled "Using Sampling to

MIDAS, IIIT Delhi 2 Aug 29, 2022
This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset

HiRID-ICU-Benchmark This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset for which the manuscript can be found here.

Biomedical Informatics at ETH Zurich 30 Dec 16, 2022
Honours project, on creating a depth estimation map from two stereo images of featureless regions

image-processing This module generates depth maps for shape-blocked-out images Install If working with anaconda, then from the root directory: conda e

2 Oct 17, 2022
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022
Yas CRNN model training - Yet Another Genshin Impact Scanner

Yas-Train Yet Another Genshin Impact Scanner 又一个原神圣遗物导出器 介绍 该仓库为 Yas 的模型训练程序 相关资料 MobileNetV3 CRNN 使用 假设你会设置基本的pytorch环境。 生成数据集 python main.py gen 训练

wormtql 18 Jan 08, 2023
Python implementation of ADD: Frequency Attention and Multi-View based Knowledge Distillation to Detect Low-Quality Compressed Deepfake Images, AAAI2022.

ADD: Frequency Attention and Multi-View based Knowledge Distillation to Detect Low-Quality Compressed Deepfake Images Binh M. Le & Simon S. Woo, "ADD:

2 Oct 24, 2022
A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

2 Jul 25, 2022
A port of muP to JAX/Haiku

MUP for Haiku This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to sugg

18 Dec 30, 2022
Machine learning and Deep learning models, deploy on telegram (the best social media)

Semi Intelligent BOT The project involves : Classifying fake news Classifying objects such as aeroplane, automobile, bird, cat, deer, dog, frog, horse

MohammadReza Norouzi 5 Mar 06, 2022