Repository for fine-tuning Transformers 🤗 based seq2seq speech models in JAX/Flax.

Overview

Seq2Seq Speech in JAX

A JAX/Flax repository for combining a pre-trained speech encoder model (e.g. Wav2Vec2, HuBERT, WavLM) with a pre-trained text decoder model (e.g. GPT2, Bart) to yield a Speech Sequence-to-Sequence (Seq2Seq) model for automatic speech recognition.

The script run_flax_speech_recognition_seq2seq.py can be used to fine-tune a Speech Seq2Seq model on one of the official speech recognition datasets or a custom dataset. It makes use of the pmap JAX operator to provide model parallelism accross GPU/TPU devices.

The modelling files are based very heavily on those from Hugging Face Transformers 🤗 . This is a standalone repository to enable rapid prototyping and involvement with the community. The final modelling files and training script will be merged into Transformers 🤗 to be used with the rest of the open-source library. The final system weights will be made publicly available at huggingface.co 🚀

Seq2SeqModel Figure 1: Speech-encoder text-decoder style Seq2Seq model.

Example Usage

To instantiate a Wav2Vec2-2-Bart model with the FlaxSpeechEncoderDecoderModel framework, run the following Python script inside the cloned repo:

from transformers import AutoFeatureExtractor, AutoTokenizer
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
import numpy as np

# checkpoints to leverage
encoder_id = "facebook/wav2vec2-large-lv60"
decoder_id = "facebook/bart-large"

model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_id, decoder_id, encoder_add_adapter=True, decoder_from_pt=True)

model.config.decoder_start_token_id = model.config.decoder.bos_token_id
model.config.pad_token_id = model.config.decoder.pad_token_id
model.config.eos_token_id = model.config.decoder.eos_token_id
model.config.use_cache = False
model.config.processor_class = "Wav2Vec2Processor"

# check if generation works
out = model.generate(np.ones((1, 2000)))

model.save_pretrained("./")

feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
feature_extractor.save_pretrained("./")
tokenizer = AutoTokenizer.from_pretrained(decoder_id)
tokenizer.save_pretrained("./")

To train the model on Librispeech ASR in default precision, run the bash script provided below:

#!/usr/bin/env bash
python run_flax_speech_recognition_seq2seq.py \
        --dataset_name="librispeech_asr" \
        --model_name_or_path="./" \
        --dataset_config_name="clean" \
        --train_split_name="train.100" \
        --eval_split_name="validation" \
        --output_dir="./" \
        --preprocessing_num_workers="16" \
        --length_column_name="input_length" \
        --overwrite_output_dir \
        --num_train_epochs="5" \
        --per_device_train_batch_size="2" \
        --per_device_eval_batch_size="2" \
        --gradient_accumulation_steps="1" \
        --logging_steps="25" \
        --max_duration_in_seconds="15" \
        --max_target_length="128" \
        --generation_max_length="40" \
        --generation_num_beams="1" \
        --learning_rate="1e-4" \
        --warmup_steps="500" \
        --text_column_name="text" \
        --save_total_limit="1" \
        --freeze_feature_encoder \
        --predict_with_generate \
        --do_lower_case \
        --do_eval \
        --do_train
Owner
Sanchit Gandhi
Open-Source Speech @huggingface
Sanchit Gandhi
Implementation of Fast Transformer in Pytorch

Fast Transformer - Pytorch Implementation of Fast Transformer in Pytorch. This only work as an encoder. Yannic video AI Epiphany Install $ pip install

Phil Wang 167 Dec 27, 2022
Unsupervised text tokenizer for Neural Network-based text generation.

SentencePiece SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation systems where the vocabu

Google 6.4k Jan 01, 2023
Graphical user interface for Argos Translate

Argos Translate GUI Website | GitHub | PyPI Graphical user interface for Argos Translate. Install pip3 install argostranslategui

Argos Open Tech 16 Dec 07, 2022
Easy-to-use CPM for Chinese text generation

CPM 项目描述 CPM(Chinese Pretrained Models)模型是北京智源人工智能研究院和清华大学发布的中文大规模预训练模型。官方发布了三种规模的模型,参数量分别为109M、334M、2.6B,用户需申请与通过审核,方可下载。 由于原项目需要考虑大模型的训练和使用,需要安装较为复杂

