Repository for paper "Non-intrusive speech intelligibility prediction from discrete latent representations"

Overview

Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations

Official repository for paper "Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations".

This public repository is a work in progress! Results here bear no resemblance to results in the paper!

We predict the intelligibility of binaural speech signals by first extracting latent representations from raw audio. Then, a lightweight predictor over these latent representations can be trained. This results in improved performance over predicting on spectral features of the audio, despite the feature extractor not being explicitly trained for this task. In certain cases, a single layer is sufficient for strong correlations between the predictions and the ground-truth scores.

This repository contains:

  • vqcpc/ - Module for VQCPC model in PyTorch
  • stoi/ - Module for Small and SeqPool predictor model in PyTorch
  • data.py - File containing various PyTorch custom datasets
  • main-vqcpc.py - Script for VQCPC training
  • create-latents.py - Script for generating latent dataset from trained VQCPC
  • plot-latents.py - Script for visualizing extracted latent representations
  • main-stoi.py - Script for STOI predictor training
  • main-test.py - Script for evaluating models
  • compute-correlations.py - Script for computing metrics for many models
  • checkpoints/ - trained checkpoints of VQCPC and STOI predictor models
  • config/ - Directory containing various configuration files for experiments
  • results/ - Directory containing official results from experiments
  • dataset/ - Directory containing metadata files for the dataset
  • data-generator/ - Directory containing dataset generation scripts (MATLAB)

All models are implemented in PyTorch. The training scripts are implemented using ptpt - a lightweight framework around PyTorch.

Visualisation of binaural waveform, predicted per-frame STOI, and latent representation: Visualisation of binaural waveform, predicted per-frame STOI, and latent representation.

Usage

VQ-CPC Training

Begin VQ-CPC training using the configuration defined in config.toml:

python main-vqcpc.py --cfg-path config-path.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Latent Dataset Generation

Begin latent dataset generation using pre-trained VQCPC model-checkpoint.pt from dataset wav-dataset and output to latent-dataset using configuration defined in config.toml:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml

As above, but distributed across n processes with script rank r:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml --array-size n --array-rank r

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--detect-anomaly    # detect autograd anomalies and terminate if encountered
-n                  # alias for `--array-size`
-r                  # alias for `--array-rank`

Latent Plotting

Begin interactive VQCPC latent visualisation script using pre-trained model model-checkpoint.pt on dataset wav-dataset using configuration defined in config.toml:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml

If you additionally have a pre-trained, per-frame STOI score predictor (not SeqPool predictor) you can specify the checkpoint stoi-checkpoint.pt and additional configuration stoi-config.toml, you can plot per-frame scores alongside the waveform and latent features:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml --stoi stoi-checkpoint.pt --stoi-cfg stoi-config.toml

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--cmap              # define matplotlib colourmap
--style             # define matplotlib style

STOI Predictor Training

Begin intelligibility score predictor training script using configuration in config.toml:

python main-stoi.py --cfg-path config.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Predictor Evaluation

Begin evaluation of a pre-trained STOI score predictor using checkpoint stoi-checkpoint.pt on dataset dataset-root using configuration in stoi-config.toml:

python main-test.py stoi-checkpoint.pt dataset-root --cfg-path stoi-config.toml

Other useful arguments:

--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--batch-size        # control dataloader batch size
--seed              # random seed (default: 12345)

Overall Evaluation

Compare results from many results files produced by main-test.py based on dataset ground truth:

python compute-correlations.py ground-truth.csv pred-1.csv ... pred-n.csv --names pred-1 ... pred-n

Configuration

Examples configurations for all experiments can be found here

We use toml files to define configurations. Each one consists of three sections:

  • [trainer]: configuration options for ptpt.TrainerConfig.
  • [data]: configuration options for the dataset.
  • [vqcpc] or [stoi]: configuration options for the VQCPC and predictor models respectively.

Checkpoints

Pretrained checkpoints for all models can be found here

Citation

TODO: add citation once paper published / arXiv-ed :)

Owner
Alex McKinney
Final-year student at Durham University. Interested in generative models and unsupervised representation learning.
Alex McKinney
ML-Ensemble – high performance ensemble learning

A Python library for high performance ensemble learning ML-Ensemble combines a Scikit-learn high-level API with a low-level computational graph framew

