CATE: Computation-aware Neural Architecture Encoding with Transformers

Overview

CATE: Computation-aware Neural Architecture Encoding with Transformers

Code for paper:

CATE: Computation-aware Neural Architecture Encoding with Transformers
Shen Yan, Kaiqiang Song, Fei Liu, Mi Zhang.
ICML 2021 (Long Talk).

CATE
Overview of CATE: It takes computationally similar architecture pairs as the input and trained to predict masked operators given the pairwise computation information. Apart from the cross-attention blocks, the pretrained Transformer encoder is used to extract architecture encodings for the downstream search.

The repository is built upon pybnn and nas-encodings.

Requirements

conda create -n tf python=3.7
source activate tf
cat requirements.txt | xargs -n 1 -L 1 pip install

Experiments on NAS-Bench-101

Dataset preparation on NAS-Bench-101

Install nasbench and download nasbench_only108.tfrecord in ./data folder.

python preprocessing/gen_json.py

Data will be saved in ./data/nasbench101.json.

Generate architecture pairs

python preprocessing/data_generate.py --dataset nasbench101 --flag extract_seq
python preprocessing/data_generate.py --dataset nasbench101 --flag build_pair --k 2 --d 2000000 --metric params

The corresponding training data and pairs will be saved in ./data/nasbench101/.

Alternatively, you can download the data train_data.pt, test_data.pt and pair indices train_pair_k2_d2000000_metric_params.pt, test_pair_k2_d2000000_metric_params.pt from here.

Pretraining

bash run_scripts/pretrain_nasbench101.sh

The pretrained models will be saved in ./model/.

Alternatively, you can download the pretrained model nasbench101_model_best.pth from here.

Extract the pretrained encodings

python inference/inference.py --pretrained_path model/nasbench101_model_best.pth.tar --train_data data/nasbench101/train_data.pt --valid_data data/nasbench101/test_data.pt --dataset nasbench101

The extracted embeddings will be saved in ./cate_nasbench101.pt.

Alternatively, you can download the pretrained embeddings cate_nasbench101.pt from here.

Run search experiments on NAS-Bench-101

bash run_scripts/run_search_nasbench101.sh

Search results will be saved in ./nasbench101/.

Experiments on NAS-Bench-301

Dataset preparation

Install nasbench301 and download the xgb_v1.0 and lgb_runtime_v1.0 file. You may need to make pytorch_geometric compatible with Pytorch and CUDA version.

python preprocessing/gen_json_darts.py # randomly sample 1,000,000 archs

Data will be saved in ./data/nasbench301_proxy.json.

Alternatively, you can download the json file nasbench301_proxy.json from here.

Generate architecture pairs

python preprocessing/data_generate.py --dataset nasbench301 --flag extract_seq
python preprocessing/data_generate.py --dataset nasbench301 --flag build_pair --k 1 --d 5000000 --metric flops

The correspoding training data and pairs will be saved in ./data/nasbench301/.

Alternatively, you can download the data train_data.pt, test_data.pt and pair indices train_pair_k1_d5000000_metric_flops.pt, test_pair_k1_d5000000_metric_flops.pt from here.

Pretraining

bash run_scripts/pretrain_nasbench301.sh

The pretrained models will be saved in ./model/.

Alternatively, you can download the pretrained model nasbench301_model_best.pth from here.

Extract the pretrained encodings

python inference/inference.py --pretrained_path model/nasbench301_model_best.pth.tar --train_data data/nasbench301/train_data.pt --valid_data data/nasbench301/test_data.pt --dataset nasbench301 --n_vocab 11

The extracted encodings will be saved in ./cate_nasbench301.pt.

Alternatively, you can download the pretrained embeddings cate_nasbench301.pt from here.

Run search experiments on NAS-Bench-301

bash run_scripts/run_search_nasbench301.sh

Search results will be saved in ./nasbench301/.

DARTS experiments without surrogate models

Download the pretrained embeddings cate_darts.pt from here.

python search_methods/dngo_ls_darts.py --dim 64 --init_size 16 --topk 5 --dataset darts --output_path bo  --embedding_path cate_darts.pt

