Loopy belief propagation for factor graphs on discrete variables, in JAX!

Overview

continuous-integration PyPI version pre-commit.ci status codecov Documentation Status

PGMax

PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.

  • General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
  • LBP in JAX: PGMax generates pure JAX functions implementing LBP for a given factor graph. The generated pure JAX functions run on modern accelerators (GPU/TPU), work with JAX transformations (e.g. vmap for processing batches of models/samples, grad for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.

See our blog post and companion paper for more details.

Installation | Getting started

Installation

Install from PyPI

pip install pgmax

Install latest version from GitHub

pip install git+https://github.com/vicariousinc/PGMax.git

Developer

git clone https://github.com/vicariousinc/PGMax.git
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python3 -
cd PGMax
poetry shell
poetry install
pre-commit install

Install on GPU

By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.

Getting Started

Here are a few self-contained Colab notebooks to help you get started on using PGMax:

Citing PGMax

Please consider citing our companion paper if you use PGMax in your work:

@article{zhou2022pgmax,
  author = {Zhou, Guangyao and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
  title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
  journal = {arXiv preprint arXiv:2202.04110},
  year={2022}
}

First two authors contributed equally.

Comments
  • Incomplete documentation for add_factor

    Incomplete documentation for add_factor

    The documentation for add_factor currently says:

    log_potentials – Optional array of shape (num_val_configs,) or (num_factors, num_val_configs). If specified, it contains the log of the potential value for every possible configuration. If none, it is assumed the log potential is uniform 0 and such an array is automatically initialized.

    However, in order to use it one would need to know the order in which the factors should be specified. Could this be added to the documentation?

    bug documentation 
    opened by nathanielvirgo 8
  • Add support for Python 3.9 and 3.10

    Add support for Python 3.9 and 3.10

    I don't know if this is a bug or a problem with my installation, but if I try to run a file containing only the line

    from pgmax.fg import graph
    

    I get the error

    % /opt/local/bin/python3 /Users/nathaniel/Dropbox/Code/PGMax/test01.py
    Traceback (most recent call last):
      File "/Users/nathaniel/Dropbox/Code/PGMax/test01.py", line 1, in <module>
        from pgmax.fg import graph
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/fg/graph.py", line 11, in <module>
        import pgmax.bp.infer as infer
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/infer.py", line 6, in <module>
        import pgmax.bp.bp_utils as bp_utils
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/bp_utils.py", line 11, in <module>
        @jax.partial(jax.jit, static_argnames="max_segment_length")
    AttributeError: module 'jax' has no attribute 'partial'
    

    This is with Python3.9 installed using Macports on MacOS 10.15.7. PGMax, jax and other prerequisities were installed with pip-3.9 install --user PGMax. I'm happy to give any other information about my installation if you can tell me how to obtain it.

    enhancement 
    opened by nathanielvirgo 6
  • Test sanity check example using new interface and inference modules, and put together the first unit test

    Test sanity check example using new interface and inference modules, and put together the first unit test

    The unit test should run fast. One option is to cache new results. Another option is to just make the model really small.

    In the process, we should also:

    1. Deprecate the current contrib module and create a new examples directory to hold everything.
    2. Start figuring out what our user facing interface should look like.
    opened by StannisZhou 6
  • FactorGraph supports any type of factors + runs specialized inference for ORFactors

    FactorGraph supports any type of factors + runs specialized inference for ORFactors

    In this PR we

    1. Redefine the factor graph abstraction by introducing factor types: factors in a graph are clustered in factor groups, which are grouped according to their factor types. See fg/graph.py
    2. Specify two types of factors: EnumerationFactor (this class already existed) and ORFactor (this new class inherits from the new LogicalFactor). Each Factor class must have its own methods to compile and concatenate Wirings for inference. See factors/enumeration.py and factors/logical.py.
    3. Make running inference in a graph agnostic to the current type of factors supported. New factors types can then be added without modifying graph.py.
    4. Implement a specialized inference for ORFactors (see pass_OR_fac_to_var_messages in factors/logical.py) and compare it with the existing one for EnumerationFactors in the unit test tests/factors/test_or.py
    opened by antoine-dedieu 5
  • RCN example

    RCN example

    This PR contains an example implementation of RCN using the pgmax package. We load a pre-trained RCN model on a very small subset of mnist (20 examples) and test on a small subset of mnist (20 examples). The reported accuracy = 0.80.

    opened by shrinuKushagra 5
  • Variables refactor

    Variables refactor

    We update the way of representing variables. In particular:

    • We get rid of variables names, as welll as of the Variables and CompositeVariableGroup classes. A variable is now represented by a tuple (variable hash, variable num_states) In particular, a FactorGraph can then directly be instantiated asfg = graph.FactorGraph(variables=[hidden_variables, visible_variables]) Similarly, Factors are defined by directly passing the variables involved, as [hidden_variables[ii], visible_variables[jj]]
    • We rewrite NDVariableArray so that the user can access variables by relying on the use of numpy arrays. We also optimize some follow-up computations.
    opened by antoine-dedieu 4
  • Numba speedup for wiring + log potentials

    Numba speedup for wiring + log potentials

    This PR is the continuation of https://github.com/vicariousinc/PGMax/pull/129 and part of our efforts to speed up the adding of FactorGroups and the wiring compilation.

    As https://github.com/vicariousinc/PGMax/pull/129 has moved most of the wiring computation to the FactorGroup level, we can now use numba for fast computation of these wirings

    As a result:

    • adding factors for the RBM exp takes 3s, building run_bp takes 1s
    • adding factors for the convor exp takes 2s, building run_bp takes 1s
    opened by antoine-dedieu 4
  • RCN implementation on a small train and test set

    RCN implementation on a small train and test set

    This PR contains the first implementation of the RCN example using the PGMax package. The file to run is examples/rcn/inference_pgmax_small.py This code contains implementation, visualization on a small set. Trained with 20 examples and tested on 20 examples.

    The inference has been separated from model creation code. Saved models are added to /storage/users/skushagra/pgmax_rcn_artifacts/ .

    Implementation on the full dataset will be implemented in a later PR.

    opened by shrinuKushagra 4
  • Make `FactorGraph` mutable to support interactive model building

    Make `FactorGraph` mutable to support interactive model building

    Should implement interface for:

    1. Add factors
    2. Set evidence for variables
    3. Initialize messages by setting messages for factors
    4. Initialize messages by spreading beliefs from variables
    enhancement 
    opened by StannisZhou 4
  • Add customized class for pairwise factors; Default to have uniform potentials

    Add customized class for pairwise factors; Default to have uniform potentials

    Currently, users have to manually create an array of all possible configs and a uniform potential, but it would be nice to do this behind-the-scenes in some easy way. Maybe we can make it so that if either of these is None during init, then we assume all possible configs or uniform potential respectively and automatically create these

    enhancement 
    opened by NishanthJKumar 4
  • Make BP closer to jax optimizer

    Make BP closer to jax optimizer

    Resolves https://github.com/vicariousinc/PGMax/issues/124

    We make graph.BP closer to JAX optimizers https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html

    opened by antoine-dedieu 3
  • Provide high-level syntax for creating factors

    Provide high-level syntax for creating factors

    One of the speed bottleneck in creating a FactorGraph is the time to create the variables_for_factors list, which is currently slow as we loop through the individual variables.

    However, in the case where all the variable groups are NDVarArr we can speed up this step a lot proposing a generic get_factors interface where the user would define the general rule for the factors and the corresponding list would be generated with numba.

    One options is to have a first argument which consists of variable groups for which we loop over dimensions, and a second argument which consists of variable groups for do not loop over, For, example get_factors({x:(i, j), y:(k, l)}, {z:(i+k, j+l)}) would mean

    factors = []
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            for k in range(y.shape[0]):
                for l in range(y.shape[1]): 
                    factors.append((x[i, j], y[k, l], z[i+k, j+l]))
    
    opened by antoine-dedieu 0
  • Modify `vars_to_starts` representation

    Modify `vars_to_starts` representation

    Creating the vars_to_starts as a dictionnary mapping variable to int is expensive in the case where we have a lot of variables.

    Instead it could map a variable group to an array (in the case of a NDVariableArray) or a list (for a VariableDict)

    opened by antoine-dedieu 0
  • Improve documentation for variables/variable groups

    Improve documentation for variables/variable groups

    Currently it's not clear what names PGMax assigns to different variables (e.g. https://github.com/vicariousinc/PGMax/issues/115). Add documentation to make this clearer.

    documentation 
    opened by StannisZhou 0
Releases(v0.4.1)
  • v0.4.1(May 19, 2022)

    Highlights

    • Fixing two minor issues when running BP with variable groups defined with different number of states by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/144

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.4.0...v0.4.1

    Source code(tar.gz)
    Source code(zip)
  • v0.4.0(May 9, 2022)

    Breaking changes

    ⚠️ This release changes the high-level API as well as import paths ⚠️

    This release makes several major breaking changes to improve usability and efficiency of the package.

    1. Interacting with variables through VarGroup objects

    We no longer refer to variables by names, but instead directly interact with VarGroup objects. This change has several implications.

    • We no longer have a Variable class. Instead, we access individual variables by indexing into VarGroup objects.

    • FactorGraph can no longer be initialized with a dictionary of variable groups (as we no longer have names for variables). Instead, we initialize a FactorGraph by

    from pgmax import fgraph
    fg = fgraph.FactorGraph(variable_groups=variable_groups)
    

    where variable_groups is either a VarGroup or a list of VarGroups.

    • We can directly construct Factor/FactorGroup using individual variables, and have a unified add_factors interface for adding Factors and FactorGroups to the FactorGraph.

    For example, we can create a PairwiseFactorGroup via:

    from pgmax import fgroup
    pairwise_factors = fgroup.PairwiseFactorGroup(
        variables_for_factors=variables_for_factors,
        log_potential_matrix=log_potential_matrix,
    )
    

    where variables_for_factors is a list of list of individual variables. And we can add factors to a FactorGraph fg by

    fg.add_factors(factors=factors)
    

    where factors can be individual Factor, individual FactorGroup, or a list of Factors and FactorGroups.

    • We access LBP results by indexing with VarGroup. For example, after running BP, we can get the MAP decoding for the VarGroup visible_variables via
    beliefs = bp.get_beliefs(bp_arrays)
    map_states_visible = infer.decode_map_states(beliefs)[visible_variables]
    

    2. Efficient construction of FactorGroup

    We have implemented efficient construction of FactorGroup. Going forward, we always recommend constructing FactorGroup instead of individual Factor.

    3. Improved LBP interface

    We first create the functions used to run BP with temperature T via

    from pgmax import infer
    bp = infer.BP(fg.bp_state, temperature=T)
    

    where bp contains functions that initialize or updates the arrays involved in LBP.

    We can initialize bp_arrays by

    bp_arrays = bp.init()
    

    apply log potentials, messages and evidence updates by

    bp_arrays = bp.update(
        bp_arrays=bp_arrays,
    	log_potentials_updates=log_potentials_updates,
    	ftov_msgs_updates=ftov_msgs_updates,
    	evidence_updates=evidence_updates,
    )
    

    and run bp for a certain number of iterations by

    bp_arrays = bp.run_bp(bp_arrays, num_iters=num_iters, damping=damping)
    

    Note that we can arbitrarily interleave bp.update with bp.run_bp, which allows flexible control over how we run LBP.

    4. Improved high-level module organization

    Now we have 5 main high-level modules, fgraph for factor graphs, factor for factors, vgroup for variable groups, fgroup for factor groups, and infer for LBP.

    Details of what has changed:

    • Speed up the process of adding Factors and compiling wiring for a FactorGraph by moving all the computations to the FactorGroup level, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/129
    • Speed up the process of computing log potentials + wiring for FactorGroup with numba, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/133
    • Make the BP class behavior closer to JAX optimizers by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/135
    • Get rid of the Variables and CompositeVariableGroup classes + of the variable names + adopt a simpler representation for variables + rely on numpy arrays to makeNDVarArray efficient, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/136
    • Overall module reorganization, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/140

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.3.0...v0.4.0

    Source code(tar.gz)
    Source code(zip)
  • v0.3.0(Mar 25, 2022)

    Highlights

    • Refactors to support adding different factor types with specialized inference procedures by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122
    • Specialized logical AND/OR factors by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122 https://github.com/vicariousinc/PGMax/pull/126
    • New example on 2D binary blind deconvolution by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/127

    New Contributors

    • @antoine-dedieu made his first contribution in https://github.com/vicariousinc/PGMax/pull/122

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.3...v0.3.0

    Source code(tar.gz)
    Source code(zip)
  • v0.2.3(Feb 19, 2022)

    What's Changed

    • Links to blog post and companion paper; Documentation updates by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/111
    • Get rid of redundant array shape for log_potentials by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/116
    • Support python 3.9/3.10; Improve documentation for add_factor; Bump up version for new release by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/120

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.2...v0.2.3

    Source code(tar.gz)
    Source code(zip)
  • v0.2.2(Jan 22, 2022)

    What's Changed

    • Update README by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/103
    • RCN example by @shrinuKushagra in https://github.com/vicariousinc/PGMax/pull/96
    • Add support for sum-product with temperature by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/104
    • Include Grid Markov Random Field example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/107
    • Changes for blog post by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/109

    New Contributors

    • @shrinuKushagra made their first contribution in https://github.com/vicariousinc/PGMax/pull/96

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.1...v0.2.2

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Dec 1, 2021)

    What's Changed

    • Bump versions for publishing by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/63
    • Example notebook with PMAP sampling of RBMs trained on MNIST digits by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/80
    • Use functools.partial instead of jax.partial by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/83
    • First pass for speeding up graph and evidence operations by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/84
    • Moving to a functional interface by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/88
    • Update README in preparation for making repo public by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/89
    • Pre commit ci test by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/90
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/91
    • Update dependency requirements to be less aggresive by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/92
    • adds codecov badge to README! by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/94
    • Fix GPU memory leak that came up in RCN example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/97
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/95
    • Docs update by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/98
    • fixes bug where wrong conf.py path was specified by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/99
    • Rtd warning fix by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/100
    • includes minor changes to README and documentation by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/101
    • Bump up version; Fixes for docs by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/102

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Sep 6, 2021)

    Features

    • Efficient and scalable max-product belief propagation using a fully flat representation
    • A general factor graph interface that supports easy specification of PGMs with pairwise factors and higher-order factors based on explicit enumeration
    • Mechanisms for evidence and message manipulation
    • 3 example notebooks showcasing the functionalities of PGMax
    Source code(tar.gz)
    Source code(zip)
