Contrastive unpaired image-to-image translation, faster and lighter training than cyclegan (ECCV 2020, in PyTorch)

Overview

Contrastive Unpaired Translation (CUT)

video (1m) | video (10m) | website | paper





We provide our PyTorch implementation of unpaired image-to-image translation based on patchwise contrastive learning and adversarial learning. No hand-crafted loss and inverse network is used. Compared to CycleGAN, our model training is faster and less memory-intensive. In addition, our method can be extended to single image training, where each “domain” is only a single image.

Contrastive Learning for Unpaired Image-to-Image Translation
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
UC Berkeley and Adobe Research
In ECCV 2020




Pseudo code

import torch
cross_entropy_loss = torch.nn.CrossEntropyLoss()

# Input: f_q (BxCxS) and sampled features from H(G_enc(x))
# Input: f_k (BxCxS) are sampled features from H(G_enc(G(x))
# Input: tau is the temperature used in PatchNCE loss.
# Output: PatchNCE loss
def PatchNCELoss(f_q, f_k, tau=0.07):
    # batch size, channel size, and number of sample locations
    B, C, S = f_q.shape

    # calculate v * v+: BxSx1
    l_pos = (f_k * f_q).sum(dim=1)[:, :, None]

    # calculate v * v-: BxSxS
    l_neg = torch.bmm(f_q.transpose(1, 2), f_k)

    # The diagonal entries are not negatives. Remove them.
    identity_matrix = torch.eye(S)[None, :, :]
    l_neg.masked_fill_(identity_matrix, -float('inf'))

    # calculate logits: (B)x(S)x(S+1)
    logits = torch.cat((l_pos, l_neg), dim=2) / tau

    # return PatchNCE loss
    predictions = logits.flatten(0, 1)
    targets = torch.zeros(B * S, dtype=torch.long)
    return cross_entropy_loss(predictions, targets)

Example Results

Unpaired Image-to-Image Translation

Single Image Unpaired Translation

Russian Blue Cat to Grumpy Cat

Parisian Street to Burano's painted houses

Prerequisites

  • Linux or macOS
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Update log

9/12/2020: Added single-image translation.

Getting started

  • Clone this repo:
git clone https://github.com/taesungp/contrastive-unpaired-translation CUT
cd CUT
  • Install PyTorch 1.1 and other dependencies (e.g., torchvision, visdom, dominate, gputil).

    For pip users, please type the command pip install -r requirements.txt.

    For Conda users, you can create a new Conda environment using conda env create -f environment.yml.

CUT and FastCUT Training and Test

  • Download the grumpifycat dataset (Fig 8 of the paper. Russian Blue -> Grumpy Cats)
bash ./datasets/download_cut_dataset.sh grumpifycat

The dataset is downloaded and unzipped at ./datasets/grumpifycat/.

  • To view training results and loss plots, run python -m visdom.server and click the URL http://localhost:8097.

  • Train the CUT model:

python train.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT

Or train the FastCUT model

python train.py --dataroot ./datasets/grumpifycat --name grumpycat_FastCUT --CUT_mode FastCUT

The checkpoints will be stored at ./checkpoints/grumpycat_*/web.

  • Test the CUT model:
python test.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT --phase train

The test results will be saved to a html file here: ./results/grumpifycat/latest_train/index.html.

CUT, FastCUT, and CycleGAN


CUT is trained with the identity preservation loss and with lambda_NCE=1, while FastCUT is trained without the identity loss but with higher lambda_NCE=10.0. Compared to CycleGAN, CUT learns to perform more powerful distribution matching, while FastCUT is designed as a lighter (half the GPU memory, can fit a larger image), and faster (twice faster to train) alternative to CycleGAN. Please refer to the paper for more details.

In the above figure, we measure the percentage of pixels belonging to the horse/zebra bodies, using a pre-trained semantic segmentation model. We find a distribution mismatch between sizes of horses and zebras images -- zebras usually appear larger (36.8% vs. 17.9%). Our full method CUT has the flexibility to enlarge the horses, as a means of better matching of the training statistics than CycleGAN. FastCUT behaves more conservatively like CycleGAN.

Training using our launcher scripts

Please see experiments/grumpifycat_launcher.py that generates the above command line arguments. The launcher scripts are useful for configuring rather complicated command-line arguments of training and testing.

