LibMTL: A PyTorch Library for Multi-Task Learning

Overview

LibMTL

Documentation Status License: MIT PyPI version Supported Python versions Downloads CodeFactor Maintainability Made With Love

LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and API instructions.

Star us on GitHub — it motivates us a lot!

Table of Content

Features

  • Unified: LibMTL provides a unified code base to implement and a consistent evaluation procedure including data processing, metric objectives, and hyper-parameters on several representative MTL benchmark datasets, which allows quantitative, fair, and consistent comparisons between different MTL algorithms.
  • Comprehensive: LibMTL supports 84 MTL models combined by 7 architectures and 12 loss weighting strategies. Meanwhile, LibMTL provides a fair comparison on 3 computer vision datasets.
  • Extensible: LibMTL follows the modular design principles, which allows users to flexibly and conveniently add customized components or make personalized modifications. Therefore, users can easily and fast develop novel loss weighting strategies and architectures or apply the existing MTL algorithms to new application scenarios with the support of LibMTL.

Overall Framework

framework.

  • Config Module: Responsible for all the configuration parameters involved in the running framework, including the parameters of optimizer and learning rate scheduler, the hyper-parameters of MTL model, training configuration like batch size, total epoch, random seed and so on.
  • Dataloaders Module: Responsible for data pre-processing and loading.
  • Model Module: Responsible for inheriting classes architecture and weighting and instantiating a MTL model. Note that the architecture and the weighting strategy determine the forward and backward processes of the MTL model, respectively.
  • Losses Module: Responsible for computing the loss for each task.
  • Metrics Module: Responsible for evaluating the MTL model and calculating the metric scores for each task.

Supported Algorithms

LibMTL currently supports the following algorithms:

  • 12 loss weighting strategies.
Weighting Strategy Venues Comments
Equally Weighting (EW) - Implemented by us
Gradient Normalization (GradNorm) ICML 2018 Implemented by us
Uncertainty Weights (UW) CVPR 2018 Implemented by us
MGDA NeurIPS 2018 Referenced from official PyTorch implementation
Dynamic Weight Average (DWA) CVPR 2019 Referenced from official PyTorch implementation
Geometric Loss Strategy (GLS) CVPR 2019 workshop Implemented by us
Projecting Conflicting Gradient (PCGrad) NeurIPS 2020 Implemented by us
Gradient sign Dropout (GradDrop) NeurIPS 2020 Implemented by us
Impartial Multi-Task Learning (IMTL) ICLR 2021 Implemented by us
Gradient Vaccine (GradVac) ICLR 2021 Spotlight Implemented by us
Conflict-Averse Gradient descent (CAGrad) NeurIPS 2021 Referenced from official PyTorch implementation
Random Loss Weighting (RLW) arXiv Implemented by us
  • 7 architectures.
Architecture Venues Comments
Hrad Parameter Sharing (HPS) ICML 1993 Implemented by us
Cross-stitch Networks (Cross_stitch) CVPR 2016 Implemented by us
Multi-gate Mixture-of-Experts (MMoE) KDD 2018 Implemented by us
Multi-Task Attention Network (MTAN) CVPR 2019 Referenced from official PyTorch implementation
Customized Gate Control (CGC) ACM RecSys 2020 Best Paper Implemented by us
Progressive Layered Extraction (PLE) ACM RecSys 2020 Best Paper Implemented by us
DSelect-k NeurIPS 2021 Referenced from official TensorFlow implementation
  • 84 combinations of different architectures and loss weighting strategies.

Installation

The simplest way to install LibMTL is using pip.

pip install -U LibMTL

More details about environment configuration is represented in Docs.

Quick Start

We use the NYUv2 dataset as an example to show how to use LibMTL.

Download Dataset

The NYUv2 dataset we used is pre-processed by mtan. You can download this dataset here.

Run a Model

The complete training code for the NYUv2 dataset is provided in examples/nyu. The file train_nyu.py is the main file for training on the NYUv2 dataset.

You can find the command-line arguments by running the following command.

python train_nyu.py -h

For instance, running the following command will train a MTL model with EW and HPS on NYUv2 dataset.

python train_nyu.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step

More details is represented in Docs.

Citation

If you find LibMTL useful for your research or development, please cite the following:

