Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch

Overview

COCO LM Pretraining (wip)

Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra.

Install

$ pip install coco-lm-pytorch

Usage

An example using the x-transformers library

$ pip install x-transformers

Then

import torch
from coco_lm_pytorch import COCO

# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator

from x_transformers import TransformerWrapper, Encoder

generator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 256,         # smaller hidden dimension
        heads = 4,         # less heads
        ff_mult = 2,       # smaller feedforward dimension
        depth = 1
    )
)

discriminator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 1024,
        heads = 16,
        ff_mult = 4,
        depth = 12
    )
)

# (2) weight tie the token and positional embeddings of generator and discriminator

generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb

# weight tie any other embeddings if available, token type embeddings, etc.

# (3) instantiate COCO

trainer = COCO(
    generator,
    discriminator,
    discr_dim = 1024,            # the embedding dimension of the discriminator
    discr_layer = 'norm',        # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
    cls_token_id = 1,            # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning
    mask_token_id = 2,           # the token id reserved for masking
    pad_token_id = 0,            # the token id for padding
    mask_prob = 0.15,            # masking probability for masked language modeling
    mask_ignore_token_ids = [],  # ids of tokens to ignore for mask modeling ex. (cls, sep)
    cl_weight = 1.,              # weight for the contrastive learning loss
    disc_weight = 1.,            # weight for the corrective learning loss
    gen_weight = 1.              # weight for the MLM loss
)

# (4) train

data = torch.randint(0, 20000, (1, 1024))

loss = trainer(data)
loss.backward()

# after much training, the discriminator should have improved

torch.save(discriminator, f'./pretrained-model.pt')

Citations

@misc{meng2021cocolm,
    title   = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining}, 
    author  = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song},
    year    = {2021},
    eprint  = {2102.08473},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Big Bird: Transformers for Longer Sequences

BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. Moreover, BigBird comes along with a theoretical understanding of the capabilities of a complete transformer that the sparse model can handle.

Beyond Paragraphs: NLP for Long Sequences

Beyond Paragraphs: NLP for Long Sequences

Text-Summarization-using-NLP - Text Summarization using NLP  to fetch BBC News Article and summarize its text and also it includes custom article Summarization PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.
PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.

An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"

MILES is a multilingual text simplifier inspired by LSBert - A BERT-based lexical simplification approach proposed in 2018. Unlike LSBert, MILES uses the bert-base-multilingual-uncased model, as well as simple language-agnostic approaches to complex word identification (CWI) and candidate ranking. PyTorch implementation of the paper:  Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding
PyTorch implementation of the paper: Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding

Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding This repository contains the official PyTorch implementation of th

Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation.  This is part of the CASL project: http://casl-project.ai/
Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation. This is part of the CASL project: http://casl-project.ai/

Texar-PyTorch is a toolkit aiming to support a broad set of machine learning, especially natural language processing and text generation tasks. Texar

In this repository, I have developed an end to end Automatic speech recognition project. I have developed the neural network model for automatic speech recognition with PyTorch and used MLflow to manage the ML lifecycle, including experimentation, reproducibility, deployment, and a central model registry.
Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks

Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks. It takes raw videos/images + text as inputs, and outputs task predictions. ClipBERT is designed based on 2D CNNs and transformers, and uses a sparse sampling strategy to enable efficient end-to-end video-and-language learning.

Comments
  • Question about corrective LM loss

    Question about corrective LM loss

    Hi @lucidrains ,

    Thanks for your great repo!

    I looked at your code: coco_lm_pytorch.py. I see there are three losses in line 242. weighted_loss = self.cl_weight * cl_loss + self.gen_weight * mlm_loss + self.disc_weight * disc_loss

    cl_loss is the contrastive loss, mlm_loss is the loss of the auxiliary generator, and disc_loss is the loss of binary discrimination. I wonder where the LM loss of corrective language modeling loss is. Could you point me?

    Best, Abdul.

    opened by elmadany 0
  • What can v0.0.2 do?

    What can v0.0.2 do?

    I'm quite excited to give COCO-LM a try! Thanks as always for the great speedy open source repo @lucidrains .

    Quick question: has this repository been tried on real data, and if so - loosely what type of setup? Trying to figure out whether jumping in coco-lm-pytorch I should have the expectation of being a first beta-tester, or I'm looking at something that is already stable. Thanks!

    opened by dginev 0
Releases(0.0.2)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Long text token classification using LongFormer

Long text token classification using LongFormer

abhishek thakur 161 Aug 07, 2022
A high-level yet extensible library for fast language model tuning via automatic prompt search

