Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch

Overview

COCO LM Pretraining (wip)

Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra.

Install

$ pip install coco-lm-pytorch

Usage

An example using the x-transformers library

$ pip install x-transformers

Then

import torch
from coco_lm_pytorch import COCO

# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator

from x_transformers import TransformerWrapper, Encoder

generator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 256,         # smaller hidden dimension
        heads = 4,         # less heads
        ff_mult = 2,       # smaller feedforward dimension
        depth = 1
    )
)

discriminator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 1024,
        heads = 16,
        ff_mult = 4,
        depth = 12
    )
)

# (2) weight tie the token and positional embeddings of generator and discriminator

generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb

# weight tie any other embeddings if available, token type embeddings, etc.

# (3) instantiate COCO

trainer = COCO(
    generator,
    discriminator,
    discr_dim = 1024,            # the embedding dimension of the discriminator
    discr_layer = 'norm',        # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
    cls_token_id = 1,            # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning
    mask_token_id = 2,           # the token id reserved for masking
    pad_token_id = 0,            # the token id for padding
    mask_prob = 0.15,            # masking probability for masked language modeling
    mask_ignore_token_ids = [],  # ids of tokens to ignore for mask modeling ex. (cls, sep)
    cl_weight = 1.,              # weight for the contrastive learning loss
    disc_weight = 1.,            # weight for the corrective learning loss
    gen_weight = 1.              # weight for the MLM loss
)

# (4) train

data = torch.randint(0, 20000, (1, 1024))

loss = trainer(data)
loss.backward()

# after much training, the discriminator should have improved

torch.save(discriminator, f'./pretrained-model.pt')

Citations

@misc{meng2021cocolm,
    title   = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining}, 
    author  = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song},
    year    = {2021},
    eprint  = {2102.08473},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Big Bird: Transformers for Longer Sequences

BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. Moreover, BigBird comes along with a theoretical understanding of the capabilities of a complete transformer that the sparse model can handle.

Beyond Paragraphs: NLP for Long Sequences

Beyond Paragraphs: NLP for Long Sequences

Text-Summarization-using-NLP - Text Summarization using NLP  to fetch BBC News Article and summarize its text and also it includes custom article Summarization PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.
PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.

An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"

MILES is a multilingual text simplifier inspired by LSBert - A BERT-based lexical simplification approach proposed in 2018. Unlike LSBert, MILES uses the bert-base-multilingual-uncased model, as well as simple language-agnostic approaches to complex word identification (CWI) and candidate ranking. PyTorch implementation of the paper:  Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding
PyTorch implementation of the paper: Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding

Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding This repository contains the official PyTorch implementation of th

Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation.  This is part of the CASL project: http://casl-project.ai/
Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation. This is part of the CASL project: http://casl-project.ai/

Texar-PyTorch is a toolkit aiming to support a broad set of machine learning, especially natural language processing and text generation tasks. Texar

In this repository, I have developed an end to end Automatic speech recognition project. I have developed the neural network model for automatic speech recognition with PyTorch and used MLflow to manage the ML lifecycle, including experimentation, reproducibility, deployment, and a central model registry.
Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks

Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks. It takes raw videos/images + text as inputs, and outputs task predictions. ClipBERT is designed based on 2D CNNs and transformers, and uses a sparse sampling strategy to enable efficient end-to-end video-and-language learning.

Comments
  • Question about corrective LM loss

    Question about corrective LM loss

    Hi @lucidrains ,

    Thanks for your great repo!

    I looked at your code: coco_lm_pytorch.py. I see there are three losses in line 242. weighted_loss = self.cl_weight * cl_loss + self.gen_weight * mlm_loss + self.disc_weight * disc_loss

    cl_loss is the contrastive loss, mlm_loss is the loss of the auxiliary generator, and disc_loss is the loss of binary discrimination. I wonder where the LM loss of corrective language modeling loss is. Could you point me?

    Best, Abdul.

    opened by elmadany 0
  • What can v0.0.2 do?

    What can v0.0.2 do?

    I'm quite excited to give COCO-LM a try! Thanks as always for the great speedy open source repo @lucidrains .

    Quick question: has this repository been tried on real data, and if so - loosely what type of setup? Trying to figure out whether jumping in coco-lm-pytorch I should have the expectation of being a first beta-tester, or I'm looking at something that is already stable. Thanks!

    opened by dginev 0
Releases(0.0.2)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
A toolkit for document-level event extraction, containing some SOTA model implementations

Document-level Event Extraction via Heterogeneous Graph-based Interaction Model with a Tracker Source code for ACL-IJCNLP 2021 Long paper: Document-le

84 Dec 15, 2022
NLP: SLU tagging

NLP: SLU tagging

北海若 3 Jan 14, 2022
Pytorch implementation of Tacotron

Tacotron-pytorch A pytorch implementation of Tacotron: A Fully End-to-End Text-To-Speech Synthesis Model. Requirements Install python 3 Install pytorc

soobin seo 203 Dec 02, 2022
pysentimiento: A Python toolkit for Sentiment Analysis and Social NLP tasks

A Python multilingual toolkit for Sentiment Analysis and Social NLP tasks

297 Dec 29, 2022
Download videos from YouTube/Twitch/Twitter right in the Windows Explorer, without installing any shady shareware apps

youtube-dl and ffmpeg Windows Explorer Integration Download videos from YouTube/Twitch/Twitter and more (any platform that is supported by youtube-dl)

Wolfgang 226 Dec 30, 2022
Yomichad - a Japanese pop-up dictionary that can display readings and English definitions of Japanese words

Yomichad is a Japanese pop-up dictionary that can display readings and English definitions of Japanese words, kanji, and optionally named entities. It is similar to yomichan, 10ten, and rikaikun in s

Jonas Belouadi 7 Nov 07, 2022
Stand-alone language identification system

langid.py readme Introduction langid.py is a standalone Language Identification (LangID) tool. The design principles are as follows: Fast Pre-trained

2k Jan 04, 2023
Python api wrapper for JellyFish Lights

Python api wrapper for JellyFish Lights The hope is to make this a pip installable package Current capabalilities: Connects to a local JellyFish Light

10 Dec 18, 2022
Korea Spell Checker

한국어 문서 koSpellPy Korean Spell checker How to use Install pip install kospellpy Use from kospellpy import spell_init spell_checker = spell_init() # d

kangsukmin 2 Oct 20, 2021
LewusBot - Twitch ChatBot built in python with twitchio library

LewusBot Twitch ChatBot built in python with twitchio library. Uses twitch/leagu

Lewus 25 Dec 04, 2022
Saptak Bhoumik 14 May 24, 2022
Named-entity recognition using neural networks. Easy-to-use and state-of-the-art results.

NeuroNER NeuroNER is a program that performs named-entity recognition (NER). Website: neuroner.com. This page gives step-by-step instructions to insta

Franck Dernoncourt 1.6k Dec 27, 2022
ElasticBERT: A pre-trained model with multi-exit transformer architecture.

This repository contains finetuning code and checkpoints for ElasticBERT. Towards Efficient NLP: A Standard Evaluation and A Strong Baseli

fastNLP 48 Dec 14, 2022
PeCo: Perceptual Codebook for BERT Pre-training of Vision Transformers

PeCo: Perceptual Codebook for BERT Pre-training of Vision Transformers

Microsoft 105 Jan 08, 2022
This is a general repo that helps you develop fast/effective NLP classifiers using Huggingface

NLP Classifier Introduction This project trains a bert model on any NLP classifcation model. And uses the model in make predictions on new data using

Abdullah Tarek 3 Mar 11, 2022
Code for our paper "Mask-Align: Self-Supervised Neural Word Alignment" in ACL 2021

Mask-Align: Self-Supervised Neural Word Alignment This is the implementation of our work Mask-Align: Self-Supervised Neural Word Alignment. @inproceed

THUNLP-MT 46 Dec 15, 2022
1 Jun 28, 2022
Python bindings to the dutch NLP tool Frog (pos tagger, lemmatiser, NER tagger, morphological analysis, shallow parser, dependency parser)

Frog for Python This is a Python binding to the Natural Language Processing suite Frog. Frog is intended for Dutch and performs part-of-speech tagging

Maarten van Gompel 46 Dec 14, 2022
LightSeq: A High-Performance Inference Library for Sequence Processing and Generation

LightSeq is a high performance inference library for sequence processing and generation implemented in CUDA. It enables highly efficient computation of modern NLP models such as BERT, GPT2, Transform

Bytedance Inc. 2.5k Jan 03, 2023
Spert NLP Relation Extraction API deployed with torchserve for inference

URLMask Python program for Linux users to change a URL to ANY domain. A program than can take any url and mask it to any domain name you like. E.g. ne

Zichu Chen 1 Nov 24, 2021