PyTorch Personal Trainer: My framework for deep learning experiments

Related tags

Deep Learningptpt
Overview

Alex's PyTorch Personal Trainer (ptpt)

(name subject to change)

This repository contains my personal lightweight framework for deep learning projects in PyTorch.

Disclaimer: this project is very much work-in-progress. Although technically useable, it is missing many features. Nonetheless, you may find some of the design patterns and code snippets to be useful in the meantime.

Installation

Simply run python -m build in the root of the repo, then run pip install on the resulting .whl file.

No pip package yet..

Usage

Import the library as with any other python library:

from ptpt.trainer import Trainer, TrainerConfig
from ptpt.log import debug, info, warning, error, critical

The core of the library is the trainer.Trainer class. In the simplest case, it takes the following as input:

net:            a `nn.Module` that is the model we wish to train.
loss_fn:        a function that takes a `nn.Module` and a batch as input.
                it returns the loss and optionally other metrics.
train_dataset:  the training dataset.
test_dataset:   the test dataset.
cfg:            a `TrainerConfig` instance that holds all
                hyperparameters.

Once this is instantiated, starting the training loop is as simple as calling trainer.train() where trainer is an instance of Trainer.

cfg stores most of the configuration options for Trainer. See the class definition of TrainerConfig for details on all options.

Examples

An example workflow would go like this:

Define your training and test datasets:

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, download=True, transform=transform)

Define your model:

# in this case, we have imported `Net` from another file
net = Net()

Define your loss function that calls net, taking the full batch as input:

# minimising classification error
def loss_fn(net, batch):
    X, y = batch
    logits = net(X)
    loss = F.nll_loss(logits, y)

    pred = logits.argmax(dim=-1, keepdim=True)
    accuracy = 100. * pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
    return loss, accuracy

Optionally create a configuration object:

# see class definition for full list of parameters
cfg = TrainerConfig(
    exp_name = 'mnist-conv',
    batch_size = 64,
    learning_rate = 4e-4,
    nb_workers = 4,
    save_outputs = False,
    metric_names = ['accuracy']
)

Initialise the Trainer class:

trainer = Trainer(
    net=net,
    loss_fn=loss_fn,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    cfg=cfg
)

Call trainer.train() to begin the training loop

trainer.train() # Go!

See more examples here.

Motivation

I found myself repeating a lot of same structure in many of my deep learning projects. This project is the culmination of my efforts refining the typical structure of my projects into (what I hope to be) a wholly reusable and general-purpose library.

Additionally, there are many nice theoretical and engineering tricks that are available to deep learning researchers. Unfortunately, a lot of them are forgotten because they fall outside the typical workflow, despite them being very beneficial to include. Another goal of this project is to transparently include these tricks so they can be added and removed with minimal code change. Where it is sane to do so, some of these could be on by default.

Finally, I am guilty of forgetting to implement decent logging: both of standard output and of metrics. Logging of standard output is not hard, and is implemented using other libraries such as rich. However, metric logging is less obvious. I'd like to avoid larger dependencies such as tensorboard being an integral part of the project, so metrics will be logged to simple numpy arrays. The library will then provide functions to produce plots from these, or they can be used in another library.

TODO:

  • Make a todo.

References

Citations

Owner
Alex McKinney
Student at Durham University. I do a variety of things. I use Arch btw
Alex McKinney
Neural style in TensorFlow! 🎨

neural-style An implementation of neural style in TensorFlow. This implementation is a lot simpler than a lot of the other ones out there, thanks to T

Anish Athalye 5.5k Dec 29, 2022
DuBE: Duple-balanced Ensemble Learning from Skewed Data

