Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"

Overview

Dataset Distillation by Matching Training Trajectories

Project Page | Paper


Teaser image

This repo contains code for training expert trajectories and distilling synthetic data from our Dataset Distillation by Matching Training Trajectories paper (CVPR 2022). Please see our project page for more results.

Dataset Distillation by Matching Training Trajectories
George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A. Efros, Jun-Yan Zhu
CMU, MIT, UC Berkeley
CVPR 2022

The task of "Dataset Distillation" is to learn a small number of synthetic images such that a model trained on this set alone will have similar test performance as a model trained on the full real dataset.

Our method distills the synthetic dataset by directly optimizing the fake images to induce similar network training dynamics as the full, real dataset. We train "student" networks for many iterations on the synthetic data, measure the error in parameter space between the "student" and "expert" networks trained on real data, and back-propagate through all the student network updates to optimize the synthetic pixels.

Wearable ImageNet: Synthesizing Tileable Textures

Teaser image

Instead of treating our synthetic data as individual images, we can instead encourage every random crop (with circular padding) on a larger canvas of pixels to induce a good training trajectory. This results in class-based textures that are continuous around their edges.

Given these tileable textures, we can apply them to areas that require such properties, such as clothing patterns.

Visualizations made using FAB3D

Getting Started

First, download our repo:

git clone https://github.com/GeorgeCazenavette/mtt-distillation.git
cd mtt-distillation

For an express instillation, we include .yaml files.

If you have an RTX 30XX GPU (or newer), run

conda env create -f requirements_11_3.yaml

If you have an RTX 20XX GPU (or older), run

conda env create -f requirements_10_2.yaml

You can then activate your conda environment with

conda activate distillation
Quadro Users Take Note:

torch.nn.DataParallel seems to not work on Quadro A5000 GPUs, and this may extend to other Quadro cards.

If you experience indefinite hanging during training, try running the process with only 1 GPU by prepending CUDA_VISIBLE_DEVICES=0 to the command.

Generating Expert Trajectories

Before doing any distillation, you'll need to generate some expert trajectories using buffer.py

The following command will train 100 ConvNet models on CIFAR-100 with ZCA whitening for 50 epochs each:

python buffer.py --dataset=CIFAR100 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

We used 50 epochs with the default learning rate for all of our experts. Worse (but still interesting) results can be obtained faster through training fewer experts by changing --num_experts. Note that experts need only be trained once and can be re-used for multiple distillation experiments.

Distillation by Matching Training Trajectories

The following command will then use the buffers we just generated to distill CIFAR-100 down to just 1 image per class:

python distill.py --dataset=CIFAR100 --ipc=1 --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

ImageNet

Our method can also distill subsets of ImageNet into low-support synthetic sets.

When generating expert trajectories with buffer.py or distilling the dataset with distill.py, you must designate a named subset of ImageNet with the --subset flag.

For example,

python distill.py --dataset=ImageNet --subset=imagefruit --model=ConvNetD5 --ipc=1 --res=128 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagefruit subset (at 128x128 resolution) into the following 10 images

To register your own ImageNet subset, you can add it to the Config class at the top of utils.py.

Simply create a list with the desired class ID's and add it to the dictionary.

This gist contains a list of all 1k ImageNet classes and their corresponding numbers.

Texture Distillation

You can also use the same set of expert trajectories (except those using ZCA) to distill classes into toroidal textures by simply adding the --texture flag.

For example,

python distill.py --texture --dataset=ImageNet --subset=imagesquawk --model=ConvNetD5 --ipc=1 --res=256 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagesquawk subset (at 256x256 resolution) into the following 10 textures

Acknowledgments

We would like to thank Alexander Li, Assaf Shocher, Gokul Swamy, Kangle Deng, Ruihan Gao, Nupur Kumari, Muyang Li, Gaurav Parmar, Chonghyuk Song, Sheng-Yu Wang, and Bingliang Zhang as well as Simon Lucey's Vision Group at the University of Adelaide for their valuable feedback. This work is supported, in part, by the NSF Graduate Research Fellowship under Grant No. DGE1745016 and grants from J.P. Morgan Chase, IBM, and SAP. Our code is adapted from https://github.com/VICO-UoE/DatasetCondensation

Related Work

  1. Tongzhou Wang et al. "Dataset Distillation", in arXiv preprint 2018
  2. Bo Zhao et al. "Dataset Condensation with Gradient Matching", in ICLR 2020
  3. Bo Zhao and Hakan Bilen. "Dataset Condensation with Differentiable Siamese Augmentation", in ICML 2021
  4. Timothy Nguyen et al. "Dataset Meta-Learning from Kernel Ridge-Regression", in ICLR 2021
  5. Timothy Nguyen et al. "Dataset Distillation with Infinitely Wide Convolutional Networks", in NeurIPS 2021
  6. Bo Zhao and Hakan Bilen. "Dataset Condensation with Distribution Matching", in arXiv preprint 2021
  7. Kai Wang et al. "CAFE: Learning to Condense Dataset by Aligning Features", in CVPR 2022

Reference

If you find our code useful for your research, please cite our paper.

