This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Related tags

Deep LearningVDA
Overview

Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models

This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Quick Links

Overview

We propose a general framework Virtual Data Augmentation (VDA) for robustly fine-tuning Pre-trained Language Models for downstream tasks. Our VDA utilizes a masked language model with Gaussian noise to augment virtual examples for improving the robustness, and also adopts regularized training to further guarantee the semantic relevance and diversity.

Train VDA

In the following section, we describe how to train a model with VDA by using our code.

Training

Data

For evaluation of our VDA, we use 6 text classification datasets, i.e. Yelp, IMDB, AGNews, MR, QNLI and MRPC datasets. These datasets can be downloaded from the GoogleDisk

After download the two ziped files, users should unzip the data fold that contains the training, validation and test data of the 6 datasets. While the Robust fold contains the examples for test the robustness.

Training scripts We public our VDA with 4 base models. For single sentence classification tasks, we use text_classifier_xxx.py files. While for sentence pair classification tasks, we use text_pair_classifier_xxx.py:

  • text_classifier.py and text_pair_classifier.py: BERT-base+VDA

  • text_classifier_freelb.py and text_pair_classifier_freelb.py: FreeLB+VDA on BERT-base

  • text_classifier_smart.py and text_pair_classifier_smart.py: SMART+VDA on BERT-base, where we only use the smooth-inducing adversarial regularization.

  • text_classifier_smix.py and text_pair_classifier_smix.py: Smix+VDA on BERT-base, where we remove the adversarial data augmentation for fair comparison

We provide example scripts for both training and test of our VDA on the 6 datasets. In run_train.sh, we provide 6 example for training on the yelp and qnli datasets. This script calls text_classifier_xxx.py for training (xxx refers to the base model). We explain the arguments in following:

  • --dataset: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --save_path: Saved fine-tuned checkpoints file.
  • --max_length: Max sequence length. (For Yelp/IMDB/AG, we use 512. While for MR/QNLI/MRPC, we use 256.)
  • --max_epoch: The maximum training epoch number. (In most of datasets and models, we use 10.)
  • --batch_size: The batch size. (We adapt the batch size to the maximum number w.r.t the GPU memory size. Note that too small number may cause model collapse.)
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)
  • --lr: Learning rate.
  • --num_warmup: The rate of warm-up steps.
  • --variance: The variance of the Gaussian noise.

For results in the paper, we use Nvidia Tesla V100 32G and Nvidia 3090 24G GPUs to train our models. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.

Evaluation

During training, our model file will show the original accuracy on the test set of the 6 datasets, which evaluates the accuracy performance of our model. Our evaluation code for robustness is based on a modified version of BERT-Attack. It outputs Attack Accuracy, Query Numbers and Perturbation Ratio metrics.

Before evaluation, please download the evaluation datasets for Robustness from the GoogleDisk. Then, following the commonly-used settings, users need to download and process consine similarity matrix following TextFooler.

Based on the checkpoint of the fine-tuned models, we use therun_test.sh script for test the robustness on yelp and qnli datasets. It is based on bert_robust.py file. We explain the arguments in following:

  • --data_path: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --tgt_path: The fine-tuned checkpoints file.
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)

which is expected to output the results as:

original accuracy is 0.960000, attack accuracy is 0.533333, query num is 687.680556, perturb rate is 0.177204

Citation

Please cite our paper if you use VDA in your work:

@inproceedings{zhou2021vda,
  author    = {Kun Zhou, Wayne Xin Zhao, Sirui Wang, Fuzheng Zhang, Wei Wu and Ji-Rong Wen},
  title     = {Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models},
  booktitle = {{EMNLP} 2021},
  publisher = {The Association for Computational Linguistics},
}
Owner
RUCAIBox
An enthusiastic group that aims to create beautiful things with AI
RUCAIBox
Official code for Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset

Official code for our Interspeech 2021 - Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset [1]*. Visually-grounded spoken language datasets c

Ian Palmer 3 Jan 26, 2022
The code for paper "Contrastive Spatio-Temporal Pretext Learning for Self-supervised Video Representation" which is accepted by AAAI 2022

Contrastive Spatio Temporal Pretext Learning for Self-supervised Video Representation (AAAI 2022) The code for paper "Contrastive Spatio-Temporal Pret

8 Jun 30, 2022
Using Streamlit to host a multi-page tool with model specs and classification metrics, while also accepting user input values for prediction.

Predicitng_viability Using Streamlit to host a multi-page tool with model specs and classification metrics, while also accepting user input values for