Using the launcher, the command below generates the training command of CUT and FastCUT.

python -m experiments grumpifycat train 0   # CUT
python -m experiments grumpifycat train 1   # FastCUT

To test using the launcher,

python -m experiments grumpifycat test 0   # CUT
python -m experiments grumpifycat test 1   # FastCUT

Possible commands are run, run_test, launch, close, and so on. Please see experiments/__main__.py for all commands. Launcher is easy and quick to define and use. For example, the grumpifycat launcher is defined in a few lines:

Grumpy Cats dataset does not have test split. # Therefore, let's set the test split to be the "train" set. return ["python test.py " + str(opt.set(phase='train')) for opt in self.common_options()] ">
from .tmux_launcher import Options, TmuxLauncher


class Launcher(TmuxLauncher):
    def common_options(self):
        return [
            Options(    # Command 0
                dataroot="./datasets/grumpifycat",
                name="grumpifycat_CUT",
                CUT_mode="CUT"
            ),

            Options(    # Command 1
                dataroot="./datasets/grumpifycat",
                name="grumpifycat_FastCUT",
                CUT_mode="FastCUT",
            )
        ]

    def commands(self):
        return ["python train.py " + str(opt) for opt in self.common_options()]

    def test_commands(self):
        # Russian Blue -> Grumpy Cats dataset does not have test split.
        # Therefore, let's set the test split to be the "train" set.
        return ["python test.py " + str(opt.set(phase='train')) for opt in self.common_options()]

Apply a pre-trained CUT model and evaluate FID

To run the pretrained models, run the following.

# Download and unzip the pretrained models. The weights should be located at
# checkpoints/horse2zebra_cut_pretrained/latest_net_G.pth, for example.
wget http://efrosgans.eecs.berkeley.edu/CUT/pretrained_models.tar
tar -xf pretrained_models.tar

# Generate outputs. The dataset paths might need to be adjusted.
# To do this, modify the lines of experiments/pretrained_launcher.py
# [id] corresponds to the respective commands defined in pretrained_launcher.py
# 0 - CUT on Cityscapes
# 1 - FastCUT on Cityscapes
# 2 - CUT on Horse2Zebra
# 3 - FastCUT on Horse2Zebra
# 4 - CUT on Cat2Dog
# 5 - FastCUT on Cat2Dog
python -m experiments pretrained run_test [id]

# Evaluate FID. To do this, first install pytorch-fid of https://github.com/mseitzer/pytorch-fid
# pip install pytorch-fid
# For example, to evaluate horse2zebra FID of CUT,
# python -m pytorch_fid ./datasets/horse2zebra/testB/ results/horse2zebra_cut_pretrained/test_latest/images/fake_B/
# To evaluate Cityscapes FID of FastCUT,
# python -m pytorch_fid ./datasets/cityscapes/valA/ ~/projects/contrastive-unpaired-translation/results/cityscapes_fastcut_pretrained/test_latest/images/fake_B/
# Note that a special dataset needs to be used for the Cityscapes model. Please read below. 
python -m pytorch_fid [path to real test images] [path to generated images]

Note: the Cityscapes pretrained model was trained and evaluated on a resized and JPEG-compressed version of the original Cityscapes dataset. To perform evaluation, please download this validation set and perform evaluation.

SinCUT Single Image Unpaired Training

To train SinCUT (single-image translation, shown in Fig 9, 13 and 14 of the paper), you need to

  1. set the --model option as --model sincut, which invokes the configuration and codes at ./models/sincut_model.py, and
  2. specify the dataset directory of one image in each domain, such as the example dataset included in this repo at ./datasets/single_image_monet_etretat/.

For example, to train a model for the Etretat cliff (first image of Figure 13), please use the following command.

python train.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat

or by using the experiment launcher script,

python -m experiments singleimage run 0

For single-image translation, we adopt network architectural components of StyleGAN2, as well as the pixel identity preservation loss used in DTN and CycleGAN. In particular, we adopted the code of rosinality, which exists at models/stylegan_networks.py.

The training takes several hours. To generate the final image using the checkpoint,

python test.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat

or simply

python -m experiments singleimage run_test 0

Datasets

Download CUT/CycleGAN/pix2pix datasets. For example,

bash datasets/download_cut_datasets.sh horse2zebra

