PyTorch implementation of "data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language" from Meta AI

Overview

data2vec-pytorch

PyTorch implementation of "data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language" from Meta AI (FAIR)

Data2Vec is the first high-performance self-supervised algorithm that learns the same way in multiple modalities, including speech, vision and text. Most machines learn exclusively from labeled data. However, through self-supervised learning, machines are able to learn about the world just by observing it and then figuring out the structure of images, speech or text. This is a more scalable and efficient approach for machines to tackle new complex tasks, such as understanding text for more spoken languages.

In summary, the method is as follows:

  1. The encoder extracts features from the masked inputs. These features are outputs of every transformer/linear layer.
  2. The teacher which is an EMA instance of the encoder (in eval model), extracts features from the unmasked inputs.
  3. Optional normalizations are applied to the layers/outputs of the teacher.
  4. Encoder outputs are regressed by a projection block/layer.
  5. The loss is calculated from encoder outputs and teacher outputs.

You can read the paper for more detail.

Implementation

Data2Vec is already implemented in fairseq in which for all modalities there is a seperate implementation (text, vision, audio). According to the paper:

Our primary is to design a single learning mechanism for different modalities. Despite the unified learning regime, we still use modality-specific features extractors and masking strategies. This makes sense given the vastly different nature of the input data.

This implementation differs in the fact that a single Data2Vec model is provided powered by a custom encoder (implemented using PyTorch + HuggingFace Transformers) and tries to unify the whole concept in a single module. The key concept is that there must be modality-specific feature extractions and masking strategies.

  • Masking: For each modality, the Dataset instance must return the masked source, the target and the mask tensor.

  • Feature Extraction: Features are the outputs from the transformer/attention layers. So the forward method must return outputs from all Encoder blocks of the transformer model. HuggingFace Transformers/Fairseq models return transformer layers outputs separately out of the box.

This implementation uses HuggingFace Transformers models as encoders for Data2Vec which you can inspect in the encoder.py files for each modality. Although, you can provide your own encoder model. Just make sure that your encoder must be Transformer-based according to the paper and outputs from every encoder layer must be provided.

Note: This implementation's goal is to provide the necessary building blocks of Data2Vec so anyone can adapt it to their own use case with ease, so in order to make it easy to get hands on, some functionalities like mixed precision, distributed training, etc are not included to keep it as clean & simple as possible. If you only need to train a standard large scale Data2Vec model use the official repo.

Train

First things first, install the requirements:

pip install -r requirements.txt

NLP

Train a Language Model based on RoBERTa (HuggingFace) on WikiText103

Configure the related properties in text/configs/roberta-pretraining.yaml and run:

python train.py --config text/configs/roberta-pretraining.yaml 

Vision

Run a Masked Image modeling training based on BEiT (HuggingFace)

Pass the path to the image dataset in the config file at vision/configs/beit-pretraining.yaml under dataset > path > train/test and modify other properties as you desire and run the following:

python train.py --config vision/configs/beit-pretraining.yaml 

Speech

Audio pretraining based on Wav2Vec2 (HuggingFace) on timit dataset. If you want to use other datasets like librispeech provide it in audio/dataset.py (some minor changes to the timit class would do the job because both are loaded from HuggingFace datasets)

Configure other properties as you desire and run the following:

python train.py --config audio/configs/wav2vec2-pretraining.yaml 

Pre-trained Weights

The models are available on HuggingFace Hub and you can use them like below:

RoBERTa

Data2Vec model trained with RoBERTa as the encoder (data2vec-roberta-base)

from transformers import AutoModel, AutoConfig
from transformers import RobertaModel

checkpoint = 'arxyzan/data2vec-roberta-base'

# Option 1: load using AutoModel
data2vec_roberta = AutoModel.from_pretrained(checkpoint)

# Option 2: load directly by RobertaModel
data2vec_roberta = RobertaModel.from_pretrained(checkpoint)

BEiT

Data2Vec model trained with BEiT as the encoder (data2vec-beit-base)

from transformers import AutoModel, AutoConfig
from transformers import BeitModel

checkpoint = 'arxyzan/data2vec-beit-base'

# Option 1: load using AutoModel
data2vec_beit = AutoModel.from_pretrained(checkpoint)

# Option 2: load directly by BeitModel
data2vec_beit = BeitModel.from_pretrained(checkpoint)

Wav2Vec2

Data2Vec model trained with Wav2Vec2 as the encoder (data2vec-wav2vec2-base)

from transformers import AutoModel, AutoConfig
from transformers import Wav2Vec2Model

checkpoint = 'arxyzan/data2vec-wav2vec2-base'

# Option 1: load using AutoModel
data2vec_wav2vec2 = AutoModel.from_pretrained(checkpoint)

# Option 2: load directly by Wav2Vec2Model
data2vec_wav2vec2 = Wav2Vec2Model.from_pretrained(checkpoint)

Note: The above models' weights were carefully ported from the original checkpoints in the fairseq version.

Fine-tuning

  1. Fine-tune using the checkpoints mentioned above:
# Text classification using Roberta model from HuggingFace
from transformers import RobertaModel, RobertaForSequenceClassification

checkpoint = 'arxyzan/data2vec-roberta-base'
# this is exactly a roberta model but trained with data2vec
data2vec_roberta = RobertaModel.from_pretrained(checkpoint)
text_classifier = RobertaForSequenceClassification(data2vec_roberta.config)
# assign `data2vec-roberta` weights to the roberta block of the classifier
text_classifier.roberta = data2vec_roberta
...
  1. In case you trained a model using this codebase, you can fine-tune it by taking out the encoder's state dict from the checkpoint which gives you a HuggingFace model and you can fine-tune it for any downstream task as you'd normally do for HuggingFace models.
