Official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th ICML Workshop on AutoML)

Related tags

Deep Learningautowu
Overview

Automated Learning Rate Scheduler for Large-Batch Training

The official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th ICML Workshop on AutoML).

Overview

AutoWU is an automated LR scheduler which consists of two phases: warmup and decay. Learning rate (LR) is increased in an exponential rate until the loss starts to increase, and in the decay phase LR is decreased following the pre-specified type of the decay (either cosine or constant-then-cosine, in our experiments).

Transition from the warmup to the decay phase is done automatically by testing whether the minimum of the predicted loss curve is attained in the past or not with high probability, and the prediction is made via Gaussian Process regression.

Diagram summarizing AutoWU

How to use

Setup

pip install -r requirements.txt

Quick use

You can use AutoWU as other PyTorch schedulers, except that it takes loss as an argument (like ReduceLROnPlateau in PyTorch). The following code snippet demonstrates a typical usage of AutoWU.

from autowu import AutoWU

...

scheduler = AutoWU(optimizer,
                   len(train_loader),  # the number of steps in one epoch 
                   total_epochs,  # total number of epochs
                   immediate_cooldown=True,
                   cooldown_type='cosine',
                   device=device)

...

for _ in range(total_epochs):
    for inputs, targets in train_loader:
        loss = loss_fn(model(inputs), targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step(loss)

The default decay phase schedule is ''cosine''. To use constant-then-cosine schedule rather than cosine, set immediate_cooldown=False and set cooldown_fraction to a desired value:

scheduler = AutoWU(optimizer,
                   len(train_loader),  # the number of steps in one epoch 
                   total_epochs,  # total number of epochs
                   immediate_cooldown=False,
                   cooldown_type='cosine',
                   cooldown_fraction=0.2,  # fraction of cosine decay at the end
                   device=device)

Reproduction of results

We provide an exemplar training script train.py which is based on Pytorch Image Models. The script supports training ResNet-50 and EfficientNet-B0 on ImageNet classification under the setting almost identical to the paper. We report the top-1 accuracy of ResNet-50 and EfficientNet-B0 on the validation set trained with batch sizes 4K (4096) and 16K (16384), along with the scores reported in our paper.

ResNet-50 This repo. Reported (paper)
4K 75.54% 75.70%
16K 74.87% 75.22%
EfficientNet-B0 This repo. Reported (paper)
4K 75.74% 75.81%
16K 75.66% 75.44%

You can use distributed.launch util to run the script. For instance, in case of ResNet-50 training with batch size 4096, execute the following line with variables set according to your environment:

python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=4 \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train.py \
--data-root $DATA_ROOT \
--amp \
--batch-size 256 

In addition, add --model efficientnet_b0 argument in case of EfficientNet-B0 training.

Citation

@inproceedings{
    kim2021automated,
    title={Automated Learning Rate Scheduler for Large-batch Training},
    author={Chiheon Kim and Saehoon Kim and Jongmin Kim and Donghoon Lee and Sungwoong Kim},
    booktitle={8th ICML Workshop on Automated Machine Learning (AutoML)},
    year={2021},
    url={https://openreview.net/forum?id=ljIl7KCNYZH}
}

License

This project is licensed under the terms of Apache License 2.0. Copyright 2021 Kakao Brain. All right reserved.

Owner
Kakao Brain
Kakao Brain Corp.
Kakao Brain
USAD - UnSupervised Anomaly Detection on multivariate time series

USAD - UnSupervised Anomaly Detection on multivariate time series Scripts and utility programs for implementing the USAD architecture. Implementation

116 Jan 04, 2023
Piotr - IoT firmware emulation instrumentation for training and research

Piotr: Pythonic IoT exploitation and Research Introduction to Piotr Piotr is an emulation helper for Qemu that provides a convenient way to create, sh

Damien Cauquil 51 Nov 09, 2022
Adversarial Reweighting for Partial Domain Adaptation

Adversarial Reweighting for Partial Domain Adaptation Code for paper "Xiang Gu, Xi Yu, Yan Yang, Jian Sun, Zongben Xu, Adversarial Reweighting for Par

12 Dec 01, 2022
PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 2021

Neural Scene Flow Fields PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 20

Zhengqi Li 585 Jan 04, 2023
SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

SPLADE 🍴 + 🥄 = 🔎 This repository contains the weights for four models as well as the code for running inference for our two papers: [v1]: SPLADE: S

NAVER 170 Dec 28, 2022
ByteTrack(Multi-Object Tracking by Associating Every Detection Box)のPythonでのONNX推論サンプル

ByteTrack-ONNX-Sample ByteTrack(Multi-Object Tracking by Associating Every Detection Box)のPythonでのONNX推論サンプルです。 ONNXに変換したモデルも同梱しています。 変換自体を試したい方はByteT

KazuhitoTakahashi 16 Oct 26, 2022
CTF challenges and write-ups for MicroCTF 2021.

MicroCTF 2021 Qualifications About This repository contains CTF challenges and official write-ups for MicroCTF 2021 Qualifications. License Distribute

Shellmates 12 Dec 27, 2022
DeepOBS: A Deep Learning Optimizer Benchmark Suite

DeepOBS - A Deep Learning Optimizer Benchmark Suite DeepOBS is a benchmarking suite that drastically simplifies, automates and improves the evaluation

Aaron Bahde 7 May 12, 2020
FEMDA: Robust classification with Flexible Discriminant Analysis in heterogeneous data

FEMDA: Robust classification with Flexible Discriminant Analysis in heterogeneous data. Flexible EM-Inspired Discriminant Analysis is a robust supervised classification algorithm that performs well i

0 Sep 06, 2022
This is the latest version of the PULP SDK

PULP-SDK This is the latest version of the PULP SDK, which is under active development. The previous (now legacy) version, which is no longer supporte

78 Dec 07, 2022
Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

Haoyan Huo 9 Nov 18, 2022
Large scale and asynchronous Hyperparameter Optimization at your fingertip.

Syne Tune This package provides state-of-the-art distributed hyperparameter optimizers (HPO) where trials can be evaluated with several backend option

Amazon Web Services - Labs 236 Jan 01, 2023
Python code for loading the Aschaffenburg Pose Dataset.

Aschaffenburg Pose Dataset (APD) This repository contains Python code for loading and filtering the Aschaffenburg Pose Dataset. The dataset itself and

1 Nov 26, 2021
Code for paper Decoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic Forecasting

Decoupled Spatial-Temporal Graph Neural Networks Code for our paper: Decoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic Forecasting.

S22 43 Jan 04, 2023
Neural-fractal - Create Fractals Using Complex-Valued Neural Networks!

Neural Fractal Create Fractals Using Complex-Valued Neural Networks! Home Page Features Define Dynamical Systems Using Complex-Valued Neural Networks

Amirabbas Asadi 10 Dec 17, 2022
Waymo motion prediction challenge 2021: 3rd place solution

Waymo motion prediction challenge 2021: 3rd place solution 📜 Technical report 🗨️ Presentation 🎉 Announcement 🛆Motion Prediction Channel Website 🛆

158 Jan 08, 2023
Defending against Model Stealing via Verifying Embedded External Features

Defending against Model Stealing Attacks via Verifying Embedded External Features This is the official implementation of our paper Defending against M

20 Dec 30, 2022
This is Unofficial Repo. Lips Don't Lie: A Generalisable and Robust Approach to Face Forgery Detection (CVPR 2021)

Lips Don't Lie: A Generalisable and Robust Approach to Face Forgery Detection This is a PyTorch implementation of the LipForensics paper. This is an U

Minha Kim 2 May 11, 2022
Collections for the lasted paper about multi-view clustering methods (papers, codes)

Multi-View Clustering Papers Collections for the lasted paper about multi-view clustering methods (papers, codes). There also exists some repositories

Andrew Guan 10 Sep 20, 2022
A python program to hack instagram

hackinsta a program to hack instagram Yokoback_(instahack) is the file to open, you need libraries write on import. You run that file in the same fold

2 Jan 22, 2022