From 229b85723562918950cd2f388a246a9becfd46ed Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 7 Nov 2024 15:49:48 -0500 Subject: [PATCH] Simplify and test `adjoint(::BlockSparseMatrix)` --- .../src/blocksparsearray/blocksparsearray.jl | 10 +++++++++ .../lib/BlockSparseArrays/test/test_basics.jl | 22 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl index 3f563f9070..95a3b0ba83 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl @@ -1,6 +1,7 @@ using BlockArrays: BlockArrays, Block, BlockedUnitRange, blockedrange, blocklength using Dictionaries: Dictionary using ..SparseArrayDOKs: SparseArrayDOK +using ..GradedAxes: dual # TODO: Delete this. ## using BlockArrays: blocks @@ -118,3 +119,12 @@ blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B ## # TODO: Preserve GPU data! ## return BlockSparseArray{elt}(undef, axes) ## end + +# Avoid proliferating wrapper types +function Base.adjoint(A::BlockSparseMatrix) + return BlockSparseMatrix(adjoint(blocks(A)), dual.(reverse(axes(A)))) +end + +function Base.transpose(A::BlockSparseMatrix) + return BlockSparseMatrix(transpose(blocks(A)), reverse(axes(A))) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index ad4b5449bc..cbdac811b1 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -98,6 +98,28 @@ include("TestBlockSparseArraysUtils.jl") a[3, 3] = NaN @test isnan(norm(a)) end + @testset "Adjoint and transpose" begin + a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) + @views for b in [Block(1, 2), Block(2, 1)] + a[b] = randn(elt, size(a[b])) + end + + ad = a' + @test ad isa BlockSparseMatrix + @test ad' == a + @test ad[Block(1, 1)] == adjoint(a[Block(1, 1)]) + @test ad[Block(1, 2)] == adjoint(a[Block(2, 1)]) + @test ad[1, 1] == conj(a[1, 1]) + @test ad[1, 2] == conj(a[2, 1]) + + at = transpose(a) + @test at isa BlockSparseMatrix + @test traspose(at) == a + @test at[Block(1, 1)] == transpose(a[Block(1, 1)]) + @test at[Block(1, 2)] == transpose(a[Block(2, 1)]) + @test at[1, 1] == a[1, 1] + @test at[1, 2] == a[2, 1] + end @testset "Tensor algebra" begin a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) @views for b in [Block(1, 2), Block(2, 1)]