A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Fluxos de captura e subida de dados no datalake da Prefeitura do Rio de Janeiro.

Pipelines Este repositório contém fluxos de captura e subida de dados no datalake da Prefeitura do Rio de Janeiro. O repositório é gerido pelo Escritó

Prefeitura do Rio de Janeiro 19 Dec 15, 2022
Integration of CCURE access control system with automation HVAC of a commercial building

API-CCURE-Automation-Quantity-Floor Integration of CCURE access control system with automation HVAC of a commercial building CCURE is an access contro

Alexandre Edson Silva Pereira 1 Nov 24, 2021
Python MapReduce library written in Cython.

Python MapReduce library written in Cython. Visit us in #hadoopy on freenode. See the link below for documentation and tutorials.

Brandyn White 243 Sep 16, 2022
Korg Volca Sample uploader for linux.

GnuVolca Korg Volca Sample uploader for linux. GnuVolca Usage Installation Via virtualenv Usage Store all the samples you want to upload on an empty d

Gonzalo Rafuls 12 Oct 11, 2022
Agora-token-helper - Some help tools for AgoraToken

Agora Token Helper Support AgoraToken version 001 - 006. But for security reason

A Classroom Engagement Platform

Project Introduction This is project introduction Setup Setting up Postgres This is the most tricky part when setting up the application. You will nee

Santosh Kumar Patro 1 Nov 18, 2021
The LiberaPay archive module for the SeanPM life archive project.

By: Top README.md Read this article in a different language Sorted by: A-Z Sorting options unavailable ( af Afrikaans Afrikaans | sq Shqiptare Albania

Sean P. Myrick V19.1.7.2 1 Aug 26, 2022
The calculator on Python.

Calculator Contributors: Delitanast An official website. Information Hello! I am Damir. It`s my first Python project. I think you want see this. I imp

3 Mar 13, 2022
Pygments is a generic syntax highlighter written in Python

Welcome to Pygments This is the source of Pygments. It is a generic syntax highlighter written in Python that supports over 500 languages and text for

1.2k Jan 06, 2023
A sandpit for textual related things

A sandpit repo for testing textual related things.

Craig Gumbley 1 Nov 08, 2021
School helper, helps you at your pyllabus's.

pyllabus, helps you at your syllabus's... WARNING: It won't run without config.py! You should add config.py yourself, it will include your APIKEY. e.g

Ahmet Efe AKYAZI 6 Aug 07, 2022
Participants of Bertelsmann Technology Scholarship created an awesome list of resources and they want to share it with the world, if you find illegal resources please report to us and we will remove.

Participants of Bertelsmann Technology Scholarship created an awesome list of resources and they want to share it with the world, if you find illegal

Wissem Marzouki 29 Nov 28, 2022
Interactive class notebooks for ECE4076 Computer Vision, weeks 1 - 6

ECE4076 Interactive class notebooks for ECE4076 Computer Vision, weeks 1 - 6. ECE4076 is a computer vision unit at Monash University, covering both cl

Michael Burke 9 Jun 16, 2022
Яндекс тренировки по алгоритмам. Июнь 2021

Young&&Yandex Тренировки по алгоритмам Если вы хотите попасть на летнюю стажировку в Яндекс, но пока не уверены в своих силах, приходите на наши трени

Podlevskiy Viktor 6 Sep 03, 2021
Paintbot - Forward & Inverse Kinematics

PAINTBOT - FORWARD & INVERSE KINEMATICS: Overview: We built a simulation of a RRR robot shown in the figure below. The robot has 3 links and is connec

Alex Lin 1 Oct 21, 2021
A fluid medium for storing, relating, and surfacing thoughts.

Conceptarium A fluid medium for storing, relating, and surfacing thoughts. Read more... Instructions The conceptarium takes up about 1GB RAM when runn

115 Dec 19, 2022
Think DSP: Digital Signal Processing in Python, by Allen B. Downey.

ThinkDSP LaTeX source and Python code for Think DSP: Digital Signal Processing in Python, by Allen B. Downey. The premise of this book (and the other

Allen Downey 3.2k Jan 08, 2023
Audio2Face - a project that transforms audio to blendshape weights,and drives the digital human,xiaomei,in UE project

Audio2Face - a project that transforms audio to blendshape weights,and drives the digital human,xiaomei,in UE project

FACEGOOD 732 Jan 08, 2023
A simple package for interacting with the 9kw.eu anti-captcha service.

Welcome to captcha9kw’s documentation! captcha9kw is a smallish Python package for making use of the 9kw.eu services, including solving of interactive

2 Feb 26, 2022
serological measurements from multiplexed ELISA assays

pysero pysero enables serological measurements with multiplexed and standard ELISA assays. The project automates estimation of antibody titers from da

Chan Zuckerberg Biohub 5 Aug 06, 2022