BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

Overview

key_visual

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

This is a demo implementation of BYOL for Audio (BYOL-A), a self-supervised learning method for general-purpose audio representation, includes:

  • Training code that can train models with arbitrary audio files.
  • Evaluation code that can evaluate trained models with downstream tasks.
  • Pretrained weights.

If you find BYOL-A useful in your research, please use the following BibTeX entry for citation.

@misc{niizumi2021byol-a,
      title={BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation}, 
      author={Daisuke Niizumi and Daiki Takeuchi and Yasunori Ohishi and Noboru Harada and Kunio Kashino},
      booktitle = {2021 International Joint Conference on Neural Networks, {IJCNN} 2021},
      year={2021},
      eprint={2103.06695},
      archivePrefix={arXiv},
      primaryClass={eess.AS}
}

Getting Started

  1. Download external source files, and apply a patch. Our implementation uses the following.

    curl -O https://raw.githubusercontent.com/lucidrains/byol-pytorch/2aa84ee18fafecaf35637da4657f92619e83876d/byol_pytorch/byol_pytorch.py
    patch < byol_a/byol_pytorch.diff
    mv byol_pytorch.py byol_a
    curl -O https://raw.githubusercontent.com/daisukelab/general-learning/7b31d31637d73e1a74aec3930793bd5175b64126/MLP/torch_mlp_clf.py
    mv torch_mlp_clf.py utils
  2. Install PyTorch 1.7.1, torchaudio, and other dependencies listed on requirements.txt.

Evaluating BYOL-A Representations

Downstream Task Evaluation

The following steps will perform a downstream task evaluation by linear-probe fashion. This is an example with SPCV2; Speech commands dataset v2.

  1. Preprocess metadata (.csv file) and audio files, processed files will be stored under a folder work.

    # usage: python -m utils.preprocess_ds <downstream task> <path to its dataset>
    python -m utils.preprocess_ds spcv2 /path/to/speech_commands_v0.02
  2. Run evaluation. This will convert all .wav audio to representation embeddings first, train a lineaer layer network, then calculate accuracy as a result.

    python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth spcv2

You can also run an evaluation multiple times and take an average result. Following will evaluate on UrbanSound8K with a unit audio duration of 4.0 seconds, for 10 times.

# usage: python evaluate.py <your weight> <downstream task> <unit duration sec.> <# of iteration>
python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth us8k 4.0 10

Evaluating Representations In Your Tasks

This is an example to calculate a feature vector for an audio sample.

from byol_a.common import *
from byol_a.augmentations import PrecomputedNorm
from byol_a.models import AudioNTT2020


device = torch.device('cuda')
cfg = load_yaml_config('config.yaml')
print(cfg)

# Mean and standard deviation of the log-mel spectrogram of input audio samples, pre-computed.
# See calc_norm_stats in evaluate.py for your reference.
stats = [-5.4919195,  5.0389895]

# Preprocessor and normalizer.
to_melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate=cfg.sample_rate,
    n_fft=cfg.n_fft,
    win_length=cfg.win_length,
    hop_length=cfg.hop_length,
    n_mels=cfg.n_mels,
    f_min=cfg.f_min,
    f_max=cfg.f_max,
)
normalizer = PrecomputedNorm(stats)

# Load pretrained weights.
model = AudioNTT2020(d=cfg.feature_d)
model.load_weight('pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth', device)

# Load your audio file.
wav, sr = torchaudio.load('work/16k/spcv2/one/00176480_nohash_0.wav') # a sample from SPCV2 for now
assert sr == cfg.sample_rate, "Let's convert the audio sampling rate in advance, or do it here online."

# Convert to a log-mel spectrogram, then normalize.
lms = normalizer((to_melspec(wav) + torch.finfo(torch.float).eps).log())

# Now, convert the audio to the representation.
features = model(lms.unsqueeze(0))

Training From Scratch

