Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

Overview

Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

The code repository for "Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation" [paper, to appear] [slides, to appear] [poster, to appear] in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:

@misc{Proto-CAT,
  author = {Yi-Kai Zhang},
  title = {Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/ZhangYikaii/Proto-CAT}},
  commit = {main}
}

Prototype-based Co-Adaptation with Transformer

Illustration of Proto-CAT. The model transforms the classification space using [公式] based on two kinds of audio-visual prototypes (class centers): (1) the base training categories (color with [公式]blue, [公式]green, and [公式]pink); and (2) the additional novel test categories (color with [公式]burning transition). Proto-CAT learns and generalizes on novel test categories from limited labeled examples, maintaining performance on the base training ones. [公式] includes audio-visual level and category level prototype-based co-adaptation. From left to right, more coverage and more bright colors represent a more reliable classification space.

 

Results

Dataset LRW LRW-1000
Data Source Audio () Video () Audio-Video () Audio-Video ()
Perf. Measures on H-mean H-mean Base Novel H-mean Base Novel H-mean
LSTM-based 32.20 8.00 97.09 23.76 37.22 71.34 0.03 0.07
GRU-based 37.01 10.58 97.44 27.35 41.71 71.34 0.05 0.09
MS-TCN-based 62.29 19.06 80.96 51.28 61.76 71.55 0.33 0.63
ProtoNet-GFSL 39.95 14.40 96.33 39.23 54.79 69.33 0.76 1.47
FEAT-GFSL 49.90 25.75 96.26 54.52 68.83 71.69 2.62 4.89
DFSL 72.13 42.56 66.10 84.62 73.81 31.68 68.72 42.56
CASTLE 75.48 34.68 73.50 90.20 80.74 11.13 54.07 17.84
Proto-CAT (Ours) 84.18 74.55 93.37 91.20 92.13 49.70 38.27 42.25
Proto-CAT+ (Ours) 93.18 90.16 91.49 54.55 38.16 43.88

Audio-visual generalized few-shot learning classification performance (in %; measured over 10,000 rounds; higher is better) of 5-way 1-shot training tasks on LRW and LRW-1000 datasets. The best result of each scenario is in bold font. The performance measure on both base and novel classes (Base, Novel in the table) is mean accuracy. Harmonic mean (i.e., H-mean) of the above two is a better generalized few-shot learning performance measure.

 

Prerequisites

Environment

Please refer to requirements.txt and run:

pip install -r requirement.txt

Dataset

  • Use preprocessed data (suggested):

    LRW and LRW-1000 forbid directly share the preprocessed data.

  • Use raw data and do preprocess:

    Download LRW Dataset and unzip, like,

    /your data_path set in .sh file
    ├── lipread_mp4
    │   ├── [ALL CLASS FOLDER]
    │   ├── ...
    

    Run prepare_lrw_audio.py and prepare_lrw_video.py to preprocess data on video and audio modality, respectively. Please modify the data path in the above preprocessing file in advance.

    Similarly, Download LRW-1000 dataset and unzip. Run prepare_lrw1000_audio.py and prepare_lrw1000_video.py to preprocess it.

Pretrained Weights

We provide pretrained weights on LRW and LRW-1000 dataset. Download from Google Drive or Baidu Yun(password: 3ad2) and put them as:

/your init_weights set in .sh file
├── Conv1dResNetGRU_LRW-pre.pth
├── Conv3dResNetLSTM_LRW-pre.pth
├── Conv1dResNetGRU_LRW1000-pre.pth
├── Conv3dResNetLSTM_LRW1000-pre.pth

 

How to Train Proto-CAT

For LRW dataset, fine-tune the parameters in run/protocat_lrw.sh, and run:

cd ./Proto-CAT/run
bash protocat_lrw.sh

Similarly, run bash protocat_lrw1000.sh for dataset LRW-1000.

Run bash protocat_plus_lrw.sh / bash protocat_plus_lrw1000.sh to train Proto-CAT+.

How to Reproduce the Result of Proto-CAT

Download the trained models from Google Drive or Baidu Yun(password: swzd) and run:

bash test_protocat_lrw.sh

Run bash test_protocat_lrw1000.sh, bash test_protocat_plus_lrw.sh, or bash test_protocat_plus_lrw1000.sh to evaluate other models.

 

Code Structures

Proto-CAT's entry function is in main.py. It calls the manager Trainer in models/train.py that contains the main training logic. In Trainer, prepare_handle.prepare_dataloader combined with train_prepare_batch inputs and preprocesses generalized few-shot style data. fit_handle controls forward and backward propagation. callbacks deals with the behaviors at each stage.

Arguments

All parameters are defined in models/utils.py. We list the main ones below:

  • do_train, do_test: Store-true switch for whether to train or test.
  • data_path: Data directory to be set.
  • model_save_path: Optimal model save directory to be set.
  • init_weights: Pretrained weights to be set.
  • dataset: Option for the dataset.
  • model_class: Option for the top model.
  • backend_type: Option list for the backend type.
  • train_way, val_way, test_way, train_shot, val_shot, test_shot, train_query, val_query, test_query: Tasks setting of generalized few-shot learning.
  • gfsl_train, gfsl_test: Switch for whether train or test in generalized few-shot learning way, i.e., whether additional base class data is included.
  • mm_list: Participating modalities.
  • lr_scheduler: List of learning rate scheduler.
  • loss_fn: Option for the loss function.
  • max_epoch: Maximum training epoch.
  • episodes_per_train_epoch, episodes_per_val_epoch, episodes_per_test_epoch: Number of sampled episodes per epoch.
  • num_tasks: Number of tasks per episode.
  • meta_batch_size: Batch size of each task.
  • test_model_filepath: Trained weights .pth file path when testing a model.
  • gpu: Multi-GPU option like --gpu 0,1,2,3.
  • logger_filename: Logger file save directory.
  • time_str: Token for each run, and will generate by itself if empty.
  • acc_per_class: Switch for whether to measure the accuracy of each class with base, novel, and harmonic mean.
  • verbose, epoch_verbose: Switch for whether to output message or output progress bar.
  • torch_seed, cuda_seed, np_seed, random_seed: Seeds of random number generation.

 

