Understanding the Generalization Benefit of Model Invariance from a Data Perspective

Overview

Understanding the Generalization Benefit of Model Invariance from a Data Perspective

This is the code for our NeurIPS2021 paper "Understanding the Generalization Benefit of Model Invariance from a Data Perspective". There are two major parts in our code: sample covering number estimation and generalization benefit evaluation.

Requirments

  • Python 3.8
  • PyTorch
  • torchvision
  • scikit-learn-extra
  • scipy
  • robustness package (already included in our code)

Our code is based on robustness package.

Dataset

  • CIFAR-10 Download and extract the data into /data/cifar10
  • R2N2 Download the ShapeNet rendered images and put the data into /data/r2n2

The randomly sampled R2N2 images used for computing sample covering numbers and indices of examples for different sample sizes could be found here.

Estimation of sample covering numbers

To estimate the sample covering numbers of different data transformations, run the following script in /scn.

CUDA_VISIBLE_DEVICES=0 python run_scn.py  --epsilon 3 --transformation crop --cover_number_method fast --data-path /path/to/dataset 

Note that the input is a N x C x H x W tensor where N is sample size.

Evaluation of generalization benefit

To train the model with data augmentation method, run the following script in /learn_invariance for R2N2 dataset

CUDA_VISIBLE_DEVICES=0 python main.py \
    --dataset r2n2 \
    --data ../data/2n2/ShapeNetRendering \
    --metainfo-path ../data/r2n2/metainfo_all.json \
    --transforms view  \
    --inv-method aug \
    --out-dir /path/to/out_dir \
    --arch resnet18 --epoch 110 --lr 1e-2 --step-lr 50 \
    --workers 30 --batch-size 128 --exp-name view

or the following script for CIFAR-10 dataset

CUDA_VISIBLE_DEVICES=0 python main.py \
    --dataset cifar \
    --data ../data/cifar10 \
    --n-per-class all \
    --transforms crop  \
    --inv-method aug \
    --out-dir /path/to/out_dir \
    --arch resnet18 --epoch 110 --lr 1e-2 --step-lr 50 \
    --workers 30 --batch-size 128 --exp-name crop 

By setting --transforms to be one of {none, flip, crop, rotate, view}, the specific transformation will be considered.

To train the model with regularization method, run the following script. Currently, the code only support 3d-view transformation on R2N2 dataset.

CUDA_VISIBLE_DEVICES=0 python main.py \
    --dataset r2n2 \
    --data ../data/r2n2/ShapeNetRendering \
    --metainfo-path ../data/r2n2/metainfo_all.json \
    --transforms view  \
    --inv-method reg \
    --inv-method-beta 1 \
    --out-dir /path/to/out_dir \
    --arch resnet18 --epoch 110 --lr 1e-2 --step-lr 50 \
    --workers 30 --batch-size 128 --exp-name reg_view 

To evaluate the model with invariance loss and worst-case consistency accuracy, run the following script.

CUDA_VISIBLE_DEVICES=0 python main.py  \
    --dataset r2n2 \
    --data ../data/r2n2/ShapeNetRendering \
    --metainfo-path ../data/r2n2/metainfo_all.json \
    --inv-method reg \
    --arch resnet18 \
    --resume /path/to/checkpoint.pt.best \
    --eval-only 1 \
    --transforms view  \
    --adv-eval 0 \
    --batch-size 2  \
    --no-store 

Note that to have the worst-case consistency accuracy we need to load 24 view images in R2N2RenderingsTorch class in dataset_3d.py.

Owner
PhD student at University of Maryland
Pointer-generator - Code for the ACL 2017 paper Get To The Point: Summarization with Pointer-Generator Networks

Note: this code is no longer actively maintained. However, feel free to use the Issues section to discuss the code with other users. Some users have u

Abi See 2.1k Jan 04, 2023
Pywonderland - A tour in the wonderland of math with python.

A Tour in the Wonderland of Math with Python A collection of python scripts for drawing beautiful figures and animating interesting algorithms in math

Zhao Liang 4.1k Jan 03, 2023
Deep learning based hand gesture recognition using LSTM and MediaPipie.

