Differentiable architecture search for convolutional and recurrent networks

Overview

Differentiable Architecture Search

Code accompanying the paper

DARTS: Differentiable Architecture Search
Hanxiao Liu, Karen Simonyan, Yiming Yang.
arXiv:1806.09055.

darts

The algorithm is based on continuous relaxation and gradient descent in the architecture space. It is able to efficiently design high-performance convolutional architectures for image classification (on CIFAR-10 and ImageNet) and recurrent architectures for language modeling (on Penn Treebank and WikiText-2). Only a single GPU is required.

Requirements

Python >= 3.5.5, PyTorch == 0.3.1, torchvision == 0.2.0

NOTE: PyTorch 0.4 is not supported at this moment and would lead to OOM.

Datasets

Instructions for acquiring PTB and WT2 can be found here. While CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) following the instructions here.

Pretrained models

The easist way to get started is to evaluate our pretrained DARTS models.

CIFAR-10 (cifar10_model.pt)

cd cnn && python test.py --auxiliary --model_path cifar10_model.pt
  • Expected result: 2.63% test error rate with 3.3M model params.

PTB (ptb_model.pt)

cd rnn && python test.py --model_path ptb_model.pt
  • Expected result: 55.68 test perplexity with 23M model params.

ImageNet (imagenet_model.pt)

cd cnn && python test_imagenet.py --auxiliary --model_path imagenet_model.pt
  • Expected result: 26.7% top-1 error and 8.7% top-5 error with 4.7M model params.

Architecture search (using small proxy models)

To carry out architecture search using 2nd-order approximation, run

cd cnn && python train_search.py --unrolled     # for conv cells on CIFAR-10
cd rnn && python train_search.py --unrolled     # for recurrent cells on PTB

Note the validation performance in this step does not indicate the final performance of the architecture. One must train the obtained genotype/architecture from scratch using full-sized models, as described in the next section.

Also be aware that different runs would end up with different local minimum. To get the best result, it is crucial to repeat the search process with different seeds and select the best cell(s) based on validation performance (obtained by training the derived cell from scratch for a small number of epochs). Please refer to fig. 3 and sect. 3.2 in our arXiv paper.

progress_convolutional_normal progress_convolutional_reduce progress_recurrent

Figure: Snapshots of the most likely normal conv, reduction conv, and recurrent cells over time.

Architecture evaluation (using full-sized models)

To evaluate our best cells by training from scratch, run

cd cnn && python train.py --auxiliary --cutout            # CIFAR-10
cd rnn && python train.py                                 # PTB
cd rnn && python train.py --data ../data/wikitext-2 \     # WT2
            --dropouth 0.15 --emsize 700 --nhidlast 700 --nhid 700 --wdecay 5e-7
cd cnn && python train_imagenet.py --auxiliary            # ImageNet

Customized architectures are supported through the --arch flag once specified in genotypes.py.

The CIFAR-10 result at the end of training is subject to variance due to the non-determinism of cuDNN back-prop kernels. It would be misleading to report the result of only a single run. By training our best cell from scratch, one should expect the average test error of 10 independent runs to fall in the range of 2.76 +/- 0.09% with high probability.

cifar10 ptb ptb

Figure: Expected learning curves on CIFAR-10 (4 runs), ImageNet and PTB.

Visualization

Package graphviz is required to visualize the learned cells

python visualize.py DARTS

where DARTS can be replaced by any customized architectures in genotypes.py.

Citation

If you use any part of this code in your research, please cite our paper:

