A PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-Supervised Learning Framework".

Related tags

Deep Learningmugs
Overview

Mugs: A Multi-Granular Self-Supervised Learning Framework

This is a PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-Supervised Learning Framework". arXiv

PWC

Overall framework of Mugs.

Fig 1. Overall framework of Mugs. In (a), for each image, two random crops of one image are fed into backbones of student and teacher. Three granular supervisions: 1) instance discrimination supervision, 2) local-group discrimination supervision, and 3) group discrimination supervision, are adopted to learn multi-granular representation. In (b), local-group modules in student/teacher averages all patch tokens, and finds top-k neighbors from memory buffer to aggregate them with the average for obtaining a local-group feature.

Pretrained models on ImageNet-1K

You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks.

Table 1. KNN and linear probing performance with their corresponding hyper-parameters, logs and model weights.

arch params pretraining epochs k-nn linear download
ViT-S/16 21M 100 72.3% 76.4% backbone only full ckpt args logs eval logs
ViT-S/16 21M 300 74.8% 78.2% backbone only full ckpt args logs eval logs
ViT-S/16 21M 800 75.6% 78.9% backbone only full ckpt args logs eval logs
ViT-B/16 85M 400 78.0% 80.6% backbone only full ckpt args logs eval logs
ViT-L/16 307M 250 80.3% 82.1% backbone only full ckpt args logs eval logs
Comparison of linear probing accuracy on ImageNet-1K.

Fig 2. Comparison of linear probing accuracy on ImageNet-1K.

Pretraining Settings

Environment

For reproducing, please install PyTorch and download the ImageNet dataset. This codebase has been developed with python version 3.8, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. For the full environment, please refer to our Dockerfile file.

ViT pretraining 🍺

To pretraining each model, please find the exact hyper-parameter settings at the args column of Table 1. For training log and linear probing log, please refer to the log and eval logs column of Table 1.

ViT-Small pretraining:

