Skip to content

Commit

Permalink
Generalize and test the rrules for sum (#217)
Browse files Browse the repository at this point in the history
* Release type constraints

* Update rrule for sum abs2

* Test sum for complex

* Test sum abs2

* Generalize sum frule

* Add sum abs2 frule

* Bump version number

* Apply suggestions from code review

Co-authored-by: willtebbutt <[email protected]>

* Add fall back for pre-Julia 1.2

Co-authored-by: willtebbutt <[email protected]>
  • Loading branch information
sethaxen and willtebbutt authored Jun 30, 2020
1 parent 693e51b commit 1a3e34a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.0"
version = "0.7.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
33 changes: 28 additions & 5 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
##### `sum`
#####

function frule((_, ẋ), ::typeof(sum), x)
return sum(x), sum(ẋ)
function frule((_, ẋ), ::typeof(sum), x; dims=:)
return sum(x; dims=dims), sum(ẋ; dims=dims)
end

function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
y = sum(sum, x; dims=dims)
function sum_pullback(ȳ)
# broadcasting the two works out the size no-matter `dims`
Expand All @@ -18,10 +18,33 @@ function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
return y, sum_pullback
end

function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:)
function frule(
(_, _, ẋ),
::typeof(sum),
::typeof(abs2),
x::AbstractArray{T};
dims=:,
) where {T<:Union{Real,Complex}}
y = sum(abs2, x; dims=dims)
∂y = if dims isa Colon
2 * real(dot(x, ẋ))
elseif VERSION v"1.2" # multi-iterator mapreduce introduced in v1.2
2 * mapreduce(_realconjtimes, +, x, ẋ; dims=dims)
else
2 * sum(_realconjtimes.(x, ẋ); dims=dims)
end
return y, ∂y
end

function rrule(
::typeof(sum),
::typeof(abs2),
x::AbstractArray{T};
dims=:,
) where {T<:Union{Real,Complex}}
y = sum(abs2, x; dims=dims)
function sum_abs2_pullback(ȳ)
return (NO_FIELDS, DoesNotExist(), @thunk(2 .* x))
return (NO_FIELDS, DoesNotExist(), @thunk(2 .* real.(ȳ) .* x))
end
return y, sum_abs2_pullback
end
59 changes: 36 additions & 23 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,41 @@
@testset "Maps and Reductions" begin
@testset "sum" begin
@testset "Vector" begin
M = 3
frule_test(sum, (randn(M), randn(M)))
rrule_test(sum, randn(), (randn(M), randn(M)))
end
@testset "Matrix" begin
M, N = 3, 4
frule_test(sum, (randn(M, N), randn(M, N)))
rrule_test(sum, randn(), (randn(M, N), randn(M, N)))
end
@testset "Array{T, 3}" begin
M, N, P = 3, 7, 11
frule_test(sum, (randn(M, N, P), randn(M, N, P)))
rrule_test(sum, randn(), (randn(M, N, P), randn(M, N, P)))
end
@testset "keyword arguments" begin
n = 4
X = randn(n, n+1)
y, pullback = rrule(sum, X; dims=2)
= randn(size(y))
_, x̄_ad = pullback(ȳ)
x̄_fd = only(j′vp(central_fdm(5, 1), x->sum(x, dims=2), ȳ, X))
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
sizes = (3, 4, 7)
@testset "dims = $dims" for dims in (:, 1)
fkwargs = (dims=dims,)
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
s = sizes[1:N]
x = randn(T, s...)
= randn(T, s...)
= randn(T, s...)
y = sum(x; dims=dims)
Δy = randn(eltype(y), size(y)...)
frule_test(sum, (x, ẋ); fkwargs=fkwargs)
rrule_test(sum, Δy, (x, x̄); fkwargs=fkwargs)
end
end
end # sum

@testset "sum abs2" begin
sizes = (3, 4, 7)
@testset "dims = $dims" for dims in (:, 1)
fkwargs = (dims=dims,)
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
s = sizes[1:N]
x, ẋ, x̄ = randn(T, s...), randn(T, s...), randn(T, s...)
y = sum(abs2, x; dims=dims)
Δy = randn(eltype(y), size(y)...)
@testset "frule" begin
# can't use frule_test here because it doesn't yet ignore nothing tangents
y_ad, ẏ_ad = frule((Zero(), Zero(), ẋ), sum, abs2, x; dims=dims)
@test y_ad == y
ẏ_fd = jvp(_fdm, z -> sum(abs2, z; dims=dims), (x, ẋ))
@test ẏ_ad ẏ_fd
end
@testset "rrule" begin
rrule_test(sum, Δy, (abs2, nothing), (x, x̄); fkwargs=fkwargs)
end
end
end
end # sum abs2
end

2 comments on commit 1a3e34a

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17249

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.1 -m "<description of version>" 1a3e34a888e129e1408ad7a85a86066926729669
git push origin v0.7.1

Please sign in to comment.