Source code of the paper Meta-learning with an Adaptive Task Scheduler.

Related tags

Deep LearningATS
Overview

ATS

About

Source code of the paper Meta-learning with an Adaptive Task Scheduler.

If you find this repository useful in your research, please cite the following paper:

@inproceedings{yao2021adaptive,
  title={Meta-learning with an Adaptive Task Scheduler},
  author={Yao, Huaxiu and Wang, Yu and Wei, Ying and Zhao, Peilin and Mahdavi, Mehrdad and Lian, Defu and Finn, Chelsea},
  booktitle={Proceedings of the Thirty-fifth Conference on Neural Information Processing Systems},
  year={2021} 
}

Miniimagenet

The processed miniimagenet dataset could be downloaded here. Assume the dataset has been downloaded and unzipped to /data/miniimagenet, which has the following file structure:

-- miniimagenet  // /data/miniimagenet
  -- miniImagenet
    -- train_task_id.pkl
    -- test_task_id.pkl
    -- mini_imagenet_train.pkl
    -- mini_imagenet_test.pkl
    -- mini_imagenet_val.pkl
    -- training_classes_20000_2_new.npz
    -- training_classes_20000_4_new.npz

Then $datadir in the following code sould be set to /data/miniimagenet.

ATS with noise = 0.6

We need to first pretrain the model with no noise. The model has been uploaded to this repo. You can also pretrain the model by yourself. The script for pretraining is as follows:
(1) 1 shot:

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.0

(2) 5 shot:

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.0

Then move the model to the current directory:
(1) 1 shot:

mv $logdir/ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_1.metalr0.001.innerlr0.01.hidden32/model20000 ./model20000_1shot

(2) 5 shot:

mv $logdir/ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_5.metalr0.001.innerlr0.01.hidden32/model10000 ./model10000_5shot

Then with this model, we could run the uniform sampling and ATS sampling. For ATS, the script is:
(1) 1 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --replace 0 --noise 0.6 --logdir $logdir --sampling_method ATS --buffer_size 10  --temperature 0.1 --scheduler_lr 0.001 --warmup 2000 --pretrain_iter 20000

(2) 5 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --noise 0.6 --logdir $logdir --sampling_method ATS --buffer_size 10 --utility_function sample --temperature 0.1 --scheduler_lr 0.001 --warmup 2000 --pretrain_iter 10000

For uniform sampling, we need to use the validation set to finetune the model trained under uniform sampling. The training commands are:
(1) 1 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6
mkdir models
mv ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_1.metalr0.001.innerlr0.01.hidden32_noise0.6/model30000 ./models/ANIL_0.4_model_1shot
python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6 --finetune

(2) 5 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6
mkdir models  // if directory "models" does not exist
mv ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_5.metalr0.001.innerlr0.01.hidden32_noise0.6/model30000 ./models/ANIL_0.4_model_5shot
python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6 --finetune

ATS with limited budgets

In this setting, pretraining is not needed. You can directly run the following code:
uniform sampling, 1 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --limit_data 1 --logdir ../train_logs --limit_classes 16

uniform sampling, 5 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --limit_data 1 --logdir ../train_logs --limit_classes 16

ATS 1 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --limit_data 1 --logdir ../train_logs --sampling_method ATS --buffer_size 6 --utility_function sample --temperature 1 --warmup 0 --limit_classes 16

ATS 5 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --limit_data 1 --logdir ../train_logs --sampling_method ATS --buffer_size 6 --utility_function sample --temperature 0.1 --warmup 0 --limit_classes 16

Drug

The processed dataset could be downloaded here. Assume the dataset has been downloaded and unzipped to /data/drug which has the following structure:

-- drug  // /data/drug
  -- ci9b00375_si_001.txt  
  -- compound_fp.npy               
  -- drug_split_id_group2.pickle  
  -- drug_split_id_group6.pickle
  -- ci9b00375_si_002.txt  
  -- drug_split_id_group17.pickle  
  -- drug_split_id_group3.pickle  
  -- drug_split_id_group9.pickle
  -- ci9b00375_si_003.txt  
  -- drug_split_id_group1.pickle   
  -- drug_split_id_group4.pickle  
  -- important_readme.md

Then $datadir in the following script should be set as /data/.

ATS with noise=4.

Uniform Sampling:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --noise 4 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --noise 4 --data_dir $datadir --train 0

ATS:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --noise 4 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --noise 4 --data_dir $datadir --train 0

ATS with full budgets

Uniform Sampling:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --data_dir $datadir --train 0

ATS:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --data_dir $datadir --train 0

For ATS, if you need to use 1 for calculating the loss as the input of the scheduler instead of 1, you can add --simple_loss after the script above.

