GithubHelp home page GithubHelp logo

Comments (6)

patrick-kidger avatar patrick-kidger commented on July 28, 2024

Indeed, the usual behaviour is for jaxtyping is to be completely done by runtime. Probably you're constructing a typechecked dataclass or calling a typechecked function somewhere inside your training loop!

from jaxtyping.

kvablack avatar kvablack commented on July 28, 2024

This is all happening inside the train_step, which is jitted. You can even see in the trace that the isinstance calls are all inside the PjitFunction(train_step) call. Nothing else is happening in the train loop except for the train_step call.

Is it expected that typechecking dataclasses would run during a jitted function call?

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024

So once a function has been JIT'd then the Python code is never evaluated again.

We don't use jax.debug.callback or jax.pure_callback or jax.experimental.io_callback, and those are the only possible escape hatches for running normal Python in JIT'd code. We don't use those so we couldn't run typechecking many times even if we wanted to!

One possibility is that Flax is deserialising via its __init__ method (which is where typechecking occurs on dataclasses), so that all of this extra overhead is occurring during the very end of a jit'd call, when reconstructing the Python objects being passed out of the JIT'd region.

If so then that's a Flax bug, but before we go pointing fingers: can you test what result you get when you don't return any Flax objects from your JIT'd call?

from jaxtyping.

kvablack avatar kvablack commented on July 28, 2024

One possibility is that Flax is deserialising via its init method (which is where typechecking occurs on dataclasses), so that all of this extra overhead is occurring during the very end of a jit'd call, when reconstructing the Python objects being passed out of the JIT'd region.

Pretty sure this is it. Here's a minimal repro that shows a performance diff only if a dataclass is being returned:


import flax.serialization
from jaxtyping import jaxtyped, Array, Int, ArrayLike, config
from typeguard import typechecked
import jax
from functools import partial
from flax import struct

@partial(jaxtyped, typechecker=typechecked)
@struct.dataclass
class Node:
    x: list[Int[ArrayLike, ""]]
    y: list[str | int | float] = struct.field(pytree_node=False)

# returning a dataclass
@jax.jit
@partial(jaxtyped, typechecker=typechecked)
def f(x: Node) -> Array:
    return x

x = Node(list(range(1000)), list(map(str, range(1000))) + list(range(1000)) + list(map(float, range(1000))))

f(x)

config.update("jaxtyping_disable", False)
f(x)
%timeit f(x) # 21ms

config.update("jaxtyping_disable", True)
f(x)
%timeit f(x) # 15ms


# not returning a dataclass
@jax.jit
@partial(jaxtyped, typechecker=typechecked)
def f(x: Node) -> Array:
    return flax.serialization.to_state_dict(x)


f(x)

config.update("jaxtyping_disable", False)
f(x)
%timeit f(x) # 15ms

config.update("jaxtyping_disable", True)
f(x)
%timeit f(x) # 15ms

And here's the line where the constructor gets called --- in the unflatten_func.

Is there another way to do it? I'm not too familiar with PyTree internals, but with a custom PyTree, you need to construct the Python object at some point, right? How does Equinox do it?

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024

Ah, going via __init__ like this is known to be a dodgy thing to do. See the JAX docs here.

Not only does this run afoul of typechecking, it also misbehaves if you ever define a custom __init__ method.

Correct behaviour is to go via __new__ and then construct the desired fields directly. Here's the Equinox implementation, whihc does exactly that:

https://github.com/patrick-kidger/equinox/blob/955c0347de2690b07e59aad70e5666ded4ee28ef/equinox/_module.py#L913

I'm afraid this one isn't something we can really fix from jaxtyping's end. I'd suggest avoiding returning the Flax objects in question here. You say this is a flax.struct.dataclass. If that's the case then equinox.Module should be more-or-less a drop-in replacement for this. (It also fixes several other known issues around inheritance, bound methods, etc. etc.) Alternatively you could use normal Python types: tuples/dictionaries/etc.

from jaxtyping.

kvablack avatar kvablack commented on July 28, 2024

Ok thanks, that makes total sense. I'm sorry, I'm sure Equinox is great, I just have too much infra built up around flax.struct.dataclass 🥲 . FWIW, I'm a huge fan of jaxtyping, and the fact that it's worked this well for me and my sprawling Flax codebase is a testament to its excellent design.

At first, I hacked together a solution that disabled typechecking completely during Flax's PyTree unflattening. But then I looked at the trace again and realized almost all of the overhead was from PyTree typechecks. So I just removed the PyTree[Float[Array, "..."]] annotation from my params tree, and lo and behold training was fast again. Using Perfetto, I was even able to measure the remaining typechecking overhead from dataclass unflattening and it was only ~1ms, which is pretty acceptable. In theory I think it should be entirely negated by JAX's asynchronous dispatch, although I'm not sure if it's working correctly, considering that the unflattening doesn't seem to happen until after all the GPU operations have finished.

from jaxtyping.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.