2.86% and 15.85% on CIFAR-10 and CIFAR-100

Overview

Shake-Shake regularization

This repository contains the code for the paper Shake-Shake regularization. This arxiv paper is an extension of Shake-Shake regularization of 3-branch residual networks which was accepted as a workshop contribution at ICLR 2017.

The code is based on fb.resnet.torch.

Table of Contents

  1. Introduction
  2. Results
  3. Usage
  4. Contact

Introduction

The method introduced in this paper aims at helping deep learning practitioners faced with an overfit problem. The idea is to replace, in a multi-branch network, the standard summation of parallel branches with a stochastic affine combination. Applied to 3-branch residual networks, shake-shake regularization improves on the best single shot published results on CIFAR-10 and CIFAR-100 by reaching test errors of 2.86% and 15.85%.

shake-shake

Figure 1: Left: Forward training pass. Center: Backward training pass. Right: At test time.

Bibtex:

@article{Gastaldi17ShakeShake,
   title = {Shake-Shake regularization},
   author = {Xavier Gastaldi},
   journal = {arXiv preprint arXiv:1705.07485},
   year = 2017,
}

Results on CIFAR-10

The base network is a 26 2x32d ResNet (i.e. the network has a depth of 26, 2 residual branches and the first residual block has a width of 32). "Shake" means that all scaling coefficients are overwritten with new random numbers before the pass. "Even" means that all scaling coefficients are set to 0.5 before the pass. "Keep" means that we keep, for the backward pass, the scaling coefficients used during the forward pass. "Batch" means that, for each residual block, we apply the same scaling coefficient for all the images in the mini-batch. "Image" means that, for each residual block, we apply a different scaling coefficient for each image in the mini-batch. The numbers in the Table below represent the average of 3 runs except for the 96d models which were run 5 times.

Forward Backward Level 26 2x32d 26 2x64d 26 2x96d 26 2x112d
Even Even n\a 4.27 3.76 3.58 -
Even Shake Batch 4.44 - -
Shake Keep Batch 4.11 - - -
Shake Even Batch 3.47 3.30 - -
Shake Shake Batch 3.67 3.07 - -
Even Shake Image 4.11 - - -
Shake Keep Image 4.09 - - -
Shake Even Image 3.47 3.20 - -
Shake Shake Image 3.55 2.98 2.86 2.821

Table 1: Error rates (%) on CIFAR-10 (Top 1 of the last epoch)

Other results

Cifar-100:
29 2x4x64d: 15.85%

Reduced CIFAR-10:
26 2x96d: 17.05%1

SVHN:
26 2x96d: 1.4%1

Reduced SVHN:
26 2x96d: 12.32%1

Usage

  1. Install fb.resnet.torch, optnet and lua-stdlib.
  2. Download Shake-Shake
git clone https://github.com/xgastaldi/shake-shake.git
  1. Copy the elements in the shake-shake folder and paste them in the fb.resnet.torch folder. This will overwrite 5 files (main.lua, train.lua, opts.lua, checkpoints.lua and models/init.lua) and add 4 new files (models/shakeshake.lua, models/shakeshakeblock.lua, models/mulconstantslices.lua and models/shakeshaketable.lua).
  2. To reproduce CIFAR-10 results (e.g. 26 2x32d "Shake-Shake-Image" ResNet) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true

To get comparable results using 1 GPU, please change the batch size and the corresponding learning rate:

CUDA_VISIBLE_DEVICES=0 th main.lua -dataset cifar10 -nGPU 1 -batchSize 64 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.1 -forwardShake true -backwardShake true -shakeImage true

A 26 2x96d "Shake-Shake-Image" ResNet can be trained on 2 GPUs using:

CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 96 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true
  1. To reproduce CIFAR-100 results (e.g. 29 2x4x64d "Shake-Even-Image" ResNeXt) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar100 -depth 29 -baseWidth 64 -groups 4 -weightDecay 5e-4 -batchSize 32 -netType shakeshake -nGPU 2 -LR 0.025 -nThreads 8 -shareGradInput true -nEpochs 1800 -lrShape cosine -forwardShake true -backwardShake false -shakeImage true

Note

Changes made to fb.resnet.torch files:

main.lua
Ln 17, 54-59, 81-100: Adds a log

train.lua
Ln 36-38 58-60 206-213: Adds the cosine learning rate function
Ln 88-89: Adds the learning rate to the elements printed on screen

opts.lua
Ln 21-64: Adds Shake-Shake options

checkpoints.lua
Ln 15-16: Adds require 'models/shakeshakeblock', 'models/shakeshaketable' and require 'std'
Ln 60-61: Avoids using the fb.resnet.torch deepcopy (it doesn't seem to be compatible with the BN in shakeshakeblock) and replaces it with the deepcopy from stdlib
Ln 67-86: Saves only the last model

models/init.lua
Ln 91-92: Adds require 'models/mulconstantslices', require 'models/shakeshakeblock' and require 'models/shakeshaketable'

The main model is in shakeshake.lua. The residual block model is in shakeshakeblock.lua. mulconstantslices.lua is just an extension of nn.mulconstant that multiplies elements of a vector with image slices of a mini-batch tensor. shakeshaketable.lua contains the method used for CIFAR-100 since the ResNeXt code uses a table implementation instead of a module version.

Reimplementations

Pytorch
https://github.com/hysts/pytorch_shake_shake

Tensorflow
https://github.com/tensorflow/models/blob/master/research/autoaugment/
https://github.com/tensorflow/tensor2tensor

Contact

xgastaldi.mba2011 at london.edu
Any discussions, suggestions and questions are welcome!

References

(1) Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V. Le. AutoAugment: Learning Augmentation Policies from Data. In arXiv:1805.09501, May 2018.

On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation

On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation On Nonlinear Latent Transformations for GAN-based Image Editi

Valentin Khrulkov 22 Oct 24, 2022
[ECCV 2020] XingGAN for Person Image Generation

Contents XingGAN or CrossingGAN Installation Dataset Preparation Generating Images Using Pretrained Model Train and Test New Models Evaluation Acknowl

Hao Tang 218 Oct 29, 2022
Keyword spotting on Arm Cortex-M Microcontrollers

Keyword spotting for Microcontrollers This repository consists of the tensorflow models and training scripts used in the paper: Hello Edge: Keyword sp

Arm Software 1k Dec 30, 2022
A Python library for working with arbitrary-dimension hypercomplex numbers following the Cayley-Dickson construction of algebras.

Hypercomplex A Python library for working with quaternions, octonions, sedenions, and beyond following the Cayley-Dickson construction of hypercomplex

7 Nov 04, 2022
PyTorch implementation of MICCAI 2018 paper "Liver Lesion Detection from Weakly-labeled Multi-phase CT Volumes with a Grouped Single Shot MultiBox Detector"

Grouped SSD (GSSD) for liver lesion detection from multi-phase CT Note: the MICCAI 2018 paper only covers the multi-phase lesion detection part of thi

Sang-gil Lee 36 Oct 12, 2022
Repository for the semantic WMI loss

Installation: pip install -e . Installing DL2: First clone DL2 in a separate directory and install it using the following commands: git clone https:/

Nick Hoernle 4 Sep 15, 2022
Official Pytorch implementation of "Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral)"

Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral): Official Project Webpage This repository provides the off

Kakao Enterprise Corp. 68 Dec 17, 2022
Code for our NeurIPS 2021 paper: Sparsely Changing Latent States for Prediction and Planning in Partially Observable Domains

GateL0RD This is a lightweight PyTorch implementation of GateL0RD, our RNN presented in "Sparsely Changing Latent States for Prediction and Planning i

Autonomous Learning Group 16 Nov 03, 2022
The implementation of the paper "HIST: A Graph-based Framework for Stock Trend Forecasting via Mining Concept-Oriented Shared Information".

The HIST framework for stock trend forecasting The implementation of the paper "HIST: A Graph-based Framework for Stock Trend Forecasting via Mining C

Wentao Xu 110 Dec 27, 2022
Implementation of Kaneko et al.'s MaskCycleGAN-VC model for non-parallel voice conversion.

MaskCycleGAN-VC Unofficial PyTorch implementation of Kaneko et al.'s MaskCycleGAN-VC (2021) for non-parallel voice conversion. MaskCycleGAN-VC is the

86 Dec 25, 2022
Deep learning for spiking neural networks

A deep learning library for spiking neural networks. Norse aims to exploit the advantages of bio-inspired neural components, which are sparse and even

Electronic Vision(s) Group — BrainScaleS Neuromorphic Hardware 59 Nov 28, 2022
SymPy-powered, Wolfram|Alpha-like answer engine totally in your browser, without backend computation

SymPy Beta SymPy Beta is a fork of SymPy Gamma. The purpose of this project is to run a SymPy-powered, Wolfram|Alpha-like answer engine totally in you

Liumeo 25 Dec 21, 2022
traiNNer is an open source image and video restoration (super-resolution, denoising, deblurring and others) and image to image translation toolbox based on PyTorch.

traiNNer traiNNer is an open source image and video restoration (super-resolution, denoising, deblurring and others) and image to image translation to

202 Jan 04, 2023
This repository contains several image-to-image translation models, whcih were tested for RGB to NIR image generation. The models are Pix2Pix, Pix2PixHD, CycleGAN and PointWise.

RGB2NIR_Experimental This repository contains several image-to-image translation models, whcih were tested for RGB to NIR image generation. The models

5 Jan 04, 2023
Code to reproduce the results for Compositional Attention

Compositional-Attention This repository contains the official implementation for the paper Compositional Attention: Disentangling Search and Retrieval

Sarthak Mittal 58 Nov 30, 2022
Performance Analysis of Multi-user NOMA Wireless-Powered mMTC Networks: A Stochastic Geometry Approach

Performance Analysis of Multi-user NOMA Wireless-Powered mMTC Networks: A Stochastic Geometry Approach Thanh Luan Nguyen, Tri Nhu Do, Georges Kaddoum

Thanh Luan Nguyen 2 Oct 10, 2022
This is the official repository for our paper: ''Pruning Self-attentions into Convolutional Layers in Single Path''.

Pruning Self-attentions into Convolutional Layers in Single Path This is the official repository for our paper: Pruning Self-attentions into Convoluti

Zhuang AI Group 77 Dec 26, 2022
PICARD - Parsing Incrementally for Constrained Auto-Regressive Decoding from Language Models

This is the official implementation of the following paper: Torsten Scholak, Nathan Schucher, Dzmitry Bahdanau. PICARD - Parsing Incrementally for Con

ElementAI 217 Jan 01, 2023
This was initially the repo for the project of [email protected] of Asaf Mazar, Millad Kassaie and Georgios Chochlakis named "Powered by the Will? Exploring Lay Theories of Behavior Change through Social Media"

Subreddit Analysis This repo includes tools for Subreddit analysis, originally developed for our class project of PSYC 626 in USC, titled "Powered by

Georgios Chochlakis 1 Dec 17, 2021
Autoregressive Models in PyTorch.

Autoregressive This repository contains all the necessary PyTorch code, tailored to my presentation, to train and generate data from WaveNet-like auto

Christoph Heindl 41 Oct 09, 2022