PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Overview

PyTorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

DOI

Blog post with full documentation: Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Image of SimCLR Arch

See also PyTorch Implementation for BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning.

Installation

$ conda env create --name simclr --file env.yml
$ conda activate simclr
$ python run.py

Config file

Before running SimCLR, make sure you choose the correct running configurations. You can change the running configurations by passing keyword arguments to the run.py file.

$ python run.py -data ./datasets --dataset-name stl10 --log-every-n-steps 100 --epochs 100 

If you want to run it on CPU (for debugging purposes) use the --disable-cuda option.

For 16-bit precision GPU training, there NO need to to install NVIDIA apex. Just use the --fp16_precision flag and this implementation will use Pytorch built in AMP training.

Feature Evaluation

Feature evaluation is done using a linear model protocol.

First, we learned features using SimCLR on the STL10 unsupervised set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linear model is trained on features extracted from the STL10 train set and evaluated on the STL10 test set.

Check the Open In Colab notebook for reproducibility.

Note that SimCLR benefits from longer training.

Linear Classification Dataset Feature Extractor Architecture Feature dimensionality Projection Head dimensionality Epochs Top1 %
Logistic Regression (Adam) STL10 SimCLR ResNet-18 512 128 100 74.45
Logistic Regression (Adam) CIFAR10 SimCLR ResNet-18 512 128 100 69.82
Logistic Regression (Adam) STL10 SimCLR ResNet-50 2048 128 50 70.075
Comments
  • A question about the

    A question about the "labels"

    Hi! I have a question about the definition of "labels" in the script "simclr.py".

    On line 54 of "simclr.py", the authors defined:

    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

    So all the entries of "labels" are all zeros. But I think according to the paper, there should be an entry as 1 for the positive pair?

    Thanks in advance for your reply!

    opened by kekehia123 6
  • size of tensors in cosine_simiarity function

    size of tensors in cosine_simiarity function

    Hi , I'm trying to understand the code in : loss/nt_xent.py

    we are sending "representations" on both arguments

        def forward(self, zis, zjs):
            representations = torch.cat([zjs, zis], dim=0)
            similarity_matrix = self.similarity_function(representations, representations)
    

    But when receiving it in cosine_similarity func somehow the sizes are: (N, 1, C) and y shape: (1, 2N, C), how can it be double if you sent the same argument

        def _cosine_simililarity(self, x, y):
            # x shape: (N, 1, C)
            # y shape: (1, 2N, C)
            # v shape: (N, 2N)
            v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
            return v
    

    Thanks for your help.

    opened by BattashB 5
  • How do i train the SimCLR model with my local dataset?

    How do i train the SimCLR model with my local dataset?

    Dear researcher, Thank you for the open-source code you provided, it is of great help to me for understanding contrastive learning. But I still have some confusion when training the SimCLR model with my local dataset, could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

    opened by bestalllen 4
  • Question about CE Loss

    Question about CE Loss

    Hello,

    Thanks for sharing the code, nice implementation.

    The way you calculate the loss by using a mask is quite brilliant. But I have a question.

    logits = torch.cat((positives, negatives), dim=1) So if I'm not wrong, the first column of logits is positive and the rest are negatives.

    labels = torch.zeros(2 * self.batch_size).to(self.device).long() But your labels are all zeros, which means no matter positive or negative, the similarity should low.

    So I wonder is the first column of labels supposed to be 1 instead of 0.

    Thanks for your help.

    opened by WShijun1991 4
  • Issue with batch-size

    Issue with batch-size

    In function info_nce_loss, the line 28, creates labels based on batch_size and on other side we have STL10 dataset which has 100,000 images which is divisible by batch_size of 32 and having batch_size like 128 or 64 gives a remainder of 32.

    Having batch_size != 32, causes error in line 42, because the similarity matrix will based on features and labels will be based on batch size.

    For instance, if the batch size = 128, the remaining images in the dataset in the last iter of data_loader is 32. Since we create two variant of each image we'll have 64 images. Now we have 128 x 2 = 256 labels from line 28, and we'll have similarity matrix of (64 x 128, 128 x 64) => (64 x 64) but with mask (256 x 256) causing "dimension mismatch"

    Solution: Change Line 28 as below

    labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(self.args.n_views)], dim=0)

    image

    opened by Mayurji 3
  • 'CosineAnnealingLR' never works with the wrong position of 'scheduler.step()'

    'CosineAnnealingLR' never works with the wrong position of 'scheduler.step()'

    Considering the setting in 'scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)',I think 'scheduler.step()' should be called every step in 'for (xis,xjs),_ in train_loader'. Otherwise,lr will nerver change until 'len(train_loader)' epochs but not steps

    opened by GuohongLi 3
  • Is it something wrong with the training model for CIFAR-10 experiments?

    Is it something wrong with the training model for CIFAR-10 experiments?

    Hi,

    I find that the ResNet20 model for CIFAR-10 experiments is not fully correct. The head conv structure should be modified (stride=1 and no pooling,) because the image size of CIFAR-10 is very small.

    opened by timqqt 2
  • GPU utilization rate is low

    GPU utilization rate is low

    Hi, thanks for the code!

    When I tried to run it on single GPU (v-100), the utilizaiton rate is very low (~0-10%) even if I increase num_worker. Would you know why this happens and how to solve it? Thanks!

    opened by LiJunnan1992 2
  • Why cos_sim after L2 norm?

    Why cos_sim after L2 norm?

    Hi, This code is really useful for me. Thanks! But I got a question about the NT-Xent loss. I noticed that you use L2 norm on z and then use cos_similarity after that. But cos_similarity already contain the function of l2 norm. Why use L2 norm first?

    opened by BoPang1996 2
  • NT_Xent Loss function: all negatives are not being used?

    NT_Xent Loss function: all negatives are not being used?

    Hi @sthalles , Thank you for sharing your code!

    Pl correct me if I am wrong: I see that in line loss/nt_xent.py line 57 (below) you are not computing contrastive loss for all negative pairs as you are reshaping total negatives in 2D array i.e. only a part of negative pairs are being used for a single positive pair, right? :

    _negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)_
    _logits = torch.cat((positives, negatives), dim=1)_
    

    Hope to hear from you soon.

    -Ishan

    opened by DAVEISHAN 2
  • Validation Loss calculation

    Validation Loss calculation

    First of all, thank you for your great work!

    Method _validate in simclr.py will raise ZeroDivisionError at line 148 if the validation data loader performs only one iteration (since counter starts from 0).

    opened by alessiamarcolini 2
  • evaluation code batch_size & validation process

    evaluation code batch_size & validation process

    I'm really appreciated about your good work :) I left a question because I got confused while studying through your great code.

    First, I wonder why you used "batch_size=batch_size*2" differently from train_loader in the test_loader part of the file "mini_batch_logistic_regression_valuator.ipynb". Is it related to creating 2 views when doing data augmentation?

    Also, in the last cell of this file, I'm confused whether the second "for" (of the two "for") in the large epoch "for" statement corresponds to the test process or the validation process. I thought it was a test process, because loss update, backpropagation, optimization, etc. were done only in the first "for", and the second yield only accuracy, but is that right? Or I'm confused if the second "for" is a validating process because the first "for" and the second "for" are going together in the entire epoch processing.

    opened by YejinS 0
  • Review Training | Fine-Tune | Test details

    Review Training | Fine-Tune | Test details

    Hi, I just want to check all the experiments details and make sure I didn't miss any part(?

    1. Training Phase : use SimCLR (two encoder branches) to train on ImageNet for 1000 epochs to get a init pretrained weights.
    2. Fine-Tuned : load the init pretrained weights on the resnet18(50/101/...) with freezed parameters and concate with a linear classifier, and train the classifier with CIFAR10/STL10 training dataset for 100 epochs.
    3. Test Phase : freeze all the encoder, classifier parameters, and test on the CIFAR10/STL10 testing dataset.

    Is this the way how you get the top1 acc in the README?

    opened by Howeng98 0
  • Confusion matrix

    Confusion matrix

    Does anyone know how to add the confusion matrix in this code? After I added it according to the online one, something went wrong. I don't know what went wrong in my code.I can't solve it. please help help me! Thanks. def confusion_matrix(output, labels, conf_matrix):

    preds = torch.argmax(output, dim=-1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix
    
    opened by here101 0
  • batch size affect

    batch size affect

    Hi, I'm trying to experiment with CIFAR-10 with the default hyper-params, and it seems to yield a better score when using smaller batch size (e.g. 72% with batch size 256 yet 78% with batch size 128). Anyone in the same situation, here?

    opened by VietHoang1512 1
  •  ModuleNotFoundError: No module named 'torch.cuda'

    ModuleNotFoundError: No module named 'torch.cuda'

    I am using pythion 3.7 on Win10, Anaconda Jupyter. I have successfully installed torch-1.10.0+cu113 torchaudio-0.10.0+cu113 torchvision-0.11.1+cu113. When trying to import torch , I get ModuleNotFoundError: No module named 'torch.cuda' Detailed error:

    ModuleNotFoundError                       Traceback (most recent call last)
    <ipython-input-1-bfd2c657fa76> in <module>
          1 import numpy as np
          2 import pandas as pd
    ----> 3 import torch
          4 import torch.nn as nn
          5 from sklearn.model_selection import train_test_split
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\__init__.py in <module>
        603 
        604 # Shared memory manager needs to know the exact location of manager executable
    --> 605 _C._initExtension(manager_path())
        606 del manager_path
        607 
    
    ModuleNotFoundError: No module named 'torch.cuda'
    

    I found posts for similar error No module named 'torch.cuda.amp'. However, any of the suggested solutions worked. Please advise.

    opened by m-bor 0
Releases(v1.0.1)
Punctuation Restoration using Transformer Models for High-and Low-Resource Languages

Punctuation Restoration using Transformer Models This repository contins official implementation of the paper Punctuation Restoration using Transforme

Tanvirul Alam 142 Jan 01, 2023
Large scale and asynchronous Hyperparameter Optimization at your fingertip.

Syne Tune This package provides state-of-the-art distributed hyperparameter optimizers (HPO) where trials can be evaluated with several backend option

Amazon Web Services - Labs 236 Jan 01, 2023
Estimation of human density in a closed space using deep learning.

Siemens HOLLZOF challenge - Human Density Estimation Add project description here. Installing Dependencies: Install Python3 either system-wide, user-w

3 Aug 08, 2021
Air Quality Prediction Using LSTM

AirQualityPredictionUsingLSTM In this Repo, i present to you the winning solution of smart gujarat hackathon 2019 where the task was to predict the qu

Deepak Nandwani 2 Dec 13, 2022
AbelNN: Deep Learning Python module from scratch

AbelNN: Deep Learning Python module from scratch I have implemented several neural networks from scratch using only Numpy. I have designed the module

Abel 2 Apr 12, 2022
A powerful framework for decentralized federated learning with user-defined communication topology

Scatterbrained Decentralized Federated Learning Scatterbrained makes it easy to build federated learning systems. In addition to traditional federated

Johns Hopkins Applied Physics Laboratory 7 Sep 26, 2022
Using Language Model to Bootstrap Human Activity Recognition Ambient Sensors Based in Smart Homes

Using Language Model to Bootstrap Human Activity Recognition Ambient Sensors Based in Smart Homes This repository is the official implementation of Us

Damien Bouchabou 0 Oct 18, 2021
TransMVSNet: Global Context-aware Multi-view Stereo Network with Transformers.

TransMVSNet This repository contains the official implementation of the paper: "TransMVSNet: Global Context-aware Multi-view Stereo Network with Trans

旷视研究院 3D 组 155 Dec 29, 2022
METS/ALTO OCR enhancing tool by the National Library of Luxembourg (BnL)

Nautilus-OCR The National Library of Luxembourg (BnL) started its first initiative in digitizing newspapers, with layout recognition and OCR on articl

National Library of Luxembourg 36 Dec 05, 2022
A fast, scalable, high performance Gradient Boosting on Decision Trees library, used for ranking, classification, regression and other machine learning tasks for Python, R, Java, C++. Supports computation on CPU and GPU.

Website | Documentation | Tutorials | Installation | Release Notes CatBoost is a machine learning method based on gradient boosting over decision tree

CatBoost 6.9k Jan 04, 2023
IAST: Instance Adaptive Self-training for Unsupervised Domain Adaptation (ECCV 2020)

This repo is the official implementation of our paper "Instance Adaptive Self-training for Unsupervised Domain Adaptation". The purpose of this repo is to better communicate with you and respond to y

CVSM Group - email: <a href=[email protected]"> 84 Dec 12, 2022
Evolutionary Scale Modeling (esm): Pretrained language models for proteins

Evolutionary Scale Modeling This repository contains code and pre-trained weights for Transformer protein language models from Facebook AI Research, i

Meta Research 1.6k Jan 09, 2023
Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021, Pytorch)

