GithubHelp home page GithubHelp logo

Comments (19)

vegapit avatar vegapit commented on June 16, 2024 1

It is just a case of adding numerical approximation functions for pdfs and cdfs of popular statistical distributions. I personally just implemented the Normal cdf derived in this article using tensors and could normally calculate the gradients with the library:

fn norm_cdf(x: &Tensor) -> Tensor {
    let denom = ((-358.0 * x / 23.0) + 111.0 * (37.0 * x / 294.0).atan()).exp() + 1.0;
    1.0 / denom
}

from tch-rs.

vegapit avatar vegapit commented on June 16, 2024 1

I use Torch to solve for maximum likelihood in non-linear parametrised models. The density functions can be reconstructed or approximated using tensor methods but I guess it is clearly more user friendly to provide them in wrappers as Pytorch does

from tch-rs.

spebern avatar spebern commented on June 16, 2024 1

I started to work on a crate to port the distributions: https://github.com/spebern/tch-distr-rs
Besides porting, the most tedious work is to test everything.

I think the best way to ease porting is to test against the python implementations directly by supplying the same input and comparing the outputs. Later, this can also be extended with fuzzy input.

The Distribution trait is open for discussion and pull requests are more than welcome to add more distributions/tests.

from tch-rs.

LaurentMazare avatar LaurentMazare commented on June 16, 2024

Thanks for the suggestion. This indeed seems like a nice thing to add - the underlying torch primitives for sampling should already exist in the rust wrappers, e.g. normal so I guess this would mostly consist in adding some trait for distributions and implementing it mimicking the python implementation, e.g. Normal.
Would that be useful to you?
Also which methods/distributions would be the most interesting to you?

from tch-rs.

jerry73204 avatar jerry73204 commented on June 16, 2024

Yes, I see that normal function. I need probability arithmetic rather than prob generators. Take the GQN for example, I need a Normal object which parameters are defined by tensors, and compute the log_prob (log prob mass function) for any value on the Normal. I also have to compute KL div for two normals. There's no need for sampling.

Torch and TensorFlow have their own such function. Interestingly, you'll see NotImplementedError if you look into Torch's source code. So I bet improving rv would be a good direction. Currently I just write raw formulas with cautions on floating point precisions.

from tch-rs.

LaurentMazare avatar LaurentMazare commented on June 16, 2024

Do you need to backprop through your KL divergence ? If that's the case I'm not sure that rv could be used out of the box but maybe I'm missing something ?
When looking at the 'Normal' distribution in torch, I don't see much NotImplementedError besides in the base Distribution class. Also the kl divergence for the normal distribution can be found here.

from tch-rs.

jerry73204 avatar jerry73204 commented on June 16, 2024

The NotImplememtedError goes here. Sorry for inprecise comment.

Backprop is desired in my case. As far as I know, torch.distributioms looks like an add-on of pytorch rather than in libtorch. Implementation in Rust is like a completely new library.

Let me write my some code for this and to see any proper way to build this feature.

from tch-rs.

LaurentMazare avatar LaurentMazare commented on June 16, 2024

Re NotImplementedError, indeed that's the base class for distributions in pytorch so nothing is implemented here and the various methods are implemented in the derived classes like 'Normal'.
The distribution bit is included within the main pytorch repo and package (contrary to say vision which is an external repo and pypi package). I don't have a proper opinion on where this belongs - in the main crate or in a separate one but starting by an external crate sounds good and if there is some upside to merge it in the main repo we can consider it later.
Let me know if you notice some pytorch primitives missing from tch-rs that could be useful to you!

from tch-rs.

LaurentMazare avatar LaurentMazare commented on June 16, 2024

Just to mention that I added a variational auto-encoder example to tch-rs. This inctludes a KL divergence loss here.
It's certainly very far away from what a nice distributions api would provide but it may be handy.

from tch-rs.

LaurentMazare avatar LaurentMazare commented on June 16, 2024

@vegapit yes that's mostly about adding such functions and probably some traits for the various distributions.
You can see the implementation for the normal distribution in the python api here. The code for the cdf is a bit different from yours and relies on torch.erf. Not sure which one has the best precision.

def cdf(self, value):
        return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)))

from tch-rs.

vegapit avatar vegapit commented on June 16, 2024

@LaurentMazare I did not know you actually had already added the error function implementation in Torch, otherwise I would have used it. I do not know what the numerical approximation in torch.erf is but my guess is that its precision must be similar to the function I described above.

from tch-rs.

LaurentMazare avatar LaurentMazare commented on June 16, 2024

Btw the error function is also available in tch-rs https://docs.rs/tch/0.1.0/tch/struct.Tensor.html#method.erf
(as the rust low level wrappers are automatically generated, we mostly get these for free)

from tch-rs.

jerry73204 avatar jerry73204 commented on June 16, 2024

A little update here. In my previous gqnrs project has some useful traits for prob distributions. There is only Normal dist there. Suppose we can start a working branch to fill the blanks for {Bernoulli,Exponential,Categorical, etc} dists?

from tch-rs.

LaurentMazare avatar LaurentMazare commented on June 16, 2024

Yes that kind of trait would indeed probably be useful, @vegapit do you think this would cover your use case?

from tch-rs.

LaurentMazare avatar LaurentMazare commented on June 16, 2024

This looks very nice, thanks for sharing! Once this has been polished/tested a bit, we could probably mention it in the tch-rs main readme file for better discoverability if that's ok with you.

from tch-rs.

spebern avatar spebern commented on June 16, 2024

That would be really nice! It definitely needs some polishing and more implemented distributions, but
I really think that the testing against python implementations takes away a lot of work.

from tch-rs.

dbsxdbsx avatar dbsxdbsx commented on June 16, 2024

This looks very nice, thanks for sharing! Once this has been polished/tested a bit, we could probably mention it in the tch-rs main readme file for better discoverability if that's ok with you.

@LaurentMazare, I wonder if it a good idea to introduce tch-distr-rs of @spebern as a feature for torch-rs. So that user no need to introduce 2 crates when using torch in rust.

from tch-rs.

LaurentMazare avatar LaurentMazare commented on June 16, 2024

If it's just to avoid having an additional dependency for crates that would want to use this, I would lean more towards keeping an external crate, and in general having smaller composable crates for the bits that are outside of the core tch-rs, e.g. I'm more thinking about moving the vision models in their own crate, the RL bits to their own thing too etc.

from tch-rs.

dbsxdbsx avatar dbsxdbsx commented on June 16, 2024

If it's just to avoid having an additional dependency for crates that would want to use this, I would lean more towards keeping an external crate, and in general having smaller composable crates for the bits that are outside of the core tch-rs, e.g. I'm more thinking about moving the vision models in their own crate, the RL bits to their own thing too etc.

@LaurentMazare , the reason for why I hope tch-distr-rs could be part of tch-rs is that in pytorch, the distribution part code is also part of the whole python torch module, though it is not a part of code in the C++ version. Meanwhile, I think it not proper to treat tch-distr-rs as a tool ONLY for reinforcement learning or some other fields.

Therefore, I suggest making it as an optional feature, which would also be flexible (as user could decide whether to include it or not through tag "feature" in Cargo.toml) and easy to transfer from pytorch for users familiar with pytorch.

from tch-rs.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.