Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rrule for stack #681

Merged
merged 7 commits into from
Nov 11, 2022
Merged

rrule for stack #681

merged 7 commits into from
Nov 11, 2022

Conversation

CarloLucibello
Copy link
Contributor

@CarloLucibello CarloLucibello commented Oct 15, 2022

stack will land in julia v1.9 thanks to JuliaLang/julia#43334

cc @mcabbott

@mcabbott
Copy link
Member

mcabbott commented Oct 15, 2022

What I think this misses is that multi-dim containers are allowed, and dims=: their dimensions -- it's not the same as any one integer dims:

julia> stack([(1,2,3) (4,5,6); (7,8,9) (10,11,12)]) |> summary
"3×2×2 Array{Int64, 3}"

julia> stack([(1,2,3) (4,5,6); (7,8,9) (10,11,12)]; dims=2) |> summary
"3×4 Matrix{Int64}"

julia> stack([(1,2,3) (4,5,6); (7,8,9) (10,11,12)]; dims=1) |> summary
"4×3 Matrix{Int64}"

What I sketched on while fidding with #671 was this, but not sure I tested it carefully. Using eachslice ought to mean that second derivatives work.

function frule((_, ẋ), ::typeof(stack), x; dims = :)
    return stack(x; dims), stack(ẋ; dims)
end

# Other iterable X also allowed, maybe this should be wider?
function rrule(::typeof(stack), X::AbstractArray; dims = :)
    Y = stack(X; dims)
    sdims = if dims isa Colon
        N = ndims(Y) - ndims(X)
        X isa AbstractVector ? ndims(Y) : ntuple(i -> i + N, ndims(X))
    else
        dims
    end
    project = ProjectTo(x)
    function stack_pullback(dY)
        dX = collect(eachslice(unthunk(dY); dims = sdims))
        return (NoTangent(), project(dX))
    end
    return Y, stack_pullback
end

# This wants #671, but ought to work with Zygote already?
function rrule(config, ::typeof(stack), f, args...; dims = :)
    y, unmap = rrule_via_ad(config, map, f, args...)
    z, unstack = rrule(stack, y)
    function stack_pullback_f(dz)
        _, dy = unstack(dz)
        _, df, dargs... = unmap(dy)
        return (NoTangent(), df, dargs...)
    end
    return z, stack_pullback_f
end

I guess multiple container dimensions here needs the new eachslice, also 1.9 but not in Compat yet, maybe that's OK. I believe that returning views is also OK, and probably all the cat rules should do this.

Handling generic containers prob. needs something like Base._iterator_axes:

_ndims(x) = _ndims(x, IteratorSize(x))
_ndims(x, ::HasLength) = 1
_ndims(x, ::HasShape{N}) where N = N

src/ChainRules.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
test/rulesets/Base/array.jl Outdated Show resolved Hide resolved
test/rulesets/Base/array.jl Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member

Something fails on CI but not locally. The testing setup hides the error, and doesn't print a useful line number:

stack: Error During Test at /home/runner/work/ChainRules.jl/ChainRules.jl/test/rulesets/Base/array.jl:420
[156](https://github.com/JuliaDiff/ChainRules.jl/actions/runs/3301482615/jobs/5446999788#step:6:159)
  Got exception outside of a @test
[157](https://github.com/JuliaDiff/ChainRules.jl/actions/runs/3301482615/jobs/5446999788#step:6:160)

@CarloLucibello
Copy link
Contributor Author

Something fails on CI but not locally.

thanks Michael, there was a missing import in the tests. I think this is ready to go if tests pass.

@CarloLucibello
Copy link
Contributor Author

failures not related

@CarloLucibello
Copy link
Contributor Author

@mcabbott is there something blocking this?

@CarloLucibello
Copy link
Contributor Author

bump

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine, I think a few things can be compacted.

src/ChainRules.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
@mcabbott mcabbott merged commit 1597bcc into JuliaDiff:main Nov 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants