ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

Related tags

Deep Learningesgd
Overview

ESGD-M

ESGD-M is a stochastic non-convex second order optimizer, suitable for training deep learning models. It is based on ESGD (Equilibrated adaptive learning rates for non-convex optimization) and incorporates quasi-hyperbolic momentum (Quasi-hyperbolic momentum and Adam for deep learning) to accelerate convergence, which considerably improves its performance over plain ESGD.

ESGD-M obtains Hessian information through occasional Hessian-vector products (by default, every ten optimizer steps; each Hessian-vector product is approximately the same cost as a gradient evaluation) and uses it to adapt per-parameter learning rates. It estimates the diagonal of the absolute Hessian, diag(|H|), to use as a diagonal preconditioner.

To use this optimizer you must call .backward() with the create_graph=True option. Gradient accumulation steps and distributed training are currently not supported.

Learning rates

ESGD-M learning rates have a different meaning from SGD and Adagrad/Adam/etc. You may need to try learning rates in the range 1e-3 to 1.

SGD class optimizers:

  • If you rescale your parameters by a factor of n, you must scale your learning rate by a factor of n^2.

  • If you rescale your loss by a factor of n, you must scale your learning rate by a factor of 1 / n.

Adagrad/Adam class optimizers:

  • If you rescale your parameters by a factor of n, you must scale your learning rate by a factor of n.

  • If you rescale your loss by a factor of n, you do not have to scale your learning rate.

Second order optimizers (including ESGD-M):

  • You do not have to scale your learning rate if you rescale either your parameters or your loss.

Momentum

The default configuration is Nesterov momentum (if v is not specified then it will default to the value of beta_1, producing Nesterov momentum):

opt = ESGD(model.parameters(), lr=1, betas=(0.9, 0.999), v=0.9)

The Quasi-Hyperbolic Momentum recommended defaults can be obtained using:

opt = ESGD(model.parameters(), lr=1, betas=(0.999, 0.999), v=0.7)

Setting v equal to 1 will do normal (non-Nesterov) momentum.

The ESGD-M decay coefficient beta_2 refers not to the squared gradient as in Adam but to the squared Hessian diagonal estimate, which it uses in place of the squared gradient to provide per-parameter adaptive learning rates.

Hessian-vector products

The absolute Hessian diagonal diag(|H|) is estimated every update_d_every steps. The default is 10. Also, for the first d_warmup steps the diagonal will be estimated regardless, to obtain a lower variance estimate of diag(|H|) quickly. The estimation uses a Hessian-vector product, which takes around the same amount of time as a gradient evaluation to compute. You must explicitly signal to PyTorch that you want to do a double backward pass by:

opt.zero_grad(set_to_none=True)
loss = loss_fn(model(inputs), targets)
loss.backward(create_graph=True)
opt.step()

Weight decay

Weight decay is performed separately from the Hessian-vector product and the preconditioner, similar to AdamW except that the weight decay value provided by the user is multiplied by the current learning rate to determine the factor to decay the weights by.

Learning rate warmup

Because the diag(|H|) estimates are high variance, the adaptive learning rates are not very reliable before many steps have been taken and many estimates have been averaged together. To deal with this ESGD-M has a short exponential learning rate warmup by default (it is combined with any external learning rate schedulers). On each step (starting from 1) the learning rate will be:

lr * (1 - lr_warmup**step)

The default value for lr_warmup is 0.99, which reaches 63% of the specified learning rate in 100 steps and 95% in 300 steps.

Owner
Katherine Crowson
AI/generative artist.
Katherine Crowson
An Artificial Intelligence trying to drive a car by itself on a user created map

An Artificial Intelligence trying to drive a car by itself on a user created map

Akhil Sahukaru 17 Jan 13, 2022
Official repository for "On Improving Adversarial Transferability of Vision Transformers" (2021)

Improving-Adversarial-Transferability-of-Vision-Transformers Muzammal Naseer, Kanchana Ranasinghe, Salman Khan, Fahad Khan, Fatih Porikli arxiv link A

Muzammal Naseer 47 Dec 02, 2022
Self-training with Weak Supervision (NAACL 2021)

This repo holds the code for our weak supervision framework, ASTRA, described in our NAACL 2021 paper: "Self-Training with Weak Supervision"

Microsoft 148 Nov 20, 2022
Fast Scattering Transform with CuPy/PyTorch

Announcement 11/18 This package is no longer supported. We have now released kymatio: http://www.kymat.io/ , https://github.com/kymatio/kymatio which

Edouard Oyallon 289 Dec 07, 2022
RLDS stands for Reinforcement Learning Datasets

RLDS RLDS stands for Reinforcement Learning Datasets and it is an ecosystem of tools to store, retrieve and manipulate episodic data in the context of

Google Research 135 Jan 01, 2023
Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".

AST: Audio Spectrogram Transformer Introduction Citing Getting Started ESC-50 Recipe Speechcommands Recipe AudioSet Recipe Pretrained Models Contact I

Yuan Gong 603 Jan 07, 2023
CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation

CPT This repository contains code and checkpoints for CPT. CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Gener

fastNLP 341 Dec 29, 2022
Official PyTorch Code of GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection (CVPR 2021)

GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Mo

Abhinav Kumar 76 Jan 02, 2023
Code for the paper "Functional Regularization for Reinforcement Learning via Learned Fourier Features"

Reinforcement Learning with Learned Fourier Features State-space Soft Actor-Critic Experiments Move to the state-SAC-LFF repository. cd state-SAC-LFF

Alex Li 10 Nov 11, 2022
On the adaptation of recurrent neural networks for system identification

On the adaptation of recurrent neural networks for system identification This repository contains the Python code to reproduce the results of the pape

Marco Forgione 3 Jan 13, 2022
Code for our CVPR 2021 Paper "Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes".

Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes (CVPR 2021) Project page | Paper | Colab | Colab for Drawing App Rethinking Style

CompVis Heidelberg 153 Jan 04, 2023
Facebook AI Image Similarity Challenge: Descriptor Track

Facebook AI Image Similarity Challenge: Descriptor Track This repository contains the code for our solution to the Facebook AI Image Similarity Challe

Sergio MP 17 Dec 14, 2022
Bald-to-Hairy Translation Using CycleGAN

GANiry: Bald-to-Hairy Translation Using CycleGAN Official PyTorch implementation of GANiry. GANiry: Bald-to-Hairy Translation Using CycleGAN, Fidan Sa

Fidan Samet 10 Oct 27, 2022
OCTIS: Comparing Topic Models is Simple! A python package to optimize and evaluate topic models (accepted at EACL2021 demo track)

OCTIS : Optimizing and Comparing Topic Models is Simple! OCTIS (Optimizing and Comparing Topic models Is Simple) aims at training, analyzing and compa

MIND 478 Jan 01, 2023
Easy-to-use micro-wrappers for Gym and PettingZoo based RL Environments

SuperSuit introduces a collection of small functions which can wrap reinforcement learning environments to do preprocessing ('microwrappers'). We supp

Farama Foundation 357 Jan 06, 2023
Prototype python implementation of the ome-ngff table spec

Prototype python implementation of the ome-ngff table spec

Kevin Yamauchi 8 Nov 20, 2022
An implementation of the proximal policy optimization algorithm

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

Martin Huber 59 Dec 09, 2022
Sparse Physics-based and Interpretable Neural Networks

Sparse Physics-based and Interpretable Neural Networks for PDEs This repository contains the code and manuscript for research done on Sparse Physics-b

28 Jan 03, 2023
prior-based-losses-for-medical-image-segmentation

Repository for papers: Benchmark: Effect of Prior-based Losses on Segmentation Performance: A Benchmark Midl: A Surprisingly Effective Perimeter-based

Rosana EL JURDI 9 Sep 07, 2022
COVID-VIT: Classification of Covid-19 from CT chest images based on vision transformer models

COVID-ViT COVID-VIT: Classification of Covid-19 from CT chest images based on vision transformer models This code is to response to te MIA-COV19 compe

17 Dec 30, 2022