Official PyTorch implementation of "Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets" (ICLR 2021)

Related tags

Deep LearningMetaD2A
Overview

Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets

This is the official PyTorch implementation for the paper Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets (ICLR 2021) : https://openreview.net/forum?id=rkQuFUmUOg3.

Abstract

Despite the success of recent Neural Architecture Search (NAS) methods on various tasks which have shown to output networks that largely outperform human-designed networks, conventional NAS methods have mostly tackled the optimization of searching for the network architecture for a single task (dataset), which does not generalize well across multiple tasks (datasets). Moreover, since such task-specific methods search for a neural architecture from scratch for every given task, they incur a large computational cost, which is problematic when the time and monetary budget are limited. In this paper, we propose an efficient NAS framework that is trained once on a database consisting of datasets and pretrained networks and can rapidly search a neural architecture for a novel dataset. The proposed MetaD2A (Meta Dataset-to-Architecture) model can stochastically generate graphs (architectures) from a given set (dataset) via a cross-modal latent space learned with amortized meta-learning. Moreover, we also propose a meta-performance predictor to estimate and select the best architecture without direct training on target datasets. The experimental results demonstrate that our model meta-learned on subsets of ImageNet-1K and architectures from NAS-Bench 201 search space successfully generalizes to multiple benchmark datasets including CIFAR-10 and CIFAR-100, with an average search time of 33 GPU seconds. Even under a large search space, MetaD2A is 5.5K times faster than NSGANetV2, a transferable NAS method, with comparable performance. We believe that the MetaD2A proposes a new research direction for rapid NAS as well as ways to utilize the knowledge from rich databases of datasets and architectures accumulated over the past years.

Framework of MetaD2A Model

Prerequisites

  • Python 3.6 (Anaconda)
  • PyTorch 1.6.0
  • CUDA 10.2
  • python-igraph==0.8.2
  • tqdm==4.50.2
  • torchvision==0.7.0
  • python-igraph==0.8.2
  • nas-bench-201==1.3
  • scipy==1.5.2

If you are not familiar with preparing conda environment, please follow the below instructions

$ conda create --name metad2a python=3.6
$ conda activate metad2a
$ conda install pytorch==1.6.0 torchvision cudatoolkit=10.2 -c pytorch
$ pip install nas-bench-201
$ conda install -c conda-forge tqdm
$ conda install -c conda-forge python-igraph
$ pip install scipy

And for data preprocessing,

$ pip install requests

Hardware Spec used for experiments of the paper

  • GPU: A single Nvidia GeForce RTX 2080Ti
  • CPU: Intel(R) Xeon(R) Silver 4114 CPU @ 2.20GHz

NAS-Bench-201

Go to the folder for NAS-Bench-201 experiments (i.e. MetaD2A_nas_bench_201)

$ cd MetaD2A_nas_bench_201

Data Preparation

To download preprocessed data files, run get_files/get_preprocessed_data.py:

$ python get_files/get_preprocessed_data.py

It will take some time to download and preprocess each dataset.

To download MNIST, Pets and Aircraft Datasets, run get_files/get_{DATASET}.py

$ python get_files/get_mnist.py
$ python get_files/get_aircraft.py
$ python get_files/get_pets.py

Other datasets such as Cifar10, Cifar100, SVHN will be automatically downloaded when you load dataloader by torchvision.

If you want to use your own dataset, please first make your own preprocessed data, by modifying process_dataset.py .

$ process_dataset.py

MetaD2A Evaluation (Meta-Test)

You can download trained checkpoint files for generator and predictor

$ python get_files/get_checkpoint.py
$ python get_files/get_predictor_checkpoint.py

1. Evaluation on Cifar10 and Cifar100

By set --data-name as the name of dataset (i.e. cifar10, cifar100), you can evaluate the specific dataset only

# Meta-testing for generator 
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 400 --num-gen-arch 500 --data-name {DATASET_NAME}

After neural architecture generation is completed, meta-performance predictor selects high-performing architectures among the candidates

# Meta-testing for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 --test --num-gen-arch 500 --data-name {DATASET_NAME}

2. Evaluation on Other Datasets

By set --data-name as the name of dataset (i.e. mnist, svhn, aircraft, pets), you can evaluate the specific dataset only

# Meta-testing for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 400 --num-gen-arch 50 --data-name {DATASET_NAME}

After neural architecture generation is completed, meta-performance predictor selects high-performing architectures among the candidates

# Meta-testing for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 --test --num-gen-arch 50 --data-name {DATASET_NAME}

Meta-Training MetaD2A Model

You can train the generator and predictor as follows

# Meta-training for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 
                 
# Meta-training for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 

Results

The results of training architectures which are searched by meta-trained MetaD2A model for each dataset

Accuracy

CIFAR10 CIFAR100 MNIST SVHN Aircraft Oxford-IIT Pets
PC-DARTS 93.66±0.17 66.64±0.04 99.66±0.04 95.40±0.67 46.08±7.00 25.31±1.38
MetaD2A (Ours) 94.37±0.03 73.51±0.00 99.71±0.08 96.34±0.37 58.43±1.18 41.50±4.39

Search Time (GPU Sec)

CIFAR10 CIFAR100 MNIST SVHN Aircraft Oxford-IIT Pets
PC-DARTS 10395 19951 24857 31124 3524 2844
MetaD2A (Ours) 69 96 7 7 10 8

MobileNetV3 Search Space

Go to the folder for MobileNetV3 Search Space experiments (i.e. MetaD2A_mobilenetV3)

$ cd MetaD2A_mobilenetV3

And follow README.md written for experiments of MobileNetV3 Search Space

Citation

If you found the provided code useful, please cite our work.

@inproceedings{
    lee2021rapid,
    title={Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets},
    author={Hayeon Lee and Eunyoung Hyung and Sung Ju Hwang},
    booktitle={ICLR},
    year={2021}
}

Reference

Owner
Ph.D. student @ School of Computing, Korea Advanced Institute of Science and Technology (KAIST)
Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation

Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation Requirements This repository needs mmsegmentation Training To train

20 May 28, 2022
Deep Markov Factor Analysis (NeurIPS2021)

Deep Markov Factor Analysis (DMFA) Codes and experiments for deep Markov factor analysis (DMFA) model accepted for publication at NeurIPS2021: A. Farn

Sarah Ostadabbas 2 Dec 16, 2022
Repo for code associated with Modeling the Mitral Valve.

Project Title Mitral Valve Getting Started Repo for code associated with Modeling the Mitral Valve. See https://arxiv.org/abs/1902.00018 for preprint,

Alex Kaiser 1 May 17, 2022
CasualHealthcare's Pneumonia detection with Artificial Intelligence (Convolutional Neural Network)

CasualHealthcare's Pneumonia detection with Artificial Intelligence (Convolutional Neural Network) This is PneumoniaDiagnose, an artificially intellig

Azhaan 2 Jan 03, 2022
This is the workbook I created while I was studying for the Qiskit Associate Developer exam. I hope this becomes useful to others as it was for me :)

A Workbook for the Qiskit Developer Certification Exam Hello everyone! This is Bartu, a fellow Qiskitter. I have recently taken the Certification exam

Bartu Bisgin 66 Dec 10, 2022
Implementation of H-UCRL Algorithm

Implementation of H-UCRL Algorithm This repository is an implementation of the H-UCRL algorithm introduced in Curi, S., Berkenkamp, F., & Krause, A. (

Sebastian Curi 25 May 20, 2022
Efficient semidefinite bounds for multi-label discrete graphical models.

Low rank solvers #################################### benchmark/ : folder with the random instances used in the paper. ############################

1 Dec 08, 2022
Code for “ACE-HGNN: Adaptive Curvature ExplorationHyperbolic Graph Neural Network”

ACE-HGNN: Adaptive Curvature Exploration Hyperbolic Graph Neural Network This repository is the implementation of ACE-HGNN in PyTorch. Environment pyt

9 Nov 28, 2022
A web-based application for quick, scalable, and automated hyperparameter tuning and stacked ensembling in Python.

Xcessiv Xcessiv is a tool to help you create the biggest, craziest, and most excessive stacked ensembles you can think of. Stacked ensembles are simpl

Reiichiro Nakano 1.3k Nov 17, 2022
PyTorch code accompanying the paper "Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning" (NeurIPS 2021).

HIGL This is a PyTorch implementation for our paper: Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning (NeurIPS 2021). Our cod

Junsu Kim 20 Dec 14, 2022
PyTorch Language Model for 1-Billion Word (LM1B / GBW) Dataset

PyTorch Large-Scale Language Model A Large-Scale PyTorch Language Model trained on the 1-Billion Word (LM1B) / (GBW) dataset Latest Results 39.98 Perp

Ryan Spring 114 Nov 04, 2022
Pytorch implementation of "Neural Wireframe Renderer: Learning Wireframe to Image Translations"

Neural Wireframe Renderer: Learning Wireframe to Image Translations Pytorch implementation of ideas from the paper Neural Wireframe Renderer: Learning

Yuan Xue 7 Nov 14, 2022
Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides

Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides Project | This repo is the officia

CVSM Group - email: <a href=[email protected]"> 33 Dec 28, 2022
Pytorch Implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension)

DiffSinger - PyTorch Implementation PyTorch implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension). Status

Keon Lee 152 Jan 02, 2023
AntiFuzz: Impeding Fuzzing Audits of Binary Executables

AntiFuzz: Impeding Fuzzing Audits of Binary Executables Get the paper here: https://www.usenix.org/system/files/sec19-guler.pdf Usage: The python scri

Chair for Sys­tems Se­cu­ri­ty 88 Dec 21, 2022
Python code for the paper How to scale hyperparameters for quickshift image segmentation

How to scale hyperparameters for quickshift image segmentation Python code for the paper How to scale hyperparameters for quickshift image segmentatio

0 Jan 25, 2022
A boosting-based Multiple Instance Learning (MIL) package that includes MIL-Boost and MCIL-Boost

A boosting-based Multiple Instance Learning (MIL) package that includes MIL-Boost and MCIL-Boost

Jun-Yan Zhu 27 Aug 08, 2022
A light-weight image labelling tool for Python designed for creating segmentation data sets.

An image labelling tool for creating segmentation data sets, for Django and Flask.

117 Nov 21, 2022
ICCV2021 Papers with Code

ICCV2021 Papers with Code

Amusi 1.4k Jan 02, 2023
Deep Sea Treasure Environment for Multi-Objective Optimization Research

DeepSeaTreasure Environment Installation In order to get started with this environment, you can install it using the following command: python3 -m pip

imec IDLab 6 Nov 14, 2022