Shared code for training sentence embeddings with Flax / JAX

Overview

flax-sentence-embeddings

This repository will be used to share code for the Flax / JAX community event to train sentence embeddings on 1B+ training pairs.

You can add your code by creating a pull request.

Dataloading

Dowload data

You can download the data using this basic python script at the root of the project. Download should be completed in about 20 minutes given your connection speed. Total size on disk is arround 25G.

python dataset/download_data.py --dataset_list=datasets_list.tsv --data_path=PATH_TO_STORE_DATASETS

Dataloading

First implementation of the dataloader takes as input a single jsonl.gz file. It creates a pointer on the file such that samples are loaded one by one. The implementation is based on torch standard Dataloader and Dataset classes. The class supports num_worker>0 such that data loading is done in a background process on the CPU, i.e. the data is loaded and tokenized in parallel to training the network. This avoid to create a bottleneck from I/O and tokenization. The implementation currently return {'anchor': '...,' 'positive': '...'}

from dataset.dataset import IterableCorpusDataset

corpus_dataset = IterableCorpusDataset(
  file_path=os.path.join(PATH_TO_STORE_DATASETS, 'stackexchange_duplicate_questions_title_title.json.gz'), 
  batch_size=2,
  num_workers=2, 
  transform=None)

corpus_dataset_itr = iter(corpus_dataset)
next(corpus_dataset_itr)

# {'anchor': 'Can anyone explain all these Developer Options?',
#  'positive': 'what is the advantage of using the GPU rendering options in Android?'}

def collate(batch_input_str):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    batch = {'anchor': tokenizer.batch_encode_plus([b['anchor'] for b in batch_input_str], pad_to_max_length=True),
             'positive': tokenizer.batch_encode_plus([b['positive'] for b in batch_input_str], pad_to_max_length=True)}
    return batch

corpus_dataloader = DataLoader(
  corpus_dataset,
  batch_size=2,
  num_workers=2,
  collate_fn=collate,
  pin_memory=False,
  drop_last=True,
  shuffle=False)

print(next(iter(corpus_dataloader)))

# {'anchor': {'input_ids': [[101, 4531, 2019, 2523, 2090, 2048, 4725, 1997, 2966, 8830, 1998, 1037, 7142, 8023, 102, 0, 0, 0], [101, 1039, 1001, 10463, 5164, 1061, 2100, 2100, 24335, 26876, 11927, 4779, 4779, 2102, 2000, 3058, 7292, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}, 'positive': {'input_ids': [[101, 1045, 2031, 2182, 2007, 2033, 1010, 2048, 4725, 1997, 8830, 1025, 1037, 3115, 2729, 4118, 1010, 1998, 1037, 17009, 8830, 1012, 2367, 3633, 4374, 2367, 4118, 1010, 2049, 2035, 18154, 11095, 1012, 1045, 2572, 2667, 2000, 2424, 1996, 2523, 1997, 1996, 17009, 8830, 1998, 1037, 1005, 2092, 2108, 3556, 1005, 2029, 2003, 1037, 15973, 3643, 1012, 2054, 2003, 1996, 2190, 2126, 2000, 2424, 2151, 8924, 1029, 1041, 1012, 1043, 1012, 8833, 6553, 26237, 2944, 1029, 102], [101, 1045, 2572, 2667, 2000, 10463, 1037, 5164, 3058, 2046, 1037, 4289, 2005, 29296, 3058, 7292, 1012, 1996, 4289, 2003, 2066, 1024, 1000, 2297, 2692, 20958, 2620, 17134, 19317, 19317, 1000, 1045, 2228, 2023, 1041, 16211, 4570, 2000, 1061, 2100, 2100, 24335, 26876, 11927, 4779, 4779, 2102, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}}

=======

Installation

Poetry

A Poetry toml is provided to manage dependencies in a virtualenv. Check https://python-poetry.org/

Once you've installed poetry, you can connect to virtual env and update dependencies:

poetry shell
poetry update
poetry install

requirements.txt

Someone on your platform should generate it once with following command.

poetry export -f requirements.txt --output requirements.txt

Rust compiler for hugginface tokenizers

  • Hugginface tokenizers require a Rust compiler so install one.

custom libs

  • If you want a specific version of any library, edit the pyproject.toml, add it and/or replace "*" by it.
Owner
Nils Reimers
Nils Reimers
Translation for Trilium Notes. Trilium Notes 中文版.

Trilium Translation 中文说明 This repo provides a translation for the awesome Trilium Notes. Currently, I have translated Trilium Notes into Chinese. Test

743 Jan 08, 2023
ChainKnowledgeGraph, 产业链知识图谱包括A股上市公司、行业和产品共3类实体

ChainKnowledgeGraph, 产业链知识图谱包括A股上市公司、行业和产品共3类实体,包括上市公司所属行业关系、行业上级关系、产品上游原材料关系、产品下游产品关系、公司主营产品、产品小类共6大类。 上市公司4,654家,行业511个,产品95,559条、上游材料56,824条,上级行业480条,下游产品390条,产品小类52,937条,所属行业3,946条。

