Code for MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks

Overview

MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks

This is the code for the paper:

MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks on Corrupted Labels
Lu Jiang, Zhengyuan Zhou, Thomas Leung, Li-Jia Li, Li Fei-Fei
Presented at ICML 2018

Please note that this is not an officially supported Google product.

If you find this code useful in your research then please cite

@inproceedings{jiang2018mentornet,
  title={MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks on Corrupted Labels},
  author={Jiang, Lu and Zhou, Zhengyuan and Leung, Thomas and Li, Li-Jia and Fei-Fei, Li},
  booktitle={ICML},
  year={2018}
}

Introduction

We are interested in training a deep network using curriculum learning (Bengio et al., 2009), i.e. learning examples with focus. Each curriculum is implemented as a network (called MentorNet).

  • During training, MentorNet supervises the training of the base network (called StudentNet).
  • At the test time, StudentNet makes prediction alone without MentorNet.

Training Overview

Setups

All code was developed and tested on Nvidia V100/P100 (16GB) the following environment.

  • Ubuntu 18.04
  • Python 2.7.15
  • TensorFlow 1.8.0
  • numpy 1.13.3
  • imageio 2.3.0

Download Cloud SDK to get data and models. Next we need to download the dataset and pre-trained MentorNet models. Put them into the same directory as the code directory.

gsutil -m cp -r gs://mentornet_project/data .
gsutil -m cp -r gs://mentornet_project/mentornet_models .

Alternatively, you may download the zip files: data and models.

Running MentorNet on CIFAR

export PYTHONPATH="$PYTHONPATH:$PWD/code/"

python code/cifar_train_mentornet.py \
  --dataset_name=cifar10   \
  --trained_mentornet_dir=mentornet_models/models/mentornet_pd1_g_1/mentornet_pd \
  --loss_p_precentile=0.75  \
  --nofixed_epoch_after_burn_in  \
  --burn_in_epoch=0  \
  --example_dropout_rates="0.5,17,0.05,83" \
  --data_dir=data/cifar10/0.2 \
  --train_log_dir=cifar_models/cifar10/resnet/0.2/mentornet_pd1_g_1/train \
  --studentnet=resnet101 \
  --max_number_of_steps=39000

A full list of commands can be found in this file. The training script has a number of command-line flags that you can use to configure the model architecture, hyperparameters, and input / output settings:

  • --trained_mentornet_dir: Directory where to find the trained MentorNet model, created by mentornet_learning/train.py.
  • --loss_p_percentile: p-percentile used to compute the loss moving average. Default is 0.7.
  • --burn_in_epoch: Number of first epochs to perform burn-in. In the burn-in period, every sample has a fixed 1.0 weight. Default is 0.
  • --fixed_epoch_after_burn_in: Whether to use the fixed epoch as the MentorNet input feature after the burn-in period. Set True for MentorNet DD. Default is False.
  • --loss_moving_average_decay: Decay factor used in moving average. Default is 0.5.
  • --example_dropout_rates: Comma-separated list indicating the example drop-out rate for the total of 100 epochs. The format is [dropout rate, epoch_num]+, the piecewise drop-out rate from boundaries and values. The sum of epoch_num is 100. Drop-out means the probability of setting sample weights to zeros proposed (Liang et al., 2016). Default is 0.5, 17, 0.05, 78, 1.0, 5.

To evaluate a model, run the evaluation job in parallel with the training job (on a different GPU).

python cifar/cifar_eval.py \
 --dataset_name=cifar10 \
 --data_dir=cifar/data/cifar10/val/ \
 --checkpoint_dir=cifar_models/cifar10/resnet/0.2/mentornet_pd1_g_1/train \
 --eval_dir=cifar_models/cifar10/resnet/0.2/mentornet_pd1_g_1//eval_val \
 --studentnet=resnet101 \
 --device_id=1

A complete list of commands of running experiments can be found at commands/train_studentnet_resnet.sh and commands/train_studentnet_inception.sh.

MentorNet Framework

MentorNet is a general framework for curriculum learning, where various curriculums can be learned by the same MentorNet structure of different parameters.

It is flexible as we can switch curriculums by attaching different MentorNets without modifying the pipeline.

We train a few MentorNets listed below. We can think of a MentorNet as a hyper-parameter and will be tuned for different problems.

Curriculum Visualization Intuition Model Name
No curriculum image Assign uniform weight to every sample uniform. baseline_mentornet
Self-paced
(Kuma et al. 2010)
image Favor samples of smaller loss. self_paced_mentornet
SPCL linear
(Jiang et al. 2015)
image Discount the weight by loss linearly. spcl_linear_mentornet
Hard example mining
(Felzenszwalb et al., 2008)
image Favor samples of greater loss. hard_example_mining_mentornet
Focal loss
(Lin et al., 2017)
image Increase the weight by loss by the exponential CDF. focal_loss_mentornet
Predefined Mixture image Mixture of SPL and SPCL changing by epoch. mentornet_pd
MentorNet Data-driven image Learned on a small subset of the CIFAR data. mentornet_dd

Note there are many more curriculums can be trained by MentorNet, for example, prediction variance (Chang et al., 2017), implicit regularizer (Fan et al. 2017), self-paced with diversity (Jiang et al. 2014), sample re-weighting (Dehghani et al., 2018, Ren et al., 2018), etc.

Performance

The numbers are slightly different from the ones reported in the paper due to the re-implementation on the third party library.

CIFAR-10 ResNet

noise_fraction baseline self_paced focal_loss mentornet_pd mentornet_dd
0.2 0.796 0.822 0.797 0.910 0.914
0.4 0.568 0.802 0.634 0.776 0.887
0.8 0.238 0.297 0.25 0.283 0.463

CIFAR-100 ResNet

noise_fraction baseline self_paced focal_loss mentornet_pd mentornet_dd
0.2 0.624 0.652 0.613 0.733 0.726
0.4 0.448 0.509 0.467 0.567 0.675
0.8 0.084 0.089 0.079 0.193 0.301

CIFAR-10 Inception

noise_fraction baseline self_paced focal_loss mentornet_pd mentornet_dd
0.2 0.775 0.784 0.747 0.798 0.800
0.4 0.72 0.733 0.695 0.731 0.763
0.8 0.29 0.272 0.309 0.312 0.461

CIFAR-100 Inception

noise_fraction baseline self_paced focal_loss mentornet_pd mentornet_dd
0.2 0.42 0.408 0.391 0.451 0.466
0.4 0.346 0.32 0.313 0.386 0.411
0.8 0.108 0.091 0.107 0.125 0.203

Algorithm

We propose an algorithm to optimize the StudentNet model parameter w jointly with a

given MentorNet. Unlike the alternating minimization, it minimizes w (StudentNet parameter) and v (sample weight) stochastically over mini-batches.

The curriculum can change during training, and MentorNet is updated a few times in the algorithm.

Algorithm

To learn new curriculums (Step 6), see this page.

We found specific MentorNet architectures do not matter that much.

References

  • Bengio, Yoshua, et al. "Curriculum learning". In ICML, 2009.
  • Kumar M. Pawan, Packer Benjamin, and Koller Daphne "Self-paced learning for latent variable models". In NIPS, 2010.
  • Jiang, Lu et al. "Self-paced Learning with Diversity", In NIPS 2014
  • Jiang, Lu, et al. "Self-Paced Curriculum Learning." In AAAI. 2015.
  • Liang, Junwei et al. Learning to Detect Concepts from Webly-Labeled Video Data, In IJCAI 2016.
  • Lin, Tsung-Yi, et al. "Focal loss for dense object detection." In ICCV. 2017.
  • Fan, Yanbo, et al. "Self-Paced Learning: an Implicit Regularization Perspective." In AAAI 2017.
  • Felzenszwalb, Pedro, et al. "A discriminatively trained, multiscale, deformable part model." In CVPR 2008.
  • Dehghani, Mostafa, et al. "Fidelity-Weighted Learning." In ICLR 2018.
  • Ren, Mengye, et al. "Learning to reweight examples for robust deep learning." In ICML 2018.
  • Fan, Yang, et al. "Learning to Teach." In ICLR 2018.
  • Chang, Haw-Shiuan, et al. "Active Bias: Training More Accurate Neural Networks by Emphasizing High Variance Samples." In NIPS 2017.
Owner
Google
Google ❤️ Open Source
Google
VISNOTATE: An Opensource tool for Gaze-based Annotation of WSI Data

VISNOTATE: An Opensource tool for Gaze-based Annotation of WSI Data Introduction Requirements Installation and Setup Supported Hardware and Software R

SigmaLab 1 Jun 14, 2022
Multispectral Object Detection with Yolov5

Multispectral-Object-Detection Intro Official Code for Cross-Modality Fusion Transformer for Multispectral Object Detection. Multispectral Object Dete

Richard Fang 121 Jan 01, 2023
Official PyTorch Implementation of GAN-Supervised Dense Visual Alignment

GAN-Supervised Dense Visual Alignment — Official PyTorch Implementation Paper | Project Page | Video This repo contains training, evaluation and visua

944 Jan 07, 2023
Cancer-and-Tumor-Detection-Using-Inception-model - In this repo i am gonna show you how i did cancer/tumor detection in lungs using deep neural networks, specifically here the Inception model by google.

Cancer-and-Tumor-Detection-Using-Inception-model In this repo i am gonna show you how i did cancer/tumor detection in lungs using deep neural networks

Deepak Nandwani 1 Jan 01, 2022
Tensorflow implementation of soft-attention mechanism for video caption generation.

SA-tensorflow Tensorflow implementation of soft-attention mechanism for video caption generation. An example of soft-attention mechanism. The attentio

Paul Chen 153 Nov 14, 2022
PyTorch DepthNet Training on Still Box dataset

DepthNet training on Still Box Project page This code can replicate the results of our paper that was published in UAVg-17. If you use this repo in yo

Clément Pinard 115 Nov 21, 2022
Codes to pre-train T5 (Text-to-Text Transfer Transformer) models pre-trained on Japanese web texts

t5-japanese Codes to pre-train T5 (Text-to-Text Transfer Transformer) models pre-trained on Japanese web texts. The following is a list of models that

Kimio Kuramitsu 1 Dec 13, 2021
A tool to prepare websites grabbed with wget for local viewing.

makelocal A tool to prepare websites grabbed with wget for local viewing. exapmples After fetching xkcd.com with: wget -r -no-remove-listing -r -N --p

5 Apr 23, 2022
JumpDiff: Non-parametric estimator for Jump-diffusion processes for Python

jumpdiff jumpdiff is a python library with non-parametric Nadaraya─Watson estimators to extract the parameters of jump-diffusion processes. With jumpd

Rydin 28 Dec 10, 2022
Código de um painel de auto atendimento feito em Python.

Painel de Auto-Atendimento O intuito desse projeto era fazer em Python um programa que simulasse um painel de auto atendimento, no maior estilo Mac Do

Calebe Alves Evangelista 2 Nov 09, 2022
DLFlow is a deep learning framework.

DLFlow是一套深度学习pipeline,它结合了Spark的大规模特征处理能力和Tensorflow模型构建能力。利用DLFlow可以快速处理原始特征、训练模型并进行大规模分布式预测,十分适合离线环境下的生产任务。利用DLFlow,用户只需专注于模型开发,而无需关心原始特征处理、pipeline构建、生产部署等工作。

DiDi 152 Oct 27, 2022
Self-Supervised Learning with Kernel Dependence Maximization

Self-Supervised Learning with Kernel Dependence Maximization This is the code for SSL-HSIC, a self-supervised learning loss proposed in the paper Self

DeepMind 29 Dec 29, 2022
A Research-oriented Federated Learning Library and Benchmark Platform for Graph Neural Networks. Accepted to ICLR'2021 - DPML and MLSys'21 - GNNSys workshops.

FedGraphNN: A Federated Learning System and Benchmark for Graph Neural Networks A Research-oriented Federated Learning Library and Benchmark Platform

FedML-AI 175 Dec 01, 2022
MMFlow is an open source optical flow toolbox based on PyTorch

Documentation: https://mmflow.readthedocs.io/ Introduction English | 简体中文 MMFlow is an open source optical flow toolbox based on PyTorch. It is a part

OpenMMLab 688 Jan 06, 2023
Torch implementation of SegNet and deconvolutional network

Torch implementation of SegNet and deconvolutional network

Fedor Chervinskii 5 Jul 17, 2020
Transformer part of 12th place solution in Riiid! Answer Correctness Prediction

kaggle_riiid Transformer part of 12th place solution in Riiid! Answer Correctness Prediction. Please see here for more information. Execution You need

Sakami Kosuke 2 Apr 23, 2022
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.

[ICLR 2021] RAPID: A Simple Approach for Exploration in Reinforcement Learning This is the Tensorflow implementation of ICLR 2021 paper Rank the Episo

Daochen Zha 48 Nov 21, 2022
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

Karbo 45 Dec 21, 2022
Unofficial PyTorch implementation of Fastformer based on paper "Fastformer: Additive Attention Can Be All You Need"."

Fastformer-PyTorch Unofficial PyTorch implementation of Fastformer based on paper Fastformer: Additive Attention Can Be All You Need. Usage : import t

Hong-Jia Chen 126 Dec 06, 2022
Official code for "EagerMOT: 3D Multi-Object Tracking via Sensor Fusion" [ICRA 2021]

EagerMOT: 3D Multi-Object Tracking via Sensor Fusion Read our ICRA 2021 paper here. Check out the 3 minute video for the quick intro or the full prese

Aleksandr Kim 276 Dec 30, 2022