Official code for: A Probabilistic Hard Attention Model For Sequentially Observed Scenes

Overview

"A Probabilistic Hard Attention Model For Sequentially Observed Scenes"

Authors: Samrudhdhi Rangrej, James Clark Accepted to: BMVC'21 framework A recurrent attention model sequentially observes glimpses from an image and predicts a class label. At time t, the model actively observes a glimpse gt and its coordinates lt. Given gt and lt, the feed-forward module F extracts features ft, and the recurrent module R updates a hidden state to ht. Using an updated hidden state ht, the linear classifier C predicts the class distribution p(y|ht). At time t+1, the model assesses various candidate locations l before attending an optimal one. It predicts p(y|g,l,ht) ahead of time and selects the candidate l that maximizes KL[p(y|g,l,ht)||p(y|ht)]. The model synthesizes the features of g using a Partial VAE to approximate p(y|g,l,ht) without attending to the glimpse g. The normalizing flow-based encoder S predicts the approximate posterior q(z|ht). The decoder D uses a sample z~q(z|ht) to synthesize a feature map f~ containing features of all glimpses. The model uses f~(l) as features of a glimpse at location l and evaluates p(y|g,l,ht)=p(y|f~(l),ht). Dashed arrows show a path to compute the lookahead class distribution p(y|f~(l),ht).

Requirements:

torch==1.8.1, torchvision==0.9.1, tensorboard==2.5.0, fire==0.4.0

Datasets:

Training a model

Use main.py to train and evaluate the model.

Arguments

  • dataset: one of 'svhn', 'cifar10', 'cifar100', 'cinic10', 'tinyimagenet'
  • datapath: path to the downloaded datasets
  • lr: learning rate
  • training_phase: one of 'first', 'second', 'third'
  • ccebal: coefficient for cross entropy loss
  • batch: batch-size for training
  • batchv: batch-size for evaluation
  • T: maximum time-step
  • logfolder: path to log directory
  • epochs: number of training epochs
  • pretrain_checkpoint: checkpoint for pretrained model from previous training phase

Example commands to train the model for SVHN dataset are as follows. Training Stage one

python3 main.py \
    --dataset='svhn' \
    --datapath='./data/' \
    --lr=0.001 \
    --training_phase='first' \
    --ccebal=1 \
    --batch=64 \
    --batchv=64 \
    --T=7 \
    --logfolder='./svhn_log_first' \
    --epochs=1000 \
    --pretrain_checkpoint=None

Training Stage two

python3 main.py \
    --dataset='svhn' \
    --datapath='./data/' \
    --lr=0.001 \
    --training_phase='second' \
    --ccebal=0 \
    --batch=64 \
    --batchv=64 \
    --T=7 \
    --logfolder='./svhn_log_second' \
    --epochs=100 \
    --pretrain_checkpoint='./svhn_log_first/weights_f_1000.pth'

Training Stage three

python3 main.py \
    --dataset='svhn' \
    --datapath='./data/' \
    --lr=0.001 \
    --training_phase='third' \
    --ccebal=16 \
    --batch=64 \
    --batchv=64 \
    --T=7 \
    --logfolder='./svhn_log_third' \
    --epochs=100 \
    --pretrain_checkpoint='./svhn_log_second/weights_f_100.pth'

Visualization of attention policy for a CIFAR-10 image

example The top row shows the entire image and the EIG maps for t=1 to 6. The bottom row shows glimpses attended by our model. The model observes the first glimpse at a random location. Our model observes a glimpse of size 8x8. The glimpses overlap with the stride of 4, resulting in a 7x7 grid of glimpses. The EIG maps are of size 7x7 and are upsampled for the display. We display the entire image for reference; our model never observes the whole image.

Acknowledgement

Major parts of neural spline flows implementation are borrowed from Karpathy's pytorch-normalizing-flows.

Colab notebook for openai/glide-text2im.

GLIDE text2im on Colab This repository provides a Colab notebook to produce images conditioned on text prompts with GLIDE [1]. Usage Run text2im.ipynb

Wok 19 Oct 19, 2022
Official implementation of "Learning Proposals for Practical Energy-Based Regression", 2021.

ebms_proposals Official implementation (PyTorch) of the paper: Learning Proposals for Practical Energy-Based Regression, 2021 [arXiv] [project]. Fredr

Fredrik Gustafsson 10 Oct 22, 2022
[ACM MM 2021] Multiview Detection with Shadow Transformer (and View-Coherent Data Augmentation)

