Official implementation for the paper: Multi-label Classification with Partial Annotations using Class-aware Selective Loss

Overview

PWC

Multi-label Classification with Partial Annotations using Class-aware Selective Loss


Paper | Pretrained models

Official PyTorch Implementation

Emanuel Ben-Baruch, Tal Ridnik, Itamar Friedman, Avi Ben-Cohen, Nadav Zamir, Asaf Noy, Lihi Zelnik-Manor
DAMO Academy, Alibaba Group

Abstract

Large-scale multi-label classification datasets are commonly, and perhaps inevitably, partially annotated. That is, only a small subset of labels are annotated per sample. Different methods for handling the missing labels induce different properties on the model and impact its accuracy. In this work, we analyze the partial labeling problem, then propose a solution based on two key ideas. First, un-annotated labels should be treated selectively according to two probability quantities: the class distribution in the overall dataset and the specific label likelihood for a given data sample. We propose to estimate the class distribution using a dedicated temporary model, and we show its improved efficiency over a naive estimation computed using the dataset's partial annotations. Second, during the training of the target model, we emphasize the contribution of annotated labels over originally un-annotated labels by using a dedicated asymmetric loss. Experiments conducted on three partially labeled datasets, OpenImages, LVIS, and simulated-COCO, demonstrate the effectiveness of our approach. Specifically, with our novel selective approach, we achieve state-of-the-art results on OpenImages dataset. Code will be made available.

Class-aware Selective Approach

An overview of our approach is summarized in the following figure:

Loss Implementation

Our loss consists of a selective approach for adjusting the training mode for each class individualy and a partial asymmetric loss.

An implementation of the Class-aware Selective Loss (CSL) can be found here.

  • class PartialSelectiveLoss(nn.Module)

Pretrained Models

We provide models pretrained on the OpenImages datasset with different modes and architectures:

Model Architecture Link mAP
Ignore TResNet-M link 85.38
Negative TResNet-M link 85.85
Selective (CSL) TResNet-M link 86.72
Selective (CSL) TResNet-L link 87.34

Inference Code (Demo)

We provide inference code, that demonstrate how to load the model, pre-process an image and do inference. Example run of OpenImages model (after downloading the relevant model):

python infer.py  \
--dataset_type=OpenImages \
--model_name=tresnet_m \
--model_path=./models_local/mtresnet_opim_86.72.pth \
--pic_path=./pics/10162266293_c7634cbda9_o.jpg \
--input_size=448

Result Examples

Training Code

Training code is provided in (train.py). Also, code for simulating partial annotation for the MS-COCO dataset is available (here). In particular, two "partial" simulation schemes are implemented: fix-per-class(FPC) and random-per-sample (RPS).

  • FPC: For each class, we randomly sample a fixed number of positive annotations and the same number of negative annotations. The rest of the annotations are dropped.
  • RPA: We omit each annotation with probability p.

Pretrained weights using the ImageNet-21k dataset can be found here: link
Pretrained weights using the ImageNet-1k dataset can be found here: link

Example of training with RPS simulation:

--data=/mnt/datasets/COCO/COCO_2014
--model-path=models/pretrain/mtresnet_21k
--gamma_pos=0
--gamma_neg=4
--gamma_unann=4
--simulate_partial_type=rps
--simulate_partial_param=0.5
--partial_loss_mode=selective
--likelihood_topk=5
--prior_threshold=0.5
--prior_path=./outputs/priors/prior_fpc_1000.csv

Example of training with FPC simulation:

--data=/mnt/datasets/COCO/COCO_2014
--model-path=models/pretrain/mtresnet_21k
--gamma_pos=0
--gamma_neg=4
--gamma_unann=4
--simulate_partial_type=fpc
--simulate_partial_param=1000
--partial_loss_mode=selective
--likelihood_topk=5
--prior_threshold=0.5
--prior_path=./outputs/priors/prior_fpc_1000.csv

Typical Training Results

FPC (1,000) simulation scheme:

Model mAP
Ignore, CE 76.46
Negative, CE 81.24
Negative, ASL (4,1) 81.64
CSL - Selective, P-ASL(4,3,1) 83.44

RPS (0.5) simulation scheme:

Model mAP
Ignore, CE 84.90
Negative, CE 81.21
Negative, ASL (4,1) 81.91
CSL- Selective, P-ASL(4,1,1) 85.21

Estimating the Class Distribution

The training code contains also the procedure for estimting the class distribution from the data. Our approach enables to rank the classes based on training a temporary model usinig the Ignore mode. link

Top 10 classes:

Method Top 10 ranked classes
Original 'person', 'chair', 'car', 'dining table', 'cup', 'bottle', 'bowl', 'handbag', 'truck', 'backpack'
Estiimate (Ignore mode) 'person', 'chair', 'handbag', 'cup', 'bench', 'bottle', 'backpack', 'car', 'cell phone', 'potted plant'
Estimate (Negative mode) 'kite' 'truck' 'carrot' 'baseball glove' 'tennis racket' 'remote' 'cat' 'tie' 'horse' 'boat'

Citation

