Scalable training for dense retrieval models.

Overview

Scalable implementation of dense retrieval.

Training on cluster

By default it trains locally:

PYTHONPATH=.:$PYTHONPATH python dpr_scale/main.py trainer.gpus=1

SLURM Training

To train the model on SLURM, run:

PYTHONPATH=.:$PYTHONPATH python dpr_scale/main.py -m trainer=slurm trainer.num_nodes=2 trainer.gpus=2

Reproduce DPR on 8 gpus

PYTHONPATH=.:$PYTHONPATH python dpr_scale/main.py -m --config-name nq.yaml  +hydra.launcher.name=dpr_stl_nq_reproduce

Generate embeddings on Wikipedia

PYTHONPATH=.:$PYTHONPATH python dpr_scale/generate_embeddings.py -m --config-name nq.yaml datamodule=generate datamodule.test_path=psgs_w100.tsv +task.ctx_embeddings_dir=<CTX_EMBEDDINGS_DIR> +task.checkpoint_path=<CHECKPOINT_PATH>

Get retrieval results

Currently this runs on 1 GPU. Use CTX_EMBEDDINGS_DIR from above.

PYTHONPATH=.:$PYTHONPATH python dpr_scale/run_retrieval.py --config-name nq.yaml trainer=gpu_1_host trainer.gpus=1 +task.output_path=<PATH_TO_OUTPUT_JSON> +task.ctx_embeddings_dir=<CTX_EMBEDDINGS_DIR> +task.checkpoint_path=<CHECKPOINT_PATH> +task.passages=psgs_w100.tsv datamodule.test_path=<PATH_TO_QUERIES_JSONL>

Generate query embeddings

Alternatively, query embedding generation and retrieval can be separated. After query embeddings are generated using the following command, the run_retrieval_fb.py or run_retrieval_multiset.py script can be used to perform retrieval.

PYTHONPATH=.:$PYTHONPATH python dpr_scale/generate_query_embeddings.py -m --config-name nq.yaml trainer.gpus=1 datamodule.test_path=<PATH_TO_QUERIES_JSONL> +task.ctx_embeddings_dir=<CTX_EMBEDDINGS_DIR> +task.checkpoint_path=<CHECKPOINT_PATH> +task.query_emb_output_path=<OUTPUT_TO_QUERY_EMB>

Get evaluation metrics for a given JSON output file

python dpr_scale/eval_dpr.py --retrieval <PATH_TO_OUTPUT_JSON> --topk 1 5 10 20 50 100 

Get evaluation metrics for MSMARCO

python dpr_scale/msmarco_eval.py ~data/msmarco/qrels.dev.small.tsv PATH_TO_OUTPUT_JSON

Domain-matched Pre-training Tasks for Dense Retrieval

Paper: https://arxiv.org/abs/2107.13602

The sections below provide links to datasets and pretrained models, as well as, instructions to prepare datasets, pretrain and fine-tune them.

Q&A Datasets

PAQ

Download the dataset from here

Conversational Datasets

You can download the dataset from the respective tables.

Reddit

File Download Link
train download
dev download

ConvAI2

File Download Link
train download
dev download

DSTC7

File Download Link
train download
dev download
test download

Prepare by downloading the tar ball linked here, and using the command below.

DSTC7_DATA_ROOT=<path_of_dir_where_the_data_is_extracted>
python dpr_scale/data_prep/prep_conv_datasets.py \
    --dataset dstc7 \
    --in_file_path $DSTC7_DATA_ROOT/ubuntu_train_subtask_1_augmented.json \
    --out_file_path $DSTC7_DATA_ROOT/ubuntu_train.jsonl

Ubuntu V2

File Download Link
train download
dev download
test download

Prepare by downloading the tar ball linked here, and using the command below.

UBUNTUV2_DATA_ROOT=<path_of_dir_where_the_data_is_extracted>
python dpr_scale/data_prep/prep_conv_datasets.py \
    --dataset ubuntu2 \
    --in_file_path $UBUNTUV2_DATA_ROOT/train.csv \
    --out_file_path $UBUNTUV2_DATA_ROOT/train.jsonl

Pretraining DPR

Pretrained Checkpoints

Pretrained Model Dataset Download Link
BERT-base PAQ download
BERT-large PAQ download
BERT-base Reddit download
BERT-large Reddit download
RoBERTa-base Reddit download
RoBERTa-large Reddit download

Pretraining on PAQ dataset

DPR_ROOT=<path_of_your_repo's_root>
MODEL="bert-large-uncased"
NODES=8
BSZ=16
MAX_EPOCHS=20
LR=1e-5
TIMOUT_MINS=4320
EXP_DIR=<path_of_the_experiment_dir>
TRAIN_PATH=<path_of_the_training_data_file>
mkdir -p ${EXP_DIR}/logs
PYTHONPATH=$DPR_ROOT python ${DPR_ROOT}/dpr_scale/main.py -m \
    --config-dir ${DPR_ROOT}/dpr_scale/conf \
    --config-name nq.yaml \
    hydra.launcher.timeout_min=$TIMOUT_MINS \
    hydra.sweep.dir=${EXP_DIR} \
    trainer.num_nodes=${NODES} \
    task.optim.lr=${LR} \
    task.model.model_path=${MODEL} \
    trainer.max_epochs=${MAX_EPOCHS} \
    datamodule.train_path=$TRAIN_PATH \
    datamodule.batch_size=${BSZ} \
    datamodule.num_negative=1 \
    datamodule.num_val_negative=10 \
    datamodule.num_test_negative=50 > ${EXP_DIR}/logs/log.out 2> ${EXP_DIR}/logs/log.err &