Gopalika Sharma 1 Nov 08, 2021
Kaggle Ultrasound Nerve Segmentation competition [Keras]

Ultrasound nerve segmentation using Keras (1.0.7) Kaggle Ultrasound Nerve Segmentation competition [Keras] #Install (Ubuntu {14,16}, GPU) cuDNN requir

179 Dec 28, 2022
基于PaddleOCR搭建的OCR server... 离线部署用

开头说明 DangoOCR 是基于大家的 CPU处理器 来运行的,CPU处理器 的好坏会直接影响其速度, 但不会影响识别的精度 ,目前此版本识别速度可能在 0.5-3秒之间,具体取决于大家机器的配置,可以的话尽量不要在运行时开其他太多东西。需要配合团子翻译器 Ver3.6 及其以上的版本才可以使用!

胖次团子 131 Dec 25, 2022
Implement of "Training deep neural networks via direct loss minimization" in PyTorch for 0-1 loss

This is the implementation of "Training deep neural networks via direct loss minimization" published at ICML 2016 in PyTorch. The implementation targe

Cuong Nguyen 1 Jan 18, 2022
Gapmm2: gapped alignment using minimap2 (align transcripts to genome)

gapmm2: gapped alignment using minimap2 This tool is a wrapper for minimap2 to r

Jon Palmer 2 Jan 27, 2022
A cool little repl-based simulation written in Python

A cool little repl-based simulation written in Python planned to integrate machine-learning into itself to have AI battle to the death before your eye

Em 6 Sep 17, 2022
This repository contains the implementations related to the experiments of a set of publicly available datasets that are used in the time series forecasting research space.

TSForecasting This repository contains the implementations related to the experiments of a set of publicly available datasets that are used in the tim

Rakshitha Godahewa 80 Dec 30, 2022
Implementation of a Transformer, but completely in Triton

Transformer in Triton (wip) Implementation of a Transformer, but completely in Triton. I'm completely new to lower-level neural net code, so this repo

Phil Wang 152 Dec 22, 2022
[PNAS2021] The neural architecture of language: Integrative modeling converges on predictive processing

The neural architecture of language: Integrative modeling converges on predictive processing Code accompanying the paper The neural architecture of la

Martin Schrimpf 36 Dec 01, 2022
Python implementation of ADD: Frequency Attention and Multi-View based Knowledge Distillation to Detect Low-Quality Compressed Deepfake Images, AAAI2022.

ADD: Frequency Attention and Multi-View based Knowledge Distillation to Detect Low-Quality Compressed Deepfake Images Binh M. Le & Simon S. Woo, "ADD:

2 Oct 24, 2022
LSTM built using Keras Python package to predict time series steps and sequences. Includes sin wave and stock market data

LSTM Neural Network for Time Series Prediction LSTM built using the Keras Python package to predict time series steps and sequences. Includes sine wav

Jakob Aungiers 4.1k Jan 02, 2023
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 05, 2022
68 keypoint annotations for COFW test data

68 keypoint annotations for COFW test data This repository contains manually annotated 68 keypoints for COFW test data (original annotation of CFOW da

31 Dec 06, 2022
An Implementation of Fully Convolutional Networks in Tensorflow.

Update An example on how to integrate this code into your own semantic segmentation pipeline can be found in my KittiSeg project repository. tensorflo

Marvin Teichmann 1.1k Dec 12, 2022
Code base for NeurIPS 2021 publication titled Kernel Functional Optimisation (KFO)

KernelFunctionalOptimisation Code base for NeurIPS 2021 publication titled Kernel Functional Optimisation (KFO) We have conducted all our experiments

2 Jun 29, 2022
Model Agnostic Interpretability for Multiple Instance Learning

MIL Model Agnostic Interpretability This repo contains the code for "Model Agnostic Interpretability for Multiple Instance Learning". Overview Executa

Joe Early 10 Dec 17, 2022
curl-impersonate: A special compilation of curl that makes it impersonate Chrome & Firefox

curl-impersonate A special compilation of curl that makes it impersonate real browsers. It can impersonate the four major browsers: Chrome, Edge, Safa

lwthiker 1.9k Jan 03, 2023
Unofficial PyTorch implementation of the Adaptive Convolution architecture for image style transfer

AdaConv Unofficial PyTorch implementation of the Adaptive Convolution architecture for image style transfer from "Adaptive Convolutions for Structure-

65 Dec 22, 2022