Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Overview

Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Paper

Description

Recent research has shown that numerous human-interpretable directions exist in the latent space of GANs. In this paper, we develop an automatic procedure for finding directions that lead to foreground-background image separation, and we use these directions to train an image segmentation model without human supervision. Our method is generator-agnostic, producing strong segmentation results with a wide range of different GAN architectures. Furthermore, by leveraging GANs pretrained on large datasets such as ImageNet, we are able to segment images from a range of domains without further training or finetuning. Evaluating our method on image segmentation benchmarks, we compare favorably to prior work while using neither human supervision nor access to the training data. Broadly, our results demonstrate that automatically extracting foreground-background structure from pretrained deep generative models can serve as a remarkably effective substitute for human supervision.

How to run

Dependencies

This code depends on pytorch-pretrained-gans, a repository I developed that exposes a standard interface for a variety of pretrained GANs. Install it with:

pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

The pretrained weights for most GANs are downloaded automatically. For those that are not, I have provided scripts in that repository.

There are also some standard dependencies:

Install them with:

pip install hydra-core==1.1.0dev5 pytorch_lightning albumentations tqdm retry kornia

General Approach

Our unsupervised segmentation approach has two steps: (1) finding a good direction in latent space, and (2) training a segmentation model from data and masks that are generated using this direction.

In detail, this means:

  1. We use optimization/main.py finds a salient direction (or two salient directions) in the latent space of a given pretrained GAN that leads to foreground-background image separation.
  2. We use segmentation/main.py to train a standard segmentation network (a UNet) on generated data. The data can be generated in two ways: (1) you can generate the images on-the-fly during training, or (2) you can generate the images before training the segmentation model using segmentation/generate_and_save.py and then train the segmentation network afterward. The second approach is faster, but requires more disk space (~10GB for 1 million images). We will also provide a pre-generated dataset (coming soon).

Configuration and Logging

We use Hydra for configuration and Weights and Biases for logging. With Hydra, you can specify a config file (found in configs/) with --config-name=myconfig.yaml. You can also override the config from the command line by specifying the overriding arguments (without --). For example, you can enable Weights and Biases with wandb=True and you can name the run with name=myname.

The structure of the configs is as follows:

config
├── data_gen
│   ├── generated.yaml  # <- for generating data with 1 latent direction
│   ├── generated-dual.yaml   # <- for generating data with 2 latent directions
│   ├── generator  # <- different types of GANs for generating data
│   │   ├── bigbigan.yaml
│   │   ├── pretrainedbiggan.yaml
│   │   ├── selfconditionedgan.yaml
│   │   ├── studiogan.yaml
│   │   └── stylegan2.yaml 
│   └── saved.yaml  # <- for using pre-generated data
├── optimize.yaml  # <- for optimization
└── segment.yaml   # <- for segmentation

Code Structure

The code is structured as follows:

src
├── models  # <- segmentation model
│   ├── __init__.py
│   ├── latent_shift_model.py  # <- shifts direction in latent space
│   ├── unet_model.py  # <- segmentation model
│   └── unet_parts.py
├── config  # <- configuration, explained above
│   ├── ... 
├── datasets  # <- classes for loading datasets during segmentation/generation
│   ├── __init__.py
│   ├── gan_dataset.py  # <- for generating dataset
│   ├── saved_gan_dataset.py  # <- for pre-generated dataset
│   └── real_dataset.py  # <- for evaluation datasets (i.e. real images)
├── optimization
│   ├── main.py  # <- main script
│   └── utils.py  # <- helper functions
└── segmentation
    ├── generate_and_save.py  # <- for generating a dataset and saving it to disk
    ├── main.py  # <- main script, uses PyTorch Lightning 
    ├── metrics.py  # <- for mIoU/F-score calculations
    └── utils.py  # <- helper functions

Datasets

The datasets should have the following structure. You can easily add you own datasets or use only a subset of these datasets by modifying config/segment.yaml. You should specify your directory by modifying root in that file on line 19, or by passing data_seg.root=MY_DIR using the command line whenever you call python segmentation/main.py.

├── DUT_OMRON
│   ├── DUT-OMRON-image
│   │   └── ...
│   └── pixelwiseGT-new-PNG
│       └── ...
├── DUTS
│   ├── DUTS-TE
│   │   ├── DUTS-TE-Image
│   │   │   └── ...
│   │   └── DUTS-TE-Mask
│   │       └── ...
│   └── DUTS-TR
│       ├── DUTS-TR-Image
│       │   └── ...
│       └── DUTS-TR-Mask
│           └── ...
├── ECSSD
│   ├── ground_truth_mask
│   │   └── ...
│   └── images
│       └── ...
├── CUB_200_2011
│   ├── train_images
│   │   └── ...
│   ├── train_segmentations
│   │   └── ...
│   ├── test_images
│   │   └── ...
│   └── test_segmentations
│       └── ...
└── Flowers
    ├── train_images
    │   └── ...
    ├── train_segmentations
    │   └── ...
    ├── test_images
    │   └── ...
    └── test_segmentations
        └── ...

The datasets can be downloaded from:

Training

Before training, make sure you understand the general approach (explained above).

Note: All commands are called from within the src directory.

In the example commands below, we use BigBiGAN. You can easily switch out BigBiGAN for another model if you would like to.

Optimization

PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME

This should take less than 5 minutes to run. The output will be saved in outputs/optimization/fixed-BigBiGAN-NAME/DATE/, with the final checkpoint in latest.pth.

Segmentation with precomputed generations

The recommended way of training is to generate the data first and train afterward. An example generation script would be:

PYTHONPATH=. python segmentation/generate_and_save.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.save_dir="YOUR_OUTPUT_DIR" \
data_gen.save_size=1000000 \
data_gen.kwargs.batch_size=1 \
data_gen.kwargs.generation_batch_size=128

This will generate 1 million image-label pairs and save them to YOUR_OUTPUT_DIR/images. Note that YOUR_OUTPUT_DIR should be an absolute path, not a relative one, because Hydra changes the working directory. You may also want to tune the generation_batch_size to maximize GPU utilization on your machine. It takes around 3-4 hours to generate 1 million images on a single V100 GPU.

Once you have generated data, you can train a segmentation model:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=saved \
data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE"

It takes around 3 hours on 1 GPU to complete 18000 iterations, by which point the model has converged (in fact you can probably get away with fewer steps, I would guess around ~5000).

Segmentation with on-the-fly generations

Alternatively, you can generate data while training the segmentation model. An example script would be:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.kwargs.generation_batch_size=128

Evaluation

To evaluate, set the train argument to False. For example:

python train.py \
name="eval" \
train=False \
eval_checkpoint=${checkpoint} \
data_seg.root=${DATASETS_DIR} 

Pretrained models

  • ... are coming soon!

Available GANs

It should be possible to use any GAN from pytorch-pretrained-gans, including:

Citation

@inproceedings{melaskyriazi2021finding,
  author    = {Melas-Kyriazi, Luke and Rupprecht, Christian and Laina, Iro and Vedaldi, Andrea},
  title     = {Finding an Unsupervised Image Segmenter in each of your Deep Generative Models},
  booktitle = arxiv,
  year      = {2021}
}
You might also like...
pytorch implementation of
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

pyhsmm - library for approximate unsupervised inference in Bayesian Hidden Markov Models (HMMs) and explicit-duration Hidden semi-Markov Models (HSMMs), focusing on the Bayesian Nonparametric extensions, the HDP-HMM and HDP-HSMM, mostly with weak-limit approximations. The pytorch implementation of  DG-Font: Deformable Generative Networks for Unsupervised Font Generation
The pytorch implementation of DG-Font: Deformable Generative Networks for Unsupervised Font Generation

DG-Font: Deformable Generative Networks for Unsupervised Font Generation The source code for 'DG-Font: Deformable Generative Networks for Unsupervised

Minimal PyTorch implementation of Generative Latent Optimization from the paper
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

source code for https://arxiv.org/abs/2005.11248 "Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics"

Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics This work will be published in Nature Biomedical

DeepCAD: A Deep Generative Network for Computer-Aided Design Models
DeepCAD: A Deep Generative Network for Computer-Aided Design Models

DeepCAD This repository provides source code for our paper: DeepCAD: A Deep Generative Network for Computer-Aided Design Models Rundi Wu, Chang Xiao,

TAug :: Time Series Data Augmentation using Deep Generative Models

TAug :: Time Series Data Augmentation using Deep Generative Models Note!!! The package is under development so be careful for using in production! Fea

Comments
  • pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    Hi, is the repo in the pytorch-pretrained-gans step public or is that the right URL for it? I got prompted for username and password when I tried the pip install git+ and don't see the repo at that URL: https://github.com/lukemelas/pytorch-pretrained-gans (Get 404)

    Thanks.

    opened by ModMorph 2
  • Help producing results with the StyleGAN models

    Help producing results with the StyleGAN models

    Hi there!

    I'm having trouble producing meaningful results on StyleGAN2 on AFHQ. I've been using the default setup and hyperparameters. After 50 iterations (with the default batch size of 32) I get visualisations that look initially promising: (https://i.imgur.com/eR79Wyd.png). But as training progresses, and indeed when it reaches 300 iterations, these are the visualisation results: https://i.imgur.com/36zhBzT.png.

    I've tried playing with the learning rate, and the number of iterations with no success yet. Did you have tips here or ideas as to what might be going wrong here?

    Thanks! James.

    opened by james-oldfield 1
  • bug

    bug

    Firstly, I ran PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME. And then, I ran PYTHONPATH=. python segmentation/generate_and_save.py \ name=NAME \ data_gen=generated \ data_gen/generator=bigbigan \ data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \ data_gen.save_dir="YOUR_OUTPUT_DIR" \ data_gen.save_size=1000000 \ data_gen.kwargs.batch_size=1 \ data_gen.kwargs.generation_batch_size=128 When I ran PYTHONPATH=. python segmentation/main.py \ name=NAME \ data_gen=saved \ data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE" An error occurred. The error is: Traceback (most recent call last): File "segmentation/main.py", line 98, in main kwargs = dict(images_dir=_cfg.images_dir, labels_dir=_cfg.labels_dir, omegaconf.errors.InterpolationResolutionError: KeyError raised while resolving interpolation: "Environment variable '/raid/name/gaochengli/segmentation/src/images' not found" full_key: data_seg.data[0].images_dir object_type=dict According to what you wrote, I modified the root (config/segment.yaml on line 19). Just like this "/raid/name/gaochengli/segmentation/src/images". And the folder contains all data sets,whose name is images. I wonder why such a mistake happened.

    opened by Lee-Gao 1
Owner
Luke Melas-Kyriazi
I'm student at Harvard University studying mathematics and computer science, always open to collaborate on interesting projects!
Luke Melas-Kyriazi
Official code for the publication "HyFactor: Hydrogen-count labelled graph-based defactorization Autoencoder".

HyFactor Graph-based architectures are becoming increasingly popular as a tool for structure generation. Here, we introduce a novel open-source archit

Laboratoire-de-Chemoinformatique 11 Oct 10, 2022
This repository contains various models targetting multimodal representation learning, multimodal fusion for downstream tasks such as multimodal sentiment analysis.

Multimodal Deep Learning 🎆 🎆 🎆 Announcing the multimodal deep learning repository that contains implementation of various deep learning-based model

Deep Cognition and Language Research (DeCLaRe) Lab 398 Dec 30, 2022
Deploy a ML inference service on a budget in less than 10 lines of code.

BudgetML is perfect for practitioners who would like to quickly deploy their models to an endpoint, but not waste a lot of time, money, and effort trying to figure out how to do this end-to-end.

1.3k Dec 25, 2022
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.

Deep Image Search - AI-Based Image Search Engine Deep Image Search is an AI-based image search engine that includes deep transfer learning features Ex

139 Jan 01, 2023
Visual Memorability for Robotic Interestingness via Unsupervised Online Learning (ECCV 2020 Oral and TRO)

Visual Interestingness Refer to the project description for more details. This code based on the following paper. Chen Wang, Yuheng Qiu, Wenshan Wang,

Chen Wang 36 Sep 08, 2022
MAVE: : A Product Dataset for Multi-source Attribute Value Extraction

MAVE: : A Product Dataset for Multi-source Attribute Value Extraction The dataset contains 3 million attribute-value annotations across 1257 unique ca

Google Research Datasets 89 Jan 08, 2023
FastFace: Lightweight Face Detection Framework

Light Face Detection using PyTorch Lightning

Ömer BORHAN 75 Dec 05, 2022
Official code for the ICLR 2021 paper Neural ODE Processes

Neural ODE Processes Official code for the paper Neural ODE Processes (ICLR 2021). Abstract Neural Ordinary Differential Equations (NODEs) use a neura

Cristian Bodnar 50 Oct 28, 2022
Implementation of self-attention mechanisms for general purpose. Focused on computer vision modules. Ongoing repository.

Self-attention building blocks for computer vision applications in PyTorch Implementation of self attention mechanisms for computer vision in PyTorch

AI Summer 962 Dec 23, 2022
(ICCV 2021) Official code of "Dressing in Order: Recurrent Person Image Generation for Pose Transfer, Virtual Try-on and Outfit Editing."

Dressing in Order (DiOr) 👚 [Paper] 👖 [Webpage] 👗 [Running this code] The official implementation of "Dressing in Order: Recurrent Person Image Gene

Aiyu Cui 277 Dec 28, 2022
Predictive AI layer for existing databases.

MindsDB is an open-source AI layer for existing databases that allows you to effortlessly develop, train and deploy state-of-the-art machine learning

MindsDB Inc 12.2k Jan 03, 2023
Artifacts for paper "MMO: Meta Multi-Objectivization for Software Configuration Tuning"

MMO: Meta Multi-Objectivization for Software Configuration Tuning This repository contains the data and code for the following paper that is currently

0 Nov 17, 2021
Vrcwatch - Supply the local time to VRChat as Avatar Parameters through OSC

English: README-EN.md VRCWatch VRCWatch は、VRChat 内のアバター向けに現在時刻を送信するためのプログラムです。 使

Kosaki Mezumona 17 Nov 30, 2022
CTRL-C: Camera calibration TRansformer with Line-Classification

CTRL-C: Camera calibration TRansformer with Line-Classification This repository contains the official code and pretrained models for CTRL-C (Camera ca

57 Nov 14, 2022
PyTorch wrappers for using your model in audacity!

audacitorch This package contains utilities for prepping PyTorch audio models for use in Audacity. More specifically, it provides abstract classes for

Hugo Flores García 130 Dec 14, 2022
State-of-the-art data augmentation search algorithms in PyTorch

MuarAugment Description MuarAugment is a package providing the easiest way to a state-of-the-art data augmentation pipeline. How to use You can instal

43 Dec 12, 2022
HiFT: Hierarchical Feature Transformer for Aerial Tracking (ICCV2021)

HiFT: Hierarchical Feature Transformer for Aerial Tracking Ziang Cao, Changhong Fu, Junjie Ye, Bowen Li, and Yiming Li Our paper is Accepted by ICCV 2

Intelligent Vision for Robotics in Complex Environment 55 Nov 23, 2022
implementation of paper - You Only Learn One Representation: Unified Network for Multiple Tasks

YOLOR implementation of paper - You Only Learn One Representation: Unified Network for Multiple Tasks To reproduce the results in the paper, please us

Kin-Yiu, Wong 1.8k Jan 04, 2023
This repository implements Douzero's interface to IGCA.

douzero-interface-for-ICGA This repository implements Douzero's interface to ICGA. ./douzero: This directory stores Doudizhu AI projects. ./interface:

zhanggenjin 4 Aug 07, 2022
A new version of the CIDACS-RL linkage tool suitable to a cluster computing environment.

Fully Distributed CIDACS-RL The CIDACS-RL is a brazillian record linkage tool suitable to integrate large amount of data with high accuracy. However,

Robespierre Pita 5 Nov 04, 2022