You can also train models. Followings are an example of training on FSD50K.

  1. Convert all samples to 16kHz. This will convert all FSD50K files to a folder work/16k/fsd50k while preserving folder structure.

    python -m utils.convert_wav /path/to/fsd50k work/16k/fsd50k
  2. Start training, this example trains with all development set audio samples from FSD50K.

    python train.py work/16k/fsd50k/FSD50K.dev_audio

Refer to Table VI on our paper for the performance of a model trained on FSD50K.

Pretrained Weights

We include 3 pretrained weights of our encoder network.

Method Dim. Filename NSynth US8K VoxCeleb1 VoxForge SPCV2/12 SPCV2 Average
BYOL-A 512-d AudioNTT2020-BYOLA-64x96d512.pth 69.1% 78.2% 33.4% 83.5% 86.5% 88.9% 73.3%
BYOL-A 1024-d AudioNTT2020-BYOLA-64x96d1024.pth 72.7% 78.2% 38.0% 88.5% 90.1% 91.4% 76.5%
BYOL-A 2048-d AudioNTT2020-BYOLA-64x96d2048.pth 74.1% 79.1% 40.1% 90.2% 91.0% 92.2% 77.8%

License

This implementation is for your evaluation of BYOL-A paper, see LICENSE for the detail.

Acknowledgements

BYOL-A is built on top of byol-pytorch, a BYOL implementation by Phil Wang (@lucidrains). We thank Phil for open-source sophisticated code.