Azua - build AI algorithms to aid efficient decision-making with minimum data requirements.

Project Azua 0. Overview Many modern AI algorithms are known to be data-hungry, whereas human decision-making is much more efficient. The human can re

Microsoft 197 Jan 06, 2023
Codebase for the Summary Loop paper at ACL2020

Summary Loop This repository contains the code for ACL2020 paper: The Summary Loop: Learning to Write Abstractive Summaries Without Examples. Training

Canny Lab @ The University of California, Berkeley 44 Nov 04, 2022
Collection of generative models in Tensorflow

tensorflow-generative-model-collections Tensorflow implementation of various GANs and VAEs. Related Repositories Pytorch version Pytorch version of th

3.8k Dec 30, 2022
Official code for Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset

Official code for our Interspeech 2021 - Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset [1]*. Visually-grounded spoken language datasets c

Ian Palmer 3 Jan 26, 2022
Code for CVPR 2021 oral paper "Exploring Data-Efficient 3D Scene Understanding with Contrastive Scene Contexts"

Exploring Data-Efficient 3D Scene Understanding with Contrastive Scene Contexts The rapid progress in 3D scene understanding has come with growing dem

Facebook Research 182 Dec 30, 2022
Makes patches from huge resolution .svs slide files using openslide

openslide_patcher Makes patches from huge resolution .svs slide files using openslide Example collage I made from outputs:

