Code for "Unsupervised Source Separation via Bayesian inference in the latent domain"

Overview

LQVAE-separation

Code for "Unsupervised Source Separation via Bayesian inference in the latent domain"

Paper

Samples

GT Compressed Separated
Drums GT Compressed Drums Separated Drums
Bass GT Compressed Bass Separated Bass
Mix GT Compressed Mix Separated Mix

The separation is performed on a x64 compressed latent domain. The results can be upsampled via Jukebox upsamplers in order to increment perceptive quality (WIP).

Install

Install the conda package manager from https://docs.conda.io/en/latest/miniconda.html

conda create --name lqvae-separation python=3.7.5
conda activate lqvae-separation
pip install mpi4py==3.0.3
pip install ffmpeg-python==0.2.0
pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
pip install -r requirements.txt
pip install -e .

Checkpoints

  • Enter inside script/ folder and create the folder checkpoints/ and the folder results/.
  • Download the checkpoints contained in this Google Drive folder and put them inside checkpoints/

Separation with checkpoints

  • Call the following in order to perform bs separations of 3 seconds starting from second shift of the mixture created with the sources in path_1 and path_2. The sources must be WAV files sampled at 22kHz.
    PYTHONPATH=.. python bayesian_inference.py --shift=shift --path_1=path_1 --path_2=path_2 --bs=bs
    
  • The default value for bs is 64, and can be handled by an RTX3080 with 16 GB of VRAM. Lower the value if you get CUDA: out of memory.

Training

LQ-VAE

  • The vqvae/vqvae.pyfile of Jukebox has been modified in order to include the linearization loss of the LQ-VAE (it is computed at all levels of the hierarchical VQ-VAE but we only care of the topmost level given that we perform separation there). One can train a new LQ-VAE on custom data (here data/train for train and data/test for test) by running the following from the root of the project
PYTHONPATH=. mpiexec -n 1 python jukebox/train.py --hps=vqvae --sample_length=131072 --bs=8 
--audio_files_dir=data/train/ --labels=False --train --test --aug_shift --aug_blend --name=lq_vae --test_audio_files_dir=data/test
  • The trained model uses the vqvae hyperparameters in hparams.py so if you want to change the levels / downsampling factors you have to modify them there.
  • The only constraint for training the LQ-VAE is to use an even number for the batch size, given its use of pairs in the loss.
  • Given that L_lin enforces the sum operation on the latent domain, you can use the data of both sources together (or any other audio data).
  • Checkpoints are save in logs/lq_vae (lq_vae is the name parameter).

Priors

  • After training the LQ-VAE, train two priors on two different classes by calling
PYTHONPATH=. mpiexec -n 1 python jukebox/train.py --hps=vqvae,small_prior,all_fp16,cpu_ema --name=pior_source
 --audio_files_dir=data/source/train --test_audio_files_dir=data/source/test --labels=False --train --test --aug_shift
  --aug_blend --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000 --min_duration=24 --sample_length=1048576 
  --bs=16 --n_ctx=8192 --sample=True --sample_iters=1000 --restore_vqvae=logs/lq_vae/checkpoint_lq_vae.pth.tar
  • Here the data of the source is located in data/source/train and data/source/test and we assume the LQ-VAE has 3 levels (topmost level = 2).
  • The Transformer model is defined by the parameters of small_prior in hparams.py and uses a context of n_ctx=8192 codes.
  • The checkpoint path of the LQ-VAE trained in the previous step must be passed to --restore_vqvae
  • Checkpoints are save in logs/pior_source (pior_source is the name parameter).

Codebook sums

  • Before separation, the sums between all codes must be computed using the LQ-VAE. This can be done using the codebook_precalc.py in the script folder:
