PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

Related tags

Deep Learningdino
Overview

Self-Supervised Vision Transformers with DINO

PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supervised Vision Transformers.
[blogpost] [arXiv]

DINO illustration

Pretrained models

You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the training and evaluation logs.

arch params k-nn linear download
DeiT-S/16 21M 74.5% 77.0% backbone only full checkpoint args logs eval logs
DeiT-S/8 21M 78.3% 79.7% backbone only full checkpoint args logs eval logs
ViT-B/16 85M 76.1% 78.2% backbone only full checkpoint args logs eval logs
ViT-B/8 85M 77.4% 80.1% backbone only full checkpoint args logs eval logs
ResNet-50 23M 67.5% 75.3% backbone only full checkpoint args logs eval logs

The pretrained models are available on PyTorch Hub.

import torch
deits16 = torch.hub.load('facebookresearch/dino', 'dino_deits16')
deits8 = torch.hub.load('facebookresearch/dino', 'dino_deits8')
vitb16 = torch.hub.load('facebookresearch/dino', 'dino_vitb16')
vitb8 = torch.hub.load('facebookresearch/dino', 'dino_vitb8')
resnet50 = torch.hub.load('facebookresearch/dino', 'dino_resnet50')

Training

Documentation

Please install PyTorch and download the ImageNet dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the args column of the pretrained models section. For a glimpse at the full documentation of DINO training please run:

python main_dino.py --help

Vanilla DINO training 🦕

Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach ~69.3% on k-NN eval and ~73.8% on linear eval. We will shortly provide training and linear evaluation logs for this run to help reproducibility.

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Multi-node training

We use Slurm and submitit (pip install submitit). To train on 2 nodes with 8 GPUs each (total 16 GPUs):

python run_with_submitit.py --nodes 2 --ngpus 8 --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
DINO with ViT-base network.
python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base  --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Boosting DINO performance 🦖

You can improve the performance of the vanilla run by:

  • training for more epochs: --epochs 300,
  • increasing the teacher temperature: --teacher_temp 0.07 --warmup_teacher_temp_epochs 30.
  • removing last layer normalization (only safe with --arch deit_small): --norm_last_layer false,
Full command.
python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

The resulting pretrained model should reach ~73.4% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We will shortly provide training and linear evaluation logs for this run to help reproducibility.

ResNet-50 and other convnets trainings

This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example here is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs:

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Evaluation: k-NN classification on ImageNet

To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet

If you choose not to specify --pretrained_weights, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet 

Evaluation: Linear classification on ImageNet

To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet

Self-attention visualization

You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:

python visualize_attention.py
Self-attention from a Vision Transformer with 8x8 patches trained with DINO

License

See the LICENSE file for more details.

Citation

If you find this repository useful, please consider giving a star and citation 🦖 :

@article{caron2021emerging,
  title={Emerging Properties in Self-Supervised Vision Transformers},
  author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e  and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
  journal={arXiv preprint arXiv:2104.14294},
  year={2021}
}
Owner
Facebook Research
Facebook Research
This is an implementation of PIFuhd based on Pytorch

Open-PIFuhd This is a unofficial implementation of PIFuhd PIFuHD: Multi-Level Pixel-Aligned Implicit Function forHigh-Resolution 3D Human Digitization

Lingteng Qiu 235 Dec 19, 2022
Unofficial PyTorch Implementation of AHDRNet (CVPR 2019)

AHDRNet-PyTorch This is the PyTorch implementation of Attention-guided Network for Ghost-free High Dynamic Range Imaging (CVPR 2019). The official cod

Yutong Zhang 4 Sep 08, 2022
Vector Quantized Diffusion Model for Text-to-Image Synthesis

Vector Quantized Diffusion Model for Text-to-Image Synthesis Due to company policy, I have to set microsoft/VQ-Diffusion to private for now, so I prov

Shuyang Gu 294 Jan 05, 2023
Official implementation of DreamerPro: Reconstruction-Free Model-Based Reinforcement Learning with Prototypical Representations in TensorFlow 2

