simpleT5 is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.

Overview

simplet5

Quickly train T5 models in just 3 lines of code + ONNX support

PyPI version License

simpleT5 is built on top of PyTorch-lightning ⚡️ and Transformers 🤗 that lets you quickly train your T5 models.

T5 models can be used for several NLP tasks such as summarization, QA, QG, translation, text generation, and more.

Here's a link to Medium article along with an example colab notebook

Install

pip install --upgrade simplet5

Usage

simpleT5 for summarization task Open In Collab

# import
from simplet5 import SimpleT5

# instantiate
model = SimpleT5()

# load
model.from_pretrained("t5","t5-base")

# train
model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
            eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
            source_max_token_len = 512, 
            target_max_token_len = 128,
            batch_size = 8,
            max_epochs = 5,
            use_gpu = True,
            outputdir = "outputs",
            early_stopping_patience_epochs = 0
            )

# load trained T5 model
model.load_model("t5","path/to/trained/model/directory", use_gpu=False)

# predict
model.predict("input text for prediction")

# need faster inference on CPU, get ONNX support
model.convert_and_load_onnx_model("path/to/T5 model/directory")
model.onnx_predict("input text for prediction")
Comments
  • Suppress the Output Models

    Suppress the Output Models

    Hello there!

    I'd like to ask if there is any possible way to eliminate all models, except for the last trained one. When I fine tune a model, it gives me X different models if I fine tune the model X epochs. I just need the last model and couldn't find a way to prevent writing those models to disk.

    Thanks!

    opened by bayismet 6
  • TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask In onnx_predict function

    TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask In onnx_predict function

    Hello, when I run the fine-tuned mt5 model under onnx, I get the following error:

    `TypeError Traceback (most recent call last) in ----> 1 model.onnx_predict(text)

    ~\AppData\Roaming\Python\Python38\site-packages\simplet5\simplet5.py in onnx_predict(self, source_text) 469 """ generates prediction from ONNX model """ 470 token = self.onnx_tokenizer(source_text, return_tensors="pt") --> 471 tokens = self.onnx_model.generate( 472 input_ids=token["input_ids"], 473 attention_mask=token["attention_mask"],

    C:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\grad_mode.py in decorate_context(*args, **kwargs) 26 def decorate_context(*args, **kwargs): 27 with self.class(): ---> 28 return func(*args, **kwargs) 29 return cast(F, decorate_context) 30

    C:\ProgramData\Anaconda3\lib\site-packages\transformers\generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, **model_kwargs) 1051 input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs 1052 ) -> 1053 return self.beam_search( 1054 input_ids, 1055 beam_scorer,

    C:\ProgramData\Anaconda3\lib\site-packages\transformers\generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs) 1788 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 1789 -> 1790 outputs = self( 1791 **model_inputs, 1792 return_dict=True,

    C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1050 or _global_forward_hooks or _global_forward_pre_hooks): -> 1051 return forward_call(*input, **kwargs) 1052 # Do not call functions when jit is used 1053 full_backward_hooks, non_full_backward_hooks = [], []

    TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask'`

    I tried to downgrade transformers and onnxruntime but the error still remains.

    opened by farshadfiruzi 6
  • colab error

    colab error

    When running google colab, the line below produced error:

    model.load_model("t5","outputs/SimpleT5-epoch-2-train-loss-0.9526", use_gpu=True)

    Error: 404 Client Error: Not Found for url: https://huggingface.co/outputs/SimpleT5-epoch-2-train-loss-0.9526/resolve/main/config.json

    Please help. Thanks a lot.

    opened by yzhang-github-pub 6
  • shows object has no attribute 'convert_and_load_onnx_model'

    shows object has no attribute 'convert_and_load_onnx_model'

    Shows the error AttributeError: 'SimpleT5' object has no attribute 'convert_and_load_onnx_model' while running the example notebook provided in the repository

    https://github.com/Shivanandroy/simpleT5/blob/main/examples/simpleT5-summarization.ipynb

    opened by pradeepdev-1995 4
  • Adding logger

    Adding logger

    With this change, users can provide a PyTorch Lightning logger object to the .train() method:

    from pytorch_lightning.loggers import WandbLogger
    
    wandb_logger = WandbLogger(project="my-project", name="run-name")
    
    model.train(
        train_df=train_df,
        eval_df=eval_df,
        logger=wandb_logger
    )
    
    opened by versae 3
  • codet5 support added

    codet5 support added

    Needed to use this package for my experiemnts,

    Salesforce/codet5-base uses a roberta tokenizer hence this pull request.

    users can now specify : model.from_pretrained("codet5","Salesforce/codet5-base")

    If you want to read through codet5

    Here are the links: https://huggingface.co/Salesforce/codet5-base

    Kind regards, Mosh

    opened by mosh98 3
  • byT5 with version 0.1.2

    byT5 with version 0.1.2

    hi there, it seems that the newest version of simpleT5 does no longer work with byT5. The line elif model_type == "byt5": is commented out. The newest version of transformers seems to use a new type of tokenizer T5TokenizerFast and ByT5TokenizerFast does not exist. Any ideas about how to fix that?

    opened by kimgerdes 2
  • Is there any option for fine-tuning mt5 models instead of training from scratch?

    Is there any option for fine-tuning mt5 models instead of training from scratch?

    Hi, Thanks for the amazing simpleT5 package. I use the following script to train a mt5 model for summarization task.

    from simplet5 import SimpleT5

    model = SimpleT5()

    model.from_pretrained(model_type="mt5", model_name="google/mt5-small")

    model.train(train_df=train_df, eval_df=test_df, source_max_token_len=256, target_max_token_len=64, batch_size=8, max_epochs=3, use_gpu=True, outputdir = "outputs", early_stopping_patience_epochs = 0, )

    When I run this code, training start from scratch. My question is that is there any flag to fine-tune the mt5 model instead of training from scratch?

    opened by farshadfiruzi 2
  • Kernel dies every time when I start training the model

    Kernel dies every time when I start training the model

    Hi Shiva, Thank you very much for a such clean and neat wrapper for training ML models. I am using t5(precisely t5-small) as the base to train my model for summarization. I use the dataset using datasets from huggingface. However, everytime when I initiate the training code, the kernel dies and restarts. Any help here is much appreciated!

    Following is my code.

    Import dependencies

    %%capture
    !pip install --user simplet5==0.1.4
    !pip install transformers
    !pip install wandb
    !pip install pandas
    !pip install datasets
    !pip install --user simpletransformers
    

    Load data using datasets from huggingface

    import pandas as pd
    import warnings
    warnings.filterwarnings("ignore")
    from datasets import load_dataset
    dataset = load_dataset("scitldr")
    

    Preparing the train and eval data

    train_df = dataset["train"].to_pandas().copy()
    train_df.drop(columns=["source_labels","rouge_scores","paper_id"],inplace=True)
    train_df.rename(columns={"source":"source_text","target":"target_text"}, inplace=True)
    train_df.count() ## No NaN found - zero 1992 dataset
    
    train_df['source_text'] = train_df['source_text'].astype('str').str.rstrip(']\'')
    train_df['source_text'] = train_df['source_text'].astype('str').str.lstrip('[\'')
    train_df['target_text'] = train_df['target_text'].astype('str').str.rstrip(']\'')
    train_df['target_text'] = train_df['target_text'].astype('str').str.lstrip('[\'')
    
    train_df["source_text"]=train_df["source_text"].str.replace('\'','')
    train_df["target_text"]=train_df["target_text"].str.replace('\'','')
    train_df["source_text"]="summarize: "+train_df["source_text"]
    train_df.to_csv("train.csv")
    
    eval_df = dataset["validation"].to_pandas().copy()
    eval_df.drop(columns=["source_labels","rouge_scores","paper_id"],inplace=True)
    eval_df.rename(columns={"source":"source_text","target":"target_text"}, inplace=True)
    eval_df.count() ## No NaN found - zero 1992 dataset
    
    eval_df['source_text'] = eval_df['source_text'].astype('str').str.rstrip(']\'')
    eval_df['source_text'] = eval_df['source_text'].astype('str').str.lstrip('[\'')
    eval_df['target_text'] = eval_df['target_text'].astype('str').str.rstrip(']\'')
    eval_df['target_text'] = eval_df['target_text'].astype('str').str.lstrip('[\'')
    
    eval_df["source_text"]=train_df["source_text"].str.replace('\'','')
    eval_df["target_text"]=train_df["target_text"].str.replace('\'','')
    eval_df["source_text"]="summarize: "+train_df["source_text"]
    eval_df.to_csv("eval.csv")
    

    Loading simpleT5 and wandb_logger and finally loading the model and training code

    from simplet5 import SimpleT5
    from pytorch_lightning.loggers import WandbLogger
    wandb_logger = WandbLogger(project="ask-poc-logger")
    model = SimpleT5()
    model.from_pretrained("t5","t5-small")
    model.train(train_df=train_df[0:100], 
                eval_df=eval_df[0:100],
                source_max_token_len = 512, 
                target_max_token_len = 100,
                batch_size = 2,
                max_epochs = 3,
                use_gpu = True,
                outputdir = "outputs",
                logger = wandb_logger
                )
    

    I am running this code on the following machine. A vertex AI workbench from Google Cloud. N1-Standard-16 machine type with 16 core and 60 GB Memory. And added GPU P100. Any help is much appreciated ! Thanks in advance!

    opened by kkrishnan90 1
  • Is the task string necessary?

    Is the task string necessary?

    Hi,

    I have fine-tuned the model to write a compliment for a person, given the person's profile and it works pretty well. In the training examples, I haven't prepended the string 'summarize :' to the source_string column entries. Is it necessary (does it lead to better results) to prepend the string indicating the task?

    opened by nikogamulin 1
  • Push finished model

    Push finished model

    Is there a way of automatically pushing the checkpoints to the HuggingFace hub? I am running this mainly in Colab. Works great but often the Colab has timed out, and the checkpoints are lost.

    opened by peregilk 2
  • Unicode Charecter training issue

    Unicode Charecter training issue

    I tried to train My model for translating English to Bengali. After Training when I run the code, The output is not Unicode Bengali character.

    I Eat Rice (eng)=> আমি ভাত খাই (Bn)

    this type of data is input to the model while training. After complete, when I tested the model by inputting "I Eat Rice" I was expecting "আমি ভাত খাই" as output. But instead of this, the model gave me "Ich esse Reis." I dont know what kind of language is this. Its not related to bengali.

    opened by rahat10120141 5
  • ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

    ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

    import soundfile as sf from scipy.io import wavfile from IPython.display import Audio from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer

    import speech_recognition as sr import io from pydub import AudioSegment

    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

    r = sr.Recognizer() with sr.Microphone(sample_rate=16000) as source: print("speak") while True: audio = r.listen(source) data = io.BytesIO(audio.get_wav_data()) clip = AudioSegment.from_file(data) x = torch.FloatTensor(clip.get_array_of_samples()) print(x)

        inputs = tokenizer(x, sampling_rate=16000, return_tensors='pt', padding='longest').input_values
        logits = model(inputs).logits
        tokens = torch.argmax(logits, axis=-1)
        text = tokenizer.batch_decode(tokens)
    
        print('you said: ', str(text).lower())
    
    opened by Ushanjay 1
  • Saved model name not customizable

    Saved model name not customizable

    def training_epoch_end(self, training_step_outputs): """ save tokenizer and model on epoch end """ self.average_training_loss = np.round( torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(), 4, ) path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(self.average_training_loss)}-val-loss-{str(self.average_validation_loss)}"

    Will be very helpful if you can allow the name customizable (note the 'path' assignment).

    Btw, SimpleT5 is simply cool!

    opened by ke-lara 0
