Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

Overview

Label-Efficient Semantic Segmentation with Diffusion Models

Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

This code is based on datasetGAN and guided-diffusion.

Note: use --recurse-submodules when clone.

 

Overview

The paper investigates the representations learned by the state-of-the-art DDPMs and shows that they capture high-level semantic information valuable for downstream vision tasks. We design a simple segmentation approach that exploits these representations and outperforms the alternatives in the few-shot operating point in the context of semantic segmentation.

DDPM-based Segmentation

 

Dependencies

  • Python >= 3.7
  • Packages: see requirements.txt

 

Datasets

The evaluation is performed on 6 collected datasets with a few annotated images in the training set: Bedroom-18, FFHQ-34, Cat-15, Horse-21, CelebA-19 and ADE-Bedroom-30. The number corresponds to the number of semantic classes.

datasets.tar.gz (~47Mb)

 

DDPM

Pretrained DDPMs

The models trained on LSUN are adopted from guided-diffusion. FFHQ-256 is trained by ourselves using the same model parameters as for the LSUN models.

LSUN-Bedroom: lsun_bedroom.pt
FFHQ-256: ffhq.pt
LSUN-Cat: lsun_cat.pt
LSUN-Horse: lsun_horse.pt

Run

  1. Download the datasets:
      bash datasets/download_datasets.sh
  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh
  3. Check paths in experiments/ /ddpm.json
  4. Run: bash scripts/ddpm/train_interpreter.sh

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

How to improve the performance

  1. Set input_activations=true in experiments/ /ddpm.json .
       In this case, the feature dimension is 18432.
  2. Tune for a particular task what diffusion steps and UNet blocks to use.

 

DatasetDDPM

Synthetic datasets

To download DDPM-produced synthetic datasets (50000 samples, ~7Gb):
bash synthetic-datasets/gan/download_synthetic_dataset.sh

Run | Option #1

  1. Download the synthetic dataset:
       bash synthetic-datasets/ddpm/download_synthetic_dataset.sh
  2. Check paths in experiments/ /datasetDDPM.json
  3. Run: bash scripts/datasetDDPM/train_deeplab.sh

Run | Option #2

  1. Download the datasets:
       bash datasets/download_datasets.sh

  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh

  3. Check paths in experiments/ /datasetDDPM.json

  4. Train an interpreter on a few DDPM-produced annotated samples:
       bash scripts/datasetDDPM/train_interpreter.sh

  5. Generate a synthetic dataset:
       bash scripts/datasetDDPM/generate_dataset.sh
        Please specify the hyperparameters in this script for the available resources.
        On 8xA100 80Gb, it takes about 12 hours to generate 10000 samples.

  6. Run: bash scripts/datasetDDPM/train_deeplab.sh
       One needs to specify the path to the generated data. See comments in the script.

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

SwAV

Pretrained SwAVs

We pretrain SwAV models using the official implementation on the LSUN and FFHQ-256 datasets:

LSUN-Bedroom: lsun_bedroom.pth
FFHQ-256: ffhq.pth
LSUN-Cat: lsun_cat.pth
LSUN-Horse: lsun_horse.pth

Training setup:

Dataset epochs batch-size multi-crop num-prototypes
LSUN 200 1792 2x256 + 6x108 1000
FFHQ-256 400 2048 2x224 + 6x96 200

Run

  1. Download the datasets:
       bash datasets/download_datasets.sh
  2. Download the SwAV checkpoint:
       bash checkpoints/swav/download_checkpoint.sh
  3. Check paths in experiments/ /swav.json
  4. Run: bash scripts/swav/train_interpreter.sh

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

 

DatasetGAN

Opposed to the official implementation, more recent StyleGAN2(-ADA) models are used.

Synthetic datasets

To download GAN-produced synthetic datasets (50000 samples):

bash synthetic-datasets/gan/download_synthetic_dataset.sh

Run

