Skip to content

Commit

Permalink
Cautiously make things not generated (#420)
Browse files Browse the repository at this point in the history
* pop singletonstack not generated

* StableRNG and formatting

* Not generated is_homogeneous_and_immutable

* Add 1.11 qualifier

* Try another generated function

* Revert change

* Actually revert change

* Try another

* Remove redundant comments

* Bump patch
  • Loading branch information
willtebbutt authored Dec 14, 2024
1 parent 710b068 commit 9f92f47
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.61"
version = "0.4.62"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
9 changes: 1 addition & 8 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1298,17 +1298,10 @@ straightforward to figure out much time is spent pushing to the block stack when
@inline function __assemble_lazy_zero_rdata(
r::Ref{T}, args::Vararg{CoDual,N}
) where {T<:Tuple,N}
r[] = __make_tuples(T, args)
r[] = map((T, x) -> lazy_zero_rdata(T, primal(x)), fieldtypes(T), args)
return nothing
end

@generated function __make_tuples(::Type{T}, args::Tuple) where {T}
lazy_exprs = map(eachindex(T.parameters)) do n
return :(lazy_zero_rdata($(T.parameters[n]), primal(args[$n])))
end
return Expr(:call, tuple, lazy_exprs...)
end

"""
pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)
Expand Down
30 changes: 18 additions & 12 deletions src/rrules/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
ys = blas_vectors(rng, P, M)
flags = (false, :stability, (lb=1e-3, ub=10.0))
return map(As, xs, ys) do A, x, y
return (flags..., BLAS.gemv!, tA, randn(P), A, x, randn(P), y)
(flags..., BLAS.gemv!, tA, randn(rng, P), A, x, randn(rng, P), y)
end
end...,

Expand Down Expand Up @@ -804,12 +804,13 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
#

map(Ps) do P
flags = (false, :none, nothing)
Any[
(false, :none, nothing, BLAS.dot, 3, randn(P, 5), 1, randn(P, 4), 1),
(false, :none, nothing, BLAS.dot, 3, randn(P, 6), 2, randn(P, 4), 1),
(false, :none, nothing, BLAS.dot, 3, randn(P, 6), 1, randn(P, 9), 3),
(false, :none, nothing, BLAS.dot, 3, randn(P, 12), 3, randn(P, 9), 2),
(false, :none, nothing, BLAS.scal!, 10, P(2.4), randn(P, 30), 2),
(flags..., BLAS.dot, 3, randn(rng, P, 5), 1, randn(rng, P, 4), 1),
(flags..., BLAS.dot, 3, randn(rng, P, 6), 2, randn(rng, P, 4), 1),
(flags..., BLAS.dot, 3, randn(rng, P, 6), 1, randn(rng, P, 9), 3),
(flags..., BLAS.dot, 3, randn(rng, P, 12), 3, randn(rng, P, 9), 2),
(flags..., BLAS.scal!, 10, P(2.4), randn(rng, P, 30), 2),
]
end...,

Expand All @@ -834,18 +835,21 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
map_prod(t_flags, t_flags, Ps) do (tA, tB, P)
As = blas_matrices(rng, P, 5, 5)
Bs = blas_matrices(rng, P, 5, 5)
a = randn(rng, P)
b = randn(rng, P)
return map_prod(As, Bs) do (A, B)
(false, :none, nothing, aliased_gemm!, tA, tB, randn(P), randn(P), A, B)
(false, :none, nothing, aliased_gemm!, tA, tB, a, b, A, B)
end
end...,

# syrk!
map_prod(uplos, t_flags, Ps) do (uplo, t, P)
As = blas_matrices(rng, P, t == 'N' ? 3 : 4, t == 'N' ? 4 : 3)
C = randn(P, 3, 3)
flags = (false, :none, nothing)
C = randn(rng, P, 3, 3)
a = randn(rng, P)
b = randn(rng, P)
return map(As) do A
return (flags..., BLAS.syrk!, uplo, t, randn(P), A, randn(P), C)
(false, :none, nothing, BLAS.syrk!, uplo, t, a, A, b, C)
end
end...,

Expand All @@ -855,10 +859,11 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
) do (side, ul, tA, dA, M, N, P)
t = tA == 'N'
R = side == 'L' ? M : N
a = randn(rng, P)
As = blas_matrices(rng, P, R, R)
Bs = blas_matrices(rng, P, M, N)
return map(As, Bs) do A, B
(false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, randn(P), A, B)
(false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, a, A, B)
end
end...,

Expand All @@ -868,13 +873,14 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
) do (side, ul, tA, dA, M, N, P)
t = tA == 'N'
R = side == 'L' ? M : N
a = randn(rng, P)
As = map(blas_matrices(rng, P, R, R)) do A
A[diagind(A)] .+= 1
return A
end
Bs = blas_matrices(rng, P, M, N)
return map(As, Bs) do A, B
(false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, randn(P), A, B)
(false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, a, A, B)
end
end...,
)
Expand Down
7 changes: 6 additions & 1 deletion src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,12 @@ function rrule!!(
end
end

@generated is_homogeneous_and_immutable(::P) where {P<:Tuple} = allequal(P.parameters)
@static if VERSION >= v"1.11"
is_homogeneous_and_immutable(::P) where {P<:Tuple} = allequal(fieldtypes(P))
else
@generated is_homogeneous_and_immutable(::P) where {P<:Tuple} = allequal(fieldtypes(P))
end

@inline is_homogeneous_and_immutable(p::NamedTuple) = is_homogeneous_and_immutable(Tuple(p))
is_homogeneous_and_immutable(::Any) = false

Expand Down
2 changes: 1 addition & 1 deletion src/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ end
struct SingletonStack{T} end

Base.push!(::SingletonStack, ::Any) = nothing
@generated Base.pop!(::SingletonStack{T}) where {T} = T.instance
Base.pop!(::SingletonStack{T}) where {T} = T.instance

2 comments on commit 9f92f47

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/121406

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.62 -m "<description of version>" 9f92f47253d83665803aeb3b2ebb6e08c4425ce0
git push origin v0.4.62

Please sign in to comment.