Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch

Overview

Bootstrap Your Own Latent (BYOL), in Pytorch

PyPI version

Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs.

This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data.

Update 1: There is now new evidence that batch normalization is key to making this technique work well

Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work

Update 3: Finally, we have some analysis for why this works

Yannic Kilcher's excellent explanation

Now go save your organization from having to pay for labels :)

Install

$ pip install byol-pytorch

Usage

Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks.

BYOL → SimSiam

A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the use_momentum flag to False. You will no longer need to invoke update_moving_average if you go this route as shown in the example below.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum = False       # turn off momentum in the target encoder
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

Advanced

While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class.

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    projection_size = 256,           # the projection size
    projection_hidden_size = 4096,   # the hidden dimension of the MLP for both the projection and prediction
    moving_average_decay = 0.99      # the moving average decay factor for the target encoder, already set at what paper recommends
)

By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the augment_fn keyword.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn
)

In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

augment_fn2 = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip(),
    kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn,
    augment_fn2 = augment_fn2,
)

To fetch the embeddings or the projections, you simply have to pass in a return_embeddings = True flag to the BYOL learner instance

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

imgs = torch.randn(2, 3, 256, 256)
projection, embedding = learner(imgs, return_embedding = True)

Alternatives

If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning.

https://github.com/lucidrains/pixel-level-contrastive-learning

Citation

