Code for "AutoMTL: A Programming Framework for Automated Multi-Task Learning"

Related tags

Deep LearningAutoMTL
Overview

AutoMTL: A Programming Framework for Automated Multi-Task Learning

This is the website for our paper "AutoMTL: A Programming Framework for Automated Multi-Task Learning", submitted to MLSys 2022. The arXiv version will be public at Tue, 26 Oct 2021.

Abstract

Multi-task learning (MTL) jointly learns a set of tasks. It is a promising approach to reduce the training and inference time and storage costs while improving prediction accuracy and generalization performance for many computer vision tasks. However, a major barrier preventing the widespread adoption of MTL is the lack of systematic support for developing compact multi-task models given a set of tasks. In this paper, we aim to remove the barrier by developing the first programming framework AutoMTL that automates MTL model development. AutoMTL takes as inputs an arbitrary backbone convolutional neural network and a set of tasks to learn, then automatically produce a multi-task model that achieves high accuracy and has small memory footprint simultaneously. As a programming framework, AutoMTL could facilitate the development of MTL-enabled computer vision applications and even further improve task performance.

overview

Cite

Welcome to cite our work if you find it is helpful to your research. [TODO: cite info]

Description

Environment

conda install pytorch==1.6.0 torchvision==0.7.0 -c pytorch # Or higher
conda install protobuf
pip install opencv-python
pip install scikit-learn

Datasets

We conducted experiments on three popular datasets in multi-task learning (MTL), CityScapes [1], NYUv2 [2], and Tiny-Taskonomy [3]. You can download the them here. For Tiny-Taskonomy, you will need to contact the authors directly. See their official website.

File Structure

├── data
│   ├── dataloader
│   │   ├── *_dataloader.py
│   ├── heads
│   │   ├── pixel2pixel.py
│   ├── metrics
│   │   ├── pixel2pixel_loss/metrics.py
├── framework
│   ├── layer_containers.py
│   ├── base_node.py
│   ├── layer_node.py
│   ├── mtl_model.py
│   ├── trainer.py
├── models
│   ├── *.prototxt
├── utils
└── └── pytorch_to_caffe.py

Code Description

Our code can be divided into three parts: code for data, code of AutoMTL, and others

  • For Data

    • Dataloaders *_dataloader.py: For each dataset, we offer a corresponding PyTorch dataloader with a specific task variable.
    • Heads pixel2pixel.py: The ASPP head [4] is implemented for the pixel-to-pixel vision tasks.
    • Metrics pixel2pixel_loss/metrics.py: For each task, it has its own criterion and metric.
  • AutoMTL

    • Multi-Task Model Generator mtl_model.py: Transfer the given backbone model in the format of prototxt, and the task-specific model head dictionary to a multi-task supermodel.
    • Trainer Tools trainer.py: Meterialize a three-stage training pipeline to search out a good multi-task model for the given tasks. pipeline
  • Others

    • Input Backbone *.prototxt: Typical vision backbone models including Deeplab-ResNet34 [4], MobileNetV2, and MNasNet.
    • Transfer to Prototxt pytorch_to_caffe.py: If you define your own customized backbone model in PyTorch API, we also provide a tool to convert it to a prototxt file.

How to Use

Set up Data

Each task will have its own dataloader for both training and validation, task-specific criterion (loss), evaluation metric, and model head. Here we take CityScapes as an example.

tasks = ['segment_semantic', 'depth_zbuffer']
task_cls_num = {'segment_semantic': 19, 'depth_zbuffer': 1} # the number of classes in each task

You can also define your own dataloader, criterion, and evaluation metrics. Please refer to files in data/ to make sure your customized classes have the same output format as ours to fit for our framework.

dataloader dictionary

trainDataloaderDict = {}
valDataloaderDict = {}
for task in tasks:
    dataset = CityScapes(dataroot, 'train', task, crop_h=224, crop_w=224)
    trainDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)

    dataset = CityScapes(dataroot, 'test', task)
    valDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)

