Align before Fuse: Vision and Language Representation Learning with Momentum Distillation

Overview

Align before Fuse: Vision and Language Representation Learning with Momentum Distillation (Salesforce Research)

This is the official PyTorch implementation of the ALBEF paper [Blog]. This repository supports pre-training on custom datasets, as well as finetuning on VQA, SNLI-VE, NLVR2, Image-Text Retrieval on MSCOCO and Flickr30k, and visual grounding on RefCOCO+. Pre-trained and finetuned checkpoints are released.

Requirements:

  • pytorch 1.8.0
  • transformers 4.8.1
  • timm 0.4.9

Download:

Visualization:

We provide code in visualize.ipynb to visualize the important areas in an image for each word in a text. Here is an example visualization using the visual grounding checkpoint.

Pre-training on custom datasets:

  1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
  2. In configs/Pretrain.yaml, set the paths for the json files.
  3. Pre-train the model using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain 

Image-Text Retrieval:

  1. Download MSCOCO or Flickr30k datasets from the original websites.
  2. Download and extract the provided dataset json files.
  3. In configs/Retrieval_coco.yaml or configs/Retrieval_flickr.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir output/Retrieval_flickr \
--checkpoint [Pretrained checkpoint]

VQA:

  1. Download VQA v2 dataset and Visual Genome dataset from the original websites.
  2. Download and extract the provided dataset json files.
  3. In configs/VQA.yaml, set the paths for the json files and the image paths.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VQA.py \
--config ./configs/VQA.yaml \
--output_dir output/vqa \
--checkpoint [Pretrained checkpoint]
  1. Evaluate the result using the official evaluation server.

Visual Entailment:

  1. Download SNLI-VE dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/VE.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VE.py \
--config ./configs/VE.yaml \
--output_dir output/VE \
--checkpoint [Pretrained checkpoint]

Visual Grounding on RefCOCO+:

  1. Download MSCOCO dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/Grounding.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Grounding.py \
--config ./configs/Grounding.yaml \
--output_dir output/RefCOCO \
--gradcam_mode itm \ 
--block_num 8 \
--checkpoint [Pretrained checkpoint]

NLVR2:

NLVR2 requires an additional pre-training step with text-assignment (TA) to adapt the model for image-pair inputs. In order to perform TA, first set the paths for the json training files in configs/NLVR_pretrain.yaml, then run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain_nlvr.py \
--config ./configs/NLVR_pretrain.yaml \
--output_dir output/NLVR_pretrain \
--checkpoint [Pretrained checkpoint]

We provide the checkpoint after TA pre-training, which can be fine-tuned with the following steps.

  1. Download NLVR2 dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/NLVR.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env NLVR.py \
--config ./configs/NLVR.yaml \
--output_dir output/NLVR \
--checkpoint [TA pretrained checkpoint]

Citation

If you find this code to be useful for your research, please consider citing.

@article{ALBEF,
      title={Align before Fuse: Vision and Language Representation Learning with Momentum Distillation}, 
      author={Junnan Li and Ramprasaath R. Selvaraju and Akhilesh Deepak Gotmare and Shafiq Joty and Caiming Xiong and Steven Hoi},
      year={2021},
      journal={arXiv preprint arXiv:2107.07651},
}
Owner
Salesforce
A variety of vendor agnostic projects which power Salesforce
Salesforce
The Pytorch implementation for "Video-Text Pre-training with Learned Regions"

Region_Learner The Pytorch implementation for "Video-Text Pre-training with Learned Regions" (arxiv) We are still cleaning up the code further and pre

Rui Yan 0 Mar 20, 2022
Real-time Object Detection for Streaming Perception, CVPR 2022

StreamYOLO Real-time Object Detection for Streaming Perception Jinrong Yang, Songtao Liu, Zeming Li, Xiaoping Li, Sun Jian Real-time Object Detection

Jinrong Yang 237 Dec 27, 2022
Code for paper 'Hand-Object Contact Consistency Reasoning for Human Grasps Generation' at ICCV 2021

GraspTTA Hand-Object Contact Consistency Reasoning for Human Grasps Generation (ICCV 2021). Project Page with Videos Demo Quick Results Visualization

