This repository contains the source code of Auto-Lambda and baselines from the paper, Auto-Lambda: Disentangling Dynamic Task Relationships.

Overview

Auto-Lambda

This repository contains the source code of Auto-Lambda and baselines from the paper, Auto-Lambda: Disentangling Dynamic Task Relationships.

We encourage readers to check out our project page, including more interesting discussions and insights which are not covered in our technical paper.

Multi-task Methods

We implemented all weighting and gradient-based baselines presented in the paper for computer vision tasks: Dense Prediction Tasks (for NYUv2 and CityScapes) and Multi-domain Classification Tasks (for CIFAR-100).

Specifically, we have covered the implementation of these following multi-task optimisation methods:

Weighting-based:

Gradient-based:

Note: Applying a combination of both weighting and gradient-based methods can further improve performance.

Datasets

We applied the same data pre-processing following our previous project: MTAN which experimented on:

  • NYUv2 [3 Tasks] - 13 Class Segmentation + Depth Estimation + Surface Normal. [288 x 384] Resolution.
  • CityScapes [3 Tasks] - 19 Class Segmentation + 10 Class Part Segmentation + Disparity (Inverse Depth) Estimation. [256 x 512] Resolution.

Note: We have included a new task: Part Segmentation for CityScapes dataset. The pre-processing file for CityScapes has also been included in the dataset folder.

Experiments

All experiments were written in PyTorch 1.7 and can be trained with different flags (hyper-parameters) when running each training script. We briefly introduce some important flags below.

Flag Name Usage Comments
network choose multi-task network: split, mtan both architectures are based on ResNet-50; only available in dense prediction tasks
dataset choose dataset: nyuv2, cityscapes only available in dense prediction tasks
weight choose weighting-based method: equal, uncert, dwa, autol only autol will behave differently when set to different primary tasks
grad_method choose gradient-based method: graddrop, pcgrad, cagrad weight and grad_method can be applied together
task choose primary tasks: seg, depth, normal for NYUv2, seg, part_seg, disp for CityScapes, all: a combination of all standard 3 tasks only available in dense prediction tasks
with_noise toggle on to add noise prediction task for training (to evaluate robustness in auxiliary learning setting) only available in dense prediction tasks
subset_id choose domain ID for CIFAR-100, choose -1 for the multi-task learning setting only available in CIFAR-100 tasks
autol_init initialisation of Auto-Lambda, default 0.1 only available when applying Auto-Lambda
autol_lr learning rate of Auto-Lambda, default 1e-4 for NYUv2 and 3e-5 for CityScapes only available when applying Auto-Lambda

Training Auto-Lambda in Multi-task / Auxiliary Learning Mode:

python trainer_dense.py --dataset [nyuv2, cityscapes] --task [PRIMARY_TASK] --weight autol --gpu 0   # for NYUv2 or CityScapes dataset
python trainer_cifar.py --subset_id [PRIMARY_DOMAIN_ID] --weight autol --gpu 0   # for CIFAR-100 dataset

Training in Single-task Learning Mode:

python trainer_dense_single.py --dataset [nyuv2, cityscapes] --task [PRIMARY_TASK]  --gpu 0   # for NYUv2 or CityScapes dataset
python trainer_cifar_single.py --subset_id [PRIMARY_DOMAIN_ID] --gpu 0   # for CIFAR-100 dataset

Note: All experiments in the original paper were trained from scratch without pre-training.

Benchmark

For standard 3 tasks in NYUv2 (without dense prediction task) in the multi-task learning setting with Split architecture, please follow the results below.

Method Sem. Seg. (mIOU) Depth (aErr.) Normal (mDist.) Delta MTL
Single 43.37 52.24 22.40 -
Equal 44.64 43.32 24.48 +3.57%
DWA 45.14 43.06 24.17 +4.58%
GradDrop 45.39 43.23 24.18 +4.65%
PCGrad 45.15 42.38 24.13 +5.09%
Uncertainty 45.98 41.26 24.09 +6.50%
CAGrad 46.14 41.91 23.52 +7.05%
Auto-Lambda 47.17 40.97 23.68 +8.21%
Auto-Lambda + CAGrad 48.26 39.82 22.81 +11.07%

Note: The results were averaged across three random seeds. You should expect the error range less than +/-1%.

Citation

If you found this code/work to be useful in your own research, please considering citing the following:

@article{liu2022auto-lambda,
  title={Auto-Lambda: Disentangling Dynamic Task Relationships},
  author={Liu, Shikun and James, Stephen and Davison, Andrew J and Johns, Edward},
  journal={arXiv preprint arXiv:2202.03091},
  year={2022}
}

Acknowledgement

We would like to thank @Cranial-XIX for his clean implementation for gradient-based optimisation methods.

Contact

If you have any questions, please contact [email protected].

