Details about the wide minima density hypothesis and metrics to compute width of a minima

Overview

wide-minima-density-hypothesis

Details about the wide minima density hypothesis and metrics to compute width of a minima

This repo presents the wide minima density hypothesis as proposed in the following paper:

Key contributions:

  • Hypothesis about minima density
  • A SOTA LR schedule that exploits the hypothesis and beats general baseline schedules
  • Reducing wall clock training time and saving GPU compute hours with our LR schedule (Pretraining BERT-Large in 33% less training steps)
  • SOTA BLEU score on IWSLT'14 ( DE-EN )

Prerequisite:

  • CUDA, cudnn
  • Python 3.6+
  • PyTorch 1.4.0

Knee LR Schedule

Based on the density of wide vs narrow minima , we propose the Knee LR schedule that pushes generalization boundaries further by exploiting the nature of the loss landscape. The LR schedule is an explore-exploit based schedule, where the explore phase maintains a high lr for a significant time to access and land into a wide minimum with a good probability. The exploit phase is a simple linear decay scheme, which decays the lr to zero over the exploit phase. The only hyperparameter to tune is the explore epochs/steps. We have shown that 50% of the training budget allocated for explore is good enough for landing in a wider minimum and better generalization, thus removing the need for hyperparameter tuning.

  • Note that many experiments require warmup, which is done in the initial phase of training for a fixed number of steps and is usually required for Adam based optimizers/ large batch training. It is complementary with the Knee schedule and can be added to it.

To use the Knee Schedule, import the scheduler into your training file:

>>> from knee_lr_schedule import KneeLRScheduler
>>> scheduler = KneeLRScheduler(optimizer, peak_lr, warmup_steps, explore_steps, total_steps)

To use it during training :

>>> model.train()
>>> output = model(inputs)
>>> loss = criterion(output, targets)
>>> loss.backward()
>>> optimizer.step()
>>> scheduler.step()

Details about args:

  • optimizer: optimizer needed for training the model ( SGD/Adam )
  • peak_lr: the peak learning required for explore phase to escape narrow minimas
  • warmup_steps: steps required for warmup( usually needed for adam optimizers/ large batch training) Default value: 0
  • explore_steps: total steps for explore phase.
  • total_steps: total training budget steps for training the model

Measuring width of a minima

