Shared code for training sentence embeddings with Flax / JAX

Overview

flax-sentence-embeddings

This repository will be used to share code for the Flax / JAX community event to train sentence embeddings on 1B+ training pairs.

You can add your code by creating a pull request.

Dataloading

Dowload data

You can download the data using this basic python script at the root of the project. Download should be completed in about 20 minutes given your connection speed. Total size on disk is arround 25G.

python dataset/download_data.py --dataset_list=datasets_list.tsv --data_path=PATH_TO_STORE_DATASETS

Dataloading

First implementation of the dataloader takes as input a single jsonl.gz file. It creates a pointer on the file such that samples are loaded one by one. The implementation is based on torch standard Dataloader and Dataset classes. The class supports num_worker>0 such that data loading is done in a background process on the CPU, i.e. the data is loaded and tokenized in parallel to training the network. This avoid to create a bottleneck from I/O and tokenization. The implementation currently return {'anchor': '...,' 'positive': '...'}

from dataset.dataset import IterableCorpusDataset

corpus_dataset = IterableCorpusDataset(
  file_path=os.path.join(PATH_TO_STORE_DATASETS, 'stackexchange_duplicate_questions_title_title.json.gz'), 
  batch_size=2,
  num_workers=2, 
  transform=None)

corpus_dataset_itr = iter(corpus_dataset)
next(corpus_dataset_itr)

# {'anchor': 'Can anyone explain all these Developer Options?',
#  'positive': 'what is the advantage of using the GPU rendering options in Android?'}

def collate(batch_input_str):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    batch = {'anchor': tokenizer.batch_encode_plus([b['anchor'] for b in batch_input_str], pad_to_max_length=True),
             'positive': tokenizer.batch_encode_plus([b['positive'] for b in batch_input_str], pad_to_max_length=True)}
    return batch

corpus_dataloader = DataLoader(
  corpus_dataset,
  batch_size=2,
  num_workers=2,
  collate_fn=collate,
  pin_memory=False,
  drop_last=True,
  shuffle=False)

print(next(iter(corpus_dataloader)))

# {'anchor': {'input_ids': [[101, 4531, 2019, 2523, 2090, 2048, 4725, 1997, 2966, 8830, 1998, 1037, 7142, 8023, 102, 0, 0, 0], [101, 1039, 1001, 10463, 5164, 1061, 2100, 2100, 24335, 26876, 11927, 4779, 4779, 2102, 2000, 3058, 7292, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}, 'positive': {'input_ids': [[101, 1045, 2031, 2182, 2007, 2033, 1010, 2048, 4725, 1997, 8830, 1025, 1037, 3115, 2729, 4118, 1010, 1998, 1037, 17009, 8830, 1012, 2367, 3633, 4374, 2367, 4118, 1010, 2049, 2035, 18154, 11095, 1012, 1045, 2572, 2667, 2000, 2424, 1996, 2523, 1997, 1996, 17009, 8830, 1998, 1037, 1005, 2092, 2108, 3556, 1005, 2029, 2003, 1037, 15973, 3643, 1012, 2054, 2003, 1996, 2190, 2126, 2000, 2424, 2151, 8924, 1029, 1041, 1012, 1043, 1012, 8833, 6553, 26237, 2944, 1029, 102], [101, 1045, 2572, 2667, 2000, 10463, 1037, 5164, 3058, 2046, 1037, 4289, 2005, 29296, 3058, 7292, 1012, 1996, 4289, 2003, 2066, 1024, 1000, 2297, 2692, 20958, 2620, 17134, 19317, 19317, 1000, 1045, 2228, 2023, 1041, 16211, 4570, 2000, 1061, 2100, 2100, 24335, 26876, 11927, 4779, 4779, 2102, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}}

=======

Installation

Poetry

A Poetry toml is provided to manage dependencies in a virtualenv. Check https://python-poetry.org/

Once you've installed poetry, you can connect to virtual env and update dependencies:

poetry shell
poetry update
poetry install

requirements.txt

Someone on your platform should generate it once with following command.

poetry export -f requirements.txt --output requirements.txt

Rust compiler for hugginface tokenizers

  • Hugginface tokenizers require a Rust compiler so install one.

custom libs

  • If you want a specific version of any library, edit the pyproject.toml, add it and/or replace "*" by it.
Owner
Nils Reimers
Nils Reimers
A Persian Image Captioning model based on Vision Encoder Decoder Models of the transformers🤗.

Persian-Image-Captioning We fine-tuning the Vision Encoder Decoder Model for the task of image captioning on the coco-flickr-farsi dataset. The implem

Hamtech-ai 15 Aug 25, 2022
Machine learning models from Singapore's NLP research community

SG-NLP Machine learning models from Singapore's natural language processing (NLP) research community. sgnlp is a Python package that allows you to eas

AI Singapore | AI Makerspace 21 Dec 17, 2022
BERT has a Mouth, and It Must Speak: BERT as a Markov Random Field Language Model

