Simple and efficient RevNet-Library with DeepSpeed support

Related tags

Text Data & NLPrevlib
Overview

RevLib

Simple and efficient RevNet-Library with DeepSpeed support

Features

  • Half the constant memory usage and faster than RevNet libraries
  • Less memory than gradient checkpointing (1 * output_size instead of n_layers * output_size)
  • Same speed as activation checkpointing
  • Extensible
  • Trivial code (<100 Lines)

Getting started

Installation

python3 -m pip install revlib

Examples

iRevNet

iRevNet is not only partially reversible but instead a fully-invertible model. The source code looks complex at first glance. It also doesn't use the memory savings it could utilize, as RevNet requires custom AutoGrad functions that are hard to maintain. An iRevNet can be implemented like this using revlib:

import torch
from torch import nn
import revlib

channels = 64
channel_multiplier = 4
depth = 3
classes = 1000


# Create a basic function that's reversibly executed multiple times. (Like f() in ResNet)
def conv(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, (3, 3), padding=1)


def block_conv(in_channels, out_channels):
    return nn.Sequential(conv(in_channels, out_channels),
                         nn.Dropout(0.2),
                         nn.BatchNorm2d(out_channels),
                         nn.ReLU())


def block():
    return nn.Sequential(block_conv(channels, channels * channel_multiplier),
                         block_conv(channels * channel_multiplier, channels),
                         nn.Conv2d(channels, channels, (3, 3), padding=1))


# Create a reversible model. f() is invoked depth-times with different weights.
rev_model = revlib.ReversibleSequential(*[block() for _ in range(depth)])

# Wrap reversible model with non-reversible layers
model = nn.Sequential(conv(3, 2*channels), rev_model, conv(2 * channels, classes))

# Use it like you would a regular PyTorch model
inp = torch.randn((1, 3, 224, 224))
out = model(inp)
out.mean().backward()
assert out.size() == (1, 1000, 224, 224)

MomentumNet

MomentumNet is another recent paper that made significant advancements in the area of memory-efficient networks. They propose to use a momentum stream instead of a second model output as illustrated below: MomentumNetIllustration. Implementing that with revlib requires you to write a custom coupling operation (functional analogue to MemCNN) that merges input and output streams.

import torch
from torch import nn
import revlib

channels = 64
depth = 16
momentum_ema_beta = 0.99


