-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid NaN (co)tangents for sqrt(0) #599
base: main
Are you sure you want to change the base?
Conversation
function frule((_, Δx), ::typeof(sqrt), x::Number) | ||
Ω = sqrt(x) | ||
∂Ω = Δx / 2Ω | ||
return Ω, ifelse(iszero(Δx) & iszero(x), zero(∂Ω), ∂Ω) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason to use ifelse
instead of a ternary operator (which does not require to evaluate both branches)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My reasoning was that so long as the type of ∂Ω
is inferrable, the two branches do no extra work, and the use of &
and ifelse
both could perform better if this is used in an inner loop and potentially allow Zygote to perform better for higher order AD (since Zygote tends so be slow when hitting control flow but has a special rule for ifelse
). However, I was unable to devise a benchmark that showed a substantial difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that in older Julia versions there were cases where we could improve performance in SciML by avoiding zero
or moving it out of loops. But I couldn't reproduce this with a simple example immediately, maybe it's not relevant here and/or fixed in recent Julia versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An example where it matters:
julia> using BenchmarkTools
julia> function f(x)
s = zero(x)
for i in 1:10
s += iseven(i) ? zero(x) : x
end
return s
end
f (generic function with 1 method)
julia> function g(x)
s = zero(x)
for i in 1:10
s += ifelse(iseven(i), zero(x), x)
end
return s
end
g (generic function with 1 method)
julia> @btime f($(big"1.0"));
45.640 μs (3002 allocations: 164.17 KiB)
julia> @btime g($(big"1.0"));
56.341 μs (4002 allocations: 218.86 KiB)
This seems fine. How many other functions will need this? |
At first glance, I'm reluctant to add custom |
I opened a PR to ChainRulesCore that would supersede this one if merged: JuliaDiff/ChainRulesCore.jl#551 |
Functions like The motivating case for sqrt is I think something like |
Is this difference important though? There are plenty of cases where in a well-behaved primal function intermediate can be non-finite, resulting in introduction of julia> using StatsFuns
julia> normcdf(0.0, 1.0, Inf) # a constant function for all finite values of mu and sigma
1.0
julia> FiniteDifferences.grad(central_fdm(5, 1), x -> normcdf(0.0, x, Inf), 1.0)
(6.085449991639748e-14,)
julia> Zygote.gradient(x -> normcdf(0.0, x, Inf), 1.0)
(NaN,) This happens because the gradient of
Perhaps, but I don't see |
This PR fixes #576 by treating zero (co)tangents in
sqrt
as strong zeros.It partially fixes FluxML/Zygote.jl#1101 also, but to fix it entirely, we would need to do the same thing to the rule for
^
.Benchmark
This simple benchmark indicates that the performance decrease from this modified rule in Zygote is not extreme.