criterion dictionary

criterionDict = {}
for task in tasks:
    criterionDict[task] = CityScapesCriterions(task)

evaluation metric dictionary

metricDict = {}
for task in tasks:
    metricDict[task] = CityScapesMetrics(task)

task-specific heads dictionary

headsDict = nn.ModuleDict() # must be nn.ModuleDict() instead of python dictionary
for task in tasks:
    headsDict[task] = ASPPHeadNode(<feature_dim>, task_cls_num[task])

Construct Multi-Task Supermodel

prototxt = 'models/deeplab_resnet34_adashare.prototxt' # can be any CNN model
mtlmodel = MTLModel(prototxt, headsDict)

3-stage Training

define the trainer

trainer = Trainer(mtlmodel, trainDataloaderDict, valDataloaderDict, criterionDict, metricDict)

pre-train phase

trainer.pre_train(iters=<total_iter>, lr=<init_lr>, savePath=<save_path>)

policy-train phase

loss_lambda = {'segment_semantic': 1, 'depth_zbuffer': 1, 'policy':0.0005} # the weights for each task and the policy regularization term from the paper
trainer.alter_train_with_reg(iters=<total_iter>, policy_network_iters=<alter_iters>, policy_lr=<policy_lr>, network_lr=<network_lr>, 
                             loss_lambda=loss_lambda, savePath=<save_path>)

Notice that when training the policy and the model weights together, we alternatively train them for specified iters in policy_network_iters.

post-train phase

trainer.post_train(ters=<total_iter>, lr=<init_lr>, 
                   loss_lambda=loss_lambda, savePath=<save_path>, reload=<policy_train_model_name>)

Note: Please refer to Example.ipynb for more details.

References

[1] Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt. The cityscapes dataset for semantic urban scene understanding. CVPR, 3213-3223, 2016.

[2] Silberman, Nathan and Hoiem, Derek and Kohli, Pushmeet and Fergus, Rob. Indoor segmentation and support inference from rgbd images. ECCV, 746-760, 2012.

[3] Zamir, Amir R and Sax, Alexander and Shen, William and Guibas, Leonidas J and Malik, Jitendra and Savarese, Silvio. Taskonomy: Disentangling task transfer learning. CVPR, 3712-3722, 2018.

[4] Chen, Liang-Chieh and Papandreou, George and Kokkinos, Iasonas and Murphy, Kevin and Yuille, Alan L. Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. PAMI, 834-848, 2017.

Owner
Ivy Zhang
Ivy Zhang
A collection of 100 Deep Learning images and visualizations

A collection of Deep Learning images and visualizations. The project has been developed by the AI Summer team and currently contains almost 100 images.

AI Summer 65 Sep 12, 2022
tree-math: mathematical operations for JAX pytrees

tree-math: mathematical operations for JAX pytrees tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterati

Google 137 Dec 28, 2022
MINERVA: An out-of-the-box GUI tool for offline deep reinforcement learning

MINERVA is an out-of-the-box GUI tool for offline deep reinforcement learning, designed for everyone including non-programmers to do reinforcement learning as a tool.

Takuma Seno 80 Nov 06, 2022
Differentiable Annealed Importance Sampling (DAIS)

Differentiable Annealed Importance Sampling (DAIS) This repository contains the code to reproduce the DAIS results from the paper Differentiable Annea

Guodong Zhang 6 Dec 26, 2021
StocksMA is a package to facilitate access to financial and economic data of Moroccan stocks.

Creating easier access to the Moroccan stock market data What is StocksMA ? StocksMA is a package to facilitate access to financial and economic data

Salah Eddine LABIAD 28 Jan 04, 2023
Face Synthetics dataset is a collection of diverse synthetic face images with ground truth labels.

The Face Synthetics dataset Face Synthetics dataset is a collection of diverse synthetic face images with ground truth labels. It was introduced in ou

