Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

Overview

Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

The code repository for "Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation" [paper, to appear] [slides, to appear] [poster, to appear] in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:

@misc{Proto-CAT,
  author = {Yi-Kai Zhang},
  title = {Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/ZhangYikaii/Proto-CAT}},
  commit = {main}
}

Prototype-based Co-Adaptation with Transformer

Illustration of Proto-CAT. The model transforms the classification space using [公式] based on two kinds of audio-visual prototypes (class centers): (1) the base training categories (color with [公式]blue, [公式]green, and [公式]pink); and (2) the additional novel test categories (color with [公式]burning transition). Proto-CAT learns and generalizes on novel test categories from limited labeled examples, maintaining performance on the base training ones. [公式] includes audio-visual level and category level prototype-based co-adaptation. From left to right, more coverage and more bright colors represent a more reliable classification space.

 

Results

Dataset LRW LRW-1000
Data Source Audio () Video () Audio-Video () Audio-Video ()
Perf. Measures on H-mean H-mean Base Novel H-mean Base Novel H-mean
LSTM-based 32.20 8.00 97.09 23.76 37.22 71.34 0.03 0.07
GRU-based 37.01 10.58 97.44 27.35 41.71 71.34 0.05 0.09
MS-TCN-based 62.29 19.06 80.96 51.28 61.76 71.55 0.33 0.63
ProtoNet-GFSL 39.95 14.40 96.33 39.23 54.79 69.33 0.76 1.47
FEAT-GFSL 49.90 25.75 96.26 54.52 68.83 71.69 2.62 4.89
DFSL 72.13 42.56 66.10 84.62 73.81 31.68 68.72 42.56
CASTLE 75.48 34.68 73.50 90.20 80.74 11.13 54.07 17.84
Proto-CAT (Ours) 84.18 74.55 93.37 91.20 92.13 49.70 38.27 42.25
Proto-CAT+ (Ours) 93.18 90.16 91.49 54.55 38.16 43.88

Audio-visual generalized few-shot learning classification performance (in %; measured over 10,000 rounds; higher is better) of 5-way 1-shot training tasks on LRW and LRW-1000 datasets. The best result of each scenario is in bold font. The performance measure on both base and novel classes (Base, Novel in the table) is mean accuracy. Harmonic mean (i.e., H-mean) of the above two is a better generalized few-shot learning performance measure.

 

Prerequisites

Environment

Please refer to requirements.txt and run:

pip install -r requirement.txt

Dataset

  • Use preprocessed data (suggested):

    LRW and LRW-1000 forbid directly share the preprocessed data.

  • Use raw data and do preprocess:

    Download LRW Dataset and unzip, like,

    /your data_path set in .sh file
    ├── lipread_mp4
    │   ├── [ALL CLASS FOLDER]
    │   ├── ...
    

    Run prepare_lrw_audio.py and prepare_lrw_video.py to preprocess data on video and audio modality, respectively. Please modify the data path in the above preprocessing file in advance.

    Similarly, Download LRW-1000 dataset and unzip. Run prepare_lrw1000_audio.py and prepare_lrw1000_video.py to preprocess it.

Pretrained Weights

We provide pretrained weights on LRW and LRW-1000 dataset. Download from Google Drive or Baidu Yun(password: 3ad2) and put them as:

/your init_weights set in .sh file
├── Conv1dResNetGRU_LRW-pre.pth
├── Conv3dResNetLSTM_LRW-pre.pth
├── Conv1dResNetGRU_LRW1000-pre.pth
├── Conv3dResNetLSTM_LRW1000-pre.pth

 

How to Train Proto-CAT

For LRW dataset, fine-tune the parameters in run/protocat_lrw.sh, and run:

cd ./Proto-CAT/run
bash protocat_lrw.sh

Similarly, run bash protocat_lrw1000.sh for dataset LRW-1000.

Run bash protocat_plus_lrw.sh / bash protocat_plus_lrw1000.sh to train Proto-CAT+.

How to Reproduce the Result of Proto-CAT

Download the trained models from Google Drive or Baidu Yun(password: swzd) and run:

bash test_protocat_lrw.sh

Run bash test_protocat_lrw1000.sh, bash test_protocat_plus_lrw.sh, or bash test_protocat_plus_lrw1000.sh to evaluate other models.

 

Code Structures

Proto-CAT's entry function is in main.py. It calls the manager Trainer in models/train.py that contains the main training logic. In Trainer, prepare_handle.prepare_dataloader combined with train_prepare_batch inputs and preprocesses generalized few-shot style data. fit_handle controls forward and backward propagation. callbacks deals with the behaviors at each stage.

Arguments

All parameters are defined in models/utils.py. We list the main ones below:

  • do_train, do_test: Store-true switch for whether to train or test.
  • data_path: Data directory to be set.
  • model_save_path: Optimal model save directory to be set.
  • init_weights: Pretrained weights to be set.
  • dataset: Option for the dataset.
  • model_class: Option for the top model.
  • backend_type: Option list for the backend type.
  • train_way, val_way, test_way, train_shot, val_shot, test_shot, train_query, val_query, test_query: Tasks setting of generalized few-shot learning.
  • gfsl_train, gfsl_test: Switch for whether train or test in generalized few-shot learning way, i.e., whether additional base class data is included.
  • mm_list: Participating modalities.
  • lr_scheduler: List of learning rate scheduler.
  • loss_fn: Option for the loss function.
  • max_epoch: Maximum training epoch.
  • episodes_per_train_epoch, episodes_per_val_epoch, episodes_per_test_epoch: Number of sampled episodes per epoch.
  • num_tasks: Number of tasks per episode.
  • meta_batch_size: Batch size of each task.
  • test_model_filepath: Trained weights .pth file path when testing a model.
  • gpu: Multi-GPU option like --gpu 0,1,2,3.
  • logger_filename: Logger file save directory.
  • time_str: Token for each run, and will generate by itself if empty.
  • acc_per_class: Switch for whether to measure the accuracy of each class with base, novel, and harmonic mean.
  • verbose, epoch_verbose: Switch for whether to output message or output progress bar.
  • torch_seed, cuda_seed, np_seed, random_seed: Seeds of random number generation.

 