Releases(v0.1.4)
  • v0.1.4(Feb 15, 2022)

    SimpleT5 v0.1.4

    • Added support for all the pytorch-lightning supported loggers #11
    • Added support to save model only at last epoch #21
    • Added dataloader_num_workers parameter to.train( ) method to specify number of worker in train/test/val dataloader #19
    • fixed warnings and made compatible with latest transformers and pytorch-lightning
    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Sep 4, 2021)

  • version-0.1.0(Jul 13, 2021)

    SimpleT5 - version 0.1.0

    • Supports ByT5 model - Thanks to @mapmeld for his contribution
    from simplet5 import SimpleT5
    model = SimpleT5()
    model.from_pretrained("byt5", "google/byt5-small")
    
    • Added precision flag to support mixed precision training
    # train
    model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
                eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
                source_max_token_len = 512, 
                target_max_token_len = 128,
                batch_size = 8,
                max_epochs = 5,
                use_gpu = True,
                outputdir = "outputs",
                early_stopping_patience_epochs = 0,
                precision = 32
                )
    
    Source code(tar.gz)
    Source code(zip)
Owner
Shivanand Roy
Data Scientist.
Shivanand Roy
Code for our paper "Transfer Learning for Sequence Generation: from Single-source to Multi-source" in ACL 2021.