@misc{wang2020byol-pytorch,
  author =       {Phil Wang},
  title =        {Bootstrap Your Own Latent (BYOL), in Pytorch},
  howpublished = {\url{https://github.com/lucidrains/byol-pytorch}},
  year =         {2020}
}

References

Comments
  • Question for reproducing results

    Question for reproducing results

    Hi,

    Thanks for sharing this great work! I tried to reproduce the results using the official guidance but I failed.

    After processing the data, I run the following commands:

    CUDA_VISIBLE_DEVICES=0 python -W ignore train.py work/16k/fsd50k/FSD50K.dev_audio
    cp lightning_logs/version_4/checkpoints/epoch\=99-step\=16099.ckpt AudioNTT2020-BYOLA-64x96d2048.pth
    CUDA_VISIBLE_DEVICES=4 python evaluate.py AudioNTT2020-BYOLA-64x96d2048.pth spcv2
    

    However, the results are far from the reported results

    image

    Did I miss something important? Thank you very much.

    question 
    opened by ChenyangLEI 15
  • Evaluation on voxforge

    Evaluation on voxforge

    Hi,

    Thank you so much for your contribution. This works is very interesting and your code is easy for me to follow. But one of the downstream dataset, voxforge is missing from the preprocess_ds.py. Could you please release the code for that dataset, too?

    Thank you again for your time.

    Best regards

    opened by Huiimin5 9
  • A mistake in RunningMean

    A mistake in RunningMean

    Thank you for the fascinating paper and the code to reproduce it!

    I think there might be a problem in RunningMean. The current formula (the same in v1 and v2) looks like this:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n - 1}, $$

    which is inconsistent with the correct formula listed on StackOverflow:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n}. $$

    The problem is that self.n is incremented after the new mean is computed. Could you please either correct me if I am wrong or correct the code?

    opened by WhiteTeaDragon 4
  • a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    Traceback (most recent call last):
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2066, in <module>
        main()
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2060, in main
        globals = debugger.run(setup['file'], None, None, is_module)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1411, in run
        return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1418, in _exec
        pydev_imports.execfile(file, globals, locals)  # execute the script
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
        exec(compile(contents+"\n", file, 'exec'), glob, loc)
      File "E:/pythonSpace/byol-a/train.py", line 132, in <module>
        main(audio_dir=base_path + '1/', epochs=100)
      File "E:/pythonSpace/byol-a/train.py", line 112, in main
        learner = BYOLALearner(model, cfg.lr, cfg.shape,
      File "E:/pythonSpace/byol-a/train.py", line 56, in __init__
        self.learner = BYOL(model, image_size=shape, **kwargs)
      File "D:\min\envs\torch1_7_1\lib\site-packages\byol_pytorch\byol_pytorch.py", line 211, in __init__
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
    TypeError: randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3
    
    Not_an_issue 
    opened by a1030076395 3
  • Question about comments in the train.py

    Question about comments in the train.py

    https://github.com/nttcslab/byol-a/blob/master/train.py

    At line 67, there is comments for the shape of input.

            # in fact, it should be (B, 1, F, T), e.g. (256, 1, 64, 96) where 64 is the number of mel bins
            paired_inputs = torch.cat(paired_inputs) # [(B,1,T,F), (B,1,T,F)] -> (2*B,1,T,F)
    

    image

    However, it is different from the descriptions in config.yml file

    # Shape of loh-mel spectrogram [F, T].
    shape: [64, 96]
    
    bug 
    opened by ChenyangLEI 2
  • Doubt in paper

    Doubt in paper

    Hi there,

    Section 4, subsection A, part 1 from your paper says:

     The number of frames, T, in one segment was 96 in pretraining, which corresponds to 1,014ms. 
    

    However, the previous line says the hop size used was 10ms. So according to this 96 would mean 960ms?

    Am I understanding something wrong here?

    Thank You in advance!

    question 
    opened by Sreyan88 2
  • Random crop is not working.

    Random crop is not working.

    https://github.com/nttcslab/byol-a/blob/60cebdc514951e6b42e18e40a2537a01a39ad47b/byol_a/dataset.py#L80-L82

    If len(wav) > self.unit_length, length_adj will be a negative value. So start will be 0. If wav (before pad) is shorter than unit length, length_adj == 0 after padding. So start is always 0. So It will only perform a certain area of crop from 0 to self.unit_length (cropped_wav == wav[0: self.unit_length]), not random crop.

    So I think line 80 should be changed to length_adj = len(wav) - self.unit_length .

    bug 
    opened by JUiscoming 2
  • Doubt in RunningNorm

    Doubt in RunningNorm

    Hi There, great repo!

    I think I have misunderstood something wrong with the RunningNorm function. The function expects the size of an epoch, however, your implementation passes the size of the entire dataset.

    Is it a bug? Or is there a problem with my understanding?

    Thank You!

    question 
    opened by Sreyan88 2
  • How to interpret the performance

    How to interpret the performance

    Hi, it' s a great work, but how can I understance the performance metric? For example, VoxCeleb1 is usually for speaker verification, shouldn't we measure EER?

    opened by ranchlai 2
  • Finetuning of BYOL-A

    Finetuning of BYOL-A

    Hi,

    your paper is super interesting. I have a question regarding the downstream tasks. If I understand the paper correctly, you used a single linear layer for the downstream tasks which only used the sum of mean and max of the representation over time as input.

    Did you try to finetune BYOL-A end-to-end after pretraining to the downstream tasks? In the case of TRILL they were able to improve the performance even further by finetuning the whole model end-to-end. Is there a specific reason why this is not possible with BYOL-A?

    questions 
    opened by mschiwek 1
  • Missing scaling of validation samples in evaluate.py

    Missing scaling of validation samples in evaluate.py

    https://github.com/nttcslab/byol-a/blob/master/evaluate.py#L112

    It also needs: X_val = scaler.transform(X_val), or validation acc & loss will be invalid. This can be one of the reasons why we see lower performance when I tried to get official performances...

    bug 
    opened by daisukelab 0
Releases(v2.0.0)
Owner
NTT Communication Science Laboratories
NTT Communication Science Laboratories
Aydin is a user-friendly, feature-rich, and fast image denoising tool

Aydin is a user-friendly, feature-rich, and fast image denoising tool that provides a number of self-supervised, auto-tuned, and unsupervised image denoising algorithms.

Royer Lab 99 Dec 14, 2022
Fully Convlutional Neural Networks for state-of-the-art time series classification

Deep Learning for Time Series Classification As the simplest type of time series data, univariate time series provides a reasonably good starting poin

Stephen 572 Dec 23, 2022
OpenPCDet Toolbox for LiDAR-based 3D Object Detection.

OpenPCDet OpenPCDet is a clear, simple, self-contained open source project for LiDAR-based 3D object detection. It is also the official code release o

OpenMMLab 3.2k Dec 31, 2022
YouRefIt: Embodied Reference Understanding with Language and Gesture

YouRefIt: Embodied Reference Understanding with Language and Gesture YouRefIt: Embodied Reference Understanding with Language and Gesture by Yixin Che

16 Jul 11, 2022
Using PyTorch Perform intent classification using three different models to see which one is better for this task

Using PyTorch Perform intent classification using three different models to see which one is better for this task

Yoel Graumann 1 Feb 14, 2022
Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces(ICML 2021)

Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces(ICML 2021) This repository contains the code

149 Dec 15, 2022
FFCV: Fast Forward Computer Vision (and other ML workloads!)

Fast Forward Computer Vision: train models at a fraction of the cost with accele

FFCV 2.3k Jan 03, 2023
Official Implementation of 'UPDeT: Universal Multi-agent Reinforcement Learning via Policy Decoupling with Transformers' ICLR 2021(spotlight)

UPDeT Official Implementation of UPDeT: Universal Multi-agent Reinforcement Learning via Policy Decoupling with Transformers (ICLR 2021 spotlight) The

hhhusiyi 96 Dec 22, 2022
Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides

Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides Project | This repo is the officia

CVSM Group - email: <a href=[email protected]"> 33 Dec 28, 2022
Semi-Supervised Semantic Segmentation with Pixel-Level Contrastive Learning from a Class-wise Memory Bank

This repository provides the official code for replicating experiments from the paper: Semi-Supervised Semantic Segmentation with Pixel-Level Contrast

Iñigo Alonso Ruiz 58 Dec 15, 2022
Simulated garment dataset for virtual try-on

Simulated garment dataset for virtual try-on This repository contains the dataset used in the following papers: Self-Supervised Collision Handling via

33 Dec 20, 2022
Differentiable molecular simulation of proteins with a coarse-grained potential

Differentiable molecular simulation of proteins with a coarse-grained potential This repository contains the learned potential, simulation scripts and

UCL Bioinformatics Group 44 Dec 10, 2022
Fine-grained Control of Image Caption Generation with Abstract Scene Graphs

Faster R-CNN pretrained on VisualGenome This repository modifies maskrcnn-benchmark for object detection and attribute prediction on VisualGenome data

Shizhe Chen 7 Apr 20, 2021
ANN model for prediction a spatio-temporal distribution of supercooled liquid in mixed-phase clouds using Doppler cloud radar spectra.

VOODOO Revealing supercooled liquid beyond lidar attenuation Explore the docs » Report Bug · Request Feature Table of Contents About The Project Built

remsens-lim 2 Apr 28, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Adversarial Learning for Semi-supervised Semantic Segmentation, BMVC 2018

Adversarial Learning for Semi-supervised Semantic Segmentation This repo is the pytorch implementation of the following paper: Adversarial Learning fo

Wayne Hung 464 Dec 19, 2022
Monocular 3D Object Detection: An Extrinsic Parameter Free Approach (CVPR2021)

Monocular 3D Object Detection: An Extrinsic Parameter Free Approach (CVPR2021) Yunsong Zhou, Yuan He, Hongzi Zhu, Cheng Wang, Hongyang Li, Qinhong Jia

Yunsong Zhou 51 Dec 14, 2022
PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"

PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"

Yulun Zhang 1.2k Dec 26, 2022
This is an easy python software which allows to sort images with faces by gender and after by age.

Gender-age Classifier This is an easy python software which allows to sort images with faces by gender and after by age. Usage First install Deepface

Claudio Ciccarone 6 Sep 17, 2022
Human head pose estimation using Keras over TensorFlow.

RealHePoNet: a robust single-stage ConvNet for head pose estimation in the wild.

Rafael Berral Soler 71 Jan 05, 2023