diff --git a/Project.toml b/Project.toml index 404b9f3df..916e55da2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.61" +version = "0.4.62" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index ca8b50931..77afbe2c4 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -1298,17 +1298,10 @@ straightforward to figure out much time is spent pushing to the block stack when @inline function __assemble_lazy_zero_rdata( r::Ref{T}, args::Vararg{CoDual,N} ) where {T<:Tuple,N} - r[] = __make_tuples(T, args) + r[] = map((T, x) -> lazy_zero_rdata(T, primal(x)), fieldtypes(T), args) return nothing end -@generated function __make_tuples(::Type{T}, args::Tuple) where {T} - lazy_exprs = map(eachindex(T.parameters)) do n - return :(lazy_zero_rdata($(T.parameters[n]), primal(args[$n]))) - end - return Expr(:call, tuple, lazy_exprs...) -end - """ pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index f74cf02c1..4985a05d6 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -743,7 +743,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) ys = blas_vectors(rng, P, M) flags = (false, :stability, (lb=1e-3, ub=10.0)) return map(As, xs, ys) do A, x, y - return (flags..., BLAS.gemv!, tA, randn(P), A, x, randn(P), y) + (flags..., BLAS.gemv!, tA, randn(rng, P), A, x, randn(rng, P), y) end end..., @@ -804,12 +804,13 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) # map(Ps) do P + flags = (false, :none, nothing) Any[ - (false, :none, nothing, BLAS.dot, 3, randn(P, 5), 1, randn(P, 4), 1), - (false, :none, nothing, BLAS.dot, 3, randn(P, 6), 2, randn(P, 4), 1), - (false, :none, nothing, BLAS.dot, 3, randn(P, 6), 1, randn(P, 9), 3), - (false, :none, nothing, BLAS.dot, 3, randn(P, 12), 3, randn(P, 9), 2), - (false, :none, nothing, BLAS.scal!, 10, P(2.4), randn(P, 30), 2), + (flags..., BLAS.dot, 3, randn(rng, P, 5), 1, randn(rng, P, 4), 1), + (flags..., BLAS.dot, 3, randn(rng, P, 6), 2, randn(rng, P, 4), 1), + (flags..., BLAS.dot, 3, randn(rng, P, 6), 1, randn(rng, P, 9), 3), + (flags..., BLAS.dot, 3, randn(rng, P, 12), 3, randn(rng, P, 9), 2), + (flags..., BLAS.scal!, 10, P(2.4), randn(rng, P, 30), 2), ] end..., @@ -834,18 +835,21 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) map_prod(t_flags, t_flags, Ps) do (tA, tB, P) As = blas_matrices(rng, P, 5, 5) Bs = blas_matrices(rng, P, 5, 5) + a = randn(rng, P) + b = randn(rng, P) return map_prod(As, Bs) do (A, B) - (false, :none, nothing, aliased_gemm!, tA, tB, randn(P), randn(P), A, B) + (false, :none, nothing, aliased_gemm!, tA, tB, a, b, A, B) end end..., # syrk! map_prod(uplos, t_flags, Ps) do (uplo, t, P) As = blas_matrices(rng, P, t == 'N' ? 3 : 4, t == 'N' ? 4 : 3) - C = randn(P, 3, 3) - flags = (false, :none, nothing) + C = randn(rng, P, 3, 3) + a = randn(rng, P) + b = randn(rng, P) return map(As) do A - return (flags..., BLAS.syrk!, uplo, t, randn(P), A, randn(P), C) + (false, :none, nothing, BLAS.syrk!, uplo, t, a, A, b, C) end end..., @@ -855,10 +859,11 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) ) do (side, ul, tA, dA, M, N, P) t = tA == 'N' R = side == 'L' ? M : N + a = randn(rng, P) As = blas_matrices(rng, P, R, R) Bs = blas_matrices(rng, P, M, N) return map(As, Bs) do A, B - (false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, randn(P), A, B) + (false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, a, A, B) end end..., @@ -868,13 +873,14 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) ) do (side, ul, tA, dA, M, N, P) t = tA == 'N' R = side == 'L' ? M : N + a = randn(rng, P) As = map(blas_matrices(rng, P, R, R)) do A A[diagind(A)] .+= 1 return A end Bs = blas_matrices(rng, P, M, N) return map(As, Bs) do A, B - (false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, randn(P), A, B) + (false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, a, A, B) end end..., ) diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 837278906..7cd893fe8 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -651,7 +651,12 @@ function rrule!!( end end -@generated is_homogeneous_and_immutable(::P) where {P<:Tuple} = allequal(P.parameters) +@static if VERSION >= v"1.11" + is_homogeneous_and_immutable(::P) where {P<:Tuple} = allequal(fieldtypes(P)) +else + @generated is_homogeneous_and_immutable(::P) where {P<:Tuple} = allequal(fieldtypes(P)) +end + @inline is_homogeneous_and_immutable(p::NamedTuple) = is_homogeneous_and_immutable(Tuple(p)) is_homogeneous_and_immutable(::Any) = false diff --git a/src/stack.jl b/src/stack.jl index b691b8f98..7a72eb2d4 100644 --- a/src/stack.jl +++ b/src/stack.jl @@ -35,4 +35,4 @@ end struct SingletonStack{T} end Base.push!(::SingletonStack, ::Any) = nothing -@generated Base.pop!(::SingletonStack{T}) where {T} = T.instance +Base.pop!(::SingletonStack{T}) where {T} = T.instance