AdaShare: Learning What To Share For Efficient Deep Multi-Task Learning

Related tags

Deep LearningAdaShare
Overview

AdaShare: Learning What To Share For Efficient Deep Multi-Task Learning (NeurIPS 2020)

Introduction

alt text

AdaShare is a novel and differentiable approach for efficient multi-task learning that learns the feature sharing pattern to achieve the best recognition accuracy, while restricting the memory footprint as much as possible. Our main idea is to learn the sharing pattern through a task-specific policy that selectively chooses which layers to execute for a given task in the multi-task network. In other words, we aim to obtain a single network for multi-task learning that supports separate execution paths for different tasks.

Here is the link for our arxiv version.

Welcome to cite our work if you find it is helpful to your research.

@article{sun2020adashare,
  title={Adashare: Learning what to share for efficient deep multi-task learning},
  author={Sun, Ximeng and Panda, Rameswar and Feris, Rogerio and Saenko, Kate},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}

Experiment Environment

Our implementation is in Pytorch. We train and test our model on 1 Tesla V100 GPU for NYU v2 2-task, CityScapes 2-task and use 2 Tesla V100 GPUs for NYU v2 3-task and Tiny-Taskonomy 5-task.

We use python3.6 and please refer to this link to create a python3.6 conda environment.

Install the listed packages in the virual environment:

conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
conda install matplotlib
conda install -c menpo opencv
conda install pillow
conda install -c conda-forge tqdm
conda install -c anaconda pyyaml
conda install scikit-learn
conda install -c anaconda scipy
pip install tensorboardX

Datasets

Please download the formatted datasets for NYU v2 here

The formatted CityScapes can be found here.

Download Tiny-Taskonomy as instructed by its GitHub.

The formatted DomainNet can be found here.

Remember to change the dataroot to your local dataset path in all yaml files in the ./yamls/.

Training

Policy Learning Phase

Please execute train.py for policy learning, using the command

python train.py --config <yaml_file_name> --gpus <gpu ids>

For example, python train.py --config yamls/adashare/nyu_v2_2task.yml --gpus 0.

Sample yaml files are under yamls/adashare

Note: use domainnet branch for experiments on DomainNet, i.e. python train_domainnet.py --config <yaml_file_name> --gpus <gpu ids>

Retrain Phase

After Policy Learning Phase, we sample 8 different architectures and execute re-train.py for retraining.

python re-train.py --config <yaml_file_name> --gpus <gpu ids> --exp_ids <random seed id>

where we use different --exp_ids to specify different random seeds and generate different architectures. The best performance of all 8 runs is reported in the paper.

For example, python re-train.py --config yamls/adashare/nyu_v2_2task.yml --gpus 0 --exp_ids 0.

Note: use domainnet branch for experiments on DomainNet, i.e. python re-train_domainnet.py --config <yaml_file_name> --gpus <gpu ids>

Test/Inference

After Retraining Phase, execute test.py for get the quantitative results on the test set.

python test.py --config <yaml_file_name> --gpus <gpu ids> --exp_ids <random seed id>

For example, python test.py --config yamls/adashare/nyu_v2_2task.yml --gpus 0 --exp_ids 0.

We provide our trained checkpoints as follows:

  1. Please download our model in NYU v2 2-Task Learning
  2. Please donwload our model in CityScapes 2-Task Learning
  3. Please download our model in NYU v2 3-Task Learning

To use these provided checkpoints, please download them to ../experiments/checkpoints/ and uncompress there. Use the following command to test

python test.py --config yamls/adashare/nyu_v2_2task_test.yml --gpus 0 --exp_ids 0
python test.py --config yamls/adashare/cityscapes_2task_test.yml --gpus 0 --exp_ids 0
python test.py --config yamls/adashare/nyu_v2_3task_test.yml --gpus 0 --exp_ids 0

Test with our pre-trained checkpoints

We also provide some sample images to easily test our model for nyu v2 3 tasks.

Please download our model in NYU v2 3-Task Learning

Execute test_sample.py to test on sample images in ./nyu_v2_samples, using the command

python test_sample.py --config  yamls/adashare/nyu_v2_3task_test.yml --gpus 0

It will print the average quantitative results of sample images.

Note

If any link is invalid or any question, please email [email protected]

Convert scikit-learn models to PyTorch modules

sk2torch sk2torch converts scikit-learn models into PyTorch modules that can be tuned with backpropagation and even compiled as TorchScript. Problems

Alex Nichol 101 Dec 16, 2022
Implementation of CVPR'2022:Surface Reconstruction from Point Clouds by Learning Predictive Context Priors

