diff --git a/Project.toml b/Project.toml index 210d77792..7ee353236 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.8" +version = "0.7.9" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "0.9" -ChainRulesTestUtils = "0.4.2" +ChainRulesTestUtils = "0.4.2, 0.5" Compat = "3" FiniteDifferences = "0.10" Reexport = "0.2" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 5cb78318e..979518a89 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -4,14 +4,14 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) function reshape_pullback(Ȳ) - return (NO_FIELDS, @thunk(reshape(Ȳ, dims)), DoesNotExist()) + return (NO_FIELDS, reshape(Ȳ, dims), DoesNotExist()) end return reshape(A, dims), reshape_pullback end function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) function reshape_pullback(Ȳ) - ∂A = @thunk(reshape(Ȳ, dims)) + ∂A = reshape(Ȳ, dims) return (NO_FIELDS, ∂A, fill(DoesNotExist(), length(dims))...) end return reshape(A, dims...), reshape_pullback @@ -63,14 +63,14 @@ end function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) function fill_pullback(Ȳ) - return (NO_FIELDS, @thunk(sum(Ȳ)), DoesNotExist()) + return (NO_FIELDS, sum(Ȳ), DoesNotExist()) end return fill(value, dims), fill_pullback end function rrule(::typeof(fill), value::Any, dims::Int...) function fill_pullback(Ȳ) - return (NO_FIELDS, @thunk(sum(Ȳ)), ntuple(_->DoesNotExist(), length(dims))...) + return (NO_FIELDS, sum(Ȳ), ntuple(_->DoesNotExist(), length(dims))...) end return fill(value, dims), fill_pullback end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 25f1f40c2..f4a71fb3e 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -10,7 +10,7 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} y = sum(sum, x; dims=dims) function sum_pullback(ȳ) # broadcasting the two works out the size no-matter `dims` - x̄ = @thunk broadcast(x, ȳ) do xi, ȳi + x̄ = broadcast(x, ȳ) do xi, ȳi ȳi end return (NO_FIELDS, x̄) @@ -44,7 +44,7 @@ function rrule( ) where {T<:Union{Real,Complex}} y = sum(abs2, x; dims=dims) function sum_abs2_pullback(ȳ) - return (NO_FIELDS, DoesNotExist(), @thunk(2 .* real.(ȳ) .* x)) + return (NO_FIELDS, DoesNotExist(), 2 .* real.(ȳ) .* x) end return y, sum_abs2_pullback end diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 46e43abb5..3c9f283b0 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -90,7 +90,7 @@ function rrule(::typeof(BLAS.asum), n, X, incx) function asum_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ)) - ∂X = @thunk scal!(n, s, blascopy!(n, _signcomp.(X), incx, _zeros(X), incx), incx) + ∂X = scal!(n, s, blascopy!(n, _signcomp.(X), incx, _zeros(X), incx), incx) return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist()) end return Ω, asum_pullback diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index c949c3b96..c3af22333 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -101,7 +101,7 @@ function rrule(::typeof(tr), x) # This should really be a FillArray # see https://github.com/JuliaDiff/ChainRules.jl/issues/46 function tr_pullback(ΔΩ) - return (NO_FIELDS, @thunk Diagonal(fill(ΔΩ, size(x, 1)))) + return (NO_FIELDS, Diagonal(fill(ΔΩ, size(x, 1)))) end return tr(x), tr_pullback end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 84caab6ea..5b5b8043f 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -8,7 +8,7 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) F = svd(X) function svd_pullback(Ȳ::Composite) - ∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)) + ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V) return (NO_FIELDS, ∂X) end return F, svd_pullback @@ -73,9 +73,9 @@ function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) F = cholesky(X) function cholesky_pullback(Ȳ::Composite) ∂X = if F.uplo === 'U' - @thunk(chol_blocked_rev(Ȳ.U, F.U, 25, true)) + chol_blocked_rev(Ȳ.U, F.U, 25, true) else - @thunk(chol_blocked_rev(Ȳ.L, F.L, 25, false)) + chol_blocked_rev(Ȳ.L, F.L, 25, false) end return (NO_FIELDS, ∂X) end @@ -85,7 +85,7 @@ end function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky function getproperty_cholesky_pullback(Ȳ) C = Composite{T} - ∂F = @thunk if x === :U + ∂F = if x === :U if F.uplo === 'U' C(U=UpperTriangular(Ȳ),) else diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 677039743..bbc5e8743 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -21,14 +21,14 @@ end function rrule(::typeof(diag), A::AbstractMatrix) function diag_pullback(ȳ) - return (NO_FIELDS, @thunk(Diagonal(ȳ))) + return (NO_FIELDS, Diagonal(ȳ)) end return diag(A), diag_pullback end if VERSION ≥ v"1.3" function rrule(::typeof(diag), A::AbstractMatrix, k::Integer) function diag_pullback(ȳ) - return (NO_FIELDS, @thunk(diagm(size(A)..., k => ȳ)), DoesNotExist()) + return (NO_FIELDS, diagm(size(A)..., k => ȳ), DoesNotExist()) end return diag(A, k), diag_pullback end @@ -48,11 +48,9 @@ function rrule(::typeof(diagm), kv::Pair{<:Integer,<:AbstractVector}...) end function _diagm_back(p, ȳ) - return Thunk() do - k, v = p - d = diag(ȳ, k)[1:length(v)] # handle if diagonal was smaller than matrix - return Composite{typeof(p)}(second = d) - end + k, v = p + d = diag(ȳ, k)[1:length(v)] # handle if diagonal was smaller than matrix + return Composite{typeof(p)}(second = d) end function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) @@ -73,7 +71,7 @@ end function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) Ω = T(A, uplo) function HermOrSym_pullback(ΔΩ) - return (NO_FIELDS, @thunk(_symherm_back(T, ΔΩ, Ω.uplo)), DoesNotExist()) + return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist()) end return Ω, HermOrSym_pullback end @@ -149,28 +147,28 @@ end # ✖️✖️✖️TODO: Deal with complex-valued arrays as well function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) function Adjoint_pullback(ȳ) - return (NO_FIELDS, @thunk(adjoint(ȳ))) + return (NO_FIELDS, adjoint(ȳ)) end return Adjoint(A), Adjoint_pullback end function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) function Adjoint_pullback(ȳ) - return (NO_FIELDS, @thunk(vec(adjoint(ȳ)))) + return (NO_FIELDS, vec(adjoint(ȳ))) end return Adjoint(A), Adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) function adjoint_pullback(ȳ) - return (NO_FIELDS, @thunk(adjoint(ȳ))) + return (NO_FIELDS, adjoint(ȳ)) end return adjoint(A), adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractVector{<:Real}) function adjoint_pullback(ȳ) - return (NO_FIELDS, @thunk(vec(adjoint(ȳ)))) + return (NO_FIELDS, vec(adjoint(ȳ))) end return adjoint(A), adjoint_pullback end @@ -181,28 +179,28 @@ end function rrule(::Type{<:Transpose}, A::AbstractMatrix) function Transpose_pullback(ȳ) - return (NO_FIELDS, @thunk transpose(ȳ)) + return (NO_FIELDS, transpose(ȳ)) end return Transpose(A), Transpose_pullback end function rrule(::Type{<:Transpose}, A::AbstractVector) function Transpose_pullback(ȳ) - return (NO_FIELDS, @thunk vec(transpose(ȳ))) + return (NO_FIELDS, vec(transpose(ȳ))) end return Transpose(A), Transpose_pullback end function rrule(::typeof(transpose), A::AbstractMatrix) function transpose_pullback(ȳ) - return (NO_FIELDS, @thunk transpose(ȳ)) + return (NO_FIELDS, transpose(ȳ)) end return transpose(A), transpose_pullback end function rrule(::typeof(transpose), A::AbstractVector) function transpose_pullback(ȳ) - return (NO_FIELDS, @thunk vec(transpose(ȳ))) + return (NO_FIELDS, vec(transpose(ȳ))) end return transpose(A), transpose_pullback end @@ -213,40 +211,40 @@ end function rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) function UpperTriangular_pullback(ȳ) - return (NO_FIELDS, @thunk Matrix(ȳ)) + return (NO_FIELDS, Matrix(ȳ)) end return UpperTriangular(A), UpperTriangular_pullback end function rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) function LowerTriangular_pullback(ȳ) - return (NO_FIELDS, @thunk Matrix(ȳ)) + return (NO_FIELDS, Matrix(ȳ)) end return LowerTriangular(A), LowerTriangular_pullback end function rrule(::typeof(triu), A::AbstractMatrix, k::Integer) function triu_pullback(ȳ) - return (NO_FIELDS, @thunk(triu(ȳ, k)), DoesNotExist()) + return (NO_FIELDS, triu(ȳ, k), DoesNotExist()) end return triu(A, k), triu_pullback end function rrule(::typeof(triu), A::AbstractMatrix) function triu_pullback(ȳ) - return (NO_FIELDS, @thunk triu(ȳ)) + return (NO_FIELDS, triu(ȳ)) end return triu(A), triu_pullback end function rrule(::typeof(tril), A::AbstractMatrix, k::Integer) function tril_pullback(ȳ) - return (NO_FIELDS, @thunk(tril(ȳ, k)), DoesNotExist()) + return (NO_FIELDS, tril(ȳ, k), DoesNotExist()) end return tril(A, k), tril_pullback end function rrule(::typeof(tril), A::AbstractMatrix) function tril_pullback(ȳ) - return (NO_FIELDS, @thunk tril(ȳ)) + return (NO_FIELDS, tril(ȳ)) end return tril(A), tril_pullback end diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index d286d37ab..6b15b10fa 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -12,10 +12,8 @@ function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) y_sum, sum_pullback = rrule(sum, x; dims=dims) n = _denom(x, dims) function mean_pullback(ȳ) - ∂x = Thunk() do - _, ∂sum_x = sum_pullback(ȳ) - extern(∂sum_x) / n - end + _, ∂sum_x = sum_pullback(ȳ) + ∂x = extern(∂sum_x) / n return (NO_FIELDS, ∂x) end return y_sum / n, mean_pullback