Owner
Shikun Liu
Ph.D. Student, The Dyson Robotics Lab at Imperial College.
Shikun Liu
Aws-machine-learning-university-accelerated-tab - Machine Learning University: Accelerated Tabular Data Class

Machine Learning University: Accelerated Tabular Data Class This repository contains slides, notebooks, and datasets for the Machine Learning Universi

AWS Samples 916 Dec 23, 2022
Outlier Exposure with Confidence Control for Out-of-Distribution Detection

OOD-detection-using-OECC This repository contains the essential code for the paper Outlier Exposure with Confidence Control for Out-of-Distribution De

Nazim Shaikh 64 Nov 02, 2022
[CVPR 2016] Unsupervised Feature Learning by Image Inpainting using GANs

Context Encoders: Feature Learning by Inpainting CVPR 2016 [Project Website] [Imagenet Results] Sample results on held-out images: This is the trainin

Deepak Pathak 829 Dec 31, 2022
Pytorch Implementation of PointNet and PointNet++++

Pytorch Implementation of PointNet and PointNet++ This repo is implementation for PointNet and PointNet++ in pytorch. Update 2021/03/27: (1) Release p

Luigi Ariano 1 Nov 11, 2021
A Python wrapper for Google Tesseract

Python Tesseract Python-tesseract is an optical character recognition (OCR) tool for python. That is, it will recognize and "read" the text embedded i

Matthias A Lee 4.6k Jan 05, 2023
Official implementation of the paper "Lightweight Deep CNN for Natural Image Matting via Similarity Preserving Knowledge Distillation"

Lightweight-Deep-CNN-for-Natural-Image-Matting-via-Similarity-Preserving-Knowledge-Distillation Introduction Accepted at IEEE Signal Processing Letter

DongGeun-Yoon 19 Jun 07, 2022
[NeurIPS 2021] Official implementation of paper "Learning to Simulate Self-driven Particles System with Coordinated Policy Optimization".

Code for Coordinated Policy Optimization Webpage | Code | Paper | Talk (English) | Talk (Chinese) Hi there! This is the source code of the paper “Lear

DeciForce: Crossroads of Machine Perception and Autonomy 81 Dec 19, 2022
Sharpened cosine similarity torch - A Sharpened Cosine Similarity layer for PyTorch

Sharpened Cosine Similarity A layer implementation for PyTorch Install At your c

Brandon Rohrer 203 Nov 30, 2022
Reproducing Results from A Hybrid Approach to Targeting Social Assistance

title author date output Reproducing Results from A Hybrid Approach to Targeting Social Assistance Lendie Follett and Heath Henderson 12/28/2021 html_

Lendie Follett 0 Jan 06, 2022
3D HourGlass Networks for Human Pose Estimation Through Videos

3D-HourGlass-Network 3D CNN Based Hourglass Network for Human Pose Estimation (3D Human Pose) from videos. This was my summer'18 research project. Dis

Naman Jain 51 Jan 02, 2023
DGL-TreeSearch and the Gurobi-MWIS interface

Independent Set Benchmarking Suite This repository contains the code for our maximum independent set benchmarking suite as well as our implementations

Maximilian Böther 19 Nov 22, 2022
Distributional Sliced-Wasserstein distance code

Distributional Sliced Wasserstein distance This is a pytorch implementation of the paper "Distributional Sliced-Wasserstein and Applications to Genera

VinAI Research 39 Jan 01, 2023
BMW TechOffice MUNICH 148 Dec 21, 2022
Official repository with code and data accompanying the NAACL 2021 paper "Hurdles to Progress in Long-form Question Answering" (https://arxiv.org/abs/2103.06332).

Hurdles to Progress in Long-form Question Answering This repository contains the official scripts and datasets accompanying our NAACL 2021 paper, "Hur

Kalpesh Krishna 41 Nov 08, 2022
Experiments for Fake News explainability project

fake-news-explainability Experiments for fake news explainability project This repository only contains the notebooks used to train the models and eva

Lorenzo Flores (Lj) 1 Dec 03, 2022
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ( Zitnik Lab @ Harvard 44 Dec 07, 2022

Twins: Revisiting the Design of Spatial Attention in Vision Transformers

Twins: Revisiting the Design of Spatial Attention in Vision Transformers Very recently, a variety of vision transformer architectures for dense predic

482 Dec 18, 2022
Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

CoProtector Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

Zhensu Sun 1 Oct 26, 2021
VQGAN+CLIP Colab Notebook with user-friendly interface.

VQGAN+CLIP and other image generation system VQGAN+CLIP Colab Notebook with user-friendly interface. Latest Notebook: Mse regulized zquantize Notebook

Justin John 227 Jan 05, 2023
Pseudo-Visual Speech Denoising

Pseudo-Visual Speech Denoising This code is for our paper titled: Visual Speech Enhancement Without A Real Visual Stream published at WACV 2021. Autho

Sindhu 94 Oct 22, 2022