Keskar et.al 2016 (https://arxiv.org/abs/1609.04836) argue that wider minima generalize much better than sharper minima. The computation method in their work uses the compute expensive LBFGS-B second order method, which is hard to scale. We use a projected gradient ascent based method, which is first order in nature and very easy to implement/use. Here is a simple way you can compute the width of the minima your model finds during training:

>>> from minima_width_compute import ComputeKeskarSharpness
>>> cks = ComputeKeskarSharpness(model_final_ckpt, optimizer, criterion, trainloader, epsilon, lr, max_steps)
>>> width = cks.compute_sharpness()

Details about args:

  • model_final_ckpt: model loaded with the saved checkpoint after final training step
  • optimizer : optimizer to use for projected gradient ascent ( SGD, Adam )
  • criterion : criterion for computing loss (e.g. torch.nn.CrossEntropyLoss)
  • trainloader : iterator over the training dataset (torch.utils.data.DataLoader)
  • epsilon : epsilon value determines the local boundary around which minima witdh is computed (Default value : 1e-4)
  • lr : lr for the optimizer to perform projected gradient ascent ( Default: 0.001)
  • max_steps : max steps to compute the width (Default: 1000). Setting it too low could lead to the gradient ascent method not converging to an optimal point.

The above default values have been chosen after tuning and observing the loss values of projected gradient ascent on Cifar-10 with ResNet-18 and SGD-Momentum optimizer, as mentioned in our paper. The values may vary for experiments with other optimizers/datasets/models. Please tune them for optimal convergence.

  • Acknowledgements: We would like to thank Harshay Shah (https://github.com/harshays) for his helpful discussions for computing the width of the minima.

Citation

Please cite our paper in your publications if you use our work:

@article{iyer2020wideminima,
  title={Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule},
  author={Iyer, Nikhil and Thejas, V and Kwatra, Nipun and Ramjee, Ramachandran and Sivathanu, Muthian},
  journal={arXiv preprint arXiv:2003.03977},
  year={2020}
}
  • Note: This work was done during an internship at Microsoft Research India
Owner
Nikhil Iyer
Studied at BITS-Pilani Hyderabad Campus. AI Research @ Jio-Haptik. Ex Microsoft Research India
Nikhil Iyer
(under submission) Bayesian Integration of a Generative Prior for Image Restoration

BIGPrior: Towards Decoupling Learned Prior Hallucination and Data Fidelity in Image Restoration Authors: Majed El Helou, and Sabine Süsstrunk {Note: p

Majed El Helou 22 Dec 17, 2022
Pytorch implementation of set transformer

set_transformer Official PyTorch implementation of the paper Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks .

Juho Lee 410 Jan 06, 2023
LSTMs (Long Short Term Memory) RNN for prediction of price trends

Price Prediction with Recurrent Neural Networks LSTMs BTC-USD price prediction with deep learning algorithm. Artificial Neural Networks specifically L

5 Nov 12, 2021
[ICCV 2021] Target Adaptive Context Aggregation for Video Scene Graph Generation

Target Adaptive Context Aggregation for Video Scene Graph Generation This is a PyTorch implementation for Target Adaptive Context Aggregation for Vide

Multimedia Computing Group, Nanjing University 44 Dec 14, 2022
A tutorial on training a DarkNet YOLOv4 model for the CrowdHuman dataset

YOLOv4 CrowdHuman Tutorial This is a tutorial demonstrating how to train a YOLOv4 people detector using Darknet and the CrowdHuman dataset. Table of c

JK Jung 118 Nov 10, 2022
A library of multi-agent reinforcement learning components and systems

Mava: a research framework for distributed multi-agent reinforcement learning Table of Contents Overview Getting Started Supported Environments System

InstaDeep Ltd 463 Dec 23, 2022
ThunderGBM: Fast GBDTs and Random Forests on GPUs

Documentations | Installation | Parameters | Python (scikit-learn) interface What's new? ThunderGBM won 2019 Best Paper Award from IEEE Transactions o

Xtra Computing Group 647 Jan 04, 2023
Shape-aware Semi-supervised 3D Semantic Segmentation for Medical Images

SASSnet Code for paper: Shape-aware Semi-supervised 3D Semantic Segmentation for Medical Images(MICCAI 2020) Our code is origin from UA-MT You can fin

klein 125 Jan 03, 2023
Patch SVDD for Image anomaly detection

Patch SVDD Patch SVDD for Image anomaly detection. Paper: https://arxiv.org/abs/2006.16067 (published in ACCV 2020). Original Code : https://github.co

Hong-Jeongmin 0 Dec 03, 2021
A resource for learning about deep learning techniques from regression to LSTM and Reinforcement Learning using financial data and the fitness functions of algorithmic trading

A tour through tensorflow with financial data I present several models ranging in complexity from simple regression to LSTM and policy networks. The s

195 Dec 07, 2022
A python software that can help blind people find things like laptops, phones, etc the same way a guide dog guides a blind person in finding his way.

GuidEye A python software that can help blind people find things like laptops, phones, etc the same way a guide dog guides a blind person in finding h

Munal Jain 0 Aug 09, 2022
SuRE Evaluation: A Supplementary Material

SuRE Evaluation: A Supplementary Material This repository contains supplementary material regarding the evaluations presented in the paper Visual Expl

NYU Visualization Lab 0 Dec 14, 2021
Code of Puregaze: Purifying gaze feature for generalizable gaze estimation, AAAI 2022.

PureGaze: Purifying Gaze Feature for Generalizable Gaze Estimation Description Our work is accpeted by AAAI 2022. Picture: We propose a domain-general

39 Dec 05, 2022
lightweight python wrapper for vowpal wabbit

vowpal_porpoise Lightweight python wrapper for vowpal_wabbit. Why: Scalable, blazingly fast machine learning. Install Install vowpal_wabbit. Clone and

Joseph Reisinger 163 Nov 24, 2022
Real-time Joint Semantic Reasoning for Autonomous Driving

MultiNet MultiNet is able to jointly perform road segmentation, car detection and street classification. The model achieves real-time speed and state-

Marvin Teichmann 518 Dec 12, 2022
Pca-on-genotypes - Mini bioinformatics project - PCA on genotypes

Mini bioinformatics project: PCA on genotypes This repo contains the code from t

Maria Nattestad 8 Dec 04, 2022
Fibonacci Method Gradient Descent

An implementation of the Fibonacci method for gradient descent, featuring a TKinter GUI for inputting the function / parameters to be examined and a matplotlib plot of the function and results.

Emma 1 Jan 28, 2022
Medical-Image-Triage-and-Classification-System-Based-on-COVID-19-CT-and-X-ray-Scan-Dataset

Medical-Image-Triage-and-Classification-System-Based-on-COVID-19-CT-and-X-ray-Sc

2 Dec 26, 2021
Amazing-Python-Scripts - 🚀 Curated collection of Amazing Python scripts from Basics to Advance with automation task scripts.

📑 Introduction A curated collection of Amazing Python scripts from Basics to Advance with automation task scripts. This is your Personal space to fin

Avinash Ranjan 1.1k Dec 29, 2022
Code for "FPS-Net: A convolutional fusion network for large-scale LiDAR point cloud segmentation".

FPS-Net Code for "FPS-Net: A convolutional fusion network for large-scale LiDAR point cloud segmentation", accepted by ISPRS journal of Photogrammetry

15 Nov 30, 2022