-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensor] Allow general objects are storage backends for Tensor (#1206)
- Loading branch information
Showing
8 changed files
with
236 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Combiner | ||
promote_rule(::Type{<:Combiner}, arraytype::Type{<:MatrixOrArrayStorage}) = arraytype | ||
|
||
# Generic AbstractArray code | ||
function contract( | ||
array1::MatrixOrArrayStorage, | ||
labels1, | ||
array2::MatrixOrArrayStorage, | ||
labels2, | ||
labelsR=contract_labels(labels1, labels2), | ||
) | ||
output_array = contraction_output(array1, labels1, array2, labels2, labelsR) | ||
contract!(output_array, labelsR, array1, labels1, array2, labels2) | ||
return output_array | ||
end | ||
|
||
function contraction_output( | ||
array1::MatrixOrArrayStorage, array2::MatrixOrArrayStorage, indsR | ||
) | ||
arraytypeR = contraction_output_type(typeof(array1), typeof(array2), indsR) | ||
return NDTensors.similar(arraytypeR, indsR) | ||
end | ||
|
||
function contraction_output_type( | ||
arraytype1::Type{<:MatrixOrArrayStorage}, arraytype2::Type{<:MatrixOrArrayStorage}, inds | ||
) | ||
return similartype(promote_type(arraytype1, arraytype2), inds) | ||
end | ||
|
||
function contraction_output( | ||
array1::MatrixOrArrayStorage, | ||
labelsarray1, | ||
array2::MatrixOrArrayStorage, | ||
labelsarray2, | ||
labelsoutput_array, | ||
) | ||
# TODO: Maybe use `axes` here to be more generic, for example for BlockArrays? | ||
indsoutput_array = contract_inds( | ||
size(array1), labelsarray1, size(array2), labelsarray2, labelsoutput_array | ||
) | ||
output_array = contraction_output(array1, array2, indsoutput_array) | ||
return output_array | ||
end | ||
|
||
# Required interface for specific AbstractArray types | ||
function contract!( | ||
arrayR::MatrixOrArrayStorage, | ||
labelsR, | ||
array1::MatrixOrArrayStorage, | ||
labels1, | ||
array2::MatrixOrArrayStorage, | ||
labels2, | ||
) | ||
props = ContractionProperties(labels1, labels2, labelsR) | ||
compute_contraction_properties!(props, array1, array2, arrayR) | ||
# TODO: Change this to just `contract!`, or maybe `contract_ttgt!`? | ||
_contract!(arrayR, array1, array2, props) | ||
return arrayR | ||
end | ||
|
||
function permutedims!( | ||
output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function | ||
) | ||
@strided output_array .= f.(output_array, permutedims(array, perm)) | ||
return output_array | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Used for dispatch to distinguish from Tensors wrapping TensorStorage. | ||
# Remove once TensorStorage is removed. | ||
const ArrayStorage{T,N} = Union{ | ||
Array{T,N},ReshapedArray{T,N},SubArray{T,N},PermutedDimsArray{T,N},StridedView{T,N} | ||
} | ||
const MatrixStorage{T} = Union{ | ||
ArrayStorage{T,2}, | ||
Transpose{T}, | ||
Adjoint{T}, | ||
Symmetric{T}, | ||
Hermitian{T}, | ||
UpperTriangular{T}, | ||
LowerTriangular{T}, | ||
UnitUpperTriangular{T}, | ||
UnitLowerTriangular{T}, | ||
Diagonal{T}, | ||
} | ||
const MatrixOrArrayStorage{T} = Union{MatrixStorage{T},ArrayStorage{T}} | ||
|
||
const ArrayStorageTensor{T,N,S,I} = Tensor{T,N,S,I} where {S<:ArrayStorage{T,N}} | ||
const MatrixStorageTensor{T,S,I} = Tensor{T,2,S,I} where {S<:MatrixStorage{T}} | ||
const MatrixOrArrayStorageTensor{T,S,I} = | ||
Tensor{T,N,S,I} where {N,S<:MatrixOrArrayStorage{T}} | ||
|
||
function Tensor(storage::MatrixOrArrayStorageTensor, inds::Tuple) | ||
return Tensor(NeverAlias(), storage, inds) | ||
end | ||
|
||
function Tensor(as::AliasStyle, storage::MatrixOrArrayStorage, inds::Tuple) | ||
return Tensor{eltype(storage),length(inds),typeof(storage),typeof(inds)}( | ||
as, storage, inds | ||
) | ||
end | ||
|
||
function getindex(tensor::MatrixOrArrayStorageTensor, I::Integer...) | ||
return storage(tensor)[I...] | ||
end | ||
|
||
function setindex!(tensor::MatrixOrArrayStorageTensor, v, I::Integer...) | ||
storage(tensor)[I...] = v | ||
return tensor | ||
end | ||
|
||
function contraction_output( | ||
tensor1::MatrixOrArrayStorageTensor, tensor2::MatrixOrArrayStorageTensor, indsR | ||
) | ||
tensortypeR = contraction_output_type(typeof(tensor1), typeof(tensor2), indsR) | ||
return NDTensors.similar(tensortypeR, indsR) | ||
end | ||
|
||
function contract!( | ||
tensorR::MatrixOrArrayStorageTensor, | ||
labelsR, | ||
tensor1::MatrixOrArrayStorageTensor, | ||
labels1, | ||
tensor2::MatrixOrArrayStorageTensor, | ||
labels2, | ||
) | ||
contract!(storage(tensorR), labelsR, storage(tensor1), labels1, storage(tensor2), labels2) | ||
return tensorR | ||
end | ||
|
||
function permutedims!( | ||
output_tensor::MatrixOrArrayStorageTensor, | ||
tensor::MatrixOrArrayStorageTensor, | ||
perm, | ||
f::Function, | ||
) | ||
permutedims!(storage(output_tensor), storage(tensor), perm, f) | ||
return output_tensor | ||
end | ||
|
||
# Linear algebra (matrix algebra) | ||
function Base.adjoint(tens::MatrixStorageTensor) | ||
return tensor(adjoint(storage(tens)), reverse(inds(tens))) | ||
end | ||
|
||
function LinearAlgebra.mul!( | ||
C::MatrixStorageTensor, A::MatrixStorageTensor, B::MatrixStorageTensor | ||
) | ||
mul!(storage(C), storage(A), storage(B)) | ||
return C | ||
end | ||
|
||
function LinearAlgebra.svd(tens::MatrixStorageTensor) | ||
F = svd(storage(tens)) | ||
U, S, V = F.U, F.S, F.Vt | ||
i, j = inds(tens) | ||
# TODO: Make this more general with a `similar_ind` function, | ||
# so the dimension can be determined from the length of `S`. | ||
min_ij = dim(i) ≤ dim(j) ? i : j | ||
α = sim(min_ij) # similar_ind(i, space(S)) | ||
β = sim(min_ij) # similar_ind(i, space(S)) | ||
Utensor = tensor(U, (i, α)) | ||
# TODO: Remove conversion to `Diagonal` to make more general, or make a generic `Diagonal` concept that works for `BlockSparseArray`. | ||
# Used for now to avoid introducing wrapper types. | ||
Stensor = tensor(Diagonal(S), (α, β)) | ||
Vtensor = tensor(V, (β, j)) | ||
return Utensor, Stensor, Vtensor, Spectrum(nothing, 0.0) | ||
end | ||
|
||
array(tensor::MatrixOrArrayStorageTensor) = storage(tensor) | ||
|
||
# Combiner | ||
function contraction_output( | ||
tensor1::MatrixOrArrayStorageTensor, tensor2::CombinerTensor, indsR | ||
) | ||
tensortypeR = contraction_output_type(typeof(tensor1), typeof(tensor2), indsR) | ||
return NDTensors.similar(tensortypeR, indsR) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
using NDTensors | ||
using LinearAlgebra | ||
using Test | ||
|
||
using NDTensors: storage, storagetype | ||
|
||
@testset "Tensor wrapping Array" begin | ||
is1 = (2, 3) | ||
D1 = randn(is1) | ||
|
||
is2 = (3, 4) | ||
D2 = randn(is2) | ||
|
||
T1 = tensor(D1, is1) | ||
T2 = tensor(D2, is2) | ||
|
||
@test T1[1, 1] == D1[1, 1] | ||
|
||
x = rand() | ||
T1[1, 1] = x | ||
|
||
@test T1[1, 1] == x | ||
@test array(T1) == D1 | ||
@test storagetype(T1) <: Matrix{Float64} | ||
@test storage(T1) == D1 | ||
@test eltype(T1) == eltype(D1) | ||
@test inds(T1) == is1 | ||
|
||
R = T1 * T2 | ||
@test storagetype(R) <: Matrix{Float64} | ||
@test Array(R) ≈ Array(T1) * Array(T2) | ||
|
||
T1r = randn!(similar(T1)) | ||
@test Array(T1r + T1) ≈ Array(T1r) + Array(T1) | ||
@test Array(permutedims(T1, (2, 1))) ≈ permutedims(Array(T1), (2, 1)) | ||
|
||
U, S, V = svd(T1) | ||
@test U * S * V ≈ T1 | ||
|
||
T12 = contract(T1, (1, -1), T2, (-1, 2)) | ||
@test T12 ≈ T1 * T2 | ||
|
||
D12 = contract(D1, (1, -1), D2, (-1, 2)) | ||
@test D12 ≈ Array(T12) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters