-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rrule for stack #681
rrule for stack #681
Conversation
What I think this misses is that multi-dim containers are allowed, and julia> stack([(1,2,3) (4,5,6); (7,8,9) (10,11,12)]) |> summary
"3×2×2 Array{Int64, 3}"
julia> stack([(1,2,3) (4,5,6); (7,8,9) (10,11,12)]; dims=2) |> summary
"3×4 Matrix{Int64}"
julia> stack([(1,2,3) (4,5,6); (7,8,9) (10,11,12)]; dims=1) |> summary
"4×3 Matrix{Int64}" What I sketched on while fidding with #671 was this, but not sure I tested it carefully. Using function frule((_, ẋ), ::typeof(stack), x; dims = :)
return stack(x; dims), stack(ẋ; dims)
end
# Other iterable X also allowed, maybe this should be wider?
function rrule(::typeof(stack), X::AbstractArray; dims = :)
Y = stack(X; dims)
sdims = if dims isa Colon
N = ndims(Y) - ndims(X)
X isa AbstractVector ? ndims(Y) : ntuple(i -> i + N, ndims(X))
else
dims
end
project = ProjectTo(x)
function stack_pullback(dY)
dX = collect(eachslice(unthunk(dY); dims = sdims))
return (NoTangent(), project(dX))
end
return Y, stack_pullback
end
# This wants #671, but ought to work with Zygote already?
function rrule(config, ::typeof(stack), f, args...; dims = :)
y, unmap = rrule_via_ad(config, map, f, args...)
z, unstack = rrule(stack, y)
function stack_pullback_f(dz)
_, dy = unstack(dz)
_, df, dargs... = unmap(dy)
return (NoTangent(), df, dargs...)
end
return z, stack_pullback_f
end I guess multiple container dimensions here needs the new Handling generic containers prob. needs something like _ndims(x) = _ndims(x, IteratorSize(x))
_ndims(x, ::HasLength) = 1
_ndims(x, ::HasShape{N}) where N = N |
Something fails on CI but not locally. The testing setup hides the error, and doesn't print a useful line number:
|
thanks Michael, there was a missing import in the tests. I think this is ready to go if tests pass. |
failures not related |
@mcabbott is there something blocking this? |
bump |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine, I think a few things can be compacted.
stack
will land in julia v1.9 thanks to JuliaLang/julia#43334cc @mcabbott