2.86% and 15.85% on CIFAR-10 and CIFAR-100

Overview

Shake-Shake regularization

This repository contains the code for the paper Shake-Shake regularization. This arxiv paper is an extension of Shake-Shake regularization of 3-branch residual networks which was accepted as a workshop contribution at ICLR 2017.

The code is based on fb.resnet.torch.

Table of Contents

  1. Introduction
  2. Results
  3. Usage
  4. Contact

Introduction

The method introduced in this paper aims at helping deep learning practitioners faced with an overfit problem. The idea is to replace, in a multi-branch network, the standard summation of parallel branches with a stochastic affine combination. Applied to 3-branch residual networks, shake-shake regularization improves on the best single shot published results on CIFAR-10 and CIFAR-100 by reaching test errors of 2.86% and 15.85%.

shake-shake

Figure 1: Left: Forward training pass. Center: Backward training pass. Right: At test time.

Bibtex:

@article{Gastaldi17ShakeShake,
   title = {Shake-Shake regularization},
   author = {Xavier Gastaldi},
   journal = {arXiv preprint arXiv:1705.07485},
   year = 2017,
}

Results on CIFAR-10

The base network is a 26 2x32d ResNet (i.e. the network has a depth of 26, 2 residual branches and the first residual block has a width of 32). "Shake" means that all scaling coefficients are overwritten with new random numbers before the pass. "Even" means that all scaling coefficients are set to 0.5 before the pass. "Keep" means that we keep, for the backward pass, the scaling coefficients used during the forward pass. "Batch" means that, for each residual block, we apply the same scaling coefficient for all the images in the mini-batch. "Image" means that, for each residual block, we apply a different scaling coefficient for each image in the mini-batch. The numbers in the Table below represent the average of 3 runs except for the 96d models which were run 5 times.

Forward Backward Level 26 2x32d 26 2x64d 26 2x96d 26 2x112d
Even Even n\a 4.27 3.76 3.58 -
Even Shake Batch 4.44 - -
Shake Keep Batch 4.11 - - -
Shake Even Batch 3.47 3.30 - -
Shake Shake Batch 3.67 3.07 - -
Even Shake Image 4.11 - - -
Shake Keep Image 4.09 - - -
Shake Even Image 3.47 3.20 - -
Shake Shake Image 3.55 2.98 2.86 2.821

Table 1: Error rates (%) on CIFAR-10 (Top 1 of the last epoch)

Other results

Cifar-100:
29 2x4x64d: 15.85%

Reduced CIFAR-10:
26 2x96d: 17.05%1

SVHN:
26 2x96d: 1.4%1

Reduced SVHN:
26 2x96d: 12.32%1

Usage

  1. Install fb.resnet.torch, optnet and lua-stdlib.
  2. Download Shake-Shake
git clone https://github.com/xgastaldi/shake-shake.git
  1. Copy the elements in the shake-shake folder and paste them in the fb.resnet.torch folder. This will overwrite 5 files (main.lua, train.lua, opts.lua, checkpoints.lua and models/init.lua) and add 4 new files (models/shakeshake.lua, models/shakeshakeblock.lua, models/mulconstantslices.lua and models/shakeshaketable.lua).
  2. To reproduce CIFAR-10 results (e.g. 26 2x32d "Shake-Shake-Image" ResNet) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true

To get comparable results using 1 GPU, please change the batch size and the corresponding learning rate:

CUDA_VISIBLE_DEVICES=0 th main.lua -dataset cifar10 -nGPU 1 -batchSize 64 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.1 -forwardShake true -backwardShake true -shakeImage true

A 26 2x96d "Shake-Shake-Image" ResNet can be trained on 2 GPUs using:

CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 96 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true
  1. To reproduce CIFAR-100 results (e.g. 29 2x4x64d "Shake-Even-Image" ResNeXt) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar100 -depth 29 -baseWidth 64 -groups 4 -weightDecay 5e-4 -batchSize 32 -netType shakeshake -nGPU 2 -LR 0.025 -nThreads 8 -shareGradInput true -nEpochs 1800 -lrShape cosine -forwardShake true -backwardShake false -shakeImage true

Note

Changes made to fb.resnet.torch files:

main.lua
Ln 17, 54-59, 81-100: Adds a log

train.lua
Ln 36-38 58-60 206-213: Adds the cosine learning rate function
Ln 88-89: Adds the learning rate to the elements printed on screen

opts.lua
Ln 21-64: Adds Shake-Shake options

checkpoints.lua
Ln 15-16: Adds require 'models/shakeshakeblock', 'models/shakeshaketable' and require 'std'
Ln 60-61: Avoids using the fb.resnet.torch deepcopy (it doesn't seem to be compatible with the BN in shakeshakeblock) and replaces it with the deepcopy from stdlib
Ln 67-86: Saves only the last model

models/init.lua
Ln 91-92: Adds require 'models/mulconstantslices', require 'models/shakeshakeblock' and require 'models/shakeshaketable'

The main model is in shakeshake.lua. The residual block model is in shakeshakeblock.lua. mulconstantslices.lua is just an extension of nn.mulconstant that multiplies elements of a vector with image slices of a mini-batch tensor. shakeshaketable.lua contains the method used for CIFAR-100 since the ResNeXt code uses a table implementation instead of a module version.

Reimplementations

Pytorch
https://github.com/hysts/pytorch_shake_shake

Tensorflow
https://github.com/tensorflow/models/blob/master/research/autoaugment/
https://github.com/tensorflow/tensor2tensor

Contact

xgastaldi.mba2011 at london.edu
Any discussions, suggestions and questions are welcome!

References

(1) Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V. Le. AutoAugment: Learning Augmentation Policies from Data. In arXiv:1805.09501, May 2018.

AdelaiDet is an open source toolbox for multiple instance-level detection and recognition tasks.

AdelaiDet is an open source toolbox for multiple instance-level detection and recognition tasks.

Adelaide Intelligent Machines (AIM) Group 3k Jan 02, 2023
Demonstrates how to divide a DL model into multiple IR model files (division) and introduce a simplest way to implement a custom layer works with OpenVINO IR models.

Demonstration of OpenVINO techniques - Model-division and a simplest-way to support custom layers Description: Model Optimizer in Intel(r) OpenVINO(tm

Yasunori Shimura 12 Nov 09, 2022
OpenVisionAPI server

🚀 Quick start An instance of ova-server is free and publicly available here: https://api.openvisionapi.com Checkout ova-client for a quick demo. Inst

Open Vision API 93 Nov 24, 2022
Activity tragle - Google is tracking everything, we just look at it

activity_tragle Google is tracking everything, we just look at it here. You need

BERNARD Guillaume 1 Feb 15, 2022
Source-to-Source Debuggable Derivatives in Pure Python

Tangent Tangent is a new, free, and open-source Python library for automatic differentiation. Existing libraries implement automatic differentiation b

Google 2.2k Jan 01, 2023
Changing the Mind of Transformers for Topically-Controllable Language Generation

We will first introduce the how to run the IPython notebook demo by downloading our pretrained models. Then, we will introduce how to run our training and evaluation code.

IESL 20 Dec 06, 2022
CVPR2021 Workshop - HDRUNet: Single Image HDR Reconstruction with Denoising and Dequantization.

HDRUNet [Paper Link] HDRUNet: Single Image HDR Reconstruction with Denoising and Dequantization By Xiangyu Chen, Yihao Liu, Zhengwen Zhang, Yu Qiao an

XyChen 105 Dec 20, 2022
YourTTS: Towards Zero-Shot Multi-Speaker TTS and Zero-Shot Voice Conversion for everyone

YourTTS: Towards Zero-Shot Multi-Speaker TTS and Zero-Shot Voice Conversion for everyone In our recent paper we propose the YourTTS model. YourTTS bri

Edresson Casanova 390 Dec 29, 2022
Deep Reinforcement Learning based autonomous navigation for quadcopters using PPO algorithm.

PPO-based Autonomous Navigation for Quadcopters This repository contains an implementation of Proximal Policy Optimization (PPO) for autonomous naviga

Bilal Kabas 16 Nov 11, 2022
PyTorch implementation for "Mining Latent Structures with Contrastive Modality Fusion for Multimedia Recommendation"

MIRCO PyTorch implementation for paper: Latent Structures Mining with Contrastive Modality Fusion for Multimedia Recommendation Dependencies Python 3.

Big Data and Multi-modal Computing Group, CRIPAC 9 Dec 08, 2022
ManipulaTHOR, a framework that facilitates visual manipulation of objects using a robotic arm

ManipulaTHOR: A Framework for Visual Object Manipulation Kiana Ehsani, Winson Han, Alvaro Herrasti, Eli VanderBilt, Luca Weihs, Eric Kolve, Aniruddha

AI2 65 Dec 30, 2022
Multi-Task Pre-Training for Plug-and-Play Task-Oriented Dialogue System

Multi-Task Pre-Training for Plug-and-Play Task-Oriented Dialogue System Authors: Yixuan Su, Lei Shu, Elman Mansimov, Arshit Gupta, Deng Cai, Yi-An Lai

Amazon Web Services - Labs 123 Dec 23, 2022
Using CNN to mimic the driver based on training data from Torcs

Behavioural-Cloning-in-autonomous-driving Using CNN to mimic the driver based on training data from Torcs. Approach First, the data was collected from

Sudharshan 2 Jan 05, 2022
A Simple Framwork for CV Pre-training Model (SOCO, VirTex, BEiT)

A Simple Framwork for CV Pre-training Model (SOCO, VirTex, BEiT)

Sense-GVT 14 Jul 07, 2022
PyTorch implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 13.4k Jan 08, 2023
gtfs2vec - Learning GTFS Embeddings for comparing PublicTransport Offer in Microregions

gtfs2vec This is a companion repository for a gtfs2vec - Learning GTFS Embeddings for comparing PublicTransport Offer in Microregions publication. Vis

Politechnika Wrocławska - repozytorium dla informatyków 5 Oct 10, 2022
The code uses SegFormer for Semantic Segmentation on Drone Dataset.

SegFormer_Segmentation The code uses SegFormer for Semantic Segmentation on Drone Dataset. The details for the SegFormer can be obtained from the foll

Dr. Sander Ali Khowaja 1 May 08, 2022
A transformer model to predict pathogenic mutations

MutFormer MutFormer is an application of the BERT (Bidirectional Encoder Representations from Transformers) NLP (Natural Language Processing) model wi

Wang Genomics Lab 2 Nov 29, 2022
MutualGuide is a compact object detector specially designed for embedded devices

Introduction MutualGuide is a compact object detector specially designed for embedded devices. Comparing to existing detectors, this repo contains two

ZHANG Heng 103 Dec 13, 2022
Pytorch Implementation of Interaction Networks for Learning about Objects, Relations and Physics

Interaction-Network-Pytorch Pytorch Implementraion of Interaction Networks for Learning about Objects, Relations and Physics. Interaction Network is a

117 Nov 05, 2022