Repository for paper "Non-intrusive speech intelligibility prediction from discrete latent representations"

Overview

Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations

Official repository for paper "Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations".

This public repository is a work in progress! Results here bear no resemblance to results in the paper!

We predict the intelligibility of binaural speech signals by first extracting latent representations from raw audio. Then, a lightweight predictor over these latent representations can be trained. This results in improved performance over predicting on spectral features of the audio, despite the feature extractor not being explicitly trained for this task. In certain cases, a single layer is sufficient for strong correlations between the predictions and the ground-truth scores.

This repository contains:

  • vqcpc/ - Module for VQCPC model in PyTorch
  • stoi/ - Module for Small and SeqPool predictor model in PyTorch
  • data.py - File containing various PyTorch custom datasets
  • main-vqcpc.py - Script for VQCPC training
  • create-latents.py - Script for generating latent dataset from trained VQCPC
  • plot-latents.py - Script for visualizing extracted latent representations
  • main-stoi.py - Script for STOI predictor training
  • main-test.py - Script for evaluating models
  • compute-correlations.py - Script for computing metrics for many models
  • checkpoints/ - trained checkpoints of VQCPC and STOI predictor models
  • config/ - Directory containing various configuration files for experiments
  • results/ - Directory containing official results from experiments
  • dataset/ - Directory containing metadata files for the dataset
  • data-generator/ - Directory containing dataset generation scripts (MATLAB)

All models are implemented in PyTorch. The training scripts are implemented using ptpt - a lightweight framework around PyTorch.

Visualisation of binaural waveform, predicted per-frame STOI, and latent representation: Visualisation of binaural waveform, predicted per-frame STOI, and latent representation.

Usage

VQ-CPC Training

Begin VQ-CPC training using the configuration defined in config.toml:

python main-vqcpc.py --cfg-path config-path.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Latent Dataset Generation

Begin latent dataset generation using pre-trained VQCPC model-checkpoint.pt from dataset wav-dataset and output to latent-dataset using configuration defined in config.toml:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml

As above, but distributed across n processes with script rank r:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml --array-size n --array-rank r

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--detect-anomaly    # detect autograd anomalies and terminate if encountered
-n                  # alias for `--array-size`
-r                  # alias for `--array-rank`

Latent Plotting

Begin interactive VQCPC latent visualisation script using pre-trained model model-checkpoint.pt on dataset wav-dataset using configuration defined in config.toml:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml

If you additionally have a pre-trained, per-frame STOI score predictor (not SeqPool predictor) you can specify the checkpoint stoi-checkpoint.pt and additional configuration stoi-config.toml, you can plot per-frame scores alongside the waveform and latent features:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml --stoi stoi-checkpoint.pt --stoi-cfg stoi-config.toml

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--cmap              # define matplotlib colourmap
--style             # define matplotlib style

STOI Predictor Training

Begin intelligibility score predictor training script using configuration in config.toml:

python main-stoi.py --cfg-path config.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Predictor Evaluation

Begin evaluation of a pre-trained STOI score predictor using checkpoint stoi-checkpoint.pt on dataset dataset-root using configuration in stoi-config.toml:

python main-test.py stoi-checkpoint.pt dataset-root --cfg-path stoi-config.toml

Other useful arguments:

--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--batch-size        # control dataloader batch size
--seed              # random seed (default: 12345)

Overall Evaluation

Compare results from many results files produced by main-test.py based on dataset ground truth:

python compute-correlations.py ground-truth.csv pred-1.csv ... pred-n.csv --names pred-1 ... pred-n

Configuration

Examples configurations for all experiments can be found here

We use toml files to define configurations. Each one consists of three sections:

  • [trainer]: configuration options for ptpt.TrainerConfig.
  • [data]: configuration options for the dataset.
  • [vqcpc] or [stoi]: configuration options for the VQCPC and predictor models respectively.

Checkpoints

Pretrained checkpoints for all models can be found here

Citation

TODO: add citation once paper published / arXiv-ed :)

Owner
Alex McKinney
Final-year student at Durham University. Interested in generative models and unsupervised representation learning.
Alex McKinney
Myia prototyping

Myia Myia is a new differentiable programming language. It aims to support large scale high performance computations (e.g. linear algebra) and their g

Mila 456 Nov 07, 2022
Unsupervised Image-to-Image Translation

UNIT: UNsupervised Image-to-image Translation Networks Imaginaire Repository We have a reimplementation of the UNIT method that is more performant. It