Microsoft 608 Jan 02, 2023
SPRING is a seq2seq model for Text-to-AMR and AMR-to-Text (AAAI2021).

SPRING This is the repo for SPRING (Symmetric ParsIng aNd Generation), a novel approach to semantic parsing and generation, presented at AAAI 2021. Wi

Sapienza NLP group 98 Dec 21, 2022
Modelisation on galaxy evolution using PEGASE-HR

model_galaxy Modelisation on galaxy evolution using PEGASE-HR This is a labwork done in internship at IAP directed by Damien Le Borgne (https://github

Adrien Anthore 1 Jan 14, 2022
Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning

Proxy Anchor Loss for Deep Metric Learning Unofficial pytorch, tensorflow and mxnet implementations of Proxy Anchor Loss for Deep Metric Learning. Not

Geonmo Gu 3 Jun 09, 2021
Python script that analyses the given datasets and comes up with the best polynomial regression representation with the smallest polynomial degree possible

Python script that analyses the given datasets and comes up with the best polynomial regression representation with the smallest polynomial degree possible, to be the most reliable with the least com

Nikolas B Virionis 2 Aug 01, 2022
Train a deep learning net with OpenStreetMap features and satellite imagery.

DeepOSM Classify roads and features in satellite imagery, by training neural networks with OpenStreetMap (OSM) data. DeepOSM can: Download a chunk of

TrailBehind, Inc. 1.3k Nov 24, 2022
Code for "Adversarial Attack Generation Empowered by Min-Max Optimization", NeurIPS 2021

Min-Max Adversarial Attacks [Paper] [arXiv] [Video] [Slide] Adversarial Attack Generation Empowered by Min-Max Optimization Jingkang Wang, Tianyun Zha

Jingkang Wang 12 Nov 23, 2022
Accommodating supervised learning algorithms for the historical prices of the world's favorite cryptocurrency and boosting it through LightGBM.

Accommodating supervised learning algorithms for the historical prices of the world's favorite cryptocurrency and boosting it through LightGBM.

1 Nov 27, 2021
Action Segmentation Evaluation

Reference Action Segmentation Evaluation Code This repository contains the reference code for action segmentation evaluation. If you have a bug-fix/im

5 May 22, 2022
This is the code for ACL2021 paper A Unified Generative Framework for Aspect-Based Sentiment Analysis

This is the code for ACL2021 paper A Unified Generative Framework for Aspect-Based Sentiment Analysis Install the package in the requirements.txt, the

108 Dec 23, 2022
Differentiable Prompt Makes Pre-trained Language Models Better Few-shot Learners

DART Implementation for ICLR2022 paper Differentiable Prompt Makes Pre-trained Language Models Better Few-shot Learners. Environment

ZJUNLP 83 Dec 27, 2022
Jupyter notebooks for using & learning Keras

deep-learning-with-keras-notebooks 這個github的repository主要是個人在學習Keras的一些記錄及練習。希望在學習過程中發現到一些好的資訊與範例也可以對想要學習使用 Keras來解決問題的同好,或是對深度學習有興趣的在學學生可以有一些方便理解與上手範例

ErhWen Kuo 2.1k Dec 27, 2022
End-to-end image segmentation kit based on PaddlePaddle.

English | 简体中文 PaddleSeg PaddleSeg has released the new version including the following features: Our team won the 6.2k Jan 02, 2023

Code for One-shot Talking Face Generation from Single-speaker Audio-Visual Correlation Learning (AAAI 2022)

One-shot Talking Face Generation from Single-speaker Audio-Visual Correlation Learning (AAAI 2022) Paper | Demo Requirements Python = 3.6 , Pytorch

FuxiVirtualHuman 84 Jan 03, 2023
Implementation of Multistream Transformers in Pytorch

Multistream Transformers Implementation of Multistream Transformers in Pytorch. This repository deviates slightly from the paper, where instead of usi

Phil Wang 47 Jul 26, 2022