2 Dec 23, 2021
ReSSL: Relational Self-Supervised Learning with Weak Augmentation

ReSSL: Relational Self-Supervised Learning with Weak Augmentation This repository contains PyTorch evaluation code, training code and pretrained model

mingkai 45 Oct 25, 2022
PASSL包含 SimCLR,MoCo,BYOL,CLIP等基于对比学习的图像自监督算法以及 Vision-Transformer,Swin-Transformer,BEiT,CVT,T2T,MLP_Mixer等视觉Transformer算法

PASSL Introduction PASSL is a Paddle based vision library for state-of-the-art Self-Supervised Learning research with PaddlePaddle. PASSL aims to acce

186 Dec 29, 2022
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

Joshua Ji 3 Aug 20, 2022
This repository provides a basic implementation of our GCPR 2021 paper "Learning Conditional Invariance through Cycle Consistency"

Learning Conditional Invariance through Cycle Consistency This repository provides a basic TensorFlow 1 implementation of the proposed model in our GC

BMDA - University of Basel 1 Nov 04, 2022
Differentiable Simulation of Soft Multi-body Systems

Differentiable Simulation of Soft Multi-body Systems Yi-Ling Qiao, Junbang Liang, Vladlen Koltun, Ming C. Lin [Paper] [Code] Updates The C++ backend s