Search log will be saved in ./darts/. Final search result will be saved in ./darts/bo/dim64.

Evaluate the learned cell on DARTS Search Space on CIFAR-10

python darts/cnn/train.py --auxiliary --cutout --arch cate_small
python darts/cnn/train.py --auxiliary --cutout --arch cate_large
  • Expected results (CATE-Small): 2.55% avg. test error with 3.5M model params.
  • Expected results (CATE-Large): 2.46% avg. test error with 4.1M model params.

Transfer learning on ImageNet

python darts/cnn/train_imagenet.py  --arch cate_small --seed 1 
python darts/cnn/train_imagenet.py  --arch cate_large --seed 1
  • Expected results (CATE-Small): 26.05% test error with 5.0M model params and 556M mult-adds.
  • Expected results (CATE-Large): 25.01% test error with 5.8M model params and 642M mult-adds.

Visualize the learned cell

python darts/cnn/visualize.py cate_small
python darts/cnn/visualize.py cate_large

Experiments on outside search space

Build outside search space dataset

bash run_scripts/generate_oo.sh

Data will be saved in ./data/nasbench101_oo_train.json and ./data/nasbench101_oo_test.json.

Generate architecture pairs

python preprocessing/data_generate_oo.py --flag extract_seq
python preprocessing/data_generate_oo.py --flag build_pair

The corresponding training data and pair indices will be saved in ./data/nasbench101/.

Pretraining

python run.py --do_train --parallel --train_data data/nasbench101/nasbench101_oo_trainSet_train.pt --train_pair data/nasbench101/oo_train_pairs_k2_params_dist2e6.pt  --valid_data data/nasbench101/nasbench101_oo_trainSet_validation.pt --valid_pair data/nasbench101/oo_validation_pairs_k2_params_dist2e6.pt --dataset oo

The pretrained models will be saved in ./model/.

Extract embeddings on outside search space

# Adjacency encoding
python inference/inference_adj.py
# CATE encoding
python inference/inference.py --pretrained_path model/oo_model_best.pth.tar --train_data data/nasbench101/nasbench101_oo_testSet_split1.pt --valid_data data/nasbench101/nasbench101_oo_testSet_split2.pt --dataset oo_nasbench101

The extracted encodings will be saved as ./adj_oo_nasbench101.pt and ./cate_oo_nasbench101.pt.

Alternatively, you can download the data, pair indices, pretrained models, and extracted embeddings from here.

Run MLP predictor experiments on outside search space

for s in {1..500}; do python search_methods/oo_mlp.py --dim 27 --seed $s --init_size 16 --topk 5 --dataset oo_nasbench101 --output_path np_adj  --embedding_path adj_oo_nasbench101.pt; done
for s in {1..500}; do python search_methods/oo_mlp.py --dim 64 --seed $s --init_size 16 --topk 5 --dataset oo_nasbench101 --output_path np_cate  --embedding_path cate_oo_nasbench101.pt; done

Search results will be saved in ./oo_nasbench101.

Citation

If you find this useful for your work, please consider citing:

@InProceedings{yan2021cate,
  title = {CATE: Computation-aware Neural Architecture Encoding with Transformers},
  author = {Yan, Shen and Song, Kaiqiang and Liu, Fei and Zhang, Mi},
  booktitle = {ICML},
  year = {2021}
}
Ground truth data for the Optical Character Recognition of Historical Classical Commentaries.

OCR Ground Truth for Historical Commentaries The dataset OCR ground truth for historical commentaries (GT4HistComment) was created from the public dom

Ajax Multi-Commentary 3 Sep 08, 2022
Lightweight, Python library for fast and reproducible experimentation :microscope:

Steppy What is Steppy? Steppy is a lightweight, open-source, Python 3 library for fast and reproducible experimentation. Steppy lets data scientist fo

minerva.ml 134 Jul 10, 2022
This is the codebase for the ICLR 2021 paper Trajectory Prediction using Equivariant Continuous Convolution

Trajectory Prediction using Equivariant Continuous Convolution (ECCO) This is the codebase for the ICLR 2021 paper Trajectory Prediction using Equivar

