PyTorch implementation of the TTC algorithm

Overview

Trust-the-Critics

This repository is a PyTorch implementation of the TTC algorithm and the WGAN misalignment experiments presented in Trust the Critics: Generatorless and Multipurpose WGANs with Initial Convergence Guarantees.

How to run this code

  • Create a Python virtual environment with Python 3.8 installed.
  • Install the necessary Python packages listed in the requirements.txt file (this can be done through pip install -r /path/to/requirements.txt).

In the example_shell_scripts folder, we include samples of shell scripts we used to run our experiments. We note that training generative models is computationally demanding, and thus requires adequate computational resources (i.e. running this on your laptop is not recommended).

TTC algorithm

The various experiments we run with TTC are described in Section 5 and Addendix B of the paper. Illustrating the flexibility of the TTC algorithm, the image generation, denoising and translation experiments can all be run using the ttc.py script; the only necessary changes are the source and target datasets. Running TTC with a given source and a given target will train and save several critic neural networks that can subsequently be used to push the source distribution towards the target distribution by applying the 'steptaker' function found in TTC_utils/steptaker.py once for each critic.

Necessary arguments for ttc.py are:

  • 'source' : The name of the distribution or dataset that is to be pushed towards the target (options are listed in ttc.py).
  • 'target' : The name of the target dataset (options are listed in ttc.py).
  • 'data' : The path of a directory where the necessary data is located. This includes the target dataset, in a format that can be accessed by a dataloader object obtained from the corresponding function in dataloader.py. Such a dataloader always belongs to the torch.utils.data.DataLoader class (e.g. if target=='mnist', then the corresponding dataloader will be an instance of torchvision.datasets.MNIST, and the MNIST dataset should be placed in 'data' in a way that reflects this). If the source is a dataset, it needs to be placed in 'data' as well. If source=='untrained_gen', then the untrained generator used to create the source distribution needs to be saved under 'data/ugen.pth'.
  • 'temp_dir' : The path of a directory where the trained critics will be saved, along with a few other files (including the log.pkl file that contains the step sizes). Despite the name, this folder isn't necessarily temporary.

Other optional arguments are described in a commented section at the top of the ttc.py script. Note that running ttc.py will only train the critics that the TTC algorithm uses to push the source distribution towards the target distribution, it will not actually push any samples from the source towards the target (as mentioned above, this is done using the steptaker function).

TTC image generation
For a generative experiment, run ttc.py with the source argument set to either 'noise' or 'untrained_gen' and the target of your choice. Then, run ttc_eval.py, which will use the saved critics and step sizes to push noise inputs towards the target distribution according to the TTC algorithm (using the steptaker function), and which will optionally evaluate generative performance with FID and/or MMD (FID is used in the paper). The arguments 'source', 'target', 'data', 'temp_dir' and 'model' for ttc_eval.py should be set to the same values as when running ttc.py. If evaluating FID, the folder specified by 'temp_dir' should contain a subdirectory named 'temp_dir/{target}test' (e.g. 'temp_dir/mnisttest' if target=='mnist') containing the test data from the target dataset saved as individual files. For instance, this folder could contain files of the form '00001.jpg', '00002.jpg', etc. (although extensions other than .jpg can be used).

TTC denoising
For a denoising experiment, run ttc.py with source=='noisybsds500' and target=='bsds500' (specifying a noise level with the 'sigma' argument). Then, run denoise_eval.py (with the same 'temp_dir', 'data' and 'model' arguments), which will add noise to images, denoise them using the TTC algorithm and the saved critics, and evaluate PSNR's.

TTC Monet translation
For a denoising experiment, run ttc.py with source=='photo' and target=='monet'. Then run ttc_eval.py (with the same 'source', 'target', 'temp_dir', 'data' and 'model' arguments, and presumably with no FID or MMD evaluation), which will sample realistic images from the source and make them look like Monet paintings.

WGAN misalignment

The WGAN misalignment experiments are described in Section 3 and Appendix B.1 of the paper, and are run using misalignments.py. This script trains a WGAN while, at some iterations, measuring how misaligned the movement of generated samples caused by updating the generator is from the critic's gradient. The generator's FID is also measured at the same iterations.