Pretraining on Reddit dataset

# Use a batch size of 16 for BERT and RoBERTa base models.
BSZ=4
NODES=8
MAX_EPOCHS=5
WARMUP_STEPS=10000
LR=1e-5
MODEL="roberta-large"
EXP_DIR=<path_of_the_experiment_dir>
PYTHONPATH=. python dpr_scale/main.py -m \
    --config-dir ${DPR_ROOT}/dpr_scale/conf \
    --config-name reddit.yaml \
    hydra.launcher.nodes=${NODES} \
    hydra.sweep.dir=${EXP_DIR} \
    trainer.num_nodes=${NODES} \
    task.optim.lr=${LR} \
    task.model.model_path=${MODEL} \
    trainer.max_epochs=${MAX_EPOCHS} \
    task.warmup_steps=${WARMUP_STEPS} \
    datamodule.batch_size=${BSZ} > ${EXP_DIR}/logs/log.out 2> ${EXP_DIR}/logs/log.err &

Fine-tuning DPR on downstream tasks/datasets

Fine-tune the pretrained PAQ checkpoint

# You can also try 2e-5 or 5e-5. Usually these 3 learning rates work best.
LR=1e-5
# Use a batch size of 32 for BERT and RoBERTa base models.
BSZ=12
MODEL="bert-large-uncased"
MAX_EPOCHS=40
WARMUP_STEPS=1000
NODES=1
PRETRAINED_CKPT_PATH=<path_of_checkpoint_pretrained_on_reddit>
EXP_DIR=<path_of_the_experiment_dir>
PYTHONPATH=. python dpr_scale/main.py -m \
    --config-dir ${DPR_ROOT}/dpr_scale/conf \
    --config-name nq.yaml \
    hydra.launcher.name=${NAME} \
    hydra.sweep.dir=${EXP_DIR} \
    trainer.num_nodes=${NODES} \
    trainer.max_epochs=${MAX_EPOCHS} \
    datamodule.num_negative=1 \
    datamodule.num_val_negative=25 \
    datamodule.num_test_negative=50 \
    +trainer.val_check_interval=150 \
    task.warmup_steps=${WARMUP_STEPS} \
    task.optim.lr=${LR} \
    task.pretrained_checkpoint_path=$PRETRAINED_CKPT_PATH \
    task.model.model_path=${MODEL} \
    datamodule.batch_size=${BSZ} > ${EXP_DIR}/logs/log.out 2> ${EXP_DIR}/logs/log.err &

Fine-tune the pretrained Reddit checkpoint

Batch sizes that worked on Volta 32GB GPUs for respective model and datasets.

Model Dataset Batch Size
BERT/RoBERTa base ConvAI2 64
RBERT/RoBERTa base ConvAI2 16
BERT/RoBERTa base DSTC7 24
BERT/RoBERTa base DSTC7 8
BERT/RoBERTa base Ubuntu V2 64
BERT/RoBERTa large Ubuntu V2 16
# Change the config file name to convai2.yaml or dstc7.yaml for the respective datasets.
CONFIG_FILE_NAME=ubuntuv2.yaml
# You can also try 2e-5 or 5e-5. Usually these 3 learning rates work best.
LR=1e-5
BSZ=16
NODES=1
MAX_EPOCHS=5
WARMUP_STEPS=10000
MODEL="roberta-large"
PRETRAINED_CKPT_PATH=<path_of_checkpoint_pretrained_on_reddit>
EXP_DIR=<path_of_the_experiment_dir>
PYTHONPATH=${DPR_ROOT} python ${DPR_ROOT}/dpr_scale/main.py -m \
    --config-dir=${DPR_ROOT}/dpr_scale/conf \
    --config-name=$CONFIG_FILE_NAME \
    hydra.launcher.nodes=${NODES} \
    hydra.sweep.dir=${EXP_DIR} \
    trainer.num_nodes=${NODES} \
    trainer.max_epochs=${MAX_EPOCHS} \
    +trainer.val_check_interval=150 \
    task.pretrained_checkpoint_path=$PRETRAINED_CKPT_PATH \
    task.warmup_steps=${WARMUP_STEPS} \
    task.optim.lr=${LR} \
    task.model.model_path=$MODEL \
    datamodule.batch_size=${BSZ} > ${EXP_DIR}/logs/log.out 2> ${EXP_DIR}/logs/log.err &

License

dpr-scale is CC-BY-NC 4.0 licensed as of now.

