PyTorch code of my WACV 2022 paper Improving Model Generalization by Agreement of Learned Representations from Data Augmentation

Overview

Improving Model Generalization by Agreement of Learned Representations from Data Augmentation (WACV 2022)

Paper

ArXiv

Why it matters?

When data augmentation is applied on an input image, a model is forced to learn invariant features to improve model generalization (Figure 1).

Since data augmentation incurs little overhead, why not generate 2 data augmented images (also known as 2 positive samples) from a given input. Then, force the model to agree on the common invariant features to support the correct label (Figure 2). It turns out that maximizing this agreement further improves model model generalization. We call our method AgMax.

Unlike label smoothing, AgMax consistently improves model accuracy. For example on ImageNet1k for 90 epochs, the ResNet50 performance is as follows:

Data Augmentation Baseline Label Smoothing AgMax (Ours)
Standard 76.4 76.8 76.9
CutOut 76.2 76.5 77.1
MixUp 76.5 76.7 77.6
CutMix 76.3 76.4 77.4
AutoAugment (AA) 76.2 76.2 77.1
CutOut+AA 75.7 75.7 76.6
MixUp+AA 75.9 76.5 77.1
CutMix+AA 75.5 75.5 77.0

The figure below demonstrates consistent improvement across different data augmnentation methods:

Install requirements

pip3 install -r requirements.txt

Train

For example, train ResNet50 with AgMax on 2 GPUs for 90 epochs, SGD with lr=0.1 and multistep learning rate scheduler:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard-agmax --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Compare the results without AgMax:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Test

Using a pre-trained model:

ResNet101 trained with CutMix, AutoAugment and AgMax:

mkdir checkpoints
cd checkpoints
wget https://github.com/roatienza/agmax/releases/download/agmax-0.1.0/imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth
cd ..
python3 main.py --config=ResNet101-auto_augment-cutmix-agmax --eval \
--dataset=imagenet \
--resume imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth

ResNet50 trained with CutMix, AutoAugment and AgMax:

python3 main.py --config=ResNet50-auto_augment-cutmix-agmax --eval --n-units=2048 \
--dataset=imagenet --resume imagenet-agmax-ResNet50-cutmix-auto_augment-79.12-mlp-2048.pth

Other pre-trained models (Baselines):

Citation

If you find this work useful, please cite:

@inproceedings{atienza2022agmax,
  title={Improving Model Generalization by Agreement of Learned Representations from Data Augmentation},
  author={Atienza, Rowel},
  booktitle = {IEEE/CVF Winter Conference on Applications of Computer Vision},
  year={2022},
  pubstate={published},
  tppubtype={inproceedings}
}
You might also like...
Code for the AAAI-2022 paper: Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification

Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification (AAAI 2022) Prerequisite PyTorch = 1.2.0 P

Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).
Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).

Relation Prediction as an Auxiliary Training Objective for Knowledge Base Completion This repo provides the code for the paper Relation Prediction as

Sharpness-Aware Minimization for Efficiently Improving Generalization
Sharpness-Aware Minimization for Efficiently Improving Generalization