The required arguments for misalignments.py are:

  • 'target' : The dataset used to train the WGAN - can be either 'mnist' or 'fashion' (for Fashion-MNIST).
  • 'data' : The path of a folder where the MNIST (or Fashion-MNIST) dataset is located, in a format that can be accessed by an instance of the torchvision.datasets.MNIST class (resp torchvision.datasets.FashionMNIST).
  • 'fid_data' : The path of a folder containing the test data from the MNIST dataset saved as individual files. For instance, this folder could contain files of the form '00001.jpg', '00002.jpg', etc. (although extensions other than .jpg can be used).
  • 'checkpoints' : A string of integers separated by underscores. The integers specify the iterations at which misalignments and FID are computed, and training will continue until the largest iteration is reached.

Other optional arguments (including 'results_path' and 'temp_dir') are described in a commented section at the top of the misalignments.py. The misalignment results reported in the paper (Tables 1 and 5, and Figure 3), correspond to using the default hyperparameters and to setting the 'checkpoints' argument roughly equal to '10_25000_40000', with '10' corresponding the early stage in training, '25000' to the mid stage, and '40000' to the late stage.

WGAN generation

For completeness we include the code that was used to obtain the WGAN FID statistics in Table 3 of the paper, which includes the wgan_gp.py and wgan_gp_eval.py scripts. The former trains a WGAN with the InfoGAN architecture on the dataset specified by the 'target' argument, saving generator model dictionaries in the folder specified by 'temp_dir' at ten equally spaced stages in training. The wgan_gp_eval.py script evaluates the performance of the generator with the different model dictionaries in 'temp_dir'.

The necessary arguments to run wgan_gp.py are:

  • 'target' : The name of the dataset to generate (can be either 'mnist', 'fashion' or 'cifar10').
  • 'data' : Folder where the dataset is located.
  • 'temp_dir' : Folder where the model dictionaries are saved.

Once wgan_gp.py has run, wgan_gp_eval.py should be called with the same arguments for 'target', 'data' and 'temp_dir', and setting the 'model' argument to 'infogan'. If evaluating FID, the 'temp_dir' folder needs to contain the test data from the target dataset saved as individual files. For instance, this folder could contain files of the form '00001.jpg', '00002.jpg', etc. (although extensions other than .jpg can be used).

Reproducibility

This repository contains two branches: 'main' and 'reproducible'. You are currectly viewing the 'main' branch, which contains a clean version of the code meant to be easy to read and interpret and to run more efficiently than the version on the 'reproducible' branch. The results obtained by running the code on this branch should be nearly (but not perfectly) identical to the results stated in the papers, the differences stemming from the randomness inherent to the experiments. The 'reproducible' branch allows one to replicate exactly the results stated in the paper (random seeds are specified) for the TTC experiments.

Computing architecture and running times

We ran different versions of the code presented here on Compute Canada (https://www.computecanada.ca/) clusters, always using a single NVIDIA V100 Volta or NVIDIA A100 Ampere GPU. Here are rough estimations of the running times for our experiments.

  • MNIST/Fashion MNIST generation training run (TTC): 60-90 minutes.
  • MNIST/Fashion MNIST generation training run (WGAN): 45-90 minutes (this includes misalignments computations).
  • CIFAR10 generation training run: 3-4 hours (TTC), 90 minutes (WGAN-GP).
  • Image translation training run: up to 20 hours.
  • Image denoising training run: 8-10 hours.

Assets

Portions of this code, as well as the datasets used to produce our experimental results, make use of existing assets. We provide here a list of all assets used, along with the licenses under which they are distributed, if specified by the originator:

LibMTL: A PyTorch Library for Multi-Task Learning

LibMTL LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and AP

765 Jan 06, 2023
StyleGAN - Official TensorFlow Implementation

StyleGAN — Official TensorFlow Implementation Picture: These people are not real – they were produced by our generator that allows control over differ

NVIDIA Research Projects 13.1k Jan 09, 2023
imbalanced-DL: Deep Imbalanced Learning in Python

imbalanced-DL: Deep Imbalanced Learning in Python Overview imbalanced-DL (imported as imbalanceddl) is a Python package designed to make deep imbalanc

NTUCSIE CLLab 19 Dec 28, 2022
A python module for configuration of block devices

Blivet is a python module for system storage configuration. CI status Licence See COPYING Installation From Fedora repositories Blivet is available in

78 Dec 14, 2022
Code and data for ACL2021 paper Cross-Lingual Abstractive Summarization with Limited Parallel Resources.

Multi-Task Framework for Cross-Lingual Abstractive Summarization (MCLAS) The code for ACL2021 paper Cross-Lingual Abstractive Summarization with Limit

Yu Bai 43 Nov 07, 2022
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Tencent YouTu Research 50 Dec 03, 2022
Script that receives an Image (original) and a set of images to be used as "pixels" in reconstruction of the Original image using the set of images as "pixels"

picinpics Script that receives an Image (original) and a set of images to be used as "pixels" in reconstruction of the Original image using the set of

RodrigoCMoraes 1 Oct 24, 2021
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

24 Dec 31, 2022
Google Brain - Ventilator Pressure Prediction

Google Brain - Ventilator Pressure Prediction https://www.kaggle.com/c/ventilator-pressure-prediction The ventilator data used in this competition was

Samuele Cucchi 1 Feb 11, 2022
A library for building and serving multi-node distributed faiss indices.

About Distributed faiss index service. A lightweight library that lets you work with FAISS indexes which don't fit into a single server memory. It fol

Meta Research 170 Dec 30, 2022
This library contains a Tensorflow implementation of the paper Stability Analysis of Unfolded WMMSE for Power Allocation

UWMMSE-stability Tensorflow implementation of Stability Analysis of UWMMSE Overview This library contains a Tensorflow implementation of the paper Sta

Arindam Chowdhury 1 Nov 16, 2022
BlockUnexpectedPackets - Preventing BungeeCord CPU overload due to Layer 7 DDoS attacks by scanning BungeeCord's logs

BlockUnexpectedPackets This script automatically blocks DDoS attacks that are sp

SparklyPower 3 Mar 31, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
Pytorch implementation of winner from VQA Chllange Workshop in CVPR'17

2017 VQA Challenge Winner (CVPR'17 Workshop) pytorch implementation of Tips and Tricks for Visual Question Answering: Learnings from the 2017 Challeng

Mark Dong 166 Dec 11, 2022
A PyTorch implementation of the Relational Graph Convolutional Network (RGCN).

Torch-RGCN Torch-RGCN is a PyTorch implementation of the RGCN, originally proposed by Schlichtkrull et al. in Modeling Relational Data with Graph Conv

Thiviyan Singam 66 Nov 30, 2022
Text-to-SQL in the Wild: A Naturally-Occurring Dataset Based on Stack Exchange Data

SEDE SEDE (Stack Exchange Data Explorer) is new dataset for Text-to-SQL tasks with more than 12,000 SQL queries and their natural language description

Rupert. 83 Nov 11, 2022
Playable Video Generation

Playable Video Generation Playable Video Generation Willi Menapace, Stéphane Lathuilière, Sergey Tulyakov, Aliaksandr Siarohin, Elisa Ricci Paper: ArX

Willi Menapace 136 Dec 31, 2022
Material for my PyConDE & PyData Berlin 2022 Talk "5 Steps to Speed Up Your Data-Analysis on a Single Core"

5 Steps to Speed Up Your Data-Analysis on a Single Core Material for my talk at the PyConDE & PyData Berlin 2022 Description Your data analysis pipeli

Jonathan Striebel 9 Dec 12, 2022
Official implementation of Densely connected normalizing flows

Densely connected normalizing flows This repository is the official implementation of NeurIPS 2021 paper Densely connected normalizing flows. Poster a

Matej Grcić 31 Dec 12, 2022
Hands-On Machine Learning for Algorithmic Trading, published by Packt

Hands-On Machine Learning for Algorithmic Trading Hands-On Machine Learning for Algorithmic Trading, published by Packt This is the code repository fo

Packt 981 Dec 29, 2022