Spatiotemporal Machine Learning 45 Jul 22, 2022
Deep Neural Networks Improve Radiologists' Performance in Breast Cancer Screening

Deep Neural Networks Improve Radiologists' Performance in Breast Cancer Screening Introduction This is an implementation of the model used for breast

757 Dec 30, 2022
Official Pytorch implementation of "CLIPstyler:Image Style Transfer with a Single Text Condition"

CLIPstyler Official Pytorch implementation of "CLIPstyler:Image Style Transfer with a Single Text Condition" Environment Pytorch 1.7.1, Python 3.6 $ c

201 Dec 29, 2022
Torch-based tool for quantizing high-dimensional vectors using additive codebooks

Trainable multi-codebook quantization This repository implements a utility for use with PyTorch, and ideally GPUs, for training an efficient quantizer

Daniel Povey 41 Jan 07, 2023
PyTorch implementation of "Continual Learning with Deep Generative Replay", NIPS 2017

pytorch-deep-generative-replay PyTorch implementation of Continual Learning with Deep Generative Replay, NIPS 2017 Results Continual Learning on Permu

Junsoo Ha 127 Dec 14, 2022
🤖 A Python library for learning and evaluating knowledge graph embeddings

PyKEEN PyKEEN (Python KnowlEdge EmbeddiNgs) is a Python package designed to train and evaluate knowledge graph embedding models (incorporating multi-m

PyKEEN 1.1k Jan 09, 2023
Code release for SLIP Self-supervision meets Language-Image Pre-training

SLIP: Self-supervision meets Language-Image Pre-training What you can find in this repo: Pre-trained models (with ViT-Small, Base, Large) and code to

Meta Research 621 Dec 31, 2022
Official Implementation of "DialogLM: Pre-trained Model for Long Dialogue Understanding and Summarization."

DialogLM Code for AAAI 2022 paper: DialogLM: Pre-trained Model for Long Dialogue Understanding and Summarization. Pre-trained Models We release two ve

Microsoft 92 Dec 19, 2022
This repository contains the code and models for the following paper.

DC-ShadowNet Introduction This is an implementation of the following paper DC-ShadowNet: Single-Image Hard and Soft Shadow Removal Using Unsupervised

AuAgCu 65 Dec 27, 2022
Unofficial PyTorch Implementation for HifiFace (https://arxiv.org/abs/2106.09965)

HifiFace — Unofficial Pytorch Implementation Image source: HifiFace: 3D Shape and Semantic Prior Guided High Fidelity Face Swapping (figure 1, pg. 1)

MINDs Lab 218 Jan 04, 2023
Implementation of Self-supervised Graph-level Representation Learning with Local and Global Structure (ICML 2021).

Self-supervised Graph-level Representation Learning with Local and Global Structure Introduction This project is an implementation of ``Self-supervise

MilaGraph 50 Dec 09, 2022
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Jan 01, 2023
Matthew Colbrook 1 Apr 08, 2022
Task Transformer Network for Joint MRI Reconstruction and Super-Resolution (MICCAI 2021)

T2Net Task Transformer Network for Joint MRI Reconstruction and Super-Resolution (MICCAI 2021) [Paper][Code] Dependencies numpy==1.18.5 scikit_image==

64 Nov 23, 2022
Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis

Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis [Paper] [Online Demo] The following results are obtained by our SCUNet with purely syn

Kai Zhang 312 Jan 07, 2023
QA-GNN: Question Answering using Language Models and Knowledge Graphs

QA-GNN: Question Answering using Language Models and Knowledge Graphs This repo provides the source code & data of our paper: QA-GNN: Reasoning with L

Michihiro Yasunaga 434 Jan 04, 2023
Code for "Intra-hour Photovoltaic Generation Forecasting based on Multi-source Data and Deep Learning Methods."

pv_predict_unet-lstm Code for "Intra-hour Photovoltaic Generation Forecasting based on Multi-source Data and Deep Learning Methods." IEEE Transactions

FolkScientistInDL 8 Oct 08, 2022