BERT has a Mouth, and It Must Speak: BERT as a Markov Random Field Language Model

303 Dec 17, 2022
Adversarial Examples for Extreme Multilabel Text Classification

Adversarial Examples for Extreme Multilabel Text Classification The code is adapted from the source codes of BERT-ATTACK [1], APLC_XLNet [2], and Atte

1 May 14, 2022
CCF BDCI BERT系统调优赛题baseline(Pytorch版本)

CCF BDCI BERT系统调优赛题baseline(Pytorch版本) 此版本基于Pytorch后端的huggingface进行实现。由于此实现使用了Oneflow的dataloader作为数据读入的方式,因此也需要安装Oneflow。其它框架的数据读取可以参考OneflowDataloade

Ziqi Zhou 9 Oct 13, 2022
ConvBERT-Prod

ConvBERT 目录 0. 仓库结构 1. 简介 2. 数据集和复现精度 3. 准备数据与环境 3.1 准备环境 3.2 准备数据 3.3 准备模型 4. 开始使用 4.1 模型训练 4.2 模型评估 4.3 模型预测 5. 模型推理部署 5.1 基于Inference的推理 5.2 基于Serv

yujun 7 Apr 08, 2022
Implementation of the Hybrid Perception Block and Dual-Pruned Self-Attention block from the ITTR paper for Image to Image Translation using Transformers

ITTR - Pytorch Implementation of the Hybrid Perception Block (HPB) and Dual-Pruned Self-Attention (DPSA) block from the ITTR paper for Image to Image

Phil Wang 17 Dec 23, 2022
DAGAN - Dual Attention GANs for Semantic Image Synthesis

Contents Semantic Image Synthesis with DAGAN Installation Dataset Preparation Generating Images Using Pretrained Model Train and Test New Models Evalu

Hao Tang 104 Oct 08, 2022
Words-per-minute - A terminal app written in python utilizing the curses module that tests the user's ability to type

words-per-minute A terminal app written in python utilizing the curses module th

Tanim Islam 1 Jan 14, 2022
Pre-training BERT masked language models with custom vocabulary

Pre-training BERT Masked Language Models (MLM) This repository contains the method to pre-train a BERT model using custom vocabulary. It was used to p

Stella Douka 14 Nov 02, 2022
Machine learning classifiers to predict American Sign Language .

ASL-Classifiers American Sign Language (ASL) is a natural language that serves as the predominant sign language of Deaf communities in the United Stat

Tarek idrees 0 Feb 08, 2022
Conversational text Analysis using various NLP techniques

Conversational text Analysis using various NLP techniques

Rita Anjana 159 Jan 06, 2023
本项目是作者们根据个人面试和经验总结出的自然语言处理(NLP)面试准备的学习笔记与资料,该资料目前包含 自然语言处理各领域的 面试题积累。

【关于 NLP】那些你不知道的事 作者:杨夕、芙蕖、李玲、陈海顺、twilight、LeoLRH、JimmyDU、艾春辉、张永泰、金金金 介绍 本项目是作者们根据个人面试和经验总结出的自然语言处理(NLP)面试准备的学习笔记与资料,该资料目前包含 自然语言处理各领域的 面试题积累。 目录架构 一、【

1.4k Dec 30, 2022
The SVO-Probes Dataset for Verb Understanding

The SVO-Probes Dataset for Verb Understanding This repository contains the SVO-Probes benchmark designed to probe for Subject, Verb, and Object unders

DeepMind 20 Nov 30, 2022
An extension for asreview implements a version of the tf-idf feature extractor that saves the matrix and the vocabulary.

Extension - matrix and vocabulary extractor for TF-IDF and Doc2Vec An extension for ASReview that adds a tf-idf extractor that saves the matrix and th

ASReview 4 Jun 17, 2022
Smart discord chatbot integrated with Dialogflow

academic-NLP-chatbot Smart discord chatbot integrated with Dialogflow to interact with students naturally and manage different classes in a school. De

Tom Huynh 5 Oct 24, 2022
Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing

Introduction Funnel-Transformer is a new self-attention model that gradually compresses the sequence of hidden states to a shorter one and hence reduc

GUOKUN LAI 197 Dec 11, 2022
Code for producing Japanese GPT-2 provided by rinna Co., Ltd.

japanese-gpt2 This repository provides the code for training Japanese GPT-2 models. This code has been used for producing japanese-gpt2-medium release

rinna Co.,Ltd. 491 Jan 07, 2023
PRAnCER is a web platform that enables the rapid annotation of medical terms within clinical notes.

PRAnCER (Platform enabling Rapid Annotation for Clinical Entity Recognition) is a web platform that enables the rapid annotation of medical terms within clinical notes. A user can highlight spans of

Sontag Lab 39 Nov 14, 2022
Contains descriptions and code of the mini-projects developed in various programming languages

TexttoSpeechAndLanguageTranslator-project introduction A pleasant application where the client will be given buttons like play,reset and exit. The cli

Adarsh Reddy 1 Dec 22, 2021