PyTorch code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised DA

Related tags

Deep LearningSENTRY
Overview

PyTorch Code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation

Viraj Prabhu, Shivam Khare, Deeksha Kartik, Judy Hoffman

Many existing approaches for unsupervised domain adaptation (UDA) focus on adapting under only data distribution shift and offer limited success under additional cross-domain label distribution shift. Recent work based on self-training using target pseudolabels has shown promise, but on challenging shifts pseudolabels may be highly unreliable and using them for self-training may cause error accumulation and domain misalignment. We propose Selective Entropy Optimization via Committee Consistency (SENTRY), a UDA algorithm that judges the reliability of a target instance based on its predictive consistency under a committee of random image transformations. Our algorithm then selectively minimizes predictive entropy to increase confidence on highly consistent target instances, while maximizing predictive entropy to reduce confidence on highly inconsistent ones. In combination with pseudolabel-based approximate target class balancing, our approach leads to significant improvements over the state-of-the-art on 27/31 domain shifts from standard UDA benchmarks as well as benchmarks designed to stress-test adaptation under label distribution shift.

method

Table of Contents

Setup and Dependencies

  1. Create an anaconda environment with Python 3.6: conda create -n sentry python=3.6.8 and activate: conda activate sentry
  2. Navigate to the code directory: cd code/
  3. Install dependencies: pip install -r requirements.txt

And you're all set up!

Usage

Download data

Data for SVHN->MNIST is downloaded automatically via PyTorch. Data for other benchmarks can be downloaded from the following links. The splits used for our experiments are already included in the data/ folder):

  1. DomainNet
  2. OfficeHome
  3. VisDA2017 (only train and validation needed)

Pretrained checkpoints

To reproduce numbers reported in the paper, we include a a few pretrained checkpoints. We include checkpoints (source and adapted) for SVHN to MNIST (DIGITS) in the checkpoints directory. Source and adapted checkpoints for Clipart to Sketch adaptation (from DomainNet) and Real_World to Product adaptation (from OfficeHome RS-UT) can be downloaded from this link, and should be saved to the checkpoints/source and checkpoints/SENTRY directory as appropriate.

Train and adapt model

  • Natural label distribution shift: Adapt a model from to for a given (where benchmark may be DomainNet, OfficeHome, VisDA, or DIGITS), as follows:
python train.py --id <experiment_id> \
                --source <source> \
                --target <target> \
                --img_dir <image_directory> \
                --LDS_type <LDS_type> \
                --load_from_cfg True \
                --cfg_file 'config/<benchmark>/<cfg_file>.yml' \
                --use_cuda True

SENTRY hyperparameters are provided via a sentry.yml config file in the corresponding config/<benchmark> folder (On DIGITS, we also provide a config for baseline adaptation via DANN). The list of valid source/target domains per-benchmark are:

  • DomainNet: real, clipart, sketch, painting
  • OfficeHome_RS_UT: Real_World, Clipart, Product
  • OfficeHome: Real_World, Clipart, Product, Art
  • VisDA2017: visda_train, visda_test
  • DIGITS: Only svhn (source) to mnist (target) adaptation is currently supported.

Pass in the path to the parent folder containing dataset images via the --img_dir <name_of_directory> flag (eg. --img_dir '~/data/DomainNet'). Pass in the label distribution shift type via the --LDS_type flag: For DomainNet, OfficeHome (standard), and VisDA2017, pass in --LDS_type 'natural' (default). For OfficeHome RS-UT, pass in --LDS_type 'RS_UT'. For DIGITS, pass in --LDS_type as one of IF1, IF20, IF50, or IF100, to load a manually long-tailed target training split with a given imbalance factor (IF), as described in Table 4 of the paper.

To load a pretrained DA checkpoint instead of training your own, additionally pass --load_da True and --id <benchmark_name> to the script above. Finally, the training script will log performance metrics to the console (average and aggregate accuracy), and additionally plot and save some per-class performance statistics to the results/ folder.

Note: By default this code runs on GPU. To run on CPU pass: --use_cuda False

Reference

If you found this code useful, please consider citing:

@article{prabhu2020sentry
   author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy},
   title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation},
   year = {2020},
   journal = {arXiv preprint: 2012.11460},
}

Acknowledgements

We would like to thank the developers of PyTorch for building an excellent framework, in addition to the numerous contributors to all the open-source packages we use.

License

MIT

Code for the bachelors-thesis flaky fault localization

Flaky_Fault_Localization Scripts for the Bachelors-Thesis: "Flaky Fault Localization" by Christian Kasberger. The thesis examines the usefulness of sp

Christian Kasberger 1 Oct 26, 2021
Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning And private Server services

Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning

MaCan 4.2k Dec 29, 2022
A CNN implementation using only numpy. Supports multidimensional images, stride, etc.

A CNN implementation using only numpy. Supports multidimensional images, stride, etc. Speed up due to heavy use of slicing and mathematical simplification..

2 Nov 30, 2021
QI-Q RoboMaster2022 CV Algorithm

QI-Q RoboMaster2022 CV Algorithm

2 Jan 10, 2022
Clean and readable code for Decision Transformer: Reinforcement Learning via Sequence Modeling

Minimal implementation of Decision Transformer: Reinforcement Learning via Sequence Modeling in PyTorch for mujoco control tasks in OpenAI gym

Nikhil Barhate 104 Jan 06, 2023
2.86% and 15.85% on CIFAR-10 and CIFAR-100

Shake-Shake regularization This repository contains the code for the paper Shake-Shake regularization. This arxiv paper is an extension of Shake-Shake

Xavier Gastaldi 294 Nov 22, 2022
An end-to-end machine learning web app to predict rugby scores (Pandas, SQLite, Keras, Flask, Docker)

Rugby score prediction An end-to-end machine learning web app to predict rugby scores Overview An demo project to provide a high-level overview of the

34 May 24, 2022
External Attention Network

Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks paper : https://arxiv.org/abs/2105.02358 EAMLP will come soon Jitto

MenghaoGuo 357 Dec 11, 2022
Improving Object Detection by Estimating Bounding Box Quality Accurately

Improving Object Detection by Estimating Bounding Box Quality Accurately Abstrac

2 Apr 14, 2022
Official implementation of "Learning Forward Dynamics Model and Informed Trajectory Sampler for Safe Quadruped Navigation" (RSS 2022)

Intro Official implementation of "Learning Forward Dynamics Model and Informed Trajectory Sampler for Safe Quadruped Navigation" Robotics:Science and

Yunho Kim 21 Dec 07, 2022
Deep learning image registration library for PyTorch

TorchIR: Pytorch Image Registration TorchIR is a image registration library for deep learning image registration (DLIR). I have integrated several ide

Bob de Vos 40 Dec 16, 2022
OpenPose: Real-time multi-person keypoint detection library for body, face, hands, and foot estimation

Build Type Linux MacOS Windows Build Status OpenPose has represented the first real-time multi-person system to jointly detect human body, hand, facia

25.7k Jan 09, 2023
Machine Learning University: Accelerated Computer Vision Class

Machine Learning University: Accelerated Computer Vision Class This repository contains slides, notebooks, and datasets for the Machine Learning Unive

AWS Samples 1.3k Dec 28, 2022
Official code for 'Pixel-wise Energy-biased Abstention Learning for Anomaly Segmentationon Complex Urban Driving Scenes'

PEBAL This repo contains the Pytorch implementation of our paper: Pixel-wise Energy-biased Abstention Learning for Anomaly Segmentation on Complex Urb

Yu Tian 117 Jan 03, 2023
SimpleDepthEstimation - An unified codebase for NN-based monocular depth estimation methods

SimpleDepthEstimation Introduction This is an unified codebase for NN-based monocular depth estimation methods, the framework is based on detectron2 (

8 Dec 13, 2022
3D Generative Adversarial Network

Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling This repository contains pre-trained models and sampling

Chengkai Zhang 791 Dec 20, 2022
TCNN Temporal convolutional neural network for real-time speech enhancement in the time domain

TCNN Pandey A, Wang D L. TCNN: Temporal convolutional neural network for real-time speech enhancement in the time domain[C]//ICASSP 2019-2019 IEEE Int

凌逆战 16 Dec 30, 2022
Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation

SSWS-loss_function_based_on_MS-TCN Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation Supervised Sliding Window

3 Aug 03, 2022
Tutel MoE: An Optimized Mixture-of-Experts Implementation

Project Tutel Tutel MoE: An Optimized Mixture-of-Experts Implementation. Supported Framework: Pytorch Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32) Ho

Microsoft 344 Dec 29, 2022
Open source hardware and software platform to build a small scale self driving car.

Donkeycar is minimalist and modular self driving library for Python. It is developed for hobbyists and students with a focus on allowing fast experimentation and easy community contributions.

Autorope 2.4k Jan 04, 2023