PyTorch META-DATASET (Few-shot classification benchmark)

Overview

PyTorch META-DATASET (Few-shot classification benchmark)

This repo contains a PyTorch implementation of meta-dataset and a unified implementation of some few-shot methods. This repo may be useful to you if you:

  • want some pre-trained ImageNet models in PyTorch for META-DATASET;
  • want to benchmark your method on META-DATASET (but do not want to mix your PyTorch code with the original TensorFlow implementation);
  • are looking for a codebase to visualize few-shot episodes.

Benefits over original code:

  1. This repo can be properly seeded, allowing to repeat the same random series of episodes if needed;
  2. Data shuffling is performed without using a buffer, hence reducing the memory consumption;
  3. Better results can be obtained using this repo thanks to an enhanced way of resizing images. More details in the paper.

Note that this code also includes the original implementation for comparison (using the PyTorch workaround proposed by the authors). If you wish to use the original implementation, set the option loader_version: 'tf' in base.yaml (by default set to pytorch).

Yet to do:

  1. Add more methods
  2. Test for the multi-source setting

Table of contents

1. Setting up

Please carefully follow the instructions below to get started.

1.1 Requirements

The present code was developped and tested in Python 3.8. The list of requirements is provided in requirements.txt:

pip install -r requirements.txt

1.2 Data

To download the META-DATASET, please follow the details instructions provided at meta-dataset to obtain the .tfrecords converted data. Once done, make sure all converted dataset are in a single folder, and execute the following script to produce index files:

bash scripts/make_records/make_index_files.sh <path_to_converted_data>

This may take a few minutes. Once all this is done, set the path variable in config/base.yaml to your data folder.

1.3 Download pre-trained models

We provide trained Resnet-18 and WRN-2810 models on the training split of ILSVRC_2012 at checkpoints. All non-episodic baselines use the same checkpoint, stored in the standard folder. The results (averaged over 600 episodes) obtained with the provided Resnet-18 are summarized below:

Inductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
Finetune Resnet-18 59.8 60.5 63.5 80.6 80.9 61.5 45.2 91.1 55.1 41.8 64.0
ProtoNet Resnet-18 48.2 46.7 44.6 53.8 70.3 45.1 38.5 82.4 42.2 38.0 51.0
SimpleShot Resnet-18 60.0 54.2 55.9 78.6 77.8 57.4 49.2 90.3 49.6 44.2 61.7
Transductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
BD-CSPN Resnet-18 60.5 54.4 55.2 80.9 77.9 57.3 50.0 91.7 47.8 43.9 62.0
TIM-GD Resnet-18 63.6 65.6 66.4 85.6 84.7 65.8 57.5 95.6 65.2 50.9 70.1

See Sect. 1.4 and 1.5 to reproduce these results.

1.4 Train models from scratch (optional)

In order to train you model from scratch, execute scripts/train.sh script:

bash scripts/train.sh <method> <architecture> <dataset>

method is to be chosen among all method specific config files in config/, architecture in ['resnet18', 'wideres2810'] and dataset among all datasets (as named by the META-DATASET converted folders). Note that the hierarchy of arguments passed to src/train.py and src/eval.py is the following: base_config < method_config < opts arguments.

Mutiprocessing : This code supports distributed training. To leverage this feature, set the gpus option accordingly (for instance gpus: [0, 1, 2, 3]).

1.5 Test your models

Once trained (or once pre-trained models downloaded), you can evaluate your model on the test split of each dataset by running:

bash scripts/test.sh <method> <architecture> <base_dataset> <test_dataset>

Results will be saved in results/ / where corresponds to a unique hash number of the config (you can only get the same result folder iff all hyperparameters are the same).

2. Visualization of results

2.1 Training metrics

During training, training loss and validation accuracy are recorded and saved as .npy files in the checkpoint folder. Then, you can use the src/plot.py to plot these metrics (even during training).

Example 1: Plot the metrics of the standard (=non episodic) resnet-18 on ImageNet:

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/standard/

Example 2: Plot the metrics of all Resnet-18 trained on ImageNet

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/

2.2 Inference metrics

For methods that perform test-time optimization (for instance MAML, TIM, Finetune, ...), method specific metrics are plotted in real-time (versus test iterations) and averaged over test epidodes, which can allow you to track unexpected behavior easily. Such metrics are implemented in src/metrics/, and the choice of which metric to plot is specificied through the eval_metrics option in the method .yaml config file. An example with TIM method is provided below.

2.3 Visualization of episodes

By setting the option visu: True at inference, you can visualize samples of episodes. An example of such visualization is given below:

The samples will be saved in results/. All relevant optons can be found in the base.yaml file, in the EVAL-VISU section.

3. Incorporate your own method

This code was designed to allow easy incorporation of new methods.

Step 1: Add your method .py file to src/methods/ by following the template provided in src/methods/method.py.

Step 2: Add import in src/methods/__init__.py

Step 3: Add your method .yaml config file including the required options episodic_training and method (name of the class corresponding to your method). Also make sure that if your method performs test-time optimization, you also properly set the option iter that specifies the number of optimization steps performed at inference (this argument is also used to plot the inference metrics, see section 2.2).

4. Contributions

Contributions are more than welcome. In particular, if you want to add methods/pre-trained models, do make a pull-request.

5. Citation

If you find this repo useful for your research, please consider citing the following papers:

