Loopy belief propagation for factor graphs on discrete variables, in JAX!

Overview

continuous-integration PyPI version pre-commit.ci status codecov Documentation Status

PGMax

PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.

  • General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
  • LBP in JAX: PGMax generates pure JAX functions implementing LBP for a given factor graph. The generated pure JAX functions run on modern accelerators (GPU/TPU), work with JAX transformations (e.g. vmap for processing batches of models/samples, grad for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.

See our blog post and companion paper for more details.

Installation | Getting started

Installation

Install from PyPI

pip install pgmax

Install latest version from GitHub

pip install git+https://github.com/vicariousinc/PGMax.git

Developer

git clone https://github.com/vicariousinc/PGMax.git
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python3 -
cd PGMax
poetry shell
poetry install
pre-commit install

Install on GPU

By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.

Getting Started

Here are a few self-contained Colab notebooks to help you get started on using PGMax:

Citing PGMax

Please consider citing our companion paper if you use PGMax in your work:

@article{zhou2022pgmax,
  author = {Zhou, Guangyao and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
  title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
  journal = {arXiv preprint arXiv:2202.04110},
  year={2022}
}

First two authors contributed equally.

Comments
  • Incomplete documentation for add_factor

    Incomplete documentation for add_factor

    The documentation for add_factor currently says:

    log_potentials – Optional array of shape (num_val_configs,) or (num_factors, num_val_configs). If specified, it contains the log of the potential value for every possible configuration. If none, it is assumed the log potential is uniform 0 and such an array is automatically initialized.

    However, in order to use it one would need to know the order in which the factors should be specified. Could this be added to the documentation?

    bug documentation 
    opened by nathanielvirgo 8
  • Add support for Python 3.9 and 3.10

    Add support for Python 3.9 and 3.10

    I don't know if this is a bug or a problem with my installation, but if I try to run a file containing only the line

    from pgmax.fg import graph
    

    I get the error

    % /opt/local/bin/python3 /Users/nathaniel/Dropbox/Code/PGMax/test01.py
    Traceback (most recent call last):
      File "/Users/nathaniel/Dropbox/Code/PGMax/test01.py", line 1, in <module>
        from pgmax.fg import graph
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/fg/graph.py", line 11, in <module>
        import pgmax.bp.infer as infer
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/infer.py", line 6, in <module>
        import pgmax.bp.bp_utils as bp_utils
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/bp_utils.py", line 11, in <module>
        @jax.partial(jax.jit, static_argnames="max_segment_length")
    AttributeError: module 'jax' has no attribute 'partial'
    

    This is with Python3.9 installed using Macports on MacOS 10.15.7. PGMax, jax and other prerequisities were installed with pip-3.9 install --user PGMax. I'm happy to give any other information about my installation if you can tell me how to obtain it.

    enhancement 
    opened by nathanielvirgo 6
  • Test sanity check example using new interface and inference modules, and put together the first unit test

    Test sanity check example using new interface and inference modules, and put together the first unit test

    The unit test should run fast. One option is to cache new results. Another option is to just make the model really small.

    In the process, we should also:

    1. Deprecate the current contrib module and create a new examples directory to hold everything.
    2. Start figuring out what our user facing interface should look like.
    opened by StannisZhou 6
  • FactorGraph supports any type of factors + runs specialized inference for ORFactors

    FactorGraph supports any type of factors + runs specialized inference for ORFactors

    In this PR we

    1. Redefine the factor graph abstraction by introducing factor types: factors in a graph are clustered in factor groups, which are grouped according to their factor types. See fg/graph.py
    2. Specify two types of factors: EnumerationFactor (this class already existed) and ORFactor (this new class inherits from the new LogicalFactor). Each Factor class must have its own methods to compile and concatenate Wirings for inference. See factors/enumeration.py and factors/logical.py.
    3. Make running inference in a graph agnostic to the current type of factors supported. New factors types can then be added without modifying graph.py.
    4. Implement a specialized inference for ORFactors (see pass_OR_fac_to_var_messages in factors/logical.py) and compare it with the existing one for EnumerationFactors in the unit test tests/factors/test_or.py
    opened by antoine-dedieu 5
  • RCN example

    RCN example

    This PR contains an example implementation of RCN using the pgmax package. We load a pre-trained RCN model on a very small subset of mnist (20 examples) and test on a small subset of mnist (20 examples). The reported accuracy = 0.80.

    opened by shrinuKushagra 5
  • Variables refactor

    Variables refactor

    We update the way of representing variables. In particular:

    • We get rid of variables names, as welll as of the Variables and CompositeVariableGroup classes. A variable is now represented by a tuple (variable hash, variable num_states) In particular, a FactorGraph can then directly be instantiated asfg = graph.FactorGraph(variables=[hidden_variables, visible_variables]) Similarly, Factors are defined by directly passing the variables involved, as [hidden_variables[ii], visible_variables[jj]]
    • We rewrite NDVariableArray so that the user can access variables by relying on the use of numpy arrays. We also optimize some follow-up computations.
    opened by antoine-dedieu 4
  • Numba speedup for wiring + log potentials

    Numba speedup for wiring + log potentials

    This PR is the continuation of https://github.com/vicariousinc/PGMax/pull/129 and part of our efforts to speed up the adding of FactorGroups and the wiring compilation.

    As https://github.com/vicariousinc/PGMax/pull/129 has moved most of the wiring computation to the FactorGroup level, we can now use numba for fast computation of these wirings

    As a result:

    • adding factors for the RBM exp takes 3s, building run_bp takes 1s
    • adding factors for the convor exp takes 2s, building run_bp takes 1s
    opened by antoine-dedieu 4
  • RCN implementation on a small train and test set

    RCN implementation on a small train and test set

    This PR contains the first implementation of the RCN example using the PGMax package. The file to run is examples/rcn/inference_pgmax_small.py This code contains implementation, visualization on a small set. Trained with 20 examples and tested on 20 examples.

    The inference has been separated from model creation code. Saved models are added to /storage/users/skushagra/pgmax_rcn_artifacts/ .

    Implementation on the full dataset will be implemented in a later PR.

    opened by shrinuKushagra 4
  • Make `FactorGraph` mutable to support interactive model building

    Make `FactorGraph` mutable to support interactive model building

    Should implement interface for:

    1. Add factors
    2. Set evidence for variables
    3. Initialize messages by setting messages for factors
    4. Initialize messages by spreading beliefs from variables
    enhancement 
    opened by StannisZhou 4
  • Add customized class for pairwise factors; Default to have uniform potentials

    Add customized class for pairwise factors; Default to have uniform potentials

    Currently, users have to manually create an array of all possible configs and a uniform potential, but it would be nice to do this behind-the-scenes in some easy way. Maybe we can make it so that if either of these is None during init, then we assume all possible configs or uniform potential respectively and automatically create these

    enhancement 
    opened by NishanthJKumar 4
  • Make BP closer to jax optimizer

    Make BP closer to jax optimizer

    Resolves https://github.com/vicariousinc/PGMax/issues/124

    We make graph.BP closer to JAX optimizers https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html

    opened by antoine-dedieu 3
  • Provide high-level syntax for creating factors

    Provide high-level syntax for creating factors

    One of the speed bottleneck in creating a FactorGraph is the time to create the variables_for_factors list, which is currently slow as we loop through the individual variables.

    However, in the case where all the variable groups are NDVarArr we can speed up this step a lot proposing a generic get_factors interface where the user would define the general rule for the factors and the corresponding list would be generated with numba.

    One options is to have a first argument which consists of variable groups for which we loop over dimensions, and a second argument which consists of variable groups for do not loop over, For, example get_factors({x:(i, j), y:(k, l)}, {z:(i+k, j+l)}) would mean

    factors = []
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            for k in range(y.shape[0]):
                for l in range(y.shape[1]): 
                    factors.append((x[i, j], y[k, l], z[i+k, j+l]))
    
    opened by antoine-dedieu 0
  • Modify `vars_to_starts` representation

    Modify `vars_to_starts` representation

    Creating the vars_to_starts as a dictionnary mapping variable to int is expensive in the case where we have a lot of variables.

    Instead it could map a variable group to an array (in the case of a NDVariableArray) or a list (for a VariableDict)

    opened by antoine-dedieu 0
  • Improve documentation for variables/variable groups

    Improve documentation for variables/variable groups

    Currently it's not clear what names PGMax assigns to different variables (e.g. https://github.com/vicariousinc/PGMax/issues/115). Add documentation to make this clearer.

    documentation 
    opened by StannisZhou 0
Releases(v0.4.1)
  • v0.4.1(May 19, 2022)

    Highlights

    • Fixing two minor issues when running BP with variable groups defined with different number of states by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/144

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.4.0...v0.4.1

    Source code(tar.gz)
    Source code(zip)
  • v0.4.0(May 9, 2022)

    Breaking changes

    ⚠️ This release changes the high-level API as well as import paths ⚠️

    This release makes several major breaking changes to improve usability and efficiency of the package.

    1. Interacting with variables through VarGroup objects

    We no longer refer to variables by names, but instead directly interact with VarGroup objects. This change has several implications.

    • We no longer have a Variable class. Instead, we access individual variables by indexing into VarGroup objects.

    • FactorGraph can no longer be initialized with a dictionary of variable groups (as we no longer have names for variables). Instead, we initialize a FactorGraph by

    from pgmax import fgraph
    fg = fgraph.FactorGraph(variable_groups=variable_groups)
    

    where variable_groups is either a VarGroup or a list of VarGroups.

    • We can directly construct Factor/FactorGroup using individual variables, and have a unified add_factors interface for adding Factors and FactorGroups to the FactorGraph.

    For example, we can create a PairwiseFactorGroup via:

    from pgmax import fgroup
    pairwise_factors = fgroup.PairwiseFactorGroup(
        variables_for_factors=variables_for_factors,
        log_potential_matrix=log_potential_matrix,
    )
    

    where variables_for_factors is a list of list of individual variables. And we can add factors to a FactorGraph fg by

    fg.add_factors(factors=factors)
    

    where factors can be individual Factor, individual FactorGroup, or a list of Factors and FactorGroups.

    • We access LBP results by indexing with VarGroup. For example, after running BP, we can get the MAP decoding for the VarGroup visible_variables via
    beliefs = bp.get_beliefs(bp_arrays)
    map_states_visible = infer.decode_map_states(beliefs)[visible_variables]
    

    2. Efficient construction of FactorGroup

    We have implemented efficient construction of FactorGroup. Going forward, we always recommend constructing FactorGroup instead of individual Factor.

    3. Improved LBP interface

    We first create the functions used to run BP with temperature T via

    from pgmax import infer
    bp = infer.BP(fg.bp_state, temperature=T)
    

    where bp contains functions that initialize or updates the arrays involved in LBP.

    We can initialize bp_arrays by

    bp_arrays = bp.init()
    

    apply log potentials, messages and evidence updates by

    bp_arrays = bp.update(
        bp_arrays=bp_arrays,
    	log_potentials_updates=log_potentials_updates,
    	ftov_msgs_updates=ftov_msgs_updates,
    	evidence_updates=evidence_updates,
    )
    

    and run bp for a certain number of iterations by

    bp_arrays = bp.run_bp(bp_arrays, num_iters=num_iters, damping=damping)
    

    Note that we can arbitrarily interleave bp.update with bp.run_bp, which allows flexible control over how we run LBP.

    4. Improved high-level module organization

    Now we have 5 main high-level modules, fgraph for factor graphs, factor for factors, vgroup for variable groups, fgroup for factor groups, and infer for LBP.

    Details of what has changed:

    • Speed up the process of adding Factors and compiling wiring for a FactorGraph by moving all the computations to the FactorGroup level, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/129
    • Speed up the process of computing log potentials + wiring for FactorGroup with numba, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/133
    • Make the BP class behavior closer to JAX optimizers by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/135
    • Get rid of the Variables and CompositeVariableGroup classes + of the variable names + adopt a simpler representation for variables + rely on numpy arrays to makeNDVarArray efficient, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/136
    • Overall module reorganization, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/140

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.3.0...v0.4.0

    Source code(tar.gz)
    Source code(zip)
  • v0.3.0(Mar 25, 2022)

    Highlights

    • Refactors to support adding different factor types with specialized inference procedures by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122
    • Specialized logical AND/OR factors by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122 https://github.com/vicariousinc/PGMax/pull/126
    • New example on 2D binary blind deconvolution by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/127

    New Contributors

    • @antoine-dedieu made his first contribution in https://github.com/vicariousinc/PGMax/pull/122

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.3...v0.3.0

    Source code(tar.gz)
    Source code(zip)
  • v0.2.3(Feb 19, 2022)

    What's Changed

    • Links to blog post and companion paper; Documentation updates by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/111
    • Get rid of redundant array shape for log_potentials by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/116
    • Support python 3.9/3.10; Improve documentation for add_factor; Bump up version for new release by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/120

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.2...v0.2.3

    Source code(tar.gz)
    Source code(zip)
  • v0.2.2(Jan 22, 2022)

    What's Changed

    • Update README by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/103
    • RCN example by @shrinuKushagra in https://github.com/vicariousinc/PGMax/pull/96
    • Add support for sum-product with temperature by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/104
    • Include Grid Markov Random Field example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/107
    • Changes for blog post by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/109

    New Contributors

    • @shrinuKushagra made their first contribution in https://github.com/vicariousinc/PGMax/pull/96

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.1...v0.2.2

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Dec 1, 2021)

    What's Changed

    • Bump versions for publishing by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/63
    • Example notebook with PMAP sampling of RBMs trained on MNIST digits by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/80
    • Use functools.partial instead of jax.partial by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/83
    • First pass for speeding up graph and evidence operations by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/84
    • Moving to a functional interface by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/88
    • Update README in preparation for making repo public by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/89
    • Pre commit ci test by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/90
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/91
    • Update dependency requirements to be less aggresive by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/92
    • adds codecov badge to README! by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/94
    • Fix GPU memory leak that came up in RCN example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/97
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/95
    • Docs update by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/98
    • fixes bug where wrong conf.py path was specified by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/99
    • Rtd warning fix by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/100
    • includes minor changes to README and documentation by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/101
    • Bump up version; Fixes for docs by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/102

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Sep 6, 2021)

    Features

    • Efficient and scalable max-product belief propagation using a fully flat representation
    • A general factor graph interface that supports easy specification of PGMs with pairwise factors and higher-order factors based on explicit enumeration
    • Mechanisms for evidence and message manipulation
    • 3 example notebooks showcasing the functionalities of PGMax
    Source code(tar.gz)
    Source code(zip)
Code of paper "CDFI: Compression-Driven Network Design for Frame Interpolation", CVPR 2021

CDFI (Compression-Driven-Frame-Interpolation) [Paper] (Coming soon...) | [arXiv] Tianyu Ding*, Luming Liang*, Zhihui Zhu, Ilya Zharkov IEEE Conference

Tianyu Ding 95 Dec 04, 2022
Diagnostic tests for linguistic capacities in language models

LM diagnostics This repository contains the diagnostic datasets and experimental code for What BERT is not: Lessons from a new suite of psycholinguist

61 Jan 02, 2023
Useful materials and tutorials for 110-1 NTU DBME5028 (Application of Deep Learning in Medical Imaging)

Useful materials and tutorials for 110-1 NTU DBME5028 (Application of Deep Learning in Medical Imaging)

7 Jun 22, 2022
Public repo for the ICCV2021-CVAMD paper "Is it Time to Replace CNNs with Transformers for Medical Images?"

Is it Time to Replace CNNs with Transformers for Medical Images? Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (C

Christos Matsoukas 80 Dec 27, 2022
This repository consists of Blender python scripts and corresponding assets to generate variants of the CANDLE dataset

candle-simulator This repository consists of Blender python scripts and corresponding assets to generate variants of the IITH-CANDLE dataset. The rend

1 Dec 15, 2021
source code for 'Finding Valid Adjustments under Non-ignorability with Minimal DAG Knowledge' by A. Shah, K. Shanmugam, K. Ahuja

Source code for "Finding Valid Adjustments under Non-ignorability with Minimal DAG Knowledge" Reference: Abhin Shah, Karthikeyan Shanmugam, Kartik Ahu

Abhin Shah 1 Jun 03, 2022
A collection of educational notebooks on multi-view geometry and computer vision.

Multiview notebooks This is a collection of educational notebooks on multi-view geometry and computer vision. Subjects covered in these notebooks incl

Max 65 Dec 09, 2022
Real-Time SLAM for Monocular, Stereo and RGB-D Cameras, with Loop Detection and Relocalization Capabilities

ORB-SLAM2 Authors: Raul Mur-Artal, Juan D. Tardos, J. M. M. Montiel and Dorian Galvez-Lopez (DBoW2) 13 Jan 2017: OpenCV 3 and Eigen 3.3 are now suppor

Raul Mur-Artal 7.8k Dec 30, 2022
Code for Universal Semi-Supervised Semantic Segmentation models paper accepted in ICCV 2019

USSS_ICCV19 Code for Universal Semi Supervised Semantic Segmentation accepted to ICCV 2019. Full Paper available at https://arxiv.org/abs/1811.10323.

Tarun K 68 Nov 24, 2022
Temporal-Relational CrossTransformers

Temporal-Relational Cross-Transformers (TRX) This repo contains code for the method introduced in the paper: Temporal-Relational CrossTransformers for

83 Dec 12, 2022
H&M Fashion Image similarity search with Weaviate and DocArray

H&M Fashion Image similarity search with Weaviate and DocArray This example shows how to do image similarity search using DocArray and Weaviate as Doc

Laura Ham 18 Aug 11, 2022
Official PyTorch Implementation of Rank & Sort Loss [ICCV2021]

Rank & Sort Loss for Object Detection and Instance Segmentation The official implementation of Rank & Sort Loss. Our implementation is based on mmdete

Kemal Oksuz 229 Dec 20, 2022
MiniSom is a minimalistic implementation of the Self Organizing Maps

MiniSom Self Organizing Maps MiniSom is a minimalistic and Numpy based implementation of the Self Organizing Maps (SOM). SOM is a type of Artificial N

Giuseppe Vettigli 1.2k Jan 03, 2023
利用yolov5和TensorRT从0到1实现目标检测的模型训练到模型部署全过程

写在前面 利用TensorRT加速推理速度是以时间换取精度的做法,意味着在推理速度上升的同时将会有精度的下降,不过不用太担心,精度下降微乎其微。此外,要有NVIDIA显卡,经测试,CUDA10.2可以支持20系列显卡及以下,30系列显卡需要CUDA11.x的支持,并且目前有bug。 默认你已经完成了

Helium 6 Jul 28, 2022
BOVText: A Large-Scale, Multidimensional Multilingual Dataset for Video Text Spotting

BOVText: A Large-Scale, Bilingual Open World Dataset for Video Text Spotting Updated on December 10, 2021 (Release all dataset(2021 videos)) Updated o

weijiawu 47 Dec 26, 2022
This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".

A Memory-saving Training Framework for Transformers This is the official PyTorch implementation for Mesa: A Memory-saving Training Framework for Trans

Zhuang AI Group 105 Dec 06, 2022
Discovering Interpretable GAN Controls [NeurIPS 2020]

GANSpace: Discovering Interpretable GAN Controls Figure 1: Sequences of image edits performed using control discovered with our method, applied to thr

Erik Härkönen 1.7k Jan 03, 2023
a reimplementation of Holistically-Nested Edge Detection in PyTorch

pytorch-hed This is a personal reimplementation of Holistically-Nested Edge Detection [1] using PyTorch. Should you be making use of this work, please

Simon Niklaus 375 Dec 06, 2022
Chainer implementation of recent GAN variants

Chainer-GAN-lib This repository collects chainer implementation of state-of-the-art GAN algorithms. These codes are evaluated with the inception score

399 Oct 23, 2022
This repo contains the code for the paper "Efficient hierarchical Bayesian inference for spatio-temporal regression models in neuroimaging" that has been accepted to NeurIPS 2021.

Dugh-NeurIPS-2021 This repo contains the code for the paper "Efficient hierarchical Bayesian inference for spatio-temporal regression models in neuroi

Ali Hashemi 5 Jul 12, 2022