Code to use Augmented Shapiro Wilks Stopping, as well as code for the paper "Statistically Signifigant Stopping of Neural Network Training"

Related tags

Deep LearningASWS
Overview

This codebase is being actively maintained, please create and issue if you have issues using it

Basics

All data files are included under losses and each folder. The main Augmented Shapiro-Wilk Stopping criterion is implemented in analysis.py, along with several helper functions and wrappers. The other comparison heuristics are also included in analysis.py, along with their wrappers. grapher.py contains all the code for generating the graphs used in the paper, and earlystopping_calculator.py includes code for generating tables and calculating some statistics from the data. hyperparameter_search.py contains all the code used to execute the grid-search on the ASWS method, along with the grid-search for the other heuristics.

Installing

If you would like to try our code, just run pip3 install git+https://github.com/justinkterry/ASWS

Example

If you wanted to try to determine the ASWS stopping point of a model, you can do so using the analysis.py file. If at anypoint during model training you wanted to perform the stop criterion test, you can do

from ASWS.analysis import aswt_stopping

test_acc = [] # for storing model accuracies
for i in training_epochs:

    model.train()
    test_accuracy = model.evaluate(test_set)
    test_acc.append(test_accuracy)
    gamma = 0.5 # fill hyperparameters as desired
    num_data = 20
    slack_prop=0.1
    count = 20

    if len(test_acc) > count:
        aswt_stop_criterion = aswt_stopping(test_acc, gamma, count, num_data, slack_prop=slack_prop)

        if aswt_stop_criterion:
            print("Stop Training")

and if you already have finished training the model and wanted to determine the ASWS stopping point, you would need a CSV with columns Epoch, Training Loss, Training Acc, Test Loss, Test Acc. You could then use the following example

from ASWS.analysis import get_aswt_stopping_point_of_model, read_file

_, _, _, test_acc = read_file("modelaccuracy.csv")
gamma = 0.5 # fill hyperparameters as desired
num_data = 20
slack_prop=0.1
count = 20

stop_epoch, stop_accuracy = get_aswt_stopping_point_of_model(test_acc, gamma=gamma, num_data=num_data, count=count, slack_prop=slack_prop)

pytorch-training

The pytorch-training folder contains the driver file for training each model, along with the model files which contain each network definition. The main.py file can be run out of the box for the models listed in the paper. The model to train is specified via the --model argument. All learning rate schedulers listed in the paper are available (via --schedule step etc.) and the ASWS learning rate scheduler is available via --schedule ASWT . The corresponding ASWS hyperparameters are passed in at the command line (for example --gamma 0.5).

Example

In order to recreate the GoogLeNet ASWT 1 scheduler from the paper, you can use the following command

python3 main.py --model GoogLeNet --schedule ASWT --gamma 0.76 --num_data 19 --slack_prop 0.05 --lr 0.1

Owner
J K Terry
CS PhD student at UMD, founder of Swarm Labs, maintainer of Gym and PettingZoo. I work in deep reinforcement learning.
J K Terry
CvT-ASSD: Convolutional vision-Transformerbased Attentive Single Shot MultiBox Detector (ICTAI 2021 CCF-C 会议)The 33rd IEEE International Conference on Tools with Artificial Intelligence

CvT-ASSD including extra CvT, CvT-SSD, VGG-ASSD models original-code-website: https://github.com/albert-jin/CvT-SSD new-code-website: https://github.c

金伟强 -上海大学人工智能小渣渣~ 5 Mar 07, 2022
Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The original code is written in keras.

CasRel-pytorch-reimplement Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The o

longlongman 170 Dec 01, 2022
Implementation of Research Paper "Learning to Enhance Low-Light Image via Zero-Reference Deep Curve Estimation"

Zero-DCE and Zero-DCE++(Lite architechture for Mobile and edge Devices) Papers Abstract The paper presents a novel method, Zero-Reference Deep Curve E

Tauhid Khan 15 Dec 10, 2022
Multi Task RL Baselines

MTRL Multi Task RL Algorithms Contents Introduction Setup Usage Documentation Contributing to MTRL Community Acknowledgements Introduction M

Facebook Research 171 Jan 09, 2023
Source code for the GPT-2 story generation models in the EMNLP 2020 paper "STORIUM: A Dataset and Evaluation Platform for Human-in-the-Loop Story Generation"