Acknowledgment

We thank the following repos providing helpful components/functions in our work.

Owner
Kaiaicy
Few-Shot Learning
Kaiaicy
The dataset of tweets pulling from Twitters with keyword: Hydroxychloroquine, location: US, Time: 2020

HCQ_Tweet_Dataset: FREE to Download. Keywords: HCQ, hydroxychloroquine, tweet, twitter, COVID-19 This dataset is associated with the paper "Understand

2 Mar 16, 2022
This repository contains a PyTorch implementation of the paper Learning to Assimilate in Chaotic Dynamical Systems.

Amortized Assimilation This repository contains a PyTorch implementation of the paper Learning to Assimilate in Chaotic Dynamical Systems. Abstract: T

4 Aug 16, 2022
RuDOLPH: One Hyper-Modal Transformer can be creative as DALL-E and smart as CLIP

[Paper] [Хабр] [Model Card] [Colab] [Kaggle] RuDOLPH 🦌 🎄 ☃️ One Hyper-Modal Tr

Sber AI 230 Dec 31, 2022
An extremely simple, intuitive, hardware-friendly, and well-performing network structure for LiDAR semantic segmentation on 2D range image. IROS21

FIDNet_SemanticKITTI Motivation Implementing complicated network modules with only one or two points improvement on hardware is tedious. So here we pr

YimingZhao 54 Dec 12, 2022
Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation

Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation Official PyTorch implementation for the paper Look

Rishabh Jangir 20 Nov 24, 2022
SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021)

SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021) PyTorch implementation of SnapMix | paper Method Overview Cite

DavidHuang 126 Dec 30, 2022
This is the code for the paper "Motion-Focused Contrastive Learning of Video Representations" (ICCV'21).

Motion-Focused Contrastive Learning of Video Representations Introduction This is the code for the paper "Motion-Focused Contrastive Learning of Video

11 Sep 23, 2022
Türkiye Canlı Mobese Görüntülerinde Profesyonel Nesne Takip Sistemi

Türkiye Mobese Görüntü Takip Türkiye Mobese görüntülerinde OPENCV ve Yolo ile takip sistemi Multiple Object Tracking System in Turkish Mobese with OPE

15 Dec 22, 2022
An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge.

Bottom-Up and Top-Down Attention for Visual Question Answering An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge. The

Hengyuan Hu 731 Jan 03, 2023
Pre-trained models for a Cascaded-FCN in caffe and tensorflow that segments

Cascaded-FCN This repository contains the pre-trained models for a Cascaded-FCN in caffe and tensorflow that segments the liver and its lesions out of

300 Nov 22, 2022
A simple implementation of Kalman filter in single object tracking

kalman-filter-in-single-object-tracking A simple implementation of Kalman filter in single object tracking https://www.bilibili.com/video/BV1Qf4y1J7D4

130 Dec 26, 2022
Applicator Kit for Modo allow you to apply Apple ARKit Face Tracking data from your iPhone or iPad to your characters in Modo.

Applicator Kit for Modo Applicator Kit for Modo allow you to apply Apple ARKit Face Tracking data from your iPhone or iPad with a TrueDepth camera to

Andrew Buttigieg 3 Aug 24, 2021
Evolving neural network parameters in JAX.

Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc

Trevor Thackston 6 Feb 12, 2022
Oriented Object Detection: Oriented RepPoints + Swin Transformer/ReResNet

Oriented RepPoints for Aerial Object Detection The code for the implementation of “Oriented RepPoints + Swin Transformer/ReResNet”. Introduction Based

96 Dec 13, 2022
PyTorch implementation of the REMIND method from our ECCV-2020 paper "REMIND Your Neural Network to Prevent Catastrophic Forgetting"

REMIND Your Neural Network to Prevent Catastrophic Forgetting This is a PyTorch implementation of the REMIND algorithm from our ECCV-2020 paper. An ar

Tyler Hayes 72 Nov 27, 2022
Head2Toe: Utilizing Intermediate Representations for Better OOD Generalization

Head2Toe: Utilizing Intermediate Representations for Better OOD Generalization Code for reproducing our results in the Head2Toe paper. Paper: arxiv.or

Google Research 62 Dec 12, 2022
Contrastively Disentangled Sequential Variational Audoencoder

Contrastively Disentangled Sequential Variational Audoencoder (C-DSVAE) Overview This is the implementation for our C-DSVAE, a novel self-supervised d

Junwen Bai 35 Dec 24, 2022
Wider or Deeper: Revisiting the ResNet Model for Visual Recognition

ademxapp Visual applications by the University of Adelaide In designing our Model A, we did not over-optimize its structure for efficiency unless it w

Zifeng Wu 338 Dec 12, 2022
Pytorch implementation of paper "Efficient Nearest Neighbor Language Models" (EMNLP 2021)

Pytorch implementation of paper "Efficient Nearest Neighbor Language Models" (EMNLP 2021)

Junxian He 57 Jan 01, 2023
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 209 Dec 30, 2022