Comments (3)
It looks like you need to update your JAX version. register_dataclass
returned None
before JAX v0.4.29, so you cannot use the decorator approach in older JAX versions. If you want to use the decorator approach from the docs, you'll have to update to JAX v0.4.29 or newer.
If updating to a more recent JAX version is not an option, then you can register your class this way:
@dataclass(frozen=True, slots=True, kw_only=True)
class A[GENERIC: Api]:
metadata_a: GENERIC
data: jax.Array
jax.tree_util.register_dataclass(A, meta_fields=["metadata_a"], data_fields=["data"])
from jax.
ah, whoops, should be
assert type_b == type
assert type_a == type
i.e. the type of these classes ought to be type. using the tree_util decorator makes it NoneType, causing generics to fail
(updated)
from jax.
Fixed!
from jax.
Related Issues (20)
- Array dispatching with __array_ufunc__ in JAX HOT 8
- Different roundings on GPU vs. CPU HOT 2
- Figure out + document how to use manual capture profiling with Kubernetes
- Is there difference in computing gradient between jax and torch HOT 1
- int4 reshape: Reshape should have supported layout before reaching the emitter
- cuSolver internal error with `jax.scipy.stats.multivariate_normal.pdf()`
- Argument mismatch in `jax.numpy.linalg.matrix_rank` documentation.
- Wrong result for unsigned dtype input into `jax.numpy.partition` and `jax.numpy.argpartition`
- Mark jax typing Protocols as `@runtime_checkable` HOT 5
- Global Singleton (XlaDebugInfoManager) leaks out of the control of C API and gets two copies in two shared libraries HOT 1
- How to perform a Bucketize and dynamic_partition OP? HOT 5
- Error Building Jaxlib v0.4.30 on Jetson Orin HOT 3
- Efficient parallelization of tasks with highly variable difficulty and run times
- Automatically differentiate Pallas kernel failed HOT 4
- pallas hello world not implemented for int8
- Check failed in collective_pipeliner when using gradient accumulation with non-unrolled loop HOT 1
- Feature Request: a way to opt out of "hermetic python" HOT 2
- non-addressable data access inside a jitted function HOT 2
- Allow overloading binary operations between jax.Array and custom array-like types (perhaps via `__array_priority__`) HOT 3
- Incorrect cost_analysis() Flop Count on GPU for Dense layer/Dot product
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.