Code and project page for ICCV 2021 paper "DisUnknown: Distilling Unknown Factors for Disentanglement Learning"

Overview

DisUnknown: Distilling Unknown Factors for Disentanglement Learning

See introduction on our project page

Requirements

  • PyTorch >= 1.8.0
  • PyYAML, for loading configuration files
  • Optional: h5py, for using the 3D Shapes dataset
  • Optional: Matplotlib, for plotting sample distributions in code space

Preparing Datasets

Dataset classes and default configurations are provided for the following datasets. See below for how to add new datasets, or you can open an issue and the author might consider adding it. Some datasets need to be prepared before using:

$ python disentangler.py prepare_data <dataset_name> --data_path </path/to/dataset>

If the dataset does not have a standard training/test split it will be split randomly. Use the --test_portion <portion> option to set the portion of test samples. Some dataset have additional options.

  • MNIST, Fashion-MNIST, QMNIST, SVHN
    • Dataset names are mnist, fashion_mnist, qmnist, svhn.
    • data_path should be the same as those for the built-in dataset classes provided by torchvision.
    • We use the full NIST digit dataset from QMNIST (what = 'nist') and it needs to be split.
    • For SVHN, set include_extra: true in dataset_args in the configuration file (this is the default) to include the extra training images in the training set.
  • 3D Chairs
    • Dataset name is chairs.
    • data_path should be the folder containing the rendered_chairs folder.
    • Needs to be split.
    • You may use --compress to down-sample all images and save them as a NumPy array of PNG-encoded bytes. Use --downsample_size <size> to set image size, default to 128. Note that this does not dictate the training-time image size, which is configured separately. Compressing the images speeds up training only slightly if a multi-processing dataloader is used but makes plotting significantly faster.
    • Unrelated to this work, but the author wants to note that this dataset curiously contains 31 azimuth angles times two altitudes for a total of 62 images for each chair with image id 031 skipped, apparently because 32 was the intended number of azimuth angles but when they rendered the images those angles were generated using numpy.linspace(0, 360, 32), ignoring the fact that 0 and 360 are the same angle, then removed the duplicated images 031 and 063 after they realized the mistake. Beware of off-by-one errors in linspace, especially if it is also circular!
  • 3D shapes
    • Dataset name is 3dshapes.
    • data_path should be the folder containing 3dshapes.h5.
    • Needs to be split.
    • You may use --compress to extract all images and then save them as a NumPy array of PNG-encoded bytes. This is mainly for space-saving: the original dataset, when extracted from HDFS, takes 5.9GB of memory. The re-compressed version takes 2.2GB. Extraction and compression takes about an hour.
  • dSprites
    • Dataset name is dsprites
    • data_path should be the folder containing the .npz file.
    • Needs to be split.
    • This dataset is problematic. I found that orientation 0 and orientation 39 are the same, and presumably that was because similar to 3D Chairs something like linspace(0, 360, 40) was used to generate the angles. So yes, I'm telling you again, beware of off-by-one errors in linspace, especially if it is also circular! Anyway in my dataset class I discarded orientation 39, so there are only 39 different orientations and 3 * 6 * 39 * 32 * 32 = 718848 images.
    • The bigger problem is that each of the three shapes (square, ellipse, heart) has a different symmetry. For hearts, each image uniquely determines an orientation angle; for ellipses, each image has two possible orientation angles; and for squares, each image has four possible orientation angles. They managed to make the dataset so that (apart from orientation 0 and 39 being the same) different orientations correspond to different images because 2 and 4 are not divisors of 39 (which makes me wonder if the off-by-one error was intentional) but the orientation is still conceptually wrong, since if you consider the orientation angles of ellipses modulo 180 or the orientation angles of squares modulo 90, then the orientation class IDs are not ordered in increasing order of orientation angles. Instead the orientation angles of ellipses go around twice in this range and the orientation angles of squares go around four times. To solve this problem, I included an option to set relabel_orientation: true in dataset_args in the configuration file (this is the default) which will cause the orientation of ellipses and squares to be re-labeled in the correct order. Specifically, for ellipses orientation t is re-labeled as (t * 2) % 39 and for squares orientation t is re-labeled as (t * 4) % 39. But still, this causes ellipses to rotate twice as slowly and squares to rotate four times as slowly when the orientation increases, which is still not ideal. When shapes with different symmetries are mixed there is simply no easy solution, and do not expect good results on this dataset if the unknown factor contains the orientation.
    • --compress does the same thing as in 3D Shapes.

