GithubHelp home page GithubHelp logo

vicariousinc / pgmax Goto Github PK

View Code? Open in Web Editor NEW
63.0 63.0 9.0 2.21 MB

Loopy belief propagation for factor graphs on discrete variables, in JAX!

Home Page: https://pgmax.readthedocs.io

License: MIT License

Python 100.00%
jax python

pgmax's People

Contributors

antoine-dedieu avatar nishanthjkumar avatar pre-commit-ci[bot] avatar shrinukushagra avatar stanniszhou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

pgmax's Issues

Add RBM example

RBM is relevant as it is a standard model with two classes of variables (which are nicely supported here)

Maybe examples could be divided into two subfolders:

  • standard_models: Ising, RBM
  • complex_models: heretic, cut model, CMRF(?)

New factor graph interface that are aware of group structures in variables/factors

Currently the factor graph works with individual variables/factors. But in many cases the factor graphs are constructed with variable groups/factor groups.

The factor graphs can exploit such structures to make certain operations more efficient/convenient. Some examples include:

  1. Each variable group can implement a customized flattening function for a given evidence array. This way we no longer need to always go through the vars_to_evidence dictionary in setting the evidence.
  2. We can name and organize factors according to factor groups, and set messages, potentially for an entire factor group.

Experiments on PMAP learning of RBM

Setup

Learning RBMs on MNIST digits from Sec. 5.5 of PMAP paper

Metrics

  1. Sampling quality (Fig. 3(a) of PMAP paper)
  2. Visualization of samples (Fig. 3(b) of PMAP paper)
  3. Speed of inference

Concrete TODOs

Design and implement core inference functions

Migrate the best solution implemented in the contrib module, taking into account new message flows.

To quote @lazarox :

I’ve never used clipping so I don’t know about that one. Maybe I would just clip the unaries and hope that everything else remains contained.

Normalization should not make any mathematical difference in the algorithm, it’s there only for numerical robustness. You should be able to normalize and renormalize at will without changing the results. So just normalize as needed so that all operations are numerically stable.

For the rest, this is how I do it:
outgoing messages: Factor-to-variable messages
incoming messages: Variable-to-factor messages
messages: vector of log-max-marginals, including all possible assignments of the variable. I.e., for binary variables it’s a two dimensional vector.

Then:
1. You have a set of outgoing messages. For each message, the maximum value should be zero. Renormalize if not. These messages are the key quantity we are updating. Everything else are derived quantities.
2. Compute the beliefs by adding the messages at each variable (parallel op)
3. Compute the incoming messages by subtracting the outgoing messages from the beliefs (parallel op)
4. For each factor compute the new outgoing messages new_outgoing (parallel op)
5. Renormalize new_outgoing so that the max of each message is 0
6. Compute the message deltas: delta=new_outgoing-outgoing (parallel op)
7. If you now updated the outgoing messages of a single factor to the new ones, you’d get serial MP. Instead do outgoing += eta*delta  (parallel op). eta is the stepsize, i.e., 1-damping .
8. Renormalize outgoing so that the max of each message is 0

Add HowTo example notebook

This would explain for instance:

  • how to create a factor graph
  • how to use add_factor (with its different use)
  • how to give evidence
  • how to parallelize over scenes
  • ...

Currently this has to be derive from existing examples

Construct `FactorGraph`s using `FactorGroup`s

Change the interface of FactorGraphs to take as input a sequence of FactorGroups, instead of the current flat list of individual factors. Do various expansions (getting the flat list of variables and factors) inside the FactorGraph class.

Add customized class for pairwise factors; Default to have uniform potentials

Currently, users have to manually create an array of all possible configs and a uniform potential, but it would be nice to do this behind-the-scenes in some easy way. Maybe we can make it so that if either of these is None during init, then we assume all possible configs or uniform potential respectively and automatically create these

Keep track of mapping from factors to starting indices in the flat message array

Once #45 is resolved, also keep track of mapping from factors to starting indices in the flat message array to provide finer control over things like message initialization. Concretely:

  • Keep a mapping from FactorGroups to starting indices in the flat message array, inside the refactored FactorGraph which takes a sequence of FactorGroups at initialization.
  • Within each FactorGroup, keep a mapping from individual factors (indexed using a tuple of involved variable indices) to starting indices in the flat message array.