Sebastian Flennerhag 764 Dec 31, 2022
Example for AUAV 2022 with obstacle avoidance.

AUAV 2022 Sample This is a sample PX4 based quadrotor path planning framework based on Ubuntu 20.04 and ROS noetic for the IEEE Autonomous UAS 2022 co

James Goppert 11 Sep 16, 2022
Source code for From Stars to Subgraphs

GNNAsKernel Official code for From Stars to Subgraphs: Uplifting Any GNN with Local Structure Awareness Visualizations GNN-AK(+) GNN-AK(+) with Subgra

44 Dec 19, 2022
Code accompanying the paper "Wasserstein GAN"

Wasserstein GAN Code accompanying the paper "Wasserstein GAN" A few notes The first time running on the LSUN dataset it can take a long time (up to an

3.1k Jan 01, 2023
GNPy: Optical Route Planning and DWDM Network Optimization

GNPy is an open-source, community-developed library for building route planning and optimization tools in real-world mesh optical networks

Telecom Infra Project 140 Dec 19, 2022
PyTorch implementation of Towards Accurate Alignment in Real-time 3D Hand-Mesh Reconstruction (ICCV 2021).

Towards Accurate Alignment in Real-time 3D Hand-Mesh Reconstruction Introduction This is official PyTorch implementation of Towards Accurate Alignment

TANG Xiao 96 Dec 27, 2022
Weakly supervised medical named entity classification

Trove Trove is a research framework for building weakly supervised (bio)medical named entity recognition (NER) and other entity attribute classifiers

60 Nov 18, 2022
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
Boundary-aware Transformers for Skin Lesion Segmentation

Boundary-aware Transformers for Skin Lesion Segmentation Introduction This is an official release of the paper Boundary-aware Transformers for Skin Le

Jiacheng Wang 79 Dec 16, 2022
Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning

Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning This is the official repository for Conservative and Adaptive Penalty fo

7 Nov 22, 2022
Code for the paper "TadGAN: Time Series Anomaly Detection Using Generative Adversarial Networks"

TadGAN: Time Series Anomaly Detection Using Generative Adversarial Networks This is a Python3 / Pytorch implementation of TadGAN paper. The associated

Arun 92 Dec 03, 2022
Data manipulation and transformation for audio signal processing, powered by PyTorch

torchaudio: an audio library for PyTorch The aim of torchaudio is to apply PyTorch to the audio domain. By supporting PyTorch, torchaudio follows the

1.9k Dec 28, 2022
Code basis for the paper "Camera Condition Monitoring and Readjustment by means of Noise and Blur" (2021)

Camera Condition Monitoring and Readjustment by means of Noise and Blur This repository contains the source code of the paper: Wischow, M., Gallego, G

7 Dec 22, 2022
CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

Frederick Wang 3 Apr 26, 2022
Blender scripts for computing geodesic distance

GeoDoodle Geodesic distance computation for Blender meshes Table of Contents Overivew Usage Implementation Overview This addon provides an operator fo

20 Jun 08, 2022
[ECCVW2020] Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DiMP)

Feel free to visit my homepage Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DIMP) [ECCVW2020 paper] Presentation

Seokeon Choi 35 Oct 26, 2022
Learning Representations that Support Robust Transfer of Predictors

Transfer Risk Minimization (TRM) Code for Learning Representations that Support Robust Transfer of Predictors Prepare the Datasets Preprocess the Scen

Yilun Xu 15 Dec 07, 2022
This repo contains the pytorch implementation for Dynamic Concept Learner (accepted by ICLR 2021).

DCL-PyTorch Pytorch implementation for the Dynamic Concept Learner (DCL). More details can be found at the project page. Framework Grounding Physical

Zhenfang Chen 31 Jan 06, 2023
Official PyTorch Implementation of Mask-aware IoU and maYOLACT Detector [BMVC2021]

The official implementation of Mask-aware IoU and maYOLACT detector. Our implementation is based on mmdetection. Mask-aware IoU for Anchor Assignment

Kemal Oksuz 46 Sep 29, 2022
A new test set for ImageNet

ImageNetV2 The ImageNetV2 dataset contains new test data for the ImageNet benchmark. This repository provides associated code for assembling and worki

186 Dec 18, 2022