@misc{boudiaf2021mutualinformation,
      title={Mutual-Information Based Few-Shot Classification}, 
      author={Malik Boudiaf and Ziko Imtiaz Masud and Jérôme Rony and Jose Dolz and Ismail Ben Ayed and Pablo Piantanida},
      year={2021},
      eprint={2106.12252},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Additionally, do not hesitate to file issues if you encounter problems, or reach out directly to Malik Boudiaf ([email protected]).

6. Acknowledgments

I thank the authors of meta-dataset for releasing their code and the author of open-source TFRecord reader for open sourcing an awesome Pytorch-compatible TFRecordReader ! Also big thanks to @hkervadec for his thorough code review !

Owner
Malik Boudiaf
Malik Boudiaf
On the adaptation of recurrent neural networks for system identification

On the adaptation of recurrent neural networks for system identification This repository contains the Python code to reproduce the results of the pape

Marco Forgione 3 Jan 13, 2022
A Deep Learning based project for creating line art portraits.

ArtLine The main aim of the project is to create amazing line art portraits. Sounds Intresting,let's get to the pictures!! Model-(Smooth) Model-(Quali

Vijish Madhavan 3.3k Jan 07, 2023
BEAMetrics: Benchmark to Evaluate Automatic Metrics in Natural Language Generation

BEAMetrics: Benchmark to Evaluate Automatic Metrics in Natural Language Generation Installing The Dependencies $ conda create --name beametrics python

7 Jul 04, 2022
library for nonlinear optimization, wrapping many algorithms for global and local, constrained or unconstrained, optimization

NLopt is a library for nonlinear local and global optimization, for functions with and without gradient information. It is designed as a simple, unifi

Steven G. Johnson 1.4k Dec 25, 2022
TensorFlow implementation of original paper : https://github.com/hszhao/PSPNet

Keras implementation of PSPNet(caffe) Implemented Architecture of Pyramid Scene Parsing Network in Keras. For the best compability please use Python3.

VladKry 386 Dec 29, 2022
TAPEX: Table Pre-training via Learning a Neural SQL Executor

TAPEX: Table Pre-training via Learning a Neural SQL Executor The official repository which contains the code and pre-trained models for our paper TAPE

Microsoft 157 Dec 28, 2022
Autoencoders pretraining using clustering

Autoencoders pretraining using clustering

IITiS PAN 2 Dec 16, 2021
implementation of paper - You Only Learn One Representation: Unified Network for Multiple Tasks

YOLOR implementation of paper - You Only Learn One Representation: Unified Network for Multiple Tasks To reproduce the results in the paper, please us

Kin-Yiu, Wong 1.8k Jan 04, 2023
Scripts for training an AI to play the endless runner Subway Surfers using a supervised machine learning approach by imitation and a convolutional neural network (CNN) for image classification

About subwAI subwAI - a project for training an AI to play the endless runner Subway Surfers using a supervised machine learning approach by imitation

82 Jan 01, 2023
Efficient Multi Collection Style Transfer Using GAN

Proposed a new model that can make style transfer from single style image, and allow to transfer into multiple different styles in a single model.

Zhaozheng Shen 2 Jan 15, 2022
A PyTorch implementation of "Multi-Scale Contrastive Siamese Networks for Self-Supervised Graph Representation Learning", IJCAI-21

MERIT A PyTorch implementation of our IJCAI-21 paper Multi-Scale Contrastive Siamese Networks for Self-Supervised Graph Representation Learning. Depen

Graph Analysis & Deep Learning Laboratory, GRAND 32 Jan 02, 2023
Seeing Dynamic Scene in the Dark: High-Quality Video Dataset with Mechatronic Alignment (ICCV2021)

Seeing Dynamic Scene in the Dark: High-Quality Video Dataset with Mechatronic Alignment This is a pytorch project for the paper Seeing Dynamic Scene i

DV Lab 21 Nov 28, 2022
Train the HRNet model on ImageNet

High-resolution networks (HRNets) for Image classification News [2021/01/20] Add some stronger ImageNet pretrained models, e.g., the HRNet_W48_C_ssld_

HRNet 866 Jan 04, 2023
Nicholas Lee 3 Jan 09, 2022
OpenPose: Real-time multi-person keypoint detection library for body, face, hands, and foot estimation

Build Type Linux MacOS Windows Build Status OpenPose has represented the first real-time multi-person system to jointly detect human body, hand, facia

25.7k Jan 09, 2023
Code release for NeuS

NeuS We present a novel neural surface reconstruction method, called NeuS, for reconstructing objects and scenes with high fidelity from 2D image inpu

Peng Wang 813 Jan 04, 2023
Digital Twin Mobility Profiling: A Spatio-Temporal Graph Learning Approach

Digital Twin Mobility Profiling: A Spatio-Temporal Graph Learning Approach This is the implementation of traffic prediction code in DTMP based on PyTo

chenxin 1 Dec 19, 2021
Official repository for the paper "Self-Supervised Models are Continual Learners" (CVPR 2022)

Self-Supervised Models are Continual Learners This is the official repository for the paper: Self-Supervised Models are Continual Learners Enrico Fini

Enrico Fini 73 Dec 18, 2022
WaveFake: A Data Set to Facilitate Audio DeepFake Detection

WaveFake: A Data Set to Facilitate Audio DeepFake Detection This is the code repository for our NeurIPS 2021 (Track on Datasets and Benchmarks) paper

Chair for Sys­tems Se­cu­ri­ty 27 Dec 22, 2022
Towards Multi-Camera 3D Human Pose Estimation in Wild Environment

PanopticStudio Toolbox This repository has a toolbox to download, process, and visualize the Panoptic Studio (Panoptic) data. Note: Sep-21-2020: Curre

335 Jan 09, 2023