Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

Overview

model_based_energy_constrained_compression

Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking" (https://openreview.net/pdf?id=BylBr3C9K7)

@inproceedings{yang2018energy,
  title={Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking},
  author={Yang, Haichuan and Zhu, Yuhao and Liu, Ji},
  booktitle={ICLR},
  year={2019}
}

Prerequisites

Python (3.6)
PyTorch 1.0

To use the ImageNet dataset, download the dataset and move validation images to labeled subfolders (e.g., using https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh)

Training and testing

example

To run the training with energy constraint on AlexNet,

python energy_proj_train.py --net alexnet --dataset imagenet --datadir [imagenet-folder with train and val folders] --batch_size 128 --lr 1e-3 --momentum 0.9 --l2wd 1e-4 --proj_int 10 --logdir ./log/path-of-log --num_workers 8 --exp_bdecay --epochs 30 --distill 0.5 --nodp --budget 0.2

usage

usage: energy_proj_train.py [-h] [--net NET] [--dataset DATASET]
                            [--datadir DATADIR] [--batch_size BATCH_SIZE]
                            [--val_batch_size VAL_BATCH_SIZE]
                            [--num_workers NUM_WORKERS] [--epochs EPOCHS]
                            [--lr LR] [--xlr XLR] [--l2wd L2WD]
                            [--xl2wd XL2WD] [--momentum MOMENTUM]
                            [--lr_decay LR_DECAY] [--lr_decay_e LR_DECAY_E]
                            [--lr_decay_add] [--proj_int PROJ_INT] [--nodp]
                            [--input_mask] [--randinit] [--pretrain PRETRAIN]
                            [--eval] [--seed SEED]
                            [--log_interval LOG_INTERVAL]
                            [--test_interval TEST_INTERVAL]
                            [--save_interval SAVE_INTERVAL] [--logdir LOGDIR]
                            [--distill DISTILL] [--budget BUDGET]
                            [--exp_bdecay] [--mgpu] [--skip1]

Model-Based Energy Constrained Training

optional arguments:
  -h, --help            show this help message and exit
  --net NET             network arch
  --dataset DATASET     dataset used in the experiment
  --datadir DATADIR     dataset dir in this machine
  --batch_size BATCH_SIZE
                        batch size for training
  --val_batch_size VAL_BATCH_SIZE
                        batch size for evaluation
  --num_workers NUM_WORKERS
                        number of workers for training loader
  --epochs EPOCHS       number of epochs to train
  --lr LR               learning rate
  --xlr XLR             learning rate for input mask
  --l2wd L2WD           l2 weight decay
  --xl2wd XL2WD         l2 weight decay (for input mask)
  --momentum MOMENTUM   momentum
  --proj_int PROJ_INT   how many batches for each projection
  --nodp                turn off dropout
  --input_mask          enable input mask
  --randinit            use random init
  --pretrain PRETRAIN   file to load pretrained model
  --eval                evaluate testset in the begining
  --seed SEED           random seed
  --log_interval LOG_INTERVAL
                        how many batches to wait before logging training
                        status
  --test_interval TEST_INTERVAL
                        how many epochs to wait before another test
  --save_interval SAVE_INTERVAL
                        how many epochs to wait before save a model
  --logdir LOGDIR       folder to save to the log
  --distill DISTILL     distill loss weight
  --budget BUDGET       energy budget (relative)
  --exp_bdecay          exponential budget decay
  --mgpu                enable using multiple gpus
  --skip1               skip the first W update
Owner
Haichuan Yang
Haichuan Yang
Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

205 Jan 02, 2023
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 01, 2022
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
A tiny package to compare two neural networks in PyTorch

Compare neural networks by their feature similarity

Anand Krishnamoorthy 180 Dec 30, 2022
A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision

đŸ¤— Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.

Hugging Face 3.5k Jan 08, 2023
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
High-fidelity performance metrics for generative models in PyTorch

High-fidelity performance metrics for generative models in PyTorch

Vikram Voleti 5 Oct 24, 2021
PyTorch toolkit for biomedical imaging

farabio is a minimal PyTorch toolkit for out-of-the-box deep learning support in biomedical imaging. For further information, see Wikis and Docs.

San Askaruly 47 Dec 28, 2022
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
ocaml-torch provides some ocaml bindings for the PyTorch tensor library.

ocaml-torch provides some ocaml bindings for the PyTorch tensor library. This brings to OCaml NumPy-like tensor computations with GPU acceleration and tape-based automatic differentiation.

Laurent Mazare 369 Jan 03, 2023
A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API

micrograd A tiny Autograd engine (with a bite! :)). Implements backpropagation (reverse-mode autodiff) over a dynamically built DAG and a small neural

Andrej 3.5k Jan 08, 2023
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
pip install antialiased-cnns to improve stability and accuracy

Antialiased CNNs [Project Page] [Paper] [Talk] Making Convolutional Networks Shift-Invariant Again Richard Zhang. In ICML, 2019. Quick & easy start Ru

Adobe, Inc. 1.6k Dec 28, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022