StarGAN - Official PyTorch Implementation

Overview

StarGAN - Official PyTorch Implementation

***** New: StarGAN v2 is available at https://github.com/clovaai/stargan-v2 *****

This repository provides the official PyTorch implementation of the following paper:

StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation
Yunjey Choi1,2, Minje Choi1,2, Munyoung Kim2,3, Jung-Woo Ha2, Sung Kim2,4, Jaegul Choo1,2    
1Korea University, 2Clova AI Research, NAVER Corp.
3The College of New Jersey, 4Hong Kong University of Science and Technology
https://arxiv.org/abs/1711.09020

Abstract: Recent studies have shown remarkable success in image-to-image translation for two domains. However, existing approaches have limited scalability and robustness in handling more than two domains, since different models should be built independently for every pair of image domains. To address this limitation, we propose StarGAN, a novel and scalable approach that can perform image-to-image translations for multiple domains using only a single model. Such a unified model architecture of StarGAN allows simultaneous training of multiple datasets with different domains within a single network. This leads to StarGAN's superior quality of translated images compared to existing models as well as the novel capability of flexibly translating an input image to any desired target domain. We empirically demonstrate the effectiveness of our approach on a facial attribute transfer and a facial expression synthesis tasks.

Dependencies

Downloading datasets

To download the CelebA dataset:

git clone https://github.com/yunjey/StarGAN.git
cd StarGAN/
bash download.sh celeba

To download the RaFD dataset, you must request access to the dataset from the Radboud Faces Database website. Then, you need to create a folder structure as described here.

Training networks

To train StarGAN on CelebA, run the training script below. See here for a list of selectable attributes in the CelebA dataset. If you change the selected_attrs argument, you should also change the c_dim argument accordingly.

# Train StarGAN using the CelebA dataset
python main.py --mode train --dataset CelebA --image_size 128 --c_dim 5 \
               --sample_dir stargan_celeba/samples --log_dir stargan_celeba/logs \
               --model_save_dir stargan_celeba/models --result_dir stargan_celeba/results \
               --selected_attrs Black_Hair Blond_Hair Brown_Hair Male Young

# Test StarGAN using the CelebA dataset
python main.py --mode test --dataset CelebA --image_size 128 --c_dim 5 \
               --sample_dir stargan_celeba/samples --log_dir stargan_celeba/logs \
               --model_save_dir stargan_celeba/models --result_dir stargan_celeba/results \
               --selected_attrs Black_Hair Blond_Hair Brown_Hair Male Young

To train StarGAN on RaFD:

# Train StarGAN using the RaFD dataset
python main.py --mode train --dataset RaFD --image_size 128 \
               --c_dim 8 --rafd_image_dir data/RaFD/train \
               --sample_dir stargan_rafd/samples --log_dir stargan_rafd/logs \
               --model_save_dir stargan_rafd/models --result_dir stargan_rafd/results

# Test StarGAN using the RaFD dataset
python main.py --mode test --dataset RaFD --image_size 128 \
               --c_dim 8 --rafd_image_dir data/RaFD/test \
               --sample_dir stargan_rafd/samples --log_dir stargan_rafd/logs \
               --model_save_dir stargan_rafd/models --result_dir stargan_rafd/results

To train StarGAN on both CelebA and RafD:

# Train StarGAN using both CelebA and RaFD datasets
python main.py --mode=train --dataset Both --image_size 256 --c_dim 5 --c2_dim 8 \
               --sample_dir stargan_both/samples --log_dir stargan_both/logs \
               --model_save_dir stargan_both/models --result_dir stargan_both/results

# Test StarGAN using both CelebA and RaFD datasets
python main.py --mode test --dataset Both --image_size 256 --c_dim 5 --c2_dim 8 \
               --sample_dir stargan_both/samples --log_dir stargan_both/logs \
               --model_save_dir stargan_both/models --result_dir stargan_both/results

To train StarGAN on your own dataset, create a folder structure in the same format as RaFD and run the command:

# Train StarGAN on custom datasets
python main.py --mode train --dataset RaFD --rafd_crop_size CROP_SIZE --image_size IMG_SIZE \
               --c_dim LABEL_DIM --rafd_image_dir TRAIN_IMG_DIR \
               --sample_dir stargan_custom/samples --log_dir stargan_custom/logs \
               --model_save_dir stargan_custom/models --result_dir stargan_custom/results