Make `decode_map_states` output a mapping from variable keys to MAP states

Currently, the decode_map_states function outputs a mapping from Variables to integers corresponding to the variable's MAP state. However, the user doesn't really have access to Variables, they only use keys to index Variables through a VariableGroup, so it's rather unintuitive/cumbersome to output a mapping from Variable to int. Rather, it should be from keys to int.

Update naming/docs based on feedback from internal beta

From @lawchekun :

  • key_tuple in GenericVariableGroup -> variable_names
  • In general maybe move away from using keys and instead use names
  • Would probably be helpful if there's a short example showing that init_msgs can be used to setup/trigger the belief_propagation too in run_bp

Check whether a key is a sequence before checking key length

@StannisZhou : Wanted to check the expected format for the segment

np.zeros(valid_configs_dict[edge[-1]], dtype=float)

As I got a

Traceback (most recent call last):
  File "rcn_example.py", line 153, in <module>
    np.zeros(valid_configs_dict[edge[-1]], dtype=float),  # This line causes issues
ValueError: maximum supported dimension for an ndarray is 32, found 65

Here's the shape of valid_configs_dict[edge[-1]]

In [2]: valid_configs_dict[edge[-1]].shape
Out[2]: (65, 2)

I tried

np.zeros_like(valid_configs_dict[edge[-1]], dtype=float)

But got

Traceback (most recent call last):
  File "rcn_example.py", line 154, in <module>
    np.zeros_like(valid_configs_dict[edge[-1]], dtype=float),
  File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/graph.py", line 129, in add_factor
    self._variable_group, *new_args, **kwargs
  File "<string>", line 7, in __init__
  File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/groups.py", line 380, in __post_init__
    self, "_keys_to_factors", MappingProxyType(self._get_keys_to_factors())
  File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/groups.py", line 504, in _get_keys_to_factors
    for ii in range(len(self.connected_var_keys))
  File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/groups.py", line 504, in <dictcomp>
    for ii in range(len(self.connected_var_keys))
  File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/groups.py", line 174, in __getitem__
    if len(curr_key) < 2:
TypeError: object of type 'numpy.int32' has no len()

So I think I'm not parsing it right...

Originally posted by @lawchekun in #73 (comment)

Benchmark against pomegranate

Notes about pomegranate

Package Focus: easy to stack and sequence probabilistic models by considering them to just be an underlying probability distribution. Also supports parallelization and GPU computation.

Overlap: pomegranate implements loopy belief propagation on factor graphs. However, only MAP inference is implemented (so they only seem to support max product and not sum product)

Comparison we can make: Speed of running either sum-product or max-product, ease of specifying model (maybe in terms of number of lines required or something like that?)

Demonstrating our advantage: We should be able to specify models much more easily with PGMax than with pomegranate since pomegranate only lets you add one variable/factor at a time. Also, since PGMax's inference is JIT'ted end-to-end, I expect it to be more efficient than pomegranate's inference (even though pomegranate leverages GPU's)

Experiments on GMRF

Reproduce experiments on learning a GMRF on BO dataset.

Requires #17 and #68

Demonstrate how PGMax can be used as part of a larger NN, by implementing this experiment using Trax, Flax and Haiku.

Experiments on RCN

Reproduce RCN experiments from science paper.

Demonstrate inference with a learned model with 100 templates.

Implement forward and backward pass within a single PGM.

Benchmark against PGMPy

Notes about PGMPy

Package Focus: easy-to-use from a user perspective, wide support for a variety of different types of PGM's, inference and learning algorithms

Overlap: Also supports belief propagation on discrete, undirected factor graphs

Comparison we can make: Speed of running either sum-product or max-product, ease of specifying model (maybe in terms of number of lines required or something like that?)

Demonstrating our advantage: PGMPy uses dicts and NumPy arrays for message passing, so we should be significantly more efficient. Also, for grid models, etc. we should be able to specify them much more easily with PGMax

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.