Implementation of Feedback Transformer in Pytorch

Overview

Feedback Transformer - Pytorch

Simple implementation of Feedback Transformer in Pytorch. They improve on Transformer-XL by having each token have access to the representations of all previous layers through time. This is achieved by aggregating the outputs of all layers into a shared memory, which each token across layers can attend to at each time step.

The main drawback is longer training time, due to its non-parallel nature. But I thought I'd build it to further exploration and research into this line of work.

Yannic Kilcher video

I also took the liberty to add some various enhancements, including pre-normalization, GLU gated feedforwards, as well as simplified T5 relative positional embeddings.

Install

$ pip install feedback-transformer-pytorch

Usage

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,           # number of tokens
    dim = 512,                    # dimension
    depth = 6,                    # depth
    seq_len = 2,                  # the sequence length of each segment or window
    mem_len = 256,                # length of the memory buffer
    dim_head = 64,                # dimension of each head
    heads = 8,                    # number of heads
    attn_dropout = 0.1,           # attention dropout
    ff_dropout = 0.1              # feedforward dropout
).cuda()

x = torch.randint(0, 20000, (2, 64)).cuda()
model(x)  # (2, 64, 20000)

If you would like to have fine control over the memory (when to detach, etc), you can do it with some extra keyword arguments on .forward

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 32,
    mem_len = 256
).cuda()

x1 = torch.randint(0, 20000, (2, 32)).cuda()
x2 = torch.randint(0, 20000, (2, 32)).cuda()
x3 = torch.randint(0, 20000, (2, 32)).cuda()

out1, mem1 = model(x1, return_memory = True)
out2, mem2 = model(x2, memory = mem1, return_memory = True)
out3, mem3 = model(x3, memory = mem2, return_memory = True)  # (2, 32, 20000)

Citations

@misc{fan2021addressing,
    title   = {Addressing Some Limitations of Transformers with Feedback Memory}, 
    author  = {Angela Fan and Thibaut Lavril and Edouard Grave and Armand Joulin and Sainbayar Sukhbaatar},
    year    = {2021},
    eprint  = {2002.09402},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Should it really be using lower layers output for keys and values?

    Should it really be using lower layers output for keys and values?

    Could you explain the logic of how the key-value pairs are formed at these lines and whether it is necessary?

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/d7d8939910d1491f01a3d93ce81d4663925fb389/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L146-L151

    It looks to me that line 146 transforms the output of the layer below (x) to keys and values, and the following lines combine these keys and values with the memory. I thought that x should only be used for forming the query here, and only the existing memory is used for keys and values.

    opened by tarvaina 6
  • In place operation with gradient

    In place operation with gradient

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L173 I think this is an error.

    opened by hadaev8 4
  • Bug in weighted sum

    Bug in weighted sum

    Bug in https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L264

    Should be layer_weight = rearrange(layer_weight, 'd -> d () () ()')

    opened by Victor0118 1
  • Input/Output dimensions

    Input/Output dimensions

    Hey @lucidrains

    Can I check the dimensions of the input and output, is it (seq_len, dim) -> (? ,dim, tokens)?

    model = FeedbackTransformer(
        num_tokens = 20000,           # number of tokens
        dim = 512,                    # dimension
        depth = 6,                    # depth
        seq_len = 2,                  # the sequence length of each segment or window
        mem_len = 256,                # length of the memory buffer
        dim_head = 64,                # dimension of each head
        heads = 8,                    # number of heads
        attn_dropout = 0.1,           # attention dropout
        ff_dropout = 0.1              # feedforward dropout
    ).cuda()
    
    x = torch.randint(0, 256, (2, 512)).cuda()
    model(x)  # (1, 512, 20000)
    
    opened by iiSeymour 1
  • Non intuitive memory usage with cross attention

    Non intuitive memory usage with cross attention

    Give simple 256 dim and 512 len tensor and memory len 16 feedback transformer uses 3.6gm memory after forward pass. With cross attention on 100 len tensor usage grows to 14gb.

    While parallel version uses 3.1gb and 3.5gb.

    Notebooks for testing https://colab.research.google.com/drive/1dRImydFn3WthOXdLYIvdf5bsqjXcmhC5?usp=sharing https://colab.research.google.com/drive/1n653j4Pz9_U7OukhTlUbomAHMvpPXwx0?usp=sharing

    opened by hadaev8 0
  • I think mask padding value should be False

    I think mask padding value should be False

    Here https://github.com/lucidrains/feedback-transformer-pytorch/blob/with-cross-attention/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L181

    opened by hadaev8 0
  • ETA for the enwiki8 example

    ETA for the enwiki8 example

    Hey @lucidrains,

    Any eta on the example for auto-regressive enwiki8 example? I and others would really appreciate it as always :)

    Also, if you can provide an example for training on custom line-by-line TXT datasets, it would be absolutely fantastic.

    Thank you.

    opened by asigalov61 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Camera-caps - Examine the camera capabilities for V4l2 cameras

camera-caps This is a graphical user interface over the v4l2-ctl command line to

