A Pytorch implementation of CVPR 2021 paper "RSG: A Simple but Effective Module for Learning Imbalanced Datasets"

Related tags

Deep LearningRSG
Overview

RSG: A Simple but Effective Module for Learning Imbalanced Datasets (CVPR 2021)

A Pytorch implementation of our CVPR 2021 paper "RSG: A Simple but Effective Module for Learning Imbalanced Datasets". RSG (Rare-class Sample Generator) is a flexible module that can generate rare-class samples during training and can be combined with any backbone network. RSG is only used in the training phase, so it will not bring additional burdens to the backbone network in the testing phase.

How to use RSG in your own networks

  1. Initialize RSG module:

    from RSG import *
    
    # n_center: The number of centers, e.g., 15.
    # feature_maps_shape: The shape of input feature maps (channel, width, height), e.g., [32, 16, 16].
    # num_classes: The number of classes, e.g., 10.
    # contrastive_module_dim: The dimention of the contrastive module, e.g., 256.
    # head_class_lists: The index of head classes, e.g., [0, 1, 2].
    # transfer_strength: Transfer strength, e.g., 1.0.
    # epoch_thresh: The epoch index when rare-class samples are generated: e.g., 159.
    
    self.RSG = RSG(n_center = 15, feature_maps_shape = [32, 16, 16], num_classes=10, contrastive_module_dim = 256, head_class_lists = [0, 1, 2], transfer_strength = 1.0, epoch_thresh = 159)
    
    
  2. Use RSG in the forward pass during training:

    out = self.layer2(out)
    
    # feature_maps: The input feature maps.
    # head_class_lists: The index of head classes.
    # target: The label of samples.
    # epoch: The current index of epoch.
    
    if phase_train == True:
      out, cesc_total, loss_mv_total, combine_target = self.RSG.forward(feature_maps = out, head_class_lists = [0, 1, 2], target = target, epoch = epoch)
     
    out = self.layer3(out) 
    

The two loss terms, namely ''cesc_total'' and ''loss_mv_total'', will be returned and combined with cross-entropy loss for backpropagation. More examples and details can be found in the models in the directory ''Imbalanced_Classification/models''.

How to train

Some examples:

Go into the "Imbalanced_Classification" directory.

  1. To reimplement the result of ResNet-32 on long-tailed CIFAR-10 ($\rho$ = 100) with RSG and LDAM-DRW:

    Export CUDA_VISIBLE_DEVICES=0,1
    python cifar_train.py --imb_type exp --imb_factor 0.01 --loss_type LDAM --train_rule DRW
    
  2. To reimplement the result of ResNet-32 on step CIFAR-10 ($\rho$ = 50) with RSG and Focal loss:

    Export CUDA_VISIBLE_DEVICES=0,1
    python cifar_train.py --imb_type step --imb_factor 0.02 --loss_type Focal --train_rule None
    
  3. To run experiments on iNaturalist 2018, Places-LT, or ImageNet-LT:

    Firstly, please prepare datasets and their corresponding list files. For the convenience, we provide the list files in Google Drive and Baidu Disk.

    Google Drive Baidu Disk
    download download (code: q3dk)

    To train the model:

    python inaturalist_train.py
    

    or

    python places_train.py
    

    or

    python imagenet_lt_train.py
    

    As for Places-LT or ImageNet-LT, the model is trained on the training set, and the best model on the validation set will be saved for testing. The "places_test.py" and 'imagenet_lt_test.py' are used for testing.

Citation

@inproceedings{Jianfeng2021RSG,
  title = {RSG: A Simple but Effective Module for Learning Imbalanced Datasets},
  author = {Jianfeng Wang and Thomas Lukasiewicz and Xiaolin Hu and Jianfei Cai and Zhenghua Xu},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  year={2021}
}
Context Axial Reverse Attention Network for Small Medical Objects Segmentation

CaraNet: Context Axial Reverse Attention Network for Small Medical Objects Segmentation This repository contains the implementation of a novel attenti

401 Dec 23, 2022
The PyTorch improved version of TPAMI 2017 paper: Face Alignment in Full Pose Range: A 3D Total Solution.

Face Alignment in Full Pose Range: A 3D Total Solution By Jianzhu Guo. [Updates] 2020.8.30: The pre-trained model and code of ECCV-20 are made public

Jianzhu Guo 3.4k Jan 02, 2023
An Open Source Machine Learning Framework for Everyone

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

170.1k Jan 05, 2023
基于Paddle框架的arcface复现

