We present a regularized self-labeling approach to improve the generalization and robustness properties of fine-tuning.

Overview

Overview

This repository provides the implementation for the paper "Improved Regularization and Robustness for Fine-tuning in Neural Networks", which will be presented as a poster paper in NeurIPS'21.

In this work, we propose a regularized self-labeling approach that combines regularization and self-training methods for improving the generalization and robustness properties of fine-tuning. Our approach includes two components:

  • First, we encode layer-wise regularization to penalize the model weights at different layers of the neural net.
  • Second, we add self-labeling that relabels data points based on current neural net's belief and reweights data points whose confidence is low.

Requirements

To install requirements:

pip install -r requirements.txt

Data Preparation

We use seven image datasets in our paper. We list the link for downloading these datasets and describe how to prepare data to run our code below.

  • Aircrafts: download and extract into ./data/aircrafts
    • remove the class 257.clutter out of the data directory
  • CUB-200-2011: download and extract into ./data/CUB_200_2011/
  • Caltech-256: download and extract into ./data/caltech256/
  • Stanford-Cars: download and extract into ./data/StanfordCars/
  • Stanford-Dogs: download and extract into ./data/StanfordDogs/
  • Flowers: download and extract into ./data/flowers/
  • MIT-Indoor: download and extract into ./data/Indoor/

Our code automatically handles the split of the datasets.

Usage

Our algorithm (RegSL) interpolates between layer-wise regularization and self-labeling. Run the following commands for conducting experiments in this paper.

Fine-tuning ResNet-101 on image classification tasks.

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_indoor.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.136809975858091 --reg_predictor 6.40780158171339 --scale_factor 2.52883770643206\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_aircrafts.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 1.18330556653284 --reg_predictor 5.27713618808711 --scale_factor 1.27679969876201\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_birds.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.204403908747731 --reg_predictor 23.7850606577679 --scale_factor 4.73803591794678\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_caltech.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.0867998872549272 --reg_predictor 9.4552942790218 --scale_factor 1.1785989596144\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_cars.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 1.3340347414257 --reg_predictor 8.26940794089601 --scale_factor 3.47676759842434\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_dogs.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.0561320847651626 --reg_predictor 4.46281825974388 --scale_factor 1.58722606909531\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_flower.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.131991042311165 --reg_predictor 10.7674132173309 --scale_factor 4.98010215976503\
    --device 1

Fine-tuning ResNet-18 under label noise.

python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 7.80246991703043 --reg_predictor 14.077402847906 \
    --noise_rate 0.2 --train_correct_label --reweight_epoch 5 --reweight_temp 2.0 --correct_epoch 10 --correct_thres 0.9 

python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 8.47139398080791 --reg_predictor 19.0191127114923 \
    --noise_rate 0.4 --train_correct_label --reweight_epoch 5 --reweight_temp 2.0 --correct_epoch 10 --correct_thres 0.9 

python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 10.7576018531961 --reg_predictor 19.8157649727473 \
    --noise_rate 0.6 --train_correct_label --reweight_epoch 5 --reweight_temp 2.0 --correct_epoch 10 --correct_thres 0.9 
    
python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 9.2031662757248 --reg_predictor 6.41568500472423 \
    --noise_rate 0.8 --train_correct_label --reweight_epoch 5 --reweight_temp 1.5 --correct_epoch 10 --correct_thres 0.9 

Fine-tuning Vision Transformer on noisy labels.

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method none --reg_norm none \
    --lr 0.0001 --device 1 --noise_rate 0.4

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method none --reg_norm none \
    --lr 0.0001 --device 1 --noise_rate 0.8

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.7488074175044196 --reg_predictor 9.842955837419588 \
    --train_correct_label --reweight_epoch 24 --correct_epoch 18\
    --lr 0.0001 --device 1 --noise_rate 0.4

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.1568903647089986 --reg_predictor 1.407080880079702 \
    --train_correct_label --reweight_epoch 18 --correct_epoch 2\
    --lr 0.0001 --device 1 --noise_rate 0.8

Please follow the instructions in ViT-pytorch to download the pre-trained models.

Fine-tuning ResNet-18 on ChestX-ray14 data set.

Run experiments on ChestX-ray14 in reproduce-chexnet path:

cd reproduce-chexnet

python retrain.py --reg_method None --reg_norm None --device 0

python retrain.py --reg_method constraint --reg_norm frob \
    --reg_extractor 5.728564437344309 --reg_predictor 2.5669480884876905 --scale_factor 1.0340072757925474 \
    --device 0

Citation

If you find this repository useful, consider citing our work titled above.

Acknowledgment

Thanks to the authors of the following repositories for providing their implementation publicly available.

Owner
NEU-StatsML-Research
We are a group of faculty and students from the Computer Science College of Northeastern University
NEU-StatsML-Research
Trustworthy AI related projects

Trustworthy AI This repository aims to include trustworthy AI related projects from Huawei Noah's Ark Lab. Current projects include: Causal Structure

