Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid NaN-propagation in scalar rules #551

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

Conversation

sethaxen
Copy link
Member

As proposed in JuliaDiff/ChainRules.jl#576 (comment), this PR makes zero (co)tangents behave as strong zeros in rules defined by @scalar_rule, so that regardless of the value of the partial, if an input (co)tangent is zero, then its product with the partial is also zero.

Before:

julia> frule((NoTangent(), 0.0), sqrt, 0.0)
(0.0, NaN)

This PR:

julia> frule((NoTangent(), 0.0), sqrt, 0.0)
(0.0, 0.0)

This feature is similar to ForwardDiff's NaN-safe mode, which the docs note is 5-10% slower in their benchmarks. However, this benchmark doesn't indicate a consistent performance decrease:

using ChainRules, ChainRulesCore, BenchmarkTools, Random

myhypot(a, b, c) = hypot(a, b, c)
@scalar_rule myhypot(a::Real, b::Real, c::Real) @setup(z = inv(Ω)) (z * a, z * b, z * c)
x = rand(MersenneTwister(42), 1000)

struct MyRuleConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} end
function ChainRulesCore.rrule_via_ad(cfg::MyRuleConfig, f, args...; kwargs...)
    return rrule(cfg, f, args...; kwargs...)
end

jvp(f, x, ẋ) = frule(MyRuleConfig(), (NoTangent(), ẋ), f, x)

function j′vp(f, ȳ, x...)
    y, back = rrule(MyRuleConfig(), f, x...)
    return map(unthunk, Base.tail(back(ȳ)))
end

suite = BenchmarkGroup()
suite["jvp"] = BenchmarkGroup()
suite["jvp"]["inv"] = @benchmarkable jvp(inv, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["sqrt"] = @benchmarkable jvp(sqrt, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["cbrt"] = @benchmarkable jvp(cbrt, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log"] = @benchmarkable jvp(log, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log2"] = @benchmarkable jvp(log2, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log10"] = @benchmarkable jvp(log10, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log1p"] = @benchmarkable jvp(log1p, $(Ref(-1.0))[], $(Ref(0.0))[])

suite["j′vp"] = BenchmarkGroup()
suite["j′vp"]["inv"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], inv, $x)
suite["j′vp"]["sqrt"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], sqrt, $x)
suite["j′vp"]["cbrt"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], cbrt, $x)
suite["j′vp"]["log"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], log, $x)
suite["j′vp"]["log2"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], log2, $x)  
suite["j′vp"]["log10"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], log10, $x)
suite["j′vp"]["log1p"] = @benchmarkable j′vp(sum, $(Ref(-1.0))[], log1p, $x)
suite["j′vp"]["myhypot"] =
    @benchmarkable j′vp(myhypot, $(Ref(0.0))[], $(Ref(0.0))[], $(Ref(0.0))[], $(Ref(0.0))[])

tune!(suite)
results = run(suite);
mean(results)

Before:

2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "jvp" => 7-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  "cbrt" => TrialEstimate(2.442 ns)
	  "log" => TrialEstimate(2.221 ns)
	  "sqrt" => TrialEstimate(2.361 ns)
	  "log2" => TrialEstimate(2.453 ns)
	  "log1p" => TrialEstimate(2.572 ns)
	  "log10" => TrialEstimate(2.675 ns)
	  "inv" => TrialEstimate(1.243 ns)
  "j′vp" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  "cbrt" => TrialEstimate(11.483 μs)
	  "log" => TrialEstimate(9.677 μs)
	  "sqrt" => TrialEstimate(6.867 μs)
	  "log2" => TrialEstimate(9.708 μs)
	  "log1p" => TrialEstimate(12.228 μs)
	  "log10" => TrialEstimate(11.265 μs)
	  "myhypot" => TrialEstimate(5.421 ns)
	  "inv" => TrialEstimate(5.240 μs)

This PR:

2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "jvp" => 7-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  "cbrt" => TrialEstimate(2.574 ns)
	  "log" => TrialEstimate(2.683 ns)
	  "sqrt" => TrialEstimate(1.477 ns)
	  "log2" => TrialEstimate(2.475 ns)
	  "log1p" => TrialEstimate(2.933 ns)
	  "log10" => TrialEstimate(2.964 ns)
	  "inv" => TrialEstimate(1.245 ns)
  "j′vp" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  "cbrt" => TrialEstimate(11.344 μs)
	  "log" => TrialEstimate(9.320 μs)
	  "sqrt" => TrialEstimate(6.841 μs)
	  "log2" => TrialEstimate(9.516 μs)
	  "log1p" => TrialEstimate(12.917 μs)
	  "log10" => TrialEstimate(9.932 μs)
	  "myhypot" => TrialEstimate(6.075 ns)
	  "inv" => TrialEstimate(4.947 μs)