Multiview Detection with Shadow Transformer (and View-Coherent Data Augmentation) [arXiv] [paper] @inproceedings{hou2021multiview, title={Multiview

Yunzhong Hou 27 Dec 13, 2022
Run object detection model on the Raspberry Pi

Using TensorFlow Lite with Python is great for embedded devices based on Linux, such as Raspberry Pi.

Dimitri Yanovsky 6 Oct 08, 2022
The Official Repository for "Generalized OOD Detection: A Survey"

Generalized Out-of-Distribution Detection: A Survey 1. Overview This repository is with our survey paper: Title: Generalized Out-of-Distribution Detec

Jingkang Yang 338 Jan 03, 2023
Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Source Code

Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Source Code

STARS Laboratory 8 Sep 14, 2022
LAnguage Model Analysis

LAMA: LAnguage Model Analysis LAMA is a probe for analyzing the factual and commonsense knowledge contained in pretrained language models. The dataset

Meta Research 960 Jan 08, 2023
efficient neural audio synthesis in the waveform domain

neural waveshaping synthesis real-time neural audio synthesis in the waveform domain paper • website • colab • audio by Ben Hayes, Charalampos Saitis,

Ben Hayes 169 Dec 23, 2022
Pairwise Learning for Neural Link Prediction for OGB (PLNLP-OGB)

Pairwise Learning for Neural Link Prediction for OGB (PLNLP-OGB) This repository provides evaluation codes of PLNLP for OGB link property prediction t

Zhitao WANG 31 Oct 10, 2022
Practical and Real-world applications of ML based on the homework of Hung-yi Lee Machine Learning Course 2021

Machine Learning Theory and Application Overview This repository is inspired by the Hung-yi Lee Machine Learning Course 2021. In that course, professo

SilenceJiang 35 Nov 22, 2022
Prompt-BERT: Prompt makes BERT Better at Sentence Embeddings

Prompt-BERT: Prompt makes BERT Better at Sentence Embeddings Results on STS Tasks Model STS12 STS13 STS14 STS15 STS16 STSb SICK-R Avg. unsup-prompt-be

196 Jan 08, 2023
Synthetic Humans for Action Recognition, IJCV 2021

SURREACT: Synthetic Humans for Action Recognition from Unseen Viewpoints Gül Varol, Ivan Laptev and Cordelia Schmid, Andrew Zisserman, Synthetic Human

Gul Varol 59 Dec 14, 2022
Official PyTorch code of Holistic 3D Scene Understanding from a Single Image with Implicit Representation (CVPR 2021)

Implicit3DUnderstanding (Im3D) [Project Page] Holistic 3D Scene Understanding from a Single Image with Implicit Representation Cheng Zhang, Zhaopeng C

Cheng Zhang 149 Jan 08, 2023
Adversarial Autoencoders

Adversarial Autoencoders (with Pytorch) Dependencies argparse time torch torchvision numpy itertools matplotlib Create Datasets python create_datasets

Felipe Ducau 188 Jan 01, 2023
Decensoring Hentai with Deep Neural Networks. Formerly named DeepMindBreak.

DeepCreamPy Decensoring Hentai with Deep Neural Networks. Formerly named DeepMindBreak. A deep learning-based tool to automatically replace censored a

616 Jan 06, 2023
Official Pytorch implementation for 2021 ICCV paper "Learning Motion Priors for 4D Human Body Capture in 3D Scenes" and trained models / data

Learning Motion Priors for 4D Human Body Capture in 3D Scenes (LEMO) Official Pytorch implementation for 2021 ICCV (oral) paper "Learning Motion Prior

165 Dec 19, 2022
Diagnostic tests for linguistic capacities in language models

LM diagnostics This repository contains the diagnostic datasets and experimental code for What BERT is not: Lessons from a new suite of psycholinguist

61 Jan 02, 2023
NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

100 Sep 28, 2022
Large-scale open domain KNOwledge grounded conVERsation system based on PaddlePaddle

Knover Knover is a toolkit for knowledge grounded dialogue generation based on PaddlePaddle. Knover allows researchers and developers to carry out eff

607 Dec 31, 2022
The source code of CVPR17 'Generative Face Completion'.

GenerativeFaceCompletion Matcaffe implementation of our CVPR17 paper on face completion. In each panel from left to right: original face, masked input

Yijun Li 313 Oct 18, 2022