Source code of the paper Meta-learning with an Adaptive Task Scheduler.

Related tags

Deep LearningATS
Overview

ATS

About

Source code of the paper Meta-learning with an Adaptive Task Scheduler.

If you find this repository useful in your research, please cite the following paper:

@inproceedings{yao2021adaptive,
  title={Meta-learning with an Adaptive Task Scheduler},
  author={Yao, Huaxiu and Wang, Yu and Wei, Ying and Zhao, Peilin and Mahdavi, Mehrdad and Lian, Defu and Finn, Chelsea},
  booktitle={Proceedings of the Thirty-fifth Conference on Neural Information Processing Systems},
  year={2021} 
}

Miniimagenet

The processed miniimagenet dataset could be downloaded here. Assume the dataset has been downloaded and unzipped to /data/miniimagenet, which has the following file structure:

-- miniimagenet  // /data/miniimagenet
  -- miniImagenet
    -- train_task_id.pkl
    -- test_task_id.pkl
    -- mini_imagenet_train.pkl
    -- mini_imagenet_test.pkl
    -- mini_imagenet_val.pkl
    -- training_classes_20000_2_new.npz
    -- training_classes_20000_4_new.npz

Then $datadir in the following code sould be set to /data/miniimagenet.

ATS with noise = 0.6

We need to first pretrain the model with no noise. The model has been uploaded to this repo. You can also pretrain the model by yourself. The script for pretraining is as follows:
(1) 1 shot:

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.0

(2) 5 shot:

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.0

Then move the model to the current directory:
(1) 1 shot:

mv $logdir/ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_1.metalr0.001.innerlr0.01.hidden32/model20000 ./model20000_1shot

(2) 5 shot:

mv $logdir/ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_5.metalr0.001.innerlr0.01.hidden32/model10000 ./model10000_5shot

Then with this model, we could run the uniform sampling and ATS sampling. For ATS, the script is:
(1) 1 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --replace 0 --noise 0.6 --logdir $logdir --sampling_method ATS --buffer_size 10  --temperature 0.1 --scheduler_lr 0.001 --warmup 2000 --pretrain_iter 20000

(2) 5 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --noise 0.6 --logdir $logdir --sampling_method ATS --buffer_size 10 --utility_function sample --temperature 0.1 --scheduler_lr 0.001 --warmup 2000 --pretrain_iter 10000

For uniform sampling, we need to use the validation set to finetune the model trained under uniform sampling. The training commands are:
(1) 1 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6
mkdir models
mv ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_1.metalr0.001.innerlr0.01.hidden32_noise0.6/model30000 ./models/ANIL_0.4_model_1shot
python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6 --finetune

(2) 5 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6
mkdir models  // if directory "models" does not exist
mv ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_5.metalr0.001.innerlr0.01.hidden32_noise0.6/model30000 ./models/ANIL_0.4_model_5shot
python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6 --finetune

ATS with limited budgets

In this setting, pretraining is not needed. You can directly run the following code:
uniform sampling, 1 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --limit_data 1 --logdir ../train_logs --limit_classes 16

uniform sampling, 5 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --limit_data 1 --logdir ../train_logs --limit_classes 16

ATS 1 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --limit_data 1 --logdir ../train_logs --sampling_method ATS --buffer_size 6 --utility_function sample --temperature 1 --warmup 0 --limit_classes 16

ATS 5 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --limit_data 1 --logdir ../train_logs --sampling_method ATS --buffer_size 6 --utility_function sample --temperature 0.1 --warmup 0 --limit_classes 16

Drug

The processed dataset could be downloaded here. Assume the dataset has been downloaded and unzipped to /data/drug which has the following structure:

-- drug  // /data/drug
  -- ci9b00375_si_001.txt  
  -- compound_fp.npy               
  -- drug_split_id_group2.pickle  
  -- drug_split_id_group6.pickle
  -- ci9b00375_si_002.txt  
  -- drug_split_id_group17.pickle  
  -- drug_split_id_group3.pickle  
  -- drug_split_id_group9.pickle
  -- ci9b00375_si_003.txt  
  -- drug_split_id_group1.pickle   
  -- drug_split_id_group4.pickle  
  -- important_readme.md

Then $datadir in the following script should be set as /data/.

ATS with noise=4.

Uniform Sampling:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --noise 4 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --noise 4 --data_dir $datadir --train 0

ATS:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --noise 4 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --noise 4 --data_dir $datadir --train 0

ATS with full budgets

Uniform Sampling:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --data_dir $datadir --train 0

ATS:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --data_dir $datadir --train 0

For ATS, if you need to use 1 for calculating the loss as the input of the scheduler instead of 1, you can add --simple_loss after the script above.

Owner
Huaxiu Yao
Postdoctoral Scholar at [email protected]
Huaxiu Yao
PyTorch implementation of MoCo: Momentum Contrast for Unsupervised Visual Representation Learning

