Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ITensors] ITensor wrapping NamedDimsArray #1268

Merged
merged 24 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ using TupleTools
include("algorithm.jl")
include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("Unwrap/src/Unwrap.jl")
using .Unwrap
include("RankFactorization/src/RankFactorization.jl")
using .RankFactorization: RankFactorization
include("TensorAlgebra/src/TensorAlgebra.jl")
using .TensorAlgebra: TensorAlgebra
include("DiagonalArrays/src/DiagonalArrays.jl")
Expand All @@ -38,8 +42,6 @@ include("SortedSets/src/SortedSets.jl")
using .SortedSets
include("TagSets/src/TagSets.jl")
using .TagSets
include("Unwrap/src/Unwrap.jl")
using .Unwrap

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module NamedDimsArraysAdaptExt
include("adapt_structure.jl")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using Adapt: Adapt, adapt
using NDTensors.NamedDimsArrays: AbstractNamedDimsArray, dimnames, named, unname
function Adapt.adapt_structure(to, na::AbstractNamedDimsArray)
return named(adapt(to, unname(na)), dimnames(na))
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@eval module $(gensym())
using Test: @test, @testset
using Adapt: adapt
using NDTensors.NamedDimsArrays: named
@testset "NamedDimsArraysAdaptExt (eltype=$elt)" for elt in (Float32, Float64)
na = named(randn(2, 2), ("i", "j"))
na_complex = adapt(Array{complex(elt)}, na)
@test na ≈ na_complex
@test eltype(na_complex) === complex(elt)
end
end
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module NamedDimsArraysTensorAlgebraExt
using ..NamedDimsArrays: NamedDimsArrays
using ...NDTensors.TensorAlgebra: TensorAlgebra

include("contract.jl")
include("fusedims.jl")
include("qr.jl")
include("eigen.jl")
include("svd.jl")
end
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using NDTensors.NamedDimsArrays: AbstractNamedDimsArray, dimnames, named, unname
using NDTensors.TensorAlgebra: contract
using NDTensors.TensorAlgebra: TensorAlgebra, contract

