A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Writeup of NilbinSec's participation in the Winja CTF for c0c0n 2021

Winja-CTF-c0c0n-2021-Writeup NilbinSec's participation in the Winja CTF for c0c0n 2021 This repo covers NilbinSec's participation in the Winja CTF dur

1 Nov 15, 2021
Automation in socks label validation

This is a project for socks card label validation where the socks card is validated comparing with the correct socks card whose coordinates are stored in the database. When the test socks card is com

1 Jan 19, 2022
Developer guide for Hivecoin project

Hivecoin-developer Developer guide for Hivecoin project. Install Content are writen in reStructuredText (RST) and rendered with Sphinx. Much of the co

tweetyf 1 Nov 22, 2021
A python package that adds "docs" command to disnake

About This extension's purpose is of adding a "docs" command, its purpose is to help documenting in chat. How To Load It from disnake.ext import comma

7 Jan 03, 2023
You can easily send campaigns, e-marketing have actually account using cash will thank you for using our tools, and you can support our Vodafone Cash +201090788026

*** Welcome User Sorry I Mean Hello Brother ✓ Devolper and Design : Mokhtar Abdelkreem ========================================== You Can Follow Us O

Mo Code 1 Nov 03, 2021
Adds a Bake node to Blender's shader node system

Bake to Target This Blender Addon adds a new shader node type capable of reducing the texture-bake step to a single button press. Please note that thi

Thomas 8 Oct 04, 2022
Create a program for generator Truth Table

Python-Truth-Table-Ver-1.0 Create a program for generator Truth Table in here you have to install truth-table-generator module for python modules inst

JehanKandy 10 Jul 13, 2022
Check broken access control exists in the Java web application

javaEeAccessControlCheck Check broken access control exists in the Java web application. 检查 Java Web 应用程序中是否存在访问控制绕过问题。 使用 python3 javaEeAccessControl

kw0ng 3 May 04, 2022
Basic infrastructure for writing scripts in Python

Base Script Python is an excellent language that makes writing scripts very straightforward. Over the course of writing many scripts, we realized that

Deep Compute, LLC 9 Jan 07, 2023
A replacement of qsreplace, accepts URLs as standard input, replaces all query string values with user-supplied values and stdout.

Bhedak A replacement of qsreplace, accepts URLs as standard input, replaces all query string values with user-supplied values and stdout. Works on eve

Eshan Singh 84 Dec 31, 2022
A web application (with multiple API project options) that uses MariaDB HTAP!

Bookings Bookings is a web application that, backed by the power of the MariaDB Connectors and the MariaDB X4 Platform, unleashes the power of smart t

MariaDB Corporation 4 Dec 28, 2022
A python library with various gambling and gaming classes

gamble is a simple library that implements a collection of some common gambling-related classes Features die, dice, d-notation cards, decks, hands pok

Jacobi Petrucciani 16 May 24, 2022
JD扫码获取Cookie 本地版

JD扫码获取Cookie 本地版 请无视手机上的提示升级京东版本的提示! 下载链接 https://github.com/Zy143L/jd_cookie/releases 使用Python实现 代码很烂 没有做任何异常捕捉 但是能用 请不要将获取到的Cookie发送给任何陌生人 如果打开闪退 请使

Zy143L 420 Dec 11, 2022
This Curve Editor, written by Jehee Lee in 2015

Splines Abstract This Curve Editor, written by Jehee Lee in 2015, is a freeware. You can use, modify, redistribute the code without restriction. This

Movement Research Lab 8 Mar 11, 2022
A curated collection of Amazing Python scripts from Basics to Advance with automation task scripts

📑 Introduction A curated collection of Amazing Python scripts from Basics to Advance with automation task scripts. This is your Personal space to fin

Amitesh kumar mishra 1 Jan 22, 2022
To check my COVID-19 vaccine appointment, I wrote an infinite loop that sends me a Whatsapp message hourly using Twilio and Selenium. It works on my Raspberry Pi computer.

COVID-19_vaccine_appointment To check my COVID-19 vaccine appointment, I wrote an infinite loop that sends me a Whatsapp message hourly using Twilio a

Ayyuce Demirbas 24 Dec 17, 2022
Social reading and reviewing, decentralized with ActivityPub

BookWyrm Social reading and reviewing, decentralized with ActivityPub Contents Joining BookWyrm Contributing About BookWyrm What it is and isn't The r

BookWyrm 1.4k Jan 08, 2023
Gerador de dafaces

🎴 DefaceGenerator Obs: esse script foi criado com a intenção de ajudar pessoas iniciantes no hacking que ainda não conseguem criar suas próprias defa

LordShinigami 3 Jan 09, 2022
Fithub is a website application for athletes and fitness enthusiasts of all ages and experience levels.

Fithub is a website application for athletes and fitness enthusiasts of all ages and experience levels. Our website allows users to easily search, filter, and sort our comprehensive database of over

Andrew Wu 1 Dec 13, 2021
A project for Perotti's MGIS350 for incorporating Flask

MGIS350_5 This is our project for Perotti's MGIS350 for incorporating Flask... RIT Dev Biz Apps Web Project A web-based Inventory system for company o

1 Nov 07, 2021