Storium GPT-2 Models This is the official repository for the GPT-2 models described in the EMNLP 2020 paper [STORIUM: A Dataset and Evaluation Platfor

Nader Akoury 27 Dec 20, 2022
The code for paper Efficiently Solve the Max-cut Problem via a Quantum Qubit Rotation Algorithm

Quantum Qubit Rotation Algorithm Single qubit rotation gates $$ U(\Theta)=\bigotimes_{i=1}^n R_x (\phi_i) $$ QQRA for the max-cut problem This code wa

SheffieldWang 0 Oct 18, 2021
Joint Discriminative and Generative Learning for Person Re-identification. CVPR'19 (Oral)

Joint Discriminative and Generative Learning for Person Re-identification [Project] [Paper] [YouTube] [Bilibili] [Poster] [Supp] Joint Discriminative

NVIDIA Research Projects 1.2k Dec 30, 2022
Repository containing detailed experiments related to the paper "Memotion Analysis through the Lens of Joint Embedding".

Memotion Analysis Through The Lens Of Joint Embedding This repository contains the experiments conducted as described in the paper 'Memotion Analysis

Nethra Gunti 1 Mar 16, 2022
Unicorn can be used for performance analyses of highly configurable systems with causal reasoning

Unicorn can be used for performance analyses of highly configurable systems with causal reasoning. Users or developers can query Unicorn for a performance task.

AISys Lab 27 Jan 05, 2023
Merlion: A Machine Learning Framework for Time Series Intelligence

Merlion: A Machine Learning Library for Time Series Table of Contents Introduction Installation Documentation Getting Started Anomaly Detection Foreca

Salesforce 2.8k Dec 30, 2022
Code Repository for The Kaggle Book, Published by Packt Publishing

The Kaggle Book Data analysis and machine learning for competitive data science Code Repository for The Kaggle Book, Published by Packt Publishing "Lu

Packt 1.6k Jan 07, 2023
This is a vision-based 3d model manipulation and control UI

Manipulation of 3D Models Using Hand Gesture This program allows user to manipulation 3D models (.obj format) with their hands. The project support bo

Cortic Technology Corp. 43 Oct 23, 2022
The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue.

The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue. How do I cite D-REX? For now, cite

Alon Albalak 6 Mar 31, 2022
OpenABC-D: A Large-Scale Dataset For Machine Learning Guided Integrated Circuit Synthesis

OpenABC-D: A Large-Scale Dataset For Machine Learning Guided Integrated Circuit Synthesis Overview OpenABC-D is a large-scale labeled dataset generate

NYU Machine-Learning guided Design Automation (MLDA) 31 Nov 22, 2022
Self-Supervised Learning

Self-Supervised Learning Features self_supervised offers features like modular framework support for multi-gpu training using PyTorch Lightning easy t

Robin 1 Dec 14, 2021
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021

Hypercorrelation Squeeze for Few-Shot Segmentation This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juh

Juhong Min 165 Dec 28, 2022
[CVPR'21 Oral] Seeing Out of tHe bOx: End-to-End Pre-training for Vision-Language Representation Learning

Seeing Out of tHe bOx: End-to-End Pre-training for Vision-Language Representation Learning [CVPR'21, Oral] By Zhicheng Huang*, Zhaoyang Zeng*, Yupan H

Multimedia Research 196 Dec 13, 2022
This is the repo for the paper `SumGNN: Multi-typed Drug Interaction Prediction via Efficient Knowledge Graph Summarization'. (published in Bioinformatics'21)

SumGNN: Multi-typed Drug Interaction Prediction via Efficient Knowledge Graph Summarization This is the code for our paper ``SumGNN: Multi-typed Drug

Yue Yu 58 Dec 21, 2022
Pretraining Representations For Data-Efficient Reinforcement Learning

Pretraining Representations For Data-Efficient Reinforcement Learning Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Ch

Mila 40 Dec 11, 2022
[NeurIPS2021] Code Release of Learning Transferable Perturbations

Learning Transferable Adversarial Perturbations This is an official release of the paper Learning Transferable Adversarial Perturbations. The code is

Krishna Kanth 17 Nov 11, 2022