Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two

Overview

512x512 flowers after 12 hours of training, 1 gpu

256x256 flowers after 12 hours of training, 1 gpu

Pizza

'Lightweight' GAN

PyPI version

Implementation of 'lightweight' GAN proposed in ICLR 2021, in Pytorch. The main contributions of the paper is a skip-layer excitation in the generator, paired with autoencoding self-supervised learning in the discriminator. Quoting the one-line summary "converge on single gpu with few hours' training, on 1024 resolution sub-hundred images".

Install

$ pip install lightweight-gan

Use

One command

$ lightweight_gan --data ./path/to/images --image-size 512

Model will be saved to ./models/{name} every 1000 iterations, and samples from the model saved to ./results/{name}. name will be default, by default.

Training settings

Pretty self explanatory for deep learning practitioners

$ lightweight_gan \
    --data ./path/to/images \
    --name {name of run} \
    --batch-size 16 \
    --gradient-accumulate-every 4 \
    --num-train-steps 200000

Augmentation

Augmentation is essential for Lightweight GAN to work effectively in a low data setting

By default, the augmentation types is set to translation and cutout, with color omitted. You can include color as well with the following.

$ lightweight_gan --data ./path/to/images --aug-prob 0.25 --aug-types [translation,cutout,color]

Test augmentation

You can test and see how your images will be augmented before it pass into a neural network (if you use augmentation). Let's see how it works on this image:

Basic usage

Base code to augment your image, define --aug-test and put path to your image into --data:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg

After this will be created the file lena_augs.jpg that will be look something like this:

Options

You can use some options to change result:

  • --image-size 256 to change size of image tiles in the result. Default: 256.
  • --aug-type [color,cutout,translation] to combine several augmentations. Default: [cutout,translation].
  • --batch-size 10 to change count of images in the result image. Default: 10.
  • --num-image-tiles 5 to change count of tiles in the result image. Default: 5.

Try this command:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg \
    --batch-size 16 \
    --num-image-tiles 4 \
    --aug-types [color,translation]

result wil be something like that:

Types of augmentations

This library contains several types of embedded augmentations.
Some of these works by default, some of these can be controlled from a command as options in the --aug-types:

  • Horizontal flip (work by default, not under control, runs in the AugWrapper class);
  • color randomly change brightness, saturation and contrast;
  • cutout creates random black boxes on the image;
  • offset randomly moves image by x and y-axis with repeating image;
    • offset_h only by an x-axis;
    • offset_v only by a y-axis;
  • translation randomly moves image on the canvas with black background;

Full setup of augmentations is --aug-types [color,cutout,offset,translation].
General recommendation is using suitable augs for your data and as many as possible, then after sometime of training disable most destructive (for image) augs.

Color

Cutout

Offset

Only x-axis:

Only y-axis:

Translation

Mixed precision

You can turn on automatic mixed precision with one flag --amp

You should expect it to be 33% faster and save up to 40% memory

Multiple GPUs

Also one flag to use --multi-gpus

Generating

Once you have finished training, you can generate samples with one command. You can select which checkpoint number to load from. If --load-from is not specified, will default to the latest.

$ lightweight_gan \
  --name {name of run} \
  --load-from {checkpoint num} \
  --generate \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After run this command you will get folder near results image folder with postfix "-generated-{checkpoint num}".

You can also generate interpolations

$ lightweight_gan --name {name of run} --generate-interpolation

Show progress

After creating several checkpoints of model you can generate progress as sequence images by command:

$ lightweight_gan \
  --name {name of run} \
  --show-progress \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After running this command you will get a new folder in the results folder, with postfix "-progress". You can convert the images to a video with ffmpeg using the command "ffmpeg -framerate 10 -pattern_type glob -i '*-ema.jpg' out.mp4".

Show progress gif demonstration

Show progress video demonstration

Discriminator output size

The author has kindly let me know that the discriminator output size (5x5 vs 1x1) leads to different results on different datasets. (5x5 works better for art than for faces, as an example). You can toggle this with a single flag

# disc output size is by default 1x1
$ lightweight_gan --data ./path/to/art --image-size 512 --disc-output-size 5

Attention

You can add linear + axial attention to specific resolution layers with the following

# make sure there are no spaces between the values within the brackets []
$ lightweight_gan --data ./path/to/images --image-size 512 --attn-res-layers [32,64] --aug-prob 0.25

Bonus

You can also train with transparent images

$ lightweight_gan --data ./path/to/images --transparent

Or greyscale

$ lightweight_gan --data ./path/to/images --greyscale

Alternatives

If you want the current state of the art GAN, you can find it at https://github.com/lucidrains/stylegan2-pytorch

Citations

@inproceedings{
    anonymous2021towards,
    title={Towards Faster and Stabilized {\{}GAN{\}} Training for High-fidelity Few-shot Image Synthesis},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=1Fqg133qRaI},
    note={under review}
}
@inproceedings{
    anonymous2021global,
    title={Global Self-Attention Networks},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=KiFeuZu24k},
    note={under review}
}
@misc{cao2020global,
    title={Global Context Networks},
    author={Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
    year={2020},
    eprint={2012.13375},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{qin2020fcanet,
    title={FcaNet: Frequency Channel Attention Networks},
    author={Zequn Qin and Pengyi Zhang and Fei Wu and Xi Li},
    year={2020},
    eprint={2012.11879},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{sinha2020topk,
    title={Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples},
    author={Samarth Sinha and Zhengli Zhao and Anirudh Goyal and Colin Raffel and Augustus Odena},
    year={2020},
    eprint={2002.06224},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}

What I cannot create, I do not understand - Richard Feynman

Comments
  • Troubles with global context module in 0.15.0

    Troubles with global context module in 0.15.0

    @lucidrains

    After update to this version https://github.com/lucidrains/lightweight-gan/releases/tag/0.15.0 I cant continue train my network and did start in from zero. Previous version was in state 117k batches by 4 (468k images, around 66 hours of trainig) image and was pretty good. In new version 0.15.0 on same dataset with same parameters (--image-size 1024 --aug-types [color,offset_h] --aug-prob 1 --amp --batch-size 7) after 77k batches by 7 (539k images, around 49 hours of training) I see some bugs like oil puddle. Did you meet this or do you know how avoid this?

    image

    In previous version with sle-spatial I didnt meet something like this.

    opened by Dok11 9
  • What is sle_spatial?

    What is sle_spatial?

    I have seen this function argument mentioned in this issue:

    https://github.com/lucidrains/lightweight-gan/issues/14#issuecomment-733432989

    What is sle_spatial?

    opened by woctezuma 8
  • unable to load save model. please try downgrading the package to the version specified by the saved model

    unable to load save model. please try downgrading the package to the version specified by the saved model

    I have the following problem since today. How to do/solve this?

    continuing from previous epoch - 118 loading from version 0.21.4 unable to load save model. please try downgrading the package to the version specified by the saved model Traceback (most recent call last): File "/opt/conda/bin/lightweight_gan", line 8, in sys.exit(main()) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 193, in main fire.Fire(train_from_folder) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 184, in train_from_folder run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 59, in run_training model.load(load_from) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1603, in load raise e File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1600, in load self.GAN.load_state_dict(load_data['GAN']) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightweightGAN: Missing key(s) in state_dict: "G.layers.0.0.2.1.weight", "G.layers.0.0.2.1.bias", "G.layers.0.0.4.weight", "G.layers.0.0.4.bias", "G.layers.0.0.4.running_mean", "G.layers.0.0.4.running_var", "G.layers.1.0.2.1.weight", "G.layers.1.0.2.1.bias", "G.layers.1.0.4.weight", "G.layers.1.0.4.bias", "G.layers.1.0.4.running_mean", "G.layers.1.0.4.running_var", "G.layers.2.0.2.1.weight", "G.layers.2.0.2.1.bias", "G.layers.2.0.4.weight", "G.layers.2.0.4.bias", "G.layers.2.0.4.running_mean", "G.layers.2.0.4.running_var", "G.layers.3.0.2.1.weight", "G.layers.3.0.2.1.bias", "G.layers.3.0.4.weight", "G.layers.3.0.4.bias", "G.layers.3.0.4.running_mean", "G.layers.3.0.4.running_var", "G.layers.3.2.fn.to_lin_q.weight", "G.layers.3.2.fn.to_lin_kv.net.0.weight", "G.layers.3.2.fn.to_lin_kv.net.1.weight", "G.layers.3.2.fn.to_kv.weight", "G.layers.4.0.2.1.weight", "G.layers.4.0.2.1.bias", "G.layers.4.0.4.weight", "G.layers.4.0.4.bias", "G.layers.4.0.4.running_mean", "G.layers.4.0.4.running_var", "G.layers.5.0.2.1.weight", "G.layers.5.0.2.1.bias", "G.layers.5.0.4.weight", "G.layers.5.0.4.bias", "G.layers.5.0.4.running_mean", "G.layers.5.0.4.running_var", "D.residual_layers.3.1.fn.to_lin_q.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D.residual_layers.3.1.fn.to_kv.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.weight", "GE.layers.0.0.2.1.weight", "GE.layers.0.0.2.1.bias", "GE.layers.0.0.4.weight", "GE.layers.0.0.4.bias", "GE.layers.0.0.4.running_mean", "GE.layers.0.0.4.running_var", "GE.layers.1.0.2.1.weight", "GE.layers.1.0.2.1.bias", "GE.layers.1.0.4.weight", "GE.layers.1.0.4.bias", "GE.layers.1.0.4.running_mean", "GE.layers.1.0.4.running_var", "GE.layers.2.0.2.1.weight", "GE.layers.2.0.2.1.bias", "GE.layers.2.0.4.weight", "GE.layers.2.0.4.bias", "GE.layers.2.0.4.running_mean", "GE.layers.2.0.4.running_var", "GE.layers.3.0.2.1.weight", "GE.layers.3.0.2.1.bias", "GE.layers.3.0.4.weight", "GE.layers.3.0.4.bias", "GE.layers.3.0.4.running_mean", "GE.layers.3.0.4.running_var", "GE.layers.3.2.fn.to_lin_q.weight", "GE.layers.3.2.fn.to_lin_kv.net.0.weight", "GE.layers.3.2.fn.to_lin_kv.net.1.weight", "GE.layers.3.2.fn.to_kv.weight", "GE.layers.4.0.2.1.weight", "GE.layers.4.0.2.1.bias", "GE.layers.4.0.4.weight", "GE.layers.4.0.4.bias", "GE.layers.4.0.4.running_mean", "GE.layers.4.0.4.running_var", "GE.layers.5.0.2.1.weight", "GE.layers.5.0.2.1.bias", "GE.layers.5.0.4.weight", "GE.layers.5.0.4.bias", "GE.layers.5.0.4.running_mean", "GE.layers.5.0.4.running_var", "D_aug.D.residual_layers.3.1.fn.to_lin_q.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.weight". Unexpected key(s) in state_dict: "G.layers.0.0.2.weight", "G.layers.0.0.2.bias", "G.layers.0.0.3.bias", "G.layers.0.0.3.running_mean", "G.layers.0.0.3.running_var", "G.layers.0.0.3.num_batches_tracked", "G.layers.1.0.2.weight", "G.layers.1.0.2.bias", "G.layers.1.0.3.bias", "G.layers.1.0.3.running_mean", "G.layers.1.0.3.running_var", "G.layers.1.0.3.num_batches_tracked", "G.layers.2.0.2.weight", "G.layers.2.0.2.bias", "G.layers.2.0.3.bias", "G.layers.2.0.3.running_mean", "G.layers.2.0.3.running_var", "G.layers.2.0.3.num_batches_tracked", "G.layers.3.0.2.weight", "G.layers.3.0.2.bias", "G.layers.3.0.3.bias", "G.layers.3.0.3.running_mean", "G.layers.3.0.3.running_var", "G.layers.3.0.3.num_batches_tracked", "G.layers.3.2.fn.to_kv.net.0.weight", "G.layers.3.2.fn.to_kv.net.1.weight", "G.layers.4.0.2.weight", "G.layers.4.0.2.bias", "G.layers.4.0.3.bias", "G.layers.4.0.3.running_mean", "G.layers.4.0.3.running_var", "G.layers.4.0.3.num_batches_tracked", "G.layers.5.0.2.weight", "G.layers.5.0.2.bias", "G.layers.5.0.3.bias", "G.layers.5.0.3.running_mean", "G.layers.5.0.3.running_var", "G.layers.5.0.3.num_batches_tracked", "D.residual_layers.3.1.fn.to_kv.net.0.weight", "D.residual_layers.3.1.fn.to_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight", "GE.layers.0.0.2.weight", "GE.layers.0.0.2.bias", "GE.layers.0.0.3.bias", "GE.layers.0.0.3.running_mean", "GE.layers.0.0.3.running_var", "GE.layers.0.0.3.num_batches_tracked", "GE.layers.1.0.2.weight", "GE.layers.1.0.2.bias", "GE.layers.1.0.3.bias", "GE.layers.1.0.3.running_mean", "GE.layers.1.0.3.running_var", "GE.layers.1.0.3.num_batches_tracked", "GE.layers.2.0.2.weight", "GE.layers.2.0.2.bias", "GE.layers.2.0.3.bias", "GE.layers.2.0.3.running_mean", "GE.layers.2.0.3.running_var", "GE.layers.2.0.3.num_batches_tracked", "GE.layers.3.0.2.weight", "GE.layers.3.0.2.bias", "GE.layers.3.0.3.bias", "GE.layers.3.0.3.running_mean", "GE.layers.3.0.3.running_var", "GE.layers.3.0.3.num_batches_tracked", "GE.layers.3.2.fn.to_kv.net.0.weight", "GE.layers.3.2.fn.to_kv.net.1.weight", "GE.layers.4.0.2.weight", "GE.layers.4.0.2.bias", "GE.layers.4.0.3.bias", "GE.layers.4.0.3.running_mean", "GE.layers.4.0.3.running_var", "GE.layers.4.0.3.num_batches_tracked", "GE.layers.5.0.2.weight", "GE.layers.5.0.2.bias", "GE.layers.5.0.3.bias", "GE.layers.5.0.3.running_mean", "GE.layers.5.0.3.running_var", "GE.layers.5.0.3.num_batches_tracked", "D_aug.D.residual_layers.3.1.fn.to_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight". size mismatch for G.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for G.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]). size mismatch for GE.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for GE.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D_aug.D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]).

    opened by sebastiantrella 7
  • Greyscale image generation

    Greyscale image generation

    Hi,

    thank you for this repo, I've been playing with it a bit and it seems very good! I am trying to generate greyscale images, so I modified the channel accordingly

    init_channel = 4 if transparent else 1

    unfortunately, this seemed to have no effect as the images generated are still RGB (even though they converge towards greyscale with time), even weirder IMO is that I can modify the number of channels for the generator and keep the original 3 for the discriminator without any issue.

    I have also changed this part to no effect

    convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'L') num_channels = 1 if not transparent else 4

    Am I missing something here?

    opened by stefanorosss 7
  • Getting NoneType is not subscriptable when trying to start training.

    Getting NoneType is not subscriptable when trying to start training.

    I've been able to train models before but after changing my dataset I'm getting the error.

    My trace: File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 1356, in load name = checkpoints[-1] TypeError: 'NoneType' object is not subscriptable

    opened by TomCallan 6
  • Optimal parameters for Google Colab

    Optimal parameters for Google Colab

    Hello,

    First of all, thank you for sharing your code and insights with the rest of us!

    As for your code, I plan to run it for 12 hours on Google Colab, similarly to the set-up for what is shown in the README.

    My datasets consists of images of 256x256 resolution, and I have started training with the following command-line:

    !lightweight_gan \
     --data {image_dir} \
     --disc-output-size 5 \
     --aug-prob 0.25 \
     --aug-types [translation,cutout,color] \
     --amp \
    

    I have noticed that the expected training time is 112.5 hours with 150k iterations (the default setting), which is consistent with the average time of 2.7 seconds per iteration shown in the log. However, it is ~ 9 times more than what is shown in the README. So I wonder if I am doing something wrong, and I see 2 solutions.

    First, I could decrease the number of iterations so that it takes 12 hours, by choosing 16k iterations instead of 150k with:

     --num-train-steps 16000 \
    

    Is it what you have done for the results shown in the README?

    Second, I have noticed that I am only using 3.8 GB of GPU memory, so I could increase the batch size, as you mentioned in https://github.com/lucidrains/lightweight-gan/issues/13#issuecomment-732486110. Edit: However, the training time increases with a larger batch size. For instance, I am using 7.2 GB of GPU memory, and it takes 8.2 seconds per iteration, with the following:

     --batch-size 32 \
     --gradient-accumulate-every 4 \
    
    opened by woctezuma 6
  • Added Experimentation Tracking.

    Added Experimentation Tracking.

    Added Experimentation Tracking using Aim.

    Now you can:

    Track all the model hyperparameters and architectural choices. Track all types of losses. Filter all the experiments with respect to hyperparameters or the architecture Group and aggregate w.r.t. all the trackables to dive into granular experimentation assessment. Track the generated images to track how the model improves.

    Screen Shot 2022-04-12 at 16 56 35 Screen Shot 2022-04-12 at 16 57 24
    opened by hnhnarek 5
  • Aim installation error

    Aim installation error

    I'm trying to run the generator after training, to generate fake samples using the following command

    lightweight_gan --generate --load-from 299

    I get this following error:

    Traceback (most recent call last):
      File "C:\anaconda3\lib\runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "C:\anaconda3\lib\runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "C:\anaconda3\Scripts\lightweight_gan.exe\__main__.py", line 7, in <module>
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 195, in main
        fire.Fire(train_from_folder)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 141, in Fire
        component_trace = _Fire(component, args, parsed_flag_args, context, name)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 466, in _Fire
        component, remaining_args = _CallAndUpdateTrace(
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 681, in _CallAndUpdateTrace
        component = fn(*varargs, **kwargs)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 158, in train_from_folder
        model = Trainer(**model_args)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\lightweight_gan.py", line 1057, in __init__
        self.run = self.aim.Run(run_hash=aim_run_hash, repo=aim_repo)
    AttributeError: 'Trainer' object has no attribute 'aim'
    

    and when I try to run pip install aim, I get a dependency error with aimrocks

      ERROR: Command errored out with exit status 1:
       command: 'C:\Anaconda3\envs\aerialweb\python.exe' 'C:\Anaconda3\envs\aerialweb\lib\site-packages\pip' install --ignore-installed --no-user --prefix 'C:\Users\ahmed\AppData\Local\Temp\pip-build-env-b2ysw94t\overlay' --no-warn-script-location --no-binary :none: --only-binary :none: -i https://pypi.org/simple -- setuptools 'cython >= 3.0.0a9' 'aimrocks == 0.2.1'
           cwd: None
      Complete output (12 lines):
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      Collecting setuptools
        Using cached setuptools-59.6.0-py3-none-any.whl (952 kB)
      Collecting cython>=3.0.0a9
        Using cached Cython-3.0.0a10-py2.py3-none-any.whl (1.1 MB)
      ERROR: Could not find a version that satisfies the requirement aimrocks==0.2.1 (from versions: 0.1.3a14, 0.2.0.dev1, 0.2.0)
      ERROR: No matching distribution found for aimrocks==0.2.1
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
    

    What is aimrocks and what does it actually do? I am unable to find a matching distribution or even a wheel file to install it manually. Please help

    opened by demiahmed 4
  • Can't find

    Can't find "__main__" module (sorry if noob question)

    Hello, I hope it's not too much of a noob question, I don't have any background in coding.

    After creating the env and installing Pytorch I ran "python setup.py install" and then I ran "python lightweight_gan --data /source --image-size 512" (I filled a "source" folder with pictures of fishes) but I get the error "can't find 'main' module". More exactly, C:\Programmes perso\Logiciels\Anaconda\envs\lightweightgan\python.exe: can't find 'main' module in 'C:\Programmes perso\Logiciels\LightweightGan\lightweight_gan' I tried to copy and rename some of the other modules (init, lightweight_gan...), the code seems to start to run but stops before doing anything. So I guess some file must be missing, or did I do something wrong ?

    Thanks a lot for the repo and have a nice day

    opened by SPotaufeux 4
  • Hard cutoff straight lines/boxes of nothing in generated images

    Hard cutoff straight lines/boxes of nothing in generated images

    Hello! Training on Google Colab with

    !lightweight_gan --data my/images/ --name my-name --image-size 256 --transparent --dual-contrast-loss --num-train-steps 250000
    

    I'm at 250k iterations over the course of 5 days at 2s/it, and have gotten strange results with boxes.

    I've circled some examples of this below. image

    My training data is 22k images of 256x256 .pngs that do not contain large hard edges or boxes like this. They're video game sprites with hard edges being limited to at most 10x10px

    Are there any suggestions I can do with arguments in order to decrease the chance of the models learning that transparent boxes are good? Would converting to a white background help?

    Thank you!

    opened by timendez 4
  • Amount of training steps

    Amount of training steps

    If I bring down the number of training steps from 150 000 to 30 000, will the trained model be overall bad? Does it really need the 100 000 or 150 000 training steps?

    opened by MartimQS 4
  • Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    A rather frustrating issue:

    calling it with a trailing \ like lightweight_gan --data full_cleaned256/ --name c256 --models_dir models --results_dir results \

    sets the variable new to the truthy value '\' and deletes all progress.

    This might well be an issue with Fire but might be mitigated or fixed here too, I am unsure about that.

    Thanks. Jonas

    opened by deklesen 0
  • Projecting generated Images to Latent Space

    Projecting generated Images to Latent Space

    Is there any way to reverse engineer the generated images into the latent space?

    I am trying to embed fresh RGB as well as ones generated by the Generator into the latent space so I can find its nearest neighbour, pretty much like AI image editing tools.

    I plan to convert my RGB image into tensor embeddings based on my trained model and tweak the feature vectors.

    How can I achieve this with lightweight-gan?

    opened by demiahmed 0
  • Discriminator Loss converges to 0 while Generator loss pretty high

    Discriminator Loss converges to 0 while Generator loss pretty high

    I am trying to train with a custom image dataset for about 600,000 epochs. At about halfway, my D_loss converges to 0 while my G_loss stays put at 2.5

    My evaluation outputs are slowly starting to fade out to either black or white.

    Is there any thing that I could to tweak my model? Either by increasing the threshold for the Discriminator or by training the Generator only?

    opened by demiahmed 3
  • loss implementation differs from paper

    loss implementation differs from paper

    Hi,

    Thanks for this amazing implementation! I have a question concerning the loss implementation, as it seems to differ from the original equations. The screenshot below shows the GAN loss as presented in the paper :

    paper_losses

    • in red, the discriminator loss (D loss) on the true labels,
    • in green the D loss on labels for fake generated images,
    • and in blue, the generator loss (G loss) on labels for fake images.

    This makes sense to me. Since it is assumed that D outputs values between 0 and 1 (0 = fake, 1 = real) :

    • in red, we want D to output 1 for true images → let's assume D indeed outputs 1 for true images : -min(0, -1 + D(x)) = 0, which is indeed the minimum achievable,
    • in green, we want D to output 0 (from the discriminator perspective) for fake images → let's assume D indeed outputs 0 for fake images : -min(0, -1 - D(x^)) = 1, which is the minimum achievable if D outputs values only between 0 and 1,
    • in blue, we want D to output 1 (from the generator perspective) for fake images : the equation follows directly.

    Now, the way the authors implement this in the code provided in the supplementary materials of the paper is as follows (the colors match the ones in the above picture)

    og_code_loss_d_real og_code_loss_d_fake og_code_loss_g

    Except for the strange involved randomness (already explained in https://github.com/lucidrains/lightweight-gan/issues/11), their implementation is a one to one match with the paper equations.


    The way it is implemented in this repo however is quite different, and I do not understand why..

    lighweight_gan_losses

    Let's start with the discriminator loss :

    • in red, you want D to output small values (negative if allowed), to set this term as small as possible (0 if D can output negative values)
    • in green, you want D to output values as large as possible (larger or equal to 1) to cancel this term out as well

    For the generator loss :

    • in blue, you want the opposite of green, that is for D to output values as small as possible

    This implementation seems to be meaningful, and yields coherent results (as proven in examples). It also seems to me that D is not limited to output values between 0 and 1, but any real value (I might be wrong). I am just wondering why this choice? Could you perhaps elaborate why you decided to implement the loss differently from the original paper?

    opened by maximeraafat 1
  • showing results while training ?

    showing results while training ?

    how to show generator results after every epoch during training ?

    this is my current configuration

     lightweight_gan \
      --data "/content/dataset/Dataset/" \
      --num-train-steps 100000 \
      --image-size 128 \
      --name GAN2DBlood5k \
      --batch-size 32 \
      --gradient-accumulate-every 5 \
      --disc-output-size 1 \
      --dual-contrast-loss \
      --attn-res-layers [] \
      --calculate_fid_every 1000\
      --greyscale \
      --amp
    

    using --show-progress only works after training. Also it seems that there is no longer checkpoints per epoch

    opened by galaelized 2
Releases(1.1.1)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
NPBG++: Accelerating Neural Point-Based Graphics

[CVPR 2022] NPBG++: Accelerating Neural Point-Based Graphics Project Page | Paper This repository contains the official Python implementation of the p

Ruslan Rakhimov 57 Dec 03, 2022
How Do Adam and Training Strategies Help BNNs Optimization? In ICML 2021.

AdamBNN This is the pytorch implementation of our paper "How Do Adam and Training Strategies Help BNNs Optimization?", published in ICML 2021. In this

Zechun Liu 47 Sep 20, 2022
natural image generation using ConvNets

The Eyescream Project Generating Natural Images using Neural Networks. For our research summary on this work, please read the Arxiv paper: http://arxi

Meta Archive 601 Nov 23, 2022
[NeurIPS '21] Adversarial Attacks on Graph Classification via Bayesian Optimisation (GRABNEL)

Adversarial Attacks on Graph Classification via Bayesian Optimisation @ NeurIPS 2021 This repository contains the official implementation of GRABNEL,

Xingchen Wan 12 Dec 23, 2022
RL Algorithms with examples in Python / Pytorch / Unity ML agents

Reinforcement Learning Project This project was created to make it easier to get started with Reinforcement Learning. It now contains: An implementati

Rogier Wachters 3 Aug 19, 2022
Churn prediction

Churn-prediction Churn-prediction Data preprocessing:: Label encoder is used to normalize the categorical variable Data Transformation:: For each data

1 Sep 28, 2022
A booklet on machine learning systems design with exercises

Machine Learning Systems Design Read this booklet here. This booklet covers four main steps of designing a machine learning system: Project setup Data

Chip Huyen 7.6k Jan 08, 2023
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022
Hierarchical Cross-modal Talking Face Generation with Dynamic Pixel-wise Loss (ATVGnet)

Hierarchical Cross-modal Talking Face Generation with Dynamic Pixel-wise Loss (ATVGnet) By Lele Chen , Ross K Maddox, Zhiyao Duan, Chenliang Xu. Unive

Lele Chen 218 Dec 27, 2022
Semantic Segmentation in Pytorch. Network include: FCN、FCN_ResNet、SegNet、UNet、BiSeNet、BiSeNetV2、PSPNet、DeepLabv3_plus、 HRNet、DDRNet

🚀 If it helps you, click a star! ⭐ Update log 2020.12.10 Project structure adjustment, the previous code has been deleted, the adjustment will be re-

Deeachain 269 Jan 04, 2023
git《Learning Pairwise Inter-Plane Relations for Piecewise Planar Reconstruction》(ECCV 2020) GitHub:

Learning Pairwise Inter-Plane Relations for Piecewise Planar Reconstruction Code for the ECCV 2020 paper by Yiming Qian and Yasutaka Furukawa Getting

37 Dec 04, 2022
Fine-tuning StyleGAN2 for Cartoon Face Generation

Cartoon-StyleGAN 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation Abstract Recent studies have shown remarkable success in the unsupervised imag

Jihye Back 520 Jan 04, 2023
blind SQLIpy sebuah alat injeksi sql yang menggunakan waktu sql untuk mendapatkan sebuah server database.

blind SQLIpy Alat blind SQLIpy ini merupakan alat injeksi sql yang menggunakan metode time based blind sql injection metode tersebut membutuhkan waktu

Galih Anggoro Prasetya 4 Feb 24, 2022
Unofficial PyTorch Implementation of "Augmenting Convolutional networks with attention-based aggregation"

Pytorch Implementation of Augmenting Convolutional networks with attention-based aggregation This is the unofficial PyTorch Implementation of "Augment

DK 20 Sep 09, 2022
Decoding the Protein-ligand Interactions Using Parallel Graph Neural Networks

Decoding the Protein-ligand Interactions Using Parallel Graph Neural Networks Requirements python 0.10+ rdkit 2020.03.3.0 biopython 1.78 openbabel 2.4

Neeraj Kumar 3 Nov 23, 2022
Cleaned up code for DSTC 10: SIMMC 2.0 track: subtask 2: multimodal coreference resolution

UNITER-Based Situated Coreference Resolution with Rich Multimodal Input: arXiv MMCoref_cleaned Code for the MMCoref task of the SIMMC 2.0 dataset. Pre

Yichen (William) Huang 2 Dec 05, 2022
Sarus implementation of classical ML models. The models are implemented using the Keras API of tensorflow 2. Vizualization are implemented and can be seen in tensorboard.

Sarus published models Sarus implementation of classical ML models. The models are implemented using the Keras API of tensorflow 2. Vizualization are

Sarus Technologies 39 Aug 19, 2022
Pairwise Learning for Neural Link Prediction for OGB (PLNLP-OGB)

Pairwise Learning for Neural Link Prediction for OGB (PLNLP-OGB) This repository provides evaluation codes of PLNLP for OGB link property prediction t

Zhitao WANG 31 Oct 10, 2022
The Official Repository for "Generalized OOD Detection: A Survey"

Generalized Out-of-Distribution Detection: A Survey 1. Overview This repository is with our survey paper: Title: Generalized Out-of-Distribution Detec

Jingkang Yang 338 Jan 03, 2023
ProjectOxford-ClientSDK - This repo has moved :house: Visit our website for the latest SDKs & Samples

This project has moved 🏠 We heard your feedback! This repo has been deprecated and each project has moved to a new home in a repo scoped by API and p

Microsoft 970 Nov 28, 2022