Learning to Self-Train for Semi-Supervised Few-Shot

Overview

Learning to Self-Train for Semi-Supervised Few-Shot Classification

LICENSE Python TensorFlow

This repository contains the TensorFlow implementation for NeurIPS 2019 Paper "Learning to Self-Train for Semi-Supervised Few-Shot Classification".

Check the few-shot classification leaderboard.

Summary

Installation

In order to run this repository, we advise you to install python 2.7 or 3.5 and TensorFlow 1.3.0 with Anaconda.

You may download Anaconda and read the installation instruction on their official website: https://www.anaconda.com/download/

Create a new environment and install tensorflow on it:

conda create --name lst-tf python=2.7
conda activate lst-tf
conda install tensorflow-gpu=1.3.0

Install other requirements:

pip install scipy tqdm opencv-python pillow matplotlib

Clone this repository:

git clone https://github.com/xinzheli1217/learning-to-self-train.git 
cd learning-to-self-train

Project Architecture

.
├── data_generator              # dataset generator 
|   └── meta_data_generator.py  # data genertor for meta-train phase
├── models                      # tensorflow model files 
|   ├── models.py               # resnet12 CNN class
|   └── meta_model_LST.py       # semi-supervised meta-train model class
├── trainer                     # tensorflow trianer files  
|   └── meta_LST.py             # semi-supervised meta-train trainer class
├── utils                       # a series of tools used in this repo
|   └── misc.py                 # miscellaneous tool functions
| 
├── data                        # the folder containing datasets for experiments
├── pretrain_weights_dir        # the folder containing MTL pre-training weights
├── weights_saving_dir          # the folder containing meta-training weights
├── test_output_dir             # the folder containing meta-testing files
├── filenames_and_labels        # the folder containing image file paths and labels for experiments
|
├── exp_train.py                # the python file with main function and parameter settings for meta-training
└── exp_test.py                 # the python file with main function and parameter settings for meta-testing

Running Experiments

First, download our processed images: miniImagenet[Download Page] or tieredImagenet[Download Page], move the unziped folder to ./data. And then download the pre-trained models: miniImagenet[Download Page] or tieredImagenet[Download Page], move the unziped folder to ./pretrain_weights_dir.

Training from Pre-Trained Models

Run semi-supervised meta-train phase (e.g. 𝑚𝑖𝑛𝑖ImageNet, 1-shot) :

python exp_train.py --shot_num=1 --dataset='miniImagenet' --pretrain_class_num=64 --nb_ul_samples=10 --metatrain_iterations=15000 --exp_name='LST_mini_1_shot'

Run semi-supervised meta-test phase (e.g. 𝑚𝑖𝑛𝑖ImageNet, 1-shot) :

python exp_test.py --shot_num=1 --dataset='miniImagenet' --pretrain_class_num=64 --use_distractors=False --nb_ul_samples=100 --unfiles_num=10 --test_iter=15000 --recurrent_stage_nums=6 --nums_in_folders=30 --hard_selection=20 --exp_name='LST_mini_1_shot' 

Hyperparameters and Options

There are some main hyperparameters used in the experiments, you can edit them in the exp_train.py and the exp_test.py file for meta-train and meta-test phase respectively. There are two kinds of hyperparameters: (1) common hyperparameters that shared with meta-train and meta-test, (2) test-specific hyperparameters that used for recurrent self-training process in meta-test.

  • Common hyperparameters:

    • way_num number of classes
    • shot_num number of examples per class
    • dataset dataset used in the experiment (miniImagenet or tieredImagenet)
    • pretrain_class_num number of meta-train classes
    • exp_name name for the current experiment
    • meta_batch_size number of tasks sampled per meta-update in meta-train phase
    • base_lr step size alpha for inner gradient update
    • meta_lr the meta learning rate for SS and initial model parameters
    • min_meta_lr the min meta learning rate for all meta-parameters
    • swn_lr the meta learning rate for SWN
    • nb_ul_samples number of unlabeled examples per class
    • re_train_epoch_num number of re-training inner gradient updates
    • train_base_epoch_num number of total inner gradient updates during train (meta-train only)
    • test_base_epoch_num number of total inner gradient updates during test (meta-test only)
  • Test-specific hyperparameters:

    • use_distractors if using distractor classes during meta-test
    • num_dis number of distracting classes used for meta-testing
    • unfiles_num number of unlabeled sample files used in the experiment (There are 10 unlabeled samples per class in each file)
    • recurrent_stage_nums number of recurrent stages used during meta-test
    • local_update_num number of inner gradient updates used in each recurrent stage
    • nums_in_folders number of unlabeled samples (per class) used in each recurrent stage
    • hard_selection number of remaining samples (per class) after applying hard-selection

If you want to change other settings, please see the comments and descriptions in exp_train.py and exp_test.py.

Performance

(%) 𝑚𝑖𝑛𝑖 𝒕𝒊𝒆𝒓𝒆𝒅 𝑚𝑖𝑛𝑖 (w/D) 𝒕𝒊𝒆𝒓𝒆𝒅 (w/D)
1-shot 70.1 ± 1.9 77.7 ± 1.6 64.1 ± 1.9 73.5 ± 1.6
5-shot 78.7 ± 0.8 85.2 ± 0.8 77.4 ± 1.8 83.4 ± 0.8

Citation

Please cite our paper if it is helpful to your work:

@inproceedings{li2019lst,
  title={Learning to Self-Train for Semi-Supervised Few-Shot Classification},
  author = {Li, Xinzhe and Sun, Qianru and Liu, Yaoyao and Zhou, Qin and Zheng, Shibao and Chua, Tat-Seng and Schiele, Bernt},
  booktitle={NeurIPS},
  year={2019}
}

Acknowledgements

Our implementations use the source code from the following repositories and users:

Transformers are Graph Neural Networks!

🚀 Gated Graph Transformers Gated Graph Transformers for graph-level property prediction, i.e. graph classification and regression. Associated article

Chaitanya Joshi 46 Jun 30, 2022
SporeAgent: Reinforced Scene-level Plausibility for Object Pose Refinement

SporeAgent: Reinforced Scene-level Plausibility for Object Pose Refinement This repository implements the approach described in SporeAgent: Reinforced

Dominik Bauer 5 Jan 02, 2023
Developing your First ML Workflow of the AWS Machine Learning Engineer Nanodegree Program

Exercises and project documentation for the 3. Developing your First ML Workflow of the AWS Machine Learning Engineer Nanodegree Program

Simona Mircheva 1 Jan 13, 2022
A Peer-to-peer Platform for Secure, Privacy-preserving, Decentralized Data Science

PyGrid is a peer-to-peer network of data owners and data scientists who can collectively train AI models using PySyft. PyGrid is also the central serv

OpenMined 615 Jan 03, 2023
HybridNets: End-to-End Perception Network

HybridNets: End2End Perception Network HybridNets Network Architecture. HybridNets: End-to-End Perception Network by Dat Vu, Bao Ngo, Hung Phan 📧 FPT

Thanh Dat Vu 370 Dec 29, 2022
A series of Python scripts to access measurements from Fluke 28X meters. Fluke IR Remote Interface required.

Fluke289_data_access A series of Python scripts to access measurements from Fluke 28X meters. Fluke IR Remote Interface required. Created from informa

3 Dec 08, 2022
Pose Detection and Machine Learning for real-time body posture analysis during exercise to provide audiovisual feedback on improvement of form.

Posture: Pose Tracking and Machine Learning for prescribing corrective suggestions to improve posture and form while exercising. This repository conta

Pratham Mehta 10 Nov 11, 2022
Real-Time Multi-Contact Model Predictive Control via ADMM

Here, you can find the code for the paper 'Real-Time Multi-Contact Model Predictive Control via ADMM'. Code is currently being cleared up and optimize

17 Dec 28, 2022
(ICCV 2021) Official code of "Dressing in Order: Recurrent Person Image Generation for Pose Transfer, Virtual Try-on and Outfit Editing."

Dressing in Order (DiOr) 👚 [Paper] 👖 [Webpage] 👗 [Running this code] The official implementation of "Dressing in Order: Recurrent Person Image Gene

Aiyu Cui 277 Dec 28, 2022
YOLOv4 / Scaled-YOLOv4 / YOLO - Neural Networks for Object Detection (Windows and Linux version of Darknet )

Yolo v4, v3 and v2 for Windows and Linux (neural networks for object detection) Paper YOLO v4: https://arxiv.org/abs/2004.10934 Paper Scaled YOLO v4:

Alexey 20.2k Jan 09, 2023
Automatic library of congress classification, using word embeddings from book titles and synopses.

Automatic Library of Congress Classification The Library of Congress Classification (LCC) is a comprehensive classification system that was first deve

Ahmad Pourihosseini 3 Oct 01, 2022
SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning

SPCL SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning Update on 2021/11/25: ArXiv Ver

Binhui Xie (谢斌辉) 11 Oct 29, 2022
The trained model and denoising example for paper : Cardiopulmonary Auscultation Enhancement with a Two-Stage Noise Cancellation Approach

The trained model and denoising example for paper : Cardiopulmonary Auscultation Enhancement with a Two-Stage Noise Cancellation Approach

ycj_project 1 Jan 18, 2022
[2021 MultiMedia] CONQUER: Contextual Query-aware Ranking for Video Corpus Moment Retrieval

CONQUER: Contexutal Query-aware Ranking for Video Corpus Moment Retreival PyTorch implementation of CONQUER: Contexutal Query-aware Ranking for Video

Hou zhijian 23 Dec 26, 2022
5 Jan 05, 2023
Two-Stage Peer-Regularized Feature Recombination for Arbitrary Image Style Transfer

Two-Stage Peer-Regularized Feature Recombination for Arbitrary Image Style Transfer Paper on arXiv Public PyTorch implementation of two-stage peer-reg

NNAISENSE 38 Oct 14, 2022
Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-like Documents.

Value Retrieval with Arbitrary Queries for Form-like Documents Introduction Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-

Salesforce 13 Sep 15, 2022
IAUnet: Global Context-Aware Feature Learning for Person Re-Identification

IAUnet This repository contains the code for the paper: IAUnet: Global Context-Aware Feature Learning for Person Re-Identification Ruibing Hou, Bingpe

30 Jul 14, 2022
TCube generates rich and fluent narratives that describes the characteristics, trends, and anomalies of any time-series data (domain-agnostic) using the transfer learning capabilities of PLMs.

TCube: Domain-Agnostic Neural Time series Narration This repository contains the code for the paper: "TCube: Domain-Agnostic Neural Time series Narrat

Mandar Sharma 7 Oct 31, 2021
Blind visual quality assessment on 360° Video based on progressive learning

Blind visual quality assessment on omnidirectional or 360 video (ProVQA) Blind VQA for 360° Video via Progressively Learning from Pixels, Frames and V

5 Jan 06, 2023