Torch-mutable-modules - Use in-place and assignment operations on PyTorch module parameters with support for autograd

Overview

Torch Mutable Modules

Use in-place and assignment operations on PyTorch module parameters with support for autograd.

Publish to PyPI Run tests PyPI version Number of downloads from PyPI per month Python version support Code Style: Black

Why does this exist?

PyTorch does not allow in-place operations on module parameters (usually desirable):

linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=<AddBackward0>
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (net_1) that predicts the parameter values to another neural network (net_2), we need to be able to modify the weights of net_2 in-place and backpropagate the gradients to net_1.

# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)

# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]

# create a mutable network (net_2)
net_2 = to_mutable_module(torch.nn.Linear(1, 1))

# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias

# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()

This library provides a way to easily convert PyTorch modules into mutable modules with the to_mutable_module function.

Installation

You can install torch-mutable-modules from PyPI.

pip install torch-mutable-modules

To upgrade an existing installation of torch-mutable-modules, use the following command:

pip install --upgrade --no-cache-dir torch-mutable-modules

Importing

You can use wildcard imports or import specific functions directly:

# import all functions
from torch_mutable_modules import *

# ... or import the function manually
from torch_mutable_modules import to_mutable_module

Usage

To convert an existing PyTorch module into a mutable module, use the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linear

converted_module.weight *= 0
convreted_module.weight += 69
convreted_module.weight # tensor([[69.]], grad_fn=<AddBackward0>)

You can also declare your own PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)
    
    def forward(self, x):
        return self.linear(x)

my_module = to_mutable_module(MyModule())
my_module.linear.weight *= 0
my_module.linear.weight += 69
my_module.linear.weight # tensor([[69.]], grad_fn=<AddBackward0>)

Usage with CUDA

To create a module on the GPU, simply pass a PyTorch module that is already on the GPU to the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1).cuda()
) # converted_module is now a mutable module on the GPU

Moving a module to the GPU with .to() and .cuda() after instanciation is NOT supported. Instead, hot-swap the module parameter tensors with their CUDA counterparts.

# both of these are valid
converted_module.weight = converted_module.weight.cuda()
converted_module.bias = converted_module.bias.to("cuda")

Detailed examples

Please check out example.py to see more detailed example usages of the to_mutable_module function.

Contributing

Please feel free to submit issues or pull requests!

You might also like...
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

Implements Stacked-RNN in numpy and torch with manual forward and backward functions

Recurrent Neural Networks Implements simple recurrent network and a stacked recurrent network in numpy and torch respectively. Both flavours implement

A torch.Tensor-like DataFrame library supporting multiple execution runtimes and Arrow as a common memory format

TorchArrow (Warning: Unstable Prototype) This is a prototype library currently under heavy development. It does not currently have stable releases, an

A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.
A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.

Object Pose Estimation Demo This tutorial will go through the steps necessary to perform pose estimation with a UR3 robotic arm in Unity. You’ll gain

Python implementation of MULTIseq barcode alignment using fuzzy string matching and GMM barcode assignment

Python implementation of MULTIseq barcode alignment using fuzzy string matching and GMM barcode assignment.

 MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files
MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files

implementation of MM1 and MMC Queue on randomly generated data and evaluate simulation results then compare with analytical results and draw a plot curve for them, simulate some integrals and compare results and run monte carlo algorithm with them

Torch-based tool for quantizing high-dimensional vectors using additive codebooks

Trainable multi-codebook quantization This repository implements a utility for use with PyTorch, and ideally GPUs, for training an efficient quantizer

Torch implementation of
Torch implementation of "Enhanced Deep Residual Networks for Single Image Super-Resolution"