function TensorAlgebra.contract(
na1::AbstractNamedDimsArray, na2::AbstractNamedDimsArray, α, β; kwargs...
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
## using ..ITensors: IndexID
using LinearAlgebra: LinearAlgebra, Diagonal, Hermitian, eigen
## using ..NDTensors.DiagonalArrays: DiagonalMatrix
using ...NDTensors.NamedDimsArrays: AbstractNamedDimsArray, dimnames, name, unname
using ...NDTensors.RankFactorization: Spectrum, truncate!!
function LinearAlgebra.eigen(
na::Hermitian{T,<:AbstractNamedDimsArray{T}};
mindim=nothing,
maxdim=nothing,
cutoff=nothing,
use_absolute_cutoff=nothing,
use_relative_cutoff=nothing,
) where {T<:Union{Real,Complex}}
# TODO: Handle array wrappers around
# `AbstractNamedDimsArray` more elegantly.
d, u = eigen(Hermitian(unname(parent(na))))

# Sort by largest to smallest eigenvalues
# TODO: Replace `cpu` with `Expose` dispatch.
p = sortperm(d; rev=true, by=abs)
d = d[p]
u = u[:, p]

length_d = length(d)
truncerr = zero(Float64) # Make more generic
if any(!isnothing, (maxdim, cutoff))
d, truncerr, _ = truncate!!(
d; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff
)
length_d = length(d)
if length_d < size(u, 2)
u = u[:, 1:length_d]
end
end
spec = Spectrum(d, truncerr)

# TODO: Handle array wrappers more generally.
names_a = dimnames(parent(na))
# TODO: Make this more generic, handle `dag`, etc.
l = randname(names_a[1]) # IndexID(rand(UInt64), "", 0)
r = randname(names_a[2]) # IndexID(rand(UInt64), "", 0)
names_d = (l, r)
nd = named(Diagonal(d), names_d)
names_u = (names_a[2], r)
nu = named(u, names_u)
return nd, nu, spec
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using ...NDTensors.TensorAlgebra: fusedims, splitdims

function TensorAlgebra.fusedims(na::AbstractNamedDimsArray, fusions::Pair...)
# TODO: generalize to multiple fused groups of dimensions
@assert isone(length(fusions))
fusion = only(fusions)

split_names = first(fusion)
fused_name = last(fusion)

split_dims = map(split_name -> findfirst(isequal(split_name), dimnames(na)), split_names)
fused_dim = findfirst(isequal(fused_name), dimnames(na))
@assert isnothing(fused_dim)

unfused_dims = Tuple.(setdiff(1:ndims(na), split_dims))
partitioned_perm = (unfused_dims..., split_dims)

a_fused = fusedims(unname(na), partitioned_perm...)
names_fused = (setdiff(dimnames(na), split_names)..., fused_name)
return named(a_fused, names_fused)
end

function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...)
fused_names = first.(splitters)
split_namedlengths = last.(splitters)
splitters_unnamed = map(splitters) do splitter
fused_name, split_namedlengths = splitter
fused_dim = findfirst(isequal(fused_name), dimnames(na))
split_lengths = unname.(split_namedlengths)
return fused_dim => split_lengths
end
a_split = splitdims(unname(na), splitters_unnamed...)
names_split = Any[tuple.(dimnames(na))...]
for splitter in splitters
fused_name, split_namedlengths = splitter
fused_dim = findfirst(isequal(fused_name), dimnames(na))
split_names = name.(split_namedlengths)
names_split[fused_dim] = split_names
end
names_split = reduce((x, y) -> (x..., y...), names_split)
return named(a_split, names_split)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# using ..ITensors: IndexID
using LinearAlgebra: LinearAlgebra, qr
using ...NDTensors.NamedDimsArrays: AbstractNamedDimsArray, dimnames, name, randname, unname
function LinearAlgebra.qr(na::AbstractNamedDimsArray; positive=nothing)
# TODO: Make this more systematic.
names_a = dimnames(na)
# TODO: Create a `TensorAlgebra.qr`.
q, r = qr(unname(na))
# TODO: Use `sim` or `rand(::IndexID)`.
name_qr = randname(names_a[1]) # IndexID(rand(UInt64), "", 0)
# TODO: Make this GPU-friendly.
nq = named(Matrix(q), (names_a[1], name_qr))
nr = named(Matrix(r), (name_qr, names_a[2]))
return nq, nr
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using LinearAlgebra: LinearAlgebra, svd
using ...NDTensors.RankFactorization: Spectrum, truncate!!
function LinearAlgebra.svd(
na::AbstractNamedDimsArray;
mindim=nothing,
maxdim=nothing,
cutoff=nothing,
use_absolute_cutoff=nothing,
use_relative_cutoff=nothing,
alg=nothing,
min_blockdim=nothing,
)
# TODO: Handle array wrappers around
# `AbstractNamedDimsArray` more elegantly.
USV = svd(unname(na))
u, s, v = USV.U, USV.S, USV.Vt

# Sort by largest to smallest eigenvalues
# TODO: Replace `cpu` with `Expose` dispatch.
p = sortperm(s; rev=true, by=abs)
u = u[:, p]
s = s[p]
v = v[p, :]

s² = s .^ 2
length_s = length(s)
truncerr = zero(Float64) # Make more generic
if any(!isnothing, (maxdim, cutoff))
s², truncerr, _ = truncate!!(
s²; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff
)
length_s = length(s²)
# TODO: Avoid this if they are already the
# correct size.
u = u[:, 1:length_s]
s = s[1:length_s]
v = v[1:length_s, :]
end
spec = Spectrum(s², truncerr)

