GithubHelp home page GithubHelp logo

sbi-dev / pyknos Goto Github PK

View Code? Open in Web Editor NEW
25.0 25.0 5.0 109 KB

Conditional density estimation with neural networks

License: Apache License 2.0

Python 100.00%
density-estimation mixture-density-networks normalizing-flows

pyknos's People

Contributors

anasbekheit avatar harry-fuyu avatar jan-matthis avatar janfb avatar jnsbck avatar michaeldeistler avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pyknos's Issues

Tests fail randomly due to RNG

some of the tests check that the forward and backward passes of a transform compose to give the identity function. there are numerical checks on this with specified tolerances, often 1e-4. This can succeed or fail depending on RNG state, which is not currently set by any tests.

Desirable MDN features for SNPE-C / SNPE-A

After having implemented non-atomic SNPE-C, I am writing this issue to keep track of things that would have been desirable to exist in mdn.

Get mixture components

The only non-protected methods of mdn are log_prob() and sample(). It would be great to have a non-protected get_mixture_components. Unlike the already existing, protected _get_mixture_components(), it should also call the embedding_net.

Evaluating the log_prob

The following code should be put in a separate static method called evaluate_mixture_log_prob():

batch_size, n_mixtures, output_dim = means.size()
inputs = inputs.view(-1, 1, output_dim)

# Split up evaluation into parts.
a = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
b = -(output_dim / 2.0) * np.log(2 * np.pi)
c = sumlogdiag
d1 = (inputs.expand_as(means) - means).view(
     batch_size, n_mixtures, output_dim, 1
)
d2 = torch.matmul(precisions, d1)
d = -0.5 * torch.matmul(torch.transpose(d1, 2, 3), d2).view(
          batch_size, n_mixtures
)

This would allow to evaluate the log_prob of a MoG without instantiating a mdn. Since snpe_c has to do this for every training data point at every iteration, it would be computationally cheaper.

Along with the above refactoring of get_mixture_components, this fully separates the two main steps of calling log_prob() in an mdn.

Log-prob based on cov

Right now, we use sumlogdiag for log_prob. If one does not yet have the cholesky trafo, it would be better to use log(det(cov))

Different init strategy for the means

Means are initialized close to 0. Maybe initializing at more random locations would be better.

Variable number of layers

Requires to write a forward function that loops over a torch.ModuleList.

Add MDN with support for parameter transformations

Building on Marcel's work, add an MDN (Mixture Density Network) implementation that can play well with e.g. SNPE-A proposal posterior corrections.

This is the wishlist

  • basic MDN that works and is useable from sbi (mixture of full-covariance Gaussians)
  • MDN becomes highly compositional - only last layer implemented (MDN layer), the rest is built e.g. using Sequential and possibly heuristics given data dimensions, etc.
  • MDN layer returns a MixtureSameFamily density (becoming part of PyTorch, PR)
  • (long term) MDN layer can self-configure given the particulars of the desired MixtureSameFamily
  • Building the whole MDN does not require specification of redundant information, much like Keras s sequential (or look at thinc for a more functional take).

refactor MDN class

The MDN class has several inline comments and TODOs that are unclear. We should refactor it and improve documentation.

Revamp model specification and parameterization

get_models from nflows doesn't scale, is annoying

Instead, as a minimum, we want a pluggable system (=extensible) for defining models that is based on templating.

  • A designated directory, e.g. zoo or lib contains modules that group parameterizable functions. Lookup is done via getattr. The user gets to specify in global configuration where else to look for models. A model's identifier is the namespace + function name. Parameters can be read from similarly-named dictionaries, serialized as e.g. yaml files (example).
  • Models can be reused simply by copying and renaming.

For extra points:

  • extend the configuration system to training hyperparameters, logging, etc. As an example, consider the respective sections here.
  • study thinc's configuration system and potentially adopt it (needs to cover functionality described above). It's based on Google's gin though they it's "simpler and emphasizes a different workflow via a subset of Gin’s functionality".

What we don't want, under ANY circumstance: magic numbers sprinkled in the code, long command lines that disappear in the console's log (at best). What we want: declarative configuration with a measure of imperative flexibility (this is exactly what gin/thinc.config provide, see thinc's intro notebook).

The system should eventually be reusable in sbi.

deprecation of torch.linalg.solve_triangular

We need to update the use of torch.linalg.solve_triangular here:

https://github.com/mackelab/pyknos/blob/5ea3d5b81ecc0f72110d3d03d4a6148c919f3c7c/pyknos/mdn/mdn.py#L277-L287

see

pyknos/mdn/mdn.py:279: UserWarning: torch.triangular_solve is deprecated in favor of torch.linalg.solve_triangularand will be removed in a future PyTorch release.
  torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
  X = torch.triangular_solve(B, A).solution
  should be replaced with
  X = torch.linalg.solve_triangular(A, B). (Triggered internally at  /Users/distiller/project/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp:1672.)
    zero_mean_samples, _ = torch.triangular_solve(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html

[Code Quality] Incomplete Repo information

Your repo does not comply to the standards we defined in the lab.
Make sure for your repo to have:

  • a description including the github handle of the owner
  • a > 3 line README.md

If you don't update your repo, it will be disabled and then archived.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.