A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Submission to the HEAR2021 Challenge

Submission to the HEAR 2021 Challenge For model evaluation, python=3.8 and cuda10.2 with cudnn7.6.5 have been tested. The work uses a mixed supervised

Heinrich Dinkel 10 Dec 08, 2022
Taxonomy addition for complete trees

TACT: Taxonomic Addition for Complete Trees TACT is a Python app for stochastic polytomy resolution. It uses birth-death-sampling estimators across an

Jonathan Chang 3 Jun 07, 2022
This is a small Panel applet for the Budgie Desktop to display the battery charge of a connected Bluetooth device.

BudgieBluetoothBattery This is a small Panel applet for the Budgie Desktop to display the battery charge of a connected Bluetooth device. It uses the

Konstantin Köhring 7 Dec 05, 2022
Parser for the GeoSuite[tm] PRV export format

Parser for the GeoSuite[tm] PRV export format This library provides functionality to parse geotechnical investigation data in .prv files generated by

EMerald Geomodelling 1 Dec 17, 2021
Transform a Google Drive server into a VFX pipeline ready server

Google Drive VFX Server VFX Pipeline About The Project Quick tutorial to setup a Google Drive Server for multiple machines access, and VFX Pipeline on

Valentin Beaumont 17 Jun 27, 2022
Bad Apple printed out on the console with Python!

bad-apple Bad Apple printed out on the console with Python! Preface A word of disclaimer, while the final code is somewhat original, this project is a

CalvinLoke 186 Dec 01, 2022
chiarose(XCR) based on chia(XCH) source code fork, open source public chain

chia-rosechain 一个无耻的小活动 | A shameless little event 如果您喜欢这个项目,请点击star 将赠送您520朵玫瑰,可以去 facebook 留下您的(xcr)地址,和github用户名。 If you like this project, please

ddou123 376 Dec 14, 2022
A password genarator/manager for passwords uesing a pseudorandom number genarator

pseudorandom-password-genarator a password genarator/manager for passwords uesing a pseudorandom number genarator when you give the program a word eg

1 Nov 18, 2021
A website to collect vintage 4 tracks cassette recorders.

Vintage 4tk cassette recorders A website to collect vintage 4 tracks cassette recorders. Local development setup Copy and customize Django settings (e

1 May 01, 2022
A simple BrainF**k compiler written in Python

bf-comp A simple BrainF**k compiler written in Python. What else were you looking for?

1 Jan 09, 2022
A visidata plugin for parsing f5 ltm/gtm/audit logs

F5 Log Visidata Plugin This plugin supports the default log format for: /var/log/ltm* /var/log/gtm* /var/log/apm* /var/log/audit* It extracts common l

James Deucker 1 Jan 06, 2022
Buffer Overflows

BOF Buffer Overflows 1. BOF tips Practice using mona.py Download vulnerable exe from Exploit DB.

Vinh Nguyễn 27 Dec 08, 2022
Personal Finance Forecaster - An AI tool for forecasting personal expenses

Personal Finance Forecaster - An AI tool for forecasting personal expenses

2 Mar 09, 2022
The most hackable keyboard in all the land

MiRage Modular Keyboard © 2021 Zack Freedman of Voidstar Lab Licensed Creative Commons 4.0 Attribution Noncommercial Share-Alike The MiRage is a 60% o

Zack Freedman 558 Dec 30, 2022
Python language from the beginning.

Python For Beginners Python Programming Language ♦️ Python is a very powerful and user friendly programming language. ❄️ ♦️ There are some basic sytax

Randula Yashasmith Mawaththa 6 Sep 18, 2022
Absolute solvation free energy calculations with OpenFF and OpenMM

ABsolute SOLVantion Free Energy Calculations The absolv framework aims to offer a simple API for computing the change in free energy when transferring

7 Dec 07, 2022
Aerial Ace is a helper bot for poketwo which provide various functionalities on top of being a pokedex.

Aerial Ace is a helper bot for poketwo which provide various functionalities on top of being a pokedex.

Devanshu Mishra 1 Dec 01, 2021
Shopify Backend Developer Intern Challenge - Summer 2022

Shopify Backend Developer Intern The task is build an inventory tracking web application for a logistics company. The detailed task details can be fou

Meet Gandhi 11 Oct 08, 2022
Utility to play with ADCS, allows to request tickets and collect information about related objects

certi Utility to play with ADCS, allows to request tickets and collect information about related objects. Basically, it's the impacket copy of Certify

Eloy 185 Dec 29, 2022
Template for pre-commit hooks

Pre-commit hook template This repo is a template for a pre-commit hook. Try it out by running: pre-commit try-repo https://github.com/stefsmeets/pre-c

Stef Smeets 1 Dec 09, 2021