arcface-Paddle 基于Paddle框架的arcface复现 ArcFace-Paddle 本项目基于paddlepaddle框架复现ArcFace,并参加百度第三届论文复现赛,将在2021年5月15日比赛完后提供AIStudio链接~敬请期待 参考项目: InsightFace Padd

QuanHao Guo 16 Dec 15, 2022
An experiment on the performance of homemade Q-learning AIs in Agar.io depending on their state representation and available actions

Agar.io_Q-Learning_AI An experiment on the performance of homemade Q-learning AIs in Agar.io depending on their state representation and available act

1 Jun 09, 2022
[NeurIPS2021] Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks

Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks Code for NeurIPS 2021 Paper "Exploring Architectural Ingredients of A

Hanxun Huang 26 Dec 01, 2022
QA-GNN: Question Answering using Language Models and Knowledge Graphs

QA-GNN: Question Answering using Language Models and Knowledge Graphs This repo provides the source code & data of our paper: QA-GNN: Reasoning with L

Michihiro Yasunaga 434 Jan 04, 2023
Repository for the AugmentedPCA Python package.

Overview This Python package provides implementations of Augmented Principal Component Analysis (AugmentedPCA) - a family of linear factor models that

Billy Carson 6 Dec 07, 2022
Adversarial Graph Augmentation to Improve Graph Contrastive Learning

ADGCL : Adversarial Graph Augmentation to Improve Graph Contrastive Learning Introduction This repo contains the Pytorch [1] implementation of Adversa

susheel suresh 62 Nov 19, 2022
An executor that loads ONNX models and embeds documents using the ONNX runtime.

ONNXEncoder An executor that loads ONNX models and embeds documents using the ONNX runtime. Usage via Docker image (recommended) from jina import Flow

Jina AI 2 Mar 15, 2022
Final Project for the CS238: Decision Making Under Uncertainty course at Stanford University in Autumn '21.

Final Project for the CS238: Decision Making Under Uncertainty course at Stanford University in Autumn '21. We optimized wind turbine placement in a wind farm, subject to wake effects, using Q-learni

Manasi Sharma 2 Sep 27, 2022
GeneralOCR is open source Optical Character Recognition based on PyTorch.

Introduction GeneralOCR is open source Optical Character Recognition based on PyTorch. It makes a fidelity and useful tool to implement SOTA models on

57 Dec 29, 2022
An experimental technique for efficiently exploring neural architectures.

SMASH: One-Shot Model Architecture Search through HyperNetworks An experimental technique for efficiently exploring neural architectures. This reposit

Andy Brock 478 Aug 04, 2022
ERISHA is a mulitilingual multispeaker expressive speech synthesis framework. It can transfer the expressivity to the speaker's voice for which no expressive speech corpus is available.

ERISHA: Multilingual Multispeaker Expressive Text-to-Speech Library ERISHA is a multilingual multispeaker expressive speech synthesis framework. It ca

Ajinkya Kulkarni 43 Nov 27, 2022
GUI for TOAD-GAN, a PCG-ML algorithm for Token-based Super Mario Bros. Levels.

If you are using this code in your own project, please cite our paper: @inproceedings{awiszus2020toadgan, title={TOAD-GAN: Coherent Style Level Gene

Maren A. 13 Dec 14, 2022
PyTorch reimplementation of minimal-hand (CVPR2020)

Minimal Hand Pytorch Unofficial PyTorch reimplementation of minimal-hand (CVPR2020). you can also find in youtube or bilibili bare hand youtube or bil

Hao Meng 228 Dec 29, 2022
Revisiting Global Statistics Aggregation for Improving Image Restoration

Revisiting Global Statistics Aggregation for Improving Image Restoration Xiaojie Chu, Liangyu Chen, Chengpeng Chen, Xin Lu Paper: https://arxiv.org/pd

MEGVII Research 128 Dec 24, 2022
K-Nearest Neighbor in Pytorch

Pytorch KNN CUDA 2019/11/02 This repository will no longer be maintained as pytorch supports sort() and kthvalue on tensors. git clone https://github.

Chris Choy 65 Dec 01, 2022
This is an (re-)implementation of DeepLab-ResNet in TensorFlow for semantic image segmentation on the PASCAL VOC dataset.

DeepLab-ResNet-TensorFlow This is an (re-)implementation of DeepLab-ResNet in TensorFlow for semantic image segmentation on the PASCAL VOC dataset. Up

19 Jan 16, 2022
Code for the submitted paper Surrogate-based cross-correlation for particle image velocimetry

Surrogate-based cross-correlation (SBCC) This repository contains code for the submitted paper Surrogate-based cross-correlation for particle image ve

5 Jun 30, 2022