TRICE: a task-agnostic transferring framework for multi-source sequence generation This is the source code of our work Transfer Learning for Sequence

THUNLP-MT 9 Jun 27, 2022
PyTranslator é simultaneamente um editor e tradutor de texto com diversos recursos e interface feito com coração e 100% em Python

PyTranslator O Que é e para que serve o PyTranslator? PyTranslator é simultaneamente um editor e tradutor de texto em com interface gráfica que usa a

Elizeu Barbosa Abreu 1 May 12, 2022
Mesh TensorFlow: Model Parallelism Made Easier

Mesh TensorFlow - Model Parallelism Made Easier Introduction Mesh TensorFlow (mtf) is a language for distributed deep learning, capable of specifying

1.3k Dec 26, 2022
CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation

CPT This repository contains code and checkpoints for CPT. CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Gener

fastNLP 342 Jan 05, 2023
초성 해석기 based on ko-BART

초성 해석기 개요 한국어 초성만으로 이루어진 문장을 입력하면, 완성된 문장을 예측하는 초성 해석기입니다. 초성: ㄴㄴ ㄴㄹ ㅈㅇㅎ 예측 문장: 나는 너를 좋아해 모델 모델은 SKT-AI에서 공개한 Ko-BART를 이용합니다. 데이터 문장 단위로 이루어진 아무 코퍼스나