# TODO: Handle array wrappers more generally.
names_a = dimnames(na)
# TODO: Make this more generic, handle `dag`, etc.
l = randname(names_a[1]) # IndexID(rand(UInt64), "", 0)
r = randname(names_a[2]) # IndexID(rand(UInt64), "", 0)
names_u = (names_a[1], l)
nu = named(u, names_u)
names_s = (l, r)
ns = named(Diagonal(s), names_s)
names_v = (r, names_a[2])
nv = named(v, names_v)
return nu, ns, nv, spec
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
using Test: @test, @testset
@eval module $(gensym())
using Test: @test, @testset, @test_broken
using NDTensors.NamedDimsArrays: named, unname
using NDTensors.TensorAlgebra: TensorAlgebra

@testset "NamedDimsArraysTensorAlgebraExt" begin
using LinearAlgebra: qr
@testset "NamedDimsArraysTensorAlgebraExt contract (eltype=$(elt))" for elt in (
Float32, ComplexF32, Float64, ComplexF64
)
i = named(2, "i")
j = named(2, "j")
k = named(2, "k")
na1 = randn(i, j)
na2 = randn(j, k)
na1 = randn(elt, i, j)
na2 = randn(elt, j, k)
na_dest = TensorAlgebra.contract(na1, na2)
@test eltype(na_dest) === elt
@test unname(na_dest, (i, k)) ≈ unname(na1) * unname(na2)
end
@testset "NamedDimsArraysTensorAlgebraExt QR (eltype=$(elt))" for elt in (
Float32, ComplexF32, Float64, ComplexF64
)
di = 2
dj = 2
i = named(di, "i")
j = named(dj, "j")
na = randn(elt, i, j)
@test_broken error("QR not implemented yet")
# q, r = qr(na)
end
end
10 changes: 10 additions & 0 deletions NDTensors/src/NamedDimsArrays/src/BaseExtension/BaseExtension.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module BaseExt
replace(collection, replacements::Pair...) = Base.replace(collection, replacements...)
function replace(collection::Tuple, replacement::Pair...)
if VERSION < v"1.7"
# TODO: Add to `Compat.jl` or delete when we drop Julia 1.6 support.
return Tuple(Base.replace(collect(collection), replacements...))
end
return Base.replace(collection, replacements...)
end
end
6 changes: 6 additions & 0 deletions NDTensors/src/NamedDimsArrays/src/NamedDimsArrays.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
module NamedDimsArrays
include("BaseExtension/BaseExtension.jl")

include("traits.jl")
include("randname.jl")
include("abstractnamedint.jl")
include("abstractnamedunitrange.jl")
include("abstractnameddimsarray.jl")
include("namedint.jl")
include("namedunitrange.jl")
include("nameddimsarray.jl")
include("constructors.jl")
include("tensoralgebra.jl")

# Extensions
include("../ext/NamedDimsArraysAdaptExt/src/NamedDimsArraysAdaptExt.jl")
include("../ext/NamedDimsArraysTensorAlgebraExt/src/NamedDimsArraysTensorAlgebraExt.jl")
end
25 changes: 19 additions & 6 deletions NDTensors/src/NamedDimsArrays/src/abstractnameddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ isnamed(::AbstractNamedDimsArray) = true
# Helper function, move to `utils.jl`.
named_tuple(t::Tuple, names) = ntuple(i -> named(t[i], names[i]), length(t))

# TODO: Should `axes` output named axes or not?
# TODO: Use the proper type, `namedaxistype(a)`.
Base.axes(a::AbstractNamedDimsArray) = named_tuple(axes(unname(a)), dimnames(a))
# Base.axes(a::AbstractNamedDimsArray) = named_tuple(axes(unname(a)), dimnames(a))
Base.axes(a::AbstractNamedDimsArray) = axes(unname(a))
namedaxes(a::AbstractNamedDimsArray) = named.(axes(unname(a)), dimnames(a))
# TODO: Use the proper type, `namedlengthtype(a)`.
Base.size(a::AbstractNamedDimsArray) = length.(axes(a))
Base.size(a::AbstractNamedDimsArray) = size(unname(a))
namedsize(a::AbstractNamedDimsArray) = named.(size(unname(a)), dimnames(a))
Base.getindex(a::AbstractNamedDimsArray, I...) = unname(a)[I...]
function Base.setindex!(a::AbstractNamedDimsArray, x, I...)
unname(a)[I...] = x
Expand All @@ -46,7 +50,10 @@ rename(a::AbstractNamedDimsArray, names) = named(unname(a), names)

