Repository for fine-tuning Transformers 🤗 based seq2seq speech models in JAX/Flax.

Overview

Seq2Seq Speech in JAX

A JAX/Flax repository for combining a pre-trained speech encoder model (e.g. Wav2Vec2, HuBERT, WavLM) with a pre-trained text decoder model (e.g. GPT2, Bart) to yield a Speech Sequence-to-Sequence (Seq2Seq) model for automatic speech recognition.

The script run_flax_speech_recognition_seq2seq.py can be used to fine-tune a Speech Seq2Seq model on one of the official speech recognition datasets or a custom dataset. It makes use of the pmap JAX operator to provide model parallelism accross GPU/TPU devices.

The modelling files are based very heavily on those from Hugging Face Transformers 🤗 . This is a standalone repository to enable rapid prototyping and involvement with the community. The final modelling files and training script will be merged into Transformers 🤗 to be used with the rest of the open-source library. The final system weights will be made publicly available at huggingface.co 🚀

Seq2SeqModel Figure 1: Speech-encoder text-decoder style Seq2Seq model.

Example Usage

To instantiate a Wav2Vec2-2-Bart model with the FlaxSpeechEncoderDecoderModel framework, run the following Python script inside the cloned repo:

from transformers import AutoFeatureExtractor, AutoTokenizer
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
import numpy as np

# checkpoints to leverage
encoder_id = "facebook/wav2vec2-large-lv60"
decoder_id = "facebook/bart-large"

model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_id, decoder_id, encoder_add_adapter=True, decoder_from_pt=True)

model.config.decoder_start_token_id = model.config.decoder.bos_token_id
model.config.pad_token_id = model.config.decoder.pad_token_id
model.config.eos_token_id = model.config.decoder.eos_token_id
model.config.use_cache = False
model.config.processor_class = "Wav2Vec2Processor"

# check if generation works
out = model.generate(np.ones((1, 2000)))

model.save_pretrained("./")

feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
feature_extractor.save_pretrained("./")
tokenizer = AutoTokenizer.from_pretrained(decoder_id)
tokenizer.save_pretrained("./")

To train the model on Librispeech ASR in default precision, run the bash script provided below:

#!/usr/bin/env bash
python run_flax_speech_recognition_seq2seq.py \
        --dataset_name="librispeech_asr" \
        --model_name_or_path="./" \
        --dataset_config_name="clean" \
        --train_split_name="train.100" \
        --eval_split_name="validation" \
        --output_dir="./" \
        --preprocessing_num_workers="16" \
        --length_column_name="input_length" \
        --overwrite_output_dir \
        --num_train_epochs="5" \
        --per_device_train_batch_size="2" \
        --per_device_eval_batch_size="2" \
        --gradient_accumulation_steps="1" \
        --logging_steps="25" \
        --max_duration_in_seconds="15" \
        --max_target_length="128" \
        --generation_max_length="40" \
        --generation_num_beams="1" \
        --learning_rate="1e-4" \
        --warmup_steps="500" \
        --text_column_name="text" \
        --save_total_limit="1" \
        --freeze_feature_encoder \
        --predict_with_generate \
        --do_lower_case \
        --do_eval \
        --do_train
Owner
Sanchit Gandhi
Open-Source Speech @huggingface
Sanchit Gandhi
Code for evaluating Japanese pretrained models provided by NTT Ltd.

japanese-dialog-transformers 日本語の説明文はこちら This repository provides the information necessary to evaluate the Japanese Transformer Encoder-decoder dialo

NTT Communication Science Laboratories 216 Dec 22, 2022
The official code for “DocTr: Document Image Transformer for Geometric Unwarping and Illumination Correction”, ACM MM, Oral Paper, 2021.

Good news! Our new work exhibits state-of-the-art performances on DocUNet benchmark dataset: DocScanner: Robust Document Image Rectification with Prog

Hao Feng 231 Dec 26, 2022
Weakly-supervised Text Classification Based on Keyword Graph

Weakly-supervised Text Classification Based on Keyword Graph How to run? Download data Our dataset follows previous works. For long texts, we follow C

Hello_World 20 Dec 29, 2022
Original implementation of the pooling method introduced in "Speaker embeddings by modeling channel-wise correlations"

Speaker-Embeddings-Correlation-Pooling This is the original implementation of the pooling method introduced in "Speaker embeddings by modeling channel

Themos Stafylakis 10 Apr 30, 2022
Skipgram Negative Sampling in PyTorch

PyTorch SGNS Word2Vec's SkipGramNegativeSampling in Python. Yet another but quite general negative sampling loss implemented in PyTorch. It can be use