Owner
Facebook Research
Facebook Research
A testcase generation tool for Persistent Memory Programs.

PMFuzz PMFuzz is a testcase generation tool to generate high-value tests cases for PM testing tools (XFDetector, PMDebugger, PMTest and Pmemcheck) If

Systems Research at ShiftLab 14 Jul 24, 2022
Official implement of Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer

Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer This repository contains the PyTorch code for Evo-ViT. This work proposes a slow-fas

YifanXu 53 Dec 05, 2022
Official implementation for CVPR 2021 paper: Adaptive Class Suppression Loss for Long-Tail Object Detection

Adaptive Class Suppression Loss for Long-Tail Object Detection This repo is the official implementation for CVPR 2021 paper: Adaptive Class Suppressio

CASIA-IVA-Lab 67 Dec 04, 2022
Official repository of the paper Learning to Regress 3D Face Shape and Expression from an Image without 3D Supervision

Official repository of the paper Learning to Regress 3D Face Shape and Expression from an Image without 3D Supervision

Soubhik Sanyal 689 Dec 25, 2022
Code repo for EMNLP21 paper "Zero-Shot Information Extraction as a Unified Text-to-Triple Translation"

Zero-Shot Information Extraction as a Unified Text-to-Triple Translation Source code repo for paper Zero-Shot Information Extraction as a Unified Text

cgraywang 88 Dec 31, 2022
Code and experiments for "Deep Neural Networks for Rank Consistent Ordinal Regression based on Conditional Probabilities"

corn-ordinal-neuralnet This repository contains the orginal model code and experiment logs for the paper "Deep Neural Networks for Rank Consistent Ord

Raschka Research Group 14 Dec 27, 2022
Localization Distillation for Object Detection

Localization Distillation for Object Detection This repo is based on mmDetection. This is the code for our paper: Localization Distillation

274 Dec 26, 2022
Using deep learning to predict gene structures of the coding genes in DNA sequences of Arabidopsis thaliana

DeepGeneAnnotator: A tool to annotate the gene in the genome The master thesis of the "Using deep learning to predict gene structures of the coding ge

Ching-Tien Wang 3 Sep 09, 2022
Offline Reinforcement Learning with Implicit Q-Learning

Offline Reinforcement Learning with Implicit Q-Learning This repository contains the official implementation of Offline Reinforcement Learning with Im

Ilya Kostrikov 126 Jan 06, 2023
A web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks

This project is a web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks. Thanks for NVlabs' excelle

K.L. 150 Dec 15, 2022
Codebase of deep learning models for inferring stability of mRNA molecules

Kaggle OpenVaccine Models Codebase of deep learning models for inferring stability of mRNA molecules, corresponding to the Kaggle Open Vaccine Challen

Eternagame 40 Dec 29, 2022
Image-to-Image Translation with Conditional Adversarial Networks (Pix2pix) implementation in keras

pix2pix-keras Pix2pix implementation in keras. Original paper: Image-to-Image Translation with Conditional Adversarial Networks (pix2pix) Paper Author

William Falcon 141 Dec 30, 2022
In this project we investigate the performance of the SetCon model on realistic video footage. Therefore, we implemented the model in PyTorch and tested the model on two example videos.

Contrastive Learning of Object Representations Supervisor: Prof. Dr. Gemma Roig Institutions: Goethe University CVAI - Computational Vision & Artifici

Dirk Neuhäuser 6 Dec 08, 2022
A simple pygame dino game which can also be trained and played by a NEAT KI

Dino Game AI Game The game itself was developed with the Pygame module pip install pygame You can also play it yourself by making the dino jump with t

Kilian Kier 7 Dec 05, 2022
Biomarker identification for COVID-19 Severity in BALF cells Single-cell RNA-seq data

scBALF Covid-19 dataset Analysis Here is the Github page that has the codes for the bioinformatics pipeline described in the paper COVID-Datathon: Bio

Nami Niyakan 2 May 21, 2022
​TextWorld is a sandbox learning environment for the training and evaluation of reinforcement learning (RL) agents on text-based games.

TextWorld A text-based game generator and extensible sandbox learning environment for training and testing reinforcement learning (RL) agents. Also ch

Microsoft 983 Dec 23, 2022
An educational AI robot based on NVIDIA Jetson Nano.

JetBot Looking for a quick way to get started with JetBot? Many third party kits are now available! JetBot is an open-source robot based on NVIDIA Jet

NVIDIA AI IOT 2.6k Dec 29, 2022
Code implementation of Data Efficient Stagewise Knowledge Distillation paper.

Data Efficient Stagewise Knowledge Distillation Table of Contents Data Efficient Stagewise Knowledge Distillation Table of Contents Requirements Image

IvLabs 112 Dec 02, 2022
Reusable constraint types to use with typing.Annotated

annotated-types PEP-593 added typing.Annotated as a way of adding context-specific metadata to existing types, and specifies that Annotated[T, x] shou

125 Dec 26, 2022
Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Bayes-Newton Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Wil

AaltoML 165 Nov 27, 2022