@article{liu2018darts,
  title={DARTS: Differentiable Architecture Search},
  author={Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
  journal={arXiv preprint arXiv:1806.09055},
  year={2018}
}
Owner
Hanxiao Liu
Research Scientist @ Google Brain
Hanxiao Liu
DeepMoCap: Deep Optical Motion Capture using multiple Depth Sensors and Retro-reflectors

DeepMoCap: Deep Optical Motion Capture using multiple Depth Sensors and Retro-reflectors By Anargyros Chatzitofis, Dimitris Zarpalas, Stefanos Kollias

tofis 24 Oct 08, 2022
Visual odometry package based on hardware-accelerated NVIDIA Elbrus library with world class quality and performance.

Isaac ROS Visual Odometry This repository provides a ROS2 package that estimates stereo visual inertial odometry using the Isaac Elbrus GPU-accelerate

NVIDIA Isaac ROS 343 Jan 03, 2023
This is a yolo3 implemented via tensorflow 2.7

YoloV3 - an object detection algorithm implemented via TF 2.x source code In this article I assume you've already familiar with basic computer vision

2 Jan 17, 2022
Predicting Auction Sale Price using the kaggle bulldozer auction sales data: Modeling with Ensembles vs Neural Network

Predicting Auction Sale Price using the kaggle bulldozer auction sales data: Modeling with Ensembles vs Neural Network The performances of tree ensemb

Mustapha Unubi Momoh 2 Sep 13, 2022
This repository contains the data and code for the paper "Diverse Text Generation via Variational Encoder-Decoder Models with Gaussian Process Priors" ([email protected])

GP-VAE This repository provides datasets and code for preprocessing, training and testing models for the paper: Diverse Text Generation via Variationa

Wanyu Du 18 Dec 29, 2022
The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

machen 11 Nov 27, 2022
Kaggle | 9th place (part of) solution for the Bristol-Myers Squibb – Molecular Translation challenge

Part of the 9th place solution for the Bristol-Myers Squibb – Molecular Translation challenge translating images containing chemical structures into I

Erdene-Ochir Tuguldur 22 Nov 30, 2022
DeepLab resnet v2 model in pytorch

pytorch-deeplab-resnet DeepLab resnet v2 model implementation in pytorch. The architecture of deepLab-ResNet has been replicated exactly as it is from

Isht Dwivedi 601 Dec 22, 2022
DeepHyper: Scalable Asynchronous Neural Architecture and Hyperparameter Search for Deep Neural Networks

What is DeepHyper? DeepHyper is a software package that uses learning, optimization, and parallel computing to automate the design and development of

DeepHyper Team 214 Jan 08, 2023
Code for "Adversarial Training for a Hybrid Approach to Aspect-Based Sentiment Analysis

HAABSAStar Code for "Adversarial Training for a Hybrid Approach to Aspect-Based Sentiment Analysis". This project builds on the code from https://gith

1 Sep 14, 2020
Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction".

TGIN Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction". Files in the folder dataset/ electr

Alibaba 21 Dec 21, 2022
Automated detection of anomalous exoplanet transits in light curve data.

Automatically detecting anomalous exoplanet transits This repository contains the source code for the paper "Automatically detecting anomalous exoplan

1 Feb 01, 2022
A testcase generation tool for Persistent Memory Programs.

PMFuzz PMFuzz is a testcase generation tool to generate high-value tests cases for PM testing tools (XFDetector, PMDebugger, PMTest and Pmemcheck) If

Systems Research at ShiftLab 14 Jul 24, 2022
Hierarchical Aggregation for 3D Instance Segmentation (ICCV 2021)

HAIS Hierarchical Aggregation for 3D Instance Segmentation (ICCV 2021) by Shaoyu Chen, Jiemin Fang, Qian Zhang, Wenyu Liu, Xinggang Wang*. (*) Corresp

Hust Visual Learning Team 145 Jan 05, 2023
A keras-based real-time model for medical image segmentation (CFPNet-M)

CFPNet-M: A Light-Weight Encoder-Decoder Based Network for Multimodal Biomedical Image Real-Time Segmentation This repository contains the implementat

268 Nov 27, 2022
An efficient PyTorch implementation of the evaluation metrics in recommender systems.

recsys_metrics An efficient PyTorch implementation of the evaluation metrics in recommender systems. Overview • Installation • How to use • Benchmark

Xingdong Zuo 12 Dec 02, 2022
This repository consists of Blender python scripts and corresponding assets to generate variants of the CANDLE dataset

candle-simulator This repository consists of Blender python scripts and corresponding assets to generate variants of the IITH-CANDLE dataset. The rend

1 Dec 15, 2021
PyTorch implementation of DirectCLR from paper Understanding Dimensional Collapse in Contrastive Self-supervised Learning

DirectCLR DirectCLR is a simple contrastive learning model for visual representation learning. It does not require a trainable projector as SimCLR. It

Meta Research 49 Dec 21, 2022
Deformable DETR is an efficient and fast-converging end-to-end object detector.

Deformable DETR: Deformable Transformers for End-to-End Object Detection.

2k Jan 05, 2023
State-of-the-art data augmentation search algorithms in PyTorch

MuarAugment Description MuarAugment is a package providing the easiest way to a state-of-the-art data augmentation pipeline. How to use You can instal

43 Dec 12, 2022