# Test StarGAN on custom datasets
python main.py --mode test --dataset RaFD --rafd_crop_size CROP_SIZE --image_size IMG_SIZE \
               --c_dim LABEL_DIM --rafd_image_dir TEST_IMG_DIR \
               --sample_dir stargan_custom/samples --log_dir stargan_custom/logs \
               --model_save_dir stargan_custom/models --result_dir stargan_custom/results

Using pre-trained networks

To download a pre-trained model checkpoint, run the script below. The pre-trained model checkpoint will be downloaded and saved into ./stargan_celeba_128/models directory.

$ bash download.sh pretrained-celeba-128x128

To translate images using the pre-trained model, run the evaluation script below. The translated images will be saved into ./stargan_celeba_128/results directory.

$ python main.py --mode test --dataset CelebA --image_size 128 --c_dim 5 \
                 --selected_attrs Black_Hair Blond_Hair Brown_Hair Male Young \
                 --model_save_dir='stargan_celeba_128/models' \
                 --result_dir='stargan_celeba_128/results'

Citation

If you find this work useful for your research, please cite our paper:

@inproceedings{choi2018stargan,
author={Yunjey Choi and Minje Choi and Munyoung Kim and Jung-Woo Ha and Sunghun Kim and Jaegul Choo},
title={StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2018}
}

Acknowledgements

This work was mainly done while the first author did a research internship at Clova AI Research, NAVER. We thank all the researchers at NAVER, especially Donghyun Kwak, for insightful discussions.

