CVPR2021 Content-Aware GAN Compression

Overview

Content-Aware GAN Compression [ArXiv]

Paper accepted to CVPR2021.

@inproceedings{liu2021content,
  title     = {Content-Aware GAN Compression},
  author    = {Liu, Yuchen and Shu, Zhixin and Li, Yijun and Lin, Zhe and Perazzi, Federico and Kung, S.Y.},
  booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year      = {2021},
}

Overview

We propose a novel content-aware approach for GAN compression. With content-awareness, our 11x-accelerated GAN performs comparably with the full-size model on image generation and image editing.

Image Generation

We show an example above on the generative ability of our 11x-accelerated generator vs. the full-size one. In particular, our model generates the interested contents visually comparable to the full-size model.

Image Editing

We show an example typifying the effectiveness of our compressed StyleGAN2 for image style-mixing and morphing above. When we mix middle styles from B, the original full-size model has a significant identity loss, while our approach better preserves the person’s identity. We also observe that our morphed images have a smoother expression transition compared the full-size model in the beard, substantiating our advantage in latent space smoothness.

We provide an additional example above.

Methodology

In our work, we make the first attempt to bring content awareness into channel pruning and knowledge distillation.

Specifically, we leverage a content-parsing network to identify contents of interest (COI), a set of spatial locations with salient semantic concepts, within the generated images. We design a content-aware pruning metric (with a forward and backward path) to remove channels that are least sensitive to COI in the generated images. For knowledge distillation, we focus our distillation region only to COI of the teacher’s outputs which further enhances target contents’ distillation.

Usage

Prerequisite

We have tested our codes under the following environments:

python == 3.6.5
pytorch == 1.6.0
torchvision == 0.7.0
CUDA == 10.2

Pretrained Full-Size Generator Checkpoint

To start with, you can first download a full-size generator checkpoint from:

256px StyleGAN2

1024px StyleGAN2

and place it under the folder ./Model/full_size_model/.

Pruning

Once you get the full-size checkpoint, you can prune the generator by:

python3 prune.py \
	--generated_img_size=256 \
	--ckpt=/path/to/full/size/model/ \
	--remove_ratio=0.7 \
	--info_print

We adopt a uniform channel pruning ratio for every layer. Above procedure will remove 70% of channels from the generator in each layer. The pruned checkpoint will be saved at ./Model/pruned_model/.

Retraining

We then retrain the pruned generator by:

python3 train.py \
	--size=256 \
	--path=/path/to/ffhq/data/folder/ \
	--ckpt=/path/to/pruned/model/ \
	--teacher_ckpt=/path/to/full/size/model/ \
	--iter=450001 \
	--batch_size=16

You may adjust the variables gpu_device_ids and primary_device for the GPU setup in train_hyperparams.py.

Training Log

The time for retraining 11x-compressed models on V100 GPUs:

Model Batch Size Iterations # GPUs Time (Hour)
256px StyleGAN2 16 450k 2 131
1024px StyleGAN2 16 450k 4 251

A typical training curve for the 11x-compressed 256px StyleGAN2:

Evaluation

To evaluate the model quantitatively, we provide get_fid.py and get_ppl.py to get model's FID and PPL sores.

FID Evaluation:

python3 get_fid.py \
	--generated_img_size=256 \
	--ckpt=/path/to/model/ \
	--n_sample=50000 \
	--batch_size=64 \
	--info_print

PPL Evaluation:

python3 get_ppl.py \
	--generated_img_size=256 \
	--ckpt=/path/to/model/ \
	--n_sample=5000 \
	--eps=1e-4 \
	--info_print

We also provide an image projector which return a (real image, projected image) pair in Image_Projection_Visualization.png as well as the PSNR and LPIPS score between this pair:

python3 get_projected_image.py \
	--generated_img_size=256 \
	--ckpt=/path/to/model/ \
	--image_file=/path/to/an/RGB/image/ \
	--num_iters=800 \
	--info_print

An example of Image_Projection_Visualization.png projected by a full-size 256px StyleGAN2:

Helen-Set55

We provide the Helen-Set55 on Google Drive.

11x-Accelerated Generator Checkpoint

We provide the following checkpoints of our content-aware compressed StyleGAN2:

Compressed 256px StyleGAN2

Compressed 1024px StyleGAN2

Acknowledgement

PyTorch StyleGAN2: https://github.com/rosinality/stylegan2-pytorch

Face Parsing BiSeNet: https://github.com/zllrunning/face-parsing.PyTorch

Fréchet Inception Distance: https://github.com/mseitzer/pytorch-fid