Since we almost fully adopt the official implementation, we don't provide our reimplementation here. However, one can still reproduce our results:

  1. Download the synthetic dataset:
      bash synthetic-datasets/gan/download_synthetic_dataset.sh
  2. Change paths in experiments/ /datasetDDPM.json
  3. Change paths and run: bash scripts/datasetDDPM/train_deeplab.sh

Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

Results

  • Performance in terms of mean IoU:
Method Bedroom-28 FFHQ-34 Cat-15 Horse-21 CelebA-19 ADE-Bedroom-30
ALAE 20.0 ± 1.0 48.1 ± 1.3 -- -- 49.7 ± 0.7 15.0 ± 0.5
VDVAE -- 57.3 ± 1.1 -- -- 54.1 ± 1.0 --
GAN Inversion 13.9 ± 0.6 51.7 ± 0.8 21.4 ± 1.7 17.7 ± 0.4 51.5 ± 2.3 11.1 ± 0.2
GAN Encoder 22.4 ± 1.6 53.9 ± 1.3 32.0 ± 1.8 26.7 ± 0.7 53.9 ± 0.8 15.7 ± 0.3
SwAV 41.0 ± 2.3 54.7 ± 1.4 44.1 ± 2.1 51.7 ± 0.5 53.2 ± 1.0 30.3 ± 1.5
DatasetGAN 31.3 ± 2.7 57.0 ± 1.0 36.5 ± 2.3 45.4 ± 1.4 -- --
DatasetDDPM 46.9 ± 2.8 56.0 ± 0.9 45.4 ± 2.8 60.4 ± 1.2 -- --
DDPM 46.1 ± 1.9 57.0 ± 1.4 52.3 ± 3.0 63.1 ± 0.9 57.0 ± 1.0 32.3 ± 1.5

 

  • Examples of segmentation masks predicted by the DDPM-based method:
DDPM-based Segmentation

 

Cite

@misc{baranchuk2021labelefficient,
      title={Label-Efficient Semantic Segmentation with Diffusion Models}, 
      author={Dmitry Baranchuk and Ivan Rubachev and Andrey Voynov and Valentin Khrulkov and Artem Babenko},
      year={2021},
      eprint={2112.03126},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Owner
Yandex Research
Yandex Research
Intel® Nervana™ reference deep learning framework committed to best performance on all hardware

DISCONTINUATION OF PROJECT. This project will no longer be maintained by Intel. Intel will not provide or guarantee development of or support for this

Nervana 3.9k Dec 20, 2022
Library for time-series-forecasting-as-a-service.

TIMEX TIMEX (referred in code as timexseries) is a framework for time-series-forecasting-as-a-service. Its main goal is to provide a simple and generi

Alessandro Falcetta 8 Jan 06, 2023
Unofficial implementation (replicates paper results!) of MINER: Multiscale Implicit Neural Representations in pytorch-lightning

MINER_pl Unofficial implementation of MINER: Multiscale Implicit Neural Representations in pytorch-lightning. 📖 Ref readings Laplacian pyramid explan

AI葵 51 Nov 28, 2022
pytorch implementation of trDesign

trdesign-pytorch This repository is a PyTorch implementation of the trDesign paper based on the official TensorFlow implementation. The initial port o

Learn Ventures Inc. 41 Dec 29, 2022
Summary Explorer is a tool to visually explore the state-of-the-art in text summarization.

Summary Explorer Summary Explorer is a tool to visually inspect the summaries from several state-of-the-art neural summarization models across multipl

Webis 42 Aug 14, 2022
Hierarchical Attentive Recurrent Tracking

Hierarchical Attentive Recurrent Tracking This is an official Tensorflow implementation of single object tracking in videos by using hierarchical atte

Adam Kosiorek 147 Aug 07, 2021
Official PyTorch implementation of RIO

Image-Level or Object-Level? A Tale of Two Resampling Strategies for Long-Tailed Detection Figure 1: Our proposed Resampling at image-level and obect-

NVIDIA Research Projects 17 May 20, 2022
Home for cuQuantum Python & NVIDIA cuQuantum SDK C++ samples

Welcome to the cuQuantum repository! This public repository contains two sets of files related to the NVIDIA cuQuantum SDK: samples: All C/C++ sample

NVIDIA Corporation 147 Dec 27, 2022
A benchmark dataset for emulating atmospheric radiative transfer in weather and climate models with machine learning (NeurIPS 2021 Datasets and Benchmarks Track)

ClimART - A Benchmark Dataset for Emulating Atmospheric Radiative Transfer in Weather and Climate Models Official PyTorch Implementation Using deep le

21 Dec 31, 2022
A scanpy extension to analyse single-cell TCR and BCR data.

Scirpy: A Scanpy extension for analyzing single-cell immune-cell receptor sequencing data Scirpy is a scalable python-toolkit to analyse T cell recept

ICBI 145 Jan 03, 2023
Certis - Certis, A High-Quality Backtesting Engine

Certis - Backtesting For y'all Certis is a powerful, lightweight, simple backtes

Yeachan-Heo 46 Oct 30, 2022
This is a project based on ConvNets used to identify whether a road is clean or dirty. We have used MobileNet as our base architecture and the weights are based on imagenet.

PROJECT TITLE: CLEAN/DIRTY ROAD DETECTION USING TRANSFER LEARNING Description: This is a project based on ConvNets used to identify whether a road is

Faizal Karim 3 Nov 06, 2022
Put blind watermark into a text with python

text_blind_watermark Put blind watermark into a text. Can be used in Wechat dingding ... How to Use install pip install text_blind_watermark Alice Pu

郭飞 164 Dec 30, 2022
Official Repsoitory for "Mish: A Self Regularized Non-Monotonic Neural Activation Function" [BMVC 2020]

Mish: Self Regularized Non-Monotonic Activation Function BMVC 2020 (Official Paper) Notes: (Click to expand) A considerably faster version based on CU

Xa9aX ツ 1.2k Dec 29, 2022
DziriBERT: a Pre-trained Language Model for the Algerian Dialect

DziriBERT DziriBERT is the first Transformer-based Language Model that has been pre-trained specifically for the Algerian Dialect. It handles Algerian

117 Jan 07, 2023
Asymmetric metric learning for knowledge transfer

Asymmetric metric learning This is the official code that enables the reproduction of the results from our paper: Asymmetric metric learning for knowl

20 Dec 06, 2022
Implementation of ICCV21 paper: PnP-DETR: Towards Efficient Visual Analysis with Transformers

Implementation of ICCV 2021 paper: PnP-DETR: Towards Efficient Visual Analysis with Transformers arxiv This repository is based on detr Recently, DETR

twang 113 Dec 27, 2022
Source code for paper "ATP: AMRize Than Parse! Enhancing AMR Parsing with PseudoAMRs" @NAACL-2022

ATP: AMRize Then Parse! Enhancing AMR Parsing with PseudoAMRs Hi this is the source code of our paper "ATP: AMRize Then Parse! Enhancing AMR Parsing w

Chen Liang 13 Nov 23, 2022
A simple, fully convolutional model for real-time instance segmentation.

You Only Look At CoefficienTs ██╗ ██╗ ██████╗ ██╗ █████╗ ██████╗████████╗ ╚██╗ ██╔╝██╔═══██╗██║ ██╔══██╗██╔════╝╚══██╔══╝ ╚██

Daniel Bolya 4.6k Dec 30, 2022
Visualizing lattice vibration information from phonon dispersion to atoms (For GPUMD)

Phonon-Vibration-Viewer (For GPUMD) Visualizing lattice vibration information from phonon dispersion for primitive atoms. In this tutorial, we will in

Liangting 6 Dec 10, 2022