Skip to content

zhanglijun95/AutoMTL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

91 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Good News 09/15/2022: Our paper has been accepted by NeurIPS 2022!

UPDATE 4/25/2022: PyTorch Friendly APIs for AutoMTL

We add PyTorch APIs in APIs/ for more convenient multi-task model construction.

  • Now, users could directly adopt VCNs in layer_node.py, including Conv2dNode and BN2dNode, when building a customized multi-task model (should be inherited from mtl_model.py) in PyTorch.
  • An example can be found in mobilenetv2.py and Example_mobilenetv2.ipynb.
  • The detailed documentation is online here.

AutoMTL: A Programming Framework for Automating Efficient Multi-Task Learning

This is the website for our paper "AutoMTL: A Programming Framework for Automating Efficient Multi-Task Learning". The arXiv version can be found here.

Abstract

Multi-task learning (MTL) jointly learns a set of tasks. It is a promising approach to reduce the training and inference time and storage costs while improving prediction accuracy and generalization performance for many computer vision tasks. However, a major barrier preventing the widespread adoption of MTL is the lack of systematic support for developing compact multi-task models given a set of tasks. In this paper, we aim to remove the barrier by developing the first programming framework AutoMTL that automates MTL model development. AutoMTL takes as inputs an arbitrary backbone convolutional neural network and a set of tasks to learn, then automatically produce a multi-task model that achieves high accuracy and has small memory footprint simultaneously. As a programming framework, AutoMTL could facilitate the development of MTL-enabled computer vision applications and even further improve task performance.

Framework

Cite

Welcome to cite our work if you find it is helpful to your research.

@misc{zhang2021automtl,
      title={AutoMTL: A Programming Framework for Automating Efficient Multi-Task Learning}, 
      author={Lijun Zhang and Xiao Liu and Hui Guan},
      year={2021},
      eprint={2110.13076},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Description

Environment

You can build on your conda environment from the provided environment.yml. Feel free to change the env name in the file.

conda env create -f environment.yml

Datasets

We conducted experiments on three popular datasets in multi-task learning (MTL), CityScapes [1], NYUv2 [2], and Tiny-Taskonomy [3]. You can download the them here. For Tiny-Taskonomy, you will need to contact the authors directly. See their official website.

File Structure

├── data
│   ├── dataloader
│   │   ├── *_dataloader.py
│   ├── heads
│   │   ├── pixel2pixel.py
│   ├── metrics
│   │   ├── pixel2pixel_loss/metrics.py
├── framework
│   ├── layer_containers.py
│   ├── base_node.py
│   ├── layer_node.py
│   ├── mtl_model.py
│   ├── trainer.py
├── models
│   ├── *.prototxt
├── utils
└── └── pytorch_to_caffe.py

Code Description

Our code can be divided into three parts: code for data, code of AutoMTL, and others

  • For Data

    • Dataloaders *_dataloader.py: For each dataset, we offer a corresponding PyTorch dataloader with a specific task variable.
    • Heads pixel2pixel.py: The ASPP head [4] is implemented for the pixel-to-pixel vision tasks.
    • Metrics pixel2pixel_loss.py and pixel2pixel_metrics.py: For each task, it has its own criterion and metric.
  • AutoMTL

    • Multi-Task Model Generator mtl_model.py: Transfer the given backbone model in the format of prototxt, and the task-specific model head dictionary to a multi-task supermodel.
    • Trainer Tools trainer.py: Meterialize a three-stage training pipeline to search out a good multi-task model for the given tasks. pipeline
  • Others

    • Input Backbone *.prototxt: Typical vision backbone models including Deeplab-ResNet34 [4], MobileNetV2, and MNasNet.
    • Transfer to Prototxt pytorch_to_caffe.py: If you define your own customized backbone model in PyTorch API, we also provide a tool to convert it to a prototxt file.

How to Use

Note: Please refer to Example.ipynb for more details.

Set up Data

Each task will have its own dataloader for both training and validation, task-specific criterion (loss), evaluation metric, and model head. Here we take CityScapes as an example.

tasks = ['segment_semantic', 'depth_zbuffer']
task_cls_num = {'segment_semantic': 19, 'depth_zbuffer': 1} # the number of classes in each task

You can also define your own dataloader, criterion, and evaluation metrics. Please refer to files in data/ to make sure your customized classes have the same output format as ours to fit for our framework.

dataloader dictionary

trainDataloaderDict = {[]}
valDataloaderDict = {}
for task in tasks:
    dataset = CityScapes(dataroot, 'train', task, crop_h=224, crop_w=224)
    trainDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)
    dataset1 = CityScapes(dataroot, 'train1', task, crop_h=224, crop_w=224)
    trainDataloaderDict[task].append(DataLoader(dataset1, 16, shuffle=True)) # for network param training
    dataset2 = CityScapes(dataroot, 'train2', task, crop_h=224, crop_w=224)
    trainDataloaderDict[task].append(DataLoader(dataset2, 16, shuffle=True)) # for policy param training

    dataset = CityScapes(dataroot, 'test', task)
    valDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)

