Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

Overview

PAWS-TF 🐾

Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS) in TensorFlow (2.4.1).

PAWS introduces a simple way to combine a very small fraction of labeled data with a comparatively larger corpus of unlabeled data during pre-training. With its approach, it sets the state-of-the-art in semi-supervised learning (as of May 2021) beating methods like SimCLRV2, Meta Pseudo Labels that too with fewer parameters and a smaller pre-training schedule. For details, I recommend checking out the original paper as well as this blog post by the authors.

This repository implements and includes all the major bits proposed in PAWS in TensorFlow. The only major difference is that the pre-training and subsequent fine-tuning weren't run for the original number of epochs (600 and 30 respectively) to save compute. I have reused the utility components for PAWS loss from the original implementation.

Dataset βŒ—

The current code works with CIFAR10 and uses 4000 labeled samples (8%) during pre-training (along with the unlabeled samples).

Features ✨

  • Multi-crop augmentation strategy (originally introduced in SwAV)
  • Class stratified sampler (common in few-shot classification problems)
  • WarmUpCosine learning rate schedule (which is typical for self-supervised and semi-supervised pre-training)
  • LARS optimizer (comes from TensorFlow Model Garden)

The trunk portion (all, except the last classification layer) of a WideResNet-28-2 is used inside the encoder for CIFAR10. All the experimental configurations were followed from the Appendix C of the paper.

Setup and code structure πŸ’»

A GCP VM (n1-standard-8) with a single V100 GPU was used for executing the code.

  • paws_train.py runs the pre-training as introduced in PAWS.
  • fine_tune.py runs the fine-tuning part as suggested in Appendix C. Note that this is only required for CIFAR10.
  • nn_eval.py runs the soft nearest neighbor classification on CIFAR10 test set.

Pre-training and fine-tuning total take 1.4 hours to complete. All the logs are available in misc/logs.txt. Additionally, the indices that were used to sample the labeled examples from the CIFAR10 training set are available here.

Results πŸ“Š

Pre-training

PAWS minimizes the cross-entropy loss (as well as maximizes mean-entropy) during pre-training. This is what the training plot indicates too:

To evaluate the effectivity of the pre-training, PAWS performs soft nearest neighbor classification to report the top-1 accuracy score on a given test set.

Top-1 Accuracy

This repository gets to 73.46% top-1 accuracy on the CIFAR10 test set. Again, note that I only pre-trained for 50 epochs (as opposed to 600) and fine-tuned for 10 epochs (as opposed to 30). With the original schedule this score should be around 96.0%.

In the following PCA projection plot, we see that the embeddings of images (computed after fine-tuning) of PAWS are starting to be well separated:

Notebooks πŸ“˜

There are two Colab Notebooks:

Misc ⺟

  • Model weights are available here for reproducibility.
  • With mixed-precision training, the performance can further be improved. I am open to accepting contributions that would implement mixed-precision training in the current code.

Acknowledgements

  • Huge amount of thanks to Mahmoud Assran (first author of PAWS) for patiently resolving my doubts.
  • ML-GDE program for providing GCP credit support.

Paper Citation