ruPrompts ruPrompts is a high-level yet extensible library for fast language model tuning via automatic prompt search, featuring integration with Hugg

Sber AI 37 Dec 07, 2022
A CSRankings-like index for speech researchers

Speech Rankings This project mimics CSRankings to generate an ordered list of researchers in speech/spoken language processing along with their possib

Mutian He 19 Nov 26, 2022
Lumped-element impedance calculator and frequency-domain plotter.

fastZ: Lumped-Element Impedance Calculator fastZ is a small tool for calculating and visualizing electrical impedance in Python. Features include: Sup

Wesley Hileman 47 Nov 18, 2022
Azure Text-to-speech service for Home Assistant

Azure Text-to-speech service for Home Assistant The Azure text-to-speech platform uses online Azure Text-to-Speech cognitive service to read a text wi

Yassine Selmi 2 Aug 06, 2022
Proquabet - Convert your prose into proquints and then you essentially have Vogon poetry

Proquabet Turn your prose into a constant stream of encrypted and meaningless-so

Milo Fultz 2 Oct 10, 2022
Partially offline multi-language translator built upon Huggingface transformers.

Translate Command-line interface to translation pipelines, powered by Huggingface transformers. This tool can download translation models, and then us

Richard Jarry 8 Oct 25, 2022
Text Classification in Turkish Texts with Bert

You can watch the details of the project on my youtube channel Project Interface Project Second Interface Goal= Correctly guessing the classification

42 Dec 31, 2022
Python port of Google's libphonenumber

phonenumbers Python Library This is a Python port of Google's libphonenumber library It supports Python 2.5-2.7 and Python 3.x (in the same codebase,

David Drysdale 3.1k Dec 29, 2022
Toward Model Interpretability in Medical NLP

Toward Model Interpretability in Medical NLP LING380: Topics in Computational Linguistics Final Project James Cross ( 1 Mar 04, 2022

Visual Automata is a Python 3 library built as a wrapper for Caleb Evans' Automata library to add more visualization features.

Visual Automata Copyright 2021 Lewi Lie Uberg Released under the MIT license Visual Automata is a Python 3 library built as a wrapper for Caleb Evans'

Lewi Uberg 55 Nov 17, 2022
A simple word search made in python

Word Search Puzzle A simple word search made in python Usage $ python3 main.py -h usage: main.py [-h] [-c] [-f FILE] Generates a word s

Magoninho 16 Mar 10, 2022
The ability of computer software to identify words and phrases in spoken language and convert them to human-readable text

speech-recognition-py Speech recognition is the ability of computer software to identify words and phrases in spoken language and convert them to huma

Deepangshi 1 Apr 03, 2022
MRC approach for Aspect-based Sentiment Analysis (ABSA)

B-MRC MRC approach for Aspect-based Sentiment Analysis (ABSA) Paper: Bidirectional Machine Reading Comprehension for Aspect Sentiment Triplet Extracti

Phuc Phan 1 Apr 05, 2022
Chatbot with Pytorch, Python & Nextjs

Installation Instructions Make sure that you have Python 3, gcc, venv, and pip installed. Clone the repository $ git clone https://github.com/sahr

Rohit Sah 0 Dec 11, 2022
超轻量级bert的pytorch版本,大量中文注释,容易修改结构,持续更新

bert4pytorch 2021年8月27更新: 感谢大家的star,最近有小伙伴反映了一些小的bug,我也注意到了,奈何这个月工作上实在太忙,更新不及时,大约会在9月中旬集中更新一个只需要pip一下就完全可用的版本,然后会新添加一些关键注释。 再增加对抗训练的内容,更新一个完整的finetune

muqiu 317 Dec 18, 2022
Open solution to the Toxic Comment Classification Challenge

Starter code: Kaggle Toxic Comment Classification Challenge More competitions 🎇 Check collection of public projects 🎁 , where you can find multiple

minerva.ml 153 Jun 22, 2022
PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation

StyleSpeech - PyTorch Implementation PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation. Status (2021.06.09

Keon Lee 142 Jan 06, 2023
Text to speech for Vietnamese, ez to use, ez to update

Chào mọi người, đây là dự án mở nhằm giúp việc đọc được trở nên dễ dàng hơn. Rất cảm ơn đội ngũ Zalo đã cung cấp hạ tầng để mình có thể tạo ra app này

Trần Cao Minh Bách 32 Jul 29, 2022
LUKE -- Language Understanding with Knowledge-based Embeddings

LUKE (Language Understanding with Knowledge-based Embeddings) is a new pre-trained contextualized representation of words and entities based on transf

Studio Ousia 587 Dec 30, 2022