Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Overview

Perceiver IO

Unofficial PyTorch implementation of

This implementation supports training of Perceiver IO models with Pytorch Lightning on some example tasks via a command line interface. Perceiver IO models are constructed using generic encoder and decoder classes and task-specific input and output adapters (see Model API).

Setup

conda env create -f environment.yml
conda activate perceiver-io
export PYTHONPATH=.

Tasks

In the following subsections, Perceiver IO models are trained on some example tasks at smaller scale. In particular, they were trained on two NVIDIA GTX 1080 GPUs (8 GB memory each) using Pytorch Lightning's support for distributed data-parallel training. I didn't really tune model architectures and other hyper-parameters, so you'll probably get better results with a bit of experimentation. Support for more datasets and tasks will be added later.

Masked language modeling

Pretrain a Perceiver IO model on masked language modeling (MLM) with text from the IMDB training set. The pretrained encoder is then used for training a sentiment classification model.

python train/train_mlm.py --dataset=imdb --learning_rate=1e-3 --batch_size=64 \
  --max_epochs=200 --dropout=0.0 --weight_decay=0.0 \
  --accelerator=ddp --gpus=-1

All available command line options and their default values can be displayed with python train/train_mlm.py -h.

Sentiment classification

Train a classification decoder using a frozen encoder from masked language modeling. If you ran MLM yourself you'll need to modify the --mlm_checkpoint argument accordingly, otherwise download checkpoints from here and extract them in the root directory of this project.

python train/train_seq_clf.py --dataset=imdb --learning_rate=1e-3 --batch_size=128 \
  --max_epochs=15 --dropout=0.0 --weight_decay=1e-3 --freeze_encoder \
  --accelerator=ddp --gpus=-1 \
  --mlm_checkpoint 'logs/mlm/version_0/checkpoints/epoch=199-val_loss=4.899.ckpt'

Unfreeze the encoder and jointly fine-tune it together with the decoder that has been trained in the previous step. If you ran the previous step yourself you'll need to modify the --clf_checkpoint argument accordingly, otherwise download checkpoints from here.

python train/train_seq_clf.py --dataset=imdb --learning_rate=1e-4 --batch_size=128 \
  --max_epochs=15 --dropout=0.2 --weight_decay=1e-3 \
  --accelerator=ddp --gpus=-1 \
  --clf_checkpoint 'logs/seq_clf/version_0/checkpoints/epoch=014-val_loss=0.350.ckpt'

All available command line options and their default values can be displayed with python train/train_seq_clf.py -h.

Image classification

Classify MNIST images. See also Model API for details about the underlying Perceiver IO model.

python train/train_img_clf.py --dataset=mnist --learning_rate=1e-3 --batch_size=128 \
  --max_epochs=20 --dropout=0.0 --weight_decay=1e-4 \
  --accelerator=ddp --gpus=-1

All available command line options and their default values can be displayed with python train/train_img_clf.py -h.

Model API

The model API is based on generic encoder and decoder classes (PerceiverEncoder and PerceiverDecoder) and task-specific input and output adapters. The following snippet shows how they can be used to create an MNIST image classifier, for example:

from perceiver.adapter import ImageInputAdapter, ClassificationOutputAdapter
from perceiver.model import PerceiverIO, PerceiverEncoder, PerceiverDecoder

latent_shape = (32, 128)

# Fourier-encode pixel positions and flatten along spatial dimensions
input_adapter = ImageInputAdapter(image_shape=(28, 28, 1), num_frequency_bands=32)

# Project generic Perceiver decoder output to specified number of classes
output_adapter = ClassificationOutputAdapter(num_classes=10, num_output_channels=128)

# Generic Perceiver encoder
encoder = PerceiverEncoder(
    input_adapter=input_adapter,
    latent_shape=latent_shape,
    num_layers=3,
    num_cross_attention_heads=4,
    num_self_attention_heads=4,
    num_self_attention_layers_per_block=3,
    dropout=0.0)

# Generic Perceiver decoder
decoder = PerceiverDecoder(
    output_adapter=output_adapter,
    latent_shape=latent_shape,
    num_cross_attention_heads=1,
    dropout=0.0)

# MNIST classifier implemented as Perceiver IO model
mnist_classifier = PerceiverIO(encoder, decoder)

Tensorboard

Commands in section Tasks write Tensorboard logs to the logs directory. They can be visualized with tensorboard --logir logs. MLM training additionally writes predictions of masked sample text to Tensorboard's TEXT page. For example, the command

python train/train_mlm.py --dataset=imdb --learning_rate=1e-3 --batch_size=64 \
  --max_epochs=200 --dropout=0.0 --weight_decay=0.0 \
  --accelerator=ddp --gpus=-1 --predict_k=5 \
  --predict_samples='i have watched this [MASK] and it was awesome'  