@misc{benbaruch2021multilabel,
      title={Multi-label Classification with Partial Annotations using Class-aware Selective Loss}, 
      author={Emanuel Ben-Baruch and Tal Ridnik and Itamar Friedman and Avi Ben-Cohen and Nadav Zamir and Asaf Noy and Lihi Zelnik-Manor},
      year={2021},
      eprint={2110.10955},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgements

Several images from OpenImages dataset are used in this project. ֿ
Some components of this code implementation are adapted from the repository https://github.com/Alibaba-MIIL/ASL.

Happywhale - Whale and Dolphin Identification Silver🥈 Solution (26/1588)

Kaggle-Happywhale Happywhale - Whale and Dolphin Identification Silver 🥈 Solution (26/1588) 竞赛方案思路 图像数据预处理-标志性特征图片裁剪:首先根据开源的标注数据训练YOLOv5x6目标检测模型,将训练集

Franxx 20 Nov 14, 2022
A simple and useful implementation of LPIPS.

lpips-pytorch Description Developing perceptual distance metrics is a major topic in recent image processing problems. LPIPS[1] is a state-of-the-art

So Uchida 121 Dec 24, 2022
General neural ODE and DAE modules for power system dynamic modeling.

Py_PSNODE General neural ODE and DAE modules for power system dynamic modeling. The PyTorch-based ODE solver is developed based on torchdiffeq. Sample

14 Dec 31, 2022
Neural Architecture Search Powered by Swarm Intelligence 🐜

Neural Architecture Search Powered by Swarm Intelligence 🐜 DeepSwarm DeepSwarm is an open-source library which uses Ant Colony Optimization to tackle

288 Oct 28, 2022
Code for classifying international patents based on the text of their titles/abstracts

Patent Classification Goal: To train a machine learning classifier that can automatically classify international patents downloaded from the WIPO webs

Prashanth Rao 1 Nov 08, 2022
Implementation of UNet on the Joey ML framework

Independent Research Project - Code Joey can be cloned from here https://github.com/devitocodes/joey/. Devito and other dependencies such as PyTorch a

Navjot Kukreja 1 Oct 21, 2021
[NeurIPS2021] Code Release of K-Net: Towards Unified Image Segmentation

K-Net: Towards Unified Image Segmentation Introduction This is an official release of the paper K-Net:Towards Unified Image Segmentation. K-Net will a

Wenwei Zhang 423 Jan 02, 2023
Evaluating saliency methods on artificial data with different background types

Evaluating saliency methods on artificial data with different background types This repository contains the relevant code for the MedNeurips 2021 subm

2 Jul 05, 2022
AI Face Mesh: This is a simple face mesh detection program based on Artificial intelligence.

AI Face Mesh: This is a simple face mesh detection program based on Artificial Intelligence which made with Python. It's able to detect 468 different

Md. Rakibul Islam 1 Jan 13, 2022
Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

1 Feb 14, 2022
Code to accompany the paper "Finding Bipartite Components in Hypergraphs", which is published in NeurIPS'21.

Finding Bipartite Components in Hypergraphs This repository contains code to accompany the paper "Finding Bipartite Components in Hypergraphs", publis

Peter Macgregor 5 May 06, 2022
A new video text spotting framework with Transformer

TransVTSpotter: End-to-end Video Text Spotter with Transformer Introduction A Multilingual, Open World Video Text Dataset and End-to-end Video Text Sp

weijiawu 67 Jan 03, 2023
HybVIO visual-inertial odometry and SLAM system

HybVIO A visual-inertial odometry system with an optional SLAM module. This is a research-oriented codebase, which has been published for the purposes

Spectacular AI 320 Jan 03, 2023
A new GCN model for Point Cloud Analyse

Pytorch Implementation of PointNet and PointNet++ This repo is implementation for VA-GCN in pytorch. Classification (ModelNet10/40) Data Preparation D

12 Feb 02, 2022
'Aligned mixture of latent dynamical systems' (amLDS) for stimulus decoding probabilistic manifold alignment across animals. P. Herrero-Vidal et al. NeurIPS 2021 code.

Across-animal odor decoding by probabilistic manifold alignment (NeurIPS 2021) This repository is the official implementation of aligned mixture of la

Pedro Herrero-Vidal 3 Jul 12, 2022
EXplainable Artificial Intelligence (XAI)

EXplainable Artificial Intelligence (XAI) This repository includes the codes for different projects on eXplainable Artificial Intelligence (XAI) by th

4 Nov 28, 2022
Build Low Code Automated Tensorflow, What-IF explainable models in just 3 lines of code.

Build Low Code Automated Tensorflow explainable models in just 3 lines of code.

Hasan Rafiq 170 Dec 26, 2022
Source code for "Interactive All-Hex Meshing via Cuboid Decomposition [SIGGRAPH Asia 2021]".

Interactive All-Hex Meshing via Cuboid Decomposition Video demonstration This repository contains an interactive software to the PolyCube-based hex-me

Lingxiao Li 131 Dec 05, 2022
RL-driven agent playing tic-tac-toe on starknet against challengers.

tictactoe-on-starknet RL-driven agent playing tic-tac-toe on starknet against challengers. GUI reference: https://pythonguides.com/create-a-game-using

21 Jul 30, 2022
PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation.

PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation. Warning: the master branch might collapse. To ob

559 Dec 14, 2022