Skip to content

Code for reproducible experiments presented in KSD Aggregated Goodness-of-fit Test. NeurIPS 2022


Notifications You must be signed in to change notification settings


Repository files navigation

Reproducibility code for KSDAgg: KSD Aggregated Goodness-of-fit Test

This GitHub repository contains the code for the reproducible experiments presented in our paper KSD Aggregated Goodness-of-fit Test:

  • Gamma distribution experiment,
  • Gaussian-Bernoulli Restricted Boltzmann Machine experiment,
  • MNIST Normalizing Flow experiment.

We provide the code to run the experiments to generate Figures 1-4 and Table 1 from our paper, those can be found in figures.

To use our KSDAgg test in practice, we recommend using our ksdagg package, more details available on the ksdagg repository.

Our implementation uses two quantile estimation methods (the wild bootstrap and the parametric bootstrap) with the IMQ (inverse multiquadric) kernel. The KSDAgg test aggregates over a collection of bandwidths, and uses one of the four types of weights proposed in MMD Aggregated Two-Sample Test.


  • python 3.9

The packages in requirements.txt are required to run our tests and the ones we compare against.

Additionally, the jax and jaxlib packages are required to run the Jax implementation of KSDAgg in ksdagg/


In a chosen directory, clone the repository and change to its directory by executing

git clone
cd ksdagg-paper

We then recommend creating and activating a virtual environment by either

  • using venv:
    python3 -m venv ksdagg-env
    source ksdagg-env/bin/activate
    # can be deactivated by running:
    # deactivate
  • or using conda:
    conda create --name ksdagg-env python=3.9
    conda activate ksdagg-env
    # can be deactivated by running:
    # conda deactivate

The packages required for reproducibility of the experiments can then be installed in the virtual environment by running

python -m pip install -r requirements.txt

For using the Jax implementation of our tests, Jax needs to be installed, for which we recommend using conda. This can be done by running

  • for GPU:
    conda install -c conda-forge -c nvidia pip cuda-nvcc "jaxlib=0.4.1=*cuda*" jax
  • or, for CPU:
    conda install -c conda-forge -c nvidia pip jaxlib=0.4.1 jax

Generating or downloading the data

The data for the Gaussian-Bernoulli Restricted Boltzmann Machine experiment and for the MNIST Normalizing Flow experiment can

  • be obtained by executing
  • or, as running the above scripts can be computationally expensive, we also provide the option to download their outputs directly

Those scripts generate samples and compute their associated scores under the model for the different settings considered in our experiments, the data is saved in the new directory data.

Reproducing the experiments of the paper

First, for our three experiments, we compute KSD values to be used for the parametric bootstrap and save them in the directory parametric. This can be done by running


For convenience, we directly provide the directory parametric obtained by running this script.

To run the three experiments, the following commands can be executed


Those commands run all the tests necessary for our experiments, the results are saved in dedicated .csv and .pkl files in the directory results (which is already provided for ease of use). Note that our experiments are comprised of 'embarrassingly parallel for loops', for which significant speed up can be obtained by using parallel computing libraries such as joblib or dask.

The actual figures of the paper can be obtained from the saved dataframes in results by using the command


The figures are saved in the directory figures and correspond to the ones used in our paper.

How to use KSDAgg in practice?

The KSDAgg test is implemented as the function ksdagg in ksdagg/ for the Numpy version and in ksdagg/ for the Jax version.

For the Numpy implementation of our KSDAgg test, we only require the numpy and scipy packages.

For the Jax implementation of our KSDAgg test, we only require the jax and jaxlib packages.

To use our tests in practice, we recommend using our ksdagg package which is available on the ksdagg repository. It can be installed by running

pip install git+

Installation instructions and example code are available on the ksdagg repository.

We also provide some code showing how to use our KSDAgg test in the demo_speed.ipynb notebook which also contains speed comparisons between the Jax and Numpy implementations, as reported below.

Speed in ms Numpy (CPU) Jax (CPU) Jax (GPU)
KSDAgg 12500 1470 22

In practice, we recommend using the Jax implementation as it runs considerably faster (more than 500 times faster in the above table, see notebook demo_speed.ipynb).


Our KSDAgg code is based on our MMDAgg implementation for two-sample testing (MMD Aggregated Two-Sample Test) which can be found at

For the Gaussian-Bernoulli Restricted Boltzmann Machine experiment, we obtain the samples and scores in by relying on Wittawat Jitkrittum's implementation which can be found at under the MIT License. The relevant files we use are in the directory kgof.

For the MNIST Normalizing Flow experiment, we use in a multiscale Normalizing Flow trained on the MNIST dataset as implemented by Phillip Lippe in Tutorial 11: Normalizing Flows for image modeling as part of the UvA Deep Learning Tutorials under the MIT License.


For a computationally efficient version of KSDAgg which can run in linear time, check out our paper Efficient Aggregated Kernel Tests using Incomplete U-statistics with reproducible experiments in the agginc-paper repository and a package in in the agginc repository.


If you have any issues running our code, please do not hesitate to contact Antonin Schrab.


Centre for Artificial Intelligence, Department of Computer Science, University College London

Gatsby Computational Neuroscience Unit, University College London

Inria London


  author    = {Antonin Schrab and Benjamin Guedj and Arthur Gretton},
  title     = {KSD Aggregated Goodness-of-fit Test},
  booktitle = {Advances in Neural Information Processing Systems 35: Annual Conference
               on Neural Information Processing Systems 2022, NeurIPS 2022},
  editor    = {Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
  year      = {2022},


MIT License (see

Related tests

  • mmdagg: MMD Aggregated MMDAgg test
  • agginc: Efficient MMDAggInc HSICAggInc KSDAggInc tests
  • mmdfuse: MMD-Fuse test
  • dpkernel: Differentially private dpMMD dpHSIC tests
  • dckernel: Robust to Data Corruption dcMMD dcHSIC tests