Dawoon Jung 29 Oct 28, 2022
A library for end-to-end learning of embedding index and retrieval model

Poeem Poeem is a library for efficient approximate nearest neighbor (ANN) search, which has been widely adopted in industrial recommendation, advertis

54 Dec 21, 2022
DAGAN - Dual Attention GANs for Semantic Image Synthesis

Contents Semantic Image Synthesis with DAGAN Installation Dataset Preparation Generating Images Using Pretrained Model Train and Test New Models Evalu

Hao Tang 104 Oct 08, 2022
TruthfulQA: Measuring How Models Imitate Human Falsehoods

TruthfulQA: Measuring How Models Imitate Human Falsehoods

69 Dec 25, 2022
This project uses unsupervised machine learning to identify correlations between daily inoculation rates in the USA and twitter sentiment in regards to COVID-19.

Twitter COVID-19 Sentiment Analysis Members: Christopher Bach | Khalid Hamid Fallous | Jay Hirpara | Jing Tang | Graham Thomas | David Wetherhold Pro

4 Oct 15, 2022
Help you discover excellent English projects and get rid of disturbing by other spoken language

GitHub English Top Charts 「Help you discover excellent English projects and get

GrowingGit 544 Jan 09, 2023
Implementation of "Adversarial purification with Score-based generative models", ICML 2021

Adversarial Purification with Score-based Generative Models by Jongmin Yoon, Sung Ju Hwang, Juho Lee This repository includes the official PyTorch imp

15 Dec 15, 2022
translate using your voice

speech-to-text-translator Usage translate using your voice description this project makes translating a word easy, all you have to do is speak and...

1 Oct 18, 2021
A simple chatbot based on chatterbot that you can use for anything has basic features

Chatbotium A simple chatbot based on chatterbot that you can use for anything has basic features. I have some errors Read the paragraph below: Known b

Herman 1 Feb 16, 2022
A natural language processing model for sequential sentence classification in medical abstracts.

NLP PubMed Medical Research Paper Abstract (Randomized Controlled Trial) A natural language processing model for sequential sentence classification in

Hemanth Chandran 1 Jan 17, 2022
Implementation for paper BLEU: a Method for Automatic Evaluation of Machine Translation

BLEU Score Implementation for paper: BLEU: a Method for Automatic Evaluation of Machine Translation Author: Ba Ngoc from ProtonX BLEU score is a popul

Ngoc Nguyen Ba 6 Oct 07, 2021
Neural text generators like the GPT models promise a general-purpose means of manipulating texts.

Boolean Prompting for Neural Text Generators Neural text generators like the GPT models promise a general-purpose means of manipulating texts. These m

Jeffrey M. Binder 20 Jan 09, 2023
CJK computer science terms comparison / 中日韓電腦科學術語對照 / 日中韓のコンピュータ科学の用語対照 / 한·중·일 전산학 용어 대조

CJK computer science terms comparison This repository contains the source code of the website. You can see the website from the following link: Englis

Hong Minhee (洪 民憙) 88 Dec 23, 2022
Train and use generative text models in a few lines of code.

blather Train and use generative text models in a few lines of code. To see blather in action check out the colab notebook! Installation Use the packa

Dan Carroll 16 Nov 07, 2022
GSoC'2021 | TensorFlow implementation of Wav2Vec2

GSoC'2021 | TensorFlow implementation of Wav2Vec2

Vasudev Gupta 73 Nov 28, 2022