code for the ICLR'22 paper: On Robust Prefix-Tuning for Text Classification

Overview

On Robust Prefix-Tuning for Text Classification

Prefix-tuning has drawed much attention as it is a parameter-efficient and modular alternative to adapting pretrained language models to downstream tasks. However, we find that prefix-tuning suffers from adversarial attacks. While, unfortunately, current robust NLP methods are unsuitable for prefix-tuning as they will inevitably hamper the modularity of prefix-tuning. In our ICLR'22 paper, we propose robust prefix-tuning for text classification. Our method leverages the idea of test-time tuning, which preserves the strengths of prefix-tuning and improves its robustness at the same time. This repository contains the code for the proposed robust prefix-tuning method.

Prerequisite

PyTorch>=1.2.0, pytorch-transformers==1.2.0, OpenAttack==2.0.1, and GPUtil==1.4.0.

Train the original prefix P_θ

For the training phase of standard prefix-tuning, the command is:

  source train.sh --preseqlen [A] --learning_rate [B] --tasks [C] --n_train_epochs [D] --device [E]

where

  • [A]: The length of the prefix P_θ.
  • [B]: The (initial) learning rate.
  • [C]: The benchmark. Default: sst.
  • [D]: The total epochs during training.
  • [E]: The id of the GPU to be used.

We can also use adversarial training to improve the robustness of the prefix. For the training phase of adversarial prefix-tuning, the command is:

  source train_adv.sh --preseqlen [A] --learning_rate [B] --tasks [C] --n_train_epochs [D] --device [E] --pgd_ball [F]

where

  • [A]~[E] have the same meanings with above.
  • [F]: where norm ball is word-wise or sentence-wise.

Note that the DATA_DIR and MODEL_DIR in train_adv.sh are different from those in train.sh. When experimenting with the adversarially trained prefix P_θ's in the following steps, remember to switch the DATA_DIR and MODEL_DIR in the corresponding scripts as well.

Generate Adversarial Examples

We use the OpenAttack package to generate in-sentence adversaries. The command is:

  source generate_adv_insent.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack [H]

where

  • [A],[B],[C],[E] have the same meanings with above.
  • [G]: Load the prefix P_θ parameters trained for [G] epochs for testing. We set G=D.
  • [H]: Generate adversarial examples based on clean test set with the in-sentence attack [H].

We also implement the Universal Adversarial Trigger attack. The command is:

  source generate_adv_uat.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack clean-[H2] --uat_len [I] --uat_epoch [J]

where

  • [A],[B],[C],[E],[G] have the same meanings with above.
  • [H2]: We should search for UATs for each class in the benchmark, and H2 indicates the class id. H2=0/1 for SST, 0/1/2/3 for AG News, and 0/1/2 for SNLI.
  • [I]: The length of the UAT.
  • [J]: The epochs for exploiting UAT.

Test the performance of P_θ

The command for performance testing of P_θ under clean data and in-sentence attacks is:

  source test_prefix_theta_insent.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack [H] --test_batch_size [K]

Under UAT attack, the test command is:

  source test_prefix_theta_uat.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack clean --uat_len [I] --test_batch_size [K]

where

  • [A]~[I] have the same meanings with above.
  • [K]: The test batch size. when K=0, the batch size is adaptive (determined by GPU memory); when K>0, the batch size is fixed.

Robust Prefix P'_ψ: Constructing the canonical manifolds

By constructing the canonical manifolds with PCA, we get the projection matrices. The command is:

  source get_proj.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G]

where [A]~[G] have the same meanings with above.

Robust Prefix P'_ψ: Test its performance

Under clean data and in-sentence attacks, the command is:

  source test_robust_prefix_psi_insent.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack [H] --test_batch_size [K] --PMP_lr [L] --PMP_iter [M]

Under UAT attack, the test command is:

  source test_robust_prefix_psi_uat.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack clean --uat_len [I] --test_batch_size [K] --PMP_lr [L] --PMP_iter [M]

where

  • [A]~[K] have the same meanings with above.
  • [L]: The learning rate for test-time P'_ψ tuning.
  • [M]: The iterations for test-time P'_ψ tuning.

Running Example

# Train the original prefix P_θ
source train.sh --tasks sst --n_train_epochs 100 --device 0
source train_adv.sh --tasks sst --n_train_epochs 100 --device 1 --pgd_ball word

# Generate Adversarial Examples
source generate_adv_insent.sh --tasks sst --device 0 --test_ep 100 --attack bug
source generate_adv_uat.sh --tasks sst --device 0 --test_ep 100 --attack clean-0 --uat_len 3 --uat_epoch 10
source generate_adv_uat.sh --tasks sst --device 0 --test_ep 100 --attack clean-1 --uat_len 3 --uat_epoch 10

# Test the performance of P_θ
source test_prefix_theta_insent.sh --tasks sst --device 0 --test_ep 100 --attack bug --test_batch_size 0
source test_prefix_theta_uat.sh --tasks sst --device 0 --test_ep 100 --attack clean --uat_len 3 --test_batch_size 0

# Robust Prefix P'_ψ: Constructing the canonical manifolds
source get_proj.sh --tasks sst --device 0 --test_ep 100

# Robust Prefix P'_ψ: Test its performance
source test_robust_prefix_psi_insent.sh --tasks sst --device 0 --test_ep 100 --attack bug --test_batch_size 0 --PMP_lr 0.15 --PMP_iter 10
source test_robust_prefix_psi_uat.sh --tasks sst --device 0 --test_ep 100 --attack clean --uat_len 3 --test_batch_size 0 --PMP_lr 0.05 --PMP_iter 10

Released Data & Models

The training the original prefix P_θ and the process of generating adversarial examples can be time-consuming. As shown in our paper, the adversarial prefix-tuning is particularly slow. Efforts need to be paid on generating adversaries as well, since different attacks are to be performed on the test set based on each trained prefix. We also found that OpenAttack is now upgraded to v2.1.1, which causes compatibility issues in our codes (test_prefix_theta_insent.py).

In order to facilitate research on the robustness of prefix-tuning, we release the prefix checkpoints P_θ (with both std. and adv. training), the processed test sets that are perturbed by in-sentence attacks (including PWWS and TextBugger), as well as the generated projection matrices of the canonical manifolds in our runs for reproducibility and further enhancement. We have also hard-coded the exploited UAT tokens in test_prefix_theta_uat.py and test_robust_prefix_psi_uat.py. All the materials can be found here.

Acknowledgements:

The implementation of robust prefix tuning is based on the LAMOL repo, which is the code of LAMOL: LAnguage MOdeling for Lifelong Language Learning that studies NLP lifelong learning with GPT-style pretrained language models.

Bibtex

If you find this repository useful for your research, please consider citing our work:

