DABO: Data Augmentation with Bilevel Optimization

Overview

License

figure figure

DABO: Data Augmentation with Bilevel Optimization [Paper]

The goal is to automatically learn an efficient data augmentation regime for image classification.

Accepted at WACV2021

Table of Contents

Overview

What's new: This method provides a way to automatically learn data augmentation in order to improve the image classification performance. It does not require us to hard code augmentation techniques, which might need domain knowledge or an expensive hyper-parameter search on the validation set.

Key insight: Our method efficiently trains a network that performs data augmentation. This network learns data augmentation by usiing the gradient that flows from computing the classifier's validation loss using an online version of bilevel optimization. We also perform truncated back-propagation in order to significantly reduce the computational cost of bilevel optimization.

How it works: Our method jointly trains a classifier and an augmentation network through the following steps,

figure

  • For each mini batch,a forward pass is made to calculate the training loss.
  • Based on the training loss and the gradient of the training loss, an optimization step is made for the classifier in the inner loop.
  • A forward pass is then made on the classifier with the new weight to calculate the validation loss.
  • The gradient from the validation loss is backpropagated to train the augmentation network.

Results: Our model obtains better results than carefuly hand engineered transformations and GAN-based approaches. Further, the results are competitive against methods that use a policy search on CIFAR10, CIFAR100, BACH, Tiny-Imagenet and Imagenet datasets.

Why it matters: Proper data augmentation can significantly improve generalization performance. Unfortunately, deriving these augmentations require domain expertise or extensive hyper-parameter search. Thus, having an automatic and quick way of identifying efficient data augmentation has a big impact in obtaining better models.

Where to go from here: Performance can be improved by extending the set of learned transformations to non-differentiable transformations. The estimation of the validation loss could also be improved by exploring more the influence of the number of iteration in the inner loop. Finally, the method can be extended to other tasks like object detection of image segmentation.

Experiments

1. Install requirements: Run this command to install the Haven library which helps in managing experiments.

pip install -r requirements.txt

2.1 CIFAR10 experiments: The followng command runs the training and validation loop for CIFAR.

python trainval.py -e cifar -sb ../results -d ../data -r 1

where -e defines the experiment group, -sb is the result directory, and -d is the dataset directory.

2.2 BACH experiments: The followng command runs the training and validation loop on BACH dataset.

python trainval.py -e bach -sb ../results -d ../data -r 1

where -e defines the experiment group, -sb is the result directory, and -d is the dataset directory.

3. Results: Display the results by following the steps below,

figure

Launch Jupyter by running the following on terminal,

jupyter nbextension enable --py widgetsnbextension
jupyter notebook

Then, run the following script on a Jupyter cell,

from haven import haven_jupyter as hj
from haven import haven_results as hr
from haven import haven_utils as hu

# path to where the experiments got saved
savedir_base = ''
exp_list = None

# exp_list = hu.load_py().EXP_GROUPS[]
# get experiments
rm = hr.ResultManager(exp_list=exp_list, 
                      savedir_base=savedir_base, 
                      verbose=0
                     )
y_metrics = ['test_acc']
bar_agg = 'max'
mode = 'bar'
legend_list = ['model.netA.name']
title_list = 'dataset.name'
legend_format = 'Augmentation Netwok: {}'
filterby_list = {'dataset':{'name':'cifar10'}, 'model':{'netC':{'name':'resnet18_meta_2'}}}

# launch dashboard
hj.get_dashboard(rm, vars(), wide_display=True)

Citation

@article{mounsaveng2020learning,
  title={Learning Data Augmentation with Online Bilevel Optimization for Image Classification},
  author={Mounsaveng, Saypraseuth and Laradji, Issam and Ayed, Ismail Ben and Vazquez, David and Pedersoli, Marco},
  journal={arXiv preprint arXiv:2006.14699},
  year={2020}
}
Owner
ElementAI
ElementAI
Code implementation from my Medium blog post: [Transformers from Scratch in PyTorch]

transformer-from-scratch Code for my Medium blog post: Transformers from Scratch in PyTorch Note: This Transformer code does not include masked attent

Frank Odom 27 Dec 21, 2022
Applying CLIP to Point Cloud Recognition.