HUAWEI Noah's Ark Lab 589 Dec 30, 2022
Auto-updating data to assist in investment to NEPSE

Symbol Ratios Summary Sector LTP Undervalued Bonus % MEGA Strong Commercial Banks 368 5 10 JBBL Strong Development Banks 568 5 10 SIFC Strong Finance

Amit Chaudhary 16 Nov 01, 2022
Classical OCR DCNN reproduction based on PaddlePaddle framework.

Paddle-SVHN Classical OCR DCNN reproduction based on PaddlePaddle framework. This project reproduces Multi-digit Number Recognition from Street View I

1 Nov 12, 2021
Implementation of "Semi-supervised Domain Adaptive Structure Learning"

Semi-supervised Domain Adaptive Structure Learning - ASDA This repo contains the source code and dataset for our ASDA paper. Illustration of the propo

3 Dec 13, 2021
Repository providing a wide range of self-supervised pretrained models for computer vision tasks.

Hierarchical Pretraining: Research Repository This is a research repository for reproducing the results from the project "Self-supervised pretraining

Colorado Reed 53 Nov 09, 2022
百度2021年语言与智能技术竞赛机器阅读理解Pytorch版baseline

项目说明: 百度2021年语言与智能技术竞赛机器阅读理解Pytorch版baseline 比赛链接:https://aistudio.baidu.com/aistudio/competition/detail/66?isFromLuge=true 官方的baseline版本是基于paddlepadd

周俊贤 54 Nov 23, 2022
LoFTR:Detector-Free Local Feature Matching with Transformers CVPR 2021

LoFTR-with-train-script LoFTR:Detector-Free Local Feature Matching with Transformers CVPR 2021 (with train script --- unofficial ---). About Megadepth

Nan Xiaohu 15 Nov 04, 2022
Training DALL-E with volunteers from all over the Internet using hivemind and dalle-pytorch (NeurIPS 2021 demo)

Training DALL-E with volunteers from all over the Internet This repository is a part of the NeurIPS 2021 demonstration "Training Transformers Together

<a href=[email protected]"> 19 Dec 13, 2022
A curated list of awesome game datasets, and tools to artificial intelligence in games

🎮 Awesome Game Datasets In computer science, Artificial Intelligence (AI) is intelligence demonstrated by machines. Its definition, AI research as th

Leonardo Mauro 454 Jan 03, 2023
AWS provides a Python SDK, "Boto3" ,which can be used to access the AWS-account from the local.

Boto3 - The AWS SDK for Python Boto3 is the Amazon Web Services (AWS) Software Development Kit (SDK) for Python, which allows Python developers to wri

Shreyas Srivastava 1 Oct 25, 2021
TyXe: Pyro-based BNNs for Pytorch users

TyXe: Pyro-based BNNs for Pytorch users TyXe aims to simplify the process of turning Pytorch neural networks into Bayesian neural networks by leveragi

87 Jan 03, 2023
OREO: Object-Aware Regularization for Addressing Causal Confusion in Imitation Learning (NeurIPS 2021)

OREO: Object-Aware Regularization for Addressing Causal Confusion in Imitation Learning (NeurIPS 2021) Video demo We here provide a video demo from co

20 Nov 25, 2022
Code for paper " AdderNet: Do We Really Need Multiplications in Deep Learning?"

AdderNet: Do We Really Need Multiplications in Deep Learning? This code is a demo of CVPR 2020 paper AdderNet: Do We Really Need Multiplications in De

HUAWEI Noah's Ark Lab 915 Jan 01, 2023
NeROIC: Neural Object Capture and Rendering from Online Image Collections

NeROIC: Neural Object Capture and Rendering from Online Image Collections This repository is for the source code for the paper NeROIC: Neural Object C

Snap Research 647 Dec 27, 2022
Code for the paper Task Agnostic Morphology Evolution.

Task-Agnostic Morphology Optimization This repository contains code for the paper Task-Agnostic Morphology Evolution by Donald (Joey) Hejna, Pieter Ab

Joey Hejna 18 Aug 04, 2022
Transparent Transformer Segmentation

Transparent Transformer Segmentation Introduction This repository contains the data and code for IJCAI 2021 paper Segmenting transparent object in the

谢恩泽 140 Jan 02, 2023
Exploring Versatile Prior for Human Motion via Motion Frequency Guidance (3DV2021)

Exploring Versatile Prior for Human Motion via Motion Frequency Guidance This is the codebase for video-based human motion reconstruction in human-mot

Jiachen Xu 5 Jul 14, 2022
Temporal Knowledge Graph Reasoning Triggered by Memories

MTDM Temporal Knowledge Graph Reasoning Triggered by Memories To alleviate the time dependence, we propose a memory-triggered decision-making (MTDM) n

4 Sep 25, 2022
U-Net: Convolutional Networks for Biomedical Image Segmentation

Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras This tutorial shows how to use Keras library to build deep ne

Yihui He 401 Nov 21, 2022
A large-scale database for graph representation learning

A large-scale database for graph representation learning

Scott Freitas 29 Nov 25, 2022