PyTorch implementation of a collections of scalable Video Transformer Benchmarks.

Overview

PyTorch implementation of Video Transformer Benchmarks

This repository is mainly built upon Pytorch and Pytorch-Lightning. We wish to maintain a collections of scalable video transformer benchmarks, and discuss the training recipes of how to train a big video transformer model.

Now, we implement the TimeSformer and ViViT. And we have pre-trained the TimeSformer-B on Kinetics600, but still can't guarantee the performance reported in the paper. However, we find some relevant hyper-parameters which may help us to reach the target performance.

Table of Contents

  1. Difference
  2. TODO
  3. Setup
  4. Usage
  5. Result
  6. Acknowledge
  7. Contribution

Difference

In order to share the basic divided spatial-temporal attention module to different video transformer, we make some changes in the following apart.

1. Position embedding

We split the position embedding from R(nt*h*w×d) mentioned in the ViViT paper into R(nh*w×d) and R(nt×d) to stay the same as TimeSformer.

2. Class token

In order to make clear whether to add the class_token into the module forward computation, we only compute the interaction between class_token and query when the current layer is the last layer (except FFN) of each transformer block.

3. Initialize from the pre-trained model

  • Tokenization: the token embedding filter can be chosen either Conv2D or Conv3D, and the initializing weights of Conv3D filters from Conv2D can be replicated along temporal dimension and averaging them or initialized with zeros along the temporal positions except at the center t/2.
  • Temporal MSA module weights: one can choose to copy the weights from spatial MSA module or initialize all weights with zeros.
  • Initialize from the MAE pre-trained model provided by ZhiLiang, where the class_token that does not appear in the MAE pre-train model is initialized from truncated normal distribution.
  • Initialize from the ViT pre-trained model can be found here.

TODO

  • add more TimeSformer and ViViT variants pre-trained weights.
    • A larger version and other operation types.
  • add linear prob and partial fine-tune.
    • Make available to transfer the pre-trained model to downstream task.
  • add more scalable Video Transformer benchmarks.
    • We will also extend to multi-modality version, e.g Perceiver is coming soon.
  • add more diverse objective functions.
    • Pre-train on larger dataset through the dominated self-supervised methods, e.g Contrastive Learning and MAE.

Setup

pip install -r requirements.txt

Usage

Training

# path to Kinetics600 train set
TRAIN_DATA_PATH='/path/to/Kinetics600/train_list.txt'
# path to root directory
ROOT_DIR='/path/to/work_space'

python model_pretrain.py \
	-lr 0.005 \
	-pretrain 'vit' \
	-epoch 15 \
	-batch_size 8 \
	-num_class 600 \
	-frame_interval 32 \
	-root_dir ROOT_DIR \
	-train_data_path TRAIN_DATA_PATH

The minimal folder structure will look like as belows.

root_dir
├── pretrain_model
│   ├── pretrain_mae_vit_base_mask_0.75_400e.pth
│   ├── vit_base_patch16_224.pth
├── results
│   ├── experiment_tag
│   │   ├── ckpt
│   │   ├── log

Inference

# path to Kinetics600 pre-trained model
PRETRAIN_PATH='/path/to/pre-trained model'
# path to the test video sample
VIDEO_PATH='/path/to/video sample'

python model_inference.py \
	-pretrain PRETRAIN_PATH \
	-video_path VIDEO_PATH \
	-num_frames 8 \
	-frame_interval 32 \

Result

Kinetics-600

1. Model Zoo

name pretrain epochs num frames spatial crop top1_acc top5_acc weight log
TimeSformer-B ImageNet-21K 15e 8 224 78.4 93.6 Google drive or BaiduYun(code: yr4j) log

2. Train Recipe(ablation study)

2.1 Acc

operation top1_acc top5_acc top1_acc (three crop)
base 68.2 87.6 -
+ frame_interval 4 -> 16 (span more time) 72.9(+4.7) 91.0(+3.4) -
+ RandomCrop, flip (overcome overfit) 75.7(+2.8) 92.5(+1.5) -
+ batch size 16 -> 8 (more iterations) 75.8(+0.1) 92.4(-0.1) -
+ frame_interval 16 -> 24 (span more time) 77.7(+1.9) 93.3(+0.9) 78.4
+ frame_interval 24 -> 32 (span more time) 78.4(+0.7) 94.0(+0.7) 79.1

tips: frame_interval and data augment counts for the validation accuracy.


2.2 Time

operation epoch_time
base (start with DDP) 9h+
+ speed up training recipes 1h+
+ switch from get_batch first to sample_Indice first 0.5h
+ batch size 16 -> 8 33.32m
+ num_workers 8 -> 4 35.52m
+ frame_interval 16 -> 24 44.35m

tips: Improve the frame_interval will drop a lot on time performance.

1.speed up training recipes:

  • More GPU device.
  • pin_memory=True.
  • Avoid CPU->GPU Device transfer (such as .item(), .numpy(), .cpu() operations on tensor or log to disk).

2.get_batch first means that we firstly read all frames through the video reader, and then get the target slice of frames, so it largely slow down the data-loading speed.


Acknowledge

this repo is built on top of Pytorch-Lightning, decord and kornia. I also learn many code designs from MMaction2. I thank the authors for releasing their code.

Contribution

I look forward to seeing one can provide some ideas about the repo, please feel free to report it in the issue, or even better, submit a pull request.

And your star is my motivation, thank u~

Owner
Xin Ma
Xin Ma
Direct design of biquad filter cascades with deep learning by sampling random polynomials.

IIRNet Direct design of biquad filter cascades with deep learning by sampling random polynomials. Usage git clone https://github.com/csteinmetz1/IIRNe

Christian J. Steinmetz 55 Nov 02, 2022
A benchmark for the task of translation suggestion

WeTS: A Benchmark for Translation Suggestion Translation Suggestion (TS), which provides alternatives for specific words or phrases given the entire d

zhyang 55 Dec 24, 2022
CMUA-Watermark: A Cross-Model Universal Adversarial Watermark for Combating Deepfakes (AAAI2022)

CMUA-Watermark The official code for CMUA-Watermark: A Cross-Model Universal Adversarial Watermark for Combating Deepfakes (AAAI2022) arxiv. It is bas

50 Nov 26, 2022
SoK: Vehicle Orientation Representations for Deep Rotation Estimation

SoK: Vehicle Orientation Representations for Deep Rotation Estimation Raymond H. Tu, Siyuan Peng, Valdimir Leung, Richard Gao, Jerry Lan This is the o

FIRE Capital One Machine Learning of the University of Maryland 12 Oct 07, 2022
The Generic Manipulation Driver Package - Implements a ROS Interface over the robotics toolbox for Python

Armer Driver Armer aims to provide an interface layer between the hardware drivers of a robotic arm giving the user control in several ways: Joint vel

QUT Centre for Robotics (QCR) 13 Nov 26, 2022
Generative code template for PixelBeasts 10k NFT project.

generator-template Generative code template for combining transparent png attributes into 10,000 unique images. Used for the PixelBeasts 10k NFT proje

Yohei Nakajima 9 Aug 24, 2022
Propose a principled and practically effective framework for unsupervised accuracy estimation and error detection tasks with theoretical analysis and state-of-the-art performance.

Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles This project is for the paper: Detecting Errors and Estimating

Jiefeng Chen 13 Nov 21, 2022
MetaBalance: Improving Multi-Task Recommendations via Adapting Gradient Magnitudes of Auxiliary Tasks

MetaBalance: Improving Multi-Task Recommendations via Adapting Gradient Magnitudes of Auxiliary Tasks Introduction This repo contains the pytorch impl

Meta Research 38 Oct 10, 2022
TensorFlow ROCm port

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

ROCm Software Platform 622 Jan 09, 2023
Linear Variational State Space Filters

Linear Variational State Space Filters To set up the environment, use the provided scripts in the docker/ folder to build and run the codebase inside

0 Dec 13, 2021
2021搜狐校园文本匹配算法大赛 分比我们低的都是帅哥队

sohu_text_matching 2021搜狐校园文本匹配算法大赛Top2:分比我们低的都是帅哥队 本repo包含了本次大赛决赛环节提交的代码文件及答辩PPT,提交的模型文件可在百度网盘获取(链接:https://pan.baidu.com/s/1T9FtwiGFZhuC8qqwXKZSNA ,

hflserdaniel 43 Oct 01, 2022
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 02, 2023
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

Denis Emelin 42 Nov 24, 2022
An implementation of Deep Graph Infomax (DGI) in PyTorch

DGI Deep Graph Infomax (Veličković et al., ICLR 2019): https://arxiv.org/abs/1809.10341 Overview Here we provide an implementation of Deep Graph Infom

Petar Veličković 491 Jan 03, 2023
Flower - A Friendly Federated Learning Framework

Flower - A Friendly Federated Learning Framework Flower (flwr) is a framework for building federated learning systems. The design of Flower is based o

Adap 1.8k Jan 01, 2023
CVPR2021 Workshop - HDRUNet: Single Image HDR Reconstruction with Denoising and Dequantization.

HDRUNet [Paper Link] HDRUNet: Single Image HDR Reconstruction with Denoising and Dequantization By Xiangyu Chen, Yihao Liu, Zhengwen Zhang, Yu Qiao an

XyChen 105 Dec 20, 2022
This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

ICCV Workshop 2021 VTGAN This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

Sharif Amit Kamran 25 Dec 08, 2022
Artifacts for paper "MMO: Meta Multi-Objectivization for Software Configuration Tuning"

MMO: Meta Multi-Objectivization for Software Configuration Tuning This repository contains the data and code for the following paper that is currently

0 Nov 17, 2021
DROPO: Sim-to-Real Transfer with Offline Domain Randomization

DROPO: Sim-to-Real Transfer with Offline Domain Randomization Gabriele Tiboni, Karol Arndt, Ville Kyrki. This repository contains the code for the pap

Gabriele Tiboni 8 Dec 19, 2022