diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 4d3a3a4a2e..53b35c794f 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1140,11 +1140,10 @@ end ) where {Ann,Nargs} expr = Vector{Expr}(undef, Nargs) for i = 1:Nargs - if args[i] <: Active - throw(AssertionError("Unsupported Active arg $(args[i])")) - end @inbounds expr[i] = if args[i] <: Const :(args[$i].val) + elseif args[i] <: Active + :(Enzyme.make_zero(args[$i].val)) elseif args[i] <: MixedDuplicated :(args[$i].dval[]) else @@ -1170,9 +1169,10 @@ end for w = 1:width expr = Vector{Expr}(undef, Nargs) for i = 1:Nargs - @assert !(args[i] <: Active) @inbounds expr[i] = if args[i] <: Const :(args[$i].val) + elseif args[i] <: Active + :(Enzyme.make_zero(args[$i].val)) elseif args[i] <: BatchMixedDuplicated :(args[$i].dval[$w][]) else diff --git a/test/applyiter.jl b/test/applyiter.jl index 699a3cd69e..551c5ff781 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -507,7 +507,9 @@ end data = [[3.0], nothing, 2.0] ddata = [[0.0], nothing, 0.0] - @test_throws AssertionError Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata)) + Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata)) + @test ddata[1][1] ≈ 2.0 + @test ddata[3] ≈ 3.0 function mktup3(v) tup = tuple(v..., v...) diff --git a/test/runtests.jl b/test/runtests.jl index 5c5d70d9fa..5f909abdb1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1845,15 +1845,16 @@ end dR = zeros(6, 6) @static if VERSION ≥ v"1.11-" + elseif VERSION ≥ v"1.10.8" + autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) + @test 1.0 ≈ dR[1, 1] + @test 1.0 ≈ dR[2, 2] + @test 1.0 ≈ dR[3, 3] + @test 1.0 ≈ dR[4, 4] + @test 1.0 ≈ dR[5, 5] + @test 0.0 ≈ dR[6, 6] else @test_broken autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) - # autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) - # @test 1.0 ≈ dR[1, 1] - # @test 1.0 ≈ dR[2, 2] - # @test 1.0 ≈ dR[3, 3] - # @test 1.0 ≈ dR[4, 4] - # @test 1.0 ≈ dR[5, 5] - # @test 0.0 ≈ dR[6, 6] end end