Implementation of a Transformer that Ponders, using the scheme from the PonderNet paper

Overview

Ponder(ing) Transformer

Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of the input sequence, using the scheme from the PonderNet paper. Will also try to abstract out a pondering module that can be used with any block that returns an output with the halting probability.

This repository would not have been possible without repeated viewings of Yannic's educational video

Install

$ pip install ponder-transformer

Usage

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512
)

mask = torch.ones(1, 512).bool()

x = torch.randint(0, 20000, (1, 512))
y = torch.randint(0, 20000, (1, 512))

loss = model(x, labels = y, mask = mask)
loss.backward()

Now you can set the model to .eval() mode and it will terminate early when all samples of the batch have emitted a halting signal

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512,
    causal = True
)

x = torch.randint(0, 20000, (2, 512))
mask = torch.ones(2, 512).bool()

model.eval() # setting to eval makes it return the logits as well as the halting indices

logits, layer_indices = model(x,  mask = mask) # (2, 512, 20000), (2)

# layer indices will contain, for each batch element, which layer they exited

Citations

@misc{banino2021pondernet,
    title   = {PonderNet: Learning to Ponder}, 
    author  = {Andrea Banino and Jan Balaguer and Charles Blundell},
    year    = {2021},
    eprint  = {2107.05407},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
You might also like...
Implementation of the Transformer variant proposed in
Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

FLASH - Pytorch Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time Install $ pip install FLASH-pytorch

Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation. E.T. is based on a multimodal transformer that encodes language inputs and the full episode history of visual observations and actions. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

3D-Transformer: Molecular Representation with Transformer in 3D Space

3D-Transformer: Molecular Representation with Transformer in 3D Space

This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Transformer Huffman coding - Complete Huffman coding through transformer

Transformer_Huffman_coding Complete Huffman coding through transformer 2022/2/19

Comments
  • Evaluating ponder-net on more pondering-steps than trained on.

    Evaluating ponder-net on more pondering-steps than trained on.

    As the paper says,

    In evaluation, and under known temporal or computational limitations, N can be set naively as a constant (or not set any limit, i.e. N → ∞). For training, we found that a more effective (and interpretable) way of parameterizing N is by defining a minimum cumulative probability of halting. N is then the smallest value of n such that sum( p_sub_ j > 1 − ε)over(j=1, n) , with the hyper-parameter ε positive near 0 (in our experiments 0.05).

    from that I infer that pondering can be done to more steps than trained on. How can be done so with this implementation?

    edit: I was going through the paper again,and I think what the paper means is that the max_num_pondering_steps:N should be re evaluated at every training-step, the model should be run till the condition is met or a pre-defined num of max steps is reached, and where the cumsum_probs condition will be met will be set as 'N', with the cumsum_probs normalised with one of the methods. Then that value of 'N' will be used to calc prior geom for the kl_div (and not normalising the prior geom term).

    i.e. if the num of pondering steps are initially set to 'M', then the model will recur for 'k' steps - i.e. till the condition is met or for 'M' num of max steps; then 'N' will be calculated by first calculating the probabilities - p_0 to p_k - then normalizing through one of the methods, then calculate cumulative-sum of those probabilities, and checking where the sum is greater than threshold, and assigning it the value 'N'. After that, calculating prior geometric values with the defined hyper-parameter, for 'N' seq-len, and using this in the kl-div term against the halting probs truncated to 'N' steps.

    λp is a hyper-parameter that defines a geometric prior distribution pG(λp) on the halting policy (truncated at N)

    opened by Vbansal21 0
  • Can pondernet used for imagenet?

    Can pondernet used for imagenet?

    I plan to do a project on the complexity of tasks on image dataset like imagenet, cifar 100. If I use a vision transformer, then can I implement my project?

    opened by fryegg 2
Releases(0.0.8)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Semantic Segmentation with Pytorch-Lightning

This is a simple demo for performing semantic segmentation on the Kitti dataset using Pytorch-Lightning and optimizing the neural network by monitoring and comparing runs with Weights & Biases.

Boris Dayma 58 Nov 18, 2022
🏅 Top 5% in 제2회 연구개발특구 인공지능 경진대회 AI SPARK 챌린지

AI_SPARK_CHALLENG_Object_Detection 제2회 연구개발특구 인공지능 경진대회 AI SPARK 챌린지 🏅 Top 5% in mAP(0.75) (443명 중 13등, mAP: 0.98116) 대회 설명 Edge 환경에서의 가축 Object Dete

3 Sep 19, 2022
Official repository for the paper "Going Beyond Linear Transformers with Recurrent Fast Weight Programmers"

Recurrent Fast Weight Programmers This is the official repository containing the code we used to produce the experimental results reported in the pape

IDSIA 36 Nov 15, 2022
Unified file system operation experience for different backend

megfile - Megvii FILE library Docs: http://megvii-research.github.io/megfile megfile provides a silky operation experience with different backends (cu

MEGVII Research 76 Dec 14, 2022
Using Self-Supervised Pretext Tasks for Active Learning - Official Pytorch Implementation

Using Self-Supervised Pretext Tasks for Active Learning - Official Pytorch Implementation Experiment Setting: CIFAR10 (downloaded and saved in ./DATA

John Seon Keun Yi 38 Dec 27, 2022
PyTorch Connectomics: segmentation toolbox for EM connectomics

Introduction The field of connectomics aims to reconstruct the wiring diagram of the brain by mapping the neural connections at the level of individua

Zudi Lin 132 Dec 26, 2022
Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

ChongjianGE 89 Dec 02, 2022
Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)

This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support mnist, svhn cifar10, cifar100 st

Aaron Chen 2.4k Dec 28, 2022
Unofficial keras(tensorflow) implementation of MAE model from Masked Autoencoders Are Scalable Vision Learners

MAE-keras Unofficial keras(tensorflow) implementation of MAE model described in 'Masked Autoencoders Are Scalable Vision Learners'. This work has been

Yewon 11 Jun 12, 2022
ICCV2021, Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021 Update: 2021/03/11: update our new results. Now our T2T-ViT-14 w

YITUTech 1k Dec 31, 2022
Repo for "Benchmarking Robustness of 3D Point Cloud Recognition against Common Corruptions" https://arxiv.org/abs/2201.12296

Benchmarking Robustness of 3D Point Cloud Recognition against Common Corruptions This repo contains the dataset and code for the paper Benchmarking Ro

Jiachen Sun 168 Dec 29, 2022
D-NeRF: Neural Radiance Fields for Dynamic Scenes

D-NeRF: Neural Radiance Fields for Dynamic Scenes [Project] [Paper] D-NeRF is a method for synthesizing novel views, at an arbitrary point in time, of

Albert Pumarola 291 Jan 02, 2023
Toolbox to analyze temporal context invariance of deep neural networks

PyTCI A toolbox that estimates the integration window of a sensory response using the "Temporal Context Invariance" paradigm (TCI). The TCI method Int

4 Oct 23, 2022
[Link]mareteutral - pars tradg wth M []

pairs-trading-with-ML Jonathan Larkin, August 2017 One popular strategy classification is Pairs Trading. Though this category of strategies can exhibi

Jonathan Larkin 134 Jan 06, 2023
Cancer metastasis detection with neural conditional random field (NCRF)

NCRF Prerequisites Data Whole slide images Annotations Patch images Model Training Testing Tissue mask Probability map Tumor localization FROC evaluat

Baidu Research 731 Jan 01, 2023
Official PyTorch code of DeepPanoContext: Panoramic 3D Scene Understanding with Holistic Scene Context Graph and Relation-based Optimization (ICCV 2021 Oral).

DeepPanoContext (DPC) [Project Page (with interactive results)][Paper] DeepPanoContext: Panoramic 3D Scene Understanding with Holistic Scene Context G

Cheng Zhang 66 Nov 16, 2022
Feature extraction made simple with torchextractor

torchextractor: PyTorch Intermediate Feature Extraction Introduction Too many times some model definitions get remorselessly copy-pasted just because

Antoine Broyelle 89 Oct 31, 2022
Аналитика доходности инвестиционного портфеля в Тинькофф брокере

Аналитика доходности инвестиционного портфеля Тиньков Видео на YouTube Для работы скрипта нужно установить три переменных окружения: export TINKOFF_TO

Alexey Goloburdin 64 Dec 17, 2022
Lightweight mmm - Lightweight (Bayesian) Media Mix Model

Lightweight (Bayesian) Media Mix Model This is not an official Google product. L

Google 342 Jan 03, 2023
Face Mask Detection System built with OpenCV, TensorFlow using Computer Vision concepts

Face mask detection Face Mask Detection System built with OpenCV, TensorFlow using Computer Vision concepts in order to detect face masks in static im

Vaibhav Shukla 1 Oct 27, 2021