@misc{assran2021semisupervised,
      title={Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples}, 
      author={Mahmoud Assran and Mathilde Caron and Ishan Misra and Piotr Bojanowski and Armand Joulin and Nicolas Ballas and Michael Rabbat},
      year={2021},
      eprint={2104.13963},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Supplementary code for the paper
Supplementary code for the paper "Meta-Solver for Neural Ordinary Differential Equations" https://arxiv.org/abs/2103.08561

Meta-Solver for Neural Ordinary Differential Equations Towards robust neural ODEs using parametrized solvers. Main idea Each Runge-Kutta (RK) solver w

Code for paper "A Critical Assessment of State-of-the-Art in Entity Alignment" (https://arxiv.org/abs/2010.16314)

A Critical Assessment of State-of-the-Art in Entity Alignment This repository contains the source code for the paper A Critical Assessment of State-of

Code for the paper: Learning Adversarially Robust Representations via Worst-Case Mutual Information Maximization (https://arxiv.org/abs/2002.11798)

Representation Robustness Evaluations Our implementation is based on code from MadryLab's robustness package and Devon Hjelm's Deep InfoMax. For all t

ISTR: End-to-End Instance Segmentation with Transformers (https://arxiv.org/abs/2105.00637)

This is the project page for the paper: ISTR: End-to-End Instance Segmentation via Transformers, Jie Hu, Liujuan Cao, Yao Lu, ShengChuan Zhang, Yan Wa

Releases(v1.0.0)
Owner
Sayak Paul
Trying to learn how machines learn.
Sayak Paul
The first dataset of composite images with rationality score indicating whether the object placement in a composite image is reasonable.

Object-Placement-Assessment-Dataset-OPA Object-Placement-Assessment (OPA) is to verify whether a composite image is plausible in terms of the object p

BCMI 53 Nov 15, 2022
A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization

MADGRAD Optimization Method A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization pip install madgrad Try it out! A best

Meta Research 774 Dec 31, 2022
Model-based Reinforcement Learning Improves Autonomous Racing Performance

Racing Dreamer: Model-based versus Model-free Deep Reinforcement Learning for Autonomous Racing Cars In this work, we propose to learn a racing contro

Cyber Physical Systems - TU Wien 38 Dec 06, 2022
Code for "Continuous-Time Meta-Learning with Forward Mode Differentiation" (ICLR 2022)

Continuous-Time Meta-Learning with Forward Mode Differentiation ICLR 2022 (Spotlight) - Installation - Example - Citation This repository contains the

Tristan Deleu 25 Oct 20, 2022
Suite of 500 procedurally-generated NLP tasks to study language model adaptability

TaskBench500 The TaskBench500 dataset and code for generating tasks. Data The TaskBench dataset is available under wget http://web.mit.edu/bzl/www/Tas

Belinda Li 20 May 17, 2022
OpenGAN: Open-Set Recognition via Open Data Generation

OpenGAN: Open-Set Recognition via Open Data Generation ICCV 2021 (oral) Real-world machine learning systems need to analyze novel testing data that di

Shu Kong 90 Jan 06, 2023
SafePicking: Learning Safe Object Extraction via Object-Level Mapping, ICRA 2022

SafePicking Learning Safe Object Extraction via Object-Level Mapping Kentaro Wad

Kentaro Wada 49 Oct 24, 2022
Nicely is a real-time Feedback and Intervention Program Depression is a prevalent issue across all age groups, socioeconomic classes, and cultural identities.

Nicely is a real-time Feedback and Intervention Program Depression is a prevalent issue across all age groups, socioeconomic classes, and cultural identities.

1 Jan 16, 2022
Rethinking the Importance of Implementation Tricks in Multi-Agent Reinforcement Learning

RIIT Our open-source code for RIIT: Rethinking the Importance of Implementation Tricks in Multi-AgentReinforcement Learning. We implement and standard

405 Jan 06, 2023
EigenGAN Tensorflow, EigenGAN: Layer-Wise Eigen-Learning for GANs

Gender Bangs Body Side Pose (Yaw) Lighting Smile Face Shape Lipstick Color Painting Style Pose (Yaw) Pose (Pitch) Zoom & Rotate Flush & Eye Color Mout

Zhenliang He 321 Dec 01, 2022
This is the repo of the manuscript "Dual-branch Attention-In-Attention Transformer for speech enhancement"

DB-AIAT: A Dual-branch attention-in-attention transformer for single-channel SE

Guochen Yu 68 Dec 16, 2022
PyTorch3D is FAIR's library of reusable components for deep learning with 3D data

Introduction PyTorch3D provides efficient, reusable components for 3D Computer Vision research with PyTorch. Key features include: Data structure for

Facebook Research 6.8k Jan 01, 2023
RuDOLPH: One Hyper-Modal Transformer can be creative as DALL-E and smart as CLIP

[Paper] [Π₯Π°Π±Ρ€] [Model Card] [Colab] [Kaggle] RuDOLPH 🦌 πŸŽ„ β˜ƒοΈ One Hyper-Modal Tr

Sber AI 230 Dec 31, 2022
FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective

FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective Official implementation of "FL-WBC: Enhan

Jingwei Sun 26 Nov 28, 2022
YOLOv7 - Framework Beyond Detection

πŸ”₯πŸ”₯πŸ”₯πŸ”₯ YOLO with Transformers and Instance Segmentation, with TensorRT acceleration! πŸ”₯πŸ”₯πŸ”₯

JinTian 3k Jan 01, 2023
Polynomial-time Meta-Interpretive Learning

Louise - polynomial-time Program Learning Getting help with Louise Louise's author can be reached by email at Stassa Patsantzis 64 Dec 26, 2022

Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Jamie Townsend 42 Dec 12, 2022
EqGAN - Improving GAN Equilibrium by Raising Spatial Awareness

EqGAN - Improving GAN Equilibrium by Raising Spatial Awareness Improving GAN Equilibrium by Raising Spatial Awareness Jianyuan Wang, Ceyuan Yang, Ying

GenForce: May Generative Force Be with You 149 Dec 19, 2022
CVPR '21: In the light of feature distributions: Moment matching for Neural Style Transfer

In the light of feature distributions: Moment matching for Neural Style Transfer (CVPR 2021) This repository provides code to recreate results present

Nikolai Kalischek 49 Oct 13, 2022
Predicting path with preference based on user demonstration using Maximum Entropy Deep Inverse Reinforcement Learning in a continuous environment

Preference-Planning-Deep-IRL Introduction Check my portfolio post Dependencies Gym stable-baselines3 PyTorch Usage Take Demonstration python3 record.

Tianyu Li 9 Oct 26, 2022