diff --git a/Project.toml b/Project.toml index 4081239c3..20984db45 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.7" +version = "1.16.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 434c5b843..14d083b18 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -292,12 +292,14 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) end end - # Apply `muladd` iteratively. + # Apply `strong_muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. (∂s_1, Δs_1), _∂s_Δs_tail = Iterators.peel(zip(_∂s, Δs)) - init_expr = :($∂s_1 * $Δs_1) + # zero gradients are treated as hard zeros. This avoids propagation of NaNs when + # partials are non-finite + init_expr = :(strong_mul($∂s_1, $Δs_1)) summed_∂_mul_Δs = foldl(_∂s_Δs_tail; init=init_expr) do ex, (∂s_i, Δs_i) - :(muladd($∂s_i, $Δs_i, $ex)) + :(strong_muladd($∂s_i, $Δs_i, $ex)) end return :($proj($summed_∂_mul_Δs)) end @@ -615,3 +617,26 @@ function _constrain_and_name(arg::Expr, _) return error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type + +""" + strong_mul(x, y) + +Multiply `x` and `y`. If `iszero(y)`, treat `y` as a hard zero even for non-finite `x`. +""" +strong_mul(x, y) = ifelse(iszero(y), zero(x), x) * y + +""" + strong_muladd(x, y, z) + +Multiply `x` and `y` and add to `z`. If `iszero(y)`, treat `y` as a hard zero even for +non-finite `x`. +""" +strong_muladd(x, y, z) = muladd(ifelse(iszero(y), zero(x), x), y, z) + +# slightly faster for BigFloats +strong_mul(x::BigFloat, y::BigFloat) = (iszero(y) ? zero(x) : x) * y +strong_muladd(x::BigFloat, y::BigFloat, z) = muladd((iszero(y) ? zero(x) : x), y, z) + +# avoid raising errors for NotImplemented +strong_mul(x::NotImplemented, y) = (iszero(y) ? zero(x) : x) * y +strong_muladd(x::NotImplemented, y, z) = muladd((iszero(y) ? zero(x) : x), y, z) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 986fc9854..f32c94809 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -87,3 +87,6 @@ arguments. ``` """ struct NoTangent <: AbstractZero end + +Base.zero(::NoTangent) = NoTangent() +Base.zero(::Type{NoTangent}) = NoTangent() diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 5a177566d..764e65115 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -224,6 +224,61 @@ end end end + @testset "strong_mul" begin + ni = @not_implemented("not implemented!") + xvals = ( + 5, + randn(Float32), + randn(Float64), + randn(ComplexF64), + big(randn()), + ZeroTangent(), + ) + yvals = (3, randn(Float32), randn(Float64), randn(ComplexF64), big(randn())) + yzerovals = (0, 0.0f0, 0.0, 0.0im, big(0.0), ZeroTangent(), NoTangent()) + @testset for x in xvals + x === ni || @testset for y in yvals + @test @inferred(ChainRulesCore.strong_mul(x, y)) == x * y + end + @testset for y in yzerovals + @test @inferred(ChainRulesCore.strong_mul(x, y)) == zero(x * y) + if x isa AbstractFloat + @test ChainRulesCore.strong_mul(oftype(x, Inf), y) == zero(x * y) + @test ChainRulesCore.strong_mul(oftype(x, -Inf), y) == zero(x * y) + @test ChainRulesCore.strong_mul(oftype(x, NaN), y) == zero(x * y) + end + end + end + end + + @testset "strong_muladd" begin + ni = @not_implemented("not implemented!") + xvals = ( + 5, + randn(Float32), + randn(Float64), + randn(ComplexF64), + big(randn()), + ZeroTangent(), + ) + yvals = (3, randn(Float32), randn(Float64), randn(ComplexF64), big(randn())) + zvals = (7, randn(Float32), randn(Float64), randn(ComplexF64), big(randn())) + yzerovals = (0, 0.0f0, 0.0, 0.0 * im, big(0.0), ZeroTangent(), NoTangent()) + @testset for x in xvals, z in zvals + x === ni || @testset for y in yvals + @test @inferred(ChainRulesCore.strong_muladd(x, y, z)) == muladd(x, y, z) + end + @testset for y in yzerovals + @test @inferred(ChainRulesCore.strong_muladd(x, y, z)) == z + if x isa AbstractFloat + @test ChainRulesCore.strong_muladd(oftype(x, Inf), y, z) == z + @test ChainRulesCore.strong_muladd(oftype(x, -Inf), y, z) == z + @test ChainRulesCore.strong_muladd(oftype(x, NaN), y, z) == z + end + end + end + end + @testset "@scalar_rule" begin @testset "@scalar_rule with multiple output" begin simo(x) = (x, 2x) @@ -256,6 +311,60 @@ end @test (NoTangent(), 0.0 - 1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) end + @testset "@scalar_rule strong zero (co)tangents" begin + suminv(x, y) = inv(x) + inv(y) + @scalar_rule suminv(x, y) (-(inv(x)^2), -(inv(y)^2)) + + @test @inferred(frule((NoTangent(), 1.0, 1.0), suminv, 0.0, 1.0)) === + (Inf, -Inf) + @test @inferred(frule((NoTangent(), ZeroTangent(), 1.0), suminv, 0.0, 1.0)) === + (Inf, -1.0) + @test @inferred(frule((NoTangent(), NoTangent(), 1.0), suminv, 0.0, 1.0)) === + (Inf, -1.0) + @test @inferred(frule((NoTangent(), 0.0, 1.0), suminv, 0.0, 1.0)) === + (Inf, -1.0) + + @test @inferred(frule((NoTangent(), 1.0, 1.0), suminv, 1.0, 0.0)) === + (Inf, -Inf) + @test @inferred(frule((NoTangent(), 1.0, ZeroTangent()), suminv, 1.0, 0.0)) === + (Inf, -1.0) + @test @inferred(frule((NoTangent(), 1.0, NoTangent()), suminv, 1.0, 0.0)) === + (Inf, -1.0) + @test @inferred(frule((NoTangent(), 1.0, 0.0), suminv, 1.0, 0.0)) === + (Inf, -1.0) + + @test @inferred(rrule(suminv, 0.0, 1.0)[2](1.0)) === (NoTangent(), -Inf, -1.0) + @test @inferred(rrule(suminv, 0.0, 1.0)[2](ZeroTangent())) === + (NoTangent(), ZeroTangent(), ZeroTangent()) + @test @inferred(rrule(suminv, 0.0, 1.0)[2](NoTangent())) === + (NoTangent(), NoTangent(), NoTangent()) + @test @inferred(rrule(suminv, 0.0, 1.0)[2](0.0)) === (NoTangent(), 0.0, 0.0) + + @test @inferred(rrule(suminv, 1.0, 0.0)[2](1.0)) === (NoTangent(), -1.0, -Inf) + @test @inferred(rrule(suminv, 1.0, 0.0)[2](ZeroTangent())) === + (NoTangent(), ZeroTangent(), ZeroTangent()) + @test @inferred(rrule(suminv, 1.0, 0.0)[2](NoTangent())) === + (NoTangent(), NoTangent(), NoTangent()) + @test @inferred(rrule(suminv, 1.0, 0.0)[2](0.0)) === (NoTangent(), 0.0, 0.0) + + # cases not covered + t = @thunk(0.0) + @inferred(frule((NoTangent(), t, 1.0), suminv, 0.0, 1.0)) + @inferred(frule((NoTangent(), 1.0, t), suminv, 1.0, 0.0)) + @inferred(rrule(suminv, 0.0, 1.0)[2](t)) + @inferred(rrule(suminv, 1.0, 0.0)[2](t)) + @test_broken rrule(suminv, 0.0, 1.0)[2](t) == (NoTangent(), 0.0, 0.0) + @test_broken rrule(suminv, 1.0, 0.0)[2](t) == (NoTangent(), 0.0, 0.0) + @test_broken frule((NoTangent(), t, 1.0), suminv, 0.0, 1.0) == (Inf, -1.0) + @test_broken frule((NoTangent(), 1.0, t), suminv, 1.0, 0.0) == (Inf, -1.0) + + ni = @not_implemented("not implemented!") + @test_broken rrule(suminv, 0.0, 1.0)[2](ni) == (NoTangent(), 0.0, 0.0) + @test_broken rrule(suminv, 1.0, 0.0)[2](ni) == (NoTangent(), 0.0, 0.0) + @test_broken frule((NoTangent(), ni, 1.0), suminv, 0.0, 1.0) == (Inf, -1.0) + @test_broken frule((NoTangent(), 1.0, ni), suminv, 1.0, 0.0) == (Inf, -1.0) + end + @testset "Regression tests against #276 and #265" begin # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265 diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index e3d8642e4..6a1af49e5 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -42,9 +42,7 @@ end @test broadcastable(z) isa Ref{ZeroTangent} @test zero(@thunk(3)) === z - @test zero(NoTangent()) === z @test zero(ZeroTangent) === z - @test zero(NoTangent) === z @test zero(Tangent{Tuple{Int,Int}}((1, 2))) === z for f in (transpose, adjoint, conj) @test f(z) === z @@ -94,6 +92,8 @@ @testset "NoTangent" begin dne = NoTangent() + @test zero(dne) === NoTangent() + @test zero(NoTangent) === NoTangent() @test dne + dne == dne @test dne + 1 == 1 @test 1 + dne == 1