Jetsonhacks 25 Dec 26, 2022
YOLOv7 - Framework Beyond Detection

πŸ”₯πŸ”₯πŸ”₯πŸ”₯ YOLO with Transformers and Instance Segmentation, with TensorRT acceleration! πŸ”₯πŸ”₯πŸ”₯

JinTian 3k Jan 01, 2023
Mae segmentation - Reproduction of semantic segmentation using masked autoencoder (mae)

ADE20k Semantic segmentation with MAE Getting started Install the mmsegmentation

97 Dec 17, 2022
Automatic Idiomatic Expression Detection

IDentifier of Idiomatic Expressions via Semantic Compatibility (DISC) An Idiomatic identifier that detects the presence and span of idiomatic expressi

5 Jun 09, 2022
Rainbow is all you need! A step-by-step tutorial from DQN to Rainbow

Do you want a RL agent nicely moving on Atari? Rainbow is all you need! This is a step-by-step tutorial from DQN to Rainbow. Every chapter contains bo

Jinwoo Park (Curt) 1.4k Dec 29, 2022
[Link]mareteutral - pars tradg wth M []

pairs-trading-with-ML Jonathan Larkin, August 2017 One popular strategy classification is Pairs Trading. Though this category of strategies can exhibi

Jonathan Larkin 134 Jan 06, 2023
A transformer-based method for Healthcare Image Captioning in Vietnamese

vieCap4H Challenge 2021: A transformer-based method for Healthcare Image Captioning in Vietnamese This repo GitHub contains our solution for vieCap4H

Doanh B C 4 May 05, 2022
RSC-Net: 3D Human Pose, Shape and Texture from Low-Resolution Images and Videos

RSC-Net: 3D Human Pose, Shape and Texture from Low-Resolution Images and Videos Implementation for "3D Human Pose, Shape and Texture from Low-Resoluti

XiangyuXu 42 Nov 10, 2022
Another pytorch implementation of FCN (Fully Convolutional Networks)

FCN-pytorch-easiest Trying to be the easiest FCN pytorch implementation and just in a get and use fashion Here I use a handbag semantic segmentation f

Y. Dong 158 Dec 21, 2022
Creative Applications of Deep Learning w/ Tensorflow

Creative Applications of Deep Learning w/ Tensorflow This repository contains lecture transcripts and homework assignments as Jupyter Notebooks for th

Parag K Mital 1.5k Dec 30, 2022
A Pytorch implementation of SMU: SMOOTH ACTIVATION FUNCTION FOR DEEP NETWORKS USING SMOOTHING MAXIMUM TECHNIQUE

SMU_pytorch A Pytorch Implementation of SMU: SMOOTH ACTIVATION FUNCTION FOR DEEP NETWORKS USING SMOOTHING MAXIMUM TECHNIQUE arXiv https://arxiv.org/ab

Fuhang 36 Dec 24, 2022
Generate vibrant and detailed images using only text.

CLIP Guided Diffusion From RiversHaveWings. Generate vibrant and detailed images using only text. See captions and more generations in the Gallery See

Clay M. 401 Dec 28, 2022
paper: Hyperspectral Remote Sensing Image Classification Using Deep Convolutional Capsule Network

DC-CapsNet This is a tensorflow and keras based implementation of DC-CapsNet for HSI in the Remote Sensing Letters R. Lei et al., "Hyperspectral Remot

LEI 7 Nov 29, 2022
Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"

Memory Compressed Attention Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers

Phil Wang 47 Dec 23, 2022
This repository is the official implementation of the Hybrid Self-Attention NEAT algorithm.

This repository is the official implementation of the Hybrid Self-Attention NEAT algorithm. It contains the code to reproduce the results presented in the original paper: https://arxiv.org/abs/2112.0

Saman Khamesian 6 Dec 13, 2022
Source Code for Simulations in the Publication "Can the brain use waves to solve planning problems?"

Code for Simulations in the Publication Can the brain use waves to solve planning problems? Installing Required Python Packages Please use Python vers

EMD Group 2 Jul 01, 2022
Official implementation of Self-supervised Image-to-text and Text-to-image Synthesis

Self-supervised Image-to-text and Text-to-image Synthesis This is the official implementation of Self-supervised Image-to-text and Text-to-image Synth

6 Jul 31, 2022
RCT-ART is an NLP pipeline built with spaCy for converting clinical trial result sentences into tables through jointly extracting intervention, outcome and outcome measure entities and their relations.

Randomised controlled trial abstract result tabulator RCT-ART is an NLP pipeline built with spaCy for converting clinical trial result sentences into

2 Sep 16, 2022
Official code repository for the EMNLP 2021 paper

Integrating Visuospatial, Linguistic and Commonsense Structure into Story Visualization PyTorch code for the EMNLP 2021 paper "Integrating Visuospatia

Adyasha Maharana 23 Dec 19, 2022
Resco: A simple python package that report the effect of deep residual learning

resco Description resco is a simple python package that report the effect of dee

Pierre-Arthur ClaudΓ© 1 Jun 28, 2022