[SIGGRAPH'22] StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets

Overview

[Project] [PDF] Hugging Face Spaces

This repository contains code for our SIGGRAPH'22 paper "StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets"

by Axel Sauer, Katja Schwarz, and Andreas Geiger.

If you find our code or paper useful, please cite

@InProceedings{Sauer2021ARXIV,
  author    = {Axel Sauer and Katja Schwarz and Andreas Geiger},
  title     = {StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets},
  journal   = {arXiv.org},
  volume    = {abs/2201.00273},
  year      = {2022},
  url       = {https://arxiv.org/abs/2201.00273},
}
Rank on Papers With Code  
PWC PWC
PWC PWC
PWC PWC
PWC PWC
PWC PWC

Related Projects

  • Projected GANs Converge Faster (NeurIPS'21)  -  Official Repo  -  Projected GAN Quickstart
  • StyleGAN-XL + CLIP (Implemented by CasualGANPapers)  -  StyleGAN-XL + CLIP
  • StyleGAN-XL + CLIP (Modified by Katherine Crowson to optimize in W+ space)  -  StyleGAN-XL + CLIP

ToDos

  • Initial code release
  • Add pretrained models (ImageNet{16,32,64,128,256,512,1024}, FFHQ{256,512,1024}, Pokemon{256,512,1024})
  • Add StyleMC for editing
  • Add PTI for inversion

Requirements

  • 64-bit Python 3.8 and PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
  • CUDA toolkit 11.1 or later.
  • GCC 7 or later compilers. The recommended GCC version depends on your CUDA version; see for example, CUDA 11.4 system requirements.
  • If you run into problems when setting up the custom CUDA kernels, we refer to the Troubleshooting docs of the original StyleGAN3 repo and the following issues: #23.
  • Windows user struggling installing the env might find #10 helpful.
  • Use the following commands with Miniconda3 to create and activate your PG Python environment:
    • conda env create -f environment.yml
    • conda activate sgxl

Data Preparation

For a quick start, you can download the few-shot datasets provided by the authors of FastGAN. You can download them here. To prepare the dataset at the respective resolution, run

python dataset_tool.py --source=./data/pokemon --dest=./data/pokemon256.zip \
  --resolution=256x256 --transform=center-crop

You need to follow our progressive growing scheme to get the best results. Therefore, you should prepare separate zips for each training resolution. You can get the datasets we used in our paper at their respective websites (FFHQ, ImageNet).

Training

For progressive growing, we train a stem on low resolution, e.g., 162 pixels. When the stem is finished, i.e., FID is saturating, you can start training the upper stages; we refer to these as superresolution stages.

Training the stem

Training StyleGAN-XL on Pokemon using 8 GPUs:

python train.py --outdir=./training-runs/pokemon --cfg=stylegan3-t --data=./data/pokemon16.zip \
    --gpus=8 --batch=64 --mirror=1 --snap 10 --batch-gpu 8 --kimg 10000 --syn_layers 10

--batch specifies the overall batch size, --batch-gpu specifies the batch size per GPU. The training loop will automatically accumulate gradients if you use fewer GPUs until the overall batch size is reached.

Samples and metrics are saved in outdir. If you don't want to track metrics, set --metrics=none. You can inspect fid50k_full.json or run tensorboard in training-runs/ to monitor the training progress.

For a class-conditional dataset (ImageNet, CIFAR-10), add the flag --cond True . The dataset needs to contain the class labels; see the StyleGAN2-ADA repo on how to prepare class-conditional datasets.

Training the super-resolution stages

Continuing with pretrained stem:

python train.py --outdir=./training-runs/pokemon --cfg=stylegan3-t --data=./data/pokemon32.zip \
  --gpus=8 --batch=64 --mirror=1 --snap 10 --batch-gpu 8 --kimg 10000 --syn_layers 10 \
  --superres --up_factor 2 --head_layers 7 \
  --path_stem training-runs/pokemon/00000-stylegan3-t-pokemon16-gpus8-batch64/best_model.pkl

--up_factor allows to train several stages at once, i.e., with --up_factor=4 and a 162 stem you can directly train at resolution 642.

If you have enough compute, a good tactic is to train several stages in parallel and then restart the superresolution stage training once in a while. The current stage will then reload its previous stem's best_model.pkl. Performance can sometimes drop at first because of domain shift, but the superresolution stage quickly recovers and improves further.

Training recommendations for datasets other than ImageNet

The default settings are tuned for ImageNet. For smaller datasets (<50k images) or well-curated datasets (FFHQ), you can significantly decrease the model size enabling much faster training. Recommended settings are: --cbase 128 --cmax 128 --syn_layers 4 and for superresolution stages --head_layers 4.

Suppose you want to train as few stages as possible. We recommend training a 32x32 or 64x64 stem, then directly scaling to the final resolution (as described above, you must adjust --up_factor accordingly). However, generally, progressive growing yields better results faster as the throughput is much higher at lower resolutions. This can be seen in this figure by Karras et al., 2017:

Generating Samples & Interpolations

To generate samples and interpolation videos, run

python gen_images.py --outdir=out --trunc=0.7 --seeds=10-15 --batch-sz 1 \
  --network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl

and

python gen_video.py --output=lerp.mp4 --trunc=0.7 --seeds=0-31 --grid=4x2 \
  --network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl

For class-conditional models, you can pass the class index via --class, a index-to-label dictionary for Imagenet can be found here. For interpolation between classes, provide, e.g., --cls=0-31 to gen_video.py. The list of classes has to be the same length as --seeds.

To generate a conditional sample sheet, run

python gen_class_samplesheet.py --outdir=sample_sheets --trunc=1.0 \
  --network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl \
  --samples-per-class 4 --classes 0-32 --grid-width 32

For ImageNet models, we enable multi-modal truncation (proposed by Self-Distilled GAN). We generated 600k find 10k cluster centroids via k-means. For a given samples, multi-modal truncation finds the closest centroids and interpolates towards it. To switch from uni-model to multi-modal truncation, pass

--centroids-path=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet_centroids.npy

No Truncation Uni-Modal Truncation Multi-Modal Truncation

Image Editing

To use our reimplementation of StyleMC, and generate the example above, run

python run_stylemc.py --outdir=stylemc_out \
  --text-prompt "a chimpanzee | laughter | happyness| happy chimpanzee | happy monkey | smile | grin" \
  --seeds 0-256 --class-idx 367 --layers 10-30 --edit-strength 0.75 --init-seed 49 \
  --network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl \
  --bigger-network https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl

Recommended workflow:

  • Sample images via gen_images.py.
  • Pick a sample and use it as the inital image for stylemc.py by providing --init-seed and --class-idx.
  • Find a direction in style space via --text-prompt.
  • Finetune --edit-strength, --layers, and amount of --seeds.
  • Once you found a good setting, provide a larger model via --bigger-network. The script still optimizes the direction for the smaller model, but uses the bigger model for the final output.

Pretrained Models

We provide the following pretrained models (pass the url as PATH_TO_NETWORK_PKL):

Dataset Res FID PATH
ImageNet 162 0.73 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet16.pkl
ImageNet 322 1.11 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet32.pkl
ImageNet 642 1.52 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet64.pkl
ImageNet 1282 1.77 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl
ImageNet 2562 2.26 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet256.pkl
ImageNet 5122 2.42 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl
ImageNet 10242 2.51 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl
CIFAR10 322 1.85 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/cifar10.pkl
FFHQ 2562 2.19 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl
FFHQ 5122 2.23 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq512.pkl
FFHQ 10242 2.02 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq1024.pkl
Pokemon 2562 23.97 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl
Pokemon 5122 23.82 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon512.pkl
Pokemon 10242 25.47 https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon1024.pkl

Quality Metrics

Per default, train.py tracks FID50k during training. To calculate metrics for a specific network snapshot, run

python calc_metrics.py --metrics=fid50k_full --network=PATH_TO_NETWORK_PKL

To see the available metrics, run

python calc_metrics.py --help

We provide precomputed FID statistics for all pretrained models:

wget https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/gan-metrics.zip
unzip gan-metrics.zip -d dnnlib/

Further Information

This repo builds on the codebase of StyleGAN3 and our previous project Projected GANs Converge Faster.

Official implementation of NeuralFusion: Online Depth Map Fusion in Latent Space

NeuralFusion This is the official implementation of NeuralFusion: Online Depth Map Fusion in Latent Space. We provide code to train the proposed pipel

53 Jan 01, 2023
Code base for "On-the-Fly Test-time Adaptation for Medical Image Segmentation"

On-the-Fly Adaptation Official Pytorch Code base for On-the-Fly Test-time Adaptation for Medical Image Segmentation Paper Introduction One major probl

Jeya Maria Jose 17 Nov 10, 2022
This repository is maintained for the scientific paper tittled " Study of keyword extraction techniques for Electric Double Layer Capacitor domain using text similarity indexes: An experimental analysis "

kwd-extraction-study This repository is maintained for the scientific paper tittled " Study of keyword extraction techniques for Electric Double Layer

ping 543f 1 Dec 05, 2022
A Survey on Deep Learning Technique for Video Segmentation

A Survey on Deep Learning Technique for Video Segmentation A Survey on Deep Learning Technique for Video Segmentation Wenguan Wang, Tianfei Zhou, Fati

Tianfei Zhou 112 Dec 12, 2022
Predictive Maintenance LSTM

Predictive-Maintenance-LSTM - Predictive maintenance study for Complex case study, we've obtained failure causes by operational error and more deeply by design mistakes.

Amir M. Sadafi 1 Dec 31, 2021
A python package to perform same transformation to coco-annotation as performed on the image.

coco-transform-util A python package to perform same transformation to coco-annotation as performed on the image. Installation Way 1 $ git clone https

1 Jan 14, 2022
Learning Lightweight Low-Light Enhancement Network using Pseudo Well-Exposed Images

Learning Lightweight Low-Light Enhancement Network using Pseudo Well-Exposed Images This repository contains the implementation of the following paper

Seonggwan Ko 9 Jul 30, 2022
CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)

CrossNorm (CN) and SelfNorm (SN) (Accepted at ICCV 2021) This is the official PyTorch implementation of our CNSN paper, in which we propose CrossNorm

100 Dec 28, 2022
To prepare an image processing model to classify the type of disaster based on the image dataset

Disaster Classificiation using CNNs bunnysaini/Disaster-Classificiation Goal To prepare an image processing model to classify the type of disaster bas

Bunny Saini 1 Jan 24, 2022
Reusable constraint types to use with typing.Annotated

annotated-types PEP-593 added typing.Annotated as a way of adding context-specific metadata to existing types, and specifies that Annotated[T, x] shou

125 Dec 26, 2022
SurfEmb (CVPR 2022) - SurfEmb: Dense and Continuous Correspondence Distributions

SurfEmb SurfEmb: Dense and Continuous Correspondence Distributions for Object Pose Estimation with Learnt Surface Embeddings Rasmus Laurvig Haugard, A

Rasmus Haugaard 56 Nov 19, 2022
Compare outputs between layers written in Tensorflow and layers written in Pytorch

Compare outputs of Wasserstein GANs between TensorFlow vs Pytorch This is our testing module for the implementation of improved WGAN in Pytorch Prereq

Hung Nguyen 72 Dec 20, 2022
Lightweight Cuda Renderer with Python Wrapper.

pyRender Lightweight Cuda Renderer with Python Wrapper. Compile Change compile.sh line 5 to the glm library include path. This library can be download

Jingwei Huang 53 Dec 02, 2022
Differential Privacy for Heterogeneous Federated Learning : Utility & Privacy tradeoffs

Differential Privacy for Heterogeneous Federated Learning : Utility & Privacy tradeoffs In this work, we propose an algorithm DP-SCAFFOLD(-warm), whic

19 Nov 10, 2022
Probabilistic Entity Representation Model for Reasoning over Knowledge Graphs

Implementation for the paper: Probabilistic Entity Representation Model for Reasoning over Knowledge Graphs, Nurendra Choudhary, Nikhil Rao, Sumeet Ka

Nurendra Choudhary 8 Nov 15, 2022
Unsupervised CNN for Single View Depth Estimation: Geometry to the Rescue

Realtime Unsupervised Depth Estimation from an Image This is the caffe implementation of our paper "Unsupervised CNN for single view depth estimation:

Ravi Garg 227 Nov 28, 2022
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

Qibin (Andrew) Hou 162 Nov 28, 2022
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

Kang Liao 18 Dec 23, 2022
UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus

UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus General info This is

71 Oct 25, 2022
The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper.

Intermdiate layer matters - SSL The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper. Downl

Aakash Kaku 35 Sep 19, 2022