# replacenames(a, :i => :a, :j => :b)
# `rename` in `NamedPlus.jl`.
replacenames(a::AbstractNamedDimsArray, names::Pair) = error("Not implemented yet")
function replacenames(na::AbstractNamedDimsArray, replacements::Pair...)
# `BaseExtension.replace` needed for `Tuple` support on Julia 1.6 and older.
return named(unname(na), BaseExtension.replace(dimnames(na), replacements...))
end

# Either define new names or replace names
setnames(a::AbstractArray, names) = named(a, names)
Expand All @@ -60,16 +67,22 @@ function get_name_perm(a::AbstractNamedDimsArray, names::Tuple)
return getperm(dimnames(a), names)
end

# Ambiguity error
function get_name_perm(a::AbstractNamedDimsArray, names::Tuple{})
@assert iszero(ndims(a))
return ()
end

function get_name_perm(
a::AbstractNamedDimsArray, namedints::Tuple{Vararg{AbstractNamedInt}}
)
return getperm(size(a), namedints)
return getperm(namedsize(a), namedints)
end

function get_name_perm(
a::AbstractNamedDimsArray, namedaxes::Tuple{Vararg{AbstractNamedUnitRange}}
a::AbstractNamedDimsArray, new_namedaxes::Tuple{Vararg{AbstractNamedUnitRange}}
)
return getperm(axes(a), namedaxes)
return getperm(namedaxes(a), new_namedaxes)
end

# Indexing
Expand Down
48 changes: 48 additions & 0 deletions NDTensors/src/NamedDimsArrays/src/constructors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using Random: AbstractRNG, default_rng

# TODO: Use `AbstractNamedUnitRange`, determine the `AbstractNamedDimsArray`
# from a default value. Useful for distinguishing between `NamedDimsArray`
# and `ITensor`.
# Convenient constructors
default_eltype() = Float64
for f in [:rand, :randn]
@eval begin
function Base.$f(
rng::AbstractRNG, elt::Type{<:Number}, dims::Tuple{NamedInt,Vararg{NamedInt}}
)
a = $f(rng, elt, unname.(dims))
return named(a, name.(dims))
end
function Base.$f(
rng::AbstractRNG, elt::Type{<:Number}, dim1::NamedInt, dims::Vararg{NamedInt}
)
return $f(rng, elt, (dim1, dims...))
end
Base.$f(elt::Type{<:Number}, dims::Tuple{NamedInt,Vararg{NamedInt}}) =
$f(default_rng(), elt, dims)
Base.$f(elt::Type{<:Number}, dim1::NamedInt, dims::Vararg{NamedInt}) =
$f(elt, (dim1, dims...))
Base.$f(dims::Tuple{NamedInt,Vararg{NamedInt}}) = $f(default_eltype(), dims)
Base.$f(dim1::NamedInt, dims::Vararg{NamedInt}) = $f((dim1, dims...))
end
end
for f in [:zeros, :ones]
@eval begin
function Base.$f(elt::Type{<:Number}, dims::Tuple{NamedInt,Vararg{NamedInt}})
a = $f(elt, unname.(dims))
return named(a, name.(dims))
end
function Base.$f(elt::Type{<:Number}, dim1::NamedInt, dims::Vararg{NamedInt})
return $f(elt, (dim1, dims...))
end
Base.$f(dims::Tuple{NamedInt,Vararg{NamedInt}}) = $f(default_eltype(), dims)
Base.$f(dim1::NamedInt, dims::Vararg{NamedInt}) = $f((dim1, dims...))
end
end
function Base.fill(value, dims::Tuple{NamedInt,Vararg{NamedInt}})
a = fill(value, unname.(dims))
return named(a, name.(dims))
end
function Base.fill(value, dim1::NamedInt, dims::Vararg{NamedInt})
return fill(value, (dim1, dims...))
end
Loading
Loading