Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

Overview

Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

The code repository for "Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation" [paper, to appear] [slides, to appear] [poster, to appear] in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:

@misc{Proto-CAT,
  author = {Yi-Kai Zhang},
  title = {Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/ZhangYikaii/Proto-CAT}},
  commit = {main}
}

Prototype-based Co-Adaptation with Transformer

Illustration of Proto-CAT. The model transforms the classification space using [公式] based on two kinds of audio-visual prototypes (class centers): (1) the base training categories (color with [公式]blue, [公式]green, and [公式]pink); and (2) the additional novel test categories (color with [公式]burning transition). Proto-CAT learns and generalizes on novel test categories from limited labeled examples, maintaining performance on the base training ones. [公式] includes audio-visual level and category level prototype-based co-adaptation. From left to right, more coverage and more bright colors represent a more reliable classification space.

 

Results

Dataset LRW LRW-1000
Data Source Audio () Video () Audio-Video () Audio-Video ()
Perf. Measures on H-mean H-mean Base Novel H-mean Base Novel H-mean
LSTM-based 32.20 8.00 97.09 23.76 37.22 71.34 0.03 0.07
GRU-based 37.01 10.58 97.44 27.35 41.71 71.34 0.05 0.09
MS-TCN-based 62.29 19.06 80.96 51.28 61.76 71.55 0.33 0.63
ProtoNet-GFSL 39.95 14.40 96.33 39.23 54.79 69.33 0.76 1.47
FEAT-GFSL 49.90 25.75 96.26 54.52 68.83 71.69 2.62 4.89
DFSL 72.13 42.56 66.10 84.62 73.81 31.68 68.72 42.56
CASTLE 75.48 34.68 73.50 90.20 80.74 11.13 54.07 17.84
Proto-CAT (Ours) 84.18 74.55 93.37 91.20 92.13 49.70 38.27 42.25
Proto-CAT+ (Ours) 93.18 90.16 91.49 54.55 38.16 43.88

Audio-visual generalized few-shot learning classification performance (in %; measured over 10,000 rounds; higher is better) of 5-way 1-shot training tasks on LRW and LRW-1000 datasets. The best result of each scenario is in bold font. The performance measure on both base and novel classes (Base, Novel in the table) is mean accuracy. Harmonic mean (i.e., H-mean) of the above two is a better generalized few-shot learning performance measure.

 

Prerequisites

Environment

Please refer to requirements.txt and run:

pip install -r requirement.txt

Dataset

  • Use preprocessed data (suggested):

    LRW and LRW-1000 forbid directly share the preprocessed data.

  • Use raw data and do preprocess:

    Download LRW Dataset and unzip, like,

    /your data_path set in .sh file
    ├── lipread_mp4
    │   ├── [ALL CLASS FOLDER]
    │   ├── ...
    

    Run prepare_lrw_audio.py and prepare_lrw_video.py to preprocess data on video and audio modality, respectively. Please modify the data path in the above preprocessing file in advance.

    Similarly, Download LRW-1000 dataset and unzip. Run prepare_lrw1000_audio.py and prepare_lrw1000_video.py to preprocess it.

Pretrained Weights

We provide pretrained weights on LRW and LRW-1000 dataset. Download from Google Drive or Baidu Yun(password: 3ad2) and put them as:

/your init_weights set in .sh file
├── Conv1dResNetGRU_LRW-pre.pth
├── Conv3dResNetLSTM_LRW-pre.pth
├── Conv1dResNetGRU_LRW1000-pre.pth
├── Conv3dResNetLSTM_LRW1000-pre.pth

 

How to Train Proto-CAT

For LRW dataset, fine-tune the parameters in run/protocat_lrw.sh, and run:

cd ./Proto-CAT/run
bash protocat_lrw.sh

Similarly, run bash protocat_lrw1000.sh for dataset LRW-1000.

Run bash protocat_plus_lrw.sh / bash protocat_plus_lrw1000.sh to train Proto-CAT+.

How to Reproduce the Result of Proto-CAT

Download the trained models from Google Drive or Baidu Yun(password: swzd) and run:

bash test_protocat_lrw.sh

Run bash test_protocat_lrw1000.sh, bash test_protocat_plus_lrw.sh, or bash test_protocat_plus_lrw1000.sh to evaluate other models.

 

Code Structures

Proto-CAT's entry function is in main.py. It calls the manager Trainer in models/train.py that contains the main training logic. In Trainer, prepare_handle.prepare_dataloader combined with train_prepare_batch inputs and preprocesses generalized few-shot style data. fit_handle controls forward and backward propagation. callbacks deals with the behaviors at each stage.

Arguments

All parameters are defined in models/utils.py. We list the main ones below:

  • do_train, do_test: Store-true switch for whether to train or test.
  • data_path: Data directory to be set.
  • model_save_path: Optimal model save directory to be set.
  • init_weights: Pretrained weights to be set.
  • dataset: Option for the dataset.
  • model_class: Option for the top model.
  • backend_type: Option list for the backend type.
  • train_way, val_way, test_way, train_shot, val_shot, test_shot, train_query, val_query, test_query: Tasks setting of generalized few-shot learning.
  • gfsl_train, gfsl_test: Switch for whether train or test in generalized few-shot learning way, i.e., whether additional base class data is included.
  • mm_list: Participating modalities.
  • lr_scheduler: List of learning rate scheduler.
  • loss_fn: Option for the loss function.
  • max_epoch: Maximum training epoch.
  • episodes_per_train_epoch, episodes_per_val_epoch, episodes_per_test_epoch: Number of sampled episodes per epoch.
  • num_tasks: Number of tasks per episode.
  • meta_batch_size: Batch size of each task.
  • test_model_filepath: Trained weights .pth file path when testing a model.
  • gpu: Multi-GPU option like --gpu 0,1,2,3.
  • logger_filename: Logger file save directory.
  • time_str: Token for each run, and will generate by itself if empty.
  • acc_per_class: Switch for whether to measure the accuracy of each class with base, novel, and harmonic mean.
  • verbose, epoch_verbose: Switch for whether to output message or output progress bar.
  • torch_seed, cuda_seed, np_seed, random_seed: Seeds of random number generation.

 

