Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

Overview

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning

The predictive learning of spatiotemporal sequences aims to generate future images by learning from the historical context, where the visual dynamics are believed to have modular structures that can be learned with compositional subsystems.

First version at NeurIPS 2017

This repo first contains a PyTorch implementation of PredRNN (2017) [paper], a recurrent network with a pair of memory cells that operate in nearly independent transition manners, and finally form unified representations of the complex environment.

Concretely, besides the original memory cell of LSTM, this network is featured by a zigzag memory flow that propagates in both bottom-up and top-down directions across all layers, enabling the learned visual dynamics at different levels of RNNs to communicate.

New in PredRNN-V2 (2021)

This repo also includes the implementation of PredRNN-V2 (2021) [paper], which improves PredRNN in the following two aspects.

1. Memory Decoupling

We find that the pair of memory cells in PredRNN contain undesirable, redundant features, and thus present a memory decoupling loss to encourage them to learn modular structures of visual dynamics.

decouple

2. Reverse Scheduled Sampling

Reverse scheduled sampling is a new curriculum learning strategy for seq-to-seq RNNs. As opposed to scheduled sampling, it gradually changes the training process of the PredRNN encoder from using the previously generated frame to using the previous ground truth. Benefits: (1) It makes the training converge quickly by reducing the encoder-forecaster training gap. (2) It enforces the model to learn more from long-term input context.

rss

Evaluation in LPIPS

LPIPS is more sensitive to perceptual human judgments, the lower the better.

Moving MNIST KTH action
PredRNN 0.109 0.204
PredRNN-V2 0.071 0.139

Prediction examples

mnist

kth

radar

Get Started

  1. Install Python 3.7, PyTorch 1.3, and OpenCV 3.4.
  2. Download data. This repo contains code for two datasets: the Moving Mnist dataset and the KTH action dataset.
  3. Train the model. You can use the following bash script to train the model. The learned model will be saved in the --save_dir folder. The generated future frames will be saved in the --gen_frm_dir folder.
  4. You can get pretrained models from here.
cd mnist_script/
sh predrnn_mnist_train.sh
sh predrnn_v2_mnist_train.sh

cd kth_script/
sh predrnn_kth_train.sh
sh predrnn_v2_kth_train.sh

Citation

If you find this repo useful, please cite the following papers.

@inproceedings{wang2017predrnn,
  title={{PredRNN}: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal {LSTM}s},
  author={Wang, Yunbo and Long, Mingsheng and Wang, Jianmin and Gao, Zhifeng and Yu, Philip S},
  booktitle={Advances in Neural Information Processing Systems},
  pages={879--888},
  year={2017}
}

@misc{wang2021predrnn,
      title={{PredRNN}: A Recurrent Neural Network for Spatiotemporal Predictive Learning}, 
      author={Wang, Yunbo and Wu, Haixu and Zhang, Jianjin and Gao, Zhifeng and Wang, Jianmin and Yu, Philip S and Long, Mingsheng},
      year={2021},
      eprint={2103.09504},
      archivePrefix={arXiv},
}
Owner
THUML: Machine Learning Group @ THSS
Machine Learning Group, School of Software, Tsinghua University
THUML: Machine Learning Group @ THSS
Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”

Official implementation for TransDA Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”. Overview: Result: Prerequisites:

stanley 54 Dec 22, 2022
Deep learning (neural network) based remote photoplethysmography: how to extract pulse signal from video using deep learning tools

Deep-rPPG: Camera-based pulse estimation using deep learning tools Deep learning (neural network) based remote photoplethysmography: how to extract pu

Terbe Dániel 138 Dec 17, 2022
Western-3DSlicer-Modules - Point-Set Registrations for Ultrasound Probe Calibrations

Point-Set Registrations for Ultrasound Probe Calibrations -Undergraduate Thesis-

Matteo Tanzi 0 May 04, 2022
Plugin adapted from Ultralytics to bring YOLOv5 into Napari

