Skip to content

Commit

Permalink
[ITensors] ITensor wrapping NamedDimsArray (#1268)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Nov 25, 2023
1 parent 4c5c991 commit fbf4a1d
Show file tree
Hide file tree
Showing 208 changed files with 1,656 additions and 271 deletions.
30 changes: 18 additions & 12 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,33 @@ using Strided
using TimerOutputs
using TupleTools

# TODO: Define an `AlgorithmSelection` module
# TODO: List types, macros, and functions being used.
include("algorithm.jl")
include("SetParameters/src/SetParameters.jl")
include("lib/AlgorithmSelection/src/AlgorithmSelection.jl")
using .AlgorithmSelection: AlgorithmSelection
include("lib/BaseExtensions/src/BaseExtensions.jl")
using .BaseExtensions: BaseExtensions
include("lib/SetParameters/src/SetParameters.jl")
using .SetParameters
include("TensorAlgebra/src/TensorAlgebra.jl")
include("lib/BroadcastMapConversion/src/BroadcastMapConversion.jl")
using .BroadcastMapConversion: BroadcastMapConversion
include("lib/Unwrap/src/Unwrap.jl")
using .Unwrap
include("lib/RankFactorization/src/RankFactorization.jl")
using .RankFactorization: RankFactorization
include("lib/TensorAlgebra/src/TensorAlgebra.jl")
using .TensorAlgebra: TensorAlgebra
include("DiagonalArrays/src/DiagonalArrays.jl")
include("lib/DiagonalArrays/src/DiagonalArrays.jl")
using .DiagonalArrays
include("BlockSparseArrays/src/BlockSparseArrays.jl")
include("lib/BlockSparseArrays/src/BlockSparseArrays.jl")
using .BlockSparseArrays
include("NamedDimsArrays/src/NamedDimsArrays.jl")
include("lib/NamedDimsArrays/src/NamedDimsArrays.jl")
using .NamedDimsArrays: NamedDimsArrays
include("SmallVectors/src/SmallVectors.jl")
include("lib/SmallVectors/src/SmallVectors.jl")
using .SmallVectors
include("SortedSets/src/SortedSets.jl")
include("lib/SortedSets/src/SortedSets.jl")
using .SortedSets
include("TagSets/src/TagSets.jl")
include("lib/TagSets/src/TagSets.jl")
using .TagSets
include("Unwrap/src/Unwrap.jl")
using .Unwrap

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down

This file was deleted.

This file was deleted.

14 changes: 0 additions & 14 deletions NDTensors/src/NamedDimsArrays/src/abstractnamedunitrange.jl

This file was deleted.

38 changes: 0 additions & 38 deletions NDTensors/src/NamedDimsArrays/test/test_basic.jl

This file was deleted.

23 changes: 0 additions & 23 deletions NDTensors/src/TensorAlgebra/src/fusedims.jl

This file was deleted.

29 changes: 0 additions & 29 deletions NDTensors/src/algorithm.jl

This file was deleted.

2 changes: 2 additions & 0 deletions NDTensors/src/blocksparse/contract.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .AlgorithmSelection: Algorithm, @Algorithm_str

function contract(
tensor1::BlockSparseTensor,
labelstensor1,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module AlgorithmSelection
include("algorithm.jl")
end
39 changes: 39 additions & 0 deletions NDTensors/src/lib/AlgorithmSelection/src/algorithm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Algorithm
A type representing an algorithm backend for a function.
For example, a function might have multiple backend algorithm
implementations, which internally are selected with an `Algorithm` type.
This allows users to extend functionality with a new algorithm but
use the same interface.
"""
struct Algorithm{Alg,Kwargs<:NamedTuple}
kwargs::Kwargs
end

Algorithm{Alg}(kwargs::NamedTuple) where {Alg} = Algorithm{Alg,typeof(kwargs)}(kwargs)
Algorithm{Alg}(; kwargs...) where {Alg} = Algorithm{Alg}(NamedTuple(kwargs))
Algorithm(s; kwargs...) = Algorithm{Symbol(s)}(NamedTuple(kwargs))

Algorithm(alg::Algorithm) = alg

# TODO: Use `SetParameters`.
algorithm_string(::Algorithm{Alg}) where {Alg} = string(Alg)

function Base.show(io::IO, alg::Algorithm)
return print(io, "Algorithm type ", algorithm_string(alg), ", ", alg.kwargs)
end
Base.print(io::IO, alg::Algorithm) = print(io, algorithm_string(alg), ", ", alg.kwargs)

"""
@Algorithm_str
A convenience macro for writing [`Algorithm`](@ref) types, typically used when
adding methods to a function that supports multiple algorithm
backends.
"""
macro Algorithm_str(s)
return :(Algorithm{$(Expr(:quote, Symbol(s)))})
end
3 changes: 3 additions & 0 deletions NDTensors/src/lib/AlgorithmSelection/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
12 changes: 12 additions & 0 deletions NDTensors/src/lib/AlgorithmSelection/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
@eval module $(gensym())
using Test: @test, @testset
using NDTensors.AlgorithmSelection: Algorithm, @Algorithm_str
@testset "AlgorithmSelection" begin
@test Algorithm"alg"() isa Algorithm{:alg}
@test Algorithm("alg") isa Algorithm{:alg}
@test Algorithm(:alg) isa Algorithm{:alg}
alg = Algorithm"alg"(; x=2, y=3)
@test alg isa Algorithm{:alg}
@test alg.kwargs == (; x=2, y=3)
end
end
3 changes: 3 additions & 0 deletions NDTensors/src/lib/BaseExtensions/src/BaseExtensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module BaseExtensions
include("replace.jl")
end
32 changes: 32 additions & 0 deletions NDTensors/src/lib/BaseExtensions/src/replace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
replace(collection, replacements::Pair...) = Base.replace(collection, replacements...)
@static if VERSION < v"1.7.0-DEV.15"
# https://github.com/JuliaLang/julia/pull/38216
# TODO: Add to `Compat.jl` or delete when we drop Julia 1.6 support.
# `replace` for Tuples.
function _replace(f::Base.Callable, t::Tuple, count::Int)
return if count == 0 || isempty(t)
t
else
x = f(t[1])
(x, _replace(f, Base.tail(t), count - !==(x, t[1]))...)
end
end

function replace(f::Base.Callable, t::Tuple; count::Integer=typemax(Int))
return _replace(f, t, Base.check_count(count))
end

function _replace(t::Tuple, count::Int, old_new::Tuple{Vararg{Pair}})
return _replace(t, count) do x
Base.@_inline_meta
for o_n in old_new
isequal(first(o_n), x) && return last(o_n)
end
return x
end
end

function replace(t::Tuple, old_new::Pair...; count::Integer=typemax(Int))
return _replace(t, Base.check_count(count), old_new)
end
end
3 changes: 3 additions & 0 deletions NDTensors/src/lib/BaseExtensions/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
15 changes: 15 additions & 0 deletions NDTensors/src/lib/BaseExtensions/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using SafeTestsets: @safetestset

@safetestset "BaseExtensions" begin
using NDTensors.BaseExtensions: BaseExtensions
using Test: @test, @testset
@testset "replace $(typeof(collection))" for collection in
(["a", "b", "c"], ("a", "b", "c"))
r1 = BaseExtensions.replace(collection, "b" => "d")
@test r1 == typeof(collection)(["a", "d", "c"])
@test typeof(r1) === typeof(collection)
r2 = BaseExtensions.replace(collection, "b" => "d", "a" => "e")
@test r2 == typeof(collection)(["e", "d", "c"])
@test typeof(r2) === typeof(collection)
end
end
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module BlockSparseArrays
using ..AlgorithmSelection: Algorithm, @Algorithm_str
using BlockArrays:
AbstractBlockArray,
BlockArrays,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module LinearAlgebraExt
using ...AlgorithmSelection: Algorithm, @Algorithm_str
using BlockArrays: BlockArrays, blockedrange, blocks
using ..BlockSparseArrays: SparseArray, nonzero_keys # TODO: Move to `SparseArraysExtensions` module, rename `SparseArrayDOK`.
using ..BlockSparseArrays: BlockSparseArrays, BlockSparseArray, nonzero_blockkeys
using LinearAlgebra: LinearAlgebra, Hermitian, Transpose, I, eigen, qr
using ...NDTensors: Algorithm, @Algorithm_str # TODO: Move to `AlgorithmSelector` module.
using SparseArrays: SparseArrays, SparseMatrixCSC, spzeros, sparse

# TODO: Move to `SparseArraysExtensions`.
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ include("TestBlockSparseArraysUtils.jl")
@testset "README" begin
@test include(
joinpath(
pkgdir(BlockSparseArrays), "src", "BlockSparseArrays", "examples", "README.jl"
pkgdir(BlockSparseArrays),
"src",
"lib",
"BlockSparseArrays",
"examples",
"README.jl",
),
) isa Any
end
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module BroadcastMapConversion
# Convert broadcast call to map call by capturing array arguments
# with `map_args` and creating a map function with `map_function`.
# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl.

using Base.Broadcast: Broadcasted

const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}

function map_args(bc::Broadcasted, rest...)
return (map_args(bc.args...)..., map_args(rest...)...)
end
map_args(a::AbstractArray, rest...) = (a, map_args(rest...)...)
map_args(a, rest...) = map_args(rest...)
map_args() = ()

struct MapFunction{F,Args<:Tuple}
f::F
args::Args
end
struct Arg end

# construct MapFunction
function map_function(bc::Broadcasted)
args = map_function_tuple(bc.args)
return MapFunction(bc.f, args)
end
map_function_tuple(t::Tuple{}) = t
map_function_tuple(t::Tuple) = (map_function(t[1]), map_function_tuple(Base.tail(t))...)
map_function(a::WrappedScalarArgs) = a[]
map_function(a::AbstractArray) = Arg()
map_function(a) = a

# Evaluate MapFunction
(f::MapFunction)(args...) = apply(f, args)[1]
function apply(f::MapFunction, args)
args, newargs = apply_tuple(f.args, args)
return f.f(args...), newargs
end
apply(a::Arg, args::Tuple) = args[1], Base.tail(args)
apply(a, args) = a, args
apply_tuple(t::Tuple{}, args) = t, args
function apply_tuple(t::Tuple, args)
t1, newargs1 = apply(t[1], args)
ttail, newargs = apply_tuple(Base.tail(t), newargs1)
return (t1, ttail...), newargs
end
end
14 changes: 14 additions & 0 deletions NDTensors/src/lib/BroadcastMapConversion/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@eval module $(gensym())
using Test: @test, @testset
using NDTensors.BroadcastMapConversion: map_function, map_args
@testset "BroadcastMapConversion" begin
using Base.Broadcast: Broadcasted
c = 2.2
a = randn(2, 3)
b = randn(2, 3)
bc = Broadcasted(*, (c, a))
@test copy(bc) c * a map(map_function(bc), map_args(bc)...)
bc = Broadcasted(+, (a, b))
@test copy(bc) a + b map(map_function(bc), map_args(bc)...)
end
end
File renamed without changes.
Loading

0 comments on commit fbf4a1d

Please sign in to comment.