Efficient Lottery Ticket Finding: Less Data is More

Overview

Efficient Lottery Ticket Finding: Less Data is More

License: MIT

Codes for this paper Efficient Lottery Ticket Finding: Less Data is More. [ICML 2021]

Zhenyu Zhang*, Xuxi Chen*, Tianlong Chen*, Zhangyang Wang

Overview

The lottery ticket hypothesis (LTH) reveals the existence of winning tickets (sparse but critical subnetworks) for dense networks, that can be trained in isolation from random initialization to match the latter’s accuracies. However, finding winning tickets requires burdensome computations in the train-prune-retrain process, especially on large-scale datasets (e.g., ImageNet), restricting their practical benefits. This paper explores a new perspective on finding lottery tickets more efficiently, by doing so only with a specially selected subset of data, called Pruning- Aware Critical set (PrAC set), rather than using the full training set. The concept of PrAC set was inspired by the recent observation, that deep networks have samples that are either hard to memorize during training, or easy to forget during pruning. A PrAC set is thus hypothesized to capture those most challenging and informative examples for the dense model. We observe that a high-quality winning ticket can be found with training and pruning the dense network on the very compact PrAC set, which can substantially save training iterations for the ticket finding process.

Prerequisites

Pytorch >= 1.4

torchvision

advertorch

Usage

Vanilla Lottery Tickets

python -u main_imp.py \
	--data data/cifar10 \
	--dataset cifar10 \
	--arch res20s \
	--batch_size 128 \
	--lr 0.1 \
	--pruning_times 16 \
	--prune_type rewind_lt \
	--rewind_epoch 2 \
	--save_dir lt_cifar10_res20s

PrAC Lottery Tickets

python -u main_PrAC_imp.py \
	--data data/cifar10 \
	--dataset cifar10 \
	--arch res20s \
	--split_file npy_files/cifar10-train-val.npy \
	--batch_size 128 \
	--lr 0.1 \
	--pruning_times 16 \
	--eb_eps 0.08 \
	--prune_type rewind_lt \
	--rewind_epoch 2 \
	--threshold 0 \
	--save_dir PrAC_lt_cifar10_res20s
	

Train subnetworks

python -u main_train.py \
	--data data/cifar10 \
	--dataset cifar10 \
	--arch res20s \
	--batch_size 128 \
	--lr 0.1 \
	--init_dir PrAC_lt_cifar10_res20s/1checkpoint.pth.tar \ 
	--mask_dir PrAC_lt_cifar10_res20s/1checkpoint.pth.tar \ # sparsity=20%
	--save_dir retrain_PrAC_lt_cifar10_res20s/1

Citation


Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
Implementation of Change-Based Exploration Transfer (C-BET)

Implementation of Change-Based Exploration Transfer (C-BET), as presented in Interesting Object, Curious Agent: Learning Task-Agnostic Exploration.

Simone Parisi 29 Dec 04, 2022
HiPAL: A Deep Framework for Physician Burnout Prediction Using Activity Logs in Electronic Health Records

HiPAL Code for KDD'22 Applied Data Science Track submission -- HiPAL: A Deep Framework for Physician Burnout Prediction Using Activity Logs in Electro

Hanyang Liu 4 Aug 08, 2022
A series of Jupyter notebooks with Chinese comment that walk you through the fundamentals of Machine Learning and Deep Learning in python using Scikit-Learn and TensorFlow.

Hands-on-Machine-Learning 目的 这份笔记旨在帮助中文学习者以一种较快较系统的方式入门机器学习, 是在学习Hands-on Machine Learning with Scikit-Learn and TensorFlow这本书的 时候做的个人笔记: 此项目的可取之处 原书的

Baymax 1.5k Dec 21, 2022
Code for "Steerable Pyramid Transform Enables Robust Left Ventricle Quantification"

Code for "Steerable Pyramid Transform Enables Robust Left Ventricle Quantification" This is an end-to-end framework for accurate and robust left ventr

2 Jul 09, 2022
(AAAI2020)Grapy-ML: Graph Pyramid Mutual Learning for Cross-dataset Human Parsing

Grapy-ML: Graph Pyramid Mutual Learning for Cross-dataset Human Parsing This repository contains pytorch source code for AAAI2020 oral paper: Grapy-ML

54 Aug 04, 2022
[CVPR 2021] Monocular depth estimation using wavelets for efficiency

Single Image Depth Prediction with Wavelet Decomposition Michaël Ramamonjisoa, Michael Firman, Jamie Watson, Vincent Lepetit and Daniyar Turmukhambeto

Niantic Labs 205 Jan 02, 2023
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dea

MIC-DKFZ 1.2k Jan 04, 2023
Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation

Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation

NVIDIA Research Projects 4.8k Jan 09, 2023
SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch.

The SpeechBrain Toolkit SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch. The goal is to create a single, flexible, and us

SpeechBrain 5.1k Jan 02, 2023
A toolkit for document-level event extraction, containing some SOTA model implementations

❤️ A Toolkit for Document-level Event Extraction with & without Triggers Hi, there 👋 . Thanks for your stay in this repo. This project aims at buildi

Tong Zhu(朱桐) 159 Dec 22, 2022
Learning Off-Policy with Online Planning, CoRL 2021

LOOP: Learning Off-Policy with Online Planning Accepted in Conference of Robot Learning (CoRL) 2021. Harshit Sikchi, Wenxuan Zhou, David Held Paper In

Harshit Sikchi 24 Nov 22, 2022
P-Tuning v2: Prompt Tuning Can Be Comparable to Finetuning Universally Across Scales and Tasks

P-tuning v2 P-Tuning v2: Prompt Tuning Can Be Comparable to Finetuning Universally Across Scales and Tasks An optimized prompt tuning strategy achievi

THUDM 540 Dec 30, 2022
Raster Vision is an open source Python framework for building computer vision models on satellite, aerial, and other large imagery sets

Raster Vision is an open source Python framework for building computer vision models on satellite, aerial, and other large imagery sets (including obl

Azavea 1.7k Dec 22, 2022
Code for SALT: Stackelberg Adversarial Regularization, EMNLP 2021.

SALT: Stackelberg Adversarial Regularization Code for Adversarial Regularization as Stackelberg Game: An Unrolled Optimization Approach, EMNLP 2021. R

Simiao Zuo 10 Jan 10, 2022
TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

TorchMultimodal (Alpha Release) Introduction TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

Meta Research 663 Jan 06, 2023
On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation

On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation On Nonlinear Latent Transformations for GAN-based Image Editi

Valentin Khrulkov 22 Oct 24, 2022
Python package for downloading ECMWF reanalysis data and converting it into a time series format.

ecmwf_models Readers and converters for data from the ECMWF reanalysis models. Written in Python. Works great in combination with pytesmo. Citation If

TU Wien - Department of Geodesy and Geoinformation 31 Dec 26, 2022
Rasterize with the least efforts for researchers.

utils3d Rasterize and do image-based 3D transforms with the least efforts for researchers. Based on numpy and OpenGL. It could be helpful when you wan

Ruicheng Wang 8 Dec 15, 2022
sense-py-AnishaBaishya created by GitHub Classroom

Compute Statistics Here we compute statistics for a bunch of numbers. This project uses the unittest framework to test functionality. Pass the tests T

1 Oct 21, 2021
Additional environments compatible with OpenAI gym

Decentralized Control of Quadrotor Swarms with End-to-end Deep Reinforcement Learning A codebase for training reinforcement learning policies for quad

Zhehui Huang 40 Dec 06, 2022