S2VD Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021) Requirements and Dependencies Ubuntu 16.04, cuda 10.0 Python 3.6.10, P

Zongsheng Yue 53 Nov 23, 2022
Codes and pretrained weights for winning submission of 2021 Brain Tumor Segmentation (BraTS) Challenge

Winning submission to the 2021 Brain Tumor Segmentation Challenge This repo contains the codes and pretrained weights for the winning submission to th

94 Dec 28, 2022
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

QData 440 Jan 02, 2023
An open source app to help calm you down when needed.

By: Seanpm2001, Et; Al. Top README.md Read this article in a different language Sorted by: A-Z Sorting options unavailable ( af Afrikaans Afrikaans |

Sean P. Myrick V19.1.7.2 2 Oct 24, 2022
EZ graph is an easy to use AI solution that allows you to make and train your neural networks without a single line of code.

EZ-Graph EZ Graph is a GUI that allows users to make and train neural networks without writing a single line of code. Requirements python 3 pandas num

1 Jul 03, 2022
Faune proche - Retrieval of Faune-France data near a google maps location

faune_proche Récupération des données de Faune-France près d'un lieu google maps

4 Feb 15, 2022
Pytorch implementation of paper: "NeurMiPs: Neural Mixture of Planar Experts for View Synthesis"

NeurMips: Neural Mixture of Planar Experts for View Synthesis This is the official repo for PyTorch implementation of paper "NeurMips: Neural Mixture

James Lin 101 Dec 13, 2022