Skip to content

Commit

Permalink
Prefer numeric zero over ZeroTangent for numeric arrays
Browse files Browse the repository at this point in the history
Add tests

Fix tests
  • Loading branch information
BioTurboNick committed Sep 23, 2024
1 parent f17ba75 commit d4bf828
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ end
# For arrays, whitelist the safe ones, but always look inside Any[]:
@inline wrap_chainrules_input(dxs::AbstractArray{<:Number}) = dxs
@inline wrap_chainrules_input(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
@inline wrap_chainrules_input(dxs::AbstractArray{<:Union{Nothing,T}}) where T <: Number = map(x -> x === nothing ? zero(T) : x, dxs)
@inline wrap_chainrules_input(dxs::AbstractArray) = map(wrap_chainrules_input, dxs)

#=
Expand Down
12 changes: 12 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,15 @@ end
@test z2d_compiled.c.a === z2d_fallback.c.a
@test z2d_compiled.c.b === z2d_fallback.c.b
end

@testset "ChainRules translation" begin
@test Zygote.wrap_chainrules_input(nothing) == ZeroTangent()
@test Zygote.wrap_chainrules_input((nothing,)) == ZeroTangent()
@test Zygote.wrap_chainrules_input([nothing]) == ZeroTangent()
@test Zygote.wrap_chainrules_input(((1.0, 2.0), 3.0)) == Tangent{Any}(Tangent{Any}(1.0, 2.0), 3.0)
@test Zygote.wrap_chainrules_input((; a = 1.0, b = 2.0)) == Tangent{Any}(a = 1.0, b = 2.0)
@test Zygote.wrap_chainrules_input(Ref(1)) == 1
@test Zygote.wrap_chainrules_input([2.0; 4.0]) == [2.0; 4.0]
@test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]]
@test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible
end

0 comments on commit d4bf828

Please sign in to comment.