Owner
Huaxiu Yao
Postdoctoral Scholar at [email protected]
Huaxiu Yao
URIE: Universal Image Enhancementfor Visual Recognition in the Wild

URIE: Universal Image Enhancementfor Visual Recognition in the Wild This is the implementation of the paper "URIE: Universal Image Enhancement for Vis

Taeyoung Son 43 Sep 12, 2022
The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"

Swin-Unet The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validatio

869 Jan 07, 2023
This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detection', CVPR 2019.

Code-and-Dataset-for-CapSal This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detec

lu zhang 48 Aug 19, 2022
Deep Reinforcement Learning for mobile robot navigation in ROS Gazebo simulator

DRL-robot-navigation Deep Reinforcement Learning for mobile robot navigation in ROS Gazebo simulator. Using Twin Delayed Deep Deterministic Policy Gra

87 Jan 07, 2023
MLJetReconstruction - using machine learning to reconstruct jets for CMS

MLJetReconstruction - using machine learning to reconstruct jets for CMS The C++ data extraction code used here was based heavily on that foundv here.

ALPhA Davidson 0 Nov 17, 2021
Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow

eXtreme Gradient Boosting Community | Documentation | Resources | Contributors | Release Notes XGBoost is an optimized distributed gradient boosting l

Distributed (Deep) Machine Learning Community 23.6k Dec 31, 2022
Multi-layer convolutional LSTM with Pytorch

Convolution_LSTM_pytorch Thanks for your attention. I haven't got time to maintain this repo for a long time. I recommend this repo which provides an

Zijie Zhuang 733 Dec 30, 2022
ICLR 2021, Fair Mixup: Fairness via Interpolation

Fair Mixup: Fairness via Interpolation Training classifiers under fairness constraints such as group fairness, regularizes the disparities of predicti

Ching-Yao Chuang 49 Nov 22, 2022
A Python library for Deep Graph Networks

PyDGN Wiki Description This is a Python library to easily experiment with Deep Graph Networks (DGNs). It provides automatic management of data splitti

Federico Errica 194 Dec 22, 2022
Source code for the ACL-IJCNLP 2021 paper entitled "T-DNA: Taming Pre-trained Language Models with N-gram Representations for Low-Resource Domain Adaptation" by Shizhe Diao et al.

T-DNA Source code for the ACL-IJCNLP 2021 paper entitled Taming Pre-trained Language Models with N-gram Representations for Low-Resource Domain Adapta

shizhediao 17 Dec 22, 2022
CTF challenges from redpwnCTF 2021

redpwnCTF 2021 Challenges This repository contains challenges from redpwnCTF 2021 in the rCDS format; challenge information is in the challenge.yaml f

redpwn 27 Dec 07, 2022
This repository contains a PyTorch implementation of "AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis".

AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis | Project Page | Paper | PyTorch implementation for the paper "AD-NeRF: Audio

551 Dec 29, 2022
FCA: Learning a 3D Full-coverage Vehicle Camouflage for Multi-view Physical Adversarial Attack

FCA: Learning a 3D Full-coverage Vehicle Camouflage for Multi-view Physical Adversarial Attack Case study of the FCA. The code can be find in FCA. Cas

IDRL 21 Dec 15, 2022
A Python library for generating new text from existing samples.

ReMarkov is a Python library for generating text from existing samples using Markov chains. You can use it to customize all sorts of writing from birt

8 May 17, 2022
A data-driven maritime port simulator

PySeidon - A Data-Driven Maritime Port Simulator 🌊 Extendable and modular software for maritime port simulation. This software uses entity-component

6 Apr 10, 2022
Pytorch implementation of PTNet for high-resolution and longitudinal infant MRI synthesis

Pyramid Transformer Net (PTNet) Project | Paper Pytorch implementation of PTNet for high-resolution and longitudinal infant MRI synthesis. PTNet: A Hi

Xuzhe Johnny Zhang 6 Jun 08, 2022
Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021)

Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) This repository is the official P

Jingyun Liang 159 Dec 30, 2022
Subdivision-based Mesh Convolutional Networks

Subdivision-based Mesh Convolutional Networks The official implementation of SubdivNet in our paper, Subdivion-based Mesh Convolutional Networks Requi

Zheng-Ning Liu 181 Dec 28, 2022
FG-transformer-TTS Fine-grained style control in transformer-based text-to-speech synthesis

LST-TTS Official implementation for the paper Fine-grained style control in transformer-based text-to-speech synthesis. Submitted to ICASSP 2022. Audi

Li-Wei Chen 64 Dec 30, 2022
This project is used for the paper Differentiable Programming of Isometric Tensor Network

This project is used for the paper "Differentiable Programming of Isometric Tensor Network". (arXiv:2110.03898)

Chenhua Geng 15 Dec 13, 2022