enzymead / enzyme-jax Goto Github PK
View Code? Open in Web Editor NEWLicense: Other
License: Other
To mark which ones we see worth doing, are doing / need to do
%195 = stablehlo.iota dim = 0 : tensor<1024xi32>
%196 = stablehlo.reshape %195 : (tensor<1024xi32>) -> tensor<1x1x1024xi32>
%175 = stablehlo.pad %174, %148, low = [0, 0, 1024, 0, 0], high = [0, 0, 0, 0, 0], interior = [0, 0, 0, 0, 0] : (tensor<1x3x1024x1x1xf32>, tensor<f32>) -> tensor<1x3x2048x1x1xf32>
%176 = stablehlo.reshape %175 : (tensor<1x3x2048x1x1xf32>) -> tensor<1x3x2048xf32>
%175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
%177 = stablehlo.multiply %176, %112 : tensor<1x3x2048xf32>
%175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
%189 = stablehlo.broadcast_in_dim %177, dims = [0, 2, 4] : (tensor<1x3x2048xf32>) -> tensor<1x1x3x1024x2048xf32>
I'm getting the following error in EnzymeAD/Reactant.jl when trying to differentiate stablehlo.einsum
:
julia> f = Reactant.compile(grad, (a′, b′))
error: could not compute the adjoint for this operation %2 = "stablehlo.einsum"(%1, %0) <{einsum_config = "ij,jk->ik"}> : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64>
Pipeline failed
I open the issue here because I believe here is where the EnzymeMLIR rules are declared for HLO dialects right?
Seems like we are passing a Python object here:
Enzyme-JAX/src/enzyme_ad/jax/primitives.py
Lines 463 to 464 in ba24493
MlirModule
obj from MLIR-C that enzyme_call
is expecting:Enzyme-JAX/src/enzyme_ad/jax/enzyme_call.cc
Lines 1150 to 1154 in ba24493
$> python test/test.py ✔ enzyme-jax
Running tests under Python 3.12.3: /Users/mofeing/.pyenv/versions/enzyme-jax/bin/python
[ RUN ] EnzymeJax.test_custom_cpp_kernel
I0529 13:40:36.518865 8541272768 xla_bridge.py:884] Unable to initialize backend 'cuda':
I0529 13:40:36.518960 8541272768 xla_bridge.py:884] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0529 13:40:36.519386 8541272768 xla_bridge.py:884] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/mofeing/.pyenv/versions/3.12.3/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/mofeing/.pyenv/versions/3.12.3/lib/libtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
[[43. 43. 43.]
[43. 43. 43.]]
[[85. 85. 85.]
[85. 85. 85.]]
[Array([[56., 56., 56., 56.],
[56., 56., 56., 56.],
[56., 56., 56., 56.],
[56., 56., 56., 56.]], dtype=float32)]
(Array([[43., 43., 43.],
[43., 43., 43.]], dtype=float32), Array([[85., 85., 85.],
[85., 85., 85.]], dtype=float32), [Array([[56., 56., 56., 56.],
[56., 56., 56., 56.],
[56., 56., 56., 56.],
[56., 56., 56., 56.]], dtype=float32)])
(Array([[1., 1., 1.],
[1., 1., 1.]], dtype=float32), Array([[1., 1., 1.],
[1., 1., 1.]], dtype=float32), [Array([[56., 56., 56., 56.],
[56., 56., 56., 56.],
[56., 56., 56., 56.],
[56., 56., 56., 56.]], dtype=float32)])
(Array([[43., 43., 43.],
[43., 43., 43.]], dtype=float32), Array([[85., 85., 85.],
[85., 85., 85.]], dtype=float32), [Array([[56., 56., 56., 56.],
[56., 56., 56., 56.],
[56., 56., 56., 56.],
[56., 56., 56., 56.]], dtype=float32)])
[[128. 128. 128.]
[128. 128. 128.]]
[ OK ] EnzymeJax.test_custom_cpp_kernel
[ RUN ] EnzymeJax.test_enzyme_mlir_jit
[12. 23. 34.]
[ 50.1 70.2 110.3]
[12. 23. 34.]
(Array([500., 700., 110.], dtype=float32), Array([500., 700., 110.], dtype=float32))
[ OK ] EnzymeJax.test_enzyme_mlir_jit
[ RUN ] EnzymePipeline.test_pipeline
[ FAILED ] EnzymePipeline.test_pipeline
======================================================================
ERROR: test_pipeline (__main__.EnzymePipeline.test_pipeline)
EnzymePipeline.test_pipeline
----------------------------------------------------------------------
ValueError: PyCapsule_GetPointer called with incorrect name
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/mofeing/Developer/Enzyme-JAX/test/test.py", line 16, in test_pipeline
optimize_module(module)
File "/Users/mofeing/.pyenv/versions/enzyme-jax/lib/python3.12/site-packages/enzyme_ad/jax/primitives.py", line 463, in optimize_module
enzyme_call.optimize_module(mod, pipeline)
TypeError: optimize_module(): incompatible function arguments. The following argument types are supported:
1. (arg0: MlirModule, arg1: str) -> None
Invoked with: <jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x116b147b0>, '\n inline{default-pipeline=canonicalize max-iterations=4},\n canonicalize,cse,\n canonicalize,enzyme-hlo-generate-td{\n patterns=compare_op_canon<16>;\nbroadcast_in_dim_op_canon<16>;\nconvert_op_canon<16>;\ndynamic_broadcast_in_dim_op_not_actually_dynamic<16>;\nchained_dynamic_broadcast_in_dim_canonicalization<16>;\ndynamic_broadcast_in_dim_all_dims_non_expanding<16>;\nnoop_reduce_op_canon<16>;\nempty_reduce_op_canon<16>;\ndynamic_reshape_op_canon<16>;\nget_tuple_element_op_canon<16>;\nreal_op_canon<16>;\nimag_op_canon<16>;\nget_dimension_size_op_canon<16>;\ngather_op_canon<16>;\nreshape_op_canon<16>;\nmerge_consecutive_reshapes<16>;\ntranspose_is_reshape<16>;\nzero_extent_tensor_canon<16>;\nreorder_elementwise_and_shape_op<16>;\n\ncse_broadcast_in_dim<16>;\ncse_slice<16>;\ncse_transpose<16>;\ncse_convert<16>;\ncse_pad<16>;\ncse_dot_general<16>;\ncse_reshape<16>;\ncse_mul<16>;\ncse_div<16>;\ncse_add<16>;\ncse_subtract<16>;\ncse_min<16>;\ncse_max<16>;\ncse_neg<16>;\ncse_concatenate<16>;\n\nconcatenate_op_canon<16>(1024);\nselect_op_canon<16>(1024);\nadd_simplify<16>;\nsub_simplify<16>;\nand_simplify<16>;\nmax_simplify<16>;\nmin_simplify<16>;\nor_simplify<16>;\nnegate_simplify<16>;\nmul_simplify<16>;\ndiv_simplify<16>;\nrem_simplify<16>;\npow_simplify<16>;\nsqrt_simplify<16>;\ncos_simplify<16>;\nsin_simplify<16>;\nnoop_slice<16>;\nconst_prop_through_barrier<16>;\nslice_slice<16>;\nshift_right_logical_simplify<16>;\npad_simplify<16>;\nnegative_pad_to_slice<16>;\ntanh_simplify<16>;\nexp_simplify<16>;\nslice_simplify<16>;\nconvert_simplify<16>;\nreshape_simplify<16>;\ndynamic_slice_to_static<16>;\ndynamic_update_slice_elim<16>;\nconcat_to_broadcast<16>;\nreduce_to_reshape<16>;\nbroadcast_to_reshape<16>;\ngather_simplify<16>;\niota_simplify<16>(1024);\nbroadcast_in_dim_simplify<16>(1024);\nconvert_concat<1>;\ndynamic_update_to_concat<1>;\nslice_of_dynamic_update<1>;\nslice_elementwise<1>;\nslice_pad<1>;\ndot_reshape_dot<1>;\nconcat_const_prop<1>;\nconcat_fuse<1>;\npad_reshape_pad<1>;\npad_pad<1>;\nconcat_push_binop_add<1>;\nconcat_push_binop_mul<1>;\nscatter_to_dynamic_update_slice<1>;\nreduce_concat<1>;\nslice_concat<1>;\n\nbin_broadcast_splat_add<1>;\nbin_broadcast_splat_subtract<1>;\nbin_broadcast_splat_div<1>;\nbin_broadcast_splat_mul<1>;\nreshape_iota<16>;\nslice_reshape_slice<1>;\ndot_general_simplify<16>;\ntranspose_simplify<16>;\nreshape_empty_broadcast<1>;\nadd_pad_pad_to_concat<1>;\nbroadcast_reshape<1>;\n\nslice_reshape_concat<1>;\nslice_reshape_elementwise<1>;\nslice_reshape_transpose<1>;\nslice_reshape_dot_general<1>;\nconcat_pad<1>;\n\nreduce_pad<1>;\nbroadcast_pad<1>;\n\nzero_product_reshape_pad<1>;\nmul_zero_pad<1>;\ndiv_zero_pad<1>;\n\nbinop_const_reshape_pad<1>;\nbinop_const_pad_add<1>;\nbinop_const_pad_subtract<1>;\nbinop_const_pad_mul<1>;\nbinop_const_pad_div<1>;\n\nslice_reshape_pad<1>;\nbinop_binop_pad_pad_add<1>;\nbinop_binop_pad_pad_mul<1>;\nbinop_pad_pad_add<1>;\nbinop_pad_pad_subtract<1>;\nbinop_pad_pad_mul<1>;\nbinop_pad_pad_div<1>;\nbinop_pad_pad_min<1>;\nbinop_pad_pad_max<1>;\n\nunary_pad_push_convert<1>;\nunary_pad_push_tanh<1>;\nunary_pad_push_exp<1>;\n\ntranspose_pad<1>;\n\ntranspose_dot_reorder<1>;\ndot_transpose<1>;\nconvert_convert_float<1>;\nconcat_to_pad<1>;\nconcat_appending_reshape<1>;\nreshape_iota<1>;\n\nbroadcast_reduce<1>;\nslice_dot_general<1>;\n\ndot_reshape_pad<1>;\npad_dot_general<1>(0);\n\ndot_reshape_pad<1>;\npad_dot_general<1>(1);\n },\n transform-interpreter,\n enzyme-hlo-remove-transform\n '
----------------------------------------------------------------------
Ran 3 tests in 27.736s
FAILED (errors=1)
Extending the tests in
Line 68 in d9e2ae0
> x = jax.jacrev(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Batching rule for 'enzyme_rev' not implemented
> x = jax.jacfwd(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Batching rule for 'enzyme_fwd' not implemented
> x = jax.hessian(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Differentiation rule for 'enzyme_aug' not implemented
> x = jax.jit(jax.vmap(lambda x: add_one(x, jnp.array([1., 2., 3.]))))(jnp.array([jnp.array([1., 2., 3.])]*5))
NotImplementedError: Batching rule for 'enzyme_primal' not implemented
NOTE: Strikethrough ops are deliberately not annotated.
ComplexOp
BroadcastCompareOp
BroadcastSelectOp
TopKOp
ImportError: Python version mismatch: module was compiled for Python 3.10, but the interpreter version is incompatible: 3.11.3 (main, Apr 19 2023, 18:51:09) [Clang 14.0.6 ].
Pip package installs v0.0.4, which is a problem since the new tests do not work in the old version. Should be a quick fix?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.