As-ViT: Auto-scaling Vision Transformers without Training

Overview

As-ViT: Auto-scaling Vision Transformers without Training [PDF]

MIT licensed

Wuyang Chen, Wei Huang, Xianzhi Du, Xiaodan Song, Zhangyang Wang, Denny Zhou

In ICLR 2022.

Note: We implemented topology search (sec. 3.3) and scaling (sec. 3.4) in this code base in PyTorch. Our training code is based on Tensorflow and Keras on TPU, which will be released soon.

Overview

We present As-ViT, a framework that unifies the automatic architecture design and scaling for ViT (vision transformer), in a training-free strategy.

Highlights:

  • Trainig-free ViT Architecture Design: we design a "seed" ViT topology by leveraging a training-free search process. This extremely fast search is fulfilled by our comprehensive study of ViT's network complexity (length distorsion), yielding a strong Kendall-tau correlation with ground-truth accuracies.
  • Trainig-free ViT Architecture Scaling: starting from the "seed" topology, we automate the scaling rule for ViTs by growing widths/depths to different ViT layers. This will generate a series of architectures with different numbers of parameters in a single run.
  • Efficient ViT Training via Progressive Tokenization: we observe that ViTs can tolerate coarse tokenization in early training stages, and further propose to train ViTs faster and cheaper with a progressive tokenization strategy.

teaser
Left: Length Distortion shows a strong correlation with ViT's accuracy. Middle: Auto scaling rule of As-ViT. Right: Progressive re-tokenization for efficient ViT training.

Prerequisites

  • Ubuntu 18.04
  • Python 3.6.9
  • CUDA 11.0 (lower versions may work but were not tested)
  • NVIDIA GPU + CuDNN v7.6

This repository has been tested on V100 GPU. Configurations may need to be changed on different platforms.

Installation

  • Clone this repo:
git clone https://github.com/VITA-Grou/AsViT.git
cd AsViT
  • Install dependencies:
pip install -r requirements.txt

1. Seed As-ViT Topology Search

CUDA_VISIBLE_DEVICES=0 python ./search/reinforce.py --save_dir ./output/REINFORCE-imagenet --data_path /path/to/imagenet

This job will return you a seed topology. For example, our search seed topology is 8,2,3|4,1,2|4,1,4|4,1,6|32, which can be explained as below:

Stage1 Stage2 Stage3 Stage4 Head
Kernel K1 Split S1 Expansion E1 Kernel K2 Split S2 Expansion E2 Kernel K3 Split S3 Expansion E3 Kernel K4 Split S4 Expansion E4
8 2 3 4 1 2 4 1 4 4 1 6 32

2. Scaling

CUDA_VISIBLE_DEVICES=0 python ./search/grow.py --save_dir ./output/GROW-imagenet \
--arch "[arch]" --data_path /path/to/imagenet

Here [arch] is the seed topology (output from step 1 above). This job will return you a series of topologies. For example, our largest topology (As-ViT Large) is 8,2,3,5|4,1,2,2|4,1,4,5|4,1,6,2|32,180, which can be explained as below:

Stage1 Stage2 Stage3 Stage4 Head Initial Hidden Size
Kernel K1 Split S1 Expansion E1 Layers L1 Kernel K2 Split S2 Expansion E2 Layers L2 Kernel K3 Split S3 Expansion E3 Layers L3 Kernel K4 Split S4 Expansion E4 Layers L4
8 2 3 5 4 1 2 2 4 1 4 5 4 1 6 2 32 180

3. Evaluation

Tensorflow and Keras code for training on TPU. To be released soon.

Citation

@inproceedings{chen2021asvit,
  title={Auto-scaling Vision Transformers without Training},
  author={Chen, Wuyang and Huang, Wei and Du, Xianzhi and Song, Xiaodan and Wang, Zhangyang and Zhou, Denny},
  booktitle={International Conference on Learning Representations},
  year={2022}
}
Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
Civsim is a basic civilisation simulation and modelling system built in Python 3.8.

Civsim Introduction Civsim is a basic civilisation simulation and modelling system built in Python 3.8. It requires the following packages: perlin_noi

17 Aug 08, 2022
PyTorch implementation of the paper The Lottery Ticket Hypothesis for Object Recognition

LTH-ObjectRecognition The Lottery Ticket Hypothesis for Object Recognition Sharath Girish*, Shishira R Maiya*, Kamal Gupta, Hao Chen, Larry Davis, Abh

16 Feb 06, 2022
A Semantic Segmentation Network for Urban-Scale Building Footprint Extraction Using RGB Satellite Imagery

A Semantic Segmentation Network for Urban-Scale Building Footprint Extraction Using RGB Satellite Imagery This repository is the official implementati

Aatif Jiwani 42 Dec 08, 2022
一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。

captcha_server 一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。 使用方法 python = 3.8 以上环境 pip install -r requirements.txt -i https://pypi.douban.com/simple gun

Sml2h3 189 Dec 02, 2022
DiffQ performs differentiable quantization using pseudo quantization noise. It can automatically tune the number of bits used per weight or group of weights, in order to achieve a given trade-off between model size and accuracy.

Differentiable Model Compression via Pseudo Quantization Noise DiffQ performs differentiable quantization using pseudo quantization noise. It can auto

Facebook Research 145 Dec 30, 2022
Image-to-image regression with uncertainty quantification in PyTorch

Image-to-image regression with uncertainty quantification in PyTorch. Take any dataset and train a model to regress images to images with rigorous, distribution-free uncertainty quantification.

Anastasios Angelopoulos 25 Dec 26, 2022
FLVIS: Feedback Loop Based Visual Initial SLAM

FLVIS Feedback Loop Based Visual Inertial SLAM 1-Video EuRoC DataSet MH_05 Handheld Test in Lab FlVIS on UAV Platform 2-Relevent Publication: Under Re

UAV Lab - HKPolyU 182 Dec 04, 2022
CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP

CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP Andreas Fürst* 1, Elisabeth Rumetshofer* 1, Viet Tran1, Hubert Ramsauer1, Fei Tang3, Joh

Institute for Machine Learning, Johannes Kepler University Linz 133 Jan 04, 2023
Streamlit tool to explore coco datasets

What is this This tool given a COCO annotations file and COCO predictions file will let you explore your dataset, visualize results and calculate impo

Jakub Cieslik 75 Dec 16, 2022
Deep Reinforcement Learning for mobile robot navigation in ROS Gazebo simulator

DRL-robot-navigation Deep Reinforcement Learning for mobile robot navigation in ROS Gazebo simulator. Using Twin Delayed Deep Deterministic Policy Gra

87 Jan 07, 2023
CLIP + VQGAN / PixelDraw

clipit Yet Another VQGAN-CLIP Codebase This started as a fork of @nerdyrodent's VQGAN-CLIP code which was based on the notebooks of @RiversWithWings a

dribnet 276 Dec 12, 2022
Code for 'Self-Guided and Cross-Guided Learning for Few-shot segmentation. (CVPR' 2021)'

SCL Introduction Code for 'Self-Guided and Cross-Guided Learning for Few-shot segmentation. (CVPR' 2021)' We evaluated our approach using two baseline

34 Oct 08, 2022
A code generator from ONNX to PyTorch code

onnx-pytorch Generating pytorch code from ONNX. Currently support onnx==1.9.0 and torch==1.8.1. Installation From PyPI pip install onnx-pytorch From

Wenhao Hu 94 Jan 06, 2023
Implementation of our paper "Video Playback Rate Perception for Self-supervised Spatio-Temporal Representation Learning".

PRP Introduction This is the implementation of our paper "Video Playback Rate Perception for Self-supervised Spatio-Temporal Representation Learning".

yuanyao366 39 Dec 29, 2022
Lip Reading - Cross Audio-Visual Recognition using 3D Convolutional Neural Networks

Lip Reading - Cross Audio-Visual Recognition using 3D Convolutional Neural Networks - Official Project Page This repository contains the code develope

Amirsina Torfi 1.7k Dec 18, 2022
Deep learned, hardware-accelerated 3D object pose estimation

Isaac ROS Pose Estimation Overview This repository provides NVIDIA GPU-accelerated packages for 3D object pose estimation. Using a deep learned pose e

NVIDIA Isaac ROS 41 Dec 18, 2022
Square Root Bundle Adjustment for Large-Scale Reconstruction

RootBA: Square Root Bundle Adjustment Project Page | Paper | Poster | Video | Code Table of Contents Citation Dependencies Installing dependencies on

Nikolaus Demmel 205 Dec 20, 2022
Python script for performing depth completion from sparse depth and rgb images using the msg_chn_wacv20. model in Tensorflow Lite.

TFLite-msg_chn_wacv20-depth-completion Python script for performing depth completion from sparse depth and rgb images using the msg_chn_wacv20. model

Ibai Gorordo 2 Oct 04, 2021
BMVC 2021 Oral: code for BI-GCN: Boundary-Aware Input-Dependent Graph Convolution for Biomedical Image Segmentation

BMVC 2021 BI-GConv: Boundary-Aware Input-Dependent Graph Convolution for Biomedical Image Segmentation Necassary Dependencies: PyTorch 1.2.0 Python 3.

Yanda Meng 15 Nov 08, 2022
E2EDNA2 - An automated pipeline for simulation of DNA aptamers complexed with small molecules and short peptides

E2EDNA2 - An automated pipeline for simulation of DNA aptamers complexed with small molecules and short peptides

11 Nov 08, 2022