Jamie J. Seol 287 Dec 14, 2022
The NewSHead dataset is a multi-doc headline dataset used in NHNet for training a headline summarization model.

This repository contains the raw dataset used in NHNet [1] for the task of News Story Headline Generation. The code of data processing and training is available under Tensorflow Models - NHNet.

Google Research Datasets 31 Jul 15, 2022
Phrase-Based & Neural Unsupervised Machine Translation

Unsupervised Machine Translation This repository contains the original implementation of the unsupervised PBSMT and NMT models presented in Phrase-Bas

Facebook Research 1.5k Dec 28, 2022
Reproduction process of BERT on SST2 dataset

BERT-SST2-Prod Reproduction process of BERT on SST2 dataset 安装说明 下载代码库 git clone https://github.com/JunnYu/BERT-SST2-Prod 进入文件夹,安装requirements pip ins

yujun 1 Nov 18, 2021
DensePhrases provides answers to your natural language questions from the entire Wikipedia in real-time

DensePhrases provides answers to your natural language questions from the entire Wikipedia in real-time. While it efficiently searches the answers out of 60 billion phrases in Wikipedia, it is also v

Jinhyuk Lee 543 Jan 08, 2023
A script that automatically creates a branch name using google translation api and jira api

About google translation api와 jira api을 사용하여 자동으로 브랜치 이름을 만들어주는 스크립트 Setup 환경변수에 다음 3가지를 등록해야 한다. JIRA_USER : JIRA email (ex: hyunwook.kim 2 Dec 20, 2021

EMNLP'2021: Can Language Models be Biomedical Knowledge Bases?

BioLAMA BioLAMA is biomedical factual knowledge triples for probing biomedical LMs. The triples are collected and pre-processed from three sources: CT

DMIS Laboratory - Korea University 41 Nov 18, 2022
AllenNLP integration for Shiba: Japanese CANINE model

Allennlp Integration for Shiba allennlp-shiab-model is a Python library that provides AllenNLP integration for shiba-model. SHIBA is an approximate re

Shunsuke KITADA 12 Feb 16, 2022
Unsupervised text tokenizer focused on computational efficiency

YouTokenToMe YouTokenToMe is an unsupervised text tokenizer focused on computational efficiency. It currently implements fast Byte Pair Encoding (BPE)

VK.com 847 Dec 19, 2022
Extract city and country mentions from Text like GeoText without regex, but FlashText, a Aho-Corasick implementation.

flashgeotext ⚡ 🌍 Extract and count countries and cities (+their synonyms) from text, like GeoText on steroids using FlashText, a Aho-Corasick impleme

Ben 57 Dec 16, 2022
Unofficial PyTorch implementation of Google AI's VoiceFilter system

VoiceFilter Note from Seung-won (2020.10.25) Hi everyone! It's Seung-won from MINDs Lab, Inc. It's been a long time since I've released this open-sour

MINDs Lab 881 Jan 03, 2023
A Chinese to English Neural Model Translation Project

ZH-EN NMT Chinese to English Neural Machine Translation This project is inspired by Stanford's CS224N NMT Project Dataset used in this project: News C

Zhenbang Feng 29 Nov 26, 2022
Text to speech converter with GUI made in Python.

Text-to-speech-with-GUI Text to speech converter with GUI made in Python. To run this download the zip file and run the main file or clone this repo.

SidTheMiner 1 Nov 15, 2021
Implementation of paper Does syntax matter? A strong baseline for Aspect-based Sentiment Analysis with RoBERTa.

RoBERTaABSA This repo contains the code for NAACL 2021 paper titled Does syntax matter? A strong baseline for Aspect-based Sentiment Analysis with RoB

106 Nov 28, 2022
Ptorch NLU, a Chinese text classification and sequence annotation toolkit, supports multi class and multi label classification tasks of Chinese long text and short text, and supports sequence annotation tasks such as Chinese named entity recognition, part of speech tagging and word segmentation.

Pytorch-NLU,一个中文文本分类、序列标注工具包,支持中文长文本、短文本的多类、多标签分类任务,支持中文命名实体识别、词性标注、分词等序列标注任务。 Ptorch NLU, a Chinese text classification and sequence annotation toolkit, supports multi class and multi label classifi

186 Dec 24, 2022
A simple Speech Emotion Recognition (SER) API created using Flask and running in a Docker container.

keyword_searching Steps to use this Python scripts: (1)Paste this script into the file folder containing the PDF files you need to search from; (2)Thi

2 Nov 11, 2022