using TransformVariables, Enzyme, StaticArrays
K = 7
N_EE = N_UU = N_UE = N_EU = 20
trans = as((# common
ω_intercept = as(SVector{K}), ω_std = as(SVector{K}, asℝ₊),
ω_corr_factor = as(SVector{K}), ζ_std = asℝ₊, ε_std = asℝ₊,
κ = as(Real, 0.5, 1.5),
B1 = as(SVector{3}), B2 = as(SVector{3}), BC = as(SVector{3}),
# EE
α̂1_EE = as(view, N_EE), α̂2_EE = as(view, N_EE),
β̂1_EE = as(view, N_EE), β̂2_EE = as(view, N_EE),
M̂_EE = as(view, N_EE),
# EU
α̂1_EU = as(view, N_EU), α̂2_EU = as(view, N_EU),
β̂1_EU = as(view, N_EU), β̂2_EU = as(view, N_EU),
M̂_EU = as(view, N_EU), ŵ2_EU = as(view, N_EU),
# UE
α̂1_UE = as(view, N_UE), α̂2_UE = as(view, N_UE),
β̂1_UE = as(view, N_UE), β̂2_UE = as(view, N_UE),
M̂_UE = as(view, N_UE), ŵ1_UE = as(view, N_UE),
# UU
α̂1_UU = as(view, N_UU), α̂2_UU = as(view, N_UU),
β̂1_UU = as(view, N_UU), β̂2_UU = as(view, N_UU),
M̂_UU = as(view, N_UU),
ŵ1_UU = as(view, N_UU), ŵ2_UU = as(view, N_UU),
))
_s(x::Real) = x # simple recursive sum, for testing
_s(x::AbstractArray) = sum(_s, x)
_s(x::NamedTuple) = sum(_s, values(x))
g(t, x) = _s(transform(t, x))
x = zeros(dimension(trans))
g(trans, x) # sanity check that primal call works
∂ℓ_∂x = zero(x)
_, y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, g,
Enzyme.Active, Enzyme.Const(trans), Enzyme.Duplicated(x, ∂ℓ_∂x))
ERROR: AssertionError: Found unhandled active variable in tuple splat, jl_apply_iterate @NamedTuple{ω_intercept::TransformVariables.StaticArrayTransformation{7, Tuple{7}, TransformVariables.Identity}, ω_std::TransformVariables.StaticArrayTransformation{7, Tuple{7}, TransformVariables.ShiftedExp{true, Int64}}, ω_corr_factor::TransformVariables.StaticArrayTransformation{7, Tuple{7}, TransformVariables.Identity}, ζ_std::TransformVariables.ShiftedExp{true, Int64}, ε_std::TransformVariables.ShiftedExp{true, Int64}, κ::TransformVariables.ScaledShiftedLogistic{Float64}, B1::TransformVariables.StaticArrayTransformation{3, Tuple{3}, TransformVariables.Identity}, B2::TransformVariables.StaticArrayTransformation{3, Tuple{3}, TransformVariables.Identity}, BC::TransformVariables.StaticArrayTransformation{3, Tuple{3}, TransformVariables.Identity}, α̂1_EE::TransformVariables.ViewTransformation{1}, α̂2_EE::TransformVariables.ViewTransformation{1}, β̂1_EE::TransformVariables.ViewTransformation{1}, β̂2_EE::TransformVariables.ViewTransformation{1}, M̂_EE::TransformVariables.ViewTransformation{1}, α̂1_EU::TransformVariables.ViewTransformation{1}, α̂2_EU::TransformVariables.ViewTransformation{1}, β̂1_EU::TransformVariables.ViewTransformation{1}, β̂2_EU::TransformVariables.ViewTransformation{1}, M̂_EU::TransformVariables.ViewTransformation{1}, ŵ2_EU::TransformVariables.ViewTransformation{1}, α̂1_UE::TransformVariables.ViewTransformation{1}, α̂2_UE::TransformVariables.ViewTransformation{1}, β̂1_UE::TransformVariables.ViewTransformation{1}, β̂2_UE::TransformVariables.ViewTransformation{1}, M̂_UE::TransformVariables.ViewTransformation{1}, ŵ1_UE::TransformVariables.ViewTransformation{1}, α̂1_UU::TransformVariables.ViewTransformation{1}, α̂2_UU::TransformVariables.ViewTransformation{1}, β̂1_UU::TransformVariables.ViewTransformation{1}, β̂2_UU::TransformVariables.ViewTransformation{1}, M̂_UU::TransformVariables.ViewTransformation{1}, ŵ1_UU::TransformVariables.ViewTransformation{1}, ŵ2_UU::TransformVariables.ViewTransformation{1}}
Stacktrace:
[1] error_if_active_iter(arg::Base.RefValue{@NamedTuple{…}})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/rules/jitrules.jl:775
[2] Tuple
@ ./namedtuple.jl:200 [inlined]
[3] values
@ ./namedtuple.jl:379 [inlined]
[4] transform_with
@ ~/code/julia/TransformVariables/src/aggregation.jl:388 [inlined]
[5] transform
@ ~/code/julia/TransformVariables/src/generic.jl:268
[6] g
@ ./REPL[31]:1 [inlined]
[7] g
@ ./REPL[31]:0 [inlined]
[8] augmented_julia_g_3346_inner_1wrap
@ ./REPL[31]:0
[9] macro expansion
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:5299 [inlined]
[10] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Duplicated{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:4977
[11] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::Const{…}, ::Const{…}, ::Vararg{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:4930
[12] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Vararg{…})
@ Enzyme ~/.julia/packages/Enzyme/Dd2LU/src/Enzyme.jl:198
[13] autodiff(::ReverseMode{true, FFIABI}, ::typeof(g), ::Type, ::Const{TransformVariables.TransformTuple{…}}, ::Vararg{Any}) @ Enzyme ~/.julia/packages/Enzyme/Dd2LU/src/Enzyme.jl:224
[14] top-level scope
@ REPL[35]:1