Code for "Searching for Efficient Multi-Stage Vision Transformers"

Overview

Searching for Efficient Multi-Stage Vision Transformers

This repository contains the official Pytorch implementation of "Searching for Efficient Multi-Stage Vision Transformers" and is based on DeiT and timm.

photo not available

Illustration of the proposed multi-stage ViT-Res network.


photo not available

Illustration of weight-sharing neural architecture search with multi-architectural sampling.


photo not available

Accuracy-MACs trade-offs of the proposed ViT-ResNAS. Our networks achieves comparable results to previous work.

Content

  1. Requirements
  2. Data Preparation
  3. Pre-Trained Models
  4. Training ViT-Res
  5. Performing Neural Architecture Search
  6. Evaluation

Requirements

The codebase is tested with 8 V100 (16GB) GPUs.

To install requirements:

    pip install -r requirements.txt

Docker files are provided to set up the environment. Please run:

    cd docker

    sh 1_env_setup.sh
    
    sh 2_build_docker_image.sh
    
    sh 3_run_docker_image.sh

Make sure that the configuration specified in 3_run_docker_image.sh is correct before running the command.

Data Preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Pre-Trained Models

Pre-trained weights of super-networks and searched networks can be found here.

Training ViT-Res

To train ViT-Res-Tiny, modify IMAGENET_PATH in scripts/vit-sr-nas/reference_net/tiny.sh and run:

    sh scripts/vit-sr-nas/reference_net/tiny.sh 

We use 8 GPUs for training. Please modify numbers of GPUs (--nproc_per_node) and adjust batch size (--batch-size) if different numbers of GPUs are used.

Performing Neural Architecture Search

0. Building Sub-Train and Sub-Val Set

Modify _SOURCE_DIR, _SUB_TRAIN_DIR, and _SUB_VAL_DIR in search_utils/build_subset.py, and run:

    cd search_utils
    
    python build_subset.py
    
    cd ..

1. Super-Network Training

Before running each script, modify IMAGENET_PATH (directed to the directory containing the sub-train and sub-val sets).

For ViT-ResNAS-Tiny, run:

    sh scripts/vit-sr-nas/super_net/tiny.sh

For ViT-ResNAS-Small and Medium, run:

    sh scripts/vit-sr-nas/super_net/small.sh

2. Evolutionary Search

Before running each script, modify IMAGENET_PATH (directed to the directory containing the sub-train and sub-val sets) and MODEL_PATH.

For ViT-ResNAS-Tiny, run:

    sh scripts/vit-sr-nas/evolutionary_search/tiny.sh

For ViT-ResNAS-Small, run:

    sh scripts/vit-sr-nas/evolutionary_search/[email protected]

For ViT-ResNAS-Medium, run:

    sh scripts/vit-sr-nas/evolutionary_search/[email protected]

After running evolutionary search for each network, see summary.txt in output directory and modify network_def.

For example, the network_def in summary.txt is ((4, 220), (1, (220, 5, 32), (220, 880), 1), (1, (220, 5, 32), (220, 880), 1), (1, (220, 7, 32), (220, 800), 1), (1, (220, 7, 32), (220, 800), 0), (1, (220, 5, 32), (220, 720), 1), (1, (220, 5, 32), (220, 720), 1), (1, (220, 5, 32), (220, 720), 1), (3, 220, 440), (1, (440, 10, 48), (440, 1760), 1), (1, (440, 10, 48), (440, 1440), 1), (1, (440, 10, 48), (440, 1920), 1), (1, (440, 10, 48), (440, 1600), 1), (1, (440, 12, 48), (440, 1600), 1), (1, (440, 12, 48), (440, 1120), 0), (1, (440, 12, 48), (440, 1440), 1), (3, 440, 880), (1, (880, 16, 64), (880, 3200), 1), (1, (880, 12, 64), (880, 3200), 1), (1, (880, 16, 64), (880, 2880), 1), (1, (880, 12, 64), (880, 3200), 0), (1, (880, 12, 64), (880, 2240), 1), (1, (880, 12, 64), (880, 3520), 0), (1, (880, 14, 64), (880, 2560), 1), (2, 880, 1000)).

Remove the element in the tuple that has 1 in the first element and 0 in the last element (e.g. (1, (220, 5, 32), (220, 880), 0)).

This reflects that the transformer block is removed in a searched network.

After this modification, the network_def becomes ((4, 220), (1, (220, 5, 32), (220, 880), 1), (1, (220, 5, 32), (220, 880), 1), (1, (220, 7, 32), (220, 800), 1), (1, (220, 5, 32), (220, 720), 1), (1, (220, 5, 32), (220, 720), 1), (1, (220, 5, 32), (220, 720), 1), (3, 220, 440), (1, (440, 10, 48), (440, 1760), 1), (1, (440, 10, 48), (440, 1440), 1), (1, (440, 10, 48), (440, 1920), 1), (1, (440, 10, 48), (440, 1600), 1), (1, (440, 12, 48), (440, 1600), 1), (1, (440, 12, 48), (440, 1440), 1), (3, 440, 880), (1, (880, 16, 64), (880, 3200), 1), (1, (880, 12, 64), (880, 3200), 1), (1, (880, 16, 64), (880, 2880), 1), (1, (880, 12, 64), (880, 2240), 1), (1, (880, 14, 64), (880, 2560), 1), (2, 880, 1000)).

Then, use the searched network_def for searched network training.

3. Searched Network Training

Before running each script, modify IMAGENET_PATH.

For ViT-ResNAS-Tiny, run:

    sh scripts/vit-sr-nas/searched_net/tiny.sh

For ViT-ResNAS-Small, run:

    sh scripts/vit-sr-nas/searched_net/[email protected]

For ViT-ResNAS-Medium, run:

    sh scripts/vit-sr-nas/searched_net/[email protected]

4. Fine-tuning Trained Networks at Higher Resolution

Before running, modify IMAGENET_PATH and FINETUNE_PATH (directed to trained ViT-ResNAS-Medium checkpoint). Then, run:

    sh scripts/vit-sr-nas/finetune/[email protected]

To fine-tune at different resolutions, modify --model, --input-size and --mix-patch-len. We provide models at resolutions 280, 336, and 392 as shown in here. Note that --input-size must be equal to "56 * --mix-patch-len" since the spatial size in ViT-ResNAS is reduced by 56X.

Evaluation

Before running, modify IMAGENET_PATH and MODEL_PATH. Then, run:

    sh scripts/vit-sr-nas/eval/[email protected]

Questions

Please direct questions to Yi-Lun Liao ([email protected]).

License

This repository is released under the CC-BY-NC 4.0. license as found in the LICENSE file.

Owner
Yi-Lun Liao
Yi-Lun Liao
3D-printable hand-strapped keyboard

Note: This repo has not been cleaned up and prepared for general consumption at all. This is just a dump of the project files. If there is any interes

Wojciech Baranowski 41 Dec 31, 2022
Using deep learning to predict gene structures of the coding genes in DNA sequences of Arabidopsis thaliana

DeepGeneAnnotator: A tool to annotate the gene in the genome The master thesis of the "Using deep learning to predict gene structures of the coding ge

Ching-Tien Wang 3 Sep 09, 2022
Source code for the GPT-2 story generation models in the EMNLP 2020 paper "STORIUM: A Dataset and Evaluation Platform for Human-in-the-Loop Story Generation"

Storium GPT-2 Models This is the official repository for the GPT-2 models described in the EMNLP 2020 paper [STORIUM: A Dataset and Evaluation Platfor

Nader Akoury 27 Dec 20, 2022
python library for invisible image watermark (blind image watermark)

invisible-watermark invisible-watermark is a python library and command line tool for creating invisible watermark over image.(aka. blink image waterm

Shield Mountain 572 Jan 07, 2023
Code for BMVC2021 "MOS: A Low Latency and Lightweight Framework for Face Detection, Landmark Localization, and Head Pose Estimation"

MOS-Multi-Task-Face-Detect Introduction This repo is the official implementation of "MOS: A Low Latency and Lightweight Framework for Face Detection,

104 Dec 08, 2022
Source code for CVPR 2020 paper "Learning to Forget for Meta-Learning"

L2F - Learning to Forget for Meta-Learning Sungyong Baik, Seokil Hong, Kyoung Mu Lee Source code for CVPR 2020 paper "Learning to Forget for Meta-Lear

Sungyong Baik 29 May 22, 2022
PyTorch 1.5 implementation for paper DECOR-GAN: 3D Shape Detailization by Conditional Refinement.

DECOR-GAN PyTorch 1.5 implementation for paper DECOR-GAN: 3D Shape Detailization by Conditional Refinement, Zhiqin Chen, Vladimir G. Kim, Matthew Fish

Zhiqin Chen 72 Dec 31, 2022
PyTorch implementation of Hierarchical Multi-label Text Classification: An Attention-based Recurrent Network

hierarchical-multi-label-text-classification-pytorch Hierarchical Multi-label Text Classification: An Attention-based Recurrent Network Approach This

Mingu Kang 17 Dec 13, 2022
Source code for "Progressive Transformers for End-to-End Sign Language Production" (ECCV 2020)

Progressive Transformers for End-to-End Sign Language Production Source code for "Progressive Transformers for End-to-End Sign Language Production" (B

58 Dec 21, 2022
The implementation of the paper "HIST: A Graph-based Framework for Stock Trend Forecasting via Mining Concept-Oriented Shared Information".

The HIST framework for stock trend forecasting The implementation of the paper "HIST: A Graph-based Framework for Stock Trend Forecasting via Mining C

Wentao Xu 110 Dec 27, 2022
CNN Based Meta-Learning for Noisy Image Classification and Template Matching

CNN Based Meta-Learning for Noisy Image Classification and Template Matching Introduction This master thesis used a few-shot meta learning approach to

Kumar Manas 2 Dec 09, 2021
A3C LSTM Atari with Pytorch plus A3G design

NEWLY ADDED A3G A NEW GPU/CPU ARCHITECTURE OF A3C FOR SUBSTANTIALLY ACCELERATED TRAINING!! RL A3C Pytorch NEWLY ADDED A3G!! New implementation of A3C

David Griffis 532 Jan 02, 2023
Reference implementation for Deep Unsupervised Learning using Nonequilibrium Thermodynamics

Diffusion Probabilistic Models This repository provides a reference implementation of the method described in the paper: Deep Unsupervised Learning us

Jascha Sohl-Dickstein 238 Jan 02, 2023
[ICCV'2021] Image Inpainting via Conditional Texture and Structure Dual Generation

[ICCV'2021] Image Inpainting via Conditional Texture and Structure Dual Generation

Xiefan Guo 122 Dec 11, 2022
Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research

Megaverse Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research. The efficient design of the engine enables ph

Aleksei Petrenko 191 Dec 23, 2022
FindFunc is an IDA PRO plugin to find code functions that contain a certain assembly or byte pattern, reference a certain name or string, or conform to various other constraints.

FindFunc: Advanced Filtering/Finding of Functions in IDA Pro FindFunc is an IDA Pro plugin to find code functions that contain a certain assembly or b

213 Dec 17, 2022
You can draw the corresponding bounding box into the image and save it according to the result file (txt format) run by the tracker.

You can draw the corresponding bounding box into the image and save it according to the result file (txt format) run by the tracker.

Huiyiqianli 42 Dec 06, 2022
Code & Models for Temporal Segment Networks (TSN) in ECCV 2016

Temporal Segment Networks (TSN) We have released MMAction, a full-fledged action understanding toolbox based on PyTorch. It includes implementation fo

1.4k Jan 01, 2023
ROS Basics and TurtleSim

Waypoint Follower Anna Garverick This package draws given waypoints, then waits for a service call with a start position to send the turtle to each wa

Anna Garverick 1 Dec 13, 2021
PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot

Progressive Growing of GANs inference in PyTorch with CelebA training snapshot Description This is an inference sample written in PyTorch of the origi

320 Nov 21, 2022