ZeroGen: Efficient Zero-shot Learning via Dataset Generation

Overview

ZEROGEN

This repository contains the code for our paper “ZeroGen: Efficient Zero-shot Learning via Dataset Generation”. Our implementation is built on the source code from dino. Thanks for their work.

If you use this code, please cite our paper:

@article{ye2022zerogen,
      title={ZeroGen: Efficient Zero-shot Learning via Dataset Generation}, 
      author={Jiacheng Ye and Jiahui Gao and Qintong Li and Hang Xu and Jiangtao Feng and Zhiyong Wu and Tao Yu and Lingpeng Kong},
      year={2022},
      eprint={2202.07922},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Setup

All requirements for ZEROGEN can be found in requirements.txt. You can install all required packages in a new environment with pip install -r requirements.txt.

Usage

The scripts/run_cls.sh and scripts/run_qa.sh scripts contain the running commands for the following settings:

  • supervised learning with human annotations (SUPERVISED)
  • prompt-based zero-shot learning (PROMPTING)
  • efficient zero-shot learning via dataset generation (ZEROGEN)

For text classification (TC) tasks (e.g., SST-2 and IMDb) and natural language inference (NLI) tasks (e.g., QNLI and RTE), run with bash scripts/run_cls.sh. For question answering (QA) tasks, run with bash scripts/run_qa.sh

When generating X (i.e., denotes text in TC, hypothesis in NLI and question in QA) in the final stage of the scripts, we also train the small model and evaluate it on human annotations. Specifically, after generating log_every number of examples, we perform training on the synthetic dataset and evaluation on the gold validation set. This gives as a trend graph similar to Figure 2 in the paper, which is shown by wandb, a powerful toolkit to track experiments.

Before running, you need to reset the following parameters to yours:

  • home_dir: path to ZeroGen
  • gpu: gpu id
  • batch_size: the batch size for generating with PLM. For SST-2, it costs ~16G when using a batch size of 32 with gpt2-xl. While for SQuAD, it costs ~60G using the same batch size and PLM because of the longer contexts. So decrease the batch size if needed.
  • WANDB_PROJECT: project name, by default ZeroGen
  • WANDB_ENTITY: your wandb username
  • WANDB_API_KEY: your api-key

By default we use GPT2-XL as pre-trained language model (PLM) and DistilBERT as tiny-task model (TAM), to modify the size of PLM and TAM, you can change model_name and small_model_name in run_xxx.sh scripts.

Run with a synthesized dataset

After dataset generation, we save the synthetic dataset at:

  • For TC and NLI: out-${task_name}-x2/${dataset}/${task_name}-dataset.jsonl (e.g., out-sst-2-x2/gpt2-xl_topk0_topp0.9_sst-2-x2/sst-2-dataset.jsonl). The file is in json line format (e.g., {"C": "The Book of Mormon Musical", "X": "The Book of Mormon Musical brings all the drama and excitement of a real revival of the Broadway production to the big screen.", "Y": 0}).
  • For QA: out-${task_name}-x2/${dataset}. We save the dataset in huggingface Dataset format.

To run DistilBERT given a generated dataset, you can use the scripts/run_distilbert.sh script.

To run a LSTM-based model given a generated dataset, you can use the scripts/run_cls_lstm.sh script. Before that, you have to download the datasets from google drive link, which contain the standard test files.

Diversity and Correctness of a synthesized dataset

Divesity

We use Self-BLEU to measure the diversity of a synthesized dataset. To calculate the Self-BLEU for a given dataset, you can see the example in scripts/run_self_bleu.sh script.

Correctness

To calculate the Correctness, you can take the following steps:

  1. Replace the following parameters in scripts/run_distilbert.sh script with:

    • small_model_name=roberta-large
    • dataset=: empty means using standard training set
    • limit=: empty means using full standard training set

    This will give you a RoBERTa-Large trained with full human annotations, which can be used as an evaluator.

  2. Replace the following parameters in scripts/run_distilbert.sh script with:

    • small_model_ckpt=tmp/checkpoint-xxx: the final RoBERTa-Large checkpoint saved in step 1.
    • limit=10000: the number of samples to use, by default 10000
    • dataset=xxx: the name of synthetic dataset (e.g., gpt2-xl_topk0_topp0.9_sst-2-x2)
    • no_train=true: disable training

    Run the script, and you will get Metric on standard dataset and Metric on synthetic dataset, which represents the Correctness of standard dataset and synthetic dataset, respectively.

Resources

We provide some synthetic datasets and standard datasets for training LSTM in this google drive link. When training DistilBERT, the standard dataset is directly downloaded by huggingface Dataset package. Note we use the same prompt for IMDb/SST-2, and SQuAD/AdversarialQA, therefore the synthetic datasets are also the same.

Continuous Security Group Rule Change Detection & Response at scale

Introduction Get notified of Security Group Changes across all AWS Accounts & Regions in an AWS Organization, with the ability to respond/revert those

Raajhesh Kannaa Chidambaram 3 Aug 13, 2022
Fast image augmentation library and an easy-to-use wrapper around other libraries

Albumentations Albumentations is a Python library for image augmentation. Image augmentation is used in deep learning and computer vision tasks to inc

11.4k Jan 09, 2023
Put blind watermark into a text with python

text_blind_watermark Put blind watermark into a text. Can be used in Wechat dingding ... How to Use install pip install text_blind_watermark Alice Pu

郭飞 164 Dec 30, 2022
Framework for evaluating ANNS algorithms on billion scale datasets.

Billion-Scale ANN http://big-ann-benchmarks.com/ Install The only prerequisite is Python (tested with 3.6) and Docker. Works with newer versions of Py

Harsha Vardhan Simhadri 132 Dec 24, 2022
A DCGAN to generate anime faces using custom mined dataset

Anime-Face-GAN-Keras A DCGAN to generate anime faces using custom dataset in Keras. Dataset The dataset is created by crawling anime database websites

Pavitrakumar P 190 Jan 03, 2023
For visualizing the dair-v2x-i dataset

3D Detection & Tracking Viewer The project is based on hailanyi/3D-Detection-Tracking-Viewer and is modified, you can find the original version of the

34 Dec 29, 2022
Exploring Visual Engagement Signals for Representation Learning

Exploring Visual Engagement Signals for Representation Learning Menglin Jia, Zuxuan Wu, Austin Reiter, Claire Cardie, Serge Belongie and Ser-Nam Lim C

Menglin Jia 9 Jul 23, 2022
ManipulaTHOR, a framework that facilitates visual manipulation of objects using a robotic arm

ManipulaTHOR: A Framework for Visual Object Manipulation Kiana Ehsani, Winson Han, Alvaro Herrasti, Eli VanderBilt, Luca Weihs, Eric Kolve, Aniruddha

AI2 65 Dec 30, 2022
Woosung Choi 63 Nov 14, 2022
Demos of essentia classifiers hosted on replicate.ai

essentia-replicate-demos Demos of Essentia models hosted on replicate.ai's MTG site. The models Check our site for a complete list of the models avail

Music Technology Group - Universitat Pompeu Fabra 12 Nov 14, 2022
DeepDiffusion: Unsupervised Learning of Retrieval-adapted Representations via Diffusion-based Ranking on Latent Feature Manifold

DeepDiffusion Introduction This repository provides the code of the DeepDiffusion algorithm for unsupervised learning of retrieval-adapted representat

4 Nov 15, 2022
A time series processing library

Timeseria Timeseria is a time series processing library which aims at making it easy to handle time series data and to build statistical and machine l

Stefano Alberto Russo 11 Aug 08, 2022
Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes

Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes [Paper] Method overview 4DMatch Benchmark 4DMatch is a benchmark for matc

103 Jan 06, 2023
Model that predicts the probability of a Twitter user being anti-vaccination.

stylebody {text-align: justify}/style AVAXTAR: Anti-VAXx Tweet AnalyzeR AVAXTAR is a python package to identify anti-vaccine users on twitter. The

10 Sep 27, 2022
SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch.

The SpeechBrain Toolkit SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch. The goal is to create a single, flexible, and us

SpeechBrain 5.1k Jan 02, 2023
🚗 INGI Dakar 2K21 - Be the first one on the finish line ! 🚗

🚗 INGI Dakar 2K21 - Be the first one on the finish line ! 🚗 This year's first semester Club Info challenge will put you at the head of a car racing

ClubINFO INGI (UCLouvain) 6 Dec 10, 2021
The code for the NeurIPS 2021 paper "A Unified View of cGANs with and without Classifiers".

Energy-based Conditional Generative Adversarial Network (ECGAN) This is the code for the NeurIPS 2021 paper "A Unified View of cGANs with and without

sianchen 22 May 28, 2022
Accurate identification of bacteriophages from metagenomic data using Transformer

PhaMer is a python library for identifying bacteriophages from metagenomic data. PhaMer is based on a Transorfer model and rely on protein-based vocab

Kenneth Shang 9 Nov 30, 2022
Metadata-Extractor - Metadata Extractor Script can be used to read in exif metadata

Metadata Extractor The exifextract script can be used to read in exif metadata f

1 Feb 16, 2022
Pytorch code for "State-only Imitation with Transition Dynamics Mismatch" (ICLR 2020)

This repo contains code for our paper State-only Imitation with Transition Dynamics Mismatch published at ICLR 2020. The code heavily uses the RL mach

20 Sep 08, 2022