GithubHelp home page GithubHelp logo

Unable to jit logpdf about stheno HOT 10 OPEN

patel-zeel avatar patel-zeel commented on August 23, 2024
Unable to jit logpdf

from stheno.

Comments (10)

wesselb avatar wesselb commented on August 23, 2024 1

I see! Hmm, this might be challenging. Dispatch currently heavily leverages types, and the type of a PyTree is somewhat troublesome. You're right that jaxtyping offers a PyTree type, but that type seems to only perform instance checking rather than containing the recursive type definition that we would like. I'll have to think about this! It agree that it would be super useful to support PyTrees.

from stheno.

wesselb avatar wesselb commented on August 23, 2024 1

With the new version of Plum, I believe it should be possible to use the PyTree type from jaxtyping to get the desired behaviour. I'll open issues in the appropriate places.

from stheno.

wesselb avatar wesselb commented on August 23, 2024

Hey @patel-zeel! Nice to hear from you. :)

The reason why this is failing is because there is some logic going on which checks for missing values, as you've noticed, and this logic unfortunately doens't work well with the JIT.

The recommended solution is to use B.jit instead where import lab.jax as B. B.jit is a thin wrapper around jax.jit which runs an additional "compilation step" that takes care of code like the checking for NaNs:

import jax
import jax.numpy as jnp

from stheno.jax import GP, EQ
import lab.jax as B

x = jnp.arange(10)
y = jnp.arange(10)
lengthscale = jnp.array(1.0)
loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
grad_fn = B.jit(jax.grad(loss_fn))
grad_fn(lengthscale)

Output:

DeviceArray(50.12888031, dtype=float64)

Would you be able to check if this also works on your end?

Thanks for mentioning Chex! I wasn't aware of the library. It looks useful—I'm going to have a closer look!

from stheno.

patel-zeel avatar patel-zeel commented on August 23, 2024

Thank you, @wesselb, for a quick response! This works, and the depth of your customization is amazingly unimaginable :)

When I try with a slightly different variant (passing a dictionary instead of value), it throws plum.function.NotFoundLookupError to me:

import jax
import jax.numpy as jnp

from stheno.jax import GP, EQ
import lab.jax as B

x = jnp.arange(10)
y = jnp.arange(10)
params = {"lengthscale": jnp.array(1.0)}
loss_fn = lambda params: GP(EQ().stretch(params["lengthscale"]))(x).logpdf(y)
grad_fn = B.jit(jax.grad(loss_fn))
print(grad_fn(params))

Output

Traceback (most recent call last):
  File "/home/patel_zeel/gpax/testbed.py", line 12, in <module>
    print(grad_fn(params))
  File "/home/patel_zeel/miniconda3/envs/ajax/lib/python3.9/site-packages/lab/generic.py", line 146, in __call__
    return _jit_run(
  File "/home/patel_zeel/miniconda3/envs/ajax/lib/python3.9/site-packages/plum/function.py", line 591, in __call__
    method, return_type = self.resolve_method(*sig_types)
  File "/home/patel_zeel/miniconda3/envs/ajax/lib/python3.9/site-packages/plum/function.py", line 556, in resolve_method
    method, return_type = self._methods[self.resolve_signature(signature)]
  File "/home/patel_zeel/miniconda3/envs/ajax/lib/python3.9/site-packages/plum/function.py", line 492, in resolve_signature
    raise NotFoundLookupError(
plum.function.NotFoundLookupError: For function "_jit_run", signature Signature(builtins.function, builtins.dict, builtins.dict, builtins.dict) could not be resolved.

Also, I wonder how to use B.jit at places like jax.lax.scan function, which applies jax.jit internally.
Update: I checked with jax.lax.scan, and it works! It looks like lax.scan does not explicitly use jax.jit internally. Now only plum.function.NotFoundLookupError problem remains.

from stheno.

wesselb avatar wesselb commented on August 23, 2024

I think the problem now is that is typing is a little restrictive. In particular, a dictionary with JAX-valued values isn't recognised as a JAX object, and that's where the method error comes from. It is possible to add that method manually, as follows:

from types import FunctionType

import jax
import jax.numpy as jnp

import lab.jax as B
from plum import Union, Dict
from stheno.jax import GP, EQ


@B.generic._jit_run.dispatch
def _jit_run(
    f: FunctionType,
    compilation_cache: dict,
    jit_kw_args: dict,
    *args: Dict[object, B.JAXNumeric],
    **kw_args,
):
    return B.generic._jit_run.invoke(
        FunctionType, dict, dict, B.JAXNumeric,
    )(f, compilation_cache, jit_kw_args, *args, **kw_args)


x = jnp.arange(10)
y = jnp.arange(10)
params = {"lengthscale": jnp.array(1.0)}
loss_fn = lambda params: GP(EQ().stretch(params["lengthscale"]))(x).logpdf(y)
grad_fn = B.jit(jax.grad(loss_fn))
print(grad_fn(params))
{'lengthscale': DeviceArray(50.12888031, dtype=float64)}

It would be possible to make a small amendment to LAB so that this is the default behaviour, if that would be desirable. :)

from stheno.

patel-zeel avatar patel-zeel commented on August 23, 2024

Thanks for the solution, @wesselb. Due to the PyTree concept gaining popularity in JAX, dictionaries of parameters or lists and tuples of parameters or a mixture of these are also used sometimes as parameters. Is it possible to make B.jit work with any PyTree? Maybe PyTree from jaxtyping could be used somehow, but I am not sure how.

from stheno.

patel-zeel avatar patel-zeel commented on August 23, 2024

Yes, from the links you have shared, it looks like a harder problem. Maybe there can be a hotfix specifically for lab and Stheno till a more generic solution is available.

from stheno.

wesselb avatar wesselb commented on August 23, 2024

A hacky halfway house solution is to check for JAX-like objects in the first 10 (or so) layers of a PyTree... Then it wouldn't detect things like ((((((((((jnp.array(1),)))))))))), but maybe that's okay as a temporary solution. Do you think that would be reasonable? Or do you perhaps have another fix in mind?

from stheno.

patel-zeel avatar patel-zeel commented on August 23, 2024

I think 10 or so layers seem practical for most applications. However, how difficult would it be to convert that static number 10 to a dynamic depth by recursively checking PyTree?

from stheno.

wesselb avatar wesselb commented on August 23, 2024

I think that it would be possible to convert the static number to a dynamic depth. However, perhaps the right solution here is to see if we can actually give PyTrees first-class support. I’ll soon be working on a 2.0 of Plum, which is where currently the restrictions derive from. I will put PyTree support on the list of desired improvements!

from stheno.

Related Issues (17)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.