PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

Related tags

Deep Learningdino
Overview

Self-Supervised Vision Transformers with DINO

PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supervised Vision Transformers.
[blogpost] [arXiv]

DINO illustration

Pretrained models

You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the training and evaluation logs.

arch params k-nn linear download
DeiT-S/16 21M 74.5% 77.0% backbone only full checkpoint args logs eval logs
DeiT-S/8 21M 78.3% 79.7% backbone only full checkpoint args logs eval logs
ViT-B/16 85M 76.1% 78.2% backbone only full checkpoint args logs eval logs
ViT-B/8 85M 77.4% 80.1% backbone only full checkpoint args logs eval logs
ResNet-50 23M 67.5% 75.3% backbone only full checkpoint args logs eval logs

The pretrained models are available on PyTorch Hub.

import torch
deits16 = torch.hub.load('facebookresearch/dino', 'dino_deits16')
deits8 = torch.hub.load('facebookresearch/dino', 'dino_deits8')
vitb16 = torch.hub.load('facebookresearch/dino', 'dino_vitb16')
vitb8 = torch.hub.load('facebookresearch/dino', 'dino_vitb8')
resnet50 = torch.hub.load('facebookresearch/dino', 'dino_resnet50')

Training

Documentation

Please install PyTorch and download the ImageNet dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the args column of the pretrained models section. For a glimpse at the full documentation of DINO training please run:

python main_dino.py --help

Vanilla DINO training 🦕

Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach ~69.3% on k-NN eval and ~73.8% on linear eval. We will shortly provide training and linear evaluation logs for this run to help reproducibility.

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Multi-node training

We use Slurm and submitit (pip install submitit). To train on 2 nodes with 8 GPUs each (total 16 GPUs):

python run_with_submitit.py --nodes 2 --ngpus 8 --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
DINO with ViT-base network.
python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base  --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Boosting DINO performance 🦖

You can improve the performance of the vanilla run by:

  • training for more epochs: --epochs 300,
  • increasing the teacher temperature: --teacher_temp 0.07 --warmup_teacher_temp_epochs 30.
  • removing last layer normalization (only safe with --arch deit_small): --norm_last_layer false,
Full command.
python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

The resulting pretrained model should reach ~73.4% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We will shortly provide training and linear evaluation logs for this run to help reproducibility.

ResNet-50 and other convnets trainings

This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example here is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs:

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Evaluation: k-NN classification on ImageNet

To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet

If you choose not to specify --pretrained_weights, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet 

Evaluation: Linear classification on ImageNet

To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet

Self-attention visualization

You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:

python visualize_attention.py
Self-attention from a Vision Transformer with 8x8 patches trained with DINO

License

See the LICENSE file for more details.

Citation

If you find this repository useful, please consider giving a star and citation 🦖 :

@article{caron2021emerging,
  title={Emerging Properties in Self-Supervised Vision Transformers},
  author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e  and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
  journal={arXiv preprint arXiv:2104.14294},
  year={2021}
}
Owner
Facebook Research
Facebook Research
Unofficial Implement PU-Transformer

PU-Transformer-pytorch Pytorch unofficial implementation of PU-Transformer (PU-Transformer: Point Cloud Upsampling Transformer) https://arxiv.org/abs/

Lee Hyung Jun 7 Sep 21, 2022
🛰️ Awesome Satellite Imagery Datasets

Awesome Satellite Imagery Datasets List of aerial and satellite imagery datasets with annotations for computer vision and deep learning. Newest datase

Christoph Rieke 3k Jan 03, 2023
A sequence of Jupyter notebooks featuring the 12 Steps to Navier-Stokes

CFD Python Please cite as: Barba, Lorena A., and Forsyth, Gilbert F. (2018). CFD Python: the 12 steps to Navier-Stokes equations. Journal of Open Sour

Barba group 2.6k Dec 30, 2022
Official implementation of TMANet.

