Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"

Overview

Dataset Distillation by Matching Training Trajectories

Project Page | Paper


Teaser image

This repo contains code for training expert trajectories and distilling synthetic data from our Dataset Distillation by Matching Training Trajectories paper (CVPR 2022). Please see our project page for more results.

Dataset Distillation by Matching Training Trajectories
George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A. Efros, Jun-Yan Zhu
CMU, MIT, UC Berkeley
CVPR 2022

The task of "Dataset Distillation" is to learn a small number of synthetic images such that a model trained on this set alone will have similar test performance as a model trained on the full real dataset.

Our method distills the synthetic dataset by directly optimizing the fake images to induce similar network training dynamics as the full, real dataset. We train "student" networks for many iterations on the synthetic data, measure the error in parameter space between the "student" and "expert" networks trained on real data, and back-propagate through all the student network updates to optimize the synthetic pixels.

Wearable ImageNet: Synthesizing Tileable Textures

Teaser image

Instead of treating our synthetic data as individual images, we can instead encourage every random crop (with circular padding) on a larger canvas of pixels to induce a good training trajectory. This results in class-based textures that are continuous around their edges.

Given these tileable textures, we can apply them to areas that require such properties, such as clothing patterns.

Visualizations made using FAB3D

Getting Started

First, download our repo:

git clone https://github.com/GeorgeCazenavette/mtt-distillation.git
cd mtt-distillation

For an express instillation, we include .yaml files.

If you have an RTX 30XX GPU (or newer), run

conda env create -f requirements_11_3.yaml

If you have an RTX 20XX GPU (or older), run

conda env create -f requirements_10_2.yaml

You can then activate your conda environment with

conda activate distillation
Quadro Users Take Note:

torch.nn.DataParallel seems to not work on Quadro A5000 GPUs, and this may extend to other Quadro cards.

If you experience indefinite hanging during training, try running the process with only 1 GPU by prepending CUDA_VISIBLE_DEVICES=0 to the command.

Generating Expert Trajectories

Before doing any distillation, you'll need to generate some expert trajectories using buffer.py

The following command will train 100 ConvNet models on CIFAR-100 with ZCA whitening for 50 epochs each:

python buffer.py --dataset=CIFAR100 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

We used 50 epochs with the default learning rate for all of our experts. Worse (but still interesting) results can be obtained faster through training fewer experts by changing --num_experts. Note that experts need only be trained once and can be re-used for multiple distillation experiments.

Distillation by Matching Training Trajectories

The following command will then use the buffers we just generated to distill CIFAR-100 down to just 1 image per class:

python distill.py --dataset=CIFAR100 --ipc=1 --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

ImageNet

Our method can also distill subsets of ImageNet into low-support synthetic sets.

When generating expert trajectories with buffer.py or distilling the dataset with distill.py, you must designate a named subset of ImageNet with the --subset flag.

For example,

python distill.py --dataset=ImageNet --subset=imagefruit --model=ConvNetD5 --ipc=1 --res=128 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagefruit subset (at 128x128 resolution) into the following 10 images

To register your own ImageNet subset, you can add it to the Config class at the top of utils.py.

Simply create a list with the desired class ID's and add it to the dictionary.

This gist contains a list of all 1k ImageNet classes and their corresponding numbers.

Texture Distillation

You can also use the same set of expert trajectories (except those using ZCA) to distill classes into toroidal textures by simply adding the --texture flag.

For example,

python distill.py --texture --dataset=ImageNet --subset=imagesquawk --model=ConvNetD5 --ipc=1 --res=256 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagesquawk subset (at 256x256 resolution) into the following 10 textures

Acknowledgments

We would like to thank Alexander Li, Assaf Shocher, Gokul Swamy, Kangle Deng, Ruihan Gao, Nupur Kumari, Muyang Li, Gaurav Parmar, Chonghyuk Song, Sheng-Yu Wang, and Bingliang Zhang as well as Simon Lucey's Vision Group at the University of Adelaide for their valuable feedback. This work is supported, in part, by the NSF Graduate Research Fellowship under Grant No. DGE1745016 and grants from J.P. Morgan Chase, IBM, and SAP. Our code is adapted from https://github.com/VICO-UoE/DatasetCondensation

Related Work

  1. Tongzhou Wang et al. "Dataset Distillation", in arXiv preprint 2018
  2. Bo Zhao et al. "Dataset Condensation with Gradient Matching", in ICLR 2020
  3. Bo Zhao and Hakan Bilen. "Dataset Condensation with Differentiable Siamese Augmentation", in ICML 2021
  4. Timothy Nguyen et al. "Dataset Meta-Learning from Kernel Ridge-Regression", in ICLR 2021
  5. Timothy Nguyen et al. "Dataset Distillation with Infinitely Wide Convolutional Networks", in NeurIPS 2021
  6. Bo Zhao and Hakan Bilen. "Dataset Condensation with Distribution Matching", in arXiv preprint 2021
  7. Kai Wang et al. "CAFE: Learning to Condense Dataset by Aligning Features", in CVPR 2022

Reference

If you find our code useful for your research, please cite our paper.