criterion dictionary

criterionDict = {}
for task in tasks:
    criterionDict[task] = CityScapesCriterions(task)

evaluation metric dictionary

metricDict = {}
for task in tasks:
    metricDict[task] = CityScapesMetrics(task)

task-specific heads dictionary

headsDict = nn.ModuleDict() # must be nn.ModuleDict() instead of python dictionary
for task in tasks:
    headsDict[task] = ASPPHeadNode(<feature_dim>, task_cls_num[task])

Construct Multi-Task Supermodel

prototxt = 'models/deeplab_resnet34_adashare.prototxt' # can be any CNN model
mtlmodel = MTLModel(prototxt, headsDict)

Note: We currently support Conv2d, BatchNorm2d, Linear, ReLU, Droupout, MaxPool2d and AvgPool2d (including global pooling), elementwise operators (inclduing production, add, and max).

3-Stage Training

define the trainer

trainer = Trainer(mtlmodel, trainDataloaderDict, valDataloaderDict, criterionDict, metricDict)

pre-train phase

trainer.pre_train(iters=<total_iter>, lr=<init_lr>, savePath=<save_path>)

policy-train phase

loss_lambda = {'segment_semantic': 1, 'depth_zbuffer': 1, 'policy':0.0005} # the weights for each task and the policy regularization term from the paper
trainer.alter_train_with_reg(iters=<total_iter>, policy_network_iters=<alter_iters>, policy_lr=<policy_lr>, network_lr=<network_lr>, 
                             loss_lambda=loss_lambda, savePath=<save_path>)

Note: When training the policy and the model weights together, we alternatively train them for specified iters in policy_network_iters.

sample policy from trained policy distribution

sample_policy_dict = OrderedDict()
for task in tasks:
    for name, policy in zip(name_list[task], policy_list[task]):
        distribution = softmax(policy, axis=-1)
        distribution /= sum(distribution)
        choice = np.random.choice((0,1,2), p=distribution)
        if choice == 0:
            sample_policy_dict[name] = torch.tensor([1.0,0.0,0.0]).cuda()
        elif choice == 1:
            sample_policy_dict[name] = torch.tensor([0.0,1.0,0.0]).cuda()
        elif choice == 2:
            sample_policy_dict[name] = torch.tensor([0.0,0.0,1.0]).cuda()

Note: The policy-train stage only obtains a good policy distribution. Before conducting post-train, we should sample a certain policy from the distribution.

post-train phase

trainer.post_train(ters=<total_iter>, lr=<init_lr>, 
                   loss_lambda=loss_lambda, savePath=<save_path>, reload=<sampled_policy>)

Validation Results in the Paper

You can download fully-trained models for each dataset here.

mtlmodel.load_state_dict(torch.load(<model_name>))
trainer.validate('mtl', hard=True)

Note: The "hard" must be set to True when conducting inference since we don't want to have soft policy this time.

Inference from Trained Model

mtlmodel.load_state_dict(torch.load(<model_name>))
output = mtlmodel(x, task=<task_name>, hard=True)

References

[1] Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt. The cityscapes dataset for semantic urban scene understanding. CVPR, 3213-3223, 2016.

[2] Silberman, Nathan and Hoiem, Derek and Kohli, Pushmeet and Fergus, Rob. Indoor segmentation and support inference from rgbd images. ECCV, 746-760, 2012.

[3] Zamir, Amir R and Sax, Alexander and Shen, William and Guibas, Leonidas J and Malik, Jitendra and Savarese, Silvio. Taskonomy: Disentangling task transfer learning. CVPR, 3712-3722, 2018.

[4] Chen, Liang-Chieh and Papandreou, George and Kokkinos, Iasonas and Murphy, Kevin and Yuille, Alan L. Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. PAMI, 834-848, 2017.

About

Code for "AutoMTL: A Programming Framework for Automating Efficient Multi-Task Learning"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published