vicariousinc / pgmax Goto Github PK
View Code? Open in Web Editor NEWLoopy belief propagation for factor graphs on discrete variables, in JAX!
Home Page: https://pgmax.readthedocs.io
License: MIT License
Loopy belief propagation for factor graphs on discrete variables, in JAX!
Home Page: https://pgmax.readthedocs.io
License: MIT License
RBM is relevant as it is a standard model with two classes of variables (which are nicely supported here)
Maybe examples could be divided into two subfolders:
After this is done, make the get_all_vars method into a cached property!
Currently the factor graph works with individual variables/factors. But in many cases the factor graphs are constructed with variable groups/factor groups.
The factor graphs can exploit such structures to make certain operations more efficient/convenient. Some examples include:
vars_to_evidence
dictionary in setting the evidence.Consider making a 'PR Checklist' similar to JAX's
This interface should greatly simplify the specification of the sanity_check
example to make it extremely straightforward and simple to do. It should also be extensible to cover a wide variety of different use-cases (such as facilitating #16 in the future).
Currently, our examples are not easy to understand for a new user. I think it'd be nice to:
Learning RBMs on MNIST digits from Sec. 5.5 of PMAP paper
predict
method mentioned in tutorial which seems to be doing MAP inference) @NishanthJKumarIsing model experiments from Sec. 5.2 of PMAP paper
100 EBMs as described.
Currently we are using the deprecated pre-commit github action, which seems to have issues (e.g. for caching).
The new pre-commit ci framework looks better, but doesn't work with private repos.
Switch to the new pre-commit ci after repo is public
Migrate the best solution implemented in the contrib
module, taking into account new message flows.
To quote @lazarox :
I’ve never used clipping so I don’t know about that one. Maybe I would just clip the unaries and hope that everything else remains contained.
Normalization should not make any mathematical difference in the algorithm, it’s there only for numerical robustness. You should be able to normalize and renormalize at will without changing the results. So just normalize as needed so that all operations are numerically stable.
For the rest, this is how I do it:
outgoing messages: Factor-to-variable messages
incoming messages: Variable-to-factor messages
messages: vector of log-max-marginals, including all possible assignments of the variable. I.e., for binary variables it’s a two dimensional vector.
Then:
1. You have a set of outgoing messages. For each message, the maximum value should be zero. Renormalize if not. These messages are the key quantity we are updating. Everything else are derived quantities.
2. Compute the beliefs by adding the messages at each variable (parallel op)
3. Compute the incoming messages by subtracting the outgoing messages from the beliefs (parallel op)
4. For each factor compute the new outgoing messages new_outgoing (parallel op)
5. Renormalize new_outgoing so that the max of each message is 0
6. Compute the message deltas: delta=new_outgoing-outgoing (parallel op)
7. If you now updated the outgoing messages of a single factor to the new ones, you’d get serial MP. Instead do outgoing += eta*delta (parallel op). eta is the stepsize, i.e., 1-damping .
8. Renormalize outgoing so that the max of each message is 0
This would explain for instance:
Currently this has to be derive from existing examples
Add checking and more informative error message to prevent wrong format for log potentials
Originally posted by @lawchekun in #73 (comment)
Change the interface of FactorGraph
s to take as input a sequence of FactorGroup
s, instead of the current flat list of individual factors. Do various expansions (getting the flat list of variables and factors) inside the FactorGraph
class.
Currently, users have to manually create an array of all possible configs and a uniform potential, but it would be nice to do this behind-the-scenes in some easy way. Maybe we can make it so that if either of these is None
during init, then we assume all possible configs or uniform potential respectively and automatically create these
Pmap is as hard as inference so the pieces are here.
Maybe we could reproduce the RBM on MNIST digits experiments from Pmap paper.
This would become a package for "Learning and inference in GM"
Once #45 is resolved, also keep track of mapping from factors to starting indices in the flat message array to provide finer control over things like message initialization. Concretely:
FactorGroup
s to starting indices in the flat message array, inside the refactored FactorGraph
which takes a sequence of FactorGroup
s at initialization.FactorGroup
, keep a mapping from individual factors (indexed using a tuple of involved variable indices) to starting indices in the flat message array.Model design and learned weights provided by @lazarox
Can also incorporate some tools that analyze code coverage.
Should have decent coverage for the code base before the initial public release
Currently, the decode_map_states
function outputs a mapping from Variable
s to integers corresponding to the variable's MAP state. However, the user doesn't really have access to Variable
s, they only use keys to index Variable
s through a VariableGroup
, so it's rather unintuitive/cumbersome to output a mapping from Variable
to int
. Rather, it should be from keys to int
.
In a future PR we should implement this construction from VariableGroup
/CompositeVariableGroup
and FactorGroups
Originally posted by @StannisZhou in #30 (comment)
The unit test should run fast. One option is to cache new results. Another option is to just make the model really small.
In the process, we should also:
contrib
module and create a new examples
directory to hold everything.From @lawchekun :
key_tuple
in GenericVariableGroup
-> variable_names
keys
and instead use names
run_bp
@StannisZhou : Wanted to check the expected format for the segment
np.zeros(valid_configs_dict[edge[-1]], dtype=float)
As I got a
Traceback (most recent call last):
File "rcn_example.py", line 153, in <module>
np.zeros(valid_configs_dict[edge[-1]], dtype=float), # This line causes issues
ValueError: maximum supported dimension for an ndarray is 32, found 65
Here's the shape of valid_configs_dict[edge[-1]]
In [2]: valid_configs_dict[edge[-1]].shape
Out[2]: (65, 2)
I tried
np.zeros_like(valid_configs_dict[edge[-1]], dtype=float)
But got
Traceback (most recent call last):
File "rcn_example.py", line 154, in <module>
np.zeros_like(valid_configs_dict[edge[-1]], dtype=float),
File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/graph.py", line 129, in add_factor
self._variable_group, *new_args, **kwargs
File "<string>", line 7, in __init__
File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/groups.py", line 380, in __post_init__
self, "_keys_to_factors", MappingProxyType(self._get_keys_to_factors())
File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/groups.py", line 504, in _get_keys_to_factors
for ii in range(len(self.connected_var_keys))
File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/groups.py", line 504, in <dictcomp>
for ii in range(len(self.connected_var_keys))
File "/home/chekun/miniconda3/envs/pgmax/lib/python3.7/site-packages/pgmax/fg/groups.py", line 174, in __getitem__
if len(curr_key) < 2:
TypeError: object of type 'numpy.int32' has no len()
So I think I'm not parsing it right...
Originally posted by @lawchekun in #73 (comment)
pgmax/interface/datatypes.py:128: error: Argument 1 to "append" of "list" has incompatible type "Union[Variable, List[Variable]]"; expected "Variable"
Originally posted by @NishanthJKumar in #30 (comment)
Notes about pomegranate
Package Focus: easy to stack and sequence probabilistic models by considering them to just be an underlying probability distribution. Also supports parallelization and GPU computation.
Overlap: pomegranate implements loopy belief propagation on factor graphs. However, only MAP inference is implemented (so they only seem to support max product and not sum product)
Comparison we can make: Speed of running either sum-product or max-product, ease of specifying model (maybe in terms of number of lines required or something like that?)
Demonstrating our advantage: We should be able to specify models much more easily with PGMax than with pomegranate since pomegranate only lets you add one variable/factor at a time. Also, since PGMax's inference is JIT'ted end-to-end, I expect it to be more efficient than pomegranate's inference (even though pomegranate leverages GPU's)
Structure compiling takes 5.65s in https://github.com/NishanthJKumar/PGMax/blob/sanity_check/test_notebooks/sanity_check_optimize_mpbp_unpadded.py but 7.09s in https://github.com/NishanthJKumar/PGMax/blob/sanity_check/examples/sanity_check_example.py using the existing interface.
We should be able to optimize further, and parallelize the structure compiling process to significantly reduce the time taken.
Additional info sections are listed here
This seems to be a useful and cool tool that's easy to setup; should enable better discussions from within Vicarious and eventually beyond when the repo is open sourced!
Instructions here
Reproduce RCN experiments from science paper.
Demonstrate inference with a learned model with 100 templates.
Implement forward and backward pass within a single PGM.
Notes about PGMPy
Package Focus: easy-to-use from a user perspective, wide support for a variety of different types of PGM's, inference and learning algorithms
Overlap: Also supports belief propagation on discrete, undirected factor graphs
Comparison we can make: Speed of running either sum-product or max-product, ease of specifying model (maybe in terms of number of lines required or something like that?)
Demonstrating our advantage: PGMPy uses dicts and NumPy arrays for message passing, so we should be significantly more efficient. Also, for grid models, etc. we should be able to specify them much more easily with PGMax
See #19 (comment)
Should implement interface for:
Make get_evidence
and get_init_msgs
argument free, and make any necessary data for the two functions as attributes of children classes that inherit FactorGraph
class.
Also add docs about how to test coverage locally before opening a PR so that new contributors are aware of this.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.