Comments (2)
Hi @Laz4rz,
python -m trace
isn't ideal for tracing Jax functions. This general Python tracing tool might not handle Jax's internal workings effectively.
Jax-based alternatives like make_jaxpr
provide detailed insights into Jax computations.
For a more comprehensive view, you could use jax.profiler.trace
within your Jax programs.
make_jaxpr
example:
from jax import make_jaxpr
from jax.numpy import array
from jax.nn import softmax
p = array([0.50, 0.60, 0.70, 0.30, 0.25])
print(make_jaxpr(softmax)(p))
Output:
{ lambda ; a:f32[5]. let
b:f32[] = reduce_max[axes=(0,)] a
c:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] b
d:f32[1] = stop_gradient c
e:f32[5] = sub a d
f:f32[5] = exp e
g:f32[] = reduce_sum[axes=(0,)] f
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
i:f32[5] = div f h
in (i,) }
from jax.
Agreed with @selamw1 regarding jax-specific tracing, but still it's strange that python -m trace
fails ungracefully here. Looking at the traceback, it seems to have something to do with objects which have no parent frame, and to me that sounds like it may be some object defined in compiled code (jax primarily uses nanobind for such things).
I'll leave this marked as a bug, but I don't think it's particularly a priority to solve this unless it affects more typical JAX usage.
from jax.
Related Issues (20)
- str(PyTreeDef) identical for two PyTreeDefs, but assert with allclose fails HOT 6
- Guidelines on reducing compilation memory?
- Custom partitioning error in fused_attention_stablehlo HOT 3
- jax-metal: dynamic update slice fails with unsigned indices
- jax-metal: cond fails in compile in certain cases HOT 1
- support batched matrix multiplication in pallas
- jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 11 HOT 2
- Wrong results on matmul's associative_scan when jitted within scan HOT 2
- Wrong array values in `jax.debug.print` and the actual results when using `lax.slice` in `lax.scan` HOT 6
- Stochastic pmap lowering behavior in tests HOT 3
- Using `jax.config.update` within config context manager fails to set new value HOT 7
- Inconsistent results with shard_map when switching PRNG Implementation from threefry2x32 to rbg HOT 1
- Allow access to custom pytree definitions and override them HOT 7
- pmap out_axes=None doesn't check if output is mapped/unmapped
- reduce_window broken on CPU HOT 2
- Mistake (?) in the "How to think in Jax" doc HOT 2
- ndarray.at.set(mode="drop") gives incorrect value for the last element in the array HOT 4
- partial eval silently skips effects HOT 3
- NaN when computing gradient of squared norm evaluated at 0. HOT 2
- TracerBoolConversionError when jitting jax.numpy.linalg.norm HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.