YilingQiao 26 Dec 23, 2022
Complementary Patch for Weakly Supervised Semantic Segmentation, ICCV21 (poster)

CPN (ICCV2021) This is an implementation of Complementary Patch for Weakly Supervised Semantic Segmentation, which is accepted by ICCV2021 poster. Thi

Ferenas 20 Dec 12, 2022
PyTorch implementation of spectral graph ConvNets, NIPS’16

Graph ConvNets in PyTorch October 15, 2017 Xavier Bresson http://www.ntu.edu.sg/home/xbresson https://github.com/xbresson https://twitter.com/xbresson

Xavier Bresson 287 Jan 04, 2023
Learning To Have An Ear For Face Super-Resolution

Learning To Have An Ear For Face Super-Resolution [Project Page] This repository contains demo code of our CVPR2020 paper. Training and evaluation on

50 Nov 16, 2022
Companion code for the paper Theoretical characterization of uncertainty in high-dimensional linear classification

Companion code for the paper Theoretical characterization of uncertainty in high-dimensional linear classification Usage The required packages are lis

0 Feb 07, 2022
Siamese TabNet

Raifhack-DS-2021 https://raifhack.ru/ - Команда Звёздочка Siamese TabNet Сиамская TabNet предсказывает стоимость объекта недвижимости с price_type=1,

Daniel Gafni 15 Apr 16, 2022
U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

Dennis Bappert 104 Nov 25, 2022
Code for Paper: Self-supervised Learning of Motion Capture

Self-supervised Learning of Motion Capture This is code for the paper: Hsiao-Yu Fish Tung, Hsiao-Wei Tung, Ersin Yumer, Katerina Fragkiadaki, Self-sup

Hsiao-Yu Fish Tung 87 Jul 25, 2022
This is the repository for the paper "Have I done enough planning or should I plan more?"

Metacognitive Learning Tool box https://re.is.mpg.de What Is This? This repository contains two modules used to analyse metacognitive learning in huma

0 Dec 01, 2021
Object Tracking and Detection Using OpenCV

Object tracking is one such application of computer vision where an object is detected in a video, otherwise interpreted as a set of frames, and the object’s trajectory is estimated. For instance, yo

Happy N. Monday 4 Aug 21, 2022