PYTHONPATH=.. python codebook_precalc.py --save_path=checkpoints/codebook_sum_precalc.pt 
--restore_vqvae=../logs/lq_vae/checkpoint_lq_vae.pth.tar` --raw_to_tokens=64 --l_bins=2048
--sample_rate=22050 --alpha=[0.5, 0.5] --downs_t=(2, 2, 2) --commit=1.0 --emb_width=64

Separation with trained checkpoints

  • Trained checkpoints can be given to bayesian_inference.py as following:
    PYTHONPATH=.. python bayesian_inference.py --shift=shift --path_1=path_1 --path_2=path_2 --bs=bs --restore_vqvae=checkpoints/checkpoint_step_60001_latent.pth.tar
    --restore_priors 'checkpoints/checkpoint_drums_22050_latent_78_19k.pth.tar' checkpoints/checkpoint_latest.pth.tar' --sum_codebook=checkpoints/codebook_precalc_22050_latent.pt
    
  • restore_priors accepts two paths to the first and second prior checkpoints.

Evaluation

  • In order to evaluate the pre-trained checkpoints, run bayesian_test.py after you have put the full Slakh drums and bass validation split inside data/bass/validation and data/drums/validation.

Future work

  • training of upsamplers for increasing the quality of the separations
  • better rejection sampling method (maybe use verifiers as in https://arxiv.org/abs/2110.14168)

Citations

If you find the code useful for your research, please consider citing

@article{mancusi2021unsupervised,
  title={Unsupervised Source Separation via Bayesian Inference in the Latent Domain},
  author={Mancusi, Michele and Postolache, Emilian and Fumero, Marco and Santilli, Andrea and Cosmo, Luca and Rodol{\`a}, Emanuele},
  journal={arXiv preprint arXiv:2110.05313},
  year={2021}
}

as well as the Jukebox baseline:

  • Dhariwal, P., Jun, H., Payne, C., Kim, J. W., Radford, A., & Sutskever, I. (2020). Jukebox: A generative model for music. arXiv preprint arXiv:2005.00341.
Owner
Michele Mancusi
PhD student in Computer Science @ La Sapienza University of Rome, MSc in Quantum Information @ La Sapienza University of Rome
Michele Mancusi
Official pytorch implementation of the AAAI 2021 paper Semantic Grouping Network for Video Captioning

Semantic Grouping Network for Video Captioning Hobin Ryu, Sunghun Kang, Haeyong Kang, and Chang D. Yoo. AAAI 2021. [arxiv] Environment Ubuntu 16.04 CU

Hobin Ryu 43 Nov 25, 2022
Prototype python implementation of the ome-ngff table spec

Prototype python implementation of the ome-ngff table spec

Kevin Yamauchi 8 Nov 20, 2022
Code for CoMatch: Semi-supervised Learning with Contrastive Graph Regularization

CoMatch: Semi-supervised Learning with Contrastive Graph Regularization (Salesforce Research) This is a PyTorch implementation of the CoMatch paper [B

Salesforce 107 Dec 14, 2022
The ICS Chat System project for NYU Shanghai Fall 2021

ICS_Chat_System [Catenger] This is the ICS Chat System project for NYU Shanghai Fall 2021 Creators: Shavarsh Melikyan, Skyler Chen and Arghya Sarkar,

1 Dec 20, 2021
An Api for Emotion recognition.

PLAYEMO Playemo was built from the ground-up with Flask, a python tool that makes it easy for developers to build APIs. Use Cases Is Python your langu

greek geek 2 Jul 16, 2022
WaveFake: A Data Set to Facilitate Audio DeepFake Detection

WaveFake: A Data Set to Facilitate Audio DeepFake Detection This is the code repository for our NeurIPS 2021 (Track on Datasets and Benchmarks) paper

Chair for Sys­tems Se­cu­ri­ty 27 Dec 22, 2022
This repository provides code for "On Interaction Between Augmentations and Corruptions in Natural Corruption Robustness".

On Interaction Between Augmentations and Corruptions in Natural Corruption Robustness This repository provides the code for the paper On Interaction B

Meta Research 33 Dec 08, 2022
From Perceptron model to Deep Neural Network from scratch in Python.

Neural-Network-Basics Aim of this Repository: From Perceptron model to Deep Neural Network (from scratch) in Python. ** Currently working on a basic N

Aditya Kahol 1 Jan 14, 2022
Escaping the Gradient Vanishing: Periodic Alternatives of Softmax in Attention Mechanism

Period-alternatives-of-Softmax Experimental Demo for our paper 'Escaping the Gradient Vanishing: Periodic Alternatives of Softmax in Attention Mechani

slwang9353 0 Sep 06, 2021
This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Convolutional Networks on Node Classification

DropEdge: Towards Deep Graph Convolutional Networks on Node Classification This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Con

401 Dec 16, 2022
Select, weight and analyze complex sample data

Sample Analytics In large-scale surveys, often complex random mechanisms are used to select samples. Estimates derived from such samples must reflect

samplics 37 Dec 15, 2022
PyTorch code of my ICDAR 2021 paper Vision Transformer for Fast and Efficient Scene Text Recognition (ViTSTR)

Vision Transformer for Fast and Efficient Scene Text Recognition (ICDAR 2021) ViTSTR is a simple single-stage model that uses a pre-trained Vision Tra

Rowel Atienza 198 Dec 27, 2022
Request execution of Galaxy SARS-CoV-2 variation analysis workflows on input data you provide.

SARS-CoV-2 processing requests Request execution of Galaxy SARS-CoV-2 variation analysis workflows on input data you provide. Prerequisites This autom

useGalaxy.eu 17 Aug 13, 2022
Scalable Optical Flow-based Image Montaging and Alignment

SOFIMA SOFIMA (Scalable Optical Flow-based Image Montaging and Alignment) is a tool for stitching, aligning and warping large 2d, 3d and 4d microscopy

Google Research 16 Dec 21, 2022
AgeGuesser: deep learning based age estimation system. Powered by EfficientNet and Yolov5

AgeGuesser AgeGuesser is an end-to-end, deep-learning based Age Estimation system, presented at the CAIP 2021 conference. You can find the related pap

5 Nov 10, 2022
The official code of Anisotropic Stroke Control for Multiple Artists Style Transfer

ASMA-GAN Anisotropic Stroke Control for Multiple Artists Style Transfer Proceedings of the 28th ACM International Conference on Multimedia The officia

Six_God 146 Nov 21, 2022
《Geo Word Clouds》paper implementation

《Geo Word Clouds》paper implementation

Russellwzr 2 Jan 28, 2022
Embracing Single Stride 3D Object Detector with Sparse Transformer

SST: Single-stride Sparse Transformer This is the official implementation of paper: Embracing Single Stride 3D Object Detector with Sparse Transformer

TuSimple 385 Dec 28, 2022
CHERRY is a python library for predicting the interactions between viral and prokaryotic genomes

CHERRY is a python library for predicting the interactions between viral and prokaryotic genomes. CHERRY is based on a deep learning model, which consists of a graph convolutional encoder and a link

Kenneth Shang 12 Dec 15, 2022
NeuroLKH: Combining Deep Learning Model with Lin-Kernighan-Helsgaun Heuristic for Solving the Traveling Salesman Problem

NeuroLKH: Combining Deep Learning Model with Lin-Kernighan-Helsgaun Heuristic for Solving the Traveling Salesman Problem Liang Xin, Wen Song, Zhiguang

xinliangedu 33 Dec 27, 2022