Learned Perceptual Image Patch Similarity: https://github.com/richzhang/PerceptualSimilarity

Owner
Yuchen Liu, Ph.D. Candidate at Princeton University
A Pytree Module system for Deep Learning in JAX

Treex A Pytree-based Module system for Deep Learning in JAX Intuitive: Modules are simple Python objects that respect Object-Oriented semantics and sh

Cristian Garcia 216 Dec 20, 2022
Toolkit for collecting and applying prompts

PromptSource Promptsource is a toolkit for collecting and applying prompts to NLP datasets. Promptsource uses a simple templating language to programa

BigScience Workshop 998 Jan 03, 2023
Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Tom-R.T.Kvalvaag 2 Dec 17, 2021
Meta Learning Backpropagation And Improving It (VSML)

Meta Learning Backpropagation And Improving It (VSML) This is research code for the NeurIPS 2021 publication Kirsch & Schmidhuber 2021. Many concepts

Louis Kirsch 22 Dec 21, 2022
SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data

SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data Au

14 Nov 28, 2022
Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes

Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes [Paper] Method overview 4DMatch Benchmark 4DMatch is a benchmark for matc

103 Jan 06, 2023
Joint Gaussian Graphical Model Estimation: A Survey

Joint Gaussian Graphical Model Estimation: A Survey Test Models Fused graphical lasso [1] Group graphical lasso [1] Graphical lasso [1] Doubly joint s

Koyejo Lab 1 Aug 10, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 05, 2023
Adaptive Denoising Training (ADT) for Recommendation.

DenoisingRec Adaptive Denoising Training for Recommendation. This is the pytorch implementation of our paper at WSDM 2021: Denoising Implicit Feedback

Wenjie Wang 51 Dec 30, 2022
nanodet_plus,yolov5_v6.0

OAK_Detection OAK设备上适配nanodet_plus,yolov5_v6.0 Environment pytorch = 1.7.0

炼丹去了 1 Feb 18, 2022
SCALoss: Side and Corner Aligned Loss for Bounding Box Regression (AAAI2022).

SCALoss PyTorch implementation of the paper "SCALoss: Side and Corner Aligned Loss for Bounding Box Regression" (AAAI 2022). Introduction IoU-based lo

TuZheng 20 Sep 07, 2022
Extracting knowledge graphs from language models as a diagnostic benchmark of model performance.

Interpreting Language Models Through Knowledge Graph Extraction Idea: How do we interpret what a language model learns at various stages of training?

EPFL Machine Learning and Optimization Laboratory 9 Oct 25, 2022
Object detection (YOLO) with pytorch, OpenCV and python

Real Time Object/Face Detection Using YOLO-v3 This project implements a real time object and face detection using YOLO algorithm. You only look once,

1 Aug 04, 2022
Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit

CNTK Chat Windows build status Linux build status The Microsoft Cognitive Toolkit (https://cntk.ai) is a unified deep learning toolkit that describes

Microsoft 17.3k Dec 29, 2022
Constructing Neural Network-Based Models for Simulating Dynamical Systems

Constructing Neural Network-Based Models for Simulating Dynamical Systems Note this repo is work in progress prior to reviewing This is a companion re

Christian Møldrup Legaard 21 Nov 25, 2022
MT3: Multi-Task Multitrack Music Transcription

MT3: Multi-Task Multitrack Music Transcription MT3 is a multi-instrument automatic music transcription model that uses the T5X framework. This is not

Magenta 867 Dec 29, 2022
Hippocampal segmentation using the UNet network for each axis

Hipposeg Hippocampal segmentation using the UNet network for each axis, inspired by https://github.com/MICLab-Unicamp/e2dhipseg Red: False Positive Gr

Juan Carlos Aguirre Arango 0 Sep 02, 2021
Direct Multi-view Multi-person 3D Human Pose Estimation

Implementation of NeurIPS-2021 paper: Direct Multi-view Multi-person 3D Human Pose Estimation [paper] [video-YouTube, video-Bilibili] [slides] This is

Sea AI Lab 251 Dec 30, 2022
A computational block to solve entity alignment over textual attributes in a knowledge graph creation pipeline.

How to apply? Create your config.ini file following the example provided in config.ini Choose one of the options below to run: Run with Python3 pip in

Scientific Data Management Group 3 Jun 23, 2022
Rethinking Nearest Neighbors for Visual Classification

Rethinking Nearest Neighbors for Visual Classification arXiv Environment settings Check out scripts/env_setup.sh Setup data Download the following fin

Menglin Jia 29 Oct 11, 2022