ML-Decoder: Scalable and Versatile Classification Head

Overview

ML-Decoder: Scalable and Versatile Classification Head

PWC
PWC
PWC


Paper

Official PyTorch Implementation

Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben-Baruch, Asaf Noy
DAMO Academy, Alibaba Group

Abstract

In this paper, we introduce ML-Decoder, a new attention-based classification head. ML-Decoder predicts the existence of class labels via queries, and enables better utilization of spatial data compared to global average pooling. By redesigning the decoder architecture, and using a novel group-decoding scheme, ML-Decoder is highly efficient, and can scale well to thousands of classes. Compared to using a larger backbone, ML-Decoder consistently provides a better speed-accuracy trade-off. ML-Decoder is also versatile - it can be used as a drop-in replacement for various classification heads, and generalize to unseen classes when operated with word queries. Novel query augmentations further improve its generalization ability. Using ML-Decoder, we achieve state-of-the-art results on several classification tasks: on MS-COCO multi-label, we reach 91.4% mAP; on NUS-WIDE zero-shot, we reach 31.1% ZSL mAP; and on ImageNet single-label, we reach with vanilla ResNet50 backbone a new top score of 80.7%, without extra data or distillation.

ML-Decoder Implementation

ML-Decoder implementation is available here. It can be easily integrated into any backbone using this example code:

ml_decoder_head = MLDecoder(num_classes) # initilization

spatial_embeddings = self.backbone(input_image) # backbone generates spatial embeddings      
 
logits = ml_decoder_head(spatial_embeddings) # transfrom spatial embeddings to logits

Training Code

We will share a full reproduction code for the article results.

Multi-label Training Code


A reproduction code for MS-COCO multi-label:

python train.py  \
--data=/home/datasets/coco2014/ \
--model_name=tresnet_l \
--image_size=448

Single-label Training Code

Our single-label training code uses the excellent timm repo. Reproduction code is currently from a fork, we will work toward a full merge to the main repo.

git clone https://github.com/mrT23/pytorch-image-models.git

This is the code for A2 configuration training, with ML-Decoder (--use-ml-decoder-head=1):

python -u -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
./train.py \
/data/imagenet/ \
--amp \
-b=256 \
--epochs=300 \
--drop-path=0.05 \
--opt=lamb \
--weight-decay=0.02 \
--sched='cosine' \
--lr=4e-3 \
--warmup-epochs=5 \
--model=resnet50 \
--aa=rand-m7-mstd0.5-inc1 \
--reprob=0.0 \
--remode='pixel' \
--mixup=0.1 \
--cutmix=1.0 \
--aug-repeats 3 \
--bce-target-thresh 0.2 \
--smoothing=0 \
--bce-loss \
--train-interpolation=bicubic \
--use-ml-decoder-head=1

ZSL Training Code

Reproduction code for ZSL is WIP.

Citation