Sharpness-Aware-Minimization-TensorFlow This repository provides a minimal implementation of sharpness-aware minimization (SAM) (Sharpness-Aware Minim

ImageNet-CoG is a benchmark for concept generalization. It provides a full evaluation framework for pre-trained visual representations which measure how well they generalize to unseen concepts.

The ImageNet-CoG Benchmark Project Website Paper (arXiv) Code repository for the ImageNet-CoG Benchmark introduced in the paper "Concept Generalizatio

The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022
The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022

DG-TrajGen The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022. Our Meth

Code for
Code for "ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on", accepted at WACV 2021 Generation of Human Behavior Workshop.

ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on [ Paper ] [ Project Page ] This repository contains the code fo

This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild with Dense 3D Representations and A Benchmark. (CVPR 2022)"

Gait3D-Benchmark This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild

Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

[WACV 2020] Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints

Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints Official implementation for Reducing Footskate in Human Motion Recon

Comments
  • Question about the mutual information in the agmax loss

    Question about the mutual information in the agmax loss

    The agmax loss is computed by "agreement_loss, dl = agmax_loss(y, target, self.args.dl_weight)" , I do not know what does the argeement_loss and dl mean, and do not know the relationship between these two loss and the calculation of mutual information. Could you give me some help?

    opened by YananGu 4
Owner
Rowel Atienza
Rowel Atienza
Hide screen when boss is approaching.

BossSensor Hide your screen when your boss is approaching. Demo The boss stands up. He is approaching. When he is approaching, the program fetches fac

Hiroki Nakayama 6.2k Jan 07, 2023
Global-Local Attention for Emotion Recognition

Global-Local Attention for Emotion Recognition Requirements Python 3 Install tensorflow (or tensorflow-gpu) = 2.0.0 Install some other packages pip i

Minh Nhat Le 15 Apr 21, 2022
Code and data of the EMNLP 2021 paper "Mind the Style of Text! Adversarial and Backdoor Attacks Based on Text Style Transfer"

StyleAttack Code and data of the EMNLP 2021 paper "Mind the Style of Text! Adversarial and Backdoor Attacks Based on Text Style Transfer" Prepare Pois

THUNLP 19 Nov 20, 2022
FishNet: One Stage to Detect, Segmentation and Pose Estimation

FishNet FishNet: One Stage to Detect, Segmentation and Pose Estimation Introduction In this project, we combine target detection, instance segmentatio

1 Oct 05, 2022
Implementation for "Exploiting Aliasing for Manga Restoration" (CVPR 2021)

[CVPR Paper](To appear) | [Project Website](To appear) | BibTex Introduction As a popular entertainment art form, manga enriches the line drawings det

133 Dec 15, 2022
Writeups for the challenges from DownUnderCTF 2021

cloud Challenge Author Difficulty Release Round Bad Bucket Blue Alder easy round 1 Not as Bad Bucket Blue Alder easy round 1 Lost n Found Blue Alder m

DownUnderCTF 161 Dec 31, 2022
Densely Connected Convolutional Networks, In CVPR 2017 (Best Paper Award).

Densely Connected Convolutional Networks (DenseNets) This repository contains the code for DenseNet introduced in the following paper Densely Connecte

Zhuang Liu 4.5k Jan 03, 2023
TRACER: Extreme Attention Guided Salient Object Tracing Network implementation in PyTorch

TRACER: Extreme Attention Guided Salient Object Tracing Network This paper was accepted at AAAI 2022 SA poster session. Datasets All datasets are avai

Karel 118 Dec 29, 2022
BERT model training impelmentation using 1024 A100 GPUs for MLPerf Training v1.1

Pre-trained checkpoint and bert config json file Location of checkpoint and bert config json file This MLCommons members Google Drive location contain

SAIT (Samsung Advanced Institute of Technology) 12 Apr 27, 2022
Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-wise Distributed Data based on Pytorch Framework

VFedPCA+VFedAKPCA This is the official source code for the Paper: Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-

John 9 Sep 18, 2022
Python Tensorflow 2 scripts for detecting objects of any class in an image without knowing their label.

Tensorflow-Mobile-Generic-Object-Localizer Python Tensorflow 2 scripts for detecting objects of any class in an image without knowing their label. Ori

Ibai Gorordo 11 Nov 15, 2022
The code for paper "Learning Implicit Fields for Generative Shape Modeling".

implicit-decoder The tensorflow code for paper "Learning Implicit Fields for Generative Shape Modeling", Zhiqin Chen, Hao (Richard) Zhang. Project pag

Zhiqin Chen 353 Dec 30, 2022
Models Supported: AlbUNet [18, 34, 50, 101, 152] (1D and 2D versions for Single and Multiclass Segmentation, Feature Extraction with supports for Deep Supervision and Guided Attention)

AlbUNet-1D-2D-Tensorflow-Keras This repository contains 1D and 2D Signal Segmentation Model Builder for AlbUNet and several of its variants developed

Sakib Mahmud 1 Nov 15, 2021
Data Engineering ZoomCamp

Data Engineering ZoomCamp I'm partaking in a Data Engineering Bootcamp / Zoomcamp and will be tracking my progress here. I can't promise these notes w

Aaron 61 Jan 06, 2023
A simple code to perform canny edge contrast detection on images.

CECED-Canny-Edge-Contrast-Enhanced-Detection A simple code to perform canny edge contrast detection on images. A simple code to process images using c

Happy N. Monday 3 Feb 15, 2022
This solves the autonomous driving issue which is supported by deep learning technology. Given a video, it splits into images and predicts the angle of turning for each frame.

Self Driving Car An autonomous car (also known as a driverless car, self-driving car, and robotic car) is a vehicle that is capable of sensing its env

Sagor Saha 4 Sep 04, 2021
EMNLP 2021 paper The Devil is in the Detail: Simple Tricks Improve Systematic Generalization of Transformers.

Codebase for training transformers on systematic generalization datasets. The official repository for our EMNLP 2021 paper The Devil is in the Detail:

Csordás Róbert 57 Nov 21, 2022
Official code for Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset

Official code for our Interspeech 2021 - Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset [1]*. Visually-grounded spoken language datasets c

Ian Palmer 3 Jan 26, 2022
Project page for the paper Semi-Supervised Raw-to-Raw Mapping 2021.

Project page for the paper Semi-Supervised Raw-to-Raw Mapping 2021.

Mahmoud Afifi 22 Nov 08, 2022
A JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short.

BraVe This is a JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short. The model provided in this package wa

DeepMind 44 Nov 20, 2022