Training

To train, use

$ python disentangler.py train --config_file </path/to/config/file> --save_path </path/to/save/folder>

The configuration file is in YAML. See the commented example for explanations. If config_file is omitted, it is expected that save_path already exists and contains config.yaml. Otherwise save_path will be created if it does not exist, and config_file will be copied into it. If save_path already contains a previous training run that has been halted, it will by default resume from the latest checkpoint. --start_from <stage_name> [<iteration>] can be used to choose another restarting point. --start_from stage1 to restart from scratch. Specifying --data_path or --device will override those settings in the configuration file.

Although our goal is to deal with the cases where some factors are labeled and some factors are unknown, it feels wrong not to extrapolate to the cases where all factors are labeled or where all factors are unknown. Wo do allow these, but some parts of our method will become unnecessary and will be discarded accordingly. In particular if all factors are unknown then we just train a VAE in stage I and then a GAN having the same code space in stage II, so you can use this code for just training a GAN. We don't have the myriad of GAN tricks though.

Meaning of Visualization Images

During training, images generated for visualization will be saved in the subfolder samples. test_images.jpg contains images from the test set in even-numbered columns (starting from zero), with odd-numbered columns being empty. The generated images will contain corresponding reconstructed images in even-numbered columns, while each image in odd-numbered columns is generated by combining the unknown code from its left and the labeled code from its right (warp to the next row).

Example test images:

Test images

Example generated images:

Generated_images

Adding a New Dataset

__init__() should accept four positional arguments root, part, labeled_factors, transform in that order, plus any additional keyword arguments that one expects to receive from dataset_args in the configuration file. root is the path to the dataset folder. transform is as usual. part can be train, test or plot, specifying which subset of the dataset to load. The plotting set is generally the same as the test set, but part = 'plot' is passed in so that a smaller plotting set can be used if the test set is too large.

labeled_factors is a list of factor names. __getitem__() should return a tuple (image, labels) where image is the image and labels is a one-dimensional PyTorch tensor of type torch.int64, containing the labels for that image in the order listed in labeled_factors. labels should always be a one-dimensional tensor even if there is only one labeled factor, not a Python int or a zero-dimensional tensor. If labeled_factors is empty then __getitem__() should return image only.

In addition, metadata about the factors should be available in the following properties: nclass should be a list of ints containing the number of classes of each factor, and class_freq should be a list of PyTorch tensors, each being one-dimensional, containing the distribution of classes of each factor in (the current split of) the dataset.

If any preparation is required, implement a static method prepare_data(args) where args is a return value of argparse.ArgumentParser.parse_args(), containing properties data_path and test_portion by default. If additional command-line arguments are needed, implement a static method add_prepare_args(parser) where parser.add_argument() can be called.

Finally add it to the dictionary of recognized datasets in data/__init__.py.

Default configuration should also be created as default_config/datasets/<dataset_name>.yaml. It should at a minimum contain image_size, image_channels and factors. factors has the same syntax as labeled_factors as explained in the example training configuration. It should contain a complete list of all factors. In particular, if the dataset does not include a complete set of labels, there should be a factor called unknown which will become the default unknown factor if labeled_factors is not set in the training configuration.

Any additional settings in the default configuration will override global defaults in default_config/default_config.yaml.

Citing This Work (BibTeX)

@inproceedings{xiang2021disunknown,
  title={DisUnknown: Distilling Unknown Factors for Disentanglement Learning},
  author={Xiang, Sitao and Gu, Yuming and Xiang, Pengda and Chai, Menglei and Li, Hao and Zhao, Yajie and He, Mingming},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={14810--14819},
  year={2021}
}
Owner
Sitao Xiang
Computer Graphics PhD student at University of Southern California. Twitter: StormRaiser123
Sitao Xiang
Off-policy continuous control in PyTorch, with RDPG, RTD3 & RSAC

arXiv technical report soon available. we are updating the readme to be as comprehensive as possible Please ask any questions in Issues, thanks. Intro

Zhihan 31 Dec 30, 2022
ScriptProfilerPy - Module to visualize where your python script is slow

ScriptProfiler helps you track where your code is slow It provides: Code lines t

Lucas BLP 3 Jun 02, 2022
A3C LSTM Atari with Pytorch plus A3G design

NEWLY ADDED A3G A NEW GPU/CPU ARCHITECTURE OF A3C FOR SUBSTANTIALLY ACCELERATED TRAINING!! RL A3C Pytorch NEWLY ADDED A3G!! New implementation of A3C

David Griffis 532 Jan 02, 2023
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 07, 2022
Code Repository for Liquid Time-Constant Networks (LTCs)

Liquid time-constant Networks (LTCs) [Update] A Pytorch version is added in our sister repository: https://github.com/mlech26l/keras-ncp This is the o

Ramin Hasani 553 Dec 27, 2022
On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation

On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation On Nonlinear Latent Transformations for GAN-based Image Editi

Valentin Khrulkov 22 Oct 24, 2022
face property detection pytorch

This is the face property train code of project face-detection-project

i am x 2 Oct 18, 2021
Source code for "OmniPhotos: Casual 360° VR Photography"

OmniPhotos: Casual 360° VR Photography Project Page | Video | Paper | Demo | Data This repository contains the source code for creating and viewing Om

Christian Richardt 144 Dec 30, 2022
PyTorch implementation of "Efficient Neural Architecture Search via Parameters Sharing"

Efficient Neural Architecture Search (ENAS) in PyTorch PyTorch implementation of Efficient Neural Architecture Search via Parameters Sharing. ENAS red

Taehoon Kim 2.6k Dec 31, 2022
Implementation for Stankevičiūtė et al. "Conformal time-series forecasting", NeurIPS 2021.

Conformal time-series forecasting Implementation for Stankevičiūtė et al. "Conformal time-series forecasting", NeurIPS 2021. If you use our code in yo

Kamilė Stankevičiūtė 36 Nov 21, 2022
Implementation of the Swin Transformer in PyTorch.

Swin Transformer - PyTorch Implementation of the Swin Transformer architecture. This paper presents a new vision Transformer, called Swin Transformer,

597 Jan 03, 2023
load .txt to train YOLOX, same as Yolo others

YOLOX train your data you need generate data.txt like follow format (per line- one image). prepare one data.txt like this: img_path1 x1,y1,x2,y2,clas

LiMingf 18 Aug 18, 2022
This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code for training a DPR model then continuing training with RAG.

KGI (Knowledge Graph Induction) for slot filling This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code fo

International Business Machines 72 Jan 06, 2023
Analyzing basic network responses to novel classes

novelty-detection Analyzing how AlexNet responds to novel classes with varying degrees of similarity to pretrained classes from ImageNet. If you find

Noam Eshed 34 Oct 02, 2022
Reinforcement Learning Theory Book (rus)

Reinforcement Learning Theory Book (rus)

qbrick 206 Nov 27, 2022
Official repository of ICCV21 paper "Viewpoint Invariant Dense Matching for Visual Geolocalization"

Viewpoint Invariant Dense Matching for Visual Geolocalization: PyTorch implementation This is the implementation of the ICCV21 paper: G Berton, C. Mas

Gabriele Berton 44 Jan 03, 2023
Public Implementation of ChIRo from "Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations"

Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations This directory contains the model architectures and experimental

35 Dec 05, 2022
This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametric Head Model (CVPR 2022)".

HeadNeRF: A Real-time NeRF-based Parametric Head Model This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametr

294 Jan 01, 2023
A curated list of awesome papers for Semantic Retrieval (TOIS Accepted: Semantic Models for the First-stage Retrieval: A Comprehensive Review).

A curated list of awesome papers for Semantic Retrieval (TOIS Accepted: Semantic Models for the First-stage Retrieval: A Comprehensive Review).

Yinqiong Cai 189 Dec 28, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 05, 2023