Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Related tags

Deep LearningbigBatch
Overview

Train longer, generalize better - Big batch training

This is a code repository used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks" By Elad Hoffer, Itay Hubara and Daniel Soudry.

It is based off convNet.pytorch with some helpful options such as:

  • Training on several datasets
  • Complete logging of trained experiment
  • Graph visualization of the training/validation loss and accuracy
  • Definition of preprocessing and optimization regime for each model

Dependencies

Data

  • Configure your dataset path at data.py.
  • To get the ILSVRC data, you should register on their site for access: http://www.image-net.org/

Experiment examples

python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_lr_fix --epochs 100 --b 2048 --lr_bb_fix;
python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_regime_adaptation --epochs 100 --b 2048 --lr_bb_fix --regime_bb_fix;
python main_gbn.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_ghost_bn256 --epochs 100 --b 2048 --lr_bb_fix --mini-batch-size 256;
python main_normal.py --dataset cifar100 --model resnet --save cifar100_wresnet16_4_bs1024_regime_adaptation --epochs 100 --b 1024 --lr_bb_fix --regime_bb_fix;
python main_gbn.py --model mnist_f1 --dataset mnist --save mnist_baseline_bs4096_gbn --epochs 50 --b 4096 --lr_bb_fix --no-regime_bb_fix --mini-batch-size 128;
  • See run_experiments.sh for more examples

Model configuration

Network model is defined by writing a .py file in models folder, and selecting it using the model flag. Model function must be registered in models/__init__.py The model function must return a trainable network. It can also specify additional training options such optimization regime (either a dictionary or a function), and input transform modifications.

e.g for a model definition:

class Model(nn.Module):

    def __init__(self, num_classes=1000):
        super(Model, self).__init__()
        self.model = nn.Sequential(...)

        self.regime = {
            0: {'optimizer': 'SGD', 'lr': 1e-2,
                'weight_decay': 5e-4, 'momentum': 0.9},
            15: {'lr': 1e-3, 'weight_decay': 0}
        }

        self.input_transform = {
            'train': transforms.Compose([...]),
            'eval': transforms.Compose([...])
        }
    def forward(self, inputs):
        return self.model(inputs)

 def model(**kwargs):
        return Model()
Owner
Elad Hoffer
Elad Hoffer
Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021).

STAR-pytorch Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021). CVF (pdf) STAR-DC

43 Dec 21, 2022
Neural network for recognizing the gender of people in photos

Neural Network For Gender Recognition How to test it? Install requirements.txt file using pip install -r requirements.txt command Run nn.py using pyth

Valery Chapman 1 Sep 18, 2022
KIND: an Italian Multi-Domain Dataset for Named Entity Recognition

KIND (Kessler Italian Named-entities Dataset) KIND is an Italian dataset for Named-Entity Recognition. It contains more than one million tokens with t

Digital Humanities 5 Jun 21, 2022
Code for the TPAMI paper: "Syntax Customized Video Captioning by Imitating Exemplar Sentences"

Syntax-Customized-Video-Captioning Code for the TPAMI paper: "Syntax Customized Video Captioning by Imitating Exemplar Sentences". This is my second w

3 Dec 05, 2022
Vpw analyzer - A visual J1850 VPW analyzer written in Python

VPW Analyzer A visual J1850 VPW analyzer written in Python Requires Tkinter, Pan

7 May 01, 2022
Neural style transfer as a class in PyTorch

pt-styletransfer Neural style transfer as a class in PyTorch Based on: https://github.com/alexis-jacq/Pytorch-Tutorials Adds: StyleTransferNet as a cl

Tyler Kvochick 31 Jun 27, 2022
Curvlearn, a Tensorflow based non-Euclidean deep learning framework.

English | 简体中文 Why Non-Euclidean Geometry Considering these simple graph structures shown below. Nodes with same color has 2-hop distance whereas 1-ho

Alibaba 123 Dec 12, 2022
An optimization and data collection toolbox for convenient and fast prototyping of computationally expensive models.

An optimization and data collection toolbox for convenient and fast prototyping of computationally expensive models. Hyperactive: is very easy to lear

Simon Blanke 422 Jan 04, 2023
Optical Character Recognition + Instance Segmentation for russian and english languages

Распознавание рукописного текста в школьных тетрадях Соревнование, проводимое в рамках олимпиады НТО, разработанное Сбером. Платформа ODS. Результаты

Gerasimov Maxim 21 Dec 19, 2022
MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc.

MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc. ⭐⭐⭐⭐⭐

568 Jan 04, 2023
This repository contains pre-trained models and some evaluation code for our paper Towards Unsupervised Dense Information Retrieval with Contrastive Learning

Contriever: Towards Unsupervised Dense Information Retrieval with Contrastive Learning This repository contains pre-trained models and some evaluation

Meta Research 207 Jan 08, 2023
Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo"

dblmahmc Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo" Requirements: https://github.com

1 Dec 17, 2021
PyTorch reimplementation of the Smooth ReLU activation function proposed in the paper "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations" [arXiv 2022].

Smooth ReLU in PyTorch Unofficial PyTorch reimplementation of the Smooth ReLU (SmeLU) activation function proposed in the paper Real World Large Scale

Christoph Reich 10 Jan 02, 2023
Source code of all the projects of Udacity Self-Driving Car Engineer Nanodegree.

self-driving-car In this repository I will share the source code of all the projects of Udacity Self-Driving Car Engineer Nanodegree. Hope this might

Andrea Palazzi 2.4k Dec 29, 2022
Project Aquarium is a SUSE-sponsored open source project aiming at becoming an easy to use, rock solid storage appliance based on Ceph.

Project Aquarium Project Aquarium is a SUSE-sponsored open source project aiming at becoming an easy to use, rock solid storage appliance based on Cep

Aquarist Labs 73 Jul 21, 2022
Learning Logic Rules for Document-Level Relation Extraction

LogiRE Learning Logic Rules for Document-Level Relation Extraction We propose to introduce logic rules to tackle the challenges of doc-level RE. Equip

41 Dec 26, 2022
This is the implementation of GGHL (A General Gaussian Heatmap Labeling for Arbitrary-Oriented Object Detection)

GGHL: A General Gaussian Heatmap Labeling for Arbitrary-Oriented Object Detection This is the implementation of GGHL 👋 👋 👋 [Arxiv] [Google Drive][B

551 Dec 31, 2022
Differentiable Wavetable Synthesis

Differentiable Wavetable Synthesis

4 Feb 11, 2022
Code for ICCV 2021 paper "HuMoR: 3D Human Motion Model for Robust Pose Estimation"

Code for ICCV 2021 paper "HuMoR: 3D Human Motion Model for Robust Pose Estimation"

Davis Rempe 367 Dec 24, 2022
PipeTransformer: Automated Elastic Pipelining for Distributed Training of Large-scale Models

PipeTransformer: Automated Elastic Pipelining for Distributed Training of Large-scale Models This repository is the official implementation of the fol

DistributedML 41 Dec 06, 2022