Acknowledgment

We thank the following repos providing helpful components/functions in our work.

Owner
Kaiaicy
Few-Shot Learning
Kaiaicy
Re-implementation of the vector capsule with dynamic routing

VectorCapsule Re-implementation of the vector capsule with dynamic routing We implement the vector capsule and dynamic routing via graph neural networ

ZhenchaoTang 10 Feb 10, 2022
"Learning Free Gait Transition for Quadruped Robots vis Phase-Guided Controller"

PhaseGuidedControl The current version is developed based on the old version of RaiSim series, and possibly requires further modification. It will be

X-Mechanics 12 Oct 21, 2022
🏖 Keras Implementation of Painting outside the box

Keras implementation of Image OutPainting This is an implementation of Painting Outside the Box: Image Outpainting paper from Standford University. So

Bendang 1.1k Dec 10, 2022
[Preprint] ConvMLP: Hierarchical Convolutional MLPs for Vision, 2021

Convolutional MLP ConvMLP: Hierarchical Convolutional MLPs for Vision Preprint link: ConvMLP: Hierarchical Convolutional MLPs for Vision By Jiachen Li

SHI Lab 143 Jan 03, 2023
This repo includes the supplementary of our paper "CEMENT: Incomplete Multi-View Weak-Label Learning with Long-Tailed Labels"

Supplementary Materials for CEMENT: Incomplete Multi-View Weak-Label Learning with Long-Tailed Labels This repository includes all supplementary mater

Zhiwei Li 0 Jan 05, 2022
This is the repository for paper NEEDLE: Towards Non-invertible Backdoor Attack to Deep Learning Models.

This is the repository for paper NEEDLE: Towards Non-invertible Backdoor Attack to Deep Learning Models.

1 Oct 25, 2021
Exploiting Robust Unsupervised Video Person Re-identification

Exploiting Robust Unsupervised Video Person Re-identification Implementation of the proposed uPMnet. For the preprint, please refer to [Arxiv]. Gettin

1 Apr 09, 2022
[CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision

TorchSemiSeg [CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision by Xiaokang Chen1, Yuhui Yuan2, Gang Zeng1, Jingdong Wang

Chen XiaoKang 387 Jan 08, 2023
Official PyTorch implementation of the paper: DeepSIM: Image Shape Manipulation from a Single Augmented Training Sample

DeepSIM: Image Shape Manipulation from a Single Augmented Training Sample (ICCV 2021 Oral) Project | Paper Official PyTorch implementation of the pape

Eliahu Horwitz 393 Dec 22, 2022
A novel Engagement Detection with Multi-Task Training (ED-MTT) system

A novel Engagement Detection with Multi-Task Training (ED-MTT) system which minimizes MSE and triplet loss together to determine the engagement level of students in an e-learning environment.

Onur Çopur 12 Nov 11, 2022
Multi-Output Gaussian Process Toolkit

Multi-Output Gaussian Process Toolkit Paper - API Documentation - Tutorials & Examples The Multi-Output Gaussian Process Toolkit is a Python toolkit f

GAMES 113 Nov 25, 2022
This is the source code for generating the ASL-Skeleton3D and ASL-Phono datasets. Check out the README.md for more details.

ASL-Skeleton3D and ASL-Phono Datasets Generator The ASL-Skeleton3D contains a representation based on mapping into the three-dimensional space the coo

Cleison Amorim 5 Nov 20, 2022
Volsdf - Volume Rendering of Neural Implicit Surfaces

Volume Rendering of Neural Implicit Surfaces Project Page | Paper | Data This re

Lior Yariv 221 Jan 07, 2023
Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation

Info This is the code repository of the work Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation from Elias T

2 Apr 20, 2022
sense-py-AnishaBaishya created by GitHub Classroom

Compute Statistics Here we compute statistics for a bunch of numbers. This project uses the unittest framework to test functionality. Pass the tests T

1 Oct 21, 2021
Deep Reinforcement Learning by using an on-policy adaptation of Maximum a Posteriori Policy Optimization (MPO)

V-MPO Simple code to demonstrate Deep Reinforcement Learning by using an on-policy adaptation of Maximum a Posteriori Policy Optimization (MPO) in Pyt

Nugroho Dewantoro 9 Jun 06, 2022
Point-NeRF: Point-based Neural Radiance Fields

Point-NeRF: Point-based Neural Radiance Fields Project Sites | Paper | Primary c

Qiangeng Xu 662 Jan 01, 2023
Speech Recognition using DeepSpeech2.

deepspeech.pytorch Implementation of DeepSpeech2 for PyTorch using PyTorch Lightning. The repo supports training/testing and inference using the DeepS

Sean Naren 2k Jan 04, 2023
code for Image Manipulation Detection by Multi-View Multi-Scale Supervision

MVSS-Net Code and models for ICCV 2021 paper: Image Manipulation Detection by Multi-View Multi-Scale Supervision Update 22.02.17, Pretrained model for

dong_chengbo 131 Dec 30, 2022
It's A ML based Web Site build with python and Django to find the breed of the dog

ML-Based-Dog-Breed-Identifier This is a Django Based Web Site To Identify the Breed of which your DOG belogs All You Need To Do is to Follow These Ste

Sanskar Dwivedi 2 Oct 12, 2022