The Pytorch code of "Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification", CVPR 2022 (Oral).

Related tags

Deep LearningDeepBDC
Overview

DeepBDC for few-shot learning

      

Introduction

In this repo, we provide the implementation of the following paper:
"Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification" [Project] [Paper].

In this paper, we propose deep Brownian Distance Covariance (DeepBDC) for few-shot classification. DeepBDC can effectively learn image representations by measuring, for the query and support images, the discrepancy between the joint distribution of their embedded features and product of the marginals. The core of DeepBDC is formulated as a modular and efficient layer, which can be flexibly inserted into deep networks, suitable not only for meta-learning framework based on episodic training, but also for the simple transfer learning (STL) framework of pretraining plus linear classifier.

If you find this repo helpful for your research, please consider citing our paper:

@inproceedings{DeepBDC-CVPR2022,
    title={Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification},
    author={Jiangtao Xie and Fei Long and Jiaming Lv and Qilong Wang and Peihua Li}, 
    booktitle={CVPR},
    year={2022}
 }

Few-shot classification Results

Experimental results on miniImageNet and CUB. We report average results with 2,000 randomly sampled episodes for both 1-shot and 5-shot evaluation. More details on the experiments can be seen in the paper.

miniImageNet

Method ResNet-12 Pre-trained models Meta-trained models
5-way-1-shot 5-way-5-shot GoogleDrive BaiduCloud GoogleDrive BaiduCloud
ProtoNet 62.11±0.44 80.77±0.30 Download Download Download Download
Good-Embed 64.98±0.44 82.10±0.30 Download Download N/A
Meta DeepBDC 67.34±0.43 84.46±0.28 Download Download Download Download
STL DeepBDC 67.83±0.43 85.45±0.29 Download Download N/A

Note that for Good-Embed and STL DeepBDC, a sequential self-distillation technique is used to obtain the pre-trained models; See the paper of Good-Embed for details.

CUB

Method ResNet-18 Pre-trained models Meta-trained models
5-way-1-shot 5-way-5-shot GoogleDrive BaiduCloud GoogleDrive BaiduCloud
ProtoNet 80.90±0.43 89.81±0.23 Download Download Download Download
Good-Embed 77.92±0.46 89.94±0.26 Download Download N/A
Meta DeepBDC 83.55±0.40 93.82±0.17 Download Download Download Download
STL DeepBDC 84.01±0.42 94.02±0.24 Download Download N/A

Note that for Good-Embed and STL DeepBDC, a sequential self-distillation technique is used to obtain the pre-trained models; See the paper of Good-Embed for details.

References

[BDC] G. J. Szekely and M. L. Rizzo. Brownian distance covariance. Annals of Applied Statistics, 3:1236–1265, 2009.
[ProtoNet] Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical networks for few-shot learning. In NIPS, 2017.
[Good-Embed] Y. Tian, Y. Wang, D. Krishnan, J. B. Tenenbaum, and P. Isola. Rethinking few-shot image classification: a good embedding is all you need? In ECCV, 2020.

Implementation details

Datasets

  • miniImageNet: We use the splits provided by Chen et al.
  • CUB: We use the splits provided by Chen et al.
  • tieredImageNet
  • Aircraft
  • Cars

Implementation environment

Note that the test accuracy may slightly vary with different Pytorch/CUDA versions, GPUs, etc.

  • Linux
  • Python 3.8.3
  • torch 1.7.1
  • GPU (RTX3090) + CUDA11.0 CuDNN
  • sklearn1.0.1, pillow8.0.0, numpy1.19.2

Installation

  • Clone this repo:
git clone https://github.com/Fei-Long121/DeepBDC.git
cd DeepBDC

For Meta DeepBDC on general object recognition

  1. cd scripts/mini_magenet/run_meta_deepbdc
  2. modify the dataset path in run_pretrain.sh, run_metatrain.sh and run_test.sh
  3. bash run.sh

For STL DeepBDC on general object recognition

  1. cd scripts/mini_imagenet/run_stl_deepbdc
  2. modify the dataset path in run_pretrain.sh, run_distillation.sh and run_test.sh
  3. bash run.sh

Acknowledgments

Our code builds upon the the following code publicly available:

Contact

If you have any questions or suggestions, please contact us:

Fei Long([email protected])
Jiaming Lv([email protected])

QSYM: A Practical Concolic Execution Engine Tailored for Hybrid Fuzzing

QSYM: A Practical Concolic Execution Engine Tailored for Hybrid Fuzzing Environment Tested on Ubuntu 14.04 64bit and 16.04 64bit Installation # disabl

gts3.org (<a href=[email protected])"> 581 Dec 30, 2022
Yolov5-lite - Minimal PyTorch implementation of YOLOv5

Yolov5-Lite: Minimal YOLOv5 + Deep Sort Overview This repo is a shortened versio

Kadir Nar 57 Nov 28, 2022
Deep Illuminator is a data augmentation tool designed for image relighting. It can be used to easily and efficiently generate a wide range of illumination variants of a single image.