Surface Reconstruction from Point Clouds by Learning Predictive Context Priors (CVPR 2022) Personal Web Pages | Paper | Project Page This repository c

136 Dec 12, 2022
Genetic feature selection module for scikit-learn

sklearn-genetic Genetic feature selection module for scikit-learn Genetic algorithms mimic the process of natural selection to search for optimal valu

Manuel Calzolari 260 Dec 14, 2022
Tensorflow Implementation of ECCV'18 paper: Multimodal Human Motion Synthesis

MT-VAE for Multimodal Human Motion Synthesis This is the code for ECCV 2018 paper MT-VAE: Learning Motion Transformations to Generate Multimodal Human

Xinchen Yan 36 Oct 02, 2022
Text2Art is an AI art generator powered with VQGAN + CLIP and CLIPDrawer models

Text2Art is an AI art generator powered with VQGAN + CLIP and CLIPDrawer models. You can easily generate all kind of art from drawing, painting, sketch, or even a specific artist style just using a t

Muhammad Fathy Rashad 643 Dec 30, 2022
Local Multi-Head Channel Self-Attention for FER2013

LHC-Net Local Multi-Head Channel Self-Attention This repository is intended to provide a quick implementation of the LHC-Net and to replicate the resu

12 Jan 04, 2023
[LREC] MMChat: Multi-Modal Chat Dataset on Social Media

MMChat This repo contains the code and data for the LREC2022 paper MMChat: Multi-Modal Chat Dataset on Social Media. Dataset MMChat is a large-scale d

Silver 47 Jan 03, 2023
Application of the L2HMC algorithm to simulations in lattice QCD.

l2hmc-qcd 📊 Slides Recent talk on Training Topological Samplers for Lattice Gauge Theory from the Machine Learning for High Energy Physics, on and of

Sam Foreman 37 Dec 14, 2022
A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution

DRSAN A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution Karam Park, Jae Woong Soh, and Nam Ik Cho Environments U

4 May 10, 2022
Official implementation of "Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets" (CVPR2021)

Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets This is the official implementation of "Towards Good Pract

Sanja Fidler's Lab 52 Nov 22, 2022
[AI6122] Text Data Management & Processing

[AI6122] Text Data Management & Processing is an elective course of MSAI, SCSE, NTU, Singapore. The repository corresponds to the AI6122 of Semester 1, AY2021-2022, starting from 08/2021. The instruc

HT. Li 1 Jan 17, 2022
A PyTorch implementation for PyramidNets (Deep Pyramidal Residual Networks)

A PyTorch implementation for PyramidNets (Deep Pyramidal Residual Networks) This repository contains a PyTorch implementation for the paper: Deep Pyra

Greg Dongyoon Han 262 Jan 03, 2023
Named Entity Recognition with Small Strongly Labeled and Large Weakly Labeled Data

Named Entity Recognition with Small Strongly Labeled and Large Weakly Labeled Data arXiv This is the code base for weakly supervised NER. We provide a

Amazon 92 Jan 04, 2023
Happywhale - Whale and Dolphin Identification Silver🥈 Solution (26/1588)

Kaggle-Happywhale Happywhale - Whale and Dolphin Identification Silver 🥈 Solution (26/1588) 竞赛方案思路 图像数据预处理-标志性特征图片裁剪:首先根据开源的标注数据训练YOLOv5x6目标检测模型,将训练集

Franxx 20 Nov 14, 2022
Memory-Augmented Model Predictive Control

Memory-Augmented Model Predictive Control This repository hosts the source code for the journal article "Composing MPC with LQR and Neural Networks fo

Fangyu Wu 1 Jun 19, 2022
UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus

UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus General info This is

71 Oct 25, 2022
ICSS - Interactive Continual Semantic Segmentation

Presentation This repository contains the code of our paper: Weakly-supervised c

Alteia 9 Jul 23, 2022
A comprehensive and up-to-date developer education platform for Urbit.

curriculum A comprehensive and up-to-date developer education platform for Urbit. This project organizes developer capabilities into a hierarchy of co

Sigilante 36 Oct 04, 2022
Pytorch tutorials for Neural Style transfert

PyTorch Tutorials This tutorial is no longer maintained. Please use the official version: https://pytorch.org/tutorials/advanced/neural_style_tutorial

Alexis David Jacq 135 Jun 26, 2022
This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods

pyLiDAR-SLAM This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods, which can easily be evaluated

Kitware, Inc. 208 Dec 16, 2022