Code to use Augmented Shapiro Wilks Stopping, as well as code for the paper "Statistically Signifigant Stopping of Neural Network Training"

Related tags

Deep LearningASWS
Overview

This codebase is being actively maintained, please create and issue if you have issues using it

Basics

All data files are included under losses and each folder. The main Augmented Shapiro-Wilk Stopping criterion is implemented in analysis.py, along with several helper functions and wrappers. The other comparison heuristics are also included in analysis.py, along with their wrappers. grapher.py contains all the code for generating the graphs used in the paper, and earlystopping_calculator.py includes code for generating tables and calculating some statistics from the data. hyperparameter_search.py contains all the code used to execute the grid-search on the ASWS method, along with the grid-search for the other heuristics.

Installing

If you would like to try our code, just run pip3 install git+https://github.com/justinkterry/ASWS

Example

If you wanted to try to determine the ASWS stopping point of a model, you can do so using the analysis.py file. If at anypoint during model training you wanted to perform the stop criterion test, you can do

from ASWS.analysis import aswt_stopping

test_acc = [] # for storing model accuracies
for i in training_epochs:

    model.train()
    test_accuracy = model.evaluate(test_set)
    test_acc.append(test_accuracy)
    gamma = 0.5 # fill hyperparameters as desired
    num_data = 20
    slack_prop=0.1
    count = 20

    if len(test_acc) > count:
        aswt_stop_criterion = aswt_stopping(test_acc, gamma, count, num_data, slack_prop=slack_prop)

        if aswt_stop_criterion:
            print("Stop Training")

and if you already have finished training the model and wanted to determine the ASWS stopping point, you would need a CSV with columns Epoch, Training Loss, Training Acc, Test Loss, Test Acc. You could then use the following example

from ASWS.analysis import get_aswt_stopping_point_of_model, read_file

_, _, _, test_acc = read_file("modelaccuracy.csv")
gamma = 0.5 # fill hyperparameters as desired
num_data = 20
slack_prop=0.1
count = 20

stop_epoch, stop_accuracy = get_aswt_stopping_point_of_model(test_acc, gamma=gamma, num_data=num_data, count=count, slack_prop=slack_prop)

pytorch-training

The pytorch-training folder contains the driver file for training each model, along with the model files which contain each network definition. The main.py file can be run out of the box for the models listed in the paper. The model to train is specified via the --model argument. All learning rate schedulers listed in the paper are available (via --schedule step etc.) and the ASWS learning rate scheduler is available via --schedule ASWT . The corresponding ASWS hyperparameters are passed in at the command line (for example --gamma 0.5).

Example

In order to recreate the GoogLeNet ASWT 1 scheduler from the paper, you can use the following command

python3 main.py --model GoogLeNet --schedule ASWT --gamma 0.76 --num_data 19 --slack_prop 0.05 --lr 0.1

Owner
J K Terry
CS PhD student at UMD, founder of Swarm Labs, maintainer of Gym and PettingZoo. I work in deep reinforcement learning.
J K Terry
Very simple NCHW and NHWC conversion tool for ONNX. Change to the specified input order for each and every input OP. Also, change the channel order of RGB and BGR. Simple Channel Converter for ONNX.

scc4onnx Very simple NCHW and NHWC conversion tool for ONNX. Change to the specified input order for each and every input OP. Also, change the channel

Katsuya Hyodo 16 Dec 22, 2022
This repository contains code for the paper "Disentangling Label Distribution for Long-tailed Visual Recognition", published at CVPR' 2021

Disentangling Label Distribution for Long-tailed Visual Recognition (CVPR 2021) Arxiv link Blog post This codebase is built on Causal Norm. Install co

Hyperconnect 85 Oct 18, 2022
Keras implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 8.9k Jan 04, 2023
RLDS stands for Reinforcement Learning Datasets

RLDS RLDS stands for Reinforcement Learning Datasets and it is an ecosystem of tools to store, retrieve and manipulate episodic data in the context of

Google Research 135 Jan 01, 2023
Implemenets the Contourlet-CNN as described in C-CNN: Contourlet Convolutional Neural Networks, using PyTorch

