Official PyTorch implementation of "Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets" (ICLR 2021)

Related tags

Deep LearningMetaD2A
Overview

Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets

This is the official PyTorch implementation for the paper Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets (ICLR 2021) : https://openreview.net/forum?id=rkQuFUmUOg3.

Abstract

Despite the success of recent Neural Architecture Search (NAS) methods on various tasks which have shown to output networks that largely outperform human-designed networks, conventional NAS methods have mostly tackled the optimization of searching for the network architecture for a single task (dataset), which does not generalize well across multiple tasks (datasets). Moreover, since such task-specific methods search for a neural architecture from scratch for every given task, they incur a large computational cost, which is problematic when the time and monetary budget are limited. In this paper, we propose an efficient NAS framework that is trained once on a database consisting of datasets and pretrained networks and can rapidly search a neural architecture for a novel dataset. The proposed MetaD2A (Meta Dataset-to-Architecture) model can stochastically generate graphs (architectures) from a given set (dataset) via a cross-modal latent space learned with amortized meta-learning. Moreover, we also propose a meta-performance predictor to estimate and select the best architecture without direct training on target datasets. The experimental results demonstrate that our model meta-learned on subsets of ImageNet-1K and architectures from NAS-Bench 201 search space successfully generalizes to multiple benchmark datasets including CIFAR-10 and CIFAR-100, with an average search time of 33 GPU seconds. Even under a large search space, MetaD2A is 5.5K times faster than NSGANetV2, a transferable NAS method, with comparable performance. We believe that the MetaD2A proposes a new research direction for rapid NAS as well as ways to utilize the knowledge from rich databases of datasets and architectures accumulated over the past years.

Framework of MetaD2A Model

Prerequisites

  • Python 3.6 (Anaconda)
  • PyTorch 1.6.0
  • CUDA 10.2
  • python-igraph==0.8.2
  • tqdm==4.50.2
  • torchvision==0.7.0
  • python-igraph==0.8.2
  • nas-bench-201==1.3
  • scipy==1.5.2

If you are not familiar with preparing conda environment, please follow the below instructions

$ conda create --name metad2a python=3.6
$ conda activate metad2a
$ conda install pytorch==1.6.0 torchvision cudatoolkit=10.2 -c pytorch
$ pip install nas-bench-201
$ conda install -c conda-forge tqdm
$ conda install -c conda-forge python-igraph
$ pip install scipy

And for data preprocessing,

$ pip install requests

Hardware Spec used for experiments of the paper

  • GPU: A single Nvidia GeForce RTX 2080Ti
  • CPU: Intel(R) Xeon(R) Silver 4114 CPU @ 2.20GHz

NAS-Bench-201

Go to the folder for NAS-Bench-201 experiments (i.e. MetaD2A_nas_bench_201)

$ cd MetaD2A_nas_bench_201

Data Preparation

To download preprocessed data files, run get_files/get_preprocessed_data.py:

$ python get_files/get_preprocessed_data.py

It will take some time to download and preprocess each dataset.

To download MNIST, Pets and Aircraft Datasets, run get_files/get_{DATASET}.py

$ python get_files/get_mnist.py
$ python get_files/get_aircraft.py
$ python get_files/get_pets.py

Other datasets such as Cifar10, Cifar100, SVHN will be automatically downloaded when you load dataloader by torchvision.

If you want to use your own dataset, please first make your own preprocessed data, by modifying process_dataset.py .

$ process_dataset.py

MetaD2A Evaluation (Meta-Test)

You can download trained checkpoint files for generator and predictor

$ python get_files/get_checkpoint.py
$ python get_files/get_predictor_checkpoint.py

1. Evaluation on Cifar10 and Cifar100

By set --data-name as the name of dataset (i.e. cifar10, cifar100), you can evaluate the specific dataset only

# Meta-testing for generator 
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 400 --num-gen-arch 500 --data-name {DATASET_NAME}

After neural architecture generation is completed, meta-performance predictor selects high-performing architectures among the candidates

# Meta-testing for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 --test --num-gen-arch 500 --data-name {DATASET_NAME}

2. Evaluation on Other Datasets

By set --data-name as the name of dataset (i.e. mnist, svhn, aircraft, pets), you can evaluate the specific dataset only

# Meta-testing for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 400 --num-gen-arch 50 --data-name {DATASET_NAME}

After neural architecture generation is completed, meta-performance predictor selects high-performing architectures among the candidates

# Meta-testing for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 --test --num-gen-arch 50 --data-name {DATASET_NAME}

Meta-Training MetaD2A Model

You can train the generator and predictor as follows

# Meta-training for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 
                 
# Meta-training for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 

Results

The results of training architectures which are searched by meta-trained MetaD2A model for each dataset

Accuracy

CIFAR10 CIFAR100 MNIST SVHN Aircraft Oxford-IIT Pets
PC-DARTS 93.66±0.17 66.64±0.04 99.66±0.04 95.40±0.67 46.08±7.00 25.31±1.38
MetaD2A (Ours) 94.37±0.03 73.51±0.00 99.71±0.08 96.34±0.37 58.43±1.18 41.50±4.39

Search Time (GPU Sec)

CIFAR10 CIFAR100 MNIST SVHN Aircraft Oxford-IIT Pets
PC-DARTS 10395 19951 24857 31124 3524 2844
MetaD2A (Ours) 69 96 7 7 10 8

MobileNetV3 Search Space

Go to the folder for MobileNetV3 Search Space experiments (i.e. MetaD2A_mobilenetV3)

$ cd MetaD2A_mobilenetV3

And follow README.md written for experiments of MobileNetV3 Search Space

Citation

If you found the provided code useful, please cite our work.

@inproceedings{
    lee2021rapid,
    title={Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets},
    author={Hayeon Lee and Eunyoung Hyung and Sung Ju Hwang},
    booktitle={ICLR},
    year={2021}
}

Reference

Owner
Ph.D. student @ School of Computing, Korea Advanced Institute of Science and Technology (KAIST)
Fine-tuning StyleGAN2 for Cartoon Face Generation

Cartoon-StyleGAN 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation Abstract Recent studies have shown remarkable success in the unsupervised imag

Jihye Back 520 Jan 04, 2023
UV matrix decompostion using movielens dataset

UV-matrix-decompostion-with-kfold UV matrix decompostion using movielens dataset upload the 'ratings.dat' file install the following python libraries

2 Oct 18, 2022
This folder contains the implementation of the multi-relational attribute propagation algorithm.

MrAP This folder contains the implementation of the multi-relational attribute propagation algorithm. It requires the package pytorch-scatter. Please

6 Dec 06, 2022
Unofficial PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution

PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution [arXiv 2021].

Christoph Reich 122 Dec 12, 2022
Pytorch implementation of "A simple neural network module for relational reasoning" (Relational Networks)

Pytorch implementation of Relational Networks - A simple neural network module for relational reasoning Implemented & tested on Sort-of-CLEVR task. So

Kim Heecheol 800 Dec 05, 2022
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning The predictive learning of spatiotemporal sequences aims to generate future

THUML: Machine Learning Group @ THSS 243 Dec 26, 2022
This repository collects project-relevant Isabelle/HOL formalizations.

Isabelle/HOL formalizations related to the AuReLeE project Formalization of Abstract Argumentation Frameworks See AbstractArgumentation folder for the

AuReLeE project 1 Sep 10, 2022
Sentiment analysis translations of the Bhagavad Gita

Sentiment and Semantic Analysis of Bhagavad Gita Translations It is well known that translations of songs and poems not only breaks rhythm and rhyming

Machine learning and Bayesian inference @ UNSW Sydney 3 Aug 01, 2022
Implementation of TimeSformer, a pure attention-based solution for video classification

TimeSformer - Pytorch Implementation of TimeSformer, a pure and simple attention-based solution for reaching SOTA on video classification.

Phil Wang 602 Jan 03, 2023
Official PyTorch implementation of "Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning" (ICCV2021 Oral)

MeTAL - Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning (ICCV2021 Oral) Sungyong Baik, Janghoon Choi, Heewon Kim, Dohee Cho, Jaes

Sungyong Baik 44 Dec 29, 2022
ZeroGen: Efficient Zero-shot Learning via Dataset Generation

ZEROGEN This repository contains the code for our paper “ZeroGen: Efficient Zero

Jiacheng Ye 31 Dec 30, 2022
Benchmark library for high-dimensional HPO of black-box models based on Weighted Lasso regression

LassoBench LassoBench is a library for high-dimensional hyperparameter optimization benchmarks based on Weighted Lasso regression. Note: LassoBench is

Kenan Šehić 5 Mar 15, 2022
Demo code for ICCV 2021 paper "Sensor-Guided Optical Flow"

Sensor-Guided Optical Flow Demo code for "Sensor-Guided Optical Flow", ICCV 2021 This code is provided to replicate results with flow hints obtained f

10 Mar 16, 2022
Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis (CVPR2022)

Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis Multi-View Consistent Generative Adversarial Networks for 3D-aware

Xuanmeng Zhang 78 Dec 10, 2022
Open source person re-identification library in python

Open-ReID Open-ReID is a lightweight library of person re-identification for research purpose. It aims to provide a uniform interface for different da

Tong Xiao 1.3k Jan 01, 2023
Towards Rolling Shutter Correction and Deblurring in Dynamic Scenes (CVPR2021)

RSCD (BS-RSCD & JCD) Towards Rolling Shutter Correction and Deblurring in Dynamic Scenes (CVPR2021) by Zhihang Zhong, Yinqiang Zheng, Imari Sato We co

81 Dec 15, 2022
PyTorch Implementation of Fully Convolutional Networks. (Training code to reproduce the original result is available.)

pytorch-fcn PyTorch implementation of Fully Convolutional Networks. Requirements pytorch = 0.2.0 torchvision = 0.1.8 fcn = 6.1.5 Pillow scipy tqdm

Kentaro Wada 1.6k Jan 07, 2023
Source Code for Simulations in the Publication "Can the brain use waves to solve planning problems?"

Code for Simulations in the Publication Can the brain use waves to solve planning problems? Installing Required Python Packages Please use Python vers

EMD Group 2 Jul 01, 2022
This is my research project for the Irving Center for Cancer Dynamics/Azizi Lab, Columbia University.

bayesian_uncertainty This is my research project for the Irving Center for Cancer Dynamics/Azizi Lab, Columbia University. In this project I build a s

Max David Gupta 1 Feb 13, 2022
A tutorial on training a DarkNet YOLOv4 model for the CrowdHuman dataset

YOLOv4 CrowdHuman Tutorial This is a tutorial demonstrating how to train a YOLOv4 people detector using Darknet and the CrowdHuman dataset. Table of c

JK Jung 118 Nov 10, 2022