A state-of-the-art semi-supervised method for image recognition

Overview

Mean teachers are better role models

Paper ---- NIPS 2017 poster ---- NIPS 2017 spotlight slides ---- Blog post

By Antti Tarvainen, Harri Valpola (The Curious AI Company)

Approach

Mean Teacher is a simple method for semi-supervised learning. It consists of the following steps:

  1. Take a supervised architecture and make a copy of it. Let's call the original model the student and the new one the teacher.
  2. At each training step, use the same minibatch as inputs to both the student and the teacher but add random augmentation or noise to the inputs separately.
  3. Add an additional consistency cost between the student and teacher outputs (after softmax).
  4. Let the optimizer update the student weights normally.
  5. Let the teacher weights be an exponential moving average (EMA) of the student weights. That is, after each training step, update the teacher weights a little bit toward the student weights.

Our contribution is the last step. Laine and Aila [paper] used shared parameters between the student and the teacher, or used a temporal ensemble of teacher predictions. In comparison, Mean Teacher is more accurate and applicable to large datasets.

Mean Teacher model

Mean Teacher works well with modern architectures. Combining Mean Teacher with ResNets, we improved the state of the art in semi-supervised learning on the ImageNet and CIFAR-10 datasets.

ImageNet using 10% of the labels top-5 validation error
Variational Auto-Encoder [paper] 35.42 ± 0.90
Mean Teacher ResNet-152 9.11 ± 0.12
All labels, state of the art [paper] 3.79
CIFAR-10 using 4000 labels test error
CT-GAN [paper] 9.98 ± 0.21
Mean Teacher ResNet-26 6.28 ± 0.15
All labels, state of the art [paper] 2.86

Implementation

There are two implementations, one for TensorFlow and one for PyTorch. The PyTorch version is probably easier to adapt to your needs, since it follows typical PyTorch idioms, and there's a natural place to add your model and dataset. Let me know if anything needs clarification.

Regarding the results in the paper, the experiments using a traditional ConvNet architecture were run with the TensorFlow version. The experiments using residual networks were run with the PyTorch version.

Tips for choosing hyperparameters and other tuning

Mean Teacher introduces two new hyperparameters: EMA decay rate and consistency cost weight. The optimal value for each of these depends on the dataset, the model, and the composition of the minibatches. You will also need to choose how to interleave unlabeled samples and labeled samples in minibatches.

Here are some rules of thumb to get you started:

  • If you are working on a new dataset, it may be easiest to start with only labeled data and do pure supervised training. Then when you are happy with the architecture and hyperparameters, add mean teacher. The same network should work well, although you may want to tune down regularization such as weight decay that you have used with small data.
  • Mean Teacher needs some noise in the model to work optimally. In practice, the best noise is probably random input augmentations. Use whatever relevant augmentations you can think of: the algorithm will train the model to be invariant to them.
  • It's useful to dedicate a portion of each minibatch for labeled examples. Then the supervised training signal is strong enough early on to train quickly and prevent getting stuck into uncertainty. In the PyTorch examples we have a quarter or a half of the minibatch for the labeled examples and the rest for the unlabeled. (See TwoStreamBatchSampler in Pytorch code.)
  • For EMA decay rate 0.999 seems to be a good starting point.
  • You can use either MSE or KL-divergence as the consistency cost function. For KL-divergence, a good consistency cost weight is often between 1.0 and 10.0. For MSE, it seems to be between the number of classes and the number of classes squared. On small datasets we saw MSE getting better results, but KL always worked pretty well too.
  • It may help to ramp up the consistency cost in the beginning over the first few epochs until the teacher network starts giving good predictions.
  • An additional trick we used in the PyTorch examples: Have two seperate logit layers at the top level. Use one for classification of labeled examples and one for predicting the teacher output. And then have an additional cost between the logits of these two predictions. The intent is the same as with the consistency cost rampup: in the beginning the teacher output may be wrong, so loosen the link between the classification prediction and the consistency cost. (See the --logit-distance-cost argument in the PyTorch implementation.)
Owner
Curious AI
Deep good. Unsupervised better.
Curious AI
Akshat Surolia 2 May 11, 2022
The datasets and code of ACL 2021 paper "Aspect-Category-Opinion-Sentiment Quadruple Extraction with Implicit Aspects and Opinions".

Aspect-Category-Opinion-Sentiment (ACOS) Quadruple Extraction This repo contains the data sets and source code of our paper: Aspect-Category-Opinion-S