The Cat2Dog dataset is prepared from the AFHQ dataset. Please visit https://github.com/clovaai/stargan-v2 and download the AFHQ dataset by bash download.sh afhq-dataset of the github repo. Then reorganize directories as follows.

mkdir datasets/cat2dog
ln -s datasets/cat2dog/trainA [path_to_afhq]/train/cat
ln -s datasets/cat2dog/trainB [path_to_afhq]/train/dog
ln -s datasets/cat2dog/testA [path_to_afhq]/test/cat
ln -s datasets/cat2dog/testB [path_to_afhq]/test/dog

The Cityscapes dataset can be downloaded from https://cityscapes-dataset.com. After that, use the script ./datasets/prepare_cityscapes_dataset.py to prepare the dataset.

Preprocessing of input images

The preprocessing of the input images, such as resizing or random cropping, is controlled by the option --preprocess, --load_size, and --crop_size. The usage follows the CycleGAN/pix2pix repo.

For example, the default setting --preprocess resize_and_crop --load_size 286 --crop_size 256 resizes the input image to 286x286, and then makes a random crop of size 256x256 as a way to perform data augmentation. There are other preprocessing options that can be specified, and they are specified in base_dataset.py. Below are some example options.

  • --preprocess none: does not perform any preprocessing. Note that the image size is still scaled to be a closest multiple of 4, because the convolutional generator cannot maintain the same image size otherwise.
  • --preprocess scale_width --load_size 768: scales the width of the image to be of size 768.
  • --preprocess scale_shortside_and_crop: scales the image preserving aspect ratio so that the short side is load_size, and then performs random cropping of window size crop_size.

More preprocessing options can be added by modifying get_transform() of base_dataset.py.

Citation

If you use this code for your research, please cite our paper.

@inproceedings{park2020cut,
  title={Contrastive Learning for Unpaired Image-to-Image Translation},
  author={Taesung Park and Alexei A. Efros and Richard Zhang and Jun-Yan Zhu},
  booktitle={European Conference on Computer Vision},
  year={2020}
}

If you use the original pix2pix and CycleGAN model included in this repo, please cite the following papers

@inproceedings{CycleGAN2017,
  title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},
  author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
  booktitle={IEEE International Conference on Computer Vision (ICCV)},
  year={2017}
}


@inproceedings{isola2017image,
  title={Image-to-Image Translation with Conditional Adversarial Networks},
  author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
  booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2017}
}

Acknowledgments

We thank Allan Jabri and Phillip Isola for helpful discussion and feedback. Our code is developed based on pytorch-CycleGAN-and-pix2pix. We also thank pytorch-fid for FID computation, drn for mIoU computation, and stylegan2-pytorch for the PyTorch implementation of StyleGAN2 used in our single-image translation setting.

Owner
Research Scientist at Adobe https://taesung.me
Official PyTorch implementation of paper: Standardized Max Logits: A Simple yet Effective Approach for Identifying Unexpected Road Obstacles in Urban-Scene Segmentation (ICCV 2021 Oral Presentation)

SML (ICCV 2021, Oral) : Official Pytorch Implementation This repository provides the official PyTorch implementation of the following paper: Standardi

SangHun 61 Dec 27, 2022
Unsupervised MRI Reconstruction via Zero-Shot Learned Adversarial Transformers

Official TensorFlow implementation of the unsupervised reconstruction model using zero-Shot Learned Adversarial TransformERs (SLATER). (https://arxiv.

ICON Lab 22 Dec 22, 2022
IOT: Instance-wise Layer Reordering for Transformer Structures

Introduction This repository contains the code for Instance-wise Ordered Transformer (IOT), which is introduced in the ICLR2021 paper IOT: Instance-wi

IOT 19 Nov 15, 2022
Joint-task Self-supervised Learning for Temporal Correspondence (NeurIPS 2019)

Joint-task Self-supervised Learning for Temporal Correspondence Project | Paper Overview Joint-task Self-supervised Learning for Temporal Corresponden

Sifei Liu 167 Dec 14, 2022
This is the workbook I created while I was studying for the Qiskit Associate Developer exam. I hope this becomes useful to others as it was for me :)

A Workbook for the Qiskit Developer Certification Exam Hello everyone! This is Bartu, a fellow Qiskitter. I have recently taken the Certification exam

