-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix broadcasting into buffers #1488
Conversation
Fixes #254 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed this PR on call. It looks mostly good, but there are some outstanding points:
- remove extraneous comments
- name anonymous functions
- can you share the error that occurs when
@non_differentiable Base.Broadcast.Broadcasted(::Nothing)
is removed?
Thank you for the review! I will address the first two points soon, and for the last one I don't seem to be able to reproduce any error, so I'll remove it. Perhaps it came up due to inadvertently breaking something while debugging. I also wanted to ask, since I removed |
I also blindly removed this line
Should I restore it? |
That would be best. What worries me more is that removing that |
Zygote doesn't yet have an adjoint for I've gone back and added some broken tests to point out where adjoints are missing for Having done a fair amount of work on this, I'm not happy with the implementation yet. I think a lot of shortcomings of missing adjoints could be fixed by the following approach:
Would this approach be sound? Any ideas? |
I would prefer to even ditch the trait and just check for a set of known good types like the ones you listed to determine if |
I actually don't have time to work on the improvements I to this pr I suggested, but in order to wrap up the changes I made there are two points to address:
using Zygote
Zygote.gradient(collect(1:10)) do x
b = Zygote.Buffer(x)
tmp1 = sum(copy(b))
copyto!(b, fill(30))
tmp2 = sum(copy(b))
copyto!(b, [2i for i in 1:5])
tmp3 = sum(copy(b))
return tmp1 + tmp2 + tmp3
end # ERROR: Buffer is frozen |
I think it's for the case where you have a buffer of non-numbers. Examples would be a buffer of differentiable structs, or a buffer of arrays. That neither of these cases were tested is bad, but also not uncommon for Zygote (which historically has poor test coverage in general).
My understanding is that differentiable arguments which are mutated should not have gradients returned for correctness reasons. Instead, a copy is kept in the mutable gradient cache managed by |
Thanks! I added back the zeroing out of |
I added tests for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for implementing a pretty tricky feature, @lxvm !
I said that this pr would fix #254 and I just wanted to say that its MWE is slightly broken using Zygote
f = (du, u, p, t) -> du .= 0
(y, p, λ, t) = ([1.02634, 0.909691], [1.5, 1.0, 3.0, 1.0], [0.973655, 1.09031], 10.0)
_dy, back = Zygote.pullback(y) do u
out_ = Zygote.Buffer(u)
f(out_, u, p, t)
copy(out_)
end
dλ[:] = vec(back(λ)[1]) # ERROR: MethodError: no method matching vec(::Nothing) The good news is that the gradient through the overwritten buffer is |
Hi,
I'm using Zygote as an AD backend in Integrals.jl and while I was writing tests I noticed I couldn't assign a number to length-1
Buffer
using broadcasting. I think this is because the method signature for the pullback onmaterialize!
is too restrictive, sincecopyto!
allows for arbitrary iterators on the rhs ofbuf .= itr
. I also added a test for a MWE.PR Checklist
Update: A more complete MWE brings up a second issue:
MWE 2
Update 2: I added an adjoint for copyto! that fixes the MWE, however I'll try to add a generic adjoint for
copyto!(buffer, itr)
nextUpdate 3: I started about this the wrong way and the manual has details here on how to bypass broadcasting machinery, so an adjoint for
Base.materialize!
will have to be discarded and has to be replaced by an adjoint forcopyto!(buffer, broadcasted)
Update 4: I finished writing an adjoint and added a test for broadcasted assignment to a buffer from a generator. I'll happily incorporate any feedback and improvements.