NUSTM 144 Jan 02, 2023
Perfect implement. Model shared. x0.5 (Top1:60.646) and 1.0x (Top1:69.402).

Shufflenet-v2-Pytorch Introduction This is a Pytorch implementation of faceplusplus's ShuffleNet-v2. For details, please read the following papers:

423 Dec 07, 2022
Implementation of various Vision Transformers I found interesting

Implementation of various Vision Transformers I found interesting

Kim Seonghyeon 78 Dec 06, 2022
Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation

Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation Requirements This repository needs mmsegmentation Training To train

Adelaide Intelligent Machines (AIM) Group 7 Sep 12, 2022
Learning to Segment Instances in Videos with Spatial Propagation Network

Learning to Segment Instances in Videos with Spatial Propagation Network This paper is available at the 2017 DAVIS Challenge website. Check our result

Jingchun Cheng 145 Sep 28, 2022
Mall-Customers-Segmentation - Customer Segmentation Using K-Means Clustering

Overview Customer Segmentation is one the most important applications of unsupervised learning. Using clustering techniques, companies can identify th

NelakurthiSudheer 2 Jan 03, 2022
In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.

cdf_att_classification classes = {0: 'cat', 1: 'dog', 2: 'flower'} In this project we use both Resnet and Self-attention layer for cdf-Classification.

3 Nov 23, 2022
DeepSTD: Mining Spatio-temporal Disturbances of Multiple Context Factors for Citywide Traffic Flow Prediction

DeepSTD: Mining Spatio-temporal Disturbances of Multiple Context Factors for Citywide Traffic Flow Prediction This is the implementation of DeepSTD in

5 Sep 26, 2022
CoTr: Efficiently Bridging CNN and Transformer for 3D Medical Image Segmentation

CoTr: Efficient 3D Medical Image Segmentation by bridging CNN and Transformer This is the official pytorch implementation of the CoTr: Paper: CoTr: Ef

218 Dec 25, 2022
Official implementation of Sparse Transformer-based Action Recognition

STAR Official implementation of S parse T ransformer-based A ction R ecognition Dataset download NTU RGB+D 60 action recognition of 2D/3D skeleton fro

Chonghan_Lee 15 Nov 02, 2022
IRON Kaggle project done while doing IRONHACK Bootcamp where we had to analyze and use a Machine Learning Project to predict future sales

IRON Kaggle project done while doing IRONHACK Bootcamp where we had to analyze and use a Machine Learning Project to predict future sales. In this case, we ended up using XGBoost because it was the o

1 Jan 04, 2022
Bringing Characters to Life with Computer Brains in Unity

AI4Animation: Deep Learning for Character Control This project explores the opportunities of deep learning for character animation and control as part

Sebastian Starke 5.5k Jan 04, 2023
Detecting Blurred Ground-based Sky/Cloud Images

Detecting Blurred Ground-based Sky/Cloud Images With the spirit of reproducible research, this repository contains all the codes required to produce t

1 Oct 20, 2021
This program presents convolutional kernel density estimation, a method used to detect intercritical epilpetic spikes (IEDs)

Description This program presents convolutional kernel density estimation, a method used to detect intercritical epilpetic spikes (IEDs) in [Gardy et

Ludovic Gardy 0 Feb 09, 2022
A Moonraker plug-in for real-time compensation of frame thermal expansion

Frame Expansion Compensation A Moonraker plug-in for real-time compensation of frame thermal expansion. Installation Credit to protoloft, from whom I

58 Jan 02, 2023
Kaggle | 9th place single model solution for TGS Salt Identification Challenge

UNet for segmenting salt deposits from seismic images with PyTorch. General We, tugstugi and xuyuan, have participated in the Kaggle competition TGS S

Erdene-Ochir Tuguldur 276 Dec 20, 2022
Official PyTorch implementation of "Synthesis of Screentone Patterns of Manga Characters"

Manga Character Screentone Synthesis Official PyTorch implementation of "Synthesis of Screentone Patterns of Manga Characters" presented in IEEE ISM 2

Tsubota 2 Nov 20, 2021
Tensor-based approaches for fMRI classification

tensor-fmri Using tensor-based approaches to classify fMRI data from StarPLUS. Citation If you use any code in this repository, please cite the follow

4 Sep 07, 2022
Galactic and gravitational dynamics in Python

Gala is a Python package for Galactic and gravitational dynamics. Documentation The documentation for Gala is hosted on Read the docs. Installation an

Adrian Price-Whelan 101 Dec 22, 2022