Hanwen Jiang 47 Dec 09, 2022
Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021

Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021 [WIP] The code for CVPR 2021 paper 'Disentangled Cycle Consistency for H

ChongjianGE 94 Dec 11, 2022
LOFO (Leave One Feature Out) Importance calculates the importances of a set of features based on a metric of choice,

LOFO (Leave One Feature Out) Importance calculates the importances of a set of features based on a metric of choice, for a model of choice, by iteratively removing each feature from the set, and eval

Ahmet Erdem 691 Dec 23, 2022
FS-Mol: A Few-Shot Learning Dataset of Molecules

FS-Mol is A Few-Shot Learning Dataset of Molecules, containing molecular compounds with measurements of activity against a variety of protein targets. The dataset is presented with a model evaluation

Microsoft 114 Dec 15, 2022
Madanalysis5 - A package for event file analysis and recasting of LHC results

Welcome to MadAnalysis 5 Outline What is MadAnalysis 5? Requirements Downloading

MadAnalysis 15 Jan 01, 2023
Code for paper "Multi-level Disentanglement Graph Neural Network"

Multi-level Disentanglement Graph Neural Network (MD-GNN) This is a PyTorch implementation of the MD-GNN, and the code includes the following modules:

Lirong Wu 6 Dec 29, 2022
Companion code for the paper Theoretical characterization of uncertainty in high-dimensional linear classification

Companion code for the paper Theoretical characterization of uncertainty in high-dimensional linear classification Usage The required packages are lis

0 Feb 07, 2022
The source code for the Cutoff data augmentation approach proposed in this paper: "A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation".

Cutoff: A Simple Data Augmentation Approach for Natural Language This repository contains source code necessary to reproduce the results presented in

Dinghan Shen 49 Dec 22, 2022
EdMIPS: Rethinking Differentiable Search for Mixed-Precision Neural Networks

EdMIPS is an efficient algorithm to search the optimal mixed-precision neural network directly without proxy task on ImageNet given computation budgets. It can be applied to many popular network arch

Zhaowei Cai 47 Dec 30, 2022
A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation

##A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation. #USAGE To run the trained classifier on some images: python w

Alex Seewald 13 Nov 17, 2022
Custom Implementation of Non-Deep Networks

ParNet Custom Implementation of Non-deep Networks arXiv:2110.07641 Ankit Goyal, Alexey Bochkovskiy, Jia Deng, Vladlen Koltun Official Repository https

Pritama Kumar Nayak 20 May 27, 2022
Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Networks

CyGNet This repository reproduces the AAAI'21 paper “Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Network

CunchaoZ 89 Jan 03, 2023
Filtering variational quantum algorithms for combinatorial optimization

Current gate-based quantum computers have the potential to provide a computational advantage if algorithms use quantum hardware efficiently.

1 Feb 09, 2022
Multi-Scale Aligned Distillation for Low-Resolution Detection (CVPR2021)

MSAD Multi-Scale Aligned Distillation for Low-Resolution Detection Lu Qi*, Jason Kuen*, Jiuxiang Gu, Zhe Lin, Yi Wang, Yukang Chen, Yanwei Li, Jiaya J

DV Lab 115 Dec 23, 2022
Official PyTorch Implementation of Mask-aware IoU and maYOLACT Detector [BMVC2021]

The official implementation of Mask-aware IoU and maYOLACT detector. Our implementation is based on mmdetection. Mask-aware IoU for Anchor Assignment

Kemal Oksuz 46 Sep 29, 2022
Toolkit for collecting and applying prompts

PromptSource Promptsource is a toolkit for collecting and applying prompts to NLP datasets. Promptsource uses a simple templating language to programa

BigScience Workshop 998 Jan 03, 2023
A python bot to move your mouse every few seconds to appear active on Skype, Teams or Zoom as you go AFK. 🐭 🤖

PyMouseBot If you're from GT and annoyed with SGVPN idle timeouts while working on development laptop, You might find this useful. A python cli bot to

Oaker Min 6 Oct 24, 2022
Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Set Recognition"

Adversarial Reciprocal Points Learning for Open Set Recognition Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Se

Guangyao Chen 78 Dec 28, 2022