PyTorch implementation of the Quasi-Recurrent Neural Network - up to 16 times faster than NVIDIA's cuDNN LSTM

Overview

Quasi-Recurrent Neural Network (QRNN) for PyTorch

Updated to support multi-GPU environments via DataParallel - see the the multigpu_dataparallel.py example.

This repository contains a PyTorch implementation of Salesforce Research's Quasi-Recurrent Neural Networks paper.

The QRNN provides similar accuracy to the LSTM but can be betwen 2 and 17 times faster than the highly optimized NVIDIA cuDNN LSTM implementation depending on the use case.

To install, simply run:

pip install cupy pynvrtc git+https://github.com/salesforce/pytorch-qrnn

If you use this code or our results in your research, please cite:

@article{bradbury2016quasi,
  title={{Quasi-Recurrent Neural Networks}},
  author={Bradbury, James and Merity, Stephen and Xiong, Caiming and Socher, Richard},
  journal={International Conference on Learning Representations (ICLR 2017)},
  year={2017}
}

Software Requirements

This codebase requires Python 3, PyTorch, pynvrtc (NVIDIA's Python Bindings to NVRTC), and CuPy. While the codebase contains a CPU implementation of the QRNN, the GPU QRNN implementation is used by default if possible. Requirements are provided in requirements.txt.

Example Usage

We've updated the previously released Salesforce Research AWD-LSTM language modeling codebase to support use of the AWD-QRNN. With the same number of parameters as the LSTM and less well tuned hyper parameters, the QRNN model trains over twice as quickly and achieves nearly equivalent state-of-the-art language modeling results. For full details, refer to the AWD-LSTM-LM repository.

Usage

The QRNN API is meant to be drop-in compatible with the LSTM for many standard use cases. As such, the easiest thing to do is replace any GRU or LSTM module with the QRNN.

Note: bidirectional QRNN is not yet supported though will be in the near future.

import torch
from torchqrnn import QRNN

seq_len, batch_size, hidden_size = 7, 20, 256
size = (seq_len, batch_size, hidden_size)
X = torch.autograd.Variable(torch.rand(size), requires_grad=True).cuda()

qrnn = QRNN(hidden_size, hidden_size, num_layers=2, dropout=0.4)
qrnn.cuda()
output, hidden = qrnn(X)

print(output.size(), hidden.size())

The full documentation for the QRNN is listed below:

QRNN(input_size, hidden_size, num_layers, dropout=0):
    Applies a multiple layer Quasi-Recurrent Neural Network (QRNN) to an input sequence.

    Args:
        input_size: The number of expected features in the input x.
        hidden_size: The number of features in the hidden state h. If not specified, the input size is used.
        num_layers: The number of QRNN layers to produce.
        layers: List of preconstructed QRNN layers to use for the QRNN module (optional).
        save_prev_x: Whether to store previous inputs for use in future convolutional windows (i.e. for a continuing sequence such as in language modeling). If true, you must call reset to remove cached previous values of x. Default: False.
        window: Defines the size of the convolutional window (how many previous tokens to look when computing the QRNN values). Supports 1 and 2. Default: 1.
        zoneout: Whether to apply zoneout (i.e. failing to update elements in the hidden state) to the hidden state updates. Default: 0.
        output_gate: If True, performs QRNN-fo (applying an output gate to the output). If False, performs QRNN-f. Default: True.
        use_cuda: If True, uses fast custom CUDA kernel. If False, uses naive for loop. Default: True.

    Inputs: X, hidden
        - X (seq_len, batch, input_size): tensor containing the features of the input sequence.
        - hidden (layers, batch, hidden_size): tensor containing the initial hidden state for the QRNN.

    Outputs: output, h_n
        - output (seq_len, batch, hidden_size): tensor containing the output of the QRNN for each timestep.
        - h_n (layers, batch, hidden_size): tensor containing the hidden state for t=seq_len

The included QRNN layer supports convolutional windows of size 1 or 2 but will be extended in the future to support arbitrary convolutions.

If you are using convolutional windows of size 2 (i.e. looking at the inputs from two previous timesteps to compute the input) and want to run over a long sequence in batches, such as when using BPTT, you can set save_prev_x=True and call reset when you wish to reset the cached previous inputs.

If you want flexibility in the definition of each QRNN layer, you can construct individual QRNNLayer modules and pass them to the QRNN module using the layer argument.

Speed

Speeds are between 2 and 17 times faster than NVIDIA's cuDNN LSTM, with the difference as a result of varying batch size and sequence length. The largest gains are for small batch sizes or long sequence lengths, both highlighting the LSTMs parallelization difficulty due to forced sequentiality. For full information, refer to the Quasi-Recurrent Neural Networks paper.

Figure 4 from QRNN paper

Pictured above is Figure 4 from the QRNN paper:
Left: Training speed for two-layer 640-unit PTB LM on a batch of 20 examples of 105 timesteps. “RNN” and “softmax” include the forward and backward times, while “optimization overhead” includes gradient clipping, L2 regularization, and SGD computations.
Right: Inference speed advantage of a 320-unit QRNN layer alone over an equal-sized cuDNN LSTM layer for data with the given batch size and sequence length. Training results are similar.

Extending the QRNN speed advantage to other recurrent architectures with ForgetMult

The QRNN architecture's speed advantage comes from two primary sources: the ability to batch all computations into a few large matrix multiplications and the use of a fast element-wise recurrence function. This recurrence function, named ForgetMult, is general and can be used in other scenarios. The ForgetMult takes two arguments - the candidate input x and forget gates f - and computes h = f * x + (1 - f) * hm1 where hm1 is the previous hidden state output.

The QRNN class is a thin wrapper around this that performs the large matrix multiplications for the candidate x, the forget gates f, and the output gates o. Any other operation which requires recurrence and can have precomputed values for the candidate x and forget gates f can use this fast form of recurrence.

Example usage of the ForgetMult module: output = ForgetMult()(f, x, hidden).

    ForgetMult computes a simple recurrent equation:
    h_t = f_t * x_t + (1 - f_t) * h_{t-1}

    This equation is equivalent to dynamic weighted averaging.

    Inputs: X, hidden
        - X (seq_len, batch, input_size): tensor containing the features of the input sequence.
        - F (seq_len, batch, input_size): tensor containing the forget gate values, assumed in range [0, 1].
        - hidden_init (batch, input_size): tensor containing the initial hidden state for the recurrence (h_{t-1}).
        - cuda: If True, use the fast element-wise CUDA kernel for recurrence. If False, uses naive for loop. Default: True.

Want to help out?

First, thanks! :)

