A state-of-the-art semi-supervised method for image recognition

Overview

Mean teachers are better role models

Paper ---- NIPS 2017 poster ---- NIPS 2017 spotlight slides ---- Blog post

By Antti Tarvainen, Harri Valpola (The Curious AI Company)

Approach

Mean Teacher is a simple method for semi-supervised learning. It consists of the following steps:

  1. Take a supervised architecture and make a copy of it. Let's call the original model the student and the new one the teacher.
  2. At each training step, use the same minibatch as inputs to both the student and the teacher but add random augmentation or noise to the inputs separately.
  3. Add an additional consistency cost between the student and teacher outputs (after softmax).
  4. Let the optimizer update the student weights normally.
  5. Let the teacher weights be an exponential moving average (EMA) of the student weights. That is, after each training step, update the teacher weights a little bit toward the student weights.

Our contribution is the last step. Laine and Aila [paper] used shared parameters between the student and the teacher, or used a temporal ensemble of teacher predictions. In comparison, Mean Teacher is more accurate and applicable to large datasets.

Mean Teacher model

Mean Teacher works well with modern architectures. Combining Mean Teacher with ResNets, we improved the state of the art in semi-supervised learning on the ImageNet and CIFAR-10 datasets.

ImageNet using 10% of the labels top-5 validation error
Variational Auto-Encoder [paper] 35.42 ± 0.90
Mean Teacher ResNet-152 9.11 ± 0.12
All labels, state of the art [paper] 3.79
CIFAR-10 using 4000 labels test error
CT-GAN [paper] 9.98 ± 0.21
Mean Teacher ResNet-26 6.28 ± 0.15
All labels, state of the art [paper] 2.86

Implementation

There are two implementations, one for TensorFlow and one for PyTorch. The PyTorch version is probably easier to adapt to your needs, since it follows typical PyTorch idioms, and there's a natural place to add your model and dataset. Let me know if anything needs clarification.

Regarding the results in the paper, the experiments using a traditional ConvNet architecture were run with the TensorFlow version. The experiments using residual networks were run with the PyTorch version.

Tips for choosing hyperparameters and other tuning

Mean Teacher introduces two new hyperparameters: EMA decay rate and consistency cost weight. The optimal value for each of these depends on the dataset, the model, and the composition of the minibatches. You will also need to choose how to interleave unlabeled samples and labeled samples in minibatches.

Here are some rules of thumb to get you started:

  • If you are working on a new dataset, it may be easiest to start with only labeled data and do pure supervised training. Then when you are happy with the architecture and hyperparameters, add mean teacher. The same network should work well, although you may want to tune down regularization such as weight decay that you have used with small data.
  • Mean Teacher needs some noise in the model to work optimally. In practice, the best noise is probably random input augmentations. Use whatever relevant augmentations you can think of: the algorithm will train the model to be invariant to them.
  • It's useful to dedicate a portion of each minibatch for labeled examples. Then the supervised training signal is strong enough early on to train quickly and prevent getting stuck into uncertainty. In the PyTorch examples we have a quarter or a half of the minibatch for the labeled examples and the rest for the unlabeled. (See TwoStreamBatchSampler in Pytorch code.)
  • For EMA decay rate 0.999 seems to be a good starting point.
  • You can use either MSE or KL-divergence as the consistency cost function. For KL-divergence, a good consistency cost weight is often between 1.0 and 10.0. For MSE, it seems to be between the number of classes and the number of classes squared. On small datasets we saw MSE getting better results, but KL always worked pretty well too.
  • It may help to ramp up the consistency cost in the beginning over the first few epochs until the teacher network starts giving good predictions.
  • An additional trick we used in the PyTorch examples: Have two seperate logit layers at the top level. Use one for classification of labeled examples and one for predicting the teacher output. And then have an additional cost between the logits of these two predictions. The intent is the same as with the consistency cost rampup: in the beginning the teacher output may be wrong, so loosen the link between the classification prediction and the consistency cost. (See the --logit-distance-cost argument in the PyTorch implementation.)
Owner
Curious AI
Deep good. Unsupervised better.
Curious AI
Studying Python release adoptions by looking at PyPI downloads

Analysis of version adoptions on PyPI We get PyPI download statistics via Google's BigQuery using the pypinfo tool. Usage First you need to get an acc

Julien Palard 9 Nov 04, 2022
Multi-task Self-supervised Object Detection via Recycling of Bounding Box Annotations (CVPR, 2019)

Multi-task Self-supervised Object Detection via Recycling of Bounding Box Annotations (CVPR 2019) To make better use of given limited labels, we propo

126 Sep 13, 2022
TeST: Temporal-Stable Thresholding for Semi-supervised Learning

