Runtime type annotations for the shape, dtype etc. of PyTorch Tensors.

Overview

torchtyping

Type annotations for a tensor's shape, dtype, names, ...

Turn this:

def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)

into this:

def batch_outer_product(x:   TensorType["batch", "x_channels"],
                        y:   TensorType["batch", "y_channels"]
                        ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

with programmatic checking that the shape (dtype, ...) specification is met.

Bye-bye bugs! Say hello to enforced, clear documentation of your code.

If (like me) you find yourself littering your code with comments like # x has shape (batch, hidden_state) or statements like assert x.shape == y.shape , just to keep track of what shape everything is, then this is for you.


Installation

pip install torchtyping

Requires Python 3.7+ and PyTorch 1.7.0+.

Usage

torchtyping allows for type annotating:

  • shape: size, number of dimensions;
  • dtype (float, integer, etc.);
  • layout (dense, sparse);
  • names of dimensions as per named tensors;
  • arbitrary number of batch dimensions with ...;
  • ...plus anything else you like, as torchtyping is highly extensible.

If typeguard is (optionally) installed then at runtime the types can be checked to ensure that the tensors really are of the advertised shape, dtype, etc.

# EXAMPLE

from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()  # use before @typechecked

@typechecked
def func(x: TensorType["batch"],
         y: TensorType["batch"]) -> TensorType["batch"]:
    return x + y

func(rand(3), rand(3))  # works
func(rand(3), rand(1))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.

typeguard also has an import hook that can be used to automatically test an entire module, without needing to manually add @typeguard.typechecked decorators.

If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes. If you're not already using typeguard for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both typeguard and torchtyping also integrate with pytest, so if you're concerned about any performance penalty then they can be enabled during tests only.

API

torchtyping.TensorType[shape, dtype, layout, details]

The core of the library.

Each of shape, dtype, layout, details are optional.

  • The shape argument can be any of:
    • An int: the dimension must be of exactly this size. If it is -1 then any size is allowed.
    • A str: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
    • A ...: An arbitrary number of dimensions of any sizes.
    • A str: int pair (technically it's a slice), combining both str and int behaviour. (Just a str on its own is equivalent to str: -1.)
    • A str: ... pair, in which case the multiple dimensions corresponding to ... will be bound to the name specified by str, and again checked for consistency between arguments.
    • None, which when used in conjunction with is_named below, indicates a dimension that must not have a name in the sense of named tensors.
    • A None: int pair, combining both None and int behaviour. (Just a None on its own is equivalent to None: -1.)
    • A typing.Any: Any size is allowed for this dimension (equivalent to -1).
    • Any tuple of the above. For example.TensorType["batch": ..., "length": 10, "channels", -1]. If you just want to specify the number of dimensions then use for example TensorType[-1, -1, -1] for a three-dimensional tensor.
  • The dtype argument can be any of:
    • torch.float32, torch.float64 etc.
    • int, bool, float, which are converted to their corresponding PyTorch types. float is specifically interpreted as torch.get_default_dtype(), which is usually float32.
  • The layout argument can be either torch.strided or torch.sparse_coo, for dense and sparse tensors respectively.
  • The details argument offers a way to pass an arbitrary number of additional flags that customise and extend torchtyping. Two flags are built-in by default. torchtyping.is_named causes the names of tensor dimensions to be checked, and torchtyping.is_float can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. TensorType[torch.float32].) For discussion on how to customise torchtyping with your own details, see the further documentation.

Check multiple things at once by just putting them all together inside a single []. For example TensorType["batch": ..., "length", "channels", float, is_named].

torchtyping.patch_typeguard()

torchtyping integrates with typeguard to perform runtime type checking. torchtyping.patch_typeguard() should be called at the global level, and will patch typeguard to check TensorTypes.

This function is safe to run multiple times. (It does nothing after the first run).

  • If using @typeguard.typechecked, then torchtyping.patch_typeguard() should be called any time before using @typeguard.typechecked. For example you could call it at the start of each file using torchtyping.
  • If using typeguard.importhook.install_import_hook, then torchtyping.patch_typeguard() should be called any time before defining the functions you want checked. For example you could call torchtyping.patch_typeguard() just once, at the same time as the typeguard import hook. (The order of the hook and the patch doesn't matter.)
  • If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes.
pytest --torchtyping-patch-typeguard

torchtyping offers a pytest plugin to automatically run torchtyping.patch_typeguard() before your tests. pytest will automatically discover the plugin, you just need to pass the --torchtyping-patch-typeguard flag to enable it. Packages can then be passed to typeguard as normal, either by using @typeguard.typechecked, typeguard's import hook, or the pytest flag --typeguard-packages="your_package_here".

Further documentation

See the further documentation for:

  • FAQ;
    • Including flake8 and mypy compatibility;
  • How to write custom extensions to torchtyping;
  • Resources and links to other libraries and materials on this topic;
  • More examples.
Owner
Patrick Kidger
Maths+ML PhD student at Oxford. Neural ODEs+SDEs+CDEs, time series, rough analysis. (Also ice skating, martial arts and scuba diving!)
Patrick Kidger
A pre-trained language model for social media text in Spanish

RoBERTuito A pre-trained language model for social media text in Spanish READ THE FULL PAPER Github Repository RoBERTuito is a pre-trained language mo

25 Dec 29, 2022
The author's officially unofficial PyTorch BigGAN implementation.

BigGAN-PyTorch The author's officially unofficial PyTorch BigGAN implementation. This repo contains code for 4-8 GPU training of BigGANs from Large Sc

Andy Brock 2.6k Jan 02, 2023
StackRec: Efficient Training of Very Deep Sequential Recommender Models by Iterative Stacking

StackRec: Efficient Training of Very Deep Sequential Recommender Models by Iterative Stacking Datasets You can download datasets that have been pre-pr

25 May 29, 2022
An automated facial recognition based attendance system (desktop application)

Facial_Recognition_based_Attendance_System An automated facial recognition based attendance system (desktop application) Made using Python, Tkinter an

1 Jun 21, 2022
This is a Python Module For Encryption, Hashing And Other stuff

EnroCrypt This is a Python Module For Encryption, Hashing And Other Basic Stuff You Need, With Secure Encryption And Strong Salted Hashing You Can Do

5 Sep 15, 2022
RoIAlign & crop_and_resize for PyTorch

RoIAlign for PyTorch This is a PyTorch version of RoIAlign. This implementation is based on crop_and_resize and supports both forward and backward on

Long Chen 530 Jan 07, 2023
[ICCV 2021] Amplitude-Phase Recombination: Rethinking Robustness of Convolutional Neural Networks in Frequency Domain

Amplitude-Phase Recombination (ICCV'21) Official PyTorch implementation of "Amplitude-Phase Recombination: Rethinking Robustness of Convolutional Neur

Guangyao Chen 53 Oct 05, 2022
Random Walk Graph Neural Networks

Random Walk Graph Neural Networks This repository is the official implementation of Random Walk Graph Neural Networks. Requirements Code is written in

Giannis Nikolentzos 38 Jan 02, 2023
PyTorch implementation of Deep HDR Imaging via A Non-Local Network (TIP 2020).

NHDRRNet-PyTorch This is the PyTorch implementation of Deep HDR Imaging via A Non-Local Network (TIP 2020). 0. Differences between Original Paper and

Yutong Zhang 1 Mar 01, 2022
A data-driven maritime port simulator

PySeidon - A Data-Driven Maritime Port Simulator 🌊 Extendable and modular software for maritime port simulation. This software uses entity-component

6 Apr 10, 2022
HyperLib: Deep learning in the Hyperbolic space

HyperLib: Deep learning in the Hyperbolic space Background This library implements common Neural Network components in the hypberbolic space (using th

105 Dec 25, 2022
A general 3D Object Detection codebase in PyTorch.

Det3D is the first 3D Object Detection toolbox which provides off the box implementations of many 3D object detection algorithms such as PointPillars, SECOND, PIXOR, etc, as well as state-of-the-art

Benjin Zhu 1.4k Jan 05, 2023
The code for our CVPR paper PISE: Person Image Synthesis and Editing with Decoupled GAN, Project Page, supp.

PISE The code for our CVPR paper PISE: Person Image Synthesis and Editing with Decoupled GAN, Project Page, supp. Requirement conda create -n pise pyt

jinszhang 110 Nov 21, 2022
Adversarial Autoencoders

Adversarial Autoencoders (with Pytorch) Dependencies argparse time torch torchvision numpy itertools matplotlib Create Datasets python create_datasets

Felipe Ducau 188 Jan 01, 2023
RP-GAN: Stable GAN Training with Random Projections

RP-GAN: Stable GAN Training with Random Projections This repository contains a reference implementation of the algorithm described in the paper: Behna

Ayan Chakrabarti 20 Sep 18, 2021
Train a deep learning net with OpenStreetMap features and satellite imagery.

DeepOSM Classify roads and features in satellite imagery, by training neural networks with OpenStreetMap (OSM) data. DeepOSM can: Download a chunk of

TrailBehind, Inc. 1.3k Nov 24, 2022
ICCV2021: Code for 'Spatial Uncertainty-Aware Semi-Supervised Crowd Counting'

ICCV2021: Code for 'Spatial Uncertainty-Aware Semi-Supervised Crowd Counting'

Yanda Meng 14 May 13, 2022
An Implementation of SiameseRPN with Feature Pyramid Networks

SiameseRPN with FPN This project is mainly based on HelloRicky123/Siamese-RPN. What I've done is just add a Feature Pyramid Network method to the orig

3 Apr 16, 2022
This repo contains the code for the paper "Efficient hierarchical Bayesian inference for spatio-temporal regression models in neuroimaging" that has been accepted to NeurIPS 2021.

Dugh-NeurIPS-2021 This repo contains the code for the paper "Efficient hierarchical Bayesian inference for spatio-temporal regression models in neuroi

Ali Hashemi 5 Jul 12, 2022
CondNet: Conditional Classifier for Scene Segmentation

CondNet: Conditional Classifier for Scene Segmentation Introduction The fully convolutional network (FCN) has achieved tremendous success in dense vis

ycszen 31 Jul 22, 2022