napari-yolov5 Plugin adapted from Ultralytics to bring YOLOv5 into Napari. Training and detection can be done using the GUI. Training dataset must be

2 May 05, 2022
(SIGIR2020) “Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback’’

Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback About This repository accompanies the real-world experiments conducted i

yuta-saito 19 Dec 01, 2022
U-Net implementation in PyTorch for FLAIR abnormality segmentation in brain MRI

U-Net for brain segmentation U-Net implementation in PyTorch for FLAIR abnormality segmentation in brain MRI based on a deep learning segmentation alg

562 Jan 02, 2023
code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

Facebook Research 94 Oct 26, 2022
Official Pytorch implementation of MixMo framework

MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks Official PyTorch implementation of the MixMo framework | paper | docs Alexandr

79 Nov 07, 2022
A texturizer that I just made. Nothing special here.

texturizer This is a little project that I did with an hour's time. It texturizes an image given a image and a texture to texturize it with. There is

1 Nov 11, 2021
10th place solution for Google Smartphone Decimeter Challenge at kaggle.

Under refactoring 10th place solution for Google Smartphone Decimeter Challenge at kaggle. Google Smartphone Decimeter Challenge Global Navigation Sat

12 Oct 25, 2022
CS5242_2021 - Neural Networks and Deep Learning, NUS CS5242, 2021

CS5242_2021 Neural Networks and Deep Learning, NUS CS5242, 2021 Cloud Machine #1 : Google Colab (Free GPU) Follow this Notebook installation : https:/

Xavier Bresson 165 Oct 25, 2022
TorchDistiller - a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

This project is a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and i

yifan liu 147 Dec 03, 2022
RLBot Python bindings for the Rust crate rl_ball_sym

RLBot Python bindings for rl_ball_sym 0.6 Prerequisites: Rust & Cargo Build Tools for Visual Studio RLBot - Verify that the file %localappdata%\RLBotG

Eric Veilleux 2 Nov 25, 2022
code for ICCV 2021 paper 'Generalized Source-free Domain Adaptation'

G-SFDA Code (based on pytorch 1.3) for our ICCV 2021 paper 'Generalized Source-free Domain Adaptation'. [project] [paper]. Dataset preparing Download

Shiqi Yang 84 Dec 26, 2022
GPT, but made only out of gMLPs

GPT - gMLP This repository will attempt to crack long context autoregressive language modeling (GPT) using variations of gMLPs. Specifically, it will

Phil Wang 80 Dec 01, 2022
Toolkit for collecting and applying prompts

PromptSource Promptsource is a toolkit for collecting and applying prompts to NLP datasets. Promptsource uses a simple templating language to programa

BigScience Workshop 998 Jan 03, 2023
Diverse Branch Block: Building a Convolution as an Inception-like Unit

Diverse Branch Block: Building a Convolution as an Inception-like Unit (PyTorch) (CVPR-2021) DBB is a powerful ConvNet building block to replace regul

253 Dec 24, 2022
Learning Tracking Representations via Dual-Branch Fully Transformer Networks

Learning Tracking Representations via Dual-Branch Fully Transformer Networks DualTFR ⭐ We achieves the runner-ups for both VOT2021ST (short-term) and

phiphi 19 May 04, 2022
Matching python environment code for Lux AI 2021 Kaggle competition, and a gym interface for RL models.

Lux AI 2021 python game engine and gym This is a replica of the Lux AI 2021 game ported directly over to python. It also sets up a classic Reinforceme

Geoff McDonald 74 Nov 03, 2022
QTool: A Low-bit Quantization Toolbox for Deep Neural Networks in Computer Vision

This project provides abundant choices of quantization strategies (such as the quantization algorithms, training schedules and empirical tricks) for quantizing the deep neural networks into low-bit c

Monash Green AI Lab 51 Dec 10, 2022