From 612961353e4a81f9861fbca9db714e86f30ad0a3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Aug 2023 16:30:20 +0100 Subject: [PATCH 1/2] added parent adjoint for LowerTriangular and UpperTriangular --- src/lib/array.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib/array.jl b/src/lib/array.jl index d4b81e3e1..37884cded 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -355,6 +355,8 @@ end @adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),) @adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),) +@adjoint parent(x::LinearAlgebra.UpperTriangular) = parent(x), ȳ -> (LinearAlgebra.UpperTriangular(ȳ),) +@adjoint parent(x::LinearAlgebra.LowerTriangular) = parent(x), ȳ -> (LinearAlgebra.LowerTriangular(ȳ),) function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) m1, n1 = size(mat1) From f0e0cafca94f3a7b76783ab2f9c0dfbfd27290da Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Aug 2023 17:27:22 +0100 Subject: [PATCH 2/2] added test for parent --- test/lib/array.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/lib/array.jl b/test/lib/array.jl index 889301c1e..a3b73aff9 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -96,3 +96,13 @@ end end end end + +@testset "parent" begin + @testset "$constructor" for constructor in [LowerTriangular, UpperTriangular] + x = randn(2, 2) + y, pb = Zygote.pullback(x) do x + sum(parent(constructor(2 .* x))) + end + @test first(pb(one(y))) ≈ constructor(2 * ones(2, 2)) + end +end