@inproceedings{
cazenavette2022distillation,
title={Dataset Distillation by Matching Training Trajectories},
author={George Cazenavette and Tongzhou Wang and Antonio Torralba and Alexei A. Efros and Jun-Yan Zhu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
Owner
George Cazenavette
Carnegie Mellon University
George Cazenavette
One line to host them all. Bootstrap your image search case in minutes.

One line to host them all. Bootstrap your image search case in minutes. Survey NOW gives the world access to customized neural image search in just on

Jina AI 403 Dec 30, 2022
This repository contains code released by Google Research.

This repository contains code released by Google Research.

Google Research 26.6k Dec 31, 2022
A privacy-focused, intelligent security camera system.

Self-Hosted Home Security Camera System A privacy-focused, intelligent security camera system. Features: Multi-camera support w/ minimal configuration

Scott Barnes 175 Jan 01, 2023
The Python code for the paper A Hybrid Quantum-Classical Algorithm for Robust Fitting

About The Python code for the paper A Hybrid Quantum-Classical Algorithm for Robust Fitting The demo program was only tested under Conda in a standard

Anh-Dzung Doan 5 Nov 28, 2022
RodoSol-ALPR Dataset

RodoSol-ALPR Dataset This dataset, called RodoSol-ALPR dataset, contains 20,000 images captured by static cameras located at pay tolls owned by the Ro

Rayson Laroca 45 Dec 15, 2022
ESP32 python application to read data from a Tilt™ Hydrometer for homebrewing

TitlESP32 ESP32 MicroPython application to read and log data from a Tilt™ Hydrometer. Requirements A board with an ESP32 chip USB cable - USB A / micr

IoBeer 5 Dec 01, 2022
Graph Transformer Architecture. Source code for

Graph Transformer Architecture Source code for the paper "A Generalization of Transformer Networks to Graphs" by Vijay Prakash Dwivedi and Xavier Bres

NTU Graph Deep Learning Lab 561 Jan 08, 2023
GradAttack is a Python library for easy evaluation of privacy risks in public gradients in Federated Learning

GradAttack is a Python library for easy evaluation of privacy risks in public gradients in Federated Learning, as well as corresponding mitigation strategies.

129 Dec 30, 2022
Unpaired Caricature Generation with Multiple Exaggerations

CariMe-pytorch The official pytorch implementation of the paper "CariMe: Unpaired Caricature Generation with Multiple Exaggerations" CariMe: Unpaired

Gu Zheng 37 Dec 30, 2022
Linear Variational State Space Filters

Linear Variational State Space Filters To set up the environment, use the provided scripts in the docker/ folder to build and run the codebase inside

0 Dec 13, 2021
Demonstrates how to divide a DL model into multiple IR model files (division) and introduce a simplest way to implement a custom layer works with OpenVINO IR models.

Demonstration of OpenVINO techniques - Model-division and a simplest-way to support custom layers Description: Model Optimizer in Intel(r) OpenVINO(tm

Yasunori Shimura 12 Nov 09, 2022
Simple Dynamic Batching Inference

Simple Dynamic Batching Inference 解决了什么问题? 众所周知,Batch对于GPU上深度学习模型的运行效率影响很大。。。 是在Inference时。搜索、推荐等场景自带比较大的batch,问题不大。但更多场景面临的往往是稀碎的请求(比如图片服务里一次一张图)。 如果

116 Jan 01, 2023
Official Datasets and Implementation from our Paper "Video Class Agnostic Segmentation in Autonomous Driving".

Video Class Agnostic Segmentation [Method Paper] [Benchmark Paper] [Project] [Demo] Official Datasets and Implementation from our Paper "Video Class A

Mennatullah Siam 26 Oct 24, 2022
Example scripts for the detection of lanes using the ultra fast lane detection model in ONNX.

Example scripts for the detection of lanes using the ultra fast lane detection model in ONNX.

Ibai Gorordo 35 Sep 07, 2022
Search and filter videos based on objects that appear in them using convolutional neural networks

Thingscoop: Utility for searching and filtering videos based on their content Description Thingscoop is a command-line utility for analyzing videos se

Anastasis Germanidis 354 Dec 04, 2022
Python library for analysis of time series data including dimensionality reduction, clustering, and Markov model estimation

deeptime Releases: Installation via conda recommended. conda install -c conda-forge deeptime pip install deeptime Documentation: deeptime-ml.github.io

495 Dec 28, 2022
Code for 1st place solution in Sleep AI Challenge SNU Hospital

Sleep AI Challenge SNU Hospital 2021 Code for 1st place solution for Sleep AI Challenge (Note that the code is not fully organized) Refer to the notio

Saewon Yang 13 Jan 03, 2022
ConformalLayers: A non-linear sequential neural network with associative layers

ConformalLayers: A non-linear sequential neural network with associative layers ConformalLayers is a conformal embedding of sequential layers of Convo

Prograf-UFF 5 Sep 28, 2022
Jingju baseline - A baseline model of our project of Beijing opera script generation

Jingju Baseline It is a baseline of our project about Beijing opera script gener

midon 1 Jan 14, 2022
An image classification app boilerplate to serve your deep learning models asap!

Image 🖼 Classification App Boilerplate Have you been puzzled by tons of videos, blogs and other resources on the internet and don't know where and ho

Smaranjit Ghose 27 Oct 06, 2022