src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
@sethaxen
Copy link
Member Author

The current tests fail because

  1. zero(::NotImplemented) throws a NotImplementedException
  2. zero(NoTangent()) is a ZeroTangent(), so this change breaks inferrability. This causes the ChainRules integration tests to fail for copysign and ldexp, for which one of the partials is NoTangent()

@oxinabox
Copy link
Member

Sorry for the slow reply

Not yet having looked at the code, but the general idea that the zero from @scalar_rule should be a strong zero is correct.
And indeed it used to be a ZeroTangent(), but we changed it due to issues with this causes type widening.
(These issues would be resolved if we had JuliaLang/julia#38241)

zero(::NotImplemented) throws a NotImplementedException

This feels like it should be a ZeroTangent()
As it is one of the things that is in fact safe to do even if the tangent wasn't implemented.
Since we don't care about it

zero(NoTangent()) is a ZeroTangent()

This could well be changed to zero(NoTangent()) isa NoTangent)

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i am down with this.
Make the test pass then merge when happy

@devmotion
Copy link
Member

devmotion commented Oct 13, 2022

I just ran into the sqrt issue. Do you think you'll be able to finish and merge this PR soonish, @sethaxen? Or would you like some help here?

@sethaxen
Copy link
Member Author

@devmotion thanks for the reminder; this slipped off the end of my to-do list. I'll prioritize finishing this in the next few days. I'll let you know when it's ready for a final review.

sethaxen and others added 2 commits October 13, 2022 16:28
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@codecov-commenter
Copy link

codecov-commenter commented Oct 13, 2022

Codecov Report

Base: 93.17% // Head: 93.22% // Increases project coverage by +0.05% 🎉

Coverage data is based on head (3cdd108) compared to base (ed9a007).
Patch coverage: 100.00% of modified lines in pull request are covered.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #551      +/-   ##
==========================================
+ Coverage   93.17%   93.22%   +0.05%     
==========================================
  Files          15       15              
  Lines         908      915       +7     
==========================================
+ Hits          846      853       +7     
  Misses         62       62              