@misc{ridnik2021mldecoder,
      title={ML-Decoder: Scalable and Versatile Classification Head}, 
      author={Tal Ridnik and Gilad Sharir and Avi Ben-Cohen and Emanuel Ben-Baruch and Asaf Noy},
      year={2021},
      eprint={2111.12933},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Optimal space decomposition based-product quantization for approximate nearest neighbor search

Optimal space decomposition based-product quantization for approximate nearest neighbor search Abstract Product quantization(PQ) is an effective neare

Mylove 1 Nov 19, 2021
Implementation for ACProp ( Momentum centering and asynchronous update for adaptive gradient methdos, NeurIPS 2021)

This repository contains code to reproduce results for submission NeurIPS 2021, "Momentum Centering and Asynchronous Update for Adaptive Gradient Meth

Juntang Zhuang 15 Jun 11, 2022
Much faster than SORT(Simple Online and Realtime Tracking), a little worse than SORT

QSORT QSORT(Quick + Simple Online and Realtime Tracking) is a simple online and realtime tracking algorithm for 2D multiple object tracking in video s

Yonghye Kwon 8 Jul 27, 2022
Auto grind btdb2 exp for tower

Bloons TD Battles 2 EXP Grinder Auto grind btdb2 exp for towers Setup I suggest checking out every screenshot to see what they are supposed to be, so

Vincent 6 Jul 29, 2022
Code for "PVNet: Pixel-wise Voting Network for 6DoF Pose Estimation" CVPR 2019 oral

Good news! We release a clean version of PVNet: clean-pvnet, including how to train the PVNet on the custom dataset. Use PVNet with a detector. The tr

ZJU3DV 722 Dec 27, 2022
Contains a bunch of different python programm tasks

py_tasks Contains a bunch of different python programm tasks Armstrong.py - calculate Armsrong numbers in range from 0 to n with / without cache and c

Dmitry Chmerenko 1 Dec 17, 2021
ROS Basics and TurtleSim

Waypoint Follower Anna Garverick This package draws given waypoints, then waits for a service call with a start position to send the turtle to each wa

Anna Garverick 1 Dec 13, 2021
Visual Adversarial Imitation Learning using Variational Models (VMAIL)

Visual Adversarial Imitation Learning using Variational Models (VMAIL) This is the official implementation of the NeurIPS 2021 paper. Project website

14 Nov 18, 2022
Conditional Generative Adversarial Networks (CGAN) for Mobility Data Fusion

This code implements the paper, Kim et al. (2021). Imputing Qualitative Attributes for Trip Chains Extracted from Smart Card Data Using a Conditional Generative Adversarial Network. Transportation Re

Eui-Jin Kim 2 Feb 03, 2022
A Probabilistic End-To-End Task-Oriented Dialog Model with Latent Belief States towards Semi-Supervised Learning

LABES This is the code for EMNLP 2020 paper "A Probabilistic End-To-End Task-Oriented Dialog Model with Latent Belief States towards Semi-Supervised L

17 Sep 28, 2022
Code for the paper titled "Prabhupadavani: A Code-mixed Speech Translation Data for 25 languages"

Prabhupadavani: A Code-mixed Speech Translation Data for 25 languages Code for the paper titled "Prabhupadavani: A Code-mixed Speech Translation Data

Ayush Daksh 12 Dec 01, 2022
WSDM2022 Challenge - Large scale temporal graph link prediction

WSDM 2022 Large-scale Temporal Graph Link Prediction - Baseline and Initial Test Set WSDM Cup Website link Link to this challenge This branch offers A

Deep Graph Library 34 Dec 29, 2022
Building blocks for uncertainty-aware cycle consistency presented at NeurIPS'21.

UncertaintyAwareCycleConsistency This repository provides the building blocks and the API for the work presented in the NeurIPS'21 paper Robustness vi

EML Tübingen 19 Dec 12, 2022
Softlearning is a reinforcement learning framework for training maximum entropy policies in continuous domains. Includes the official implementation of the Soft Actor-Critic algorithm.

Softlearning Softlearning is a deep reinforcement learning toolbox for training maximum entropy policies in continuous domains. The implementation is

Robotic AI & Learning Lab Berkeley 997 Dec 30, 2022
[SIGGRAPH'22] StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets

[Project] [PDF] This repository contains code for our SIGGRAPH'22 paper "StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets" by Axel Sauer, Katja

742 Jan 04, 2023
Official PyTorch implementation of "Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks" (AAAI 2022)

Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks This is the code for reproducing the results of th

2 Dec 27, 2021
MoCoPnet - Deformable 3D Convolution for Video Super-Resolution

Deformable 3D Convolution for Video Super-Resolution Pytorch implementation of l

Xinyi Ying 28 Dec 15, 2022
Tensors and neural networks in Haskell

Hasktorch Hasktorch is a library for tensors and neural networks in Haskell. It is an independent open source community project which leverages the co

hasktorch 920 Jan 04, 2023
LOFO (Leave One Feature Out) Importance calculates the importances of a set of features based on a metric of choice,

LOFO (Leave One Feature Out) Importance calculates the importances of a set of features based on a metric of choice, for a model of choice, by iteratively removing each feature from the set, and eval

Ahmet Erdem 691 Dec 23, 2022
Code and dataset for ACL2018 paper "Exploiting Document Knowledge for Aspect-level Sentiment Classification"

Aspect-level Sentiment Classification Code and dataset for ACL2018 [paper] ‘‘Exploiting Document Knowledge for Aspect-level Sentiment Classification’’

Ruidan He 146 Nov 29, 2022