PyTorch code for the paper: FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning

Overview

FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning

This is the PyTorch implementation of our paper:
FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning
Chia-Wen Kuo, Chih-Yao Ma, Jia-Bin Huang, Zsolt Kira
European Conference on Computer Vision (ECCV), 2020
[arXiv] [Project]

Abstract

Recent state-of-the-art semi-supervised learning (SSL) methods use a combination of image-based transformations and consistency regularization as core components. Such methods, however, are limited to simple transformations such as traditional data augmentation or convex combinations of two images. In this paper, we propose a novel learned feature-based refinement and augmentation method that produces a varied set of complex transformations. Importantly, these transformations also use information from both within-class and across-class prototypical representations that we extract through clustering. We use features already computed across iterations by storing them in a memory bank, obviating the need for significant extra computation. These transformations, combined with traditional image-based augmentation, are then used as part of the consistency-based regularization loss. We demonstrate that our method is comparable to current state of art for smaller datasets (CIFAR-10 and SVHN) while being able to scale up to larger datasets such as CIFAR-100 and mini-Imagenet where we achieve significant gains over the state of art (e.g., absolute 17.44% gain on mini-ImageNet). We further test our method on DomainNet, demonstrating better robustness to out-of-domain unlabeled data, and perform rigorous ablations and analysis to validate the method.

Installation

Prequesites

  • python == 3.7
  • pytorch == 1.6
  • torchvision == 0.7

Install python dependencies:

pip install -r requirements.txt

To augment data faster, we recommend using Pillow-SIMD.

Note: this project was developed under torch==1.4 originally. During code release, it is ported to torch==1.6 for the native support of automatic mixed precision (amp) training. The numbers are slightly different from those on the paper but are within the std margins.

Datasets

Download/Extract the following datasets to the dataset folder under the project root directory.

  • For SVHN, download train and test sets here.

  • For CIFAR-10 and CIFAR-100, download the python version dataset here.

  • For mini-ImageNet, use the following command to extract mini-ImageNet from ILSVRC-12:

    python3 dataloader/mini_imagenet.py -sz 128 \
     -sd [ILSVRC-12_ROOT] \
     -dd dataset/mini-imagenet
    

    Replace [ILSVRC-12_ROOT] with the root folder of your local ILSVRC-12 dataset.

  • For DomainNet, use the following command to download the domains:

    python3 dataloader/domainnet.py -r dataset/domainnet
    

Training

All commands should be run under the project root directory.

Running arguments

-cf CONFIG: training config
-d GPU_IDS: GPUs where the model is trained on
-n SAVE_ROOT: root directory where the checkpoints are saved to
-i ITERS: number of runs for average performance

CIFAR-100

# 4k labels
python3 train/featmatch.py -cf config/cifar100/[cifar100][test][cnn13][4000].json -d 0 1 -n [cifar100][test][cnn13][4000] -i 3 -o -a

# 10k labels
python3 train/featmatch.py -cf config/cifar100/[cifar100][test][cnn13][10000].json -d 0 1 -n [cifar100][test][cnn13][10000] -i 3 -o -a

mini-ImageNet

# 4k labels
python3 train/featmatch.py -cf config/mini-imagenet/[mimagenet][test][res18][4000].json -d 0 1 -n [mimagenet][test][res18][4000] -i 3 -o -a

# 10k lables
python3 train/featmatch.py -cf config/mini-imagenet/[mimagenet][test][res18][10000].json -d 0 1 -n [mimagenet][test][res18][10000] -i 3 -o -a

DomainNet

# ru = 0%
python3 train/featmatch.py -cf config/domainnet/[domainnet][test][res18][rl5-ru00].json -d 0 1 -n [domainnet][test][res18][rl5-ru00] -i 3 -a

# ru = 25%
python3 train/featmatch.py -cf config/domainnet/[domainnet][test][res18][rl5-ru25].json -d 0 1 -n [domainnet][test][res18][rl5-ru25] -i 3 -a

# ru = 50%
python3 train/featmatch.py -cf config/domainnet/[domainnet][test][res18][rl5-ru50].json -d 0 1 -n [domainnet][test][res18][rl5-ru50] -i 3 -a

# ru = 75%
python3 train/featmatch.py -cf config/domainnet/[domainnet][test][res18][rl5-ru75].json -d 0 1 -n [domainnet][test][res18][rl5-ru75] -i 3 -a