Impacted Files Coverage Δ
src/rule_definition_tools.jl 96.95% <100.00%> (+0.09%) ⬆️
src/tangent_types/abstract_zero.jl 96.29% <100.00%> (+0.29%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@devmotion
Copy link
Member

devmotion commented Oct 13, 2022

zero(::NotImplemented) throws a NotImplementedException

This feels like it should be a ZeroTangent()
As it is one of the things that is in fact safe to do even if the tangent wasn't implemented.
Since we don't care about it

It seems it was changed at some point but maybe it should be reverted. It seems it causes type inference issues since now the return type can't be inferred if NotImplemented is involved.

Edit: I think I misread the logs, it seems thr type inference issues are actually caused by zero(::NoTangent) = ZeroTangent()?

Edit2: Just noticed that this was already discussed above (eg #551 (comment)). Sorry for the noise, I guess I should not have commented on my phone without checking the PR carefully.

@sethaxen
Copy link
Member Author

@devmotion this should be ready for review now.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, I only have some suggestions for the tests.

So it seems in the end it is not necessary to change zero(::NotImplemented) and zero(::Type{<:NotImplemented})? Maybe to be sure add tests with a @scalar_rule where one partial is NotImplemented (similar to some of the rules in SpecialFunctions) and test them with ZeroTangent, NoTangent, and 0.0 (similar to the tests for suminv)?

test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
@sethaxen
Copy link
Member Author

So it seems in the end it is not necessary to change zero(::NotImplemented) and zero(::Type{<:NotImplemented})? Maybe to be sure add tests with a @scalar_rule where one partial is NotImplemented (similar to some of the rules in SpecialFunctions) and test them with ZeroTangent, NoTangent, and 0.0 (similar to the tests for suminv)?

It's not necessary in the sense that we're keeping the old behavior for NotImplemented, The right thing to do is probably @oxinabox's suggestion in #551 (comment) (making zero(::NotImplemented) = ZeroTangent(), but this causes inferred types of scalar rules with non-implemented (co)tangents to be type unions. See e.g. https://github.com/JuliaDiff/ChainRulesCore.jl/actions/runs/3255105837/jobs/5344073780#step:6:204.

Or do you think that's acceptable for these rules?

@sethaxen
Copy link
Member Author

bump @devmotion

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mainly good.

However, could you add more tests? For instance, involving Tangent, Thunk, and NotImplemented? To me it seems, it should be possible to run into the type inference issue for all types of partials where typeof(partial) !== typeof(zero(partial)). I'm not sure what's the best solution to this issue (in the case of NotImplemented we could define zero(x::NotImplemented) = x, consistent with our definition of +(::NotImplemented, ::NotImplemented), but that's not possible in the other cases). Maybe there's no way around these small unions. Or maybe we should just deal with a subset of tangent types here and use some

_zero(x) = x
_zero(x::Number) = zero(x)
_zero(x::ZeroTangent) = x
_zero(x::NoTangent) = x

here to limit it to cases where we are (somewhat) certain that typeof(partial) === typeof(zero(partial))?

@sethaxen
Copy link
Member Author

However, could you add more tests? For instance, involving Tangent, Thunk, and NotImplemented? To me it seems, it should be possible to run into the type inference issue for all types of partials where typeof(partial) !== typeof(zero(partial)).

Tangent type for tangents of scalar primals is not supported, so type-inferrability here is not a concern. This PR does not change the behavior for NotImplemented and Thunk (iszero either errors or returns false for both, respectively). For NotImplemented, type-inferrability is not a problem since an error is raised anyways. For Thunk, I added type-inferrability tests.

A future PR could work out how to support hard zero tangents for NotImplemented and Thunk.

@devmotion
Copy link
Member

(iszero either errors or returns false for both, respectively).

I don't follow - I was concerned about partials of these types ad for these only the type of zero(partial) should be relevant, and in particular we do not call iszero on these (I think).
With the current implementation of NotImplemented, zero will throw an error - but changing it to return ZeroTangent(), as proposed above, would cause type inference issues.
Maybe it's safest to replace partial with zero(partial) only if its type is Number, ZeroTangent, or NoTangent.

@oxinabox
Copy link
Member

@sethaxen can we merge this? I would like to, it seems good enough for now.
We can always make follow ups

@sethaxen
Copy link
Member Author

@sethaxen can we merge this? I would like to, it seems good enough for now. We can always make follow ups

I agree that feature-wise it's good enough for a merge, but we should soon address the issues @devmotion has raised.

But before merging, it seems we fail IntegrationTests with ChainRulesTestUtils and Diffractor; I can look into the former to make sure they're not due to a problem in this PR. The failures in ChainRules are due to broken tests unexpectedly passing.

@sethaxen
Copy link
Member Author

@oxinabox it seems the ChainRulesTestUtils failure was a fluke and now passes. The only failure I can't explain is Diffractor. Do you think that should block this PR?

@ToucheSir
Copy link
Contributor

Are Diffractor's tests even stable at this point considering CI runs on nightly? The failure here looks completely unrelated, though it is kind of crazy that compiler APIs changed enough in just 3 days (last Diffractor master build to last downstream CI run here) to make them fail.

@mcabbott
Copy link
Member

Not sure re Diffractor. Some chance that some 2nd order test doesn't like the extra branches?

What's the argument for iszero(Δs_i) ? zero(∂s_i) : ∂s_i instead of ifelse? The branch does not defer any calculation.

One other comment is that rules using derivatives_given_output will now differ, unless they re-implement this iszero check. Should it then be a function like strong_times(x, y) or is the hassle of having another name here worse than just re-doing it elsewhere (e.g. in CR?)

@devmotion
Copy link
Member

What's the argument for iszero(Δs_i) ? zero(∂s_i) : ∂s_i instead of ifelse? The branch does not defer any calculation.

IIRC the reason was JuliaDiff/ChainRules.jl#599 (comment).

@sethaxen
Copy link
Member Author

One other comment is that rules using derivatives_given_output will now differ, unless they re-implement this iszero check. Should it then be a function like strong_times(x, y) or is the hassle of having another name here worse than just re-doing it elsewhere (e.g. in CR?)

I don't really understand how derivatives_given_output is used. Do you have a suggestion for how this could be done in that function instead of during propagation?

@mcabbott
Copy link
Member

mcabbott commented Jan 25, 2023

I see. If there's a trade-off between pleasing 2nd order AD by avoiding branches, and a marginal speed change should anyone ever differentiate BigFloat, then we should pick AD. (Or perhaps have a branched method strong_times(::BigFloat, ::BigFloat).)

derivatives_given_output is used e.g. here:

https://github.com/JuliaDiff/ChainRules.jl/blob/9adf759bc63432dc518ccf499d6938fc5a217113/src/rulesets/Base/mapreduce.jl#L89-L90

I wondered whether dyₖ * conj(∂yₖ∂xᵢ) there ought to be ChainRulesCore. strong_times(dyₖ, conj(∂yₖ∂xᵢ)) (or perhaps strong_dot absorbs the conj) to share code. In broadcasting-like use, for simple functions, there is some chance we ought to care about SIMD here.

@sethaxen
Copy link
Member Author

sethaxen commented Feb 5, 2023

@mcabbott I added strong_mul and strong_muladd (since conjugation is performed elsewhere in the macro). Is this what you had in mind?

@sethaxen
Copy link
Member Author

Bump @mcabbott

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants