A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
serological measurements from multiplexed ELISA assays

pysero pysero enables serological measurements with multiplexed and standard ELISA assays. The project automates estimation of antibody titers from da

Chan Zuckerberg Biohub 5 Aug 06, 2022
APC Power Usage is an application which shows power consuption overtime for UPS units manufactured by APC.

APC Power Usage Introduction APC Power Usage is an application which shows power consuption overtime for UPS units manufactured by APC. Screenshoots G

Stefan Kondinski 3 Oct 08, 2021
Rates how pog a word or user is. Not random and does have *some* kind of algorithm to it.

PogRater :D Rates how pogchamp a word is :D A fun project coded by JBYT27 using Python3 Have you ever wondered how pog a word is? Well, congrats, you

an aspirin 2 Jun 25, 2022
This is a program for Carbon Emission calculator.

Summary This is a program for Carbon Emission calculator. Usage This will calculate the carbon emission by each person on various factors. Contributor

Ankit Rane 2 Feb 18, 2022
Hera is a Python framework for constructing and submitting Argo Workflows.

Hera is an Argo Workflows Python SDK. Hera aims to make workflow construction and submission easy and accessible to everyone! Hera abstracts away workflow setup details while still maintaining a cons

argoproj-labs 241 Jan 02, 2023
Set of tools to analyze Tinynuke samples

tinynuke-toolset You'll find in that repository a set of tools and scripts I developped to analyze Tinynuke samples. Dll extractor: script used to ext

Heat Miser 14 Aug 18, 2022
Estimate the Market Size for Electic and Plug-In Hybrid Vehicles In Africa

Estimate the Market Size for Electic and Plug-In Hybrid Vehicles In Africa The goal of this repository is to use open data repositories to answer the

Leonce Nshuti 0 Feb 21, 2022
Demo of connecting Rasa with Zalo

Demo of connecting Rasa with Zalo

6 Jul 25, 2022
A person does not exist image bot

A person does not exist image bot

Fayas Noushad 3 Dec 12, 2021
Easy way to build a SaaS application using Python and Dash

EasySaaS This project will be attempt to make a great starting point for your next big business as easy and efficent as possible. This project will cr

xianhu 3 Nov 17, 2022
Arabic to Roman Converter in Python

Arabic-to-Roman-Converter Made together with https://github.com/goltaraya . Arabic to Roman Converter in Python. -Instructions: 1 - Make sure you have

Pedro Lucas Tomazeti Fernandes 6 Oct 28, 2021
Attempt at creating organized collection of little handy snippets of code I'm receiving along the way

ChaosCode Attempt at creating organized collection of little handy snippets of code I'm receiving along the way I always considered coding and program

INFU 4 Nov 26, 2022
Decoupled Smoothing in Probabilistic Soft Logic

Decoupled Smoothing in Probabilistic Soft Logic Experiments for "Decoupled Smoothing in Probabilistic Soft Logic". Probabilistic Soft Logic Probabilis

Kushal Shingote 1 Feb 08, 2022
Pyfetch - Simple Fetch written in Python

pyfetch Simple Fetch written in Python Screenshots Install Clone this repository

2 Sep 02, 2022
Is a polybar module that will show you your progress in Hack The Box

HTB-Status for Polybar Is a polybar module that will show you your progress in Hack The Box indicating your current rank, global rank, points and resp

bitc0de 8 Jan 14, 2022
Transform Python source code into it's most compact representation

Python Minifier Transforms Python source code into it's most compact representation. Try it out! python-minifier currently supports Python 2.7 and Pyt

Daniel Flook 403 Jan 02, 2023
LOL英雄联盟云顶之弈挂机刷代币脚本,全自动操作,智能逻辑,功能齐全。

LOL云顶之弈挂机刷代币脚本 这是2019年全球总决赛写的一个云顶挂机脚本,python完成的。 功能: 自动拿牌卖牌 策略是高星策略,非固定阵容 自动登陆账号、打码、异常重启 战利品截图上传百度云 web中控发号,改密码,查看信息等 代码是三天赶出来的,所以有点混乱,WEB中控代码也不知道扔哪去了

77 Oct 10, 2022
Similarity checking of sign languages

Similarity checking of sign languages This repository checks for similarity betw

Tonni Das Jui 1 May 13, 2022
Replay Felica Exchange For Python

FelicaReplay Replay Felica Exchange Description Standalone Replay Module Usage Save FelicaRelay (=2.0) output to file, then python replay.py [FILE].

3 Jul 14, 2022
Backups made easy, automated, monitored and SECURED with an audited encryption

Backup Controller Backups made easy, automated, monitored and SECURED with an audited encryption. Schedules backup tasks executed by Backup Maker, upl

RiotKit 1 Jan 30, 2022