To run ViT-small for 100 epochs, we use two nodes of total 8 A100 GPUs (total 512 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=8 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small 
--group_teacher_temp 0.04 --group_warmup_teacher_temp_epochs 0 --weight_decay_end 0.2 --norm_last_layer false --epochs 100

To run ViT-small for 300 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small 
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 300

To run ViT-small for 800 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small 
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 800

ViT-Base pretraining:

To run ViT-base for 400 epochs, we use two nodes of total 24 A100 GPUs (total 1024 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=24 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_base 
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --min_lr 2e-06 --weight_decay_end 0.1 --freeze_last_layer 3 --norm_last_layer 
false --epochs 400

ViT-Large pretraining:

To run ViT-large for 250 epochs, we use two nodes of total 40 A100 GPUs (total 640 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=40 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_large 
--lr 0.0015 --min_lr 1.5e-4 --group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --weight_decay 0.025 
--weight_decay_end 0.08 --norm_last_layer true --drop_path_rate 0.3 --freeze_last_layer 3 --epochs 250

Evaluation

We are cleaning up the evalutation code and will release them when they are ready.

Self-attention visualization

Here we provide the self-attention map of the [CLS] token on the heads of the last layer

Self-attention from a ViT-Base/16 trained with Mugs

Fig 3. Self-attention from a ViT-Base/16 trained with Mugs.

T-SNE visualization

Here we provide the T-SNE visualization of the learned feature by ViT-B/16. We show the fish classes in ImageNet-1K, i.e., the first six classes, including tench, goldfish, white shark, tiger shark, hammerhead, electric ray. See more examples in Appendix.

T-SNE visualization of the learned feature by ViT-B/16.

Fig 4. T-SNE visualization of the learned feature by ViT-B/16.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citation

If you find this repository useful, please consider giving a star and citation 🍺 :

@inproceedings{mugs2022SSL,
  title={Mugs: A Multi-Granular Self-Supervised Learning Framework},
  author={Pan Zhou and Yichen Zhou and Chenyang Si and Weihao Yu and Teck Khim Ng and Shuicheng Yan},
  booktitle={arXiv preprint arXiv:2203.14415},
  year={2022}
}
Owner
Sea AI Lab
Sea AI Lab
Pytorch implementation of the paper "Class-Balanced Loss Based on Effective Number of Samples"

Class-balanced-loss-pytorch Pytorch implementation of the paper Class-Balanced Loss Based on Effective Number of Samples presented at CVPR'19. Yin Cui

Vandit Jain 697 Dec 29, 2022
Photo2cartoon - 人像卡通化探索项目 (photo-to-cartoon translation project)

人像卡通化 (Photo to Cartoon) 中文版 | English Version 该项目为小视科技卡通肖像探索项目。您可使用微信扫描下方二维码或搜索“AI卡通秀”小程序体验卡通化效果。

Minivision_AI 3.5k Dec 30, 2022
This PyTorch package implements MoEBERT: from BERT to Mixture-of-Experts via Importance-Guided Adaptation (NAACL 2022).

MoEBERT This PyTorch package implements MoEBERT: from BERT to Mixture-of-Experts via Importance-Guided Adaptation (NAACL 2022). Installation Create an

Simiao Zuo 34 Dec 24, 2022
Implementation for Curriculum DeepSDF

Curriculum-DeepSDF This repository is an implementation for Curriculum DeepSDF. Full paper is available here. Preparation Please follow original setti

Haidong Zhu 69 Dec 29, 2022
A framework for analyzing computer vision models with simulated data

3DB: A framework for analyzing computer vision models with simulated data Paper Quickstart guide Blog post Installation Follow instructions on: https:

3DB 112 Jan 01, 2023
VISNOTATE: An Opensource tool for Gaze-based Annotation of WSI Data

VISNOTATE: An Opensource tool for Gaze-based Annotation of WSI Data Introduction Requirements Installation and Setup Supported Hardware and Software R

SigmaLab 1 Jun 14, 2022
[SIGGRAPH 2022 Journal Track] AvatarCLIP: Zero-Shot Text-Driven Generation and Animation of 3D Avatars

AvatarCLIP: Zero-Shot Text-Driven Generation and Animation of 3D Avatars Fangzhou Hong1*  Mingyuan Zhang1*  Liang Pan1  Zhongang Cai1,2,3  Lei Yang2 

Fangzhou Hong 749 Jan 04, 2023
Official implementation of VQ-Diffusion

Vector Quantized Diffusion Model for Text-to-Image Synthesis Overview This is the official repo for the paper: [Vector Quantized Diffusion Model for T

Microsoft 592 Jan 03, 2023
The official repository for Deep Image Matting with Flexible Guidance Input

FGI-Matting The official repository for Deep Image Matting with Flexible Guidance Input. Paper: https://arxiv.org/abs/2110.10898 Requirements easydict

Hang Cheng 51 Nov 10, 2022
Neon: an add-on for Lightbulb making it easier to handle component interactions

Neon Neon is an add-on for Lightbulb making it easier to handle component interactions. Installation pip install git+https://github.com/neonjonn/light

Neon Jonn 9 Apr 29, 2022
This repo holds the code of TransFuse: Fusing Transformers and CNNs for Medical Image Segmentation

TransFuse This repo holds the code of TransFuse: Fusing Transformers and CNNs for Medical Image Segmentation Requirements Pytorch=1.6.0, 1.9.0 (=1.

Rayicer 93 Dec 19, 2022
This project implements "virtual speed" from heart rate monito

ANT+ Virtual Stride Based Speed and Distance Monitor Overview This project imple

2 May 20, 2022
Source code of SIGIR2021 Paper 'One Chatbot Per Person: Creating Personalized Chatbots based on Implicit Profiles'

DHAP Source code of SIGIR2021 Long Paper: One Chatbot Per Person: Creating Personalized Chatbots based on Implicit User Profiles . Preinstallation Fir

ZYMa 32 Dec 06, 2022
VIsually-Pivoted Audio and(N) Text

VIP-ANT: VIsually-Pivoted Audio and(N) Text Code for the paper Connecting the Dots between Audio and Text without Parallel Data through Visual Knowled

Yän.PnG 16 Nov 04, 2022
Trustworthy AI related projects

Trustworthy AI This repository aims to include trustworthy AI related projects from Huawei Noah's Ark Lab. Current projects include: Causal Structure

HUAWEI Noah's Ark Lab 589 Dec 30, 2022
A Transformer-Based Siamese Network for Change Detection

ChangeFormer: A Transformer-Based Siamese Network for Change Detection (Under review at IGARSS-2022) Wele Gedara Chaminda Bandara, Vishal M. Patel Her

Wele Gedara Chaminda Bandara 214 Dec 29, 2022
Official PyTorch Implementation of HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning (NeurIPS 2021 Spotlight)

[NeurIPS 2021 Spotlight] HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning [Paper] This is Official PyTorch implementatio

42 Nov 01, 2022
Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Swin-Transformer-Tensorflow A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Sh

52 Dec 29, 2022
Official code repository for the EMNLP 2021 paper

Integrating Visuospatial, Linguistic and Commonsense Structure into Story Visualization PyTorch code for the EMNLP 2021 paper "Integrating Visuospatia

Adyasha Maharana 23 Dec 19, 2022
Interpretation of T cell states using reference single-cell atlases

Interpretation of T cell states using reference single-cell atlases ProjecTILs is a computational method to project scRNA-seq data into reference sing

Cancer Systems Immunology Lab 139 Jan 03, 2023