Acknowledgment

We thank the following repos providing helpful components/functions in our work.

Owner
Kaiaicy
Few-Shot Learning
Kaiaicy
Structured Data Gradient Pruning (SDGP)

Structured Data Gradient Pruning (SDGP) Weight pruning is a technique to make Deep Neural Network (DNN) inference more computationally efficient by re

Bradley McDanel 10 Nov 11, 2022
Normalizing Flows with a resampled base distribution

Resampling Base Distributions of Normalizing Flows Normalizing flows are a popular class of models for approximating probability distributions. Howeve

Vincent Stimper 24 Nov 03, 2022
The InterScript dataset contains interactive user feedback on scripts generated by a T5-XXL model.

Interscript The Interscript dataset contains interactive user feedback on a T5-11B model generated scripts. Dataset data.json contains the data in an

AI2 8 Dec 01, 2022
Deep Learning Head Pose Estimation using PyTorch.

Hopenet is an accurate and easy to use head pose estimation network. Models have been trained on the 300W-LP dataset and have been tested on real data with good qualitative performance.

Nataniel Ruiz 1.3k Dec 26, 2022
A generator of point clouds dataset for PyPipes.

CloudPipesGenerator Documentation | Colab Notebooks | Video Tutorials | Master Degree website A generator of point clouds dataset for PyPipes. TODO Us

1 Jan 13, 2022
IsoGCN code for ICLR2021

IsoGCN The official implementation of IsoGCN, presented in the ICLR2021 paper Isometric Transformation Invariant and Equivariant Graph Convolutional N

horiem 39 Nov 25, 2022
A machine learning benchmark of in-the-wild distribution shifts, with data loaders, evaluators, and default models.

WILDS is a benchmark of in-the-wild distribution shifts spanning diverse data modalities and applications, from tumor identification to wildlife monitoring to poverty mapping.

P-Lambda 437 Dec 30, 2022
Notes taking website build with Docker + Django + React.

Notes website. Try it in browser! / But how to run? Description. This is monorepository with notes website. Website provides web interface for creatin

Kirill Zhosul 2 Jul 27, 2022
The authors' official PyTorch SigWGAN implementation

The authors' official PyTorch SigWGAN implementation This repository is the official implementation of [Sig-Wasserstein GANs for Time Series Generatio

9 Jun 16, 2022
Shallow Convolutional Neural Networks for Human Activity Recognition using Wearable Sensors

-IEEE-TIM-2021-1-Shallow-CNN-for-HAR [IEEE TIM 2021-1] Shallow Convolutional Neural Networks for Human Activity Recognition using Wearable Sensors All

Wenbo Huang 1 May 17, 2022
An AutoML Library made with Optuna and PyTorch Lightning

An AutoML Library made with Optuna and PyTorch Lightning Installation Recommended pip install -U gradsflow From source pip install git+https://github.

GradsFlow 294 Dec 17, 2022
COD-Rank-Localize-and-Segment (CVPR2021)

COD-Rank-Localize-and-Segment (CVPR2021) Simultaneously Localize, Segment and Rank the Camouflaged Objects Full camouflage fixation training dataset i

JingZhang 52 Dec 20, 2022
Official pytorch implementation of Rainbow Memory (CVPR 2021)

Rainbow Memory: Continual Learning with a Memory of Diverse Samples

Clova AI Research 91 Dec 17, 2022
Efficient Speech Processing Tookit for Automatic Speaker Recognition

Sugar Efficient Speech Processing Tookit for Automatic Speaker Recognition | HuggingFace | What's New EfficientTDNN: Efficient Architecture Search for

WangRui 14 Sep 14, 2022
Physics-informed Neural Operator for Learning Partial Differential Equation

PINO Physics-informed Neural Operator for Learning Partial Differential Equation Abstract: Machine learning methods have recently shown promise in sol

107 Jan 02, 2023
Inverse Optimal Control Adapted to the Noise Characteristics of the Human Sensorimotor System

Inverse Optimal Control Adapted to the Noise Characteristics of the Human Sensorimotor System This repository contains code for the paper Schultheis,

2 Oct 28, 2022
Wanli Li and Tieyun Qian: Exploit a Multi-head Reference Graph for Semi-supervised Relation Extraction, IJCNN 2021

MRefG Wanli Li and Tieyun Qian: "Exploit a Multi-head Reference Graph for Semi-supervised Relation Extraction", IJCNN 2021 1. Requirements To reproduc

万理 5 Jul 26, 2022
BASH - Biomechanical Animated Skinned Human

We developed a method animating a statistical 3D human model for biomechanical analysis to increase accessibility for non-experts, like patients, athletes, or designers.

Machine Learning and Data Analytics Lab FAU 66 Nov 19, 2022
A PyTorch implementation of "From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network" (ICCV2021)

From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network The official code of VisionLAN (ICCV2021). VisionLAN successfully a

81 Dec 12, 2022
PyTorch package for the discrete VAE used for DALL·E.

Overview [Blog] [Paper] [Model Card] [Usage] This is the official PyTorch package for the discrete VAE used for DALL·E. Installation Before running th

OpenAI 9.5k Jan 05, 2023