Deep Illuminator Deep Illuminator is a data augmentation tool designed for image relighting. It can be used to easily and efficiently generate a wide

George Chogovadze 52 Nov 29, 2022
Implementation for NeurIPS 2021 Submission: SparseFed

READ THIS FIRST This repo is an anonymized version of an existing repository of GitHub, for the AIStats 2021 submission: SparseFed: Mitigating Model P

2 Jun 15, 2022
A library that can print Python objects in human readable format

objprint A library that can print Python objects in human readable format Install pip install objprint Usage op Use op() (or objprint()) to print obj

319 Dec 25, 2022
Using deep learning model to detect breast cancer.

Breast-Cancer-Detection Breast cancer is the most frequent cancer among women, with around one in every 19 women at risk. The number of cases of breas

1 Feb 13, 2022
Inference code for "StylePeople: A Generative Model of Fullbody Human Avatars" paper. This code is for the part of the paper describing video-based avatars.

NeuralTextures This is repository with inference code for paper "StylePeople: A Generative Model of Fullbody Human Avatars" (CVPR21). This code is for

Visual Understanding Lab @ Samsung AI Center Moscow 18 Oct 06, 2022
Detect roadway lanes using Python OpenCV for project during the 5th semester at DHBW Stuttgart for lecture in digital image processing.

Find Line Detection (Image Processing) Identifying lanes of the road is very common task that human driver performs. It's important to keep the vehicl

LMF 4 Jun 21, 2022
Reference code for the paper "Cross-Camera Convolutional Color Constancy" (ICCV 2021)

Cross-Camera Convolutional Color Constancy, ICCV 2021 (Oral) Mahmoud Afifi1,2, Jonathan T. Barron2, Chloe LeGendre2, Yun-Ta Tsai2, and Francois Bleibe

Mahmoud Afifi 76 Jan 07, 2023
Semi-Supervised Signed Clustering Graph Neural Network (and Implementation of Some Spectral Methods)

SSSNET SSSNET: Semi-Supervised Signed Network Clustering For details, please read our paper. Environment Setup Overview The project has been tested on

Yixuan He 9 Nov 24, 2022
This repository is an official implementation of the paper MOTR: End-to-End Multiple-Object Tracking with TRansformer.

MOTR: End-to-End Multiple-Object Tracking with TRansformer This repository is an official implementation of the paper MOTR: End-to-End Multiple-Object

348 Jan 07, 2023
Label Hallucination for Few-Shot Classification

Label Hallucination for Few-Shot Classification This repo covers the implementation of the following paper: Label Hallucination for Few-Shot Classific

Yiren Jian 13 Nov 13, 2022
VLG-Net: Video-Language Graph Matching Networks for Video Grounding

VLG-Net: Video-Language Graph Matching Networks for Video Grounding Introduction Official repository for VLG-Net: Video-Language Graph Matching Networ

Mattia Soldan 25 Dec 04, 2022
Cobalt Strike teamserver detection.

Cobalt-Strike-det Cobalt Strike teamserver detection. usage: cobaltstrike_verify.py [-l TARGETS] [-t THREADS] optional arguments: -h, --help show this

TimWhite 17 Sep 27, 2022
python 93% acc. CNN Dogs Vs Cats ( Pytorch )

English | 简体中文(测试中...敬请期待) Cnn-Classification-Dog-Vs-Cat 猫狗辨别 (pytorch版本) CNN Resnet18 的猫狗分类器,基于ResNet及其变体网路系列,对于一般的图像识别任务表现优异,模型精准度高达93%(小型样本)。 项目制作于

apple ye 1 May 22, 2022
Convolutional Neural Network for 3D meshes in PyTorch

MeshCNN in PyTorch SIGGRAPH 2019 [Paper] [Project Page] MeshCNN is a general-purpose deep neural network for 3D triangular meshes, which can be used f

Rana Hanocka 1.4k Jan 04, 2023
A pytorch implementation of the ACL2019 paper "Simple and Effective Text Matching with Richer Alignment Features".

RE2 This is a pytorch implementation of the ACL 2019 paper "Simple and Effective Text Matching with Richer Alignment Features". The original Tensorflo

287 Dec 21, 2022
Discretized Integrated Gradients for Explaining Language Models (EMNLP 2021)

Discretized Integrated Gradients for Explaining Language Models (EMNLP 2021) Overview of paths used in DIG and IG. w is the word being attributed. The

INK Lab @ USC 17 Oct 27, 2022
TensorFlow GNN is a library to build Graph Neural Networks on the TensorFlow platform.

TensorFlow GNN This is an early (alpha) release to get community feedback. It's under active development and we may break API compatibility in the fut

889 Dec 30, 2022
Implementations of the algorithms in the paper Approximative Algorithms for Multi-Marginal Optimal Transport and Free-Support Wasserstein Barycenters

Implementations of the algorithms in the paper Approximative Algorithms for Multi-Marginal Optimal Transport and Free-Support Wasserstein Barycenters

Johannes von Lindheim 3 Oct 29, 2022