Bartu Bisgin 66 Dec 10, 2022
SurfEmb (CVPR 2022) - SurfEmb: Dense and Continuous Correspondence Distributions

SurfEmb SurfEmb: Dense and Continuous Correspondence Distributions for Object Pose Estimation with Learnt Surface Embeddings Rasmus Laurvig Haugard, A

Rasmus Haugaard 56 Nov 19, 2022
Implementation of ICCV 2021 oral paper -- A Novel Self-Supervised Learning for Gaussian Mixture Model

SS-GMM Implementation of ICCV 2021 oral paper -- Self-Supervised Image Prior Learning with GMM from a Single Noisy Image with supplementary material R

HUST-The Tan Lab 4 Dec 05, 2022
An example showing how to use jax to train resnet50 on multi-node multi-GPU

jax-multi-gpu-resnet50-example This repo shows how to use jax for multi-node multi-GPU training. The example is adapted from the resnet50 example in d

Yangzihao Wang 20 Jul 04, 2022
ViSD4SA, a Vietnamese Span Detection for Aspect-based sentiment analysis dataset

UIT-ViSD4SA PACLIC 35 General Introduction This repository contains the data of the paper: Span Detection for Vietnamese Aspect-Based Sentiment Analys

Nguyễn Thị Thanh Kim 5 Nov 13, 2022
A Python framework for developing parallelized Computational Fluid Dynamics software to solve the hyperbolic 2D Euler equations on distributed, multi-block structured grids.

pyHype: Computational Fluid Dynamics in Python pyHype is a Python framework for developing parallelized Computational Fluid Dynamics software to solve

Mohamed Khalil 21 Nov 22, 2022
Pytorch implementation of "Training a 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet"

Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet (arxiv) This is a Pytorch implementation of our te

蒋子航 383 Dec 27, 2022
SOFT: Softmax-free Transformer with Linear Complexity, NeurIPS 2021 Spotlight

SOFT: Softmax-free Transformer with Linear Complexity SOFT: Softmax-free Transformer with Linear Complexity, Jiachen Lu, Jinghan Yao, Junge Zhang, Xia

Fudan Zhang Vision Group 272 Dec 25, 2022
[CVPR 2021] Teachers Do More Than Teach: Compressing Image-to-Image Models (CAT)

CAT arXiv Pytorch implementation of our method for compressing image-to-image models. Teachers Do More Than Teach: Compressing Image-to-Image Models Q

Snap Research 160 Dec 09, 2022
Turning pixels into virtual points for multimodal 3D object detection.

Multimodal Virtual Point 3D Detection Turning pixels into virtual points for multimodal 3D object detection. Multimodal Virtual Point 3D Detection, Ti

Tianwei Yin 204 Jan 08, 2023
Real-CUGAN - Real Cascade U-Nets for Anime Image Super Resolution

Real Cascade U-Nets for Anime Image Super Resolution 中文 | English 🔥 Real-CUGAN

tarsin 111 Dec 28, 2022
Bayesian Meta-Learning Through Variational Gaussian Processes

vmgp This is the repository of Vivek Myers and Nikhil Sardana for our CS 330 final project, Bayesian Meta-Learning Through Variational Gaussian Proces

Vivek Myers 2 Nov 17, 2022
Cortex-compatible model server for Python and TensorFlow

Nucleus model server Nucleus is a model server for TensorFlow and generic Python models. It is compatible with Cortex clusters, Kubernetes clusters, a

Cortex Labs 14 Nov 27, 2022
Industrial knn-based anomaly detection for images. Visit streamlit link to check out the demo.

Industrial KNN-based Anomaly Detection ⭐ Now has streamlit support! ⭐ Run $ streamlit run streamlit_app.py This repo aims to reproduce the results of

aventau 102 Dec 26, 2022
Shared Attention for Multi-label Zero-shot Learning

Shared Attention for Multi-label Zero-shot Learning Overview This repository contains the implementation of Shared Attention for Multi-label Zero-shot

dathuynh 26 Dec 14, 2022
This computer program provides a reference implementation of Lagrangian Monte Carlo in metric induced by the Monge patch

This computer program provides a reference implementation of Lagrangian Monte Carlo in metric induced by the Monge patch. The code was prepared to the final version of the accepted manuscript in AIST

Marcelo Hartmann 2 May 06, 2022