DuBE: Duple-balanced Ensemble Learning from Skewed Data "Towards Inter-class and Intra-class Imbalance in Class-imbalanced Learning" (IEEE ICDE 2022 S

6 Nov 12, 2022
Demo for Real-time RGBD-based Extended Body Pose Estimation paper

Real-time RGBD-based Extended Body Pose Estimation This repository is a real-time demo for our paper that was published at WACV 2021 conference The ou

Renat Bashirov 118 Dec 26, 2022
Official Tensorflow implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Detection"

M-LSD: Towards Light-weight and Real-time Line Segment Detection Official Tensorflow implementation of "M-LSD: Towards Light-weight and Real-time Line

NAVER/LINE Vision 357 Jan 04, 2023
The source code for Adaptive Kernel Graph Neural Network at AAAI2022

AKGNN The source code for Adaptive Kernel Graph Neural Network at AAAI2022. Please cite our paper if you think our work is helpful to you: @inproceedi

11 Nov 25, 2022
JAX + dataclasses

jax_dataclasses jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for: Pytree registrati

Brent Yi 35 Dec 21, 2022
[ICCV2021] Official code for "Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition"

CTR-GCN This repo is the official implementation for Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition. The pap

Yuxin Chen 148 Dec 16, 2022
FB-tCNN for SSVEP Recognition

FB-tCNN for SSVEP Recognition Here are the codes of the tCNN and FB-tCNN in the paper "Filter Bank Convolutional Neural Network for Short Time-Window

Wenlong Ding 12 Dec 14, 2022
Image process framework based on plugin like imagej, it is esay to glue with scipy.ndimage, scikit-image, opencv, simpleitk, mayavi...and any libraries based on numpy

Introduction ImagePy is an open source image processing framework written in Python. Its UI interface, image data structure and table data structure a

ImagePy 1.2k Dec 29, 2022
Attention-guided gan for synthesizing IR images

SI-AGAN Attention-guided gan for synthesizing IR images This repository contains the Tensorflow code for "Pedestrian Gender Recognition by Style Trans

1 Oct 25, 2021
Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes

Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes [Paper] Method overview 4DMatch Benchmark 4DMatch is a benchmark for matc

103 Jan 06, 2023
Official PyTorch implementation of the paper "TEMOS: Generating diverse human motions from textual descriptions"

TEMOS: TExt to MOtionS Generating diverse human motions from textual descriptions Description Official PyTorch implementation of the paper "TEMOS: Gen

Mathis Petrovich 187 Dec 27, 2022
MAGMA - a GPT-style multimodal model that can understand any combination of images and language

MAGMA -- Multimodal Augmentation of Generative Models through Adapter-based Finetuning Authors repo (alphabetical) Constantin (CoEich), Mayukh (Mayukh

Aleph Alpha GmbH 331 Jan 03, 2023
Training DiffWave using variational method from Variational Diffusion Models.

Variational DiffWave Training DiffWave using variational method from Variational Diffusion Models. Quick Start python train_distributed.py discrete_10

Chin-Yun Yu 26 Dec 13, 2022
Code for the paper "Graph Attention Tracking". (CVPR2021)

SiamGAT 1. Environment setup This code has been tested on Ubuntu 16.04, Python 3.5, Pytorch 1.2.0, CUDA 9.0. Please install related libraries before r

122 Dec 24, 2022
🔮 A refreshing functional take on deep learning, compatible with your favorite libraries

Thinc: A refreshing functional take on deep learning, compatible with your favorite libraries From the makers of spaCy, Prodigy and FastAPI Thinc is a

Explosion 2.6k Dec 30, 2022
HEAM: High-Efficiency Approximate Multiplier Optimization for Deep Neural Networks

Approximate Multiplier by HEAM What's HEAM? HEAM is a general optimization method to generate high-efficiency approximate multipliers for specific app

4 Sep 11, 2022
A collection of semantic image segmentation models implemented in TensorFlow

A collection of semantic image segmentation models implemented in TensorFlow. Contains data-loaders for the generic and medical benchmark datasets.

bobby 16 Dec 06, 2019
《Lerning n Intrinsic Grment Spce for Interctive Authoring of Grment Animtion》

Learning an Intrinsic Garment Space for Interactive Authoring of Garment Animation Overview This is the demo code for training a motion invariant enco

YuanBo 213 Dec 14, 2022
Official code for "Stereo Waterdrop Removal with Row-wise Dilated Attention (IROS2021)"

Stereo-Waterdrop-Removal-with-Row-wise-Dilated-Attention This repository includes official codes for "Stereo Waterdrop Removal with Row-wise Dilated A

29 Oct 01, 2022