Official code for: A Probabilistic Hard Attention Model For Sequentially Observed Scenes

Overview

"A Probabilistic Hard Attention Model For Sequentially Observed Scenes"

Authors: Samrudhdhi Rangrej, James Clark Accepted to: BMVC'21 framework A recurrent attention model sequentially observes glimpses from an image and predicts a class label. At time t, the model actively observes a glimpse gt and its coordinates lt. Given gt and lt, the feed-forward module F extracts features ft, and the recurrent module R updates a hidden state to ht. Using an updated hidden state ht, the linear classifier C predicts the class distribution p(y|ht). At time t+1, the model assesses various candidate locations l before attending an optimal one. It predicts p(y|g,l,ht) ahead of time and selects the candidate l that maximizes KL[p(y|g,l,ht)||p(y|ht)]. The model synthesizes the features of g using a Partial VAE to approximate p(y|g,l,ht) without attending to the glimpse g. The normalizing flow-based encoder S predicts the approximate posterior q(z|ht). The decoder D uses a sample z~q(z|ht) to synthesize a feature map f~ containing features of all glimpses. The model uses f~(l) as features of a glimpse at location l and evaluates p(y|g,l,ht)=p(y|f~(l),ht). Dashed arrows show a path to compute the lookahead class distribution p(y|f~(l),ht).

Requirements:

torch==1.8.1, torchvision==0.9.1, tensorboard==2.5.0, fire==0.4.0

Datasets:

Training a model

Use main.py to train and evaluate the model.

Arguments

  • dataset: one of 'svhn', 'cifar10', 'cifar100', 'cinic10', 'tinyimagenet'
  • datapath: path to the downloaded datasets
  • lr: learning rate
  • training_phase: one of 'first', 'second', 'third'
  • ccebal: coefficient for cross entropy loss
  • batch: batch-size for training
  • batchv: batch-size for evaluation
  • T: maximum time-step
  • logfolder: path to log directory
  • epochs: number of training epochs
  • pretrain_checkpoint: checkpoint for pretrained model from previous training phase

Example commands to train the model for SVHN dataset are as follows. Training Stage one

python3 main.py \
    --dataset='svhn' \
    --datapath='./data/' \
    --lr=0.001 \
    --training_phase='first' \
    --ccebal=1 \
    --batch=64 \
    --batchv=64 \
    --T=7 \
    --logfolder='./svhn_log_first' \
    --epochs=1000 \
    --pretrain_checkpoint=None

Training Stage two

python3 main.py \
    --dataset='svhn' \
    --datapath='./data/' \
    --lr=0.001 \
    --training_phase='second' \
    --ccebal=0 \
    --batch=64 \
    --batchv=64 \
    --T=7 \
    --logfolder='./svhn_log_second' \
    --epochs=100 \
    --pretrain_checkpoint='./svhn_log_first/weights_f_1000.pth'

Training Stage three

python3 main.py \
    --dataset='svhn' \
    --datapath='./data/' \
    --lr=0.001 \
    --training_phase='third' \
    --ccebal=16 \
    --batch=64 \
    --batchv=64 \
    --T=7 \
    --logfolder='./svhn_log_third' \
    --epochs=100 \
    --pretrain_checkpoint='./svhn_log_second/weights_f_100.pth'

Visualization of attention policy for a CIFAR-10 image

example The top row shows the entire image and the EIG maps for t=1 to 6. The bottom row shows glimpses attended by our model. The model observes the first glimpse at a random location. Our model observes a glimpse of size 8x8. The glimpses overlap with the stride of 4, resulting in a 7x7 grid of glimpses. The EIG maps are of size 7x7 and are upsampled for the display. We display the entire image for reference; our model never observes the whole image.

Acknowledgement

Major parts of neural spline flows implementation are borrowed from Karpathy's pytorch-normalizing-flows.

Image restoration with neural networks but without learning.

Warning! The optimization may not converge on some GPUs. We've personally experienced issues on Tesla V100 and P40 GPUs. When running the code, make s

Dmitry Ulyanov 7.4k Jan 01, 2023
This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation

This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation. Yolov5 is used to detect fire and smoke and unet is used to segment fire.

7 Jan 08, 2023
Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation

Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation Prerequisites This repo is built upon a local copy of transfo

Jixuan Wang 10 Sep 28, 2022
CIFS: Improving Adversarial Robustness of CNNs via Channel-wise Importance-based Feature Selection

