Comments (6)
Indeed, the usual behaviour is for jaxtyping is to be completely done by runtime. Probably you're constructing a typechecked dataclass or calling a typechecked function somewhere inside your training loop!
from jaxtyping.
This is all happening inside the train_step
, which is jitted. You can even see in the trace that the isinstance
calls are all inside the PjitFunction(train_step)
call. Nothing else is happening in the train loop except for the train_step
call.
Is it expected that typechecking dataclasses would run during a jitted function call?
from jaxtyping.
So once a function has been JIT'd then the Python code is never evaluated again.
We don't use jax.debug.callback
or jax.pure_callback
or jax.experimental.io_callback
, and those are the only possible escape hatches for running normal Python in JIT'd code. We don't use those so we couldn't run typechecking many times even if we wanted to!
One possibility is that Flax is deserialising via its __init__
method (which is where typechecking occurs on dataclasses), so that all of this extra overhead is occurring during the very end of a jit'd call, when reconstructing the Python objects being passed out of the JIT'd region.
If so then that's a Flax bug, but before we go pointing fingers: can you test what result you get when you don't return any Flax objects from your JIT'd call?
from jaxtyping.
One possibility is that Flax is deserialising via its init method (which is where typechecking occurs on dataclasses), so that all of this extra overhead is occurring during the very end of a jit'd call, when reconstructing the Python objects being passed out of the JIT'd region.
Pretty sure this is it. Here's a minimal repro that shows a performance diff only if a dataclass is being returned:
import flax.serialization
from jaxtyping import jaxtyped, Array, Int, ArrayLike, config
from typeguard import typechecked
import jax
from functools import partial
from flax import struct
@partial(jaxtyped, typechecker=typechecked)
@struct.dataclass
class Node:
x: list[Int[ArrayLike, ""]]
y: list[str | int | float] = struct.field(pytree_node=False)
# returning a dataclass
@jax.jit
@partial(jaxtyped, typechecker=typechecked)
def f(x: Node) -> Array:
return x
x = Node(list(range(1000)), list(map(str, range(1000))) + list(range(1000)) + list(map(float, range(1000))))
f(x)
config.update("jaxtyping_disable", False)
f(x)
%timeit f(x) # 21ms
config.update("jaxtyping_disable", True)
f(x)
%timeit f(x) # 15ms
# not returning a dataclass
@jax.jit
@partial(jaxtyped, typechecker=typechecked)
def f(x: Node) -> Array:
return flax.serialization.to_state_dict(x)
f(x)
config.update("jaxtyping_disable", False)
f(x)
%timeit f(x) # 15ms
config.update("jaxtyping_disable", True)
f(x)
%timeit f(x) # 15ms
And here's the line where the constructor gets called --- in the unflatten_func
.
Is there another way to do it? I'm not too familiar with PyTree internals, but with a custom PyTree, you need to construct the Python object at some point, right? How does Equinox do it?
from jaxtyping.
Ah, going via __init__
like this is known to be a dodgy thing to do. See the JAX docs here.
Not only does this run afoul of typechecking, it also misbehaves if you ever define a custom __init__
method.
Correct behaviour is to go via __new__
and then construct the desired fields directly. Here's the Equinox implementation, whihc does exactly that:
I'm afraid this one isn't something we can really fix from jaxtyping's end. I'd suggest avoiding returning the Flax objects in question here. You say this is a flax.struct.dataclass
. If that's the case then equinox.Module
should be more-or-less a drop-in replacement for this. (It also fixes several other known issues around inheritance, bound methods, etc. etc.) Alternatively you could use normal Python types: tuples/dictionaries/etc.
from jaxtyping.
Ok thanks, that makes total sense. I'm sorry, I'm sure Equinox is great, I just have too much infra built up around flax.struct.dataclass
🥲 . FWIW, I'm a huge fan of jaxtyping, and the fact that it's worked this well for me and my sprawling Flax codebase is a testament to its excellent design.
At first, I hacked together a solution that disabled typechecking completely during Flax's PyTree unflattening. But then I looked at the trace again and realized almost all of the overhead was from PyTree typechecks. So I just removed the PyTree[Float[Array, "..."]]
annotation from my params
tree, and lo and behold training was fast again. Using Perfetto, I was even able to measure the remaining typechecking overhead from dataclass unflattening and it was only ~1ms, which is pretty acceptable. In theory I think it should be entirely negated by JAX's asynchronous dispatch, although I'm not sure if it's working correctly, considering that the unflattening doesn't seem to happen until after all the GPU operations have finished.
from jaxtyping.
Related Issues (20)
- Issues with torch.compile HOT 5
- Functions without type hints and import hook HOT 1
- Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' HOT 4
- Can typeguard be an optional dependency? HOT 8
- Are pytorch named tensors supported, like in torchtyping? HOT 1
- How to properly escape `*` and `_` when rendering docs with Sphinx HOT 3
- numpy structured dtype support HOT 1
- Bug with default argument binding HOT 2
- Incompatibility with flax.linen.tabulate HOT 4
- Unions not working HOT 4
- `install_import_hook` skip `no_type_check` HOT 1
- add type annotation for jaxtyped HOT 1
- Installing `jaxtyping` makes Pytest slow(er) HOT 3
- Annotations for tensors with dynamics dimensions HOT 2
- Leading integer and ellipses in pytree raises error HOT 2
- `mypy` gives error for `np.ndarray` as container type HOT 2
- Misleading exception when runtime type checker is used directly HOT 4
- Creating instances of `jaxtyped` dataclasses is slow HOT 3
- Inconsistent shape checking for lists of tensors HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.