sbi-dev / pyknos Goto Github PK
View Code? Open in Web Editor NEWConditional density estimation with neural networks
License: Apache License 2.0
Conditional density estimation with neural networks
License: Apache License 2.0
some of the tests check that the forward and backward passes of a transform compose to give the identity function. there are numerical checks on this with specified tolerances, often 1e-4. This can succeed or fail depending on RNG state, which is not currently set by any tests.
We want a sample-based comparison of a ground-truth and an estimated density, or two estimated densities.
See sbi-dev/sbi#1221
The current way of handling the batch dimension results in mixing up samples from different batch entires, e.g., from different xs.
After having implemented non-atomic SNPE-C, I am writing this issue to keep track of things that would have been desirable to exist in mdn
.
The only non-protected methods of mdn
are log_prob()
and sample()
. It would be great to have a non-protected get_mixture_components
. Unlike the already existing, protected _get_mixture_components()
, it should also call the embedding_net
.
The following code should be put in a separate static method called evaluate_mixture_log_prob()
:
batch_size, n_mixtures, output_dim = means.size()
inputs = inputs.view(-1, 1, output_dim)
# Split up evaluation into parts.
a = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
b = -(output_dim / 2.0) * np.log(2 * np.pi)
c = sumlogdiag
d1 = (inputs.expand_as(means) - means).view(
batch_size, n_mixtures, output_dim, 1
)
d2 = torch.matmul(precisions, d1)
d = -0.5 * torch.matmul(torch.transpose(d1, 2, 3), d2).view(
batch_size, n_mixtures
)
This would allow to evaluate the log_prob
of a MoG without instantiating a mdn
. Since snpe_c has to do this for every training data point at every iteration, it would be computationally cheaper.
Along with the above refactoring of get_mixture_components
, this fully separates the two main steps of calling log_prob()
in an mdn.
Right now, we use sumlogdiag
for log_prob. If one does not yet have the cholesky trafo, it would be better to use log(det(cov))
Means are initialized close to 0. Maybe initializing at more random locations would be better.
Requires to write a forward function that loops over a torch.ModuleList
.
Building on Marcel's work, add an MDN (Mixture Density Network) implementation that can play well with e.g. SNPE-A proposal posterior corrections.
This is the wishlist
Sequential
and possibly heuristics given data dimensions, etc.MixtureSameFamily
density (becoming part of PyTorch, PR)MixtureSameFamily
README.md
conda
and pip
git lfs
.black
- copySee https://github.com/mackelab/sbi/blob/master/README.md for reference.
The MDN class has several inline comments and TODOs that are unclear. We should refactor it and improve documentation.
The readme says the examples are in an examples folder, but I don't see such a folder :)
get_models from nflows doesn't scale, is annoying
Instead, as a minimum, we want a pluggable system (=extensible) for defining models that is based on templating.
zoo
or lib
contains modules that group parameterizable functions. Lookup is done via getattr
. The user gets to specify in global configuration where else to look for models. A model's identifier is the namespace + function name. Parameters can be read from similarly-named dictionaries, serialized as e.g. yaml files (example).For extra points:
What we don't want, under ANY circumstance: magic numbers sprinkled in the code, long command lines that disappear in the console's log (at best). What we want: declarative configuration with a measure of imperative flexibility (this is exactly what gin/thinc.config provide, see thinc's intro notebook).
The system should eventually be reusable in sbi.
We need to update the use of torch.linalg.solve_triangular
here:
see
pyknos/mdn/mdn.py:279: UserWarning: torch.triangular_solve is deprecated in favor of torch.linalg.solve_triangularand will be removed in a future PyTorch release.
torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at /Users/distiller/project/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp:1672.)
zero_mean_samples, _ = torch.triangular_solve(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
Your repo does not comply to the standards we defined in the lab.
Make sure for your repo to have:
If you don't update your repo, it will be disabled and then archived.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.