[SIGGRAPH'22] StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets

Overview

[Project] [PDF] Hugging Face Spaces

This repository contains code for our SIGGRAPH'22 paper "StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets"

by Axel Sauer, Katja Schwarz, and Andreas Geiger.

If you find our code or paper useful, please cite

@InProceedings{Sauer2021ARXIV,
  author    = {Axel Sauer and Katja Schwarz and Andreas Geiger},
  title     = {StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets},
  journal   = {arXiv.org},
  volume    = {abs/2201.00273},
  year      = {2022},
  url       = {https://arxiv.org/abs/2201.00273},
}
Rank on Papers With Code  
PWC PWC
PWC PWC
PWC PWC
PWC PWC
PWC PWC

Related Projects

  • Projected GANs Converge Faster (NeurIPS'21)  -  Official Repo  -  Projected GAN Quickstart
  • StyleGAN-XL + CLIP (Implemented by CasualGANPapers)  -  StyleGAN-XL + CLIP
  • StyleGAN-XL + CLIP (Modified by Katherine Crowson to optimize in W+ space)  -  StyleGAN-XL + CLIP

ToDos

  • Initial code release
  • Add pretrained models (ImageNet{16,32,64,128,256,512,1024}, FFHQ{256,512,1024}, Pokemon{256,512,1024})
  • Add StyleMC for editing
  • Add PTI for inversion

Requirements

  • 64-bit Python 3.8 and PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
  • CUDA toolkit 11.1 or later.
  • GCC 7 or later compilers. The recommended GCC version depends on your CUDA version; see for example, CUDA 11.4 system requirements.
  • If you run into problems when setting up the custom CUDA kernels, we refer to the Troubleshooting docs of the original StyleGAN3 repo and the following issues: #23.
  • Windows user struggling installing the env might find #10 helpful.
  • Use the following commands with Miniconda3 to create and activate your PG Python environment:
    • conda env create -f environment.yml
    • conda activate sgxl

Data Preparation

For a quick start, you can download the few-shot datasets provided by the authors of FastGAN. You can download them here. To prepare the dataset at the respective resolution, run

python dataset_tool.py --source=./data/pokemon --dest=./data/pokemon256.zip \
  --resolution=256x256 --transform=center-crop

You need to follow our progressive growing scheme to get the best results. Therefore, you should prepare separate zips for each training resolution. You can get the datasets we used in our paper at their respective websites (FFHQ, ImageNet).

Training

For progressive growing, we train a stem on low resolution, e.g., 162 pixels. When the stem is finished, i.e., FID is saturating, you can start training the upper stages; we refer to these as superresolution stages.

Training the stem

Training StyleGAN-XL on Pokemon using 8 GPUs:

python train.py --outdir=./training-runs/pokemon --cfg=stylegan3-t --data=./data/pokemon16.zip \
    --gpus=8 --batch=64 --mirror=1 --snap 10 --batch-gpu 8 --kimg 10000 --syn_layers 10

--batch specifies the overall batch size, --batch-gpu specifies the batch size per GPU. The training loop will automatically accumulate gradients if you use fewer GPUs until the overall batch size is reached.

Samples and metrics are saved in outdir. If you don't want to track metrics, set --metrics=none. You can inspect fid50k_full.json or run tensorboard in training-runs/ to monitor the training progress.

For a class-conditional dataset (ImageNet, CIFAR-10), add the flag --cond True . The dataset needs to contain the class labels; see the StyleGAN2-ADA repo on how to prepare class-conditional datasets.

Training the super-resolution stages

Continuing with pretrained stem:

python train.py --outdir=./training-runs/pokemon --cfg=stylegan3-t --data=./data/pokemon32.zip \
  --gpus=8 --batch=64 --mirror=1 --snap 10 --batch-gpu 8 --kimg 10000 --syn_layers 10 \
  --superres --up_factor 2 --head_layers 7 \
  --path_stem training-runs/pokemon/00000-stylegan3-t-pokemon16-gpus8-batch64/best_model.pkl

--up_factor allows to train several stages at once, i.e., with --up_factor=4 and a 162 stem you can directly train at resolution 642.

If you have enough compute, a good tactic is to train several stages in parallel and then restart the superresolution stage training once in a while. The current stage will then reload its previous stem's best_model.pkl. Performance can sometimes drop at first because of domain shift, but the superresolution stage quickly recovers and improves further.

Training recommendations for datasets other than ImageNet

The default settings are tuned for ImageNet. For smaller datasets (<50k images) or well-curated datasets (FFHQ), you can significantly decrease the model size enabling much faster training. Recommended settings are: --cbase 128 --cmax 128 --syn_layers 4 and for superresolution stages --head_layers 4.

Suppose you want to train as few stages as possible. We recommend training a 32x32 or 64x64 stem, then directly scaling to the final resolution (as described above, you must adjust --up_factor accordingly). However, generally, progressive growing yields better results faster as the throughput is much higher at lower resolutions. This can be seen in this figure by Karras et al., 2017:

Generating Samples & Interpolations

To generate samples and interpolation videos, run

python gen_images.py --outdir=out --trunc=0.7 --seeds=10-15 --batch-sz 1 \
  --network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl

and

python gen_video.py --output=lerp.mp4 --trunc=0.7 --seeds=0-31 --grid=4x2 \
  --network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl

For class-conditional models, you can pass the class index via --class, a index-to-label dictionary for Imagenet can be found here. For interpolation between classes, provide, e.g., --cls=0-31 to gen_video.py. The list of classes has to be the same length as --seeds.

To generate a conditional sample sheet, run

python gen_class_samplesheet.py --outdir=sample_sheets --trunc=1.0 \
  --network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl \
  --samples-per-class 4 --classes 0-32 --grid-width 32

For ImageNet models, we enable multi-modal truncation (proposed by Self-Distilled GAN). We generated 600k find 10k cluster centroids via k-means. For a given samples, multi-modal truncation finds the closest centroids and interpolates towards it. To switch from uni-model to multi-modal truncation, pass

--centroids-path=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet_centroids.npy

No Truncation Uni-Modal Truncation Multi-Modal Truncation

Image Editing

To use our reimplementation of StyleMC, and generate the example above, run

python run_stylemc.py --outdir=stylemc_out \
  --text-prompt "a chimpanzee | laughter | happyness| happy chimpanzee | happy monkey | smile | grin" \
  --seeds 0-256 --class-idx 367 --layers 10-30 --edit-strength 0.75 --init-seed 49 \
  --network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl \
  --bigger-network https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl

Recommended workflow:

  • Sample images via gen_images.py.
  • Pick a sample and use it as the inital image for stylemc.py by providing --init-seed and --class-idx.
  • Find a direction in style space via --text-prompt.
  • Finetune --edit-strength, --layers, and amount of --seeds.
  • Once you found a good setting, provide a larger model via --bigger-network. The script still optimizes the direction for the smaller model, but uses the bigger model for the final output.

Pretrained Models

We provide the following pretrained models (pass the url as PATH_TO_NETWORK_PKL):

Dataset Res FID PATH
ImageNet 162 0.73 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet16.pkl
ImageNet 322 1.11 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet32.pkl
ImageNet 642 1.52 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet64.pkl
ImageNet 1282 1.77 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl
ImageNet 2562 2.26 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet256.pkl
ImageNet 5122 2.42 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl
ImageNet 10242 2.51 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl
CIFAR10 322 1.85 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/cifar10.pkl
FFHQ 2562 2.19 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl
FFHQ 5122 2.23 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq512.pkl
FFHQ 10242 2.02 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq1024.pkl
Pokemon 2562 23.97 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl
Pokemon 5122 23.82 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon512.pkl
Pokemon 10242 25.47 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon1024.pkl

Quality Metrics

Per default, train.py tracks FID50k during training. To calculate metrics for a specific network snapshot, run

python calc_metrics.py --metrics=fid50k_full --network=PATH_TO_NETWORK_PKL

To see the available metrics, run

python calc_metrics.py --help

We provide precomputed FID statistics for all pretrained models:

wget https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/gan-metrics.zip
unzip gan-metrics.zip -d dnnlib/

Further Information

This repo builds on the codebase of StyleGAN3 and our previous project Projected GANs Converge Faster.

TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

FunMatch-Distillation TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A g

Sayak Paul 67 Dec 20, 2022
TensorFlow implementation of the algorithm in the paper "Decoupled Low-light Image Enhancement"

Decoupled Low-light Image Enhancement Shijie Hao1,2*, Xu Han1,2, Yanrong Guo1,2 & Meng Wang1,2 1Key Laboratory of Knowledge Engineering with Big Data

17 Apr 25, 2022
The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"

Swin-Unet The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validatio

869 Jan 07, 2023
efficient neural audio synthesis in the waveform domain

neural waveshaping synthesis real-time neural audio synthesis in the waveform domain paper • website • colab • audio by Ben Hayes, Charalampos Saitis,

Ben Hayes 169 Dec 23, 2022
Generative Models for Graph-Based Protein Design

Graph-Based Protein Design This repo contains code for Generative Models for Graph-Based Protein Design by John Ingraham, Vikas Garg, Regina Barzilay

John Ingraham 159 Dec 15, 2022
[NeurIPS 2021] Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data

Near-Duplicate Video Retrieval with Deep Metric Learning This repository contains the Tensorflow implementation of the paper Near-Duplicate Video Retr

Liming Jiang 238 Nov 25, 2022
PyTorch code of paper "LiVLR: A Lightweight Visual-Linguistic Reasoning Framework for Video Question Answering"

LiVLR-VideoQA We propose a Lightweight Visual-Linguistic Reasoning framework (LiVLR) for VideoQA. The overview of LiVLR: Evaluation on MSRVTT-QA Datas

JJ Jiang 7 Dec 30, 2022
Automatic Data-Regularized Actor-Critic (Auto-DrAC)

Auto-DrAC: Automatic Data-Regularized Actor-Critic This is a PyTorch implementation of the methods proposed in Automatic Data Augmentation for General

89 Dec 13, 2022
Datasets and source code for our paper Webly Supervised Fine-Grained Recognition: Benchmark Datasets and An Approach

Introduction Datasets and source code for our paper Webly Supervised Fine-Grained Recognition: Benchmark Datasets and An Approach Datasets: WebFG-496

21 Sep 30, 2022
Official repository for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'21, Oral Presentation)

Official PyTorch Implementation for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'2021, Oral Presentation) HOTR: End-to-

Kakao Brain 114 Nov 28, 2022
Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides

Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides Project | This repo is the officia

CVSM Group - email: <a href=[email protected]"> 33 Dec 28, 2022
PyTorch implementation of Hierarchical Multi-label Text Classification: An Attention-based Recurrent Network

hierarchical-multi-label-text-classification-pytorch Hierarchical Multi-label Text Classification: An Attention-based Recurrent Network Approach This

Mingu Kang 17 Dec 13, 2022
Python Actor concurrency library

Thespian Actor Library This library provides the framework of an Actor model for use by applications implementing Actors. Thespian Site with Documenta

Kevin Quick 177 Dec 11, 2022
Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis

Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis. You write a high level configuration file specifying your in

Blue Collar Bioinformatics 917 Jan 03, 2023
PenguinSpeciesPredictionML - Basic model to predict Penguin species based on beak size and sex.

Penguin Species Prediction (ML) 🐧 👨🏽‍💻 What? 💻 This project is a basic model using sklearn methods to predict Penguin species based on beak size

Tucker Paron 0 Jan 08, 2022
DirectVoxGO reconstructs a scene representation from a set of calibrated images capturing the scene.

DirectVoxGO reconstructs a scene representation from a set of calibrated images capturing the scene. We achieve NeRF-comparable novel-view synthesis quality with super-fast convergence.

sunset 709 Dec 31, 2022
PixelPyramids: Exact Inference Models from Lossless Image Pyramids (ICCV 2021)

PixelPyramids: Exact Inference Models from Lossless Image Pyramids This repository contains the PyTorch implementation of the paper PixelPyramids: Exa

Visual Inference Lab @TU Darmstadt 8 Dec 11, 2022
Spatiotemporal resampling methods for mlr3

mlr3spatiotempcv Package website: release | dev Spatiotemporal resampling methods for mlr3. This package extends the mlr3 package framework with spati

45 Nov 21, 2022
dataset for ECCV 2020 "Motion Capture from Internet Videos"

Motion Capture from Internet Videos Motion Capture from Internet Videos Junting Dong*, Qing Shuai*, Yuanqing Zhang, Xian Liu, Xiaowei Zhou, Hujun Bao

ZJU3DV 98 Dec 07, 2022