Issues
  • Is it possible to test the data that isn't in the train-set? New images with pre-trained model.

    Is it possible to test the data that isn't in the train-set? New images with pre-trained model.

    Is it possible to test the data with pre-trained model?

    I am curious if there is any way to test new images with the model I trained in the past such as training with only celebA but later testing with custom datasets.

    opened by wjun0830 15
  • RuntimeError: dimension out of range

    RuntimeError: dimension out of range

    Hi,

    I am trying to train the model on RaFD using python3 and torch v3. I've already prepared the dataset but after one epoch, this is what I am getting. Can someone please help me solve this?

    I printed the size of out_cls, real_label.

    I also tried training on CelebA+RaFD but I got cuda memory error. My gpu has only 8Gb of memory space ( I have two of them). Any workaround?

    Thanks!

    D Discriminator( (main): Sequential( (0): Conv2d (3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): LeakyReLU(0.01, inplace) (2): Conv2d (64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (3): LeakyReLU(0.01, inplace) (4): Conv2d (128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (5): LeakyReLU(0.01, inplace) (6): Conv2d (256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (7): LeakyReLU(0.01, inplace) (8): Conv2d (512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (9): LeakyReLU(0.01, inplace) (10): Conv2d (1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (11): LeakyReLU(0.01, inplace) ) (conv1): Conv2d (2048, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (conv2): Conv2d (2048, 8, kernel_size=(2, 2), stride=(1, 1), bias=False) ) The number of parameters: 44786624 Classification Acc (8 emotional expressions): ['12.50']

    Elapsed

    Classification Acc (8 emotional expressions): ['12.50'] Elapsed [0:00:04.683473], Epoch [1/200], Iter [10/31], G/loss_cls: 2.1469, G/loss_rec: 0.6333, D/loss_fake: 2.5479, D/loss_cls: 6.0393, D/loss_real: -26.6812, G/loss_fake: -1.5794, D/loss_gp: 0.3869 torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) Classification Acc (8 emotional expressions): ['18.75'] Elapsed [0:00:07.881092], Epoch [1/200], Iter [20/31], G/loss_cls: 3.8383, G/loss_rec: 0.5828, D/loss_fake: -5.0278, D/loss_cls: 9.1143, D/loss_real: -30.6130, G/loss_fake: 4.7769, D/loss_gp: 0.7379 torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) torch.Size([16, 8]) torch.Size([16]) Classification Acc (8 emotional expressions): ['31.25'] Elapsed [0:00:11.084384], Epoch [1/200], Iter [30/31], G/loss_cls: 3.3109, G/loss_rec: 0.5486, D/loss_fake: 3.4921, D/loss_cls: 4.8124, D/loss_real: -38.8579, G/loss_fake: -0.7082, D/loss_gp: 0.6353 torch.Size([8]) torch.Size([1]) Traceback (most recent call last): File "main.py", line 106, in main(config) File "main.py", line 41, in main solver.train() File "/home/ramin/codes/StarGAN/solver.py", line 274, in train d_loss_cls = F.cross_entropy(out_cls, real_label) File "/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py", line 1140, in cross_entropy return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index, reduce) File "/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py", line 786, in log_softmax return torch._C._nn.log_softmax(input, dim) RuntimeError: dimension out of range (expected to be in range of [-1, 0], but got 1)

    opened by rAm1n 10
  • Please help me to solve this

    Please help me to solve this

    image

    opened by warisha 8
  • AttributeError: module 'tensorboard.summary._tf.summary' has no attribute 'FileWriter'

    AttributeError: module 'tensorboard.summary._tf.summary' has no attribute 'FileWriter'

    Hi, I have tried setting up the project in my local system. But i am facing an issue with tensorboard.

    The number of parameters: 44762048 Traceback (most recent call last): File "main.py", line 110, in main(config) File "main.py", line 40, in main solver = Solver(celeba_loader, rafd_loader, config) File "/home/ubuntu/princess/StarGAN/solver.py", line 70, in init self.build_tensorboard() File "/home/ubuntu/princess/StarGAN/solver.py", line 109, in build_tensorboard self.logger = Logger(self.log_dir) File "/home/ubuntu/princess/StarGAN/logger.py", line 9, in init self.writer = tf.summary.FileWriter(log_dir) AttributeError: module 'tensorboard.summary._tf.summary' has no attribute 'FileWriter'

    Can you please help me out of the issue. I tried downgrading tensorboard as well as upgrade but none worked.

    Thanks in advance.

    opened by rameshd-ai 7
  • Out of memory on GPU Titan X

    Out of memory on GPU Titan X

    Hi,

    I set batch_size=3, and tried the code on GPU Titan X. The memory cost keeps on increasing, and soon causes out of memory. Must I use a GPU with 24 GB memory? Any suggestions are appreciated.

    Thanks!

    opened by tyshiwo 6
  • Pretrained model on RaFD

    Pretrained model on RaFD

    Your work is very interesting. Could you provide the pretrained model on RaFD about facial expression synthesis? Thanks!

    opened by leoliu37 6
  • AssertionError: MaskedFill can't differentiate the mask

    AssertionError: MaskedFill can't differentiate the mask

    hi, when i run the following command , i get a exception, can you help me? @yunjey

    python3 main.py --mode='train' --dataset='CelebA' --c_dim=5 --image_size=128 --sample_path='stargan_celebA/samples' --log_path='stargan_celebA/logs' --model_save_path='stargan_celebA/models' --result_path='stargan_celebA/results'

    The number of parameters: 44762048 Traceback (most recent call last): File "main.py", line 106, in main(config) File "main.py", line 41, in main solver.train() File "/data1/software/deeplearning/StarGAN/solver.py", line 279, in train accuracies = self.compute_accuracy(out_cls, real_label, self.dataset) File "/data1/software/deeplearning/StarGAN/solver.py", line 145, in compute_accuracy predicted = self.threshold(x) File "/data1/software/deeplearning/StarGAN/solver.py", line 138, in threshold x[x >= 0.5] = 1 File "/usr/local/lib/python3.5/dist-packages/torch/autograd/variable.py", line 85, in setitem return MaskedFill.apply(self, key, value, True) File "/usr/local/lib/python3.5/dist-packages/torch/autograd/_functions/tensor.py", line 440, in forward assert not ctx.needs_input_grad[1], "MaskedFill can't differentiate the mask" AssertionError: MaskedFill can't differentiate the mask

    opened by viekie 5
  • solver.py  problem

    solver.py problem

    In solver.py ,if you have no GPU, the .cuda() will be runtime error so you must judge the torch.cuda.is_available() first

    ` # Compute gradient penalty alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]`
    
    opened by seanyuyuyu 4
  • No data_loader defined for 'test' mode, and small documentation issue

    No data_loader defined for 'test' mode, and small documentation issue

    At this point self.data_loader is not defined:

    https://github.com/yunjey/StarGAN/blob/master/solver.py#L671

    I solved it on my local machine doing:

            # Set dataloader
            if self.dataset == 'CelebA':
                self.data_loader = self.celebA_loader
            else:
                self.data_loader = self.rafd_loader
    
    

    at the start of test, but not sure how you want to handle 'Both' there so I didn't submit a PR.

    Finally, the application states that:

    Translated test images and saved into ./test/results..!                                                                                                                                                                                                                                                                       
    Translated test images and saved into ./test/results..!                                                                                                                                                                                                                                                                       
    Translated test images and saved into ./test/results..!                                                                                                                                                                                                                                                                       
    Translated test images and saved into ./test/results..!                                                                                                                                                                                                                                                                       
    Translated test images and saved into ./test/results..!  
    

    but it appears to actually be putting them into ./test/samples

    Thanks for releasing this.

    opened by benburkhart1 3
  • Question: Instance Normalization with track_running_stats= True

    Question: Instance Normalization with track_running_stats= True

    Hi, I have noticed that the Generator uses "track_running_stats= True". Is there a particular reason for it? From what I can understand when the track_running_stats is True it acts as a BatchNormalization layer but I am not sure. Also, I have noticed that the new StarGAN has a track_running_stats= False. Is this some kind of a bug or what?

    Would be grateful for any clarification, Best Regards

    opened by Mypathissional 0
  • Where does the extra dimension come from in the target labels list?

    Where does the extra dimension come from in the target labels list?

    create_labels (in the solver) is called when generating target labels for network testing.

    The dimensions of the output (c_trg_list) is [c_dim x batch_size x c_dim]. Why is c_dim used for 2 dimensions, not just one? Why are the original (input) labels cloned?

    opened by aplumley 0
  • Hello, can you provide the multi-attribute translation task code of solver.py?

    Hello, can you provide the multi-attribute translation task code of solver.py?

    I want to reproduce your H + G, A + G and other multi-attribute translation images. So, Can you give me some help? Thank you very much, Your work is very nice.

    opened by zhangqian001 0
  • Reverse attribute value

    Reverse attribute value

    in line 168 in solver, why we do c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value ?

    opened by yy97831 0
  • zero vector order in the domain label vector

    zero vector order in the domain label vector

    Hi can i change the order of the zero vector in the domain label from : c_org = torch.cat([zero, c_org, mask], dim=1)
    c_trg = torch.cat([zero, c_trg, mask], dim=1)

    to

    c_org = torch.cat([c_org, zero, mask], dim=1)
    c_trg = torch.cat([ c_trg, zero, mask], dim=1)

    it should be the same as the network will ignore the zeros anyway, am I right?

    thanks

    opened by yy97831 0
  • how to save the generated images separated

    how to save the generated images separated

    How I can save all the generated images to single folder? I do not want to save all the images in one picture.

    opened by engineer-38 1
  • [Question] Mask vector in this paper

    [Question] Mask vector in this paper

    Thank you for your paper! I have a question about this paper.

    In the paper, It introduces multi-dataset and multi-domain I2I translation with CelebA and RaFD dataset. StarGAN takes the mask vector to create the desired domain for a particular dataset. Then, what should I do if I want to create a domain that corresponds to two datasets at the same time? For example, when I try to create an image with CelebA = [1, 0, 0, 1, 1] and RaFD = [0, 0, 1, 0, 0], can't I do it at the same time because of the mask vector? Do I have to proceed sequentially?

    Thank you!!

    opened by YoonSungLee 0
  • Can you share your pre-trained model on RaFD dataset?

    Can you share your pre-trained model on RaFD dataset?

    I'm citing your paper as a baseline method. Could you give me your pre-trained model on RaFD dataset? Thanks a lot.

    opened by realliujiaxu 1
  • Question about parameter adjustment when changing attributes

    Question about parameter adjustment when changing attributes

    Hi,

    I want to add more attributes to starGAN and I managed to change the attributes from 5 to 10. However, I found the results are not clear enough at 200000 iterations. I try to train for more iterations and get slight improvement. I wonder which parameter else should I adjust to get a better performance? Thank you.

    opened by Kirito0816 0
  • Disable tensorflow v2 behavior in logger.py

    Disable tensorflow v2 behavior in logger.py

    Tensorflow version 2 does not use FileWriter. This code is based on tensorflow version 1 so if we want to use the original code, we need to apply this commit (just disable tensorflow v2 behavior, use tensorflow v1)

    opened by Eun0 0
Owner
Yunjey Choi
Yunjey Choi
Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks

Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks. It takes raw videos/images + text as inputs, and outputs task predictions. ClipBERT is designed based on 2D CNNs and transformers, and uses a sparse sampling strategy to enable efficient end-to-end video-and-language learning.

Jie Lei 雷杰 486 Feb 14, 2022
An official implementation for "CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval"

The implementation of paper CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval. CLIP4Clip is a video-text retrieval model based

ArrowLuo 265 Feb 4, 2022
Official implementation of MLP Singer: Towards Rapid Parallel Korean Singing Voice Synthesis

MLP Singer Official implementation of MLP Singer: Towards Rapid Parallel Korean Singing Voice Synthesis. Audio samples are available on our demo page.

Neosapience 62 Feb 3, 2022
The official implementation of VAENAR-TTS, a VAE based non-autoregressive TTS model.

VAENAR-TTS This repo contains code accompanying the paper "VAENAR-TTS: Variational Auto-Encoder based Non-AutoRegressive Text-to-Speech Synthesis". Sa

THUHCSI 123 Jan 28, 2022
Official implementation of Meta-StyleSpeech and StyleSpeech

Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation Dongchan Min, Dong Bok Lee, Eunho Yang, and Sung Ju Hwang This is an official code

min95 103 Feb 9, 2022
Mirco Ravanelli 2.1k Feb 9, 2022
Pytorch-version BERT-flow: One can apply BERT-flow to any PLM within Pytorch framework.

Pytorch-version BERT-flow: One can apply BERT-flow to any PLM within Pytorch framework.

Ubiquitous Knowledge Processing Lab 36 Feb 3, 2022
SAINT PyTorch implementation

SAINT-pytorch A Simple pyTorch implementation of "Towards an Appropriate Query, Key, and Value Computation for Knowledge Tracing" based on https://arx

Arshad Shaikh 55 Jan 17, 2022
Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch

COCO LM Pretraining (wip) Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were a

Phil Wang 40 Dec 26, 2021
Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Phil Wang 3.8k Feb 10, 2022
A fast and easy implementation of Transformer with PyTorch.

FasySeq FasySeq is a shorthand as a Fast and easy sequential modeling toolkit. It aims to provide a seq2seq model to researchers and developers, which

宁羽 6 Sep 15, 2021
A PyTorch Implementation of End-to-End Models for Speech-to-Text

speech Speech is an open-source package to build end-to-end models for automatic speech recognition. Sequence-to-sequence models with attention, Conne

Awni Hannun 609 Jan 27, 2022
Pytorch implementation of Tacotron

Tacotron-pytorch A pytorch implementation of Tacotron: A Fully End-to-End Text-To-Speech Synthesis Model. Requirements Install python 3 Install pytorc

soobin seo 195 Dec 3, 2021
Google AI 2018 BERT pytorch implementation

BERT-pytorch Pytorch implementation of Google AI's 2018 BERT, with simple annotation BERT 2018 BERT: Pre-training of Deep Bidirectional Transformers f

Junseong Kim 4.7k Feb 5, 2022
Unofficial PyTorch implementation of Google AI's VoiceFilter system

VoiceFilter Note from Seung-won (2020.10.25) Hi everyone! It's Seung-won from MINDs Lab, Inc. It's been a long time since I've released this open-sour

MINDs Lab 791 Feb 8, 2022
Implementation of ProteinBERT in Pytorch

ProteinBERT - Pytorch (wip) Implementation of ProteinBERT in Pytorch. Original Repository Install $ pip install protein-bert-pytorch Usage import torc

Phil Wang 53 Feb 7, 2022
A PyTorch implementation of paper "Learning Shared Semantic Space for Speech-to-Text Translation", ACL (Findings) 2021

Chimera: Learning Shared Semantic Space for Speech-to-Text Translation This is a Pytorch implementation for the "Chimera" paper Learning Shared Semant

Chi Han 27 Jan 22, 2022
PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation

StyleSpeech - PyTorch Implementation PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation. Status (2021.06.09

Keon Lee 99 Feb 9, 2022
PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer

Cross-Covariance Image Transformer (XCiT) PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer L

Facebook Research 536 Feb 4, 2022