@inproceedings{
  yang2022on,
  title={On Robust Prefix-Tuning for Text Classification},
  author={Zonghan Yang and Yang Liu},
  booktitle={International Conference on Learning Representations},
  year={2022},
  url={https://openreview.net/forum?id=eBCmOocUejf}
}
Owner
Zonghan Yang
Graduate student in Tsinghua University. Two drifters, off to see the world - there's such a lot of world to see...
Zonghan Yang
This repo implements a 3D segmentation task for an airport baggage dataset.

3D CT Scan Segmentation With Occupancy Network This repo implements a 3D superresolution segmentation task for an airport baggage dataset. Our final p

Christoph Reich 2 Mar 28, 2022
Code for MarioNette: Self-Supervised Sprite Learning, in NeurIPS 2021

MarioNette | Webpage | Paper | Video MarioNette: Self-Supervised Sprite Learning Dmitriy Smirnov, Michaël Gharbi, Matthew Fisher, Vitor Guizilini, Ale

Dima Smirnov 28 Nov 18, 2022
Data labels and scripts for fastMRI.org

fastMRI+: Clinical pathology annotations for the fastMRI dataset The fastMRI dataset is a publicly available MRI raw (k-space) dataset. It has been us

Microsoft 51 Dec 22, 2022
Tensorflow 2 implementations of the C-SimCLR and C-BYOL self-supervised visual representation methods from "Compressive Visual Representations" (NeurIPS 2021)

Compressive Visual Representations This repository contains the source code for our paper, Compressive Visual Representations. We developed informatio

Google Research 30 Nov 23, 2022
Large dataset storage format for Pytorch

H5Record Large dataset ( 100G, = 1T) storage format for Pytorch (wip) Support python 3 pip install h5record Why? Writing large dataset is still a

theblackcat102 43 Oct 22, 2022
Optimus: the first large-scale pre-trained VAE language model

Optimus: the first pre-trained Big VAE language model This repository contains source code necessary to reproduce the results presented in the EMNLP 2

314 Dec 19, 2022
Rust bindings for the C++ api of PyTorch.

tch-rs Rust bindings for the C++ api of PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorc

Laurent Mazare 2.3k Dec 30, 2022
High-fidelity 3D Model Compression based on Key Spheres

High-fidelity 3D Model Compression based on Key Spheres This repository contains the implementation of the paper: Yuanzhan Li, Yuqi Liu, Yujie Lu, Siy

5 Oct 11, 2022
A PyTorch-based library for semi-supervised learning

News If you want to join TorchSSL team, please e-mail Yidong Wang ([email protected]<

1k Jan 06, 2023
Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch

Rotary Embeddings - Pytorch A standalone library for adding rotary embeddings to transformers in Pytorch, following its success as relative positional

Phil Wang 110 Dec 30, 2022
Anomaly Detection Based on Hierarchical Clustering of Mobile Robot Data

We proposed a new approach to detect anomalies of mobile robot data. We investigate each data seperately with two clustering method hierarchical and k-means. There are two sub-method that we used for

Zekeriyya Demirci 1 Jan 09, 2022
Collaborative forensic timeline analysis

Timesketch Table of Contents About Timesketch Getting started Community Contributing About Timesketch Timesketch is an open-source tool for collaborat

Google 2.1k Dec 28, 2022
Single Image Random Dot Stereogram for Tensorflow

TensorFlow-SIRDS Single Image Random Dot Stereogram for Tensorflow SIRDS is a means to present 3D data in a 2D image. It allows for scientific data di

Greg Peatfield 5 Aug 10, 2022
Text2Art is an AI art generator powered with VQGAN + CLIP and CLIPDrawer models

Text2Art is an AI art generator powered with VQGAN + CLIP and CLIPDrawer models. You can easily generate all kind of art from drawing, painting, sketch, or even a specific artist style just using a t

Muhammad Fathy Rashad 643 Dec 30, 2022
[ECCV 2020] Reimplementation of 3DDFAv2, including face mesh, head pose, landmarks, and more.

Stable Head Pose Estimation and Landmark Regression via 3D Dense Face Reconstruction Reimplementation of (ECCV 2020) Towards Fast, Accurate and Stable

Remilia Scarlet 221 Dec 30, 2022
Code for visualizing the loss landscape of neural nets

Visualizing the Loss Landscape of Neural Nets This repository contains the PyTorch code for the paper Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer

Tom Goldstein 2.2k Jan 09, 2023
Classifying cat and dog images using Kaggle dataset

PyTorch Image Classification Classifies an image as containing either a dog or a cat (using Kaggle's public dataset), but could easily be extended to

Robert Coleman 74 Nov 22, 2022
A practical ML pipeline for data labeling with experiment tracking using DVC.

Auto Label Pipeline A practical ML pipeline for data labeling with experiment tracking using DVC Goals: Demonstrate reproducible ML Use DVC to build a

Todd Cook 4 Mar 08, 2022
tf2-keras implement yolov5

YOLOv5 in tesnorflow2.x-keras yolov5数据增强jupyter示例 Bilibili视频讲解地址: 《yolov5 解读,训练,复现》 Bilibili视频讲解PPT文件: yolov5_bilibili_talk_ppt.pdf Bilibili视频讲解PPT文件:

yangcheng 254 Jan 08, 2023
MolRep: A Deep Representation Learning Library for Molecular Property Prediction

MolRep: A Deep Representation Learning Library for Molecular Property Prediction Summary MolRep is a Python package for fairly measuring algorithmic p

AI-Health @NSCC-gz 83 Dec 24, 2022