# load a checkpoint for finetuning
from transformers import RobertaModel, RobertaConfig
roberta = RobertaModel(RobertaConfig())
checkpoint = torch.load('path/to/data2vec.pt')
roberta_state_dict = checkpoint['encoder']
# load roberta weights from the encoder part of the data2vec model
encoder = roberta.load_state_dict(roberta_state_dict)

# Now fine-tune a regular HuggingFace RoBERTa model
...

Contributions

Any contribution regarding training, development and issues are welcome!

Owner
Aryan Shekarlaban
Deep Learning Developer & Researcher
Aryan Shekarlaban
Gold standard corpus annotated with verb-preverb connections for Hungarian.

Hungarian Preverb Corpus A gold standard corpus manually annotated with verb-preverb connections for Hungarian. corpus The corpus consist of the follo

RIL Lexical Knowledge Representation Research Group 3 Jan 27, 2022
Japanese synonym library

chikkarpy chikkarpyはchikkarのPython版です。 chikkarpy is a Python version of chikkar. chikkarpy は Sudachi 同義語辞書を利用し、SudachiPyの出力に同義語展開を追加するために開発されたライブラリです。

Works Applications 48 Dec 14, 2022
FactSumm: Factual Consistency Scorer for Abstractive Summarization

FactSumm: Factual Consistency Scorer for Abstractive Summarization FactSumm is a toolkit that scores Factualy Consistency for Abstract Summarization W

devfon 83 Jan 09, 2023
Code and dataset for the EMNLP 2021 Finding paper "Can NLI Models Verify QA Systems’ Predictions?"

Code and dataset for the EMNLP 2021 Finding paper "Can NLI Models Verify QA Systems’ Predictions?"

Jifan Chen 22 Oct 21, 2022
Model for recasing and repunctuating ASR transcripts

Recasing and punctuation model based on Bert Benoit Favre 2021 This system converts a sequence of lowercase tokens without punctuation to a sequence o

Benoit Favre 88 Dec 29, 2022
A machine learning model for analyzing text for user sentiment and determine whether its a positive, neutral, or negative review.

Sentiment Analysis on Yelp's Dataset Author: Roberto Sanchez, Talent Path: D1 Group Docker Deployment: Deployment of this application can be found her

Roberto Sanchez 0 Aug 04, 2021
Wrapper to display a script output or a text file content on the desktop in sway or other wlroots-based compositors

nwg-wrapper This program is a part of the nwg-shell project. This program is a GTK3-based wrapper to display a script output, or a text file content o

Piotr Miller 94 Dec 27, 2022
PortaSpeech - PyTorch Implementation

PortaSpeech - PyTorch Implementation PyTorch Implementation of PortaSpeech: Portable and High-Quality Generative Text-to-Speech. Model Size Module Nor

Keon Lee 276 Dec 26, 2022
Trankit is a Light-Weight Transformer-based Python Toolkit for Multilingual Natural Language Processing

Trankit: A Light-Weight Transformer-based Python Toolkit for Multilingual Natural Language Processing Trankit is a light-weight Transformer-based Pyth

652 Jan 06, 2023
NLP library designed for reproducible experimentation management

Welcome to the Transfer NLP library, a framework built on top of PyTorch to promote reproducible experimentation and Transfer Learning in NLP You can

Feedly 290 Dec 20, 2022
Code for ACL 2022 main conference paper "STEMM: Self-learning with Speech-text Manifold Mixup for Speech Translation".

STEMM: Self-learning with Speech-Text Manifold Mixup for Speech Translation This is a PyTorch implementation for the ACL 2022 main conference paper ST

ICTNLP 29 Oct 16, 2022
Share constant definitions between programming languages and make your constants constant again

Introduction Reconstant lets you share constant and enum definitions between programming languages. Constants are defined in a yaml file and converted

Natan Yellin 47 Sep 10, 2022
Uses Google's gTTS module to easily create robo text readin' on command.

Tool to convert text to speech, creating files for later use. TTRS uses Google's gTTS module to easily create robo text readin' on command.

0 Jun 20, 2021
This Project is based on NLTK It generates a RANDOM WORD from a predefined list of words, From that random word it read out the word, its meaning with parts of speech , its antonyms, its synonyms

This Project is based on NLTK(Natural Language Toolkit) It generates a RANDOM WORD from a predefined list of words, From that random word it read out the word, its meaning with parts of speech , its

SaiVenkatDhulipudi 2 Nov 17, 2021
BERT has a Mouth, and It Must Speak: BERT as a Markov Random Field Language Model

BERT has a Mouth, and It Must Speak: BERT as a Markov Random Field Language Model

303 Dec 17, 2022
Develop open-source Python Arabic NLP libraries that the Arab world will easily use in all Natural Language Processing applications

Develop open-source Python Arabic NLP libraries that the Arab world will easily use in all Natural Language Processing applications

BADER ALABDAN 2 Oct 22, 2022
NLP-SentimentAnalysis - Coursera Course ( Duration : 5 weeks ) offered by DeepLearning.AI

Coursera Natural Language Processing Specialization This repository contains material related to Coursera Natural Language Processing Specialization.

Nishant Sharma 1 Jun 05, 2022
Official PyTorch implementation of Time-aware Large Kernel (TaLK) Convolutions (ICML 2020)

Time-aware Large Kernel (TaLK) Convolutions (Lioutas et al., 2020) This repository contains the source code, pre-trained models, as well as instructio

Vasileios Lioutas 28 Dec 07, 2022
A CRM department in a local bank works on classify their lost customers with their past datas. So they want predict with these method that average loss balance and passive duration for future.

Rule-Based-Classification-in-a-Banking-Case. A CRM department in a local bank works on classify their lost customers with their past datas. So they wa

ÖMER YILDIZ 4 Mar 20, 2022