MoCo: Momentum Contrast for Unsupervised Visual Representation Learning This is a PyTorch implementation of the MoCo paper: @Article{he2019moco, aut

Meta Research 3.7k Jan 02, 2023
This is the official repository for our paper: ''Pruning Self-attentions into Convolutional Layers in Single Path''.

Pruning Self-attentions into Convolutional Layers in Single Path This is the official repository for our paper: Pruning Self-attentions into Convoluti

Zhuang AI Group 77 Dec 26, 2022
competitions-v2

Codabench (formerly Codalab Competitions v2) Installation $ cp .env_sample .env $ docker-compose up -d $ docker-compose exec django ./manage.py migrat

CodaLab 21 Dec 02, 2022
CLOCs: Camera-LiDAR Object Candidates Fusion for 3D Object Detection

CLOCs is a novel Camera-LiDAR Object Candidates fusion network. It provides a low-complexity multi-modal fusion framework that improves the performance of single-modality detectors. CLOCs operates on

Su Pang 254 Dec 16, 2022
OpenGAN: Open-Set Recognition via Open Data Generation

OpenGAN: Open-Set Recognition via Open Data Generation ICCV 2021 (oral) Real-world machine learning systems need to analyze novel testing data that di

Shu Kong 90 Jan 06, 2023
Code and Datasets from the paper "Self-supervised contrastive learning for volcanic unrest detection from InSAR data"

Code and Datasets from the paper "Self-supervised contrastive learning for volcanic unrest detection from InSAR data" You can download the pretrained

Bountos Nikos 3 May 07, 2022
SingleVC performs any-to-one VC, which is an important component of MediumVC project.

SingleVC performs any-to-one VC, which is an important component of MediumVC project. Here is the official implementation of the paper, MediumVC.

谷下雨 26 Dec 28, 2022
EfficientMPC - Efficient Model Predictive Control Implementation

efficientMPC Efficient Model Predictive Control Implementation The original algo

Vin 8 Dec 04, 2022
Code release for ConvNeXt model

A ConvNet for the 2020s Official PyTorch implementation of ConvNeXt, from the following paper: A ConvNet for the 2020s. arXiv 2022. Zhuang Liu, Hanzi

Meta Research 4.6k Jan 08, 2023
3D-Transformer: Molecular Representation with Transformer in 3D Space

3D-Transformer: Molecular Representation with Transformer in 3D Space

55 Dec 19, 2022
Honours project, on creating a depth estimation map from two stereo images of featureless regions

image-processing This module generates depth maps for shape-blocked-out images Install If working with anaconda, then from the root directory: conda e

2 Oct 17, 2022
A project that uses optical flow and machine learning to detect aimhacking in video clips.

waldo-anticheat A project that aims to use optical flow and machine learning to visually detect cheating or hacking in video clips from fps games. Che

waldo.vision 542 Dec 03, 2022
Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr

Google 729 Jan 02, 2023
A Traffic Sign Recognition Project which can help the driver recognise the signs via text as well as audio. Can be used at Night also.

Traffic-Sign-Recognition In this report, we propose a Convolutional Neural Network(CNN) for traffic sign classification that achieves outstanding perf

Mini Project 64 Nov 19, 2022
Understanding Convolution for Semantic Segmentation

TuSimple-DUC by Panqu Wang, Pengfei Chen, Ye Yuan, Ding Liu, Zehua Huang, Xiaodi Hou, and Garrison Cottrell. Introduction This repository is for Under

TuSimple 585 Dec 31, 2022
A full-fledged version of Pix2Seq

Stable-Pix2Seq A full-fledged version of Pix2Seq What it is. This is a full-fledged version of Pix2Seq. Compared with unofficial-pix2seq, stable-pix2s

peng gao 205 Dec 27, 2022
PyTorch implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 13.4k Jan 08, 2023
StableSims is an open-source project aimed at simulating MakerDAO's Dai stablecoin system

StableSims is an open-source project aimed at simulating MakerDAO's Dai stablecoin system, initially used for researching optimal incentive parameters for Liquidations 2.0.

Blockchain at Berkeley 52 Nov 21, 2022
[v1 (ISBI'21) + v2] MedMNIST: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification

MedMNIST Project (Website) | Dataset (Zenodo) | Paper (arXiv) | MedMNIST v1 (ISBI'21) Jiancheng Yang, Rui Shi, Donglai Wei, Zequan Liu, Lin Zhao, Bili

683 Dec 28, 2022
ImVoxelNet: Image to Voxels Projection for Monocular and Multi-View General-Purpose 3D Object Detection

ImVoxelNet: Image to Voxels Projection for Monocular and Multi-View General-Purpose 3D Object Detection This repository contains implementation of the

Visual Understanding Lab @ Samsung AI Center Moscow 190 Dec 30, 2022