PyTorch implementation of "Efficient Neural Architecture Search via Parameters Sharing"

Overview

Efficient Neural Architecture Search (ENAS) in PyTorch

PyTorch implementation of Efficient Neural Architecture Search via Parameters Sharing.

ENAS_rnn

ENAS reduce the computational requirement (GPU-hours) of Neural Architecture Search (NAS) by 1000x via parameter sharing between models that are subgraphs within a large computational graph. SOTA on Penn Treebank language modeling.

**[Caveat] Use official code from the authors: link**

Prerequisites

  • Python 3.6+
  • PyTorch==0.3.1
  • tqdm, scipy, imageio, graphviz, tensorboardX

Usage

Install prerequisites with:

conda install graphviz
pip install -r requirements.txt

To train ENAS to discover a recurrent cell for RNN:

python main.py --network_type rnn --dataset ptb --controller_optim adam --controller_lr 0.00035 \
               --shared_optim sgd --shared_lr 20.0 --entropy_coeff 0.0001

python main.py --network_type rnn --dataset wikitext

To train ENAS to discover CNN architecture (in progress):

python main.py --network_type cnn --dataset cifar --controller_optim momentum --controller_lr_cosine=True \
               --controller_lr_max 0.05 --controller_lr_min 0.0001 --entropy_coeff 0.1

or you can use your own dataset by placing images like:

data
├── YOUR_TEXT_DATASET
│   ├── test.txt
│   ├── train.txt
│   └── valid.txt
├── YOUR_IMAGE_DATASET
│   ├── test
│   │   ├── xxx.jpg (name doesn't matter)
│   │   ├── yyy.jpg (name doesn't matter)
│   │   └── ...
│   ├── train
│   │   ├── xxx.jpg
│   │   └── ...
│   └── valid
│       ├── xxx.jpg
│       └── ...
├── image.py
└── text.py

To generate gif image of generated samples:

python generate_gif.py --model_name=ptb_2018-02-15_11-20-02 --output=sample.gif

More configurations can be found here.

Results

Efficient Neural Architecture Search (ENAS) is composed of two sets of learnable parameters, controller LSTM θ and the shared parameters ω. These two parameters are alternatively trained and only trained controller is used to derive novel architectures.

1. Discovering Recurrent Cells

rnn

Controller LSTM decide 1) what activation function to use and 2) which previous node to connect.

The RNN cell ENAS discovered for Penn Treebank and WikiText-2 dataset:

ptb wikitext

Best discovered ENAS cell for Penn Treebank at epoch 27:

ptb

You can see the details of training (e.g. reward, entropy, loss) with:

tensorboard --logdir=logs --port=6006

2. Discovering Convolutional Neural Networks

cnn

Controller LSTM samples 1) what computation operation to use and 2) which previous node to connect.

The CNN network ENAS discovered for CIFAR-10 dataset:

(in progress)

3. Designing Convolutional Cells

(in progress)

Reference

Author

Taehoon Kim / @carpedm20

Owner
Taehoon Kim
ex OpenAI
Taehoon Kim
Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning

Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning This is the official repository for Conservative and Adaptive Penalty fo

7 Nov 22, 2022
This repository contains the code used for the implementation of the paper "Probabilistic Regression with HuberDistributions"

Public_prob_regression_with_huber_distributions This repository contains the code used for the implementation of the paper "Probabilistic Regression w

David Mohlin 1 Dec 04, 2021
Code for EMNLP'21 paper "Types of Out-of-Distribution Texts and How to Detect Them"

ood-text-emnlp Code for EMNLP'21 paper "Types of Out-of-Distribution Texts and How to Detect Them" Files fine_tune.py is used to finetune the GPT-2 mo

Udit Arora 19 Oct 28, 2022
A simple, clean TensorFlow implementation of Generative Adversarial Networks with a focus on modeling illustrations.

IllustrationGAN A simple, clean TensorFlow implementation of Generative Adversarial Networks with a focus on modeling illustrations. Generated Images

268 Nov 27, 2022
Geometric Algebra package for JAX

JAXGA - JAX Geometric Algebra GitHub | Docs JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing onl

Robin Kahlow 36 Dec 22, 2022
Pytorch implementation for "Adversarial Robustness under Long-Tailed Distribution" (CVPR 2021 Oral)

Adversarial Long-Tail This repository contains the PyTorch implementation of the paper: Adversarial Robustness under Long-Tailed Distribution, CVPR 20

Tong WU 89 Dec 15, 2022
This repository contains the entire code for our work "Two-Timescale End-to-End Learning for Channel Acquisition and Hybrid Precoding"

Two-Timescale-DNN Two-Timescale End-to-End Learning for Channel Acquisition and Hybrid Precoding This repository contains the entire code for our work

QiyuHu 3 Mar 07, 2022
Official implementation of the ICLR 2021 paper

You Only Need Adversarial Supervision for Semantic Image Synthesis Official PyTorch implementation of the ICLR 2021 paper "You Only Need Adversarial S

Bosch Research 272 Dec 28, 2022
tensorflow implementation of 'YOLO : Real-Time Object Detection'

YOLO_tensorflow (Version 0.3, Last updated :2017.02.21) 1.Introduction This is tensorflow implementation of the YOLO:Real-Time Object Detection It can

Jinyoung Choi 1.7k Nov 21, 2022
RITA is a family of autoregressive protein models, developed by LightOn in collaboration with the OATML group at Oxford and the Debora Marks Lab at Harvard.

RITA: a Study on Scaling Up Generative Protein Sequence Models RITA is a family of autoregressive protein models, developed by a collaboration of Ligh

LightOn 69 Dec 22, 2022
Google-drive-to-sqlite - Create a SQLite database containing metadata from Google Drive

google-drive-to-sqlite Create a SQLite database containing metadata from Google

Simon Willison 140 Dec 04, 2022
A unified framework for machine learning with time series

Welcome to sktime A unified framework for machine learning with time series We provide specialized time series algorithms and scikit-learn compatible

The Alan Turing Institute 6k Jan 08, 2023
Official repository for the paper "GN-Transformer: Fusing AST and Source Code information in Graph Networks".

GN-Transformer AST This is the official repository for the paper "GN-Transformer: Fusing AST and Source Code information in Graph Networks". Data Prep

Cheng Jun-Yan 10 Nov 26, 2022
A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch

A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch The official pytorch implementation of the paper "Towards Faster and Stabilize

Bingchen Liu 455 Jan 08, 2023
Kaggle competition: Springleaf Marketing Response

PruebaEnel Prueba Kaggle-Springleaf-master Prueba Kaggle-Springleaf Kaggle competition: Springleaf Marketing Response Competencia de Kaggle: Marketing

1 Feb 09, 2022
Official Chainer implementation of GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral)

GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral) [Project] [Paper] [Demo] [Related Work: A2RL (for Auto Image Cropping)] [C

Wu Huikai 402 Dec 27, 2022
Repo for the paper Extrapolating from a Single Image to a Thousand Classes using Distillation

Extrapolating from a Single Image to a Thousand Classes using Distillation by Yuki M. Asano* and Aaqib Saeed* (*Equal Contribution) Extrapolating from

Yuki M. Asano 16 Nov 04, 2022
Multi-task head pose estimation in-the-wild

Multi-task head pose estimation in-the-wild We provide C++ code in order to replicate the head-pose experiments in our paper https://ieeexplore.ieee.o

Roberto Valle 26 Oct 06, 2022
This is an unofficial implementation of the paper “Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection”.

This is an unofficial implementation of the paper “Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection”.

haifeng xia 32 Oct 26, 2022
Codes for 'Dual Parameterization of Sparse Variational Gaussian Processes'

Dual Parameterization of Sparse Variational Gaussian Processes Documentation | Notebooks | API reference Introduction This repository is the official

AaltoML 7 Dec 23, 2022