CIFS This repository provides codes for CIFS (ICML 2021). CIFS: Improving Adversarial Robustness of CNNs via Channel-wise Importance-based Feature Sel

Hanshu YAN 19 Nov 12, 2022
disentanglement_lib is an open-source library for research on learning disentangled representations.

disentanglement_lib disentanglement_lib is an open-source library for research on learning disentangled representation. It supports a variety of diffe

Google Research 1.3k Dec 28, 2022
Face Mask Detection System built with OpenCV, TensorFlow using Computer Vision concepts

Face mask detection Face Mask Detection System built with OpenCV, TensorFlow using Computer Vision concepts in order to detect face masks in static im

Vaibhav Shukla 1 Oct 27, 2021
NeuTex: Neural Texture Mapping for Volumetric Neural Rendering

NeuTex: Neural Texture Mapping for Volumetric Neural Rendering Paper: https://arxiv.org/abs/2103.00762 Running Run on the provided DTU scene cd run ba

Fanbo Xiang 67 Dec 28, 2022
Implementation of Segnet, FCN, UNet , PSPNet and other models in Keras.

Image Segmentation Keras : Implementation of Segnet, FCN, UNet, PSPNet and other models in Keras. Implementation of various Deep Image Segmentation mo

Divam Gupta 2.6k Jan 05, 2023
Simple ray intersection library similar to coldet - succedeed by libacc

Ray Intersection This project offers a header only acceleration structure library including implementations for a BVH- and KD-Tree. Applications may i

Nils Moehrle 29 Jun 23, 2022
[CVPR 2021] Involution: Inverting the Inherence of Convolution for Visual Recognition, a brand new neural operator

involution Official implementation of a neural operator as described in Involution: Inverting the Inherence of Convolution for Visual Recognition (CVP

Duo Li 1.3k Dec 28, 2022
2021 Artificial Intelligence Diabetes Datathon

A.I.D.D. 2021 2021 Artificial Intelligence Diabetes Datathon A.I.D.D. 2021은 ‘2021 인공지능 학습용 데이터 구축사업’을 통해 만들어진 학습용 데이터를 활용하여 당뇨병을 효과적으로 예측할 수 있는가에 대한 A

2 Dec 27, 2021
Diverse Image Captioning with Context-Object Split Latent Spaces (NeurIPS 2020)

Diverse Image Captioning with Context-Object Split Latent Spaces This repository is the PyTorch implementation of the paper: Diverse Image Captioning

Visual Inference Lab @TU Darmstadt 34 Nov 21, 2022
Code for DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents

DeepXML Code for DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents Architectures and algorithms DeepXML supports

Extreme Classification 49 Nov 06, 2022
An investigation project for SISR.

SISR-Survey An investigation project for SISR. This repository is an official project of the paper "From Beginner to Master: A Survey for Deep Learnin

Juncheng Li 79 Oct 20, 2022
Research code for the paper "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models"

Introduction This repository contains research code for the ACL 2021 paper "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual

AdapterHub 20 Aug 04, 2022
Hack Camera, Microphone, Location, Clipboard With Just a Link. Also, Get Many Details About Victim's Device. And So On...

An Automated Tool to Hack Victim's Camera, Microphone, Location, Clipboard. Has 2 Extra Features. Version 1.1 Update Fixed Some Major Bugs Data Saving

ToxicNoob 36 Jan 07, 2023
The datasets and code of ACL 2021 paper "Aspect-Category-Opinion-Sentiment Quadruple Extraction with Implicit Aspects and Opinions".

Aspect-Category-Opinion-Sentiment (ACOS) Quadruple Extraction This repo contains the data sets and source code of our paper: Aspect-Category-Opinion-S

NUSTM 144 Jan 02, 2023
DL & CV-based indicator toolset for the vehicle drivers via live dash-cam footage.

Vehicle Indicator Toolset Deep Learning and Computer Vision based indicator toolset for vehicle drivers using live dash-cam footages. Tracking of vehi

Alex Xu 12 Dec 28, 2021
Patch-Diffusion Code (AAAI2022)

Patch-Diffusion This is an official PyTorch implementation of "Patch Diffusion: A General Module for Face Manipulation Detection" in AAAI2022. Require

H 7 Nov 02, 2022
A package related to building quasi-fibration symmetries

qf A package related to building quasi-fibration symmetries. If you'd like to learn more about how it works, see the brief explanation and References

Paolo Boldi 1 Dec 01, 2021