@inproceedings{
cazenavette2022distillation,
title={Dataset Distillation by Matching Training Trajectories},
author={George Cazenavette and Tongzhou Wang and Antonio Torralba and Alexei A. Efros and Jun-Yan Zhu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
Owner
George Cazenavette
Carnegie Mellon University
George Cazenavette
Flower classification model that classifies flowers in 10 classes made using transfer learning (~85% accuracy).

flower-classification-inceptionV3 Flower classification model that classifies flowers in 10 classes. Training and validation are done using a pre-anot

Ivan R. Mršulja 1 Dec 12, 2021
Code for the paper "Implicit Representations of Meaning in Neural Language Models"

Implicit Representations of Meaning in Neural Language Models Preliminaries Create and set up a conda environment as follows: conda create -n state-pr

Belinda Li 39 Nov 03, 2022
Torch-mutable-modules - Use in-place and assignment operations on PyTorch module parameters with support for autograd

Torch Mutable Modules Use in-place and assignment operations on PyTorch module p

Kento Nishi 7 Jun 06, 2022
Project page for the paper Semi-Supervised Raw-to-Raw Mapping 2021.

Project page for the paper Semi-Supervised Raw-to-Raw Mapping 2021.

Mahmoud Afifi 22 Nov 08, 2022
Deep Q-network learning to play flappybird.

AI Plays Flappy Bird I've trained a DQN that learns to play flappy bird on it's own. Try the pre-trained model First install the pip requirements and

Anish Shrestha 3 Mar 01, 2022
blind SQLIpy sebuah alat injeksi sql yang menggunakan waktu sql untuk mendapatkan sebuah server database.

blind SQLIpy Alat blind SQLIpy ini merupakan alat injeksi sql yang menggunakan metode time based blind sql injection metode tersebut membutuhkan waktu

Galih Anggoro Prasetya 4 Feb 24, 2022
Fibonacci Method Gradient Descent

An implementation of the Fibonacci method for gradient descent, featuring a TKinter GUI for inputting the function / parameters to be examined and a matplotlib plot of the function and results.

Emma 1 Jan 28, 2022
Contrastive Multi-View Representation Learning on Graphs

Contrastive Multi-View Representation Learning on Graphs This work introduces a self-supervised approach based on contrastive multi-view learning to l

Kaveh 208 Dec 23, 2022
End-to-end image segmentation kit based on PaddlePaddle.

English | 简体中文 PaddleSeg PaddleSeg has released the new version including the following features: Our team won the 6.2k Jan 02, 2023

This repository contains the source code of Auto-Lambda and baselines from the paper, Auto-Lambda: Disentangling Dynamic Task Relationships.

Auto-Lambda This repository contains the source code of Auto-Lambda and baselines from the paper, Auto-Lambda: Disentangling Dynamic Task Relationship

Shikun Liu 76 Dec 20, 2022
Transformers provides thousands of pretrained models to perform tasks on different modalities such as text, vision, and audio.

English | 简体中文 | 繁體中文 | 한국어 State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained models

Clara Meister 50 Nov 12, 2022
[ICCV'21] NEAT: Neural Attention Fields for End-to-End Autonomous Driving

NEAT: Neural Attention Fields for End-to-End Autonomous Driving Paper | Supplementary | Video | Poster | Blog This repository is for the ICCV 2021 pap

254 Jan 02, 2023
Stock-history-display - something like a easy yearly review for your stock performance

Stock History Display Available on Heroku: https://stock-history-display.herokua

LiaoJJ 1 Jan 07, 2022
The GitHub repository for the paper: “Time Series is a Special Sequence: Forecasting with Sample Convolution and Interaction“.

SCINet This is the original PyTorch implementation of the following work: Time Series is a Special Sequence: Forecasting with Sample Convolution and I

386 Jan 01, 2023
History Aware Multimodal Transformer for Vision-and-Language Navigation

History Aware Multimodal Transformer for Vision-and-Language Navigation This repository is the official implementation of History Aware Multimodal Tra

Shizhe Chen 46 Nov 23, 2022
Inverse Optimal Control Adapted to the Noise Characteristics of the Human Sensorimotor System

Inverse Optimal Control Adapted to the Noise Characteristics of the Human Sensorimotor System This repository contains code for the paper Schultheis,

2 Oct 28, 2022
A set of tests for evaluating large-scale algorithms for Wasserstein-2 transport maps computation.

Continuous Wasserstein-2 Benchmark This is the official Python implementation of the NeurIPS 2021 paper Do Neural Optimal Transport Solvers Work? A Co

Alexander 22 Dec 12, 2022
Python wrapper class for OpenVINO Model Server. User can submit inference request to OVMS with just a few lines of code

Python wrapper class for OpenVINO Model Server. User can submit inference request to OVMS with just a few lines of code.

Yasunori Shimura 7 Jul 27, 2022
A modular, research-friendly framework for high-performance and inference of sequence models at many scales

T5X T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of

Google Research 1.1k Jan 08, 2023
Official implementation of VQ-Diffusion

Official implementation of VQ-Diffusion: Vector Quantized Diffusion Model for Text-to-Image Synthesis

Microsoft 592 Jan 03, 2023