Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch

Overview

CoCa - Pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder.

This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards)

Yannic Kilcher presentation

Install

$ pip install coca-pytorch

Usage

First install the vit-pytorch for the image encoder, which needs to be pretrained

$ pip install vit-pytorch

Then

import torch

# import vision transformer

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# do your vision transformer training

vit = Extractor(vit, return_embeddings_only = True)

# extractor will enable it so the vision transformer returns its embeddings

# import CoCa and instantiate it

from coca_pytorch.coca_pytorch import CoCa

coca = CoCa(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = 20000,            # number of text tokens
    unimodal_depth = 6,            # depth of the unimodal transformer
    multimodal_depth = 6,          # depth of the multimodal transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 1.,  # weight on the contrastive loss between image and text CLS embeddings
).cuda()

# mock text and images

text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train by giving CoCa your text and images with `return_loss = True`

loss = coca(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

loss.backward()

# do the above for as much text and images...
# then you can get the caption logits as so

logits = coca(
    text = text,
    images = images
) # (4, 512, 20000)

# and the CLIP-like text and image embeddings as

text_embeds, image_embeds = coca(
    text = text,
    images = images,
    return_embeddings = True
) # (4, 512), (4, 512)

Citations

@inproceedings{Yu2022CoCaCC,
  title   = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
  author  = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
  year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Hi @lucidrains, thank you for the implementation. Just wanted to confirm this with you, based on your code we're normalizing the img embedding and text embedding respectively using a learnable Layer Norm transformation before applying the contrastive loss. But based on my understanding, for contrastive loss we typically maximize the relative cosine similarity so the embeddings should be L2-normed instead of layernormed? Thank you.

    opened by fedshyvana 2
  • Extractor in vit_pytorch will detach the tensor.

    Extractor in vit_pytorch will detach the tensor.

    Thanks for your code! I think I may find a little bug. The cloned tensor in Extractor will be detached (https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/extractor.py#L39). So gradient may not propagate back to image encoder.

    opened by techkang 2
  • Maybe don't need this rearrange

    Maybe don't need this rearrange

    I think the logits before this line in shape (bsz, length, num_tokens) -> so I don't think here need one more rearrange https://github.com/lucidrains/CoCa-pytorch/blob/25de0b04326d8dc4c6f969e90b4466fc4894835e/coca_pytorch/coca_pytorch.py#L461

    opened by CiaoHe 2
  • How to train the model using my own dataset?

    How to train the model using my own dataset?

    Can someone tell me how to train the model using my own dataset? is it like below?But I have many images and texts...

    # train by giving CoCa your text and images with `return_loss = True`
    loss = coca(
        text = text,
        images = images,
        return_loss = True  # set this to True to get the full caption + contrastive loss
    )
    
    opened by keepcodeandsmile 1
  • why train VIT visual encoder first?

    why train VIT visual encoder first?

    Hi, thanks for sharing this repo. In the CoCA paper, both the visual encoder and text encoder are end-to end trained. But in this repo, the vit is first pretrained then fixed to train CoCa.

    opened by Flowerfan 1
  • attn_mask

    attn_mask

    cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')  
    attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
    
    attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')  
    sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
    

    Hello, I am confused of the implement of "attn_mask". I think this padding function only can mask the last row of "sim". Could you please explain it? Perhaps it's a very fool question. Thank you so much.

    opened by pldlgb 0
  • Reproducing the results in the paper

    Reproducing the results in the paper

    Thanks for this repo. Curious, is this an independent implementation of the CoCa paper? If yes, did you reproduce any result in the paper to ensure correctness of implementation?

    opened by GKIBMNY 0
  • Generating the caption of a given image

    Generating the caption of a given image

    Hello,

    Thank you for having implemented this model. Have you already implemented some code to generate the caption of a given image? If not, do you have an idea about how you would do it in this particular architecture?

    Thank you in advance.

    opened by claudiogreco 0
Releases(0.0.7)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Garbage classification using structure data.

垃圾分类模型使用说明 1.包含以下数据文件 文件 描述 data/MaterialMapping.csv 物体以及其归类的信息 data/TestRecords 光谱原始测试数据 CSV 文件 data/TestRecordDesc.zip CSV 文件描述文件 data/Boundaries.cs

wenqi 1 Dec 10, 2021
This is the official code for the paper "Tracker Meets Night: A Transformer Enhancer for UAV Tracking".

SCT This is the official code for the paper "Tracker Meets Night: A Transformer Enhancer for UAV Tracking" The spatial-channel Transformer (SCT) enhan

Intelligent Vision for Robotics in Complex Environment 27 Nov 23, 2022
Process JSON files for neural recording sessions using Medtronic's BrainSense Percept PC neurostimulator

percept_processing This code processes JSON files for streamed neural data using Medtronic's Percept PC neurostimulator with BrainSense Technology for

Maria Olaru 3 Jun 06, 2022
MADT: Offline Pre-trained Multi-Agent Decision Transformer

MADT: Offline Pre-trained Multi-Agent Decision Transformer A link to our paper can be found on Arxiv. Overview Official codebase for Offline Pre-train

Linghui Meng 51 Dec 21, 2022
Autoencoders pretraining using clustering

Autoencoders pretraining using clustering

IITiS PAN 2 Dec 16, 2021
Implementation for Curriculum DeepSDF

Curriculum-DeepSDF This repository is an implementation for Curriculum DeepSDF. Full paper is available here. Preparation Please follow original setti

Haidong Zhu 69 Dec 29, 2022
Automated image registration. Registrationimation was too much of a mouthful.

alignimation Automated image registration. Registrationimation was too much of a mouthful. This repo contains the code used for my blog post Alignimat

Ethan Rosenthal 9 Oct 13, 2022
YOLOv5 in PyTorch > ONNX > CoreML > TFLite

This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices evolved over thousands of hours of training and e

Ultralytics 34.1k Dec 31, 2022
[CVPR 2016] Unsupervised Feature Learning by Image Inpainting using GANs

Context Encoders: Feature Learning by Inpainting CVPR 2016 [Project Website] [Imagenet Results] Sample results on held-out images: This is the trainin

Deepak Pathak 829 Dec 31, 2022
Element selection for functional materials discovery by integrated machine learning of atomic contributions to properties

Element selection for functional materials discovery by integrated machine learning of atomic contributions to properties 8.11.2021 Andrij Vasylenko I

Leverhulme Research Centre for Functional Materials Design 4 Dec 20, 2022
[CVPR 2021] Released code for Counterfactual Zero-Shot and Open-Set Visual Recognition

Counterfactual Zero-Shot and Open-Set Visual Recognition This project provides implementations for our CVPR 2021 paper Counterfactual Zero-S

144 Dec 24, 2022
PyTorch implementation of "PatchGame: Learning to Signal Mid-level Patches in Referential Games" to appear in NeurIPS 2021

PatchGame: Learning to Signal Mid-level Patches in Referential Games This repository is the official implementation of the paper - "PatchGame: Learnin

Kamal Gupta 22 Mar 16, 2022
A real-time motion capture system that estimates poses and global translations using only 6 inertial measurement units

TransPose Code for our SIGGRAPH 2021 paper "TransPose: Real-time 3D Human Translation and Pose Estimation with Six Inertial Sensors". This repository

Xinyu Yi 261 Dec 31, 2022
Codes for AAAI22 paper "Learning to Solve Travelling Salesman Problem with Hardness-Adaptive Curriculum"

Paper For more details, please see our paper Learning to Solve Travelling Salesman Problem with Hardness-Adaptive Curriculum which has been accepted a

14 Sep 30, 2022
Unofficial pytorch implementation of the paper "Dynamic High-Pass Filtering and Multi-Spectral Attention for Image Super-Resolution"

DFSA Unofficial pytorch implementation of the ICCV 2021 paper "Dynamic High-Pass Filtering and Multi-Spectral Attention for Image Super-Resolution" (p

2 Nov 15, 2021
Implementation of fast algorithms for Maximum Spanning Tree (MST) parsing that includes fast ArcMax+Reweighting+Tarjan algorithm for single-root dependency parsing.

Fast MST Algorithm Implementation of fast algorithms for (Maximum Spanning Tree) MST parsing that includes fast ArcMax+Reweighting+Tarjan algorithm fo

Miloš Stanojević 11 Oct 14, 2022
PyTorch implementation(s) of various ResNet models from Twitch streams.

pytorch-resnet-twitch PyTorch implementation(s) of various ResNet models from Twitch streams. Status: ResNet50 currently not working. Will update in n

Daniel Bourke 3 Jan 11, 2022
A set of tests for evaluating large-scale algorithms for Wasserstein-2 transport maps computation.

Continuous Wasserstein-2 Benchmark This is the official Python implementation of the NeurIPS 2021 paper Do Neural Optimal Transport Solvers Work? A Co

Alexander 22 Dec 12, 2022
Final term project for Bayesian Machine Learning Lecture (XAI-623)

Mixquality_AL Final Term Project For Bayesian Machine Learning Lecture (XAI-623) Youtube Link The presentation is given in YoutubeLink Problem Formula

JeongEun Park 3 Jan 18, 2022
[ICLR 2021] "Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective" by Wuyang Chen, Xinyu Gong, Zhangyang Wang

Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective [PDF] Wuyang Chen, Xinyu Gong, Zhangyang Wang In ICLR 2

VITA 156 Nov 28, 2022