Open tasks that are interesting:

  • Modify the ForgetMult CUDA kernel to produce a BackwardForgetMult. This will enable a bidirectional QRNN. The input should be the same - f and x - but the kernel should walk backwards through the inputs.
  • Bidirectional QRNN support (requires the modification above)
  • Support PyTorch's PackedSequence such that variable length sequences are correctly masked
  • Show how to use the underlying fast recurrence operator ForgetMult in other generic ways
Owner
Salesforce
A variety of vendor agnostic projects which power Salesforce
Salesforce
Official implementation of Few-Shot and Continual Learning with Attentive Independent Mechanisms

Few-Shot and Continual Learning with Attentive Independent Mechanisms This repository is the official implementation of Few-Shot and Continual Learnin

Chikan_Huang 25 Dec 08, 2022
This repo is customed for VisDrone.

Object Detection for VisDrone(无人机航拍图像目标检测) My environment 1、Windows10 (Linux available) 2、tensorflow = 1.12.0 3、python3.6 (anaconda) 4、cv2 5、ensemble

53 Jul 17, 2022
City-seeds - A random generator of cultural characteristics intended to spark ideas and help draw threads

City Seeds This is a random generator of cultural characteristics intended to sp

Aydin O'Leary 2 Mar 12, 2022
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
Learned Initializations for Optimizing Coordinate-Based Neural Representations