writes the top 5 predictions for I have watched this [MASK] and it was awesome to Tensorboard after each epoch:

i have watched this [MASK] and it was awesome
i have watched this movie and it was awesome
i have watched this show and it was awesome
i have watched this film and it was awesome
i have watched this series and it was awesome
i have watched this dvd and it was awesome

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{jaegle2021perceiver,
    title   = {Perceiver IO: A General Architecture for Structured Inputs & Outputs},
    author  = {Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier Hénaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and João Carreira},
    year    = {2021},
    eprint  = {2107.14795},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • use Conda version of TorchMetrics

    use Conda version of TorchMetrics

    preferably use Conda version, similar could be done also with PL except Conda does not support extras

    • https://anaconda.org/conda-forge/torchmetrics
    • https://anaconda.org/conda-forge/pytorch-lightning
    opened by Borda 8
  • Genomic sequences

    Genomic sequences

    Hello,

    Thank you for your implementation of the PerceiverIO project. I am trying to use your work for genomic sequences of shape (10k, 1). I noticed that your model produces the SAME output for DIFFERENT inputs when the num_channels dimension is 1 (I am not using the Fourier Feature encodings). If the outputs are not the same, then they are nominally different. Can you please guide me in solving this issue? Thanks in advance!

    Please let me know what additional information you would need to reproduce this bug.

    opened by ajv012 7
  • Fix poetry python version limits

    Fix poetry python version limits

    The grpc dependency doesn't build with Python 3.10 yet because it relies on outdated setuptools which the dependency tree isn't managing automatically, so don't allow environments under 3.10+ until grpc gets their act together.

    Commit also updates minor versions of some deps because re-synchronizing poetry.lock with pyproject requires a full "poetry update" which pulls updated minor dependency versions everywhere.

    opened by mattsta 5
  • make repo as installable package

    make repo as installable package

    I found it will be useful to have this project as an installable package so I suggest the following changes:

    • rename actual perceiver as model (as it holds all model-related components)
    • create new package perceiverio which pulls model, data and cli as sub-packages
    • simplify the cli package as just module as it quite lite
    • add setup.py to be installed
    opened by Borda 5
  • Data preprocessing and documentation enhancements, major refactorings

    Data preprocessing and documentation enhancements, major refactorings

    Functional enhancements:

    • Support for static word masking in addition to dynamic word masking.
    • Support for individual token masking in addition to whole word masking.
    • Task-specific data preprocessing for all supported text datasets.
    • Constant learning rate scheduler with warmup now used by default.

    Documentation enhancements:

    • All training examples now provided as command line and Python script.
    • Better overview of official models and example training checkpoints.
    • Example training checkpoints can now be downloaded individually.
    • Minor enhancements to all other documentation sections.

    Refactorings and breaking changes:

    • Rename image package to vision.
    • TextDataModule base class now implements complete preprocessing logic.
    • TextDataModule subclasses only convert source dataset to a common structure.
    • Abstraction over cross-attention query creation (QueryProvider).
    • Decouple OutputAdapter interface from trainable cross-attention query.
    • Implement learned positions encodings as nn.Embedding.
    • Move adapters to separate perceiver.model.core.adapter module.
    • Rename PerceiverConfig to PerceiverIOConfig
    • Rename LitModel base class to LitPerceiverIO.
    • LitClassifier.forward now behaves like the wrapped model's forward.
    • Object-oriented design of conversion from Hugging Face Perceiver models.
    • Major refactoring of PerceiverAR and CausalLanguageModel.
    • Move FourierPositionEncoding to perceiver.model.core.position` module.
    opened by krasserm 2
  • Multi-head attention as specified in paper, API changes and refactorings

    Multi-head attention as specified in paper, API changes and refactorings

    • Multi-head attention as specified in https://arxiv.org/abs/2107.14795 Appendix E
    • Renaming of constructor parameters in Pytorch Model API
    • Redesign of config classes in Pytorch Lightning API and CLI
    • Output query now managed by output adapter instead of decoder
    opened by krasserm 2
  • text encoding error

    text encoding error

    Hi, I am getting this error

    Traceback (most recent call last):
      File "train/train_mlm.py", line 113, in <module>
        main(parser.parse_args())
      File "train/train_mlm.py", line 69, in main
        data_module.setup()
      File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/datamodule.py", line 428, in wrapped_fn
        fn(*args, **kwargs)
      File "/opt/perceiver-io/data/imdb.py", line 131, in setup
        self.ds_train = IMDBDataset(root=self.root, split='train')
      File "/opt/perceiver-io/data/imdb.py", line 42, in __init__
        self.raw_x, self.raw_y = load_split(root, split)
      File "/opt/perceiver-io/data/imdb.py", line 34, in load_split
        raw_x.append(f.read())
      File "/usr/lib/python3.6/encodings/ascii.py", line 26, in decode
        return codecs.ascii_decode(input, self.errors)[0]
    UnicodeDecodeError: 'ascii' codec can't decode byte 0xc2 in position 449: ordinal not in range(128)
    

    it is probably related to the unicode encoding

    opened by batrlatom 2
  • What is Q in the latent encoder layers?

    What is Q in the latent encoder layers?

    It seems that in the multi-layer encoder, you use x_latent as Q, x as KV, shouldn't the QKV all be x_latent in latent layers? Please correct me if I missed something in the paper, thank you!

    opened by zhangyuygss 1
  • Support key padding masks for Perceiver AR

    Support key padding masks for Perceiver AR

    • tokenizers must be configured to padding_side="left" in order to be compatible with Perceiver AR
    • support configuration of padding_side on base class of text data modules (TextDataModule).
    • implement random sequence truncation in data module instead of model
    • sequences in a batch are individually truncated to different lengths.
    • enable random_train_shift by default which increases regularization.
    opened by krasserm 0
  • Implement processor for optical flow

    Implement processor for optical flow

    • Implement OpticalFlowProcessor to preprocess input images and create optical flows from model predictions
    • Add video_utils to sample frames from videos and create output videos from estimated optical flows
    • Extend inference notebook with examples for optical flow
    opened by cstub 0
  • Major refactorings

    Major refactorings

    Better modularization Documentation rewrite Add support for Huggingface tokenizers Add support for Huggingface datasets Add support for Docker Fix missing bias terms in MHA Weight init according to paper

    opened by krasserm 0
Releases(0.7.0)
  • 0.7.0(Dec 4, 2022)

  • 0.7b1(Nov 20, 2022)

    Data preprocessing and documentation enhancements, major refactorings

    Functional enhancements:

    • Support for static word masking in addition to dynamic word masking.
    • Support for individual token masking in addition to whole word masking.
    • Task-specific data preprocessing for all supported text datasets.
    • Constant learning rate scheduler with warmup now used by default.

    Documentation enhancements:

    • All training examples now provided as command line and Python script.
    • Better overview of official models and example training checkpoints.
    • Example training checkpoints can now be downloaded individually.
    • Minor enhancements to all other documentation sections.

    Refactorings and breaking changes:

    • Rename image package to vision.
    • TextDataModule base class now implements complete preprocessing logic.
    • TextDataModule subclasses only convert source dataset to a common structure.
    • Abstraction over cross-attention query creation (QueryProvider).
    • Decouple OutputAdapter interface from trainable cross-attention query.
    • Implement learned positions encodings as nn.Embedding.
    • Move adapters to separate perceiver.model.core.adapter module.
    • Rename PerceiverConfig to PerceiverIOConfig
    • Rename LitModel base class to LitPerceiverIO.
    • LitClassifier.forward now behaves like the wrapped model's forward.
    • Object-oriented design of conversion from Hugging Face Perceiver models.
    • Major refactoring of PerceiverAR and CausalLanguageModel.
    • Move FourierPositionEncoding to perceiver.model.core.position` module.
    Source code(tar.gz)
    Source code(zip)
  • 0.6.0(Sep 25, 2022)

  • 0.5.1(Aug 31, 2022)

  • 0.5.0(Aug 22, 2022)

Owner
Martin Krasser
Freelance machine learning engineer, software developer and consultant. Mountainbike freerider, bass guitar player.
Martin Krasser
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision

🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.

Hugging Face 3.5k Jan 08, 2023
A tutorial on "Bayesian Compression for Deep Learning" published at NIPS (2017).

Code release for "Bayesian Compression for Deep Learning" In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of

Karen Ullrich 190 Dec 30, 2022
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
Official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis.

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 02, 2023
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

205 Jan 02, 2023
Code snippets created for the PyTorch discussion board

PyTorch misc Collection of code snippets I've written for the PyTorch discussion board. All scripts were testes using the PyTorch 1.0 preview and torc

461 Dec 26, 2022
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 04, 2023
Bunch of optimizer implementations in PyTorch

Bunch of optimizer implementations in PyTorch

Hyeongchan Kim 76 Jan 03, 2023
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 03, 2023
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
pip install antialiased-cnns to improve stability and accuracy

Antialiased CNNs [Project Page] [Paper] [Talk] Making Convolutional Networks Shift-Invariant Again Richard Zhang. In ICML, 2019. Quick & easy start Ru

Adobe, Inc. 1.6k Dec 28, 2022
PyTorch to TensorFlow Lite converter

PyTorch to TensorFlow Lite converter

Omer Ferhat Sarioglu 140 Dec 13, 2022