-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid NaN-propagation in scalar rules #551
base: main
Are you sure you want to change the base?
Conversation
The current tests fail because
|
Sorry for the slow reply Not yet having looked at the code, but the general idea that the zero from
This feels like it should be a
This could well be changed to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, i am down with this.
Make the test pass then merge when happy
I just ran into the |
@devmotion thanks for the reminder; this slipped off the end of my to-do list. I'll prioritize finishing this in the next few days. I'll let you know when it's ready for a final review. |
Co-authored-by: David Widmann <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Codecov ReportBase: 93.17% // Head: 93.22% // Increases project coverage by
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## main #551 +/- ##
==========================================
+ Coverage 93.17% 93.22% +0.05%
==========================================
Files 15 15
Lines 908 915 +7
==========================================
+ Hits 846 853 +7
Misses 62 62
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
It seems it was changed at some point but maybe it should be reverted. It seems it causes type inference issues since now the return type can't be inferred if Edit: I think I misread the logs, it seems thr type inference issues are actually caused by Edit2: Just noticed that this was already discussed above (eg #551 (comment)). Sorry for the noise, I guess I should not have commented on my phone without checking the PR carefully. |
@devmotion this should be ready for review now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, I only have some suggestions for the tests.
So it seems in the end it is not necessary to change zero(::NotImplemented)
and zero(::Type{<:NotImplemented})
? Maybe to be sure add tests with a @scalar_rule
where one partial is NotImplemented
(similar to some of the rules in SpecialFunctions) and test them with ZeroTangent
, NoTangent
, and 0.0
(similar to the tests for suminv
)?
Co-authored-by: David Widmann <[email protected]>
It's not necessary in the sense that we're keeping the old behavior for Or do you think that's acceptable for these rules? |
bump @devmotion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mainly good.
However, could you add more tests? For instance, involving Tangent, Thunk, and NotImplemented? To me it seems, it should be possible to run into the type inference issue for all types of partials where typeof(partial) !== typeof(zero(partial))
. I'm not sure what's the best solution to this issue (in the case of NotImplemented we could define zero(x::NotImplemented) = x
, consistent with our definition of +(::NotImplemented, ::NotImplemented)
, but that's not possible in the other cases). Maybe there's no way around these small unions. Or maybe we should just deal with a subset of tangent types here and use some
_zero(x) = x
_zero(x::Number) = zero(x)
_zero(x::ZeroTangent) = x
_zero(x::NoTangent) = x
here to limit it to cases where we are (somewhat) certain that typeof(partial) === typeof(zero(partial))
?
A future PR could work out how to support hard zero tangents for |
I don't follow - I was concerned about partials of these types ad for these only the type of |
@sethaxen can we merge this? I would like to, it seems good enough for now. |
I agree that feature-wise it's good enough for a merge, but we should soon address the issues @devmotion has raised. But before merging, it seems we fail IntegrationTests with ChainRulesTestUtils and Diffractor; I can look into the former to make sure they're not due to a problem in this PR. The failures in ChainRules are due to broken tests unexpectedly passing. |
@oxinabox it seems the ChainRulesTestUtils failure was a fluke and now passes. The only failure I can't explain is Diffractor. Do you think that should block this PR? |
Are Diffractor's tests even stable at this point considering CI runs on nightly? The failure here looks completely unrelated, though it is kind of crazy that compiler APIs changed enough in just 3 days (last Diffractor master build to last downstream CI run here) to make them fail. |
Not sure re Diffractor. Some chance that some 2nd order test doesn't like the extra branches? What's the argument for One other comment is that rules using |
IIRC the reason was JuliaDiff/ChainRules.jl#599 (comment). |
I don't really understand how |
I see. If there's a trade-off between pleasing 2nd order AD by avoiding branches, and a marginal speed change should anyone ever differentiate BigFloat, then we should pick AD. (Or perhaps have a branched method
I wondered whether |
@mcabbott I added |
Bump @mcabbott |
As proposed in JuliaDiff/ChainRules.jl#576 (comment), this PR makes zero (co)tangents behave as strong zeros in rules defined by
@scalar_rule
, so that regardless of the value of the partial, if an input (co)tangent is zero, then its product with the partial is also zero.Before:
This PR:
This feature is similar to ForwardDiff's NaN-safe mode, which the docs note is 5-10% slower in their benchmarks. However, this benchmark doesn't indicate a consistent performance decrease:
Before:
This PR: