Stochastic gradient descent with model building

Overview

Stochastic Model Building (SMB)

This repository includes a new fast and robust stochastic optimization algorithm for training deep learning models. The core idea of the algorithm is based on building models with local stochastic gradient information. The details of the algorithm is given in our recent paper.

SMB

Abstract

Stochastic gradient descent method and its variants constitute the core optimization algorithms that achieve good convergence rates for solving machine learning problems. These rates are obtained especially when these algorithms are fine-tuned for the application at hand. Although this tuning process can require large computational costs, recent work has shown that these costs can be reduced by line search methods that iteratively adjust the stepsize. We propose an alternative approach to stochastic line search by using a new algorithm based on forward step model building. This model building step incorporates a second-order information that allows adjusting not only the stepsize but also the search direction. Noting that deep learning model parameters come in groups (layers of tensors), our method builds its model and calculates a new step for each parameter group. This novel diagonalization approach makes the selected step lengths adaptive. We provide convergence rate analysis, and experimentally show that the proposed algorithm achieves faster convergence and better generalization in most problems. Moreover, our experiments show that the proposed method is quite robust as it converges for a wide range of initial stepsizes.

Keywords: model building; second-order information; stochastic gradient descent; convergence analysis

Installation

pip install git+https://github.com/sbirbil/SMB.git

Testing

Here is how you can use SMB:

import smb

optimizer = smb.SMB(model.parameters(), independent_batch=False) #independent_batch=True for SMBi optimizer

for epoch in range(100):
    
    # training steps
    model.train()
    
    for batch_index, (data, target) in enumerate(train_loader):
            
        # create loss closure for smb algorithm
        def closure():
            optimizer.zero_grad()
            loss = torch.nn.CrossEntropyLoss()(model(data), target)
            return loss
        
        # forward pass
        loss = optimizer.step(closure=closure)

You can also check our tutorial for a complete example (or the Colab notebook without installation). Set the hyper-parameter independent_batch to True in order to use the SMBi optimizer. Our paper includes more information.

Reproducing The Experiments

See the following script in order to reproduce the results in our paper.

Owner
S. Ilker Birbil
I am a faculty member working on data science and optimization.
S. Ilker Birbil
A Simulation Environment to train Robots in Large Realistic Interactive Scenes

iGibson: A Simulation Environment to train Robots in Large Realistic Interactive Scenes iGibson is a simulation environment providing fast visual rend

Stanford Vision and Learning Lab 493 Jan 04, 2023
Adversarial Autoencoders

Adversarial Autoencoders (with Pytorch) Dependencies argparse time torch torchvision numpy itertools matplotlib Create Datasets python create_datasets

Felipe Ducau 188 Jan 01, 2023
Object Detection Projekt in GKI WS2021/22

tfObjectDetection Object Detection Projekt with tensorflow in GKI WS2021/22 Docker Container: docker run -it --name --gpus all -v path/to/project:p

Tim Eggers 1 Jul 18, 2022
Hard cater examples from Hopper ICLR paper

CATER-h Honglu Zhou*, Asim Kadav, Farley Lai, Alexandru Niculescu-Mizil, Martin Renqiang Min, Mubbasir Kapadia, Hans Peter Graf (*Contact: honglu.zhou

NECLA ML Group 6 May 11, 2021
Graph Attention Networks

GAT Graph Attention Networks (Veličković et al., ICLR 2018): https://arxiv.org/abs/1710.10903 GAT layer t-SNE + Attention coefficients on Cora Overvie

Petar Veličković 2.6k Jan 05, 2023
project page for VinVL

VinVL: Revisiting Visual Representations in Vision-Language Models Updates 02/28/2021: Project page built. Introduction This repository is the project

308 Jan 09, 2023
A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval

CLIP4CMR A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval The original data and pre-calculate

24 Dec 26, 2022
A visualization tool to show a TensorFlow's graph like TensorBoard

tfgraphviz tfgraphviz is a module to visualize a TensorFlow's data flow graph like TensorBoard using Graphviz. tfgraphviz enables to provide a visuali

44 Nov 09, 2022
Submission to Twitter's algorithmic bias bounty challenge

Twitter Ethics Challenge: Pixel Perfect Submission to Twitter's algorithmic bias bounty challenge, by Travis Hoppe (@metasemantic). Abstract We build

Travis Hoppe 4 Aug 19, 2022
Streaming Anomaly Detection Framework in Python (Outlier Detection for Streaming Data)

Python Streaming Anomaly Detection (PySAD) PySAD is an open-source python framework for anomaly detection on streaming multivariate data. Documentatio

Selim Firat Yilmaz 181 Dec 18, 2022
Learning What and Where to Draw

###Learning What and Where to Draw Scott Reed, Zeynep Akata, Santosh Mohan, Samuel Tenka, Bernt Schiele, Honglak Lee This is the code for our NIPS 201

Scott Ellison Reed 337 Nov 18, 2022
A self-supervised 3D representation learning framework named viewpoint bottleneck.

Pointly-supervised 3D Scene Parsing with Viewpoint Bottleneck Paper Created by Liyi Luo, Beiwen Tian, Hao Zhao and Guyue Zhou from Institute for AI In

63 Aug 11, 2022
The PyTorch improved version of TPAMI 2017 paper: Face Alignment in Full Pose Range: A 3D Total Solution.

Face Alignment in Full Pose Range: A 3D Total Solution By Jianzhu Guo. [Updates] 2020.8.30: The pre-trained model and code of ECCV-20 are made public

Jianzhu Guo 3.4k Jan 02, 2023
A simple Neural Network that predicts the label for a series of handwritten digits

Neural_Network A simple Neural Network that predicts the label for a series of handwritten numbers This program tries to predict the label (1,2,3 etc.

Ty 1 Dec 18, 2021
Robust Instance Segmentation through Reasoning about Multi-Object Occlusion [CVPR 2021]

Robust Instance Segmentation through Reasoning about Multi-Object Occlusion [CVPR 2021] Abstract Analyzing complex scenes with DNN is a challenging ta

Irene Yuan 24 Jun 27, 2022
an implementation of 3D Ken Burns Effect from a Single Image using PyTorch

3d-ken-burns This is a reference implementation of 3D Ken Burns Effect from a Single Image [1] using PyTorch. Given a single input image, it animates

Simon Niklaus 1.4k Dec 28, 2022
Code for our ALiBi method for transformer language models.

Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation This repository contains the code and models for our paper Tra

Ofir Press 211 Dec 31, 2022
An implementation of based on pytorch and mmcv

FisherPruning-Pytorch An implementation of Group Fisher Pruning for Practical Network Compression based on pytorch and mmcv Main Functions Pruning f

Peng Lu 15 Dec 17, 2022
MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc.

MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc. ⭐⭐⭐⭐⭐

568 Jan 04, 2023
A PyTorch-centric hybrid classical-quantum machine learning framework

torchquantum A PyTorch-centric hybrid classical-quantum dynamic neural networks framework. News Add a simple example script using quantum gates to do

MIT HAN Lab 400 Jan 02, 2023