A concise but complete implementation of CLIP with various experimental improvements from recent papers

Overview

x-clip (wip)

A concise but complete implementation of CLIP with various experimental improvements from recent papers

Install

$ pip install x-clip

Usage

import torch
from x_clip import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    num_visual_tokens = 512,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
    use_all_token_embeds = True   # whether to use fine-grained contrastive learning (FILIP)
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()

Citations

@misc{radford2021learning,
    title   = {Learning Transferable Visual Models From Natural Language Supervision}, 
    author  = {Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
    year    = {2021},
    eprint  = {2103.00020},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yao2021filip,
    title   = {FILIP: Fine-grained Interactive Language-Image Pre-Training}, 
    author  = {Lewei Yao and Runhui Huang and Lu Hou and Guansong Lu and Minzhe Niu and Hang Xu and Xiaodan Liang and Zhenguo Li and Xin Jiang and Chunjing Xu},
    year    = {2021},
    eprint  = {2111.07783},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • Model forward outputs to text/image similarity score

    Model forward outputs to text/image similarity score

    Any insight on how to take the image/text embeddings (or nominal model forward output) to achieve a simple similarity score as done in the huggingface implementation? HF example here

    In the original paper I see the dot products of the image/text encoder outputs were used, but here I was having troubles with the dimensions on the outputs.

    opened by paulcjh 12
  • Using different encoders in CLIP

    Using different encoders in CLIP

    Hi, I am wondering if it was possible to use different encoders in CLIP ? For images not using vit but resnet for example. And is it possible to replace the text encoder by a features encoder for example ? If I have a vector of features for a given image and I want to use x-clip how should I do that ? I have made a code example that doesnt seems to work, here is what I did:

    import torch
    from x_clip import CLIP
    import torch.nn as nn
    from torchvision import models
    
    class Image_Encoder(torch.nn.Module):
        #output size is (bs,512)
        def __init__(self):
            super(Image_Encoder, self).__init__()
            self.model_pre = models.resnet18(pretrained=False)
            self.base=nn.Sequential(*list(self.model_pre.children()))
            self.base[0]=nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            self.resnet=self.base[:-1]
    
        def forward(self, x):
            out=self.resnet(x).squeeze()
            return out
    
    
    class features_encoder(torch.nn.Module):
        #output size is (bs,512)
        def __init__(self):
            super(features_encoder, self).__init__()
            self.model =nn.Linear(2048,512)
    
        def forward(self, x):
            out=self.model(x)
            return out
    
    images_encoder=Image_Encoder()
    features_encoder=features_encoder()
    
    clip = CLIP(
        image_encoder = images_encoder,
        text_encoder = features_encoder,
        dim_image = 512,
        dim_text = 512,
        dim_latent = 512
    )
    
    features= torch.randn(4,2048)
    images = torch.randn(4, 3, 256, 256)
    
    loss = clip(features, images, return_loss = True)
    loss.backward()
    

    but I got the following error : forward() takes 2 positional arguments but 3 were given

    Thanks

    opened by ethancohen123 8
  • Visual ssl with channels different than 3

    Visual ssl with channels different than 3

    Hi, seems to be a bug when trying to use visual ssl with a different number of channel than 3 . I think the error came from the visual ssl type ~row 280 here:

    #send a mock image tensor to instantiate parameters self.forward(torch.randn(1, 3, image_size, image_size))

    opened by ethancohen123 4
  • Allow other types of visual  SSL when initiating CLIP

    Allow other types of visual SSL when initiating CLIP

    In the following code as part of CLIP.__init__

            if use_visual_ssl:
                if visual_ssl_type == 'simsiam':
                    ssl_type = SimSiam
                elif visual_ssl_type == 'simclr':
                    ssl_type = partial(SimCLR, temperature = simclr_temperature)
                else:
                    raise ValueError(f'unknown visual_ssl_type')
    
                self.visual_ssl = ssl_type(
                    self.visual_transformer,
                    image_size = visual_image_size,
                    hidden_layer = visual_ssl_hidden_layer
                )
    

    the visual self-supervised learning is hardcoded. I would suggest changing this to accept the visual SSL module as an argument when instantiating CLIP to allow flexibility in the same manner as it does for the image encoder and text encoder.

    Example:

    barlow = BarlowTwins(augmentatation_fns)
    clip = CLIP(..., visual_ssl=barlow)
    
    opened by Froskekongen 4
  • Extract Text and Image Latents

    Extract Text and Image Latents

    Hi, in the current implementation we can only extract text and image embedding (by set return_encodings=True) which are obtained before applying latent linear layers. Isn't it better to add an option to extract latent embeddings? Another importance of this is that with the current code, it is impossible to extract the similarity matrix between a batch of images and a batch of text.

    opened by mmsamiei 2
  • NaN with mock data

    NaN with mock data

    Hi lucidrains,

    Try this and it will NaN within 100 steps (latest Github code). The loss looks fine before NaN.

    import torch
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True    
    torch.backends.cudnn.benchmark = True
    
    import random
    import numpy as np
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    num_text_tokens = 10000
    batch_sz = 12
    text_seq_len = 256
    visual_image_size = 256
    
    # mock data
    
    data_sz = 1000
    all_text = torch.randint(0, num_text_tokens, (data_sz, text_seq_len)).cuda()
    all_images = torch.randn(data_sz, 3, visual_image_size, visual_image_size).cuda()
    
    text = torch.zeros((batch_sz, text_seq_len), dtype=torch.long).cuda()
    images = torch.zeros((batch_sz, 3, visual_image_size, visual_image_size)).cuda()
    
    ##########################################################################################
    
    import wandb
    import datetime
    wandb.init(project="Test", name=datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), save_code=False)
    
    from x_clip import CLIP
    
    clip = CLIP(
        dim_text = 512,
        dim_image = 512,
        dim_latent = 512,
        num_text_tokens = num_text_tokens,
        text_enc_depth = 6,
        text_seq_len = text_seq_len,
        text_heads = 8,
        visual_enc_depth = 6,
        visual_image_size = visual_image_size,
        visual_patch_size = 32,
        visual_heads = 8,
        use_all_token_embeds = False,           # whether to use fine-grained contrastive learning (FILIP)
        decoupled_contrastive_learning = True,  # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
        extra_latent_projection = True,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
        use_visual_ssl = True,                  # whether to do self supervised learning on iages
        visual_ssl_type = 'simclr',             # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
        use_mlm = False,                        # use masked language learning (MLM) on text (DeCLIP)
        text_ssl_loss_weight = 0.05,            # weight for text MLM loss
        image_ssl_loss_weight = 0.05            # weight for image self-supervised learning loss
    ).cuda()
    
    optimizer = torch.optim.Adam(clip.parameters(), lr=1e-4, betas=(0.9, 0.99))
    
    for step in range(999999):
        for i in range(batch_sz):
            data_id = random.randrange(0, data_sz - 1)
            text[i] = all_text[data_id]
            images[i] = all_images[data_id]
    
        loss = clip(
            text,
            images,
            freeze_image_encoder = False,   # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper
            return_loss = True              # needs to be set to True to return contrastive loss
        )
        clip.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(clip.parameters(), 1.0)
        optimizer.step()
    
        now_loss = loss.item()
        wandb.log({"loss": now_loss}, step = step)
        print(step, now_loss)
    
        if 'nan' in str(now_loss):
            break
    
    opened by BlinkDL 1
  • Unable to train to convergence (small dataset)

    Unable to train to convergence (small dataset)

    Hi nice work with x-clip. Hoping to play around with it and eventually combine it into your DALLE2 work.

    Currently having some trouble training on roughly 30k image-text pairs. Loss eventually goes negative and starts producing Nan's. I've dropped learning rate down (1e-4) and I'm clipping gradients (max_norm=0.5).

    Any thoughts on what are sane training params/configs on such a small dataset using x-clip?

    opened by jacobwjs 9
Releases(0.12.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
How to Become More Salient? Surfacing Representation Biases of the Saliency Prediction Model

How to Become More Salient? Surfacing Representation Biases of the Saliency Prediction Model

Bogdan Kulynych 49 Nov 05, 2022
RepVGG: Making VGG-style ConvNets Great Again

This repository is the code that needs to be submitted for OpenMMLab Algorithm Ecological Challenge,the paper is RepVGG: Making VGG-style ConvNets Great Again

Ty Feng 62 May 21, 2022
Specification language for generating Generalized Linear Models (with or without mixed effects) from conceptual models

tisane Tisane: Authoring Statistical Models via Formal Reasoning from Conceptual and Data Relationships TL;DR: Analysts can use Tisane to author gener

Eunice Jun 11 Nov 15, 2022
Auto-updating data to assist in investment to NEPSE

Symbol Ratios Summary Sector LTP Undervalued Bonus % MEGA Strong Commercial Banks 368 5 10 JBBL Strong Development Banks 568 5 10 SIFC Strong Finance

Amit Chaudhary 16 Nov 01, 2022
Creating Multi Task Models With Keras

Creating Multi Task Models With Keras About The Project! I used the keras and Tensorflow Library, To build a Deep Learning Neural Network to Creating

Srajan Chourasia 4 Nov 28, 2022
Here is the diagnostic tool for BMVC 2021 paper Diagnosing Errors in Video Relation Detectors.

Here is the diagnostic tool for BMVC 2021 paper Diagnosing Errors in Video Relation Detectors. We provide a tiny ground truth file demo_gt.json, and t

Shuo Chen 3 Dec 26, 2022
code for paper "Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?"

Does Unsupervised Architecture Representation Learning Help Neural Architecture Search? Code for paper: Does Unsupervised Architecture Representation

39 Dec 17, 2022
A Research-oriented Federated Learning Library and Benchmark Platform for Graph Neural Networks. Accepted to ICLR'2021 - DPML and MLSys'21 - GNNSys workshops.

FedGraphNN: A Federated Learning System and Benchmark for Graph Neural Networks A Research-oriented Federated Learning Library and Benchmark Platform

FedML-AI 175 Dec 01, 2022
Quasi-Dense Similarity Learning for Multiple Object Tracking, CVPR 2021 (Oral)

Quasi-Dense Tracking This is the offical implementation of paper Quasi-Dense Similarity Learning for Multiple Object Tracking. We present a trailer th

ETH VIS Research Group 327 Dec 27, 2022
Tom-the-AI - A compound artificial intelligence software for Linux systems.

Tom the AI (version 0.82) WARNING: This software is not yet ready to use, I'm still setting up the GitHub repository. Should be ready in a few days. T

2 Apr 28, 2022
Python Rapid Artificial Intelligence Ab Initio Molecular Dynamics

Python Rapid Artificial Intelligence Ab Initio Molecular Dynamics

14 Nov 06, 2022
Detecting Human-Object Interactions with Object-Guided Cross-Modal Calibrated Semantics

[AAAI2022] Detecting Human-Object Interactions with Object-Guided Cross-Modal Calibrated Semantics Overall pipeline of OCN. Paper Link: [arXiv] [AAAI

13 Nov 21, 2022
Hierarchical Few-Shot Generative Models

Hierarchical Few-Shot Generative Models Giorgio Giannone, Ole Winther This repo contains code and experiments for the paper Hierarchical Few-Shot Gene

Giorgio Giannone 6 Dec 12, 2022
A python implementation of Deep-Image-Analogy based on pytorch.

Deep-Image-Analogy This project is a python implementation of Deep Image Analogy.https://arxiv.org/abs/1705.01088. Some results Requirements python 3

Peng Lu 171 Dec 14, 2022
Animate molecular orbital transitions using Psi4 and Blender

Molecular Orbital Transitions (MOT) Animate molecular orbital transitions using Psi4 and Blender Author: Maximilian Paradiz Dominguez, University of A

3 Feb 01, 2022
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
Towards Fine-Grained Reasoning for Fake News Detection

FinerFact This is the PyTorch implementation for the FinerFact model in the AAAI 2022 paper Towards Fine-Grained Reasoning for Fake News Detection (Ar

Ahren_Jin 15 Dec 15, 2022
The datasets and code of ACL 2021 paper "Aspect-Category-Opinion-Sentiment Quadruple Extraction with Implicit Aspects and Opinions".

Aspect-Category-Opinion-Sentiment (ACOS) Quadruple Extraction This repo contains the data sets and source code of our paper: Aspect-Category-Opinion-S

NUSTM 144 Jan 02, 2023
YOLOv5 + ROS2 object detection package

YOLOv5-ROS YOLOv5 + ROS2 object detection package This program changes the input of detect.py (ultralytics/yolov5) to sensor_msgs/Image of ROS2. Requi

Ar-Ray 23 Dec 19, 2022
Riemannian Convex Potential Maps

Modeling distributions on Riemannian manifolds is a crucial component in understanding non-Euclidean data that arises, e.g., in physics and geology. The budding approaches in this space are limited b

Facebook Research 61 Nov 28, 2022