TeST: Temporal-Stable Thresholding for Semi-supervised Learning TeST Illustration Semi-supervised learning (SSL) offers an effective method for large-

Xiong Weiyu 1 Jul 14, 2022
SingleVC performs any-to-one VC, which is an important component of MediumVC project.

SingleVC performs any-to-one VC, which is an important component of MediumVC project. Here is the official implementation of the paper, MediumVC.

谷下雨 26 Dec 28, 2022
Temporal Knowledge Graph Reasoning Triggered by Memories

MTDM Temporal Knowledge Graph Reasoning Triggered by Memories To alleviate the time dependence, we propose a memory-triggered decision-making (MTDM) n

4 Sep 25, 2022
TOOD: Task-aligned One-stage Object Detection, ICCV2021 Oral

One-stage object detection is commonly implemented by optimizing two sub-tasks: object classification and localization, using heads with two parallel branches, which might lead to a certain level of

264 Jan 09, 2023
Magic tool for managing internet connection in local network by @zalexdev

Megacut ✂️ A new powerful Python3 tool for managing internet on a local network Installation git clone https://github.com/stryker-project/megacut cd m

Stryker 12 Dec 15, 2022
Keras + Hyperopt: A very simple wrapper for convenient hyperparameter optimization

This project is now archived. It's been fun working on it, but it's time for me to move on. Thank you for all the support and feedback over the last c

Max Pumperla 2.1k Jan 03, 2023
(AAAI2022) Style Mixing and Patchwise Prototypical Matching for One-Shot Unsupervised Domain Adaptive Semantic Segmentation

SM-PPM This is a Pytorch implementation of our paper "Style Mixing and Patchwise Prototypical Matching for One-Shot Unsupervised Domain Adaptive Seman

W-zx-Y 10 Dec 07, 2022
本步态识别系统主要基于GaitSet模型进行实现

本步态识别系统主要基于GaitSet模型进行实现。在尝试部署本系统之前,建立理解GaitSet模型的网络结构、训练和推理方法。 系统的实现效果如视频所示: 演示视频 由于模型较大,部分模型文件存储在百度云盘。 链接提取码:33mb 具体部署过程 1.下载代码 2.安装requirements.txt

16 Oct 22, 2022
Official implementation of Long-Short Transformer in PyTorch.

Long-Short Transformer (Transformer-LS) This repository hosts the code and models for the paper: Long-Short Transformer: Efficient Transformers for La

NVIDIA Corporation 198 Dec 29, 2022
Everything's Talkin': Pareidolia Face Reenactment (CVPR2021)

Everything's Talkin': Pareidolia Face Reenactment (CVPR2021) Linsen Song, Wayne Wu, Chaoyou Fu, Chen Qian, Chen Change Loy, and Ran He [Paper], [Video

71 Dec 21, 2022
[CVPR'21] FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space by Quande Liu, Cheng Chen, Ji

Quande Liu 178 Jan 06, 2023
Meta Representation Transformation for Low-resource Cross-lingual Learning

MetaXL: Meta Representation Transformation for Low-resource Cross-lingual Learning This repo hosts the code for MetaXL, published at NAACL 2021. [Meta

Microsoft 36 Aug 17, 2022
Repository for the paper "From global to local MDI variable importances for random forests and when they are Shapley values"

From global to local MDI variable importances for random forests and when they are Shapley values Antonio Sutera ( Antonio Sutera 3 Feb 23, 2022

Code for the paper: "On the Bottleneck of Graph Neural Networks and Its Practical Implications"

On the Bottleneck of Graph Neural Networks and its Practical Implications This is the official implementation of the paper: On the Bottleneck of Graph

75 Dec 22, 2022
[BMVC2021] The official implementation of "DomainMix: Learning Generalizable Person Re-Identification Without Human Annotations"

DomainMix [BMVC2021] The official implementation of "DomainMix: Learning Generalizable Person Re-Identification Without Human Annotations" [paper] [de

Wenhao Wang 17 Dec 20, 2022
An efficient PyTorch implementation of the evaluation metrics in recommender systems.

recsys_metrics An efficient PyTorch implementation of the evaluation metrics in recommender systems. Overview • Installation • How to use • Benchmark

Xingdong Zuo 12 Dec 02, 2022
classify fashion-mnist dataset with pytorch

Fashion-Mnist Classifier with PyTorch Inference 1- clone this repository: git clone https://github.com/Jhamed7/Fashion-Mnist-Classifier.git 2- Instal

1 Jan 14, 2022
Implementations of CNNs, RNNs, GANs, etc

Tensorflow Programs and Tutorials This repository will contain Tensorflow tutorials on a lot of the most popular deep learning concepts. It'll also co

Adit Deshpande 1k Dec 30, 2022