NTIRE2017 Super-resolution Challenge: SNU_CVLab Introduction This is our project repository for CVPR 2017 Workshop (2nd NTIRE). We, Team SNU_CVLab, (B

Automatic number plate recognition using tech:  Yolo, OCR, Scene text detection, scene text recognation, flask, torch
Automatic number plate recognition using tech: Yolo, OCR, Scene text detection, scene text recognation, flask, torch

Automatic Number Plate Recognition Automatic Number Plate Recognition (ANPR) is the process of reading the characters on the plate with various optica

Releases(v1.1.2)
Owner
Kento Nishi
17-year-old programmer at Lynbrook High School, with strong interests in AI/Machine Learning. Open source developer and researcher at the Four Eyes Lab.
Kento Nishi
This is a collection of our NAS and Vision Transformer work.

This is a collection of our NAS and Vision Transformer work.

Microsoft 828 Dec 28, 2022
Revisiting Global Statistics Aggregation for Improving Image Restoration

Revisiting Global Statistics Aggregation for Improving Image Restoration Xiaojie Chu, Liangyu Chen, Chengpeng Chen, Xin Lu Paper: https://arxiv.org/pd

MEGVII Research 128 Dec 24, 2022
[CVPR2021] DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datasets

DoDNet This repo holds the pytorch implementation of DoDNet: DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datase

116 Dec 12, 2022
A public available dataset for road boundary detection in aerial images

Topo-boundary This is the official github repo of paper Topo-boundary: A Benchmark Dataset on Topological Road-boundary Detection Using Aerial Images

Zhenhua Xu 79 Jan 04, 2023
The-Secret-Sharing-Schemes - This interactive script demonstrates the Secret Sharing Schemes algorithm

The-Secret-Sharing-Schemes This interactive script demonstrates the Secret Shari

Nishaant Goswamy 1 Jan 02, 2022
Block-wisely Supervised Neural Architecture Search with Knowledge Distillation (CVPR 2020)

DNA This repository provides the code of our paper: Blockwisely Supervised Neural Architecture Search with Knowledge Distillation. Illustration of DNA

Changlin Li 215 Dec 19, 2022
Tools for computational pathology

A toolkit for computational pathology and machine learning. View documentation Please cite our paper Installation There are several ways to install Pa

254 Dec 12, 2022
This repository contains an implementation of the Permutohedral Attention Module in Pytorch

Permutohedral_attention_module This repository contains an implementation of the Permutohedral Attention Module

Samuel JOUTARD 26 Nov 27, 2022
Co-GAIL: Learning Diverse Strategies for Human-Robot Collaboration

CoGAIL Table of Content Overview Installation Dataset Training Evaluation Trained Checkpoints Acknowledgement Citations License Overview This reposito

Jeremy Wang 29 Dec 24, 2022
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dea

MIC-DKFZ 1.2k Jan 04, 2023
Cross-view Transformers for real-time Map-view Semantic Segmentation (CVPR 2022 Oral)

Cross View Transformers This repository contains the source code and data for our paper: Cross-view Transformers for real-time Map-view Semantic Segme

Brady Zhou 363 Dec 25, 2022
The personal repository of the work: *DanceNet3D: Music Based Dance Generation with Parametric Motion Transformer*.

DanceNet3D The personal repository of the work: DanceNet3D: Music Based Dance Generation with Parametric Motion Transformer. Dataset and Results Pleas

南嘉Nanga 36 Dec 21, 2022
This is the PyTorch implementation of GANs N’ Roses: Stable, Controllable, Diverse Image to Image Translation

Official PyTorch repo for GAN's N' Roses. Diverse im2im and vid2vid selfie to anime translation.

1.1k Jan 01, 2023
Inkscape extensions for figure resizing and editing

Academic-Inkscape: Extensions for figure resizing and editing This repository contains several Inkscape extensions designed for editing plots. Scale P

192 Dec 26, 2022
Neural network for digit classification powered by cuda

cuda_nn_mnist Neural network library for digit classification powered by cuda Resources The library was built to work with MNIST dataset. python-mnist

Nikita Ardashev 1 Dec 20, 2021
A Genetic Programming platform for Python with TensorFlow for wicked-fast CPU and GPU support.

Karoo GP Karoo GP is an evolutionary algorithm, a genetic programming application suite written in Python which supports both symbolic regression and

Kai Staats 149 Jan 09, 2023
Learning Optical Flow from a Few Matches (CVPR 2021)

Learning Optical Flow from a Few Matches This repository contains the source code for our paper: Learning Optical Flow from a Few Matches CVPR 2021 Sh

Shihao Jiang (Zac) 159 Dec 16, 2022
SingleVC performs any-to-one VC, which is an important component of MediumVC project.

SingleVC performs any-to-one VC, which is an important component of MediumVC project. Here is the official implementation of the paper, MediumVC.

谷下雨 26 Dec 28, 2022
For visualizing the dair-v2x-i dataset

3D Detection & Tracking Viewer The project is based on hailanyi/3D-Detection-Tracking-Viewer and is modified, you can find the original version of the

34 Dec 29, 2022
The official implementation of VAENAR-TTS, a VAE based non-autoregressive TTS model.

VAENAR-TTS This repo contains code accompanying the paper "VAENAR-TTS: Variational Auto-Encoder based Non-AutoRegressive Text-to-Speech Synthesis". Sa

THUHCSI 138 Oct 28, 2022