diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 10e7d8abb..7b070f730 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -162,6 +162,7 @@ end # For arrays, whitelist the safe ones, but always look inside Any[]: @inline wrap_chainrules_input(dxs::AbstractArray{<:Number}) = dxs @inline wrap_chainrules_input(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs +@inline wrap_chainrules_input(dxs::AbstractArray{<:Union{Nothing,T}}) where T <: Number = map(x -> x === nothing ? zero(T) : x, dxs) @inline wrap_chainrules_input(dxs::AbstractArray) = map(wrap_chainrules_input, dxs) #= diff --git a/test/chainrules.jl b/test/chainrules.jl index 7e55720de..3d5fcb035 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -420,3 +420,15 @@ end @test z2d_compiled.c.a === z2d_fallback.c.a @test z2d_compiled.c.b === z2d_fallback.c.b end + +@testset "ChainRules translation" begin + @test Zygote.wrap_chainrules_input(nothing) == ZeroTangent() + @test Zygote.wrap_chainrules_input((nothing,)) == ZeroTangent() + @test Zygote.wrap_chainrules_input([nothing]) == ZeroTangent() + @test Zygote.wrap_chainrules_input(((1.0, 2.0), 3.0)) == Tangent{Any}(Tangent{Any}(1.0, 2.0), 3.0) + @test Zygote.wrap_chainrules_input((; a = 1.0, b = 2.0)) == Tangent{Any}(a = 1.0, b = 2.0) + @test Zygote.wrap_chainrules_input(Ref(1)) == 1 + @test Zygote.wrap_chainrules_input([2.0; 4.0]) == [2.0; 4.0] + @test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]] + @test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible +end