DrNAS: Dirichlet Neural Architecture Search

Related tags

Deep LearningDrNAS
Overview

DrNAS

About

Code accompanying the paper
ICLR'2021: DrNAS: Dirichlet Neural Architecture Search paper
Xiangning Chen*, Ruochen Wang*, Minhao Cheng*, Xiaocheng Tang, Cho-Jui Hsieh

This code is based on the implementation of NAS-Bench-201 and PC-DARTS.

This paper proposes a novel differentiable architecture search method by formulating it into a distribution learning problem. We treat the continuously relaxed architecture mixing weight as random variables, modeled by Dirichlet distribution. With recently developed pathwise derivatives, the Dirichlet parameters can be easily optimized with gradient-based optimizer in an end-to-end manner. This formulation improves the generalization ability and induces stochasticity that naturally encourages exploration in the search space. Furthermore, to alleviate the large memory consumption of differentiable NAS, we propose a simple yet effective progressive learning scheme that enables searching directly on large-scale tasks, eliminating the gap between search and evaluation phases. Extensive experiments demonstrate the effectiveness of our method. Specifically, we obtain a test error of 2.46% for CIFAR-10, 23.7% for ImageNet under the mobile setting. On NAS-Bench-201, we also achieve state-of-the-art results on all three datasets and provide insights for the effective design of neural architecture search algorithms.

Results

On NAS-Bench-201

The table below shows the test accuracy on NAS-Bench-201 space. We achieve the state-of-the-art results on all three datasets. On CIFAR-100, DrNAS even achieves the global optimal with no variance!

Method CIFAR-10 (test) CIFAR-100 (test) ImageNet-16-120 (test)
ENAS 54.30 ± 0.00 10.62 ± 0.27 16.32 ± 0.00
DARTS 54.30 ± 0.00 38.97 ± 0.00 18.41 ± 0.00
SNAS 92.77 ± 0.83 69.34 ± 1.98 43.16 ± 2.64
PC-DARTS 93.41 ± 0.30 67.48 ± 0.89 41.31 ± 0.22
DrNAS (ours) 94.36 ± 0.00 73.51 ± 0.00 46.34 ± 0.00
optimal 94.37 73.51 47.31

For every search process, we sample 100 architectures from the current Dirichlet distribution and plot their accuracy range along with the current architecture selected by Dirichlet mean (solid line). The figure below shows that the accuracy range of the sampled architectures starts very wide but narrows gradually during the search phase. It indicates that DrNAS learns to encourage exploration at the early stages and then gradually reduces it towards the end as the algorithm becomes more and more confident of the current choice. Moreover, the performance of our architectures can consistently match the best performance of the sampled architectures, indicating the effectiveness of DrNAS.

On DARTS Space (CIFAR-10)

DrNAS achieves an average test error of 2.46%, ranking top amongst recent NAS results.

Method Test Error (%) Params (M) Search Cost (GPU days)
ENAS 2.89 4.6 0.5
DARTS 2.76 ± 0.09 3.3 1.0
SNAS 2.85 ± 0.02 2.8 1.5
PC-DARTS 2.57 ± 0.07 3.6 0.1
DrNAS (ours) 2.46 ± 0.03 4.1 0.6

On DARTS Space (ImageNet)

DrNAS can perform a direct search on ImageNet and achieves a top-1 test error below 24.0%!

Method Top-1 Error (%) Params (M) Search Cost (GPU days)
DARTS* 26.7 4.7 1.0
SNAS* 27.3 4.3 1.5
PC-DARTS 24.2 5.3 3.8
DSNAS 25.7 - -
DrNAS (ours) 23.7 5.7 4.6

* not a direct search

Usage

Architecture Search

Search on NAS-Bench-201 Space: (3 datasets to choose from)

  • Data preparation: Please first download the 201 benchmark file and prepare the api follow this repository.

  • cd 201-space && python train_search.py

  • With Progressively Pruning: cd 201-space && python train_search_progressive.py

Search on DARTS Space:

  • Data preparation: For a direct search on ImageNet, we follow PC-DARTS to sample 10% and 2.5% images for earch class as train and validation.

  • CIFAR-10: cd DARTS-space && python train_search.py

  • ImageNet: cd DARTS-space && python train_search_imagenet.py

Architecture Evaluation

  • CIFAR-10: cd DARTS-space && python train.py --cutout --auxiliary

  • ImageNet: cd DARTS-space && python train_imagenet.py --auxiliary

Reference

If you find this code useful in your research please cite

@inproceedings{chen2021drnas,
    title={Dr{\{}NAS{\}}: Dirichlet Neural Architecture Search},
    author={Xiangning Chen and Ruochen Wang and Minhao Cheng and Xiaocheng Tang and Cho-Jui Hsieh},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=9FWas6YbmB3}
}

Related Publications

Owner
Xiangning Chen
UCLA CS Ph.D. Student
Xiangning Chen
Code to reproduce the experiments from our NeurIPS 2021 paper " The Limitations of Large Width in Neural Networks: A Deep Gaussian Process Perspective"

Code To run: python runner.py new --save SAVE_NAME --data PATH_TO_DATA_DIR --dataset DATASET --model model_name [options] --n 1000 - train - t

Geoff Pleiss 5 Dec 12, 2022
Bringing Characters to Life with Computer Brains in Unity

AI4Animation: Deep Learning for Character Control This project explores the opportunities of deep learning for character animation and control as part

Sebastian Starke 5.5k Jan 04, 2023
S2s2net - Sentinel-2 Super-Resolution Segmentation Network

S2S2Net Sentinel-2 Super-Resolution Segmentation Network Getting started Install

Wei Ji 10 Nov 10, 2022
A deep learning model for style-specific music generation.

DeepJ: A model for style-specific music generation https://arxiv.org/abs/1801.00887 Abstract Recent advances in deep neural networks have enabled algo

Henry Mao 704 Nov 23, 2022
YOLOv5 🚀 is a family of object detection architectures and models pretrained on the COCO dataset

YOLOv5 🚀 is a family of object detection architectures and models pretrained on the COCO dataset, and represents Ultralytics open-source research int

阿才 73 Dec 16, 2022
An OpenAI Gym environment for Super Mario Bros

gym-super-mario-bros An OpenAI Gym environment for Super Mario Bros. & Super Mario Bros. 2 (Lost Levels) on The Nintendo Entertainment System (NES) us

Andrew Stelmach 1 Jan 05, 2022
Codes for TIM2021 paper "Anchor-Based Spatio-Temporal Attention 3-D Convolutional Networks for Dynamic 3-D Point Cloud Sequences"

Codes for TIM2021 paper "Anchor-Based Spatio-Temporal Attention 3-D Convolutional Networks for Dynamic 3-D Point Cloud Sequences"

Intelligent Robotics and Machine Vision Lab 4 Jul 19, 2022
ICCV2021, Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021 Update: 2021/03/11: update our new results. Now our T2T-ViT-14 w

YITUTech 1k Dec 31, 2022
Implementation of Barlow Twins paper

barlowtwins PyTorch Implementation of Barlow Twins paper: Barlow Twins: Self-Supervised Learning via Redundancy Reduction This is currently a work in

IgorSusmelj 86 Dec 20, 2022
Prompt Tuning with Rules

PTR Code and datasets for our paper "PTR: Prompt Tuning with Rules for Text Classification" If you use the code, please cite the following paper: @art

THUNLP 118 Dec 30, 2022
Pytorch Implementation of Google's Parallel Tacotron 2: A Non-Autoregressive Neural TTS Model with Differentiable Duration Modeling

Parallel Tacotron2 Pytorch Implementation of Google's Parallel Tacotron 2: A Non-Autoregressive Neural TTS Model with Differentiable Duration Modeling

Keon Lee 170 Dec 27, 2022
Neural-net-from-scratch - A simple Neural Network from scratch in Python using the Pymathrix library

A Simple Neural Network from scratch A Simple Neural Network from scratch in Pyt

Youssef Chafiqui 2 Jan 07, 2022
Learning Skeletal Articulations with Neural Blend Shapes

This repository provides an end-to-end library for automatic character rigging and blend shapes generation as well as a visualization tool. It is based on our work Learning Skeletal Articulations wit

Peizhuo 504 Dec 30, 2022
Offline Reinforcement Learning with Implicit Q-Learning

Offline Reinforcement Learning with Implicit Q-Learning This repository contains the official implementation of Offline Reinforcement Learning with Im

Ilya Kostrikov 125 Dec 31, 2022
Udacity's CS101: Intro to Computer Science - Building a Search Engine

Udacity's CS101: Intro to Computer Science - Building a Search Engine All soluti

Phillip 0 Feb 26, 2022
Paddle-Adversarial-Toolbox (PAT) is a Python library for Deep Learning Security based on PaddlePaddle.

Paddle-Adversarial-Toolbox Paddle-Adversarial-Toolbox (PAT) is a Python library for Deep Learning Security based on PaddlePaddle. Model Zoo Common FGS

AgentMaker 17 Nov 08, 2022
Implementation for "Manga Filling Style Conversion with Screentone Variational Autoencoder" (SIGGRAPH ASIA 2020 issue)

Manga Filling with ScreenVAE SIGGRAPH ASIA 2020 | Project Website | BibTex This repository is for ScreenVAE introduced in the following paper "Manga F

30 Dec 24, 2022
Oriented Response Networks, in CVPR 2017

Oriented Response Networks [Home] [Project] [Paper] [Supp] [Poster] Torch Implementation The torch branch contains: the official torch implementation

ZhouYanzhao 217 Dec 12, 2022
DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

DeepSpeed+Megatron trained the world's most powerful language model: MT-530B DeepSpeed is hiring, come join us! DeepSpeed is a deep learning optimizat

Microsoft 8.4k Dec 28, 2022
A Transformer-Based Siamese Network for Change Detection

ChangeFormer: A Transformer-Based Siamese Network for Change Detection (Under review at IGARSS-2022) Wele Gedara Chaminda Bandara, Vishal M. Patel Her

Wele Gedara Chaminda Bandara 214 Dec 29, 2022