SVHN

# 250 labels
python3 train/featmatch.py -cf config/svhn/[svhn][test][wrn][250].json -d 0 1 -n [svhn][test][wrn][250] -i 3 -o -a

# 1k labels
python3 train/featmatch.py -cf config/svhn/[svhn][test][wrn][1000].json -d 0 1 -n [svhn][test][wrn][1000] -i 3 -o -a

# 4k labels
python3 train/featmatch.py -cf config/svhn/[svhn][test][wrn][4000].json -d 0 1 -n [svhn][test][wrn][4000] -i 3 -o -a

CIFAR-10

# 250 labels
python3 train/featmatch.py -cf config/cifar10/[cifar10][test][wrn][250].json -d 0 1 -n [cifar10][test][wrn][250] -i 3 -o -a

# 1k labels
python3 train/featmatch.py -cf config/cifar10/[cifar10][test][wrn][1000].json -d 0 1 -n [cifar10][test][wrn][1000] -i 3 -o -a

# 4k labels
python3 train/featmatch.py -cf config/cifar10/[cifar10][test][wrn][4000].json -d 0 1 -n [cifar10][test][wrn][4000] -i 3 -o -a

Results

Here are the quantitative results on different datasets, with different number of labels. Numbers represent error rate in three runs (lower the better).

For CIFAR-100, mini-ImageNet, CIFAR-10, and SVHN, we follow the conventional evaluation method. The model is evaluated directly on the test set, and the median of the last K (K=10 in our case) testing accuracies is reported.

For our proposed DomainNet setting, we reserve 1% of validation data, which is much fewer than the 5% of labeled data. The model is evaluated on the validation data, and the model with the best validation accuracy is selected. Finally, we report the test accuracy of the selected model.

CIFAR-100

#labels 4k 10k
paper 31.06 ± 0.41 26.83 ± 0.04
repo 30.79 ± 0.35 26.88 ± 0.13

mini-ImageNet

#labels 4k 10k
paper 39.05 ± 0.06 34.79 ± 0.22
repo 38.94 ± 0.19 34.84 ± 0.19

DomainNet

ru 0% 25% 50% 75%
paper 40.66 ± 0.60 46.11 ± 1.15 54.01 ± 0.66 58.30 ± 0.93
repo 40.47 ± 0.23 43.40 ± 0.25 52.49 ± 1.06 56.20 ± 1.25

SVHN

#labels 250 1k 4k
paper 3.34 ± 0.19 3.10 ± 0.06 2.62 ± 0.08
repo 3.62 ± 0.12 3.02 ± 0.04 2.61 ± 0.02

CIFAR-10

#labels 250 1k 4k
paper 7.50 ± 0.64 5.76 ± 0.07 4.91 ± 0.18
repo 7.38 ± 0.94 6.04 ± 0.24 5.19 ± 0.05

Acknowledgement

This work was funded by DARPA’s Learning with Less Labels (LwLL) program under agreement HR0011-18-S-0044 and DARPAs Lifelong Learning Machines (L2M) program under Cooperative Agreement HR0011-18-2-0019.

Citation

@inproceedings{kuo2020featmatch,
  title={Featmatch: Feature-based augmentation for semi-supervised learning},
  author={Kuo, Chia-Wen and Ma, Chih-Yao and Huang, Jia-Bin and Kira, Zsolt},
  booktitle={European Conference on Computer Vision},
  pages={479--495},
  year={2020},
  organization={Springer}
}
Beyond Image to Depth: Improving Depth Prediction using Echoes (CVPR 2021)

Beyond Image to Depth: Improving Depth Prediction using Echoes (CVPR 2021) Kranti Kumar Parida, Siddharth Srivastava, Gaurav Sharma. We address the pr

Kranti Kumar Parida 33 Jun 27, 2022
This is a simple backtesting framework to help you test your crypto currency trading. It includes a way to download and store historical crypto data and to execute a trading strategy.

You can use this simple crypto backtesting script to ensure your trading strategy is successful Minimal setup required and works well with static TP a

Andrei 154 Sep 12, 2022
A bare-bones Python library for quality diversity optimization.