Ming-Yu Liu 劉洺堉 1.9k Dec 26, 2022
TabNet for fastai

TabNet for fastai This is an adaptation of TabNet (Attention-based network for tabular data) for fastai (=2.0) library. The original paper https://ar

Mikhail Grankin 116 Oct 21, 2022
TensorFlow CNN for fast style transfer

Fast Style Transfer in TensorFlow Add styles from famous paintings to any photo in a fraction of a second! It takes 100ms on a 2015 Titan X to style t

1 Dec 14, 2021
Zalo AI challenge 2021 task hum to song

Zalo AI challenge 2021 task Hum to Song pipeline: Chuẩn bị dữ liệu cho quá trình train: Sửa các file đường dẫn trong config/preprocess.yaml raw_path:

Vo Van Phuc 105 Dec 16, 2022
A platform to display the carbon neutralization information for researchers, decision-makers, and other participants in the community.

Welcome to Carbon Insight Carbon Insight is a platform aiming to display the carbon neutralization roadmap for researchers, decision-makers, and other

Microsoft 14 Oct 24, 2022
Lab Materials for MIT 6.S191: Introduction to Deep Learning

This repository contains all of the code and software labs for MIT 6.S191: Introduction to Deep Learning! All lecture slides and videos are available

Alexander Amini 5.6k Dec 26, 2022
DALL-Eval: Probing the Reasoning Skills and Social Biases of Text-to-Image Generative Transformers

DALL-Eval: Probing the Reasoning Skills and Social Biases of Text-to-Image Generative Transformers Authors: Jaemin Cho, Abhay Zala, and Mohit Bansal (

Jaemin Cho 98 Dec 15, 2022
MegEngine implementation of YOLOX

Introduction YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and ind

旷视天元 MegEngine 77 Nov 22, 2022
Tiny Object Detection in Aerial Images.

AI-TOD AI-TOD is a dataset for tiny object detection in aerial images. [Paper] [Dataset] Description AI-TOD comes with 700,621 object instances for ei

jwwangchn 116 Dec 30, 2022
Official implementation of Representer Point Selection via Local Jacobian Expansion for Post-hoc Classifier Explanation of Deep Neural Networks and Ensemble Models at NeurIPS 2021

Representer Point Selection via Local Jacobian Expansion for Classifier Explanation of Deep Neural Networks and Ensemble Models This repository is the

Yi(Amy) Sui 2 Dec 01, 2021
SimulLR - PyTorch Implementation of SimulLR

PyTorch Implementation of SimulLR There is an interesting work[1] about simultan

11 Dec 22, 2022
SimBERT升级版(SimBERTv2)!

RoFormer-Sim RoFormer-Sim,又称SimBERTv2,是我们之前发布的SimBERT模型的升级版。 介绍 https://kexue.fm/archives/8454 训练 tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6 下载

318 Dec 31, 2022
Artifacts for paper "MMO: Meta Multi-Objectivization for Software Configuration Tuning"

MMO: Meta Multi-Objectivization for Software Configuration Tuning This repository contains the data and code for the following paper that is currently

0 Nov 17, 2021
An end-to-end PyTorch framework for image and video classification

What's New: March 2021: Added RegNetZ models November 2020: Vision Transformers now available, with training recipes! 2020-11-20: Classy Vision v0.5 R

Facebook Research 1.5k Dec 31, 2022
PyTorch implementation for "Mining Latent Structures with Contrastive Modality Fusion for Multimedia Recommendation"

MIRCO PyTorch implementation for paper: Latent Structures Mining with Contrastive Modality Fusion for Multimedia Recommendation Dependencies Python 3.

Big Data and Multi-modal Computing Group, CRIPAC 9 Dec 08, 2022
Source code for the GPT-2 story generation models in the EMNLP 2020 paper "STORIUM: A Dataset and Evaluation Platform for Human-in-the-Loop Story Generation"

Storium GPT-2 Models This is the official repository for the GPT-2 models described in the EMNLP 2020 paper [STORIUM: A Dataset and Evaluation Platfor

Nader Akoury 27 Dec 20, 2022
Simple data balancing baselines for worst-group-accuracy benchmarks.

BalancingGroups Code to replicate the experimental results from Simple data balancing baselines achieve competitive worst-group-accuracy. Replicating

Meta Research 29 Dec 02, 2022
[ICML 2021] Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data

Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data This repo provides the source code & data of our paper: Break-It-Fix-It: Unsupervised

Michihiro Yasunaga 86 Nov 30, 2022
Keras implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 8.9k Jan 04, 2023