C-CNN: Contourlet Convolutional Neural Networks This repo implemenets the Contourlet-CNN as described in C-CNN: Contourlet Convolutional Neural Networ

Goh Kun Shun (KHUN) 10 Nov 03, 2022
Target Propagation via Regularized Inversion

Target Propagation via Regularized Inversion The present code implements an ideal formulation of target propagation using regularized inverses compute

Vincent Roulet 0 Dec 02, 2021
This is a yolo3 implemented via tensorflow 2.7

YoloV3 - an object detection algorithm implemented via TF 2.x source code In this article I assume you've already familiar with basic computer vision

2 Jan 17, 2022
A Pytorch implementation of MoveNet from Google. Include training code and pre-train model.

Movenet.Pytorch Intro MoveNet is an ultra fast and accurate model that detects 17 keypoints of a body. This is A Pytorch implementation of MoveNet fro

Mr.Fire 241 Dec 26, 2022
Genpass - A Passwors Generator App With Python3

Genpass Welcom again into another python3 App this is simply an Passwors Generat

Mal4D 1 Jan 09, 2022
A repository for the updated version of CoinRun used to collect MUGEN, a multimodal video-audio-text dataset.

A repository for the updated version of CoinRun used to collect MUGEN, a multimodal video-audio-text dataset. This repo contains scripts to train RL agents to navigate the closed world and collect vi

MUGEN 11 Oct 22, 2022
Tracking code for the winner of track 1 in the MMP-Tracking Challenge at ICCV 2021 Workshop.

Tracking Code for the winner of track1 in MMP-Trakcing challenge This repository contains our tracking code for the Multi-camera Multiple People Track

DamoCV 29 Nov 13, 2022
WSDM‘2022: Knowledge Enhanced Sports Game Summarization

Knowledge Enhanced Sports Game Summarization Cooming Soon! :) Data will be released after approval process. Code will be published once the author of

Jiaan Wang 14 Jul 13, 2022
Distributed Evolutionary Algorithms in Python

DEAP DEAP is a novel evolutionary computation framework for rapid prototyping and testing of ideas. It seeks to make algorithms explicit and data stru

Distributed Evolutionary Algorithms in Python 4.9k Jan 05, 2023
Task Transformer Network for Joint MRI Reconstruction and Super-Resolution (MICCAI 2021)

T2Net Task Transformer Network for Joint MRI Reconstruction and Super-Resolution (MICCAI 2021) [Paper][Code] Dependencies numpy==1.18.5 scikit_image==

64 Nov 23, 2022
ChatBot-Pytorch - A GPT-2 ChatBot implemented using Pytorch and Huggingface-transformers

ChatBot-Pytorch A GPT-2 ChatBot implemented using Pytorch and Huggingface-transf

ParZival 42 Dec 09, 2022
Learning Generative Models of Textured 3D Meshes from Real-World Images, ICCV 2021

Learning Generative Models of Textured 3D Meshes from Real-World Images This is the reference implementation of "Learning Generative Models of Texture

Dario Pavllo 115 Jan 07, 2023
TransPrompt - Towards an Automatic Transferable Prompting Framework for Few-shot Text Classification

TransPrompt This code is implement for our EMNLP 2021's paper 《TransPrompt:Towards an Automatic Transferable Prompting Framework for Few-shot Text Cla

WangJianing 23 Dec 21, 2022
BBB streaming without Xorg and Pulseaudio and Chromium and other nonsense (heavily WIP)

BBB Streamer NG? Makes a conference like this... ...streamable like this! I also recorded a small video showing the basic features: https://www.youtub

Lukas Schauer 60 Oct 21, 2022
Code & Models for Temporal Segment Networks (TSN) in ECCV 2016

Temporal Segment Networks (TSN) We have released MMAction, a full-fledged action understanding toolbox based on PyTorch. It includes implementation fo

1.4k Jan 01, 2023
Code for NeurIPS 2021 paper 'Spatio-Temporal Variational Gaussian Processes'

Spatio-Temporal Variational GPs This repository is the official implementation of the methods in the publication: O. Hamelijnck, W.J. Wilkinson, N.A.

AaltoML 26 Sep 16, 2022