Labels4Free: Unsupervised Segmentation using StyleGAN

Overview

Labels4Free: Unsupervised Segmentation using StyleGAN

ICCV 2021

image Figure: Some segmentation masks predicted by Labels4Free Framework on real and synthetic images

We propose an unsupervised segmentation framework for StyleGAN generated objects. We build on two main observations. First, the features generated by StyleGAN hold valuable information that can be utilized towards training segmentation networks. Second, the foreground and background can often be treated to be largely independent and be swapped across images to produce plausible composited images. For our solution, we propose to augment the Style-GAN2 generator architecture with a segmentation branch and to split the generator into a foreground and background network. This enables us to generate soft segmentation masks for the foreground object in an unsupervised fashion. On multiple object classes, we report comparable results against state-of-the-art supervised segmentation networks, while against the best unsupervised segmentation approach we demonstrate a clear improvement, both in qualitative and quantitative metrics.

Labels4Free: Unsupervised Segmentation Using StyleGAN (ICCV 2021)
Rameen Abdal, Peihao Zhu, Niloy Mitra, Peter Wonka
KAUST, Adobe Research

[Paper] [Project Page] [Video]

Installation

Clone this repo.

git clone https://github.com/RameenAbdal/Labels4Free.git
cd Labels4Free/

This repo is based on the Pytorch implementation of StyleGAN2 (rosinality/stylegan2-pytorch). Refer to this repo for setting up the environment, preparation of LMDB datasets and downloading pretrained weights of the models.

Download the pretrained weights of Alpha Networks here

Training the models

The models were trained on 4 RTX 2080 (24 GB) GPUs. In order to train the models using the settings in the paper use the following commands for each dataset.

Checkpoints and samples are saved in ./checkpoint and ./sample folders.

FFHQ dataset

python -m torch.distributed.launch --nproc_per_node=4 train.py --size 1024 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [FFHQ_CONFIG-F_CHECKPOINT]--loss_multiplier 1.2 --iter 1200 --trunc 1.0 --lr 0.0002 --reproduce_model

LSUN-Horse dataset

python -m torch.distributed.launch --nproc_per_node=4 train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_HORSE_CONFIG-F_CHECKPOINT] --loss_multiplier 3 --iter 500 --trunc 1.0 --lr 0.0002 --reproduce_model

LSUN-Cat dataset

python -m torch.distributed.launch --nproc_per_node=4 train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAT_CONFIG-F_CHECKPOINT]  --loss_multiplier 3 --iter 900 --trunc 0.5 --lr 0.0002 --reproduce_model

LSUN-Car dataset

python train.py --size 512 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAR_CONFIG-F_CHECKPOINT] --loss_multiplier 10 --iter 50 --trunc 0.3 --lr 0.002 --sat_weight 1.0 --model_save_freq 25 --reproduce_model --use_disc

In order to train your own models using different settings e.g on a single GPU, using different samples, iterations etc. use the following commands.

FFHQ dataset

python train.py --size 1024 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [FFHQ_CONFIG-F_CHECKPOINT] --loss_multiplier 1.2 --iter 2000 --trunc 1.0 --lr 0.0002 --bg_coverage_wt 3 --bg_coverage_value 0.4

LSUN-Horse dataset

python train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_HORSE_CONFIG-F_CHECKPOINT] --loss_multiplier 3 --iter 2000 --trunc 1.0 --lr 0.0002 --bg_coverage_wt 6 --bg_coverage_value 0.6

LSUN-Cat dataset

python train.py --size 256 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAT_CONFIG-F_CHECKPOINT] --loss_multiplier 3 --iter 2000 --trunc 0.5 --lr 0.0002 --bg_coverage_wt 4 --bg_coverage_value 0.35

LSUN-Car dataset

python train.py --size 512 [LMDB_DATASET_PATH] --batch 2 --n_sample 8 --ckpt [LSUN_CAR_CONFIG-F_CHECKPOINT] --loss_multiplier 20 --iter 750 --trunc 0.3 --lr 0.0008 --sat_weight 0.1 --bg_coverage_wt 40 --bg_coverage_value 0.75 --model_save_freq 50

Sample from the pretrained model

Samples are saved in ./test_sample folder.

python test_sample.py --size [SIZE] --batch 2 --n_sample 100 --ckpt_bg_extractor [ALPHANETWORK_MODEL] --ckpt_generator [GENERATOR_MODEL] --th 0.9

Results on Custom dataset

Folder: Custom dataset, predicted and ground truth masks.

python test_customdata.py --path_gt [GT_Folder] --path_pred [PRED_FOLDER]

Citation

@InProceedings{Abdal_2021_ICCV,
    author    = {Abdal, Rameen and Zhu, Peihao and Mitra, Niloy J. and Wonka, Peter},
    title     = {Labels4Free: Unsupervised Segmentation Using StyleGAN},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {13970-13979}
}

Acknowledgments

This implementation builds upon the Pytorch implementation of StyleGAN2 (rosinality/stylegan2-pytorch). This work was supported by Adobe Research and KAUST Office of Sponsored Research (OSR).

Owner
PhD @ KAUST
Few-shot NLP benchmark for unified, rigorous eval

FLEX FLEX is a benchmark and framework for unified, rigorous few-shot NLP evaluation. FLEX enables: First-class NLP support Support for meta-training

AI2 85 Dec 03, 2022
Pytorch Implementation of Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic

Pytorch Implementation of Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic [Paper] [Colab is coming soon] Approach Example Usage To r

170 Jan 03, 2023
Certis - Certis, A High-Quality Backtesting Engine

Certis - Backtesting For y'all Certis is a powerful, lightweight, simple backtes

Yeachan-Heo 46 Oct 30, 2022
Deep Learning applied to Integral data analysis

DeepIntegralCompton Deep Learning applied to Integral data analysis Module installation Move to the root directory of the project and execute : pip in

Thomas Vuillaume 1 Dec 10, 2021
Using Clinical Drug Representations for Improving Mortality and Length of Stay Predictions

Using Clinical Drug Representations for Improving Mortality and Length of Stay Predictions Usage Clone the code to local. https://github.com/tanlab/MI

Computational Biology and Machine Learning lab @ TOBB ETU 3 Oct 18, 2022
A Human-in-the-Loop workflow for creating HD images from text

A Human-in-the-Loop? workflow for creating HD images from text DALLĀ·E Flow is an interactive workflow for generating high-definition images from text

Jina AI 2.5k Jan 02, 2023
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 03, 2023
A rule learning algorithm for the deduction of syndrome definitions from time series data.

README This project provides a rule learning algorithm for the deduction of syndrome definitions from time series data. Large parts of the algorithm a

0 Sep 24, 2021
Adversarial Learning for Semi-supervised Semantic Segmentation, BMVC 2018

Adversarial Learning for Semi-supervised Semantic Segmentation This repo is the pytorch implementation of the following paper: Adversarial Learning fo

Wayne Hung 464 Dec 19, 2022
Rust bindings for the C++ api of PyTorch.

tch-rs Rust bindings for the C++ api of PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorc

Laurent Mazare 2.3k Dec 30, 2022
Unified learning approach for egocentric hand gesture recognition and fingertip detection

Unified Gesture Recognition and Fingertip Detection A unified convolutional neural network (CNN) algorithm for both hand gesture recognition and finge

Mohammad 227 Dec 25, 2022
AI Virtual Calculator: This is a simple virtual calculator based on Artificial intelligence.

AI Virtual Calculator: This is a simple virtual calculator that works with gestures using OpenCV. We will use our hand in the air to click on the calc

Md. Rakibul Islam 1 Jan 13, 2022
CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation

CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation We propose a novel approach to translate unpaired contrast computed

Nicolae Catalin Ristea 13 Jan 02, 2023
A large-image collection explorer and fast classification tool

IMAX: Interactive Multi-image Analysis eXplorer This is an interactive tool for visualize and classify multiple images at a time. It written in Python

Matias Carrasco Kind 23 Dec 16, 2022
Hybrid Neural Fusion for Full-frame Video Stabilization

FuSta: Hybrid Neural Fusion for Full-frame Video Stabilization Project Page | Video | Paper | Google Colab Setup Setup environment for [Yu and Ramamoo

Yu-Lun Liu 430 Jan 04, 2023
Fake-user-agent-traffic-geneator - Python CLI Tool to generate fake traffic against URLs with configurable user-agents

Fake traffic generator for Gartner Demo Generate fake traffic to URLs with custo

New Relic Experimental 3 Oct 31, 2022
Python Assignments for the Deep Learning lectures by Andrew NG on coursera with complete submission for grading capability.

Python Assignments for the Deep Learning lectures by Andrew NG on coursera with complete submission for grading capability.

Utkarsh Agiwal 1 Feb 03, 2022
TGRNet: A Table Graph Reconstruction Network for Table Structure Recognition

TGRNet: A Table Graph Reconstruction Network for Table Structure Recognition Xue, Wenyuan, et al. "TGRNet: A Table Graph Reconstruction Network for Ta

Wenyuan 68 Jan 04, 2023
WebUAV-3M: A Benchmark Unveiling the Power of Million-Scale Deep UAV Tracking

WebUAV-3M: A Benchmark Unveiling the Power of Million-Scale Deep UAV Tracking [Paper Link] Abstract In this work, we contribute a new million-scale Un

25 Jan 01, 2023
Unsupervised Attributed Multiplex Network Embedding (AAAI 2020)

Unsupervised Attributed Multiplex Network Embedding (DMGI) Overview Nodes in a multiplex network are connected by multiple types of relations. However

Chanyoung Park 114 Dec 06, 2022