Learned Initializations for Optimizing Coordinate-Based Neural Representations Project Page | Paper Matthew Tancik*1, Ben Mildenhall*1, Terrance Wang1

Matthew Tancik 127 Jan 03, 2023
Code for our TKDE paper "Understanding WeChat User Preferences and “Wow” Diffusion"

wechat-wow-analysis Understanding WeChat User Preferences and “Wow” Diffusion. Fanjin Zhang, Jie Tang, Xueyi Liu, Zhenyu Hou, Yuxiao Dong, Jing Zhang,

18 Sep 16, 2022
HybVIO visual-inertial odometry and SLAM system

HybVIO A visual-inertial odometry system with an optional SLAM module. This is a research-oriented codebase, which has been published for the purposes

Spectacular AI 320 Jan 03, 2023
GRF: Learning a General Radiance Field for 3D Representation and Rendering

GRF: Learning a General Radiance Field for 3D Representation and Rendering [Paper] [Video] GRF: Learning a General Radiance Field for 3D Representatio

Alex Trevithick 243 Dec 29, 2022
This is the official Pytorch-version code of FlatGCN (Flattened Graph Convolutional Networks for Recommendation).

FlatGCN This is the official Pytorch-version code of FlatGCN (Flattened Graph Convolutional Networks for Recommendation, submitted to ICASSP2022). Req

Dreamer 2 Aug 09, 2022
Official PyTorch Code of GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection (CVPR 2021)

GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Mo

Abhinav Kumar 76 Jan 02, 2023
Single cell current best practices tutorial case study for the paper:Luecken and Theis, "Current best practices in single-cell RNA-seq analysis: a tutorial"

Scripts for "Current best-practices in single-cell RNA-seq: a tutorial" This repository is complementary to the publication: M.D. Luecken, F.J. Theis,

Theis Lab 968 Dec 28, 2022
Point detection through multi-instance deep heatmap regression for sutures in endoscopy

Suture detection PyTorch This repo contains the reference implementation of suture detection model in PyTorch for the paper Point detection through mu

artificial intelligence in the area of cardiovascular healthcare 3 Jul 16, 2022
A command line simple note taking app

Why yet another note taking program? note was designed with a very specific target in mind: me, and my 2354 scraps of paper. It runs from the command

64 Nov 20, 2022
DeepStochlog Package For Python

DeepStochLog Installation Installing SWI Prolog DeepStochLog requires SWI Prolog to run. Run the following commands to install: sudo apt-add-repositor

KU Leuven Machine Learning Research Group 17 Dec 23, 2022
A high-level Python library for Quantum Natural Language Processing

lambeq About lambeq is a toolkit for quantum natural language processing (QNLP). Documentation: https://cqcl.github.io/lambeq/ User support: lambeq-su

Cambridge Quantum 315 Jan 01, 2023
Automatic Image Background Subtraction

Automatic Image Background Subtraction This repo contains set of scripts for automatic one-shot image background subtraction task using the following

Oleg Sémery 6 Dec 05, 2022
Auditing Black-Box Prediction Models for Data Minimization Compliance

Data-Minimization-Auditor An auditing tool for model-instability based data minimization that is introduced in "Auditing Black-Box Prediction Models f

Bashir Rastegarpanah 2 Mar 24, 2022
Evaluation Pipeline for our ECCV2020: Journey Towards Tiny Perceptual Super-Resolution.

Journey Towards Tiny Perceptual Super-Resolution Test code for our ECCV2020 paper: https://arxiv.org/abs/2007.04356 Our x4 upscaling pre-trained model

Royson 6 Mar 30, 2022
Sleep staging from ECG, assisted with EEG

Sleep_Staging_Knowledge Distillation This codebase implements knowledge distillation approach for ECG based sleep staging assisted by EEG based sleep

2 Dec 12, 2022
Using BERT+Bi-LSTM+CRF

Chinese Medical Entity Recognition Based on BERT+Bi-LSTM+CRF Step 1 I share the dataset on my google drive, please download the whole 'CCKS_2019_Task1

Xiang WU 55 Dec 21, 2022