382 Jan 07, 2023
Transformer related optimization, including BERT, GPT

This repository provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA.

NVIDIA Corporation 1.7k Jan 04, 2023
A python script that will use hydra to get user and password to login to ssh, ftp, and telnet

Hydra-Auto-Hack A python script that will use hydra to get user and password to login to ssh, ftp, and telnet Project Description This python script w

2 Jan 16, 2022
A framework for training and evaluating AI models on a variety of openly available dialogue datasets.

ParlAI (pronounced “par-lay”) is a python framework for sharing, training and testing dialogue models, from open-domain chitchat, to task-oriented dia

Facebook Research 9.7k Jan 09, 2023
Club chatbot

Chatbot Club chatbot Instructions to get the Chatterbot working Step 1. First make sure you are using a version of Python 3 or newer. To check your ve

5 Mar 07, 2022
This converter will create the exact measure for your cappuccino recipe from the grandiose Rafaella Ballerini!

About CappuccinoJs This converter will create the exact measure for your cappuccino recipe from the grandiose Rafaella Ballerini! Este conversor criar

Arthur Ottoni Ribeiro 48 Nov 15, 2022
A simple Streamlit App to classify swahili news into different categories.

Swahili News Classifier Streamlit App A simple app to classify swahili news into different categories. Installation Install all streamlit requirements

Davis David 4 May 01, 2022
Code for text augmentation method leveraging large-scale language models

HyperMix Code for our paper GPT3Mix and conducting classification experiments using GPT-3 prompt-based data augmentation. Getting Started Installing P

NAVER AI 47 Dec 20, 2022
spaCy-wrap: For Wrapping fine-tuned transformers in spaCy pipelines

spaCy-wrap: For Wrapping fine-tuned transformers in spaCy pipelines spaCy-wrap is minimal library intended for wrapping fine-tuned transformers from t

Kenneth Enevoldsen 32 Dec 29, 2022
Leon is an open-source personal assistant who can live on your server.

Leon Your open-source personal assistant. Website :: Documentation :: Roadmap :: Contributing :: Story 👋 Introduction Leon is an open-source personal

Leon AI 11.7k Dec 30, 2022
Ελληνικά νέα (Python script) / Greek News Feed (Python script)

Ελληνικά νέα (Python script) / Greek News Feed (Python script) Ελληνικά English Το 2017 είχα υλοποιήσει ένα Python script για να εμφανίζει τα τωρινά ν

Loren Kociko 1 Jun 14, 2022
Code for "Semantic Role Labeling as Dependency Parsing: Exploring Latent Tree Structures Inside Arguments".

Code for "Semantic Role Labeling as Dependency Parsing: Exploring Latent Tree Structures Inside Arguments".

Yu Zhang 50 Nov 08, 2022
Understanding the Difficulty of Training Transformers

Admin Understanding the Difficulty of Training Transformers Guided by our analyses, we propose Adaptive Model Initialization (Admin), which successful

Liyuan Liu 300 Dec 29, 2022
🗣️ NALP is a library that covers Natural Adversarial Language Processing.

NALP: Natural Adversarial Language Processing Welcome to NALP. Have you ever wanted to create natural text from raw sources? If yes, NALP is for you!

Gustavo Rosa 21 Aug 12, 2022
A collection of models for image - text generation in ACM MM 2021.

Bi-directional Image and Text Generation UMT-BITG (image & text generator) Unifying Multimodal Transformer for Bi-directional Image and Text Generatio

Multimedia Research 63 Oct 30, 2022
Beyond Masking: Demystifying Token-Based Pre-Training for Vision Transformers

beyond masking Beyond Masking: Demystifying Token-Based Pre-Training for Vision Transformers The code is coming Figure 1: Pipeline of token-based pre-

Yunjie Tian 23 Sep 27, 2022
null

CP-Cluster Confidence Propagation Cluster aims to replace NMS-based methods as a better box fusion framework in 2D/3D Object detection, Instance Segme

Yichun Shen 41 Dec 08, 2022