DABO: Data Augmentation with Bilevel Optimization

Overview

License

figure figure

DABO: Data Augmentation with Bilevel Optimization [Paper]

The goal is to automatically learn an efficient data augmentation regime for image classification.

Accepted at WACV2021

Table of Contents

Overview

What's new: This method provides a way to automatically learn data augmentation in order to improve the image classification performance. It does not require us to hard code augmentation techniques, which might need domain knowledge or an expensive hyper-parameter search on the validation set.

Key insight: Our method efficiently trains a network that performs data augmentation. This network learns data augmentation by usiing the gradient that flows from computing the classifier's validation loss using an online version of bilevel optimization. We also perform truncated back-propagation in order to significantly reduce the computational cost of bilevel optimization.

How it works: Our method jointly trains a classifier and an augmentation network through the following steps,

figure

  • For each mini batch,a forward pass is made to calculate the training loss.
  • Based on the training loss and the gradient of the training loss, an optimization step is made for the classifier in the inner loop.
  • A forward pass is then made on the classifier with the new weight to calculate the validation loss.
  • The gradient from the validation loss is backpropagated to train the augmentation network.

Results: Our model obtains better results than carefuly hand engineered transformations and GAN-based approaches. Further, the results are competitive against methods that use a policy search on CIFAR10, CIFAR100, BACH, Tiny-Imagenet and Imagenet datasets.

Why it matters: Proper data augmentation can significantly improve generalization performance. Unfortunately, deriving these augmentations require domain expertise or extensive hyper-parameter search. Thus, having an automatic and quick way of identifying efficient data augmentation has a big impact in obtaining better models.

Where to go from here: Performance can be improved by extending the set of learned transformations to non-differentiable transformations. The estimation of the validation loss could also be improved by exploring more the influence of the number of iteration in the inner loop. Finally, the method can be extended to other tasks like object detection of image segmentation.

Experiments

1. Install requirements: Run this command to install the Haven library which helps in managing experiments.

pip install -r requirements.txt

2.1 CIFAR10 experiments: The followng command runs the training and validation loop for CIFAR.

python trainval.py -e cifar -sb ../results -d ../data -r 1

where -e defines the experiment group, -sb is the result directory, and -d is the dataset directory.

2.2 BACH experiments: The followng command runs the training and validation loop on BACH dataset.

python trainval.py -e bach -sb ../results -d ../data -r 1

where -e defines the experiment group, -sb is the result directory, and -d is the dataset directory.

3. Results: Display the results by following the steps below,

figure

Launch Jupyter by running the following on terminal,

jupyter nbextension enable --py widgetsnbextension
jupyter notebook

Then, run the following script on a Jupyter cell,

from haven import haven_jupyter as hj
from haven import haven_results as hr
from haven import haven_utils as hu

# path to where the experiments got saved
savedir_base = ''
exp_list = None

# exp_list = hu.load_py().EXP_GROUPS[]
# get experiments
rm = hr.ResultManager(exp_list=exp_list, 
                      savedir_base=savedir_base, 
                      verbose=0
                     )
y_metrics = ['test_acc']
bar_agg = 'max'
mode = 'bar'
legend_list = ['model.netA.name']
title_list = 'dataset.name'
legend_format = 'Augmentation Netwok: {}'
filterby_list = {'dataset':{'name':'cifar10'}, 'model':{'netC':{'name':'resnet18_meta_2'}}}

# launch dashboard
hj.get_dashboard(rm, vars(), wide_display=True)

Citation

@article{mounsaveng2020learning,
  title={Learning Data Augmentation with Online Bilevel Optimization for Image Classification},
  author={Mounsaveng, Saypraseuth and Laradji, Issam and Ayed, Ismail Ben and Vazquez, David and Pedersoli, Marco},
  journal={arXiv preprint arXiv:2006.14699},
  year={2020}
}
Owner
ElementAI
ElementAI
Single Image Deraining Using Bilateral Recurrent Network (TIP 2020)

Single Image Deraining Using Bilateral Recurrent Network Introduction Single image deraining has received considerable progress based on deep convolut

23 Aug 10, 2022
Intrusion Test Tool with Python

P3ntsT00L Uma ferramenta escrita em Python, feita para Teste de intrusão. Requisitos ter o python 3.9.8 instalado em sua máquina. ter a git instalada

josh washington 2 Dec 27, 2021
Semi-supervised Transfer Learning for Image Rain Removal. In CVPR 2019.

Semi-supervised Transfer Learning for Image Rain Removal This package contains the Python implementation of "Semi-supervised Transfer Learning for Ima

Wei Wei 59 Dec 26, 2022
The all new way to turn your boring vector meshes into the new fad in town; Voxels!

Voxelator The all new way to turn your boring vector meshes into the new fad in town; Voxels! Notes: I have not tested this on a rotated mesh. With fu

6 Feb 03, 2022
code for paper "Not All Unlabeled Data are Equal: Learning to Weight Data in Semi-supervised Learning" by Zhongzheng Ren*, Raymond A. Yeh*, Alexander G. Schwing.

Not All Unlabeled Data are Equal: Learning to Weight Data in Semi-supervised Learning Overview This code is for paper: Not All Unlabeled Data are Equa

Jason Ren 22 Nov 23, 2022
A minimalist implementation of score-based diffusion model

sdeflow-light This is a minimalist codebase for training score-based diffusion models (supporting MNIST and CIFAR-10) used in the following paper "A V

Chin-Wei Huang 89 Dec 20, 2022
A pre-trained language model for social media text in Spanish

RoBERTuito A pre-trained language model for social media text in Spanish READ THE FULL PAPER Github Repository RoBERTuito is a pre-trained language mo

25 Dec 29, 2022
Show-attend-and-tell - TensorFlow Implementation of "Show, Attend and Tell"

Show, Attend and Tell Update (December 2, 2016) TensorFlow implementation of Show, Attend and Tell: Neural Image Caption Generation with Visual Attent

Yunjey Choi 902 Nov 29, 2022
(ImageNet pretrained models) The official pytorch implemention of the TPAMI paper "Res2Net: A New Multi-scale Backbone Architecture"

Res2Net The official pytorch implemention of the paper "Res2Net: A New Multi-scale Backbone Architecture" Our paper is accepted by IEEE Transactions o

Res2Net Applications 928 Dec 29, 2022
Neighborhood Contrastive Learning for Novel Class Discovery

Neighborhood Contrastive Learning for Novel Class Discovery This repository contains the official implementation of our paper: Neighborhood Contrastiv

Zhun Zhong 56 Dec 09, 2022
Public repo for the ICCV2021-CVAMD paper "Is it Time to Replace CNNs with Transformers for Medical Images?"

Is it Time to Replace CNNs with Transformers for Medical Images? Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (C

Christos Matsoukas 80 Dec 27, 2022
Robot Reinforcement Learning on the Constraint Manifold

Implementation of "Robot Reinforcement Learning on the Constraint Manifold"

31 Dec 05, 2022
SeisComP/SeisBench interface to enable deep-learning (re)picking in SeisComP

scdlpicker SeisComP/SeisBench interface to enable deep-learning (re)picking in SeisComP Objective This is a simple deep learning (DL) repicker module

Joachim Saul 6 May 13, 2022
AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation

AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation AniGAN: Style-Guided Generative Adversarial Networks for U

Bing Li 81 Dec 14, 2022
Source code for deep symbolic optimization.

Update July 10, 2021: This repository now supports an additional symbolic optimization task: learning symbolic policies for reinforcement learning. Th

Brenden Petersen 290 Dec 25, 2022
HackBMU-5.0-Team-Ctrl-Alt-Elite - HackBMU 5.0 Team Ctrl Alt Elite

HackBMU-5.0-Team-Ctrl-Alt-Elite The search is over. We present to you ‘Health-A-

3 Feb 19, 2022
A tool for making map images from OpenTTD save games

OpenTTD Surveyor A tool for making map images from OpenTTD save games. This is not part of the main OpenTTD codebase, nor is it ever intended to be pa

Aidan Randle-Conde 9 Feb 15, 2022
Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data - Official PyTorch Implementation (CVPR 2022)

Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data (CVPR 2022) Potentials of primitive shapes f

31 Sep 27, 2022
Unofficial PyTorch implementation of Google AI's VoiceFilter system

VoiceFilter Note from Seung-won (2020.10.25) Hi everyone! It's Seung-won from MINDs Lab, Inc. It's been a long time since I've released this open-sour

MINDs Lab 883 Jan 07, 2023
Self-supervised learning (SSL) is a method of machine learning

Self-supervised learning (SSL) is a method of machine learning. It learns from unlabeled sample data. It can be regarded as an intermediate form between supervised and unsupervised learning.

Ashish Patel 4 May 26, 2022