Skip to content

Commit

Permalink
Simplify and test adjoint(::BlockSparseMatrix)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Nov 7, 2024
1 parent 3838f62 commit 229b857
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using BlockArrays: BlockArrays, Block, BlockedUnitRange, blockedrange, blocklength
using Dictionaries: Dictionary
using ..SparseArrayDOKs: SparseArrayDOK
using ..GradedAxes: dual

# TODO: Delete this.
## using BlockArrays: blocks
Expand Down Expand Up @@ -118,3 +119,12 @@ blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B
## # TODO: Preserve GPU data!
## return BlockSparseArray{elt}(undef, axes)
## end

# Avoid proliferating wrapper types
function Base.adjoint(A::BlockSparseMatrix)
return BlockSparseMatrix(adjoint(blocks(A)), dual.(reverse(axes(A))))
end

function Base.transpose(A::BlockSparseMatrix)
return BlockSparseMatrix(transpose(blocks(A)), reverse(axes(A)))
end
22 changes: 22 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,28 @@ include("TestBlockSparseArraysUtils.jl")
a[3, 3] = NaN
@test isnan(norm(a))
end
@testset "Adjoint and transpose" begin
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
@views for b in [Block(1, 2), Block(2, 1)]
a[b] = randn(elt, size(a[b]))
end

ad = a'
@test ad isa BlockSparseMatrix
@test ad' == a
@test ad[Block(1, 1)] == adjoint(a[Block(1, 1)])
@test ad[Block(1, 2)] == adjoint(a[Block(2, 1)])
@test ad[1, 1] == conj(a[1, 1])
@test ad[1, 2] == conj(a[2, 1])

at = transpose(a)
@test at isa BlockSparseMatrix
@test traspose(at) == a
@test at[Block(1, 1)] == transpose(a[Block(1, 1)])
@test at[Block(1, 2)] == transpose(a[Block(2, 1)])
@test at[1, 1] == a[1, 1]
@test at[1, 2] == a[2, 1]
end
@testset "Tensor algebra" begin
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
@views for b in [Block(1, 2), Block(2, 1)]
Expand Down

0 comments on commit 229b857

Please sign in to comment.