diff --git a/Project.toml b/Project.toml index b32ccd742..c6f8c9543 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.44.7" +version = "1.45.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -20,7 +20,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Adapt = "3.4.0" ChainRulesCore = "1.15.3" ChainRulesTestUtils = "1.5" -Compat = "3.42.0, 4" +Compat = "3.46, 4.2" FiniteDifferences = "0.12.20" GPUArraysCore = "0.1.0" IrrationalConstants = "0.1.1" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 9f63eeb11..28e73c166 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -22,6 +22,8 @@ import ChainRulesCore: rrule, frule # Experimental: using ChainRulesCore: derivatives_given_output +using Compat: stack + # numbers that we know commute under multiplication const CommutativeMulNumber = Union{Real,Complex} diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 2461c5561..4ae424151 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -610,3 +610,30 @@ function _extrema_dims(x, dims) end return y, extrema_pullback_dims end + +##### +##### `stack` +##### + +function frule((_, ẋ), ::typeof(stack), x; dims::Union{Integer, Colon} = :) + return stack(x; dims), stack(ẋ; dims) +end + +# Other iterable X also allowed, maybe this should be wider? +function rrule(::typeof(stack), X::AbstractArray; dims::Union{Integer, Colon} = :) + Y = stack(X; dims) + sdims = if dims isa Colon + N = ndims(Y) - ndims(X) + X isa AbstractVector ? ndims(Y) : ntuple(i -> i + N, ndims(X)) + else + dims + end + project = ProjectTo(X) + function stack_pullback(Δ) + dY = unthunk(Δ) + dY isa AbstractZero && return (NoTangent(), dY) + dX = collect(eachslice(dY; dims = sdims)) + return (NoTangent(), project(reshape(dX, project.axes))) + end + return Y, stack_pullback +end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 921c81534..afdeb4cb9 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -416,3 +416,33 @@ end B = hcat(A[:,:,1], A[:,:,1]) @test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1] end + +@testset "stack" begin + # vector container + xs = [rand(3, 4), rand(3, 4)] + test_frule(stack, xs) + test_frule(stack, xs; fkwargs=(dims=1,)) + + test_rrule(stack, xs, check_inferred=false) + test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false) + test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false) + test_rrule(stack, xs, fkwargs=(dims=3,), check_inferred=false) + + # multidimensional container + ms = [rand(2,3) for _ in 1:4, _ in 1:5]; + + if VERSION > v"1.9-" # this needs new eachslice, not yet in Compat + test_rrule(stack, ms, check_inferred=false) + end + test_rrule(stack, ms, fkwargs=(dims=1,), check_inferred=false) + test_rrule(stack, ms, fkwargs=(dims=3,), check_inferred=false) + + # non-array inner objects + ts = [Tuple(rand(3)) for _ in 1:4, _ in 1:2]; + + if VERSION > v"1.9-" + test_rrule(stack, ts, check_inferred=false) + end + test_rrule(stack, ts, fkwargs=(dims=1,), check_inferred=false) + test_rrule(stack, ts, fkwargs=(dims=2,), check_inferred=false) +end diff --git a/test/runtests.jl b/test/runtests.jl index 71444f388..a9f25c55c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils using Adapt using Base.Broadcast: broadcastable using ChainRules +using ChainRules: stack using ChainRulesCore using ChainRulesTestUtils using ChainRulesTestUtils: rand_tangent, _fdm