Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Related tags

Deep LearningbigBatch
Overview

Train longer, generalize better - Big batch training

This is a code repository used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks" By Elad Hoffer, Itay Hubara and Daniel Soudry.

It is based off convNet.pytorch with some helpful options such as:

  • Training on several datasets
  • Complete logging of trained experiment
  • Graph visualization of the training/validation loss and accuracy
  • Definition of preprocessing and optimization regime for each model

Dependencies

Data

  • Configure your dataset path at data.py.
  • To get the ILSVRC data, you should register on their site for access: http://www.image-net.org/

Experiment examples

python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_lr_fix --epochs 100 --b 2048 --lr_bb_fix;
python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_regime_adaptation --epochs 100 --b 2048 --lr_bb_fix --regime_bb_fix;
python main_gbn.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_ghost_bn256 --epochs 100 --b 2048 --lr_bb_fix --mini-batch-size 256;
python main_normal.py --dataset cifar100 --model resnet --save cifar100_wresnet16_4_bs1024_regime_adaptation --epochs 100 --b 1024 --lr_bb_fix --regime_bb_fix;
python main_gbn.py --model mnist_f1 --dataset mnist --save mnist_baseline_bs4096_gbn --epochs 50 --b 4096 --lr_bb_fix --no-regime_bb_fix --mini-batch-size 128;
  • See run_experiments.sh for more examples

Model configuration

Network model is defined by writing a .py file in models folder, and selecting it using the model flag. Model function must be registered in models/__init__.py The model function must return a trainable network. It can also specify additional training options such optimization regime (either a dictionary or a function), and input transform modifications.

e.g for a model definition:

class Model(nn.Module):

    def __init__(self, num_classes=1000):
        super(Model, self).__init__()
        self.model = nn.Sequential(...)

        self.regime = {
            0: {'optimizer': 'SGD', 'lr': 1e-2,
                'weight_decay': 5e-4, 'momentum': 0.9},
            15: {'lr': 1e-3, 'weight_decay': 0}
        }

        self.input_transform = {
            'train': transforms.Compose([...]),
            'eval': transforms.Compose([...])
        }
    def forward(self, inputs):
        return self.model(inputs)

 def model(**kwargs):
        return Model()
Owner
Elad Hoffer
Elad Hoffer
Pytorch implementation of "Neural Wireframe Renderer: Learning Wireframe to Image Translations"

Neural Wireframe Renderer: Learning Wireframe to Image Translations Pytorch implementation of ideas from the paper Neural Wireframe Renderer: Learning

Yuan Xue 7 Nov 14, 2022
Pneumonia Detection using machine learning - with PyTorch

Pneumonia Detection Pneumonia Detection using machine learning. Training was done in colab: DEMO: Result (Confusion Matrix): Data I uploaded my datase

Wilhelm Berghammer 12 Jul 07, 2022
Visualizer using audio and semantic analysis to explore BigGAN (Brock et al., 2018) latent space.

BigGAN Audio Visualizer Description This visualizer explores BigGAN (Brock et al., 2018) latent space by using pitch/tempo of an audio file to generat

Rush Kapoor 2 Nov 21, 2022
Progressive Image Deraining Networks: A Better and Simpler Baseline

Progressive Image Deraining Networks: A Better and Simpler Baseline [arxiv] [pdf] [supp] Introduction This paper provides a better and simpler baselin

190 Dec 01, 2022
Meli Data Challenge 2021 - First Place Solution

My solution for the Meli Data Challenge 2021

Matias Moreyra 23 Mar 09, 2022
Full Resolution Residual Networks for Semantic Image Segmentation

Full-Resolution Residual Networks (FRRN) This repository contains code to train and qualitatively evaluate Full-Resolution Residual Networks (FRRNs) a

Toby Pohlen 274 Oct 27, 2022
Label-Free Model Evaluation with Semi-Structured Dataset Representations

Label-Free Model Evaluation with Semi-Structured Dataset Representations Prerequisites This code uses the following libraries Python 3.7 NumPy PyTorch

8 Oct 06, 2022
Semi-supervised semantic segmentation needs strong, varied perturbations

Semi-supervised semantic segmentation using CutMix and Colour Augmentation Implementations of our papers: Semi-supervised semantic segmentation needs

146 Dec 20, 2022
JAX + dataclasses

jax_dataclasses jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for: Pytree registrati

Brent Yi 35 Dec 21, 2022
SuMa++: Efficient LiDAR-based Semantic SLAM (Chen et al IROS 2019)

SuMa++: Efficient LiDAR-based Semantic SLAM This repository contains the implementation of SuMa++, which generates semantic maps only using three-dime

Photogrammetry & Robotics Bonn 701 Dec 30, 2022
Denoising Diffusion Probabilistic Models

Denoising Diffusion Probabilistic Models Jonathan Ho, Ajay Jain, Pieter Abbeel Paper: https://arxiv.org/abs/2006.11239 Website: https://hojonathanho.g

Jonathan Ho 1.5k Jan 08, 2023
Efficient Householder transformation in PyTorch

Efficient Householder Transformation in PyTorch This repository implements the Householder transformation algorithm for calculating orthogonal matrice

Anton Obukhov 49 Nov 20, 2022
[ICCV 2021] Relaxed Transformer Decoders for Direct Action Proposal Generation

RTD-Net (ICCV 2021) This repo holds the codes of paper: "Relaxed Transformer Decoders for Direct Action Proposal Generation", accepted in ICCV 2021. N

Multimedia Computing Group, Nanjing University 80 Nov 30, 2022
This repository contains the implementation of the paper: "Towards Frequency-Based Explanation for Robust CNN"

RobustFreqCNN About This repository contains the implementation of the paper "Towards Frequency-Based Explanation for Robust CNN" arxiv. It primarly d

Sarosij Bose 2 Jan 23, 2022
TinyML Cookbook, published by Packt

TinyML Cookbook This is the code repository for TinyML Cookbook, published by Packt. Author: Gian Marco Iodice Publisher: Packt About the book This bo

Packt 93 Dec 29, 2022
This repository allows the user to automatically scale a 3D model/mesh/point cloud on Agisoft Metashape

Metashape-Utils This repository allows the user to automatically scale a 3D model/mesh/point cloud on Agisoft Metashape, given a set of 2D coordinates

INSCRIBE 4 Nov 07, 2022
Plugin adapted from Ultralytics to bring YOLOv5 into Napari

napari-yolov5 Plugin adapted from Ultralytics to bring YOLOv5 into Napari. Training and detection can be done using the GUI. Training dataset must be

2 May 05, 2022
Omnidirectional Scene Text Detection with Sequential-free Box Discretization (IJCAI 2019). Including competition model, online demo, etc.

Box_Discretization_Network This repository is built on the pytorch [maskrcnn_benchmark]. The method is the foundation of our ReCTs-competition method

Yuliang Liu 266 Nov 24, 2022
SLIDE : In Defense of Smart Algorithms over Hardware Acceleration for Large-Scale Deep Learning Systems

The SLIDE package contains the source code for reproducing the main experiments in this paper. Dataset The Datasets can be downloaded in Amazon-

Intel Labs 72 Dec 16, 2022
Parameterized Explainer for Graph Neural Network

PGExplainer This is a Tensorflow implementation of the paper: Parameterized Explainer for Graph Neural Network https://arxiv.org/abs/2011.04573 NeurIP

Dongsheng Luo 89 Dec 12, 2022