Temporal Memory Attention for Video Semantic Segmentation, arxiv Introduction We propose a Temporal Memory Attention Network (TMANet) to adaptively in

wanghao 94 Dec 02, 2022
code for our BMVC 2021 paper "HCV: Hierarchy-Consistency Verification for Incremental Implicitly-Refined Classification"

HCV_IIRC code for our BMVC 2021 paper HCV: Hierarchy-Consistency Verification for Incremental Implicitly-Refined Classification by Kai Wang, Xialei Li

kai wang 13 Oct 03, 2022
Keras-1D-NN-Classifier

Keras-1D-NN-Classifier This code is based on the reference codes linked below. reference 1, reference 2 This code is for 1-D array data classification

Jae-Hoon Shim 6 May 18, 2021
Can we do Customers Segmentation using PHP and Unsupervized Machine Learning ? Yes we can ! 🤡

Customers Segmentation using PHP and Rubix ML PHP Library Can we do Customers Segmentation using PHP and Unsupervized Machine Learning ? Yes we can !

Mickaël Andrieu 11 Oct 08, 2022
This repository accompanies our paper “Do Prompt-Based Models Really Understand the Meaning of Their Prompts?”

This repository accompanies our paper “Do Prompt-Based Models Really Understand the Meaning of Their Prompts?” Usage To replicate our results in Secti

Albert Webson 64 Dec 11, 2022
Analysis of Smiles through reservoir sampling & RDkit

Analysis of Smiles through reservoir sampling and machine learning (under development). This is a simple project that includes two Jupyter files for t

Aurimas A. Nausėdas 6 Aug 30, 2022
PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models This repository will reproduce the main results from our pape

Mitch Hill 32 Nov 25, 2022
implicit displacement field

Geometry-Consistent Neural Shape Representation with Implicit Displacement Fields [project page][paper][cite] Geometry-Consistent Neural Shape Represe

Yifan Wang 100 Dec 19, 2022
Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm

DeCLIP Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm. Our paper is available in arxiv Updates ** Ou

Sense-GVT 470 Dec 30, 2022
Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face Manipulation" published in CVPR 2020.

FFD Source Code Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face M

88 Nov 22, 2022
FAVD: Featherweight Assisted Vulnerability Discovery

FAVD: Featherweight Assisted Vulnerability Discovery This repository contains the replication package for the paper "Featherweight Assisted Vulnerabil

secureIT 4 Sep 16, 2022
Generate vibrant and detailed images using only text.

CLIP Guided Diffusion From RiversHaveWings. Generate vibrant and detailed images using only text. See captions and more generations in the Gallery See

Clay M. 401 Dec 28, 2022
Datasets, tools, and benchmarks for representation learning of code.

The CodeSearchNet challenge has been concluded We would like to thank all participants for their submissions and we hope that this challenge provided

GitHub 1.8k Dec 25, 2022
[ICCV21] Official implementation of the "Social NCE: Contrastive Learning of Socially-aware Motion Representations" in PyTorch.

Social-NCE + CrowdNav Website | Paper | Video | Social NCE + Trajectron | Social NCE + STGCNN This is an official implementation for Social NCE: Contr

VITA lab at EPFL 125 Dec 23, 2022
A simple python program that can be used to implement user authentication tokens into your program...

token-generator A simple python module that can be used by developers to implement user authentication tokens into your program... code examples creat

octo 6 Apr 18, 2022
pytorch implementation of ABC : Auxiliary Balanced Classifier for Class-imbalanced Semi-supervised Learning

ABC:Auxiliary Balanced Classifier for Class-imbalanced Semi-supervised Learning, NeurIPS 2021 pytorch implementation of ABC : Auxiliary Balanced Class

Hyuck Lee 25 Dec 22, 2022
A Haskell kernel for IPython.

IHaskell You can now try IHaskell directly in your browser at CoCalc or mybinder.org. Alternatively, watch a talk and demo showing off IHaskell featur

Andrew Gibiansky 2.4k Dec 29, 2022