Skip to content

Commit

Permalink
[NDTensor] Allow general objects are storage backends for Tensor (#1206)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Sep 28, 2023
1 parent 10f7eee commit d56b4a7
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 8 deletions.
7 changes: 7 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ include("empty/EmptyTensor.jl")
include("empty/tensoralgebra/contract.jl")
include("empty/adapt.jl")

#####################################
# Array Tensor (experimental)
# TODO: Move to `Experimental` module.
#
include("arraytensor/arraytensor.jl")
include("arraytensor/array.jl")

#####################################
# Deprecations
#
Expand Down
66 changes: 66 additions & 0 deletions NDTensors/src/arraytensor/array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Combiner
promote_rule(::Type{<:Combiner}, arraytype::Type{<:MatrixOrArrayStorage}) = arraytype

# Generic AbstractArray code
function contract(
array1::MatrixOrArrayStorage,
labels1,
array2::MatrixOrArrayStorage,
labels2,
labelsR=contract_labels(labels1, labels2),
)
output_array = contraction_output(array1, labels1, array2, labels2, labelsR)
contract!(output_array, labelsR, array1, labels1, array2, labels2)
return output_array
end

function contraction_output(
array1::MatrixOrArrayStorage, array2::MatrixOrArrayStorage, indsR
)
arraytypeR = contraction_output_type(typeof(array1), typeof(array2), indsR)
return NDTensors.similar(arraytypeR, indsR)
end

function contraction_output_type(
arraytype1::Type{<:MatrixOrArrayStorage}, arraytype2::Type{<:MatrixOrArrayStorage}, inds
)
return similartype(promote_type(arraytype1, arraytype2), inds)
end

function contraction_output(
array1::MatrixOrArrayStorage,
labelsarray1,
array2::MatrixOrArrayStorage,
labelsarray2,
labelsoutput_array,
)
# TODO: Maybe use `axes` here to be more generic, for example for BlockArrays?
indsoutput_array = contract_inds(
size(array1), labelsarray1, size(array2), labelsarray2, labelsoutput_array
)
output_array = contraction_output(array1, array2, indsoutput_array)
return output_array
end

# Required interface for specific AbstractArray types
function contract!(
arrayR::MatrixOrArrayStorage,
labelsR,
array1::MatrixOrArrayStorage,
labels1,
array2::MatrixOrArrayStorage,
labels2,
)
props = ContractionProperties(labels1, labels2, labelsR)
compute_contraction_properties!(props, array1, array2, arrayR)
# TODO: Change this to just `contract!`, or maybe `contract_ttgt!`?
_contract!(arrayR, array1, array2, props)
return arrayR
end

function permutedims!(
output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function
)
@strided output_array .= f.(output_array, permutedims(array, perm))
return output_array
end
110 changes: 110 additions & 0 deletions NDTensors/src/arraytensor/arraytensor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Used for dispatch to distinguish from Tensors wrapping TensorStorage.
# Remove once TensorStorage is removed.
const ArrayStorage{T,N} = Union{
Array{T,N},ReshapedArray{T,N},SubArray{T,N},PermutedDimsArray{T,N},StridedView{T,N}
}
const MatrixStorage{T} = Union{
ArrayStorage{T,2},
Transpose{T},
Adjoint{T},
Symmetric{T},
Hermitian{T},
UpperTriangular{T},
LowerTriangular{T},
UnitUpperTriangular{T},
UnitLowerTriangular{T},
Diagonal{T},
}
const MatrixOrArrayStorage{T} = Union{MatrixStorage{T},ArrayStorage{T}}

const ArrayStorageTensor{T,N,S,I} = Tensor{T,N,S,I} where {S<:ArrayStorage{T,N}}
const MatrixStorageTensor{T,S,I} = Tensor{T,2,S,I} where {S<:MatrixStorage{T}}
const MatrixOrArrayStorageTensor{T,S,I} =
Tensor{T,N,S,I} where {N,S<:MatrixOrArrayStorage{T}}

function Tensor(storage::MatrixOrArrayStorageTensor, inds::Tuple)
return Tensor(NeverAlias(), storage, inds)
end

function Tensor(as::AliasStyle, storage::MatrixOrArrayStorage, inds::Tuple)
return Tensor{eltype(storage),length(inds),typeof(storage),typeof(inds)}(
as, storage, inds
)
end

function getindex(tensor::MatrixOrArrayStorageTensor, I::Integer...)
return storage(tensor)[I...]
end

function setindex!(tensor::MatrixOrArrayStorageTensor, v, I::Integer...)
storage(tensor)[I...] = v
return tensor
end

function contraction_output(
tensor1::MatrixOrArrayStorageTensor, tensor2::MatrixOrArrayStorageTensor, indsR
)
tensortypeR = contraction_output_type(typeof(tensor1), typeof(tensor2), indsR)
return NDTensors.similar(tensortypeR, indsR)
end

function contract!(
tensorR::MatrixOrArrayStorageTensor,
labelsR,
tensor1::MatrixOrArrayStorageTensor,
labels1,
tensor2::MatrixOrArrayStorageTensor,
labels2,
)
contract!(storage(tensorR), labelsR, storage(tensor1), labels1, storage(tensor2), labels2)
return tensorR
end

function permutedims!(
output_tensor::MatrixOrArrayStorageTensor,
tensor::MatrixOrArrayStorageTensor,
perm,
f::Function,
)
permutedims!(storage(output_tensor), storage(tensor), perm, f)
return output_tensor
end

# Linear algebra (matrix algebra)
function Base.adjoint(tens::MatrixStorageTensor)
return tensor(adjoint(storage(tens)), reverse(inds(tens)))
end

function LinearAlgebra.mul!(
C::MatrixStorageTensor, A::MatrixStorageTensor, B::MatrixStorageTensor
)
mul!(storage(C), storage(A), storage(B))
return C
end

function LinearAlgebra.svd(tens::MatrixStorageTensor)
F = svd(storage(tens))
U, S, V = F.U, F.S, F.Vt
i, j = inds(tens)
# TODO: Make this more general with a `similar_ind` function,
# so the dimension can be determined from the length of `S`.
min_ij = dim(i) dim(j) ? i : j
α = sim(min_ij) # similar_ind(i, space(S))
β = sim(min_ij) # similar_ind(i, space(S))
Utensor = tensor(U, (i, α))
# TODO: Remove conversion to `Diagonal` to make more general, or make a generic `Diagonal` concept that works for `BlockSparseArray`.
# Used for now to avoid introducing wrapper types.
Stensor = tensor(Diagonal(S), (α, β))
Vtensor = tensor(V, (β, j))
return Utensor, Stensor, Vtensor, Spectrum(nothing, 0.0)
end

array(tensor::MatrixOrArrayStorageTensor) = storage(tensor)

# Combiner
function contraction_output(
tensor1::MatrixOrArrayStorageTensor, tensor2::CombinerTensor, indsR
)
tensortypeR = contraction_output_type(typeof(tensor1), typeof(tensor2), indsR)
return NDTensors.similar(tensortypeR, indsR)
end
9 changes: 4 additions & 5 deletions NDTensors/src/tensor/tensor.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@

"""
Tensor{StoreT,IndsT}
A plain old tensor (with order independent
interface and no assumption of labels)
"""
struct Tensor{ElT,N,StoreT<:TensorStorage,IndsT} <: AbstractArray{ElT,N}
struct Tensor{ElT,N,StoreT,IndsT} <: AbstractArray{ElT,N}
storage::StoreT
inds::IndsT

Expand All @@ -21,8 +20,8 @@ struct Tensor{ElT,N,StoreT<:TensorStorage,IndsT} <: AbstractArray{ElT,N}
and tensor(store::TensorStorage, inds) constructors.
"""
function Tensor{ElT,N,StoreT,IndsT}(
::AllowAlias, storage::TensorStorage, inds::Tuple
) where {ElT,N,StoreT<:TensorStorage,IndsT}
::AllowAlias, storage, inds::Tuple
) where {ElT,N,StoreT,IndsT}
@assert ElT == eltype(StoreT)
@assert length(inds) == N
return new{ElT,N,StoreT,IndsT}(storage, inds)
Expand Down Expand Up @@ -74,7 +73,7 @@ end
# already (like a Vector). In the future this may be lifted
# to allow for very large tensor orders in which case Tuple
# operations may become too slow.
function Tensor(as::AliasStyle, storage::TensorStorage, inds)
function Tensor(as::AliasStyle, storage, inds)
return Tensor(as, storage, Tuple(inds))
end

Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/tensoralgebra/generic_tensor_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
function permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function=(r, t) -> t)
Base.checkdims_perm(output_tensor, tensor, perm)
error(
"`perutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function=(r, t) -> t)` not implemented for `typeof(output_tensor) = $(typeof(output_tensor))`, `typeof(tensor) = $(typeof(tensor))`, `perm = $perm`, and `f = $f`.",
"`permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function=(r, t) -> t)` not implemented for `typeof(output_tensor) = $(typeof(output_tensor))`, `typeof(tensor) = $(typeof(tensor))`, `perm = $perm`, and `f = $f`.",
)
return output_tensor
end
Expand Down
45 changes: 45 additions & 0 deletions NDTensors/test/arraytensor/arraytensor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using NDTensors
using LinearAlgebra
using Test

using NDTensors: storage, storagetype

@testset "Tensor wrapping Array" begin
is1 = (2, 3)
D1 = randn(is1)

is2 = (3, 4)
D2 = randn(is2)

T1 = tensor(D1, is1)
T2 = tensor(D2, is2)

@test T1[1, 1] == D1[1, 1]

x = rand()
T1[1, 1] = x

@test T1[1, 1] == x
@test array(T1) == D1
@test storagetype(T1) <: Matrix{Float64}
@test storage(T1) == D1
@test eltype(T1) == eltype(D1)
@test inds(T1) == is1

R = T1 * T2
@test storagetype(R) <: Matrix{Float64}
@test Array(R) Array(T1) * Array(T2)

T1r = randn!(similar(T1))
@test Array(T1r + T1) Array(T1r) + Array(T1)
@test Array(permutedims(T1, (2, 1))) permutedims(Array(T1), (2, 1))

U, S, V = svd(T1)
@test U * S * V T1

T12 = contract(T1, (1, -1), T2, (-1, 2))
@test T12 T1 * T2

D12 = contract(D1, (1, -1), D2, (-1, 2))
@test D12 Array(T12)
end
1 change: 1 addition & 0 deletions NDTensors/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ end
"emptynumber.jl",
"emptystorage.jl",
"combiner.jl",
"arraytensor/arraytensor.jl",
]
println("Running $filename")
include(filename)
Expand Down
4 changes: 2 additions & 2 deletions src/itensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ NDTensors.Dense{Float64,Array{Float64,1}}
## Accessor Functions, Index Functions and Operations
mutable struct ITensor
tensor::Tensor
function ITensor(::AllowAlias, T::Tensor{<:Any,<:Any,<:TensorStorage,<:Tuple})
function ITensor(::AllowAlias, T::Tensor{<:Any,<:Any,<:Any,<:Tuple})
@debug_check begin
is = inds(T)
if !allunique(is)
Expand Down Expand Up @@ -761,7 +761,7 @@ end
Return a view of the TensorStorage of the ITensor.
"""
storage(T::ITensor)::TensorStorage = storage(tensor(T))
storage(T::ITensor) = storage(tensor(T))

storagetype(x::ITensor) = storagetype(tensor(x))

Expand Down

0 comments on commit d56b4a7

Please sign in to comment.