# Compute y2 from x2 and f(x1) by merging x2 and f(x1) in the forward pass.
def momentum_coupling_forward(other_stream: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return other_stream * momentum_ema_beta + fn_out * (1 - momentum_ema_beta)


# Calculate x2 from y2 and f(x1) by manually computing the inverse of momentum_coupling_forward.
def momentum_coupling_inverse(output: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return (output - fn_out * (1 - momentum_ema_beta)) / momentum_ema_beta


# Pass in coupling functions which will be used instead of x2 + f(x1) and y2 - f(x1)
rev_model = revlib.ReversibleSequential(*[layer for _ in range(depth)
                                          for layer in [nn.Conv2d(channels, channels, (3, 3), padding=1),
                                                        nn.Identity()]],
                                        coupling_forward=[momentum_coupling_forward, revlib.additive_coupling_forward],
                                        coupling_inverse=[momentum_coupling_inverse, revlib.additive_coupling_inverse])

inp = torch.randn((16, channels * 2, 224, 224))
out = rev_model(inp)
assert out.size() == (16, channels * 2, 224, 224)

Reformer

Reformer uses RevNet with chunking and LSH-attention to efficiently train a transformer. Using revlib, standard implementations, such as lucidrains' Reformer, can be improved upon to use less memory. Below we're still using the basic building blocks from lucidrains' code to have a comparable model.

import torch
from torch import nn
from reformer_pytorch.reformer_pytorch import LSHSelfAttention, Chunk, FeedForward, AbsolutePositionalEmbedding
import revlib


class Reformer(torch.nn.Module):
    def __init__(self, sequence_length: int, features: int, depth: int, heads: int, bucket_size: int = 64,
                 lsh_hash_count: int = 8, ff_chunks: int = 16, input_classes: int = 256, output_classes: int = 256):
        super(Reformer, self).__init__()
        self.token_embd = nn.Embedding(input_classes, features * 2)
        self.pos_embd = AbsolutePositionalEmbedding(features * 2, sequence_length)

        self.core = revlib.ReversibleSequential(*[nn.Sequential(nn.LayerNorm(features), layer) for _ in range(depth)
                                                 for layer in
                                                 [LSHSelfAttention(features, heads, bucket_size, lsh_hash_count),
                                                  Chunk(ff_chunks, FeedForward(features, activation=nn.GELU), 
                                                        along_dim=-2)]],
                                                split_dim=-1)
        self.out_norm = nn.LayerNorm(features * 2)
        self.out_linear = nn.Linear(features * 2, output_classes)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        return self.out_linear(self.out_norm(self.core(self.token_embd(inp) + self.pos_embd(inp))))


sequence = 1024
classes = 16
model = Reformer(sequence, 256, 6, 8, output_classes=classes)
out = model(torch.ones((16, sequence), dtype=torch.long))
assert out.size() == (16, sequence, classes)

Explanation

Most other RevNet libraries, such as MemCNN and Revtorch calculate both f() and g() in one go, to create one large computation. RevLib, on the other hand, brings Mesh TensorFlow's "reversible half residual and swap" to PyTorch. reversible_half_residual_and_swap computes only one of f() and g() and swaps the inputs and gradients. This way, the library only has to store one output as it can recover the other output during the backward pass.
Following Mesh TensorFlow's example, revlib also uses separate x1 and x2 tensors instead of concatenating and splitting at every step to reduce the cost of memory-bound operations.

RevNet's memory consumption doesn't scale with its depth, so it's significantly more memory-efficient for deep models. One problem in most implementations was that two tensors needed to be stored in the output, quadrupling the required memory. The high memory consumption rendered RevNet nearly useless for small networks, such as BERT, with its six layers.
RevLib works around this problem by storing only one output and two inputs for each forward pass, giving a model as small as BERT a >2x improvement!

Ignoring the dual-path structure of a RevNet, it usually used to be much slower than gradient checkpointing. However, RevLib uses minimal coupling functions and has no overhead between Sequence items, allowing it to train as fast as a comparable model with gradient checkpointing.

Owner
Lucas Nestler
German ai researcher
Lucas Nestler
An official repository for tutorials of Probabilistic Modelling and Reasoning (2021/2022) - a University of Edinburgh master's course.

PMR computer tutorials on HMMs (2021-2022) This is a repository for computer tutorials of Probabilistic Modelling and Reasoning (2021/2022) - a Univer

Vaidotas Šimkus 10 Dec 06, 2022
Natural Language Processing for Adverse Drug Reaction (ADR) Detection

Natural Language Processing for Adverse Drug Reaction (ADR) Detection This repo contains code from a project to identify ADRs in discharge summaries a

Medicines Optimisation Service - Austin Health 21 Aug 05, 2022
Official code of our work, Unified Pre-training for Program Understanding and Generation [NAACL 2021].

PLBART Code pre-release of our work, Unified Pre-training for Program Understanding and Generation accepted at NAACL 2021. Note. A detailed documentat

Wasi Ahmad 138 Dec 30, 2022
PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation

StyleSpeech - PyTorch Implementation PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation. Status (2021.06.09

Keon Lee 142 Jan 06, 2023
Count the frequency of letters or words in a text file and show a graph.

Word Counter By EBUS Coding Club Count the frequency of letters or words in a text file and show a graph. Requirements Python 3.9 or higher matplotlib

EBUS Coding Club 0 Apr 09, 2022
문장단위로 분절된 나무위키 데이터셋. Releases에서 다운로드 받거나, tfds-korean을 통해 다운로드 받으세요.

Namuwiki corpus 문장단위로 미리 분절된 나무위키 코퍼스. 목적이 LM등에서 사용하기 위한 데이터셋이라, 링크/이미지/테이블 등등이 잘려있습니다. 문장 단위 분절은 kss를 활용하였습니다. 라이선스는 나무위키에 명시된 바와 같이 CC BY-NC-SA 2.0

Jeong Ukjae 16 Apr 02, 2022
End-to-End Speech Processing Toolkit

ESPnet: end-to-end speech processing toolkit system/pytorch ver. 1.0.1 1.1.0 1.2.0 1.3.1 1.4.0 1.5.1 1.6.0 1.7.1 1.8.1 ubuntu18/python3.8/pip ubuntu18

ESPnet 5.9k Jan 03, 2023
Beautiful visualizations of how language differs among document types.

Scattertext 0.1.0.0 A tool for finding distinguishing terms in corpora and displaying them in an interactive HTML scatter plot. Points corresponding t

Jason S. Kessler 2k Dec 27, 2022
Automatically search Stack Overflow for the command you want to run

stackshell Automatically search Stack Overflow (and other Stack Exchange sites) for the command you want to ru Use the up and down arrows to change be

circuit10 22 Oct 27, 2021
Transformation spoken text to written text

Transformation spoken text to written text This model is used for formatting raw asr text output from spoken text to written text (Eg. date, number, i

Nguyen Binh 16 Dec 28, 2022
Bpe algorithm can finetune tokenizer - Bpe algorithm can finetune tokenizer

"# bpe_algorithm_can_finetune_tokenizer" this is an implyment for https://github

张博 1 Feb 02, 2022
Th2En & Th2Zh: The large-scale datasets for Thai text cross-lingual summarization

Th2En & Th2Zh: The large-scale datasets for Thai text cross-lingual summarization 📥 Download Datasets 📥 Download Trained Models INTRODUCTION TH2ZH (

Nakhun Chumpolsathien 5 Jan 03, 2022
Data preprocessing rosetta parser for python

datapreprocessing_rosetta_parser I've never done any NLP or text data processing before, so I wanted to use this hackathon as a learning opportunity,

ASReview hackathon for Follow the Money 2 Nov 28, 2021
BPEmb is a collection of pre-trained subword embeddings in 275 languages, based on Byte-Pair Encoding (BPE) and trained on Wikipedia.

BPEmb is a collection of pre-trained subword embeddings in 275 languages, based on Byte-Pair Encoding (BPE) and trained on Wikipedia. Its intended use is as input for neural models in natural languag

Benjamin Heinzerling 1.1k Jan 03, 2023
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.

Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language mod

13.2k Jul 07, 2021
This repository contains data used in the NAACL 2021 Paper - Proteno: Text Normalization with Limited Data for Fast Deployment in Text to Speech Systems

Proteno This is the data release associated with the corresponding NAACL 2021 Paper - Proteno: Text Normalization with Limited Data for Fast Deploymen

37 Dec 04, 2022
Code-autocomplete, a code completion plugin for Python

Code AutoComplete code-autocomplete, a code completion plugin for Python.

xuming 13 Jan 07, 2023
Generating Korean Slogans with phonetic and structural repetition

LexPOS_ko Generating Korean Slogans with phonetic and structural repetition Generating Slogans with Linguistic Features LexPOS is a sequence-to-sequen

Yeoun Yi 3 May 23, 2022
Use PaddlePaddle to reproduce the paper:mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer

MT5_paddle Use PaddlePaddle to reproduce the paper:mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer English | 简体中文 mT5: A Massively

2 Oct 17, 2021
PyKaldi is a Python scripting layer for the Kaldi speech recognition toolkit.

PyKaldi is a Python scripting layer for the Kaldi speech recognition toolkit. It provides easy-to-use, low-overhead, first-class Python wrappers for t

922 Dec 31, 2022