DreamerPro Official implementation of DreamerPro: Reconstruction-Free Model-Based Reinforcement Learning with Prototypical Representations in TensorFl

22 Nov 01, 2022
nn_builder lets you build neural networks with less boilerplate code

nn_builder lets you build neural networks with less boilerplate code. You specify the type of network you want and it builds it. Install pip install n

Petros Christodoulou 157 Nov 20, 2022
An intelligent, flexible grammar of machine learning.

An english representation of machine learning. Modify what you want, let us handle the rest. Overview Nylon is a python library that lets you customiz

Palash Shah 79 Dec 02, 2022
An energy estimator for eyeriss-like DNN hardware accelerator

Energy-Estimator-for-Eyeriss-like-Architecture- An energy estimator for eyeriss-like DNN hardware accelerator This is an energy estimator for eyeriss-

HEXIN BAO 2 Mar 26, 2022
Official PyTorch implementation of Data-free Knowledge Distillation for Object Detection, WACV 2021.

Introduction This repository is the official PyTorch implementation of Data-free Knowledge Distillation for Object Detection, WACV 2021. Data-free Kno

NVIDIA Research Projects 50 Jan 05, 2023
Python Jupyter kernel using Poetry for reproducible notebooks

Poetry Kernel Use per-directory Poetry environments to run Jupyter kernels. No need to install a Jupyter kernel per Python virtual environment! The id

Pathbird 204 Jan 04, 2023
This repository contains the official code of the paper Equivariant Subgraph Aggregation Networks (ICLR 2022)

Equivariant Subgraph Aggregation Networks (ESAN) This repository contains the official code of the paper Equivariant Subgraph Aggregation Networks (IC

Beatrice Bevilacqua 59 Dec 13, 2022
WRENCH: Weak supeRvision bENCHmark

🔧 What is it? Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development

Jieyu Zhang 176 Dec 28, 2022
Score refinement for confidence-based 3D multi-object tracking

Score refinement for confidence-based 3D multi-object tracking Our video gives a brief explanation of our Method. This is the official code for the pa

Cognitive Systems Research Group 47 Dec 26, 2022
Template repository for managing machine learning research projects built with PyTorch-Lightning

Tutorial Repository with a minimal example for showing how to deploy training across various compute infrastructure.

Sidd Karamcheti 3 Feb 11, 2022
The 2nd place solution of 2021 google landmark retrieval on kaggle.

Leaderboard, taxonomy, and curated list of few-shot object detection papers.

229 Dec 13, 2022
This is an open source python repository for various python tests

Welcome to Py-tests This is an open source python repository for various python tests. This is in response to the hacktoberfest2021 challenge. It is a

Yada Martins Tisan 3 Oct 31, 2021
This code uses generative adversarial networks to generate diverse task allocation plans for Multi-agent teams.

Mutli-agent task allocation This code uses generative adversarial networks to generate diverse task allocation plans for Multi-agent teams. To change

Biorobotics Lab 5 Oct 12, 2022
CoSMA: Convolutional Semi-Regular Mesh Autoencoder. From Paper "Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes"

Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes Implementation of CoSMA: Convolutional Semi-Regular Mesh Autoencoder arXiv p

Fraunhofer SCAI 10 Oct 11, 2022
Temporal Segment Networks (TSN) in PyTorch

TSN-Pytorch We have released MMAction, a full-fledged action understanding toolbox based on PyTorch. It includes implementation for TSN as well as oth

1k Jan 03, 2023
PyTea: PyTorch Tensor shape error analyzer

PyTea: PyTorch Tensor Shape Error Analyzer paper project page Requirements node.js = 12.x python = 3.8 z3-solver = 4.8 How to install and use # ins

ROPAS Lab. 240 Jan 02, 2023
PyTorch/TorchScript compiler for NVIDIA GPUs using TensorRT

PyTorch/TorchScript compiler for NVIDIA GPUs using TensorRT

NVIDIA Corporation 1.8k Dec 30, 2022