基于pytorch构建cyclegan示例

Overview

cyclegan-demo

基于Pytorch构建CycleGAN示例

如何运行

准备数据集

将数据集整理成4个文件,分别命名为

  • trainA, trainB:训练集,A、B代表两类图片
  • testA, testB:测试集,A、B代表两类图片

例如

D:\CODE\CYCLEGAN-DEMO\DATA\SUMMER2WINTER
├─testA
├─testB
├─trainA
└─trainB

之后在main.py中将root设为数据集的路径。

参数设置

main.py中的初始化参数

# 初始化参数
# seed: 随机种子
# root: 数据集路径
# output_model_root: 模型的输出路径
# image_size: 图片尺寸
# batch_size: 一次喂入的数据量
# lr: 学习率
# betas: 一阶和二阶动量
# epochs: 训练总次数
# historical_epochs: 历史训练次数
# - 0表示不沿用历史模型
# - >0表示对应训练次数的模型
# - -1表示最后一次训练的模型
# save_every: 保存频率
# loss_range: Loss的显示范围
seed = 123
data_root = 'D:/code/cyclegan-demo/data/summer2winter'
output_model_root = 'output/model'
image_size = 64
batch_size = 16
lr = 2e-4
betas = (.5, .999)
epochs = 100
historical_epochs = -1
save_every = 1
loss_range = 1000

安装和运行

  1. 安装依赖
pip install -r requirements.txt
  1. 打开命令行,运行Visdom
python -m visdom.server
  1. 运行主程序
python main.py

训练过程的可视化展示

访问地址http://localhost:8097即可进入Visdom可视化页面,页面中将展示:

  • A类真实图片 -【A2B生成器】 -> B类虚假图片 -【B2A生成器】 -> A类重构图片
  • B类真实图片 -【B2A生成器】 -> A类虚假图片 -【A2B生成器】 -> B类重构图片
  • 判别器A、B以及生成器的Loss曲线

一些可视化的具体用法可见Visdom的使用方法。

测试

TODO

介绍

目录结构

  • dataset.py 数据集
  • discriminator.py 判别器
  • generater.py 生成器
  • main.py 主程序
  • replay_buffer.py 缓冲区
  • resblk.py 残差块
  • util.py 工具方法

原理介绍

残差块是生成器的组成部分,其结构如下

Resblk(
  (main): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (2): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): ReflectionPad2d((1, 1, 1, 1))
    (5): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (6): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
)

生成器结构如下,由于采用全卷积结构,事实上其结构与图片尺寸无关

Generater(
  (input): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
  )
  (downsampling): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (resnet): Sequential(
    (0): Resblk
    (1): Resblk
    (2): Resblk
    (3): Resblk
    (4): Resblk
    (5): Resblk
    (6): Resblk
    (7): Resblk
    (8): Resblk
  )

  (upsampling): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (output): Sequential(
    (0): ReplicationPad2d((3, 3, 3, 3))
    (1): Conv2d(64, 3, kernel_size=(7, 7), stride=(1, 1))
    (2): Tanh()
  )
)

判别器结构如下,池化层具体尺寸由图片尺寸决定,64x64的图片对应池化层为6x6

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
  (output): Sequential(
    (0): AvgPool2d(kernel_size=torch.Size([6, 6]), stride=torch.Size([6, 6]), padding=0)
    (1): Flatten(start_dim=1, end_dim=-1)
  )
)

训练共有三个优化器,分别负责生成器、判别器A、判别器B的优化。

损失有三种类型:

  • 一致性损失:A(B)类真实图片与经生成器生成的图片的误差,该损失使得生成后的风格与原图更接近,采用L1Loss
  • 对抗损失:A(B)类图片经生成器得到B(A)类图片,再经判别器判别的错误率,采用MSELoss
  • 循环损失:A(B)类图片经生成器得到B(A)类图片,再经生成器得到A(B)类的重建图片,原图和重建图片的误差,采用L1Loss

生成器的训练过程:

  1. 将A(B)类真实图片送入生成器,得到生成的图片,计算生成图片与原图的一致性损失
  2. 将A(B)类真实图片送入生成器得到虚假图片,再送入判别器得到判别结果,计算判别结果与真实标签1的对抗损失(虚假图片应能被判别器判别为真实图片,即生成器能骗过判别器)
  3. 将A(B)类虚假图片送入生成器,得到重建图片,计算重建图片与原图的循环损失
  4. 计算、更新梯度

判别器A的训练过程:

  1. 将A类真实图片送入判别器A,得到判别结果,计算判别结果与真实标签1的对抗损失(判别器应将真实图片判别为真实)
  2. 将A类虚假图片送入判别器A,得到判别结果,计算判别结果与虚假标签0的对抗损失(判别器应将虚假图片判别为虚假)
  3. 计算、更新梯度
Owner
Koorye
学习?学个屁
Koorye
Is RobustBench/AutoAttack a suitable Benchmark for Adversarial Robustness?

Adversrial Machine Learning Benchmarks This code belongs to the papers: Is RobustBench/AutoAttack a suitable Benchmark for Adversarial Robustness? Det

Adversarial Machine Learning 9 Nov 27, 2022
This respository includes implementations on Manifoldron: Direct Space Partition via Manifold Discovery

Manifoldron: Direct Space Partition via Manifold Discovery This respository includes implementations on Manifoldron: Direct Space Partition via Manifo

dayang_wang 4 Apr 28, 2022
Code for Environment Inference for Invariant Learning (ICML 2020 UDL Workshop Paper)

Environment Inference for Invariant Learning This code accompanies the paper Environment Inference for Invariant Learning, which appears at ICML 2021.

Elliot Creager 40 Dec 09, 2022
Computational modelling of ray propagation through optical elements using the principles of geometric optics (Ray Tracer)

Computational modelling of ray propagation through optical elements using the principles of geometric optics (Ray Tracer) Introduction By applying the

Son Gyo Jung 1 Jul 09, 2022
Neural Scene Graphs for Dynamic Scene (CVPR 2021)

Implementation of Neural Scene Graphs, that optimizes multiple radiance fields to represent different objects and a static scene background. Learned representations can be rendered with novel object

151 Dec 26, 2022
A Python library for common tasks on 3D point clouds

Point Cloud Utils (pcu) - A Python library for common tasks on 3D point clouds Point Cloud Utils (pcu) is a utility library providing the following fu

Francis Williams 622 Dec 27, 2022
OpenMMLab Image and Video Editing Toolbox

Introduction MMEditing is an open source image and video editing toolbox based on PyTorch. It is a part of the OpenMMLab project. The master branch wo

OpenMMLab 3.9k Jan 04, 2023
Time Series Forecasting with Temporal Fusion Transformer in Pytorch

Forecasting with the Temporal Fusion Transformer Multi-horizon forecasting often contains a complex mix of inputs – including static (i.e. time-invari

Nicolás Fornasari 6 Jan 24, 2022
This tutorial aims to learn the basics of deep learning by hands, and master the basics through combination of lectures and exercises

2021-Deep-learning This tutorial aims to learn the basics of deep learning by hands, and master the basics through combination of paper and exercises.

108 Feb 24, 2022
Tiny Kinetics-400 for test

Kinetics-400迷你数据集 English | 简体中文 该数据集旨在解决的问题:参照Kinetics-400数据格式,训练基于自己数据的视频理解模型。 数据集介绍 Kinetics-400是视频领域benchmark常用数据集,详细介绍可以参考其官方网站Kinetics。整个数据集包含40

38 Jan 06, 2023
Repo for the Video Person Clustering dataset, and code for the associated paper

Video Person Clustering Repo for the Video Person Clustering dataset, and code for the associated paper. This reporsitory contains the Video Person Cl

Andrew Brown 47 Nov 02, 2022
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 07, 2022
[AI6122] Text Data Management & Processing

[AI6122] Text Data Management & Processing is an elective course of MSAI, SCSE, NTU, Singapore. The repository corresponds to the AI6122 of Semester 1, AY2021-2022, starting from 08/2021. The instruc

HT. Li 1 Jan 17, 2022
Code for Recurrent Mask Refinement for Few-Shot Medical Image Segmentation (ICCV 2021).

Recurrent Mask Refinement for Few-Shot Medical Image Segmentation Steps Install any missing packages using pip or conda Preprocess each dataset using

XIE LAB @ UCI 39 Dec 08, 2022
Visual Adversarial Imitation Learning using Variational Models (VMAIL)

Visual Adversarial Imitation Learning using Variational Models (VMAIL) This is the official implementation of the NeurIPS 2021 paper. Project website

14 Nov 18, 2022
Unofficial PyTorch implementation of SimCLR by Google Brain

Unofficial PyTorch implementation of SimCLR by Google Brain

Rishabh Anand 2 Oct 13, 2021
[Preprint] ConvMLP: Hierarchical Convolutional MLPs for Vision, 2021

Convolutional MLP ConvMLP: Hierarchical Convolutional MLPs for Vision Preprint link: ConvMLP: Hierarchical Convolutional MLPs for Vision By Jiachen Li

SHI Lab 143 Jan 03, 2023
Deploy a ML inference service on a budget in less than 10 lines of code.

BudgetML is perfect for practitioners who would like to quickly deploy their models to an endpoint, but not waste a lot of time, money, and effort trying to figure out how to do this end-to-end.

1.3k Dec 25, 2022
Flower classification model that classifies flowers in 10 classes made using transfer learning (~85% accuracy).

flower-classification-inceptionV3 Flower classification model that classifies flowers in 10 classes. Training and validation are done using a pre-anot

Ivan R. Mršulja 1 Dec 12, 2021
Official PyTorch Implementation for "Recurrent Video Deblurring with Blur-Invariant Motion Estimation and Pixel Volumes"

PVDNet: Recurrent Video Deblurring with Blur-Invariant Motion Estimation and Pixel Volumes This repository contains the official PyTorch implementatio

Junyong Lee 98 Nov 06, 2022