pyribs Website Source PyPI Conda CI/CD Docs Docs Status Twitter pyribs.org GitHub docs.pyribs.org A bare-bones Python library for quality diversity op

ICAROS 127 Jan 06, 2023
Codebase for "ProtoAttend: Attention-Based Prototypical Learning."

Codebase for "ProtoAttend: Attention-Based Prototypical Learning." Authors: Sercan O. Arik and Tomas Pfister Paper: Sercan O. Arik and Tomas Pfister,

47 2 May 17, 2022
Finetuning Pipeline

KLUE Baseline Korean(한국어) KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark. See our paper fo

74 Dec 13, 2022
Like Dirt-Samples, but cleaned up

Clean-Samples Like Dirt-Samples, but cleaned up, with clear provenance and license info (generally a permissive creative commons licence but check the

TidalCycles 39 Nov 30, 2022
Geometric Sensitivity Decomposition

Geometric Sensitivity Decomposition This repo is the official implementation of A Geometric Perspective towards Neural Calibration via Sensitivity Dec

16 Dec 26, 2022
Small utility to demangle Nim symbols in callgrind files

nim_callgrind A small utility to demangle Nim symbols from callgrind files. Usage Run your (Nim) program with something like this: valgrind --tool=cal

kraptor 3 Feb 15, 2022
Animation of solving the traveling salesman problem to optimality using mixed-integer programming and iteratively eliminating sub tours

tsp-streamlit Animation of solving the traveling salesman problem to optimality using mixed-integer programming and iteratively eliminating sub tours.

4 Nov 05, 2022
a basic code repository for basic task in CV(classification,detection,segmentation)

basic_cv a basic code repository for basic task in CV(classification,detection,segmentation,tracking) classification generate dataset train predict de

1 Oct 15, 2021
Repositório criado para abrigar os notebooks com a listas de exercícios propostos pelo professor Gustavo Guanabara do canal Curso em Vídeo do YouTube durante o Curso de Python 3

Curso em Vídeo - Exercícios de Python 3 Sobre o repositório Este repositório contém os notebooks com a listas de exercícios propostos pelo professor G

João Pedro Pereira 9 Oct 15, 2022
This repository contains the code for the ICCV 2019 paper "Occupancy Flow - 4D Reconstruction by Learning Particle Dynamics"

Occupancy Flow This repository contains the code for the project Occupancy Flow - 4D Reconstruction by Learning Particle Dynamics. You can find detail

189 Dec 29, 2022
PiRank: Learning to Rank via Differentiable Sorting

PiRank: Learning to Rank via Differentiable Sorting This repository provides a reference implementation for learning PiRank-based models as described

54 Dec 17, 2022
This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge column damage detection

Bridge-damage-segmentation This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge c

Jingxiao Liu 5 Dec 07, 2022
A lightweight python AUTOmatic-arRAY library.

A lightweight python AUTOmatic-arRAY library. Write numeric code that works for: numpy cupy dask autograd jax mars tensorflow pytorch ... and indeed a

Johnnie Gray 62 Dec 27, 2022
pytorch bert intent classification and slot filling

pytorch_bert_intent_classification_and_slot_filling 基于pytorch的中文意图识别和槽位填充 说明 基本思路就是:分类+序列标注(命名实体识别)同时训练。 使用的预训练模型:hugging face上的chinese-bert-wwm-ext 依

西西嘛呦 33 Dec 15, 2022
Transformer in Computer Vision

Transformer-in-Vision A paper list of some recent Transformer-based CV works. If you find some ignored papers, please open issues or pull requests. **

506 Dec 26, 2022
Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)

ASGNet The code is for the paper "Adaptive Prototype Learning and Allocation for Few-Shot Segmentation" (accepted to CVPR 2021) [arxiv] Overview data/

Gen Li 91 Dec 23, 2022
Event queue (Equeue) dialect is an MLIR Dialect that models concurrent devices in terms of control and structure.

Event Queue Dialect Event queue (Equeue) dialect is an MLIR Dialect that models concurrent devices in terms of control and structure. Motivation The m

Cornell Capra 23 Dec 08, 2022
NeuralDiff: Segmenting 3D objects that move in egocentric videos

NeuralDiff: Segmenting 3D objects that move in egocentric videos Project Page | Paper + Supplementary | Video About This repository contains the offic

Vadim Tschernezki 14 Dec 05, 2022