@misc{LibMTL,
 author = {Baijiong Lin and Yu Zhang},
 title = {LibMTL: A PyTorch Library for Multi-Task Learning},
 year = {2021},
 publisher = {GitHub},
 journal = {GitHub repository},
 howpublished = {\url{https://github.com/median-research-group/LibMTL}}
}

Contributors

LibMTL is developed and maintained by Baijiong Lin and Yu Zhang.

Contact Us

If you have any question or suggestion, please feel free to contact us by raising an issue or sending an email to [email protected].

Acknowledgements

We would like to thank the authors that release the public repositories (listed alphabetically): CAGrad, dselect_k_moe, MultiObjectiveOptimization, and mtan.

License

LibMTL is released under the MIT license.

CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation

CPT This repository contains code and checkpoints for CPT. CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Gener

fastNLP 341 Dec 29, 2022
Automatic Image Background Subtraction

Automatic Image Background Subtraction This repo contains set of scripts for automatic one-shot image background subtraction task using the following

Oleg Sémery 6 Dec 05, 2022
CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

Frederick Wang 3 Apr 26, 2022
Fantasy Points Prediction and Dream Team Formation

Fantasy-Points-Prediction-and-Dream-Team-Formation Collected Data from open source resources that have over 100 Parameters for predicting cricket play

Akarsh Singh 2 Sep 13, 2022
Command-line tool for downloading and extending the RedCaps dataset.

RedCaps Downloader This repository provides the official command-line tool for downloading and extending the RedCaps dataset. Users can seamlessly dow

RedCaps dataset 33 Dec 14, 2022
A Kernel fuzzer focusing on race bugs

Razzer: Finding kernel race bugs through fuzzing Environment setup $ source scripts/envsetup.sh scripts/envsetup.sh sets up necessary environment var

Systems and Software Security Lab at Seoul National University (SNU) 328 Dec 26, 2022
AI Summer's complete catalog of articles

Learn Deep Learning with AI Summer A collection of all articles (almost 100) written for the AI Summer blog organized by topic. Deep Learning Theory M

AI Summer 95 Dec 29, 2022
Face Transformer for Recognition

Face-Transformer This is the code of Face Transformer for Recognition (https://arxiv.org/abs/2103.14803v2). Recently there has been great interests of

Zhong Yaoyao 153 Nov 30, 2022
LQM - Improving Object Detection by Estimating Bounding Box Quality Accurately

Improving Object Detection by Estimating Bounding Box Quality Accurately Abstract Object detection aims to locate and classify object instances in ima

IM Lab., POSTECH 0 Sep 28, 2022
DSL for matching Python ASTs

py-ast-rule-engine This library provides a DSL (domain-specific language) to match a pattern inside a Python AST (abstract syntax tree). The library i

1 Dec 18, 2021
Computer Vision Script to recognize first person motion, developed as final project for the course "Machine Learning and Deep Learning"

Overview of The Code BaseColab/MLDL_FPAR.pdf: it contains the full explanation of our work Base Colab: it contains the base colab used to perform all

Simone Papicchio 4 Jul 16, 2022
Lightweight mmm - Lightweight (Bayesian) Media Mix Model

Lightweight (Bayesian) Media Mix Model This is not an official Google product. L

Google 342 Jan 03, 2023
Repository for the Bias Benchmark for QA dataset.

BBQ Repository for the Bias Benchmark for QA dataset. Authors: Alicia Parrish, Angelica Chen, Nikita Nangia, Vishakh Padmakumar, Jason Phang, Jana Tho

ML² AT CILVR 18 Nov 18, 2022
This is the repo of the manuscript "Dual-branch Attention-In-Attention Transformer for speech enhancement"

DB-AIAT: A Dual-branch attention-in-attention transformer for single-channel SE

Guochen Yu 68 Dec 16, 2022
code and models for "Laplacian Pyramid Reconstruction and Refinement for Semantic Segmentation"

Laplacian Pyramid Reconstruction and Refinement for Semantic Segmentation This repository contains code and models for the method described in: Golnaz

55 Jun 18, 2022
Rest API Written In Python To Classify NSFW Images.

Rest API Written In Python To Classify NSFW Images.

Wahyusaputra 2 Dec 23, 2021
LoFTR:Detector-Free Local Feature Matching with Transformers CVPR 2021

LoFTR-with-train-script LoFTR:Detector-Free Local Feature Matching with Transformers CVPR 2021 (with train script --- unofficial ---). About Megadepth

Nan Xiaohu 15 Nov 04, 2022
MazeRL is an application oriented Deep Reinforcement Learning (RL) framework

MazeRL is an application oriented Deep Reinforcement Learning (RL) framework, addressing real-world decision problems. Our vision is to cover the complete development life cycle of RL applications ra

EnliteAI GmbH 222 Dec 24, 2022
Pytorch code for "DPFM: Deep Partial Functional Maps" - 3DV 2021 (Oral)

DPFM Code for "DPFM: Deep Partial Functional Maps" - 3DV 2021 (Oral) Installation This implementation runs on python = 3.7, use pip to install depend

Souhaib Attaiki 29 Oct 03, 2022
Jupyter Dock is a set of Jupyter Notebooks for performing molecular docking protocols interactively, as well as visualizing, converting file formats and analyzing the results.

Molecular Docking integrated in Jupyter Notebooks Description | Citation | Installation | Examples | Limitations | License Table of content Descriptio

Angel J. Ruiz Moreno 173 Dec 25, 2022