From 224e553132f26979e7f5c34fccaf769788022299 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 30 Jun 2020 16:56:03 -0700 Subject: [PATCH] Add rules for evalpoly (#190) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add rules for evalpoly * Rename pullback * Apply suggestions from code review Co-authored-by: Lyndon White * Make backx its own method for the future * Update UniformScaling to-dos * Add matrix poly tests * Use correct indices * Reorganize * Use generated functions with fallbacks * Reimplement rules for matrices * Deactivate complex tests Until complex conventions are clarified and FiniteDifferences v0.10.0 is supported. * Add generated functions for tuple case * Add comment * Rename defs to exs * Refactor and test fallbacks * Simplify indexing * Don't store output as an intermediate * Support scalar x with matrix pi * Make extensible for other ps * Move fallback tests under rrule tests * Reorder args and remove unnecessary product * Eliminate unneeded mul and reorganize * Remove unnecessary product * Fix length of ys and wrap lines * Place final ∂yi to in loop This for some reason profiles much faster * Keep other rules consistent with vector * Unify tests * Increment version number * Try equality check outside of tuple * Approximate check due to muladd * Approximate check scalar output too * Decrement version number Co-authored-by: Lyndon White --- Project.toml | 2 +- src/rulesets/Base/base.jl | 143 +++++++++++++++++++++++++++++++++++++ test/rulesets/Base/base.jl | 35 +++++++++ 3 files changed, 179 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2a0f7c481..7388caafb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.1" +version = "0.7.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index c70af02c8..82609236a 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -150,3 +150,146 @@ function rrule(::typeof(identity), x) end return (x, identity_pullback) end + +##### +##### `evalpoly` +##### + +if VERSION ≥ v"1.4" + function frule((_, Δx, Δp), ::typeof(evalpoly), x, p) + N = length(p) + @inbounds y = p[N] + Δy = Δp[N] + @inbounds for i in (N - 1):-1:1 + Δy = muladd(Δx, y, muladd(x, Δy, Δp[i])) + y = muladd(x, y, p[i]) + end + return y, Δy + end + + function rrule(::typeof(evalpoly), x, p) + y, ys = _evalpoly_intermediates(x, p) + function evalpoly_pullback(Δy) + ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) + return NO_FIELDS, ∂x, ∂p + end + return y, evalpoly_pullback + end + + # evalpoly but storing intermediates + function _evalpoly_intermediates(x, p::Tuple) + return if @generated + N = length(p.parameters) + exs = [] + vars = [] + ex = :(p[$N]) + for i in 1:(N - 1) + yi = Symbol("y", i) + push!(vars, yi) + push!(exs, :($yi = $ex)) + ex = :(muladd(x, $yi, p[$(N - i)])) + end + push!(exs, :(y = $ex)) + Expr(:block, exs..., :(y, ($(vars...),))) + else + _evalpoly_intermediates_fallback(x, p) + end + end + function _evalpoly_intermediates_fallback(x, p::Tuple) + N = length(p) + y = p[N] + ys = (y, ntuple(N - 2) do i + return y = muladd(x, y, p[N - i]) + end...) + y = muladd(x, y, p[1]) + return y, ys + end + function _evalpoly_intermediates(x, p) + N = length(p) + @inbounds yn = one(x) * p[N] + ys = similar(p, typeof(yn), N - 1) + @inbounds ys[1] = yn + @inbounds for i in 2:(N - 1) + ys[i] = muladd(x, ys[i - 1], p[N - i + 1]) + end + @inbounds y = muladd(x, ys[N - 1], p[1]) + return y, ys + end + + # TODO: Handle following cases + # 1) x is a UniformScaling, pᵢ is a matrix + # 2) x is a matrix, pᵢ is a UniformScaling + @inline _evalpoly_backx(x, yi, ∂yi) = ∂yi * yi' + @inline _evalpoly_backx(x, yi, ∂x, ∂yi) = muladd(∂yi, yi', ∂x) + @inline _evalpoly_backx(x::Number, yi, ∂yi) = conj(dot(∂yi, yi)) + @inline _evalpoly_backx(x::Number, yi, ∂x, ∂yi) = _evalpoly_backx(x, yi, ∂yi) + ∂x + + @inline _evalpoly_backp(pi, ∂yi) = ∂yi + + function _evalpoly_back(x, p::Tuple, ys, Δy) + return if @generated + exs = [] + vars = [] + N = length(p.parameters) + for i in 2:(N - 1) + ∂pi = Symbol("∂p", i) + push!(vars, ∂pi) + push!(exs, :(∂x = _evalpoly_backx(x, ys[$(N - i)], ∂x, ∂yi))) + push!(exs, :($∂pi = _evalpoly_backp(p[$i], ∂yi))) + push!(exs, :(∂yi = x′ * ∂yi)) + end + push!(vars, :(_evalpoly_backp(p[$N], ∂yi))) # ∂pN + Expr( + :block, + :(x′ = x'), + :(∂yi = Δy), + :(∂p1 = _evalpoly_backp(p[1], ∂yi)), + :(∂x = _evalpoly_backx(x, ys[$(N - 1)], ∂yi)), + :(∂yi = x′ * ∂yi), + exs..., + :(∂p = (∂p1, $(vars...))), + :(∂x, Composite{typeof(p),typeof(∂p)}(∂p)), + ) + else + _evalpoly_back_fallback(x, p, ys, Δy) + end + end + function _evalpoly_back_fallback(x, p::Tuple, ys, Δy) + x′ = x' + ∂yi = Δy + N = length(p) + ∂p1 = _evalpoly_backp(p[1], ∂yi) + ∂x = _evalpoly_backx(x, ys[N - 1], ∂yi) + ∂yi = x′ * ∂yi + ∂p = ( + ∂p1, + ntuple(N - 2) do i + ∂x = _evalpoly_backx(x, ys[N-i-1], ∂x, ∂yi) + ∂pi = _evalpoly_backp(p[i+1], ∂yi) + ∂yi = x′ * ∂yi + return ∂pi + end..., + _evalpoly_backp(p[N], ∂yi), # ∂pN + ) + return ∂x, Composite{typeof(p),typeof(∂p)}(∂p) + end + function _evalpoly_back(x, p, ys, Δy) + x′ = x' + ∂yi = one(x′) * Δy + N = length(p) + @inbounds ∂p1 = _evalpoly_backp(p[1], ∂yi) + ∂p = similar(p, typeof(∂p1)) + @inbounds begin + ∂x = _evalpoly_backx(x, ys[N - 1], ∂yi) + ∂yi = x′ * ∂yi + ∂p[1] = ∂p1 + for i in 2:(N - 1) + ∂x = _evalpoly_backx(x, ys[N - i], ∂x, ∂yi) + ∂p[i] = _evalpoly_backp(p[i], ∂yi) + ∂yi = x′ * ∂yi + end + ∂p[N] = _evalpoly_backp(p[N], ∂yi) + end + return ∂x, ∂p + end +end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index e2c621a9f..0c19b8277 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -144,6 +144,41 @@ ) end + VERSION ≥ v"1.4" && @testset "evalpoly" begin + # test fallbacks for when code generation fails + @testset "fallbacks for $T" for T in (Float64, ComplexF64) + x, p = randn(T), Tuple(randn(T, 10)) + y_fb, ys_fb = ChainRules._evalpoly_intermediates_fallback(x, p) + y, ys = ChainRules._evalpoly_intermediates(x, p) + @test y_fb ≈ y + @test collect(ys_fb) ≈ collect(ys) + + Δy, ys = randn(T), Tuple(randn(T, 9)) + ∂x_fb, ∂p_fb = ChainRules._evalpoly_back_fallback(x, p, ys, Δy) + ∂x, ∂p = ChainRules._evalpoly_back(x, p, ys, Δy) + @test ∂x_fb ≈ ∂x + @test collect(∂p_fb) ≈ collect(∂p) + end + + @testset "x dim: $(nx), pi dim: $(np), type: $T" for T in (Float64, ComplexF64), nx in (tuple(), 3), np in (tuple(), 3) + # skip x::Matrix, pi::Number case, which is not supported by evalpoly + isempty(np) && !isempty(nx) && continue + m = 5 + sx = (nx..., nx...) + sp = (np..., np...) + x, ẋ, x̄ = randn(T, sx...), randn(T, sx...), randn(T, sx...) + p = [randn(T, sp...) for _ in 1:m] + ṗ = [randn(T, sp...) for _ in 1:m] + p̄ = [randn(T, sp...) for _ in 1:m] + Ω = evalpoly(x, p) + Ω̄ = randn(T, size(Ω)...) + frule_test(evalpoly, (x, ẋ), (p, ṗ)) + frule_test(evalpoly, (x, ẋ), (Tuple(p), Tuple(ṗ))) + rrule_test(evalpoly, Ω̄, (x, x̄), (p, p̄)) + rrule_test(evalpoly, Ω̄, (x, x̄), (Tuple(p), Tuple(p̄))) + end + end + @testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im) test_scalar(one, x) test_scalar(zero, x)