liuhuanyong 415 Jan 06, 2023
Code for CodeT5: a new code-aware pre-trained encoder-decoder model.

CodeT5: Identifier-aware Unified Pre-trained Encoder-Decoder Models for Code Understanding and Generation This is the official PyTorch implementation

Salesforce 564 Jan 08, 2023
Pretrained language model and its related optimization techniques developed by Huawei Noah's Ark Lab.

Pretrained Language Model This repository provides the latest pretrained language models and its related optimization techniques developed by Huawei N

HUAWEI Noah's Ark Lab 2.6k Jan 08, 2023
Data and code to support "Applied Natural Language Processing" (INFO 256, Fall 2021, UC Berkeley)

anlp21 Course materials for "Applied Natural Language Processing" (INFO 256, Fall 2021, UC Berkeley) Syllabus: http://people.ischool.berkeley.edu/~dba

David Bamman 48 Dec 06, 2022
Extract city and country mentions from Text like GeoText without regex, but FlashText, a Aho-Corasick implementation.

flashgeotext ⚡ 🌍 Extract and count countries and cities (+their synonyms) from text, like GeoText on steroids using FlashText, a Aho-Corasick impleme

Ben 57 Dec 16, 2022
A demo for end-to-end English and Chinese text spotting using ABCNet.

ABCNet_Chinese A demo for end-to-end English and Chinese text spotting using ABCNet. This is an old model that was trained a long ago, which serves as

Yuliang Liu 45 Oct 04, 2022
ElasticBERT: A pre-trained model with multi-exit transformer architecture.

This repository contains finetuning code and checkpoints for ElasticBERT. Towards Efficient NLP: A Standard Evaluation and A Strong Baseli

fastNLP 48 Dec 14, 2022
An Explainable Leaderboard for NLP

ExplainaBoard: An Explainable Leaderboard for NLP Introduction | Website | Download | Backend | Paper | Video | Bib Introduction ExplainaBoard is an i

NeuLab 319 Dec 20, 2022
SentAugment is a data augmentation technique for semi-supervised learning in NLP.

SentAugment SentAugment is a data augmentation technique for semi-supervised learning in NLP. It uses state-of-the-art sentence embeddings to structur

Meta Research 363 Dec 30, 2022
Bidirectional Variational Inference for Non-Autoregressive Text-to-Speech (BVAE-TTS)

Bidirectional Variational Inference for Non-Autoregressive Text-to-Speech (BVAE-TTS) Yoonhyung Lee, Joongbo Shin, Kyomin Jung Abstract: Although early

LEE YOON HYUNG 147 Dec 05, 2022
NLPShala , the best IDE for all Natural language processing tasks.

The revolutionary IDE for all NLP (Natural language processing) stuffs on the internet.

Abhi 3 Aug 08, 2021
Revisiting Pre-trained Models for Chinese Natural Language Processing (Findings of EMNLP 2020)

This repository contains the resources in our paper "Revisiting Pre-trained Models for Chinese Natural Language Processing", which will be published i

Yiming Cui 463 Dec 30, 2022
official ( API ) for the zAmericanEnglish app in [ Google play ] and [ App store ]

official ( API ) for the zAmericanEnglish app in [ Google play ] and [ App store ]

Plugin 3 Jan 12, 2022
构建一个多源(公众号、RSS)、干净、个性化的阅读环境

2C 构建一个多源(公众号、RSS)、干净、个性化的阅读环境 作为一名微信公众号的重度用户,公众号一直被我设为汲取知识的地方。随着使用程度的增加,相信大家或多或少会有一个比较头疼的问题——广告问题。 假设你关注的公众号有十来个,若一个公众号两周接一次广告,理论上你会面临二十多次广告,实际上会更多,运

howie.hu 678 Dec 28, 2022
Segmenter - Transformer for Semantic Segmentation

Segmenter - Transformer for Semantic Segmentation

592 Dec 27, 2022
PyTorch Implementation of the paper Single Image Texture Translation for Data Augmentation

SITT The repo contains official PyTorch Implementation of the paper Single Image Texture Translation for Data Augmentation. Authors: Boyi Li Yin Cui T

Boyi Li 52 Jan 05, 2023
Beautiful visualizations of how language differs among document types.

Scattertext 0.1.0.0 A tool for finding distinguishing terms in corpora and displaying them in an interactive HTML scatter plot. Points corresponding t

Jason S. Kessler 2k Dec 27, 2022
Python SDK for working with Voicegain Speech-to-Text

Voicegain Speech-to-Text Python SDK Python SDK for the Voicegain Speech-to-Text API. This API allows for large vocabulary speech-to-text transcription

Voicegain 3 Dec 14, 2022
Chinese segmentation library

What is loso? loso is a Chinese segmentation system written in Python. It was developed by Victor Lin ( Fang-Pen Lin 82 Jun 28, 2022