GithubHelp home page GithubHelp logo

Comments (2)

ricardoV94 avatar ricardoV94 commented on June 25, 2024

@HarshvirSandhu thanks for reporting the issue.

That tutorial was written back in a time when JAX + jitting was more flexible. These days all jitted functions must have constant shape, which means a graph like the one in the example can never be translated to JAX, since it's fundamentally a function with dynamic shapes. The first error can be avoided by specifying the scalar has an integer dtype, but that will only kick the can further. After also updating the dispatch function signature we get this:

import jax.numpy as jnp

from pytensor.tensor.basic import Eye
from pytensor.link.jax.dispatch import jax_funcify
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
import pytensor.tensor as pt

@jax_funcify.register(Eye)
def jax_funcify_Eye(op, node, **kwargs):

    dtype = op.dtype

    def eye(N, M, k):
        return jnp.eye(N, M, k, dtype=dtype)

    return eye


def test_jax_Eye():
    """Test JAX conversion of the `Eye` `Op`."""

    x_at = pt.scalar(dtype="int")
    eye_var = pt.eye(x_at)

    out_fg = FunctionGraph(outputs=[eye_var])

    compare_jax_and_py(out_fg, [3])


test_jax_Eye()
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,). 'N' argument of jnp.eye().
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

We should use a different example and perhaps keep this in a section of its own explaining that not all PyTensor graphs can be converted into a jitted JAX-graph, usually those that have dynamic shapes.

from pytensor.

HangenYuu avatar HangenYuu commented on June 25, 2024

Hi,

I would like to take up this issue as part of my GSoC application. I will open a PR on this after changing to a more suitable example.

from pytensor.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.