PointCLIP: Point Cloud Understanding by CLIP This repository is an official implementation of the paper 'PointCLIP: Point Cloud Understanding by CLIP'

Renrui Zhang 175 Dec 24, 2022
A simple approach to emable dense segmentation with ViT.

Vision Transformer Segmentation Network This implementation of ViT in pytorch uses a super simple and straight-forward way of generating an output of

HReynaud 5 Jan 03, 2023
This is the repository for the paper "Have I done enough planning or should I plan more?"

Metacognitive Learning Tool box https://re.is.mpg.de What Is This? This repository contains two modules used to analyse metacognitive learning in huma

0 Dec 01, 2021
Combining Latent Space and Structured Kernels for Bayesian Optimization over Combinatorial Spaces

This repository contains source code for the paper Combining Latent Space and Structured Kernels for Bayesian Optimization over Combinatorial Spaces a

9 Nov 21, 2022
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Alexis David Jacq 163 Dec 26, 2022
Resources related to our paper "CLIN-X: pre-trained language models and a study on cross-task transfer for concept extraction in the clinical domain"

CLIN-X (CLIN-X-ES) & (CLIN-X-EN) This repository holds the companion code for the system reported in the paper: "CLIN-X: pre-trained language models a

Bosch Research 4 Dec 05, 2022
FSL-Mate: A collection of resources for few-shot learning (FSL).

FSL-Mate is a collection of resources for few-shot learning (FSL). In particular, FSL-Mate currently contains FewShotPapers: a paper list which tracks

Yaqing Wang 1.5k Jan 08, 2023
PyTorch implementation of SIFT descriptor

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
5 Jan 05, 2023
neural image generation

pixray Pixray is an image generation system. It combines previous ideas including: Perception Engines which uses image augmentation and iteratively op

dribnet 398 Dec 17, 2022
GitHub repository for "Improving Video Generation for Multi-functional Applications"

Improving Video Generation for Multi-functional Applications GitHub repository for "Improving Video Generation for Multi-functional Applications" Pape

Bernhard Kratzwald 328 Dec 07, 2022
Code for A Volumetric Transformer for Accurate 3D Tumor Segmentation

VT-UNet This repo contains the supported pytorch code and configuration files to reproduce 3D medical image segmentaion results of VT-UNet. Environmen

Himashi Amanda Peiris 114 Dec 20, 2022
SciPy fixes and extensions

scipyx SciPy is large library used everywhere in scientific computing. That's why breaking backwards-compatibility comes as a significant cost and is

Nico Schlömer 16 Jul 17, 2022
A solution to ensure Crowd Management with Contactless and Safe systems.

CovidTrack A Solution to ensure Crowd Management with Contactless and Safe systems. ML Model Mask Detection Social Distancing Detection Analytics Page

Om Khare 1 Nov 10, 2021
Benchmark library for high-dimensional HPO of black-box models based on Weighted Lasso regression

LassoBench LassoBench is a library for high-dimensional hyperparameter optimization benchmarks based on Weighted Lasso regression. Note: LassoBench is

Kenan Šehić 5 Mar 15, 2022
Data and Code for paper Outlining and Filling: Hierarchical Query Graph Generation for Answering Complex Questions over Knowledge Graph is available for research purposes.

Data and Code for paper Outlining and Filling: Hierarchical Query Graph Generation for Answering Complex Questions over Knowledge Graph is available f

Yongrui Chen 5 Nov 10, 2022
Flask101 - FullStack Web Development with Python & JS - From TAQWA

Task: Create a CLI Calculator Step 0: Creating Virtual Environment $ python -m

Hossain Foysal 1 May 31, 2022
[ACM MM2021] MGH: Metadata Guided Hypergraph Modeling for Unsupervised Person Re-identification

Introduction This project is developed based on FastReID, which is an ongoing ReID project. Projects BUC In projects/BUC, we implement AAAI 2019 paper

WuYiming 7 Apr 13, 2022
Project ArXiv Citation Network

Project ArXiv Citation Network Overview This project involved the analysis of the ArXiv citation network. Usage The complete code of this project is i

Dennis Núñez-Fernández 5 Oct 20, 2022