@misc{grill2020bootstrap,
    title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning},
    author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko},
    year = {2020},
    eprint = {2006.07733},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{chen2020exploring,
    title={Exploring Simple Siamese Representation Learning}, 
    author={Xinlei Chen and Kaiming He},
    year={2020},
    eprint={2011.10566},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
Comments
  • Negative Loss, Transfer Learning/Fine-Tuning Question

    Negative Loss, Transfer Learning/Fine-Tuning Question

    Hi! Thanks for sharing this repo -- really clean and easy to use.

    When training using the PyTorch Lightning script from the repo, my loss is negative (and gets more negative over time) when training. Is this expected? Screenshot 2020-06-22 at 6 23 47 PM


    I'm curious to know if you've fine-tuned a pretrained model using this BYOL as the README example suggested. If yes, how were the results? Any intuition regarding how many epochs to fine-tune for?

    Thanks!

    opened by rsomani95 13
  • AssertionError: hidden layer never emitted an output with multi-gpu training

    AssertionError: hidden layer never emitted an output with multi-gpu training

    I tried your library with a WideResnet40-2 model and used layer_index=-2.

    The lightning example works fine for single-gpu but i got the error with multiple GPUs.

    opened by reactivetype 7
  • How to transfer the trained ckpt to pytorch.pth model?

    How to transfer the trained ckpt to pytorch.pth model?

    I use the example script to train a model, I got a ckpt file. but how could I extra the trained resnet50.pth instead of the whole SelfSupervisedLearner? Sorry I am new for pytorch lightning lib. What I want is the SelfSupervised resnet50.pth, because I want this to replace the original ImageNet-pretrained one. Thank you a lot.

    opened by knaffe 5
  • Training loss decreased and then increased

    Training loss decreased and then increased

    Hi, I used your example on my own data. The training loss decreased and then increased after 100 epochs, which is wired. Did you meet similar situations? Is it hard to train the model? the batchsize is 128/256 lr is 0.1/0.2 weight_decay is 1e-6

    opened by easonyang1996 4
  • Can't load ckpt

    Can't load ckpt

    I use byol-pytorch-master/examples/lightning/train.py to generate ckpt locally after training, but when I load ckpt, there will be the following errors. How should I load it? Thanks a lot! 截屏2020-11-18 上午12 51 48

    opened by AndrewTal 4
  • BYOL uses different augmentations for view1 and view2

    BYOL uses different augmentations for view1 and view2

    opened by OlivierDehaene 4
  • Transferring results on Cifar and other datasets

    Transferring results on Cifar and other datasets

    Thanks for your open sourcing!

    I notice that the BYOL has a large gap on the transferring downstream datasets: e.g., SimCLR reaches 71.6% on Cifar 100, while BYOL can reach to 78.4%.

    I understand that this might depends on the downstream training protocols. And could you provide us a sample code on that, especially for the LBFGS optimized logistic regressor?

    opened by jacobswan1 4
  • The saved network is same as the initial one?

    The saved network is same as the initial one?

    Firstly, thank you so much for this clean implementation!!

    The self-supervised training process looks good, but the saved (i.e. improved) model is exactly the same as the initial one on my side. Have you observed the same problem?

    The code I tested:

    import torch
    from net.byol import BYOL
    from torchvision import models
     
           
    resnet = models.resnet50(pretrained=True)
    param_1 = resnet.parameters()
    
    learner = BYOL(
        resnet,
        image_size = 256,
        hidden_layer = 'avgpool'
    )
    
    opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
    
    def sample_unlabelled_images():
        return torch.randn(20, 3, 256, 256)
    
    for _ in range(2):
        images = sample_unlabelled_images()
        loss = learner(images)
        opt.zero_grad()
        loss.backward()
        opt.step()
        learner.update_moving_average() # update moving average of target encoder
    
    # save your improved network
    torch.save(resnet.state_dict(), './checkpoints/improved-net.pt')
    
    # restore the model      
    resnet2 = models.resnet50()
    resnet2.load_state_dict(torch.load('./checkpoints/improved-net.pt'))
    param_2 = resnet2.parameters()
    
    # test whether two models are the same 
    for p1, p2 in zip(param_1, param_2):
        if p1.data.ne(p2.data).sum() > 0:
            print('They are different.')
    print('They are same.')
    
    opened by KimMeen 3
  • the maximum batch size can only be set to 32

    the maximum batch size can only be set to 32

    When I run the code with a 2080ti GPU with 10G memory, the maximum batch size can only be set to 32. Is there any place in the code that takes up a lot of video memory?

    opened by cuixianheng 3
  • Pretrained network

    Pretrained network

    Hi, thanks for sharing the code and making it so easy to use. I see in the example you set resnet = models.resnet50(pretrained=True). Is this what is done in the paper? Shouldn't self-supervised-learned networks be trained from scratch?

    Thanks again, P.

    opened by pmorerio 3
  • Singleton Class Members

    Singleton Class Members

    Forgive me for my unfamiliarity with software design, but I'm wondering why it is necessary to write a singleton wrapper for projector and target_encoder. Is there any disadvantage of initializing them in __init__?

    opened by wentaoyuan 3
  • Increase EMA-parameter during training

    Increase EMA-parameter during training

    Hi, I noticed that the EMA-parameter (called beta in the code, τ in the paper) is not updated during training. In the paper they describe that they increase τ from the start value to 1 during training: "Specifically, we set τ = 1 − (1 − τbase) · (cos(πk/K) + 1)/2 with k the current training step and K the maximum number of training steps." This makes a huge difference to the validation loss at the end of the training.

    without_tau_update with_tau_update

    opened by Benjamin-Hansson 1
  • Why the loss is different from BYOL authors'

    Why the loss is different from BYOL authors'

    I found the loss is different from the loss said in BYOL paper which should be a L2 loss and I did't find explanation... The loss in this repo is a cosine loss, and I just want to know why. BTW, thanks for this great repo!

    opened by Jing-XING 2
  • How to cluster/predict images?

    How to cluster/predict images?

    Hi, I have trained using examples given with pytorch-lightning. I couldn't find code to do clustering of images after training. How can I find which image falls in which cluster? Is there any predictor API? I want to do something like this

    image

    opened by laxmimerit 1
  • BN layer weights and biases are not updated

    BN layer weights and biases are not updated

    Thanks for sharing this repo, great work!

    I trained BYOL on my data and noticed that the weights and biases for BN layers are not updated on the saved model. I used resnet18 without pretrained weights resnet = models.resnet50(pretrained=False). After training for multiple epochs, the saved model has bn1.weight all equal to 1.0 and bn1.bias all equal to 0.0 .

    Is this the expected behavior or am I missing something? Appreciate your response!

    opened by kregmi 1
  •  Warning: grad and param do not obey the gradient layout contract.

    Warning: grad and param do not obey the gradient layout contract.

    Has anybody gotten a similar warning when using it?

    Warning: grad and param do not obey the gradient layout contract. This is not an error, but may impair performance. grad.sizes() = [512, 256, 1, 1], strides() = [256, 1, 1, 1] param.sizes() = [512, 256, 1, 1], strides() = [256, 1, 256, 256] (function operator())

    opened by mohaEs 3
Releases(0.6.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Implementation for the EMNLP 2021 paper "Interactive Machine Comprehension with Dynamic Knowledge Graphs".

Interactive Machine Comprehension with Dynamic Knowledge Graphs Implementation for the EMNLP 2021 paper. Dependencies apt-get -y update apt-get instal

Xingdi (Eric) Yuan 19 Aug 23, 2022
A fast poisson image editing implementation that can utilize multi-core CPU or GPU to handle a high-resolution image input.

Poisson Image Editing - A Parallel Implementation Jiayi Weng (jiayiwen), Zixu Chen (zixuc) Poisson Image Editing is a technique that can fuse two imag

Jiayi Weng 110 Dec 27, 2022
Code for paper "Context-self contrastive pretraining for crop type semantic segmentation"

Code for paper "Context-self contrastive pretraining for crop type semantic segmentation" Setting up a python environment Follow the instruction in ht

Michael Tarasiou 11 Oct 09, 2022
Download files from DSpace systems (because for some reason DSpace won't let you)

DSpaceDL A tool for downloading files from DSpace items. For some reason, DSpace systems have a dogshit UI, and Universities absolutely LOOOVE to use

Soumitra Shewale 5 Dec 01, 2022
TorchOk - The toolkit for fast Deep Learning experiments in Computer Vision

TorchOk - The toolkit for fast Deep Learning experiments in Computer Vision

52 Dec 23, 2022
Open source implementation of AceNAS: Learning to Rank Ace Neural Architectures with Weak Supervision of Weight Sharing

AceNAS This repo is the experiment code of AceNAS, and is not considered as an official release. We are working on integrating AceNAS as a built-in st

Yuge Zhang 6 Sep 07, 2022
Wandb-predictions - WANDB Predictions With Python

WANDB API CI/CD Below we capture the CI/CD scenarios that we would expect with o

Anish Shah 6 Oct 07, 2022
PyTorch Kafka Dataset: A definition of a dataset to get training data from Kafka.

PyTorch Kafka Dataset: A definition of a dataset to get training data from Kafka.

ERTIS Research Group 7 Aug 01, 2022
A collection of easy-to-use, ready-to-use, interesting deep neural network models

Interesting and reproducible research works should be conserved. This repository wraps a collection of deep neural network models into a simple and un

Aria Ghora Prabono 16 Jun 16, 2022
A little Python application to auto tag your photos with the power of machine learning.

Tag Machine A little Python application to auto tag your photos with the power of machine learning. Report a bug or request a feature Table of Content

Florian Torres 14 Dec 21, 2022
Patch SVDD for Image anomaly detection

Patch SVDD Patch SVDD for Image anomaly detection. Paper: https://arxiv.org/abs/2006.16067 (published in ACCV 2020). Original Code : https://github.co

Hong-Jeongmin 0 Dec 03, 2021
Code base of object detection

rmdet code base of object detection. 环境安装: 1. 安装conda python环境 - `conda create -n xxx python=3.7/3.8` - `conda activate xxx` 2. 运行脚本,自动安装pytorch1

3 Mar 08, 2022
李云龙二次元风格化!打滚卖萌,使用了animeGANv2进行了视频的风格迁移

李云龙二次元风格化!一键star、fork,你也可以生成这样的团长! 打滚卖萌求star求fork! 0.效果展示 视频效果前往B站观看效果最佳:李云龙二次元风格化: github开源repo:李云龙二次元风格化 百度AIstudio开源地址,一键fork即可运行: 李云龙二次元风格化!一键fork

oukohou 44 Dec 04, 2022
Using pretrained language models for biomedical knowledge graph completion.

LMs for biomedical KG completion This repository contains code to run the experiments described in: Scientific Language Models for Biomedical Knowledg

Rahul Nadkarni 41 Nov 30, 2022
A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

Yunxia Zhao 3 Dec 29, 2022
CVNets: A library for training computer vision networks

CVNets: A library for training computer vision networks This repository contains the source code for training computer vision models. Specifically, it

Apple 1.1k Jan 03, 2023
An auto discord account and token generator. Automatically verifies the phone number. Works without proxy. Bypasses captcha.

JOIN DISCORD SERVER https://discord.gg/uAc3agBY FREE HCAPTCHA SOLVING API Discord-Token-Gen An auto discord token generator. Auto verifies phone numbe

3kp 271 Jan 01, 2023
Code for our CVPR 2022 Paper "GEN-VLKT: Simplify Association and Enhance Interaction Understanding for HOI Detection"

GEN-VLKT Code for our CVPR 2022 paper "GEN-VLKT: Simplify Association and Enhance Interaction Understanding for HOI Detection". Contributed by Yue Lia

Yue Liao 47 Dec 04, 2022
Code for the paper "SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness" (NeurIPS 2021)

SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness (NeurIPS2021) This repository contains code for the paper "Smo

Jongheon Jeong 17 Dec 27, 2022
Code & Models for 3DETR - an End-to-end transformer model for 3D object detection

3DETR: An End-to-End Transformer Model for 3D Object Detection PyTorch implementation and models for 3DETR. 3DETR (3D DEtection TRansformer) is a simp

Facebook Research 487 Dec 31, 2022