Hand Gesture Recognition Deep learning based hand gesture recognition using LSTM and MediaPipie. Demo video using PingPong Robot Files Pretrained mode

Brad 24 Nov 11, 2022
Official Implementation of DDOD (Disentangle your Dense Object Detector), ACM MM2021

Disentangle Your Dense Object Detector This repo contains the supported code and configuration files to reproduce object detection results of Disentan

loveSnowBest 51 Jan 07, 2023
Fully Automatic Page Turning on Real Scores

Fully Automatic Page Turning on Real Scores This repository contains the corresponding code for our extended abstract Henkel F., Schwaiger S. and Widm

Florian Henkel 7 Jan 02, 2022
A large-scale face dataset for face parsing, recognition, generation and editing.

CelebAMask-HQ [Paper] [Demo] CelebAMask-HQ is a large-scale face image dataset that has 30,000 high-resolution face images selected from the CelebA da

switchnorm 1.7k Dec 26, 2022
Vikrant Deshpande 1 Nov 17, 2022
Bayesian Neural Networks in PyTorch

We present the new scheme to compute Monte Carlo estimator in Bayesian VI settings with almost no memory cost in GPU, regardles of the number of sampl

Jurijs Nazarovs 7 May 03, 2022
Autotype on websites that have copy-paste disabled like Moodle, HackerEarth contest etc.

Autotype A quick and small python script that helps you autotype on websites that have copy paste disabled like Moodle, HackerEarth contests etc as it

Tushar 32 Nov 03, 2022
Modification of convolutional neural net "UNET" for image segmentation in Keras framework

ZF_UNET_224 Pretrained Model Modification of convolutional neural net "UNET" for image segmentation in Keras framework Requirements Python 3.*, Keras

209 Nov 02, 2022
PowerGridworld: A Framework for Multi-Agent Reinforcement Learning in Power Systems

PowerGridworld provides users with a lightweight, modular, and customizable framework for creating power-systems-focused, multi-agent Gym environments that readily integrate with existing training fr

National Renewable Energy Laboratory 37 Dec 17, 2022
Denoising Diffusion Implicit Models

Denoising Diffusion Implicit Models (DDIM) Jiaming Song, Chenlin Meng and Stefano Ermon, Stanford Implements sampling from an implicit model that is t

465 Jan 05, 2023
Neighbor2Seq: Deep Learning on Massive Graphs by Transforming Neighbors to Sequences

Neighbor2Seq: Deep Learning on Massive Graphs by Transforming Neighbors to Sequences This repository is an official PyTorch implementation of Neighbor

DIVE Lab, Texas A&M University 8 Jun 12, 2022
[ICCV'21] Learning Conditional Knowledge Distillation for Degraded-Reference Image Quality Assessment

CKDN The official implementation of the ICCV2021 paper "Learning Conditional Knowledge Distillation for Degraded-Reference Image Quality Assessment" O

Multimedia Research 50 Dec 13, 2022
Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation

Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation Requirements This repository needs mmsegmentation Training To train

20 May 28, 2022
A toolset of Python programs for signal modeling and indentification via sparse semilinear autoregressors.

SPAAR Description A toolset of Python programs for signal modeling via sparse semilinear autoregressors. References Vides, F. (2021). Computing Semili

Fredy Vides 0 Oct 30, 2021
AAAI 2022: Stationary diffusion state neural estimation

Stationary Diffusion State Neural Estimation Although many graph-based clustering methods attempt to model the stationary diffusion state in their obj

绽琨 33 Nov 24, 2022
基于PaddleClas实现垃圾分类,并转换为inference格式用PaddleHub服务端部署

百度网盘链接及提取码: 链接:https://pan.baidu.com/s/1HKpgakNx1hNlOuZJuW6T1w 提取码:wylx 一个垃圾分类项目带你玩转飞桨多个产品(1) 基于PaddleClas实现垃圾分类,导出inference模型并利用PaddleHub Serving进行服务

thomas-yanxin 22 Jul 12, 2022
E2C implementation in PyTorch

Embed to Control implementation in PyTorch Paper can be found here: https://arxiv.org/abs/1506.07365 You will need a patched version of OpenAI Gym in

Yicheng Luo 42 Dec 12, 2022