Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ITensors] ITensor wrapping NamedDimsArray #1268

Merged
merged 24 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,33 @@ using Strided
using TimerOutputs
using TupleTools

# TODO: Define an `AlgorithmSelection` module
# TODO: List types, macros, and functions being used.
include("algorithm.jl")
include("SetParameters/src/SetParameters.jl")
include("lib/AlgorithmSelection/src/AlgorithmSelection.jl")
using .AlgorithmSelection: AlgorithmSelection
include("lib/BaseExtensions/src/BaseExtensions.jl")
using .BaseExtensions: BaseExtensions
include("lib/SetParameters/src/SetParameters.jl")
using .SetParameters
include("TensorAlgebra/src/TensorAlgebra.jl")
include("lib/BroadcastMapConversion/src/BroadcastMapConversion.jl")
using .BroadcastMapConversion: BroadcastMapConversion
include("lib/Unwrap/src/Unwrap.jl")
using .Unwrap
include("lib/RankFactorization/src/RankFactorization.jl")
using .RankFactorization: RankFactorization
include("lib/TensorAlgebra/src/TensorAlgebra.jl")
using .TensorAlgebra: TensorAlgebra
include("DiagonalArrays/src/DiagonalArrays.jl")
include("lib/DiagonalArrays/src/DiagonalArrays.jl")
using .DiagonalArrays
include("BlockSparseArrays/src/BlockSparseArrays.jl")
include("lib/BlockSparseArrays/src/BlockSparseArrays.jl")
using .BlockSparseArrays
include("NamedDimsArrays/src/NamedDimsArrays.jl")
include("lib/NamedDimsArrays/src/NamedDimsArrays.jl")
using .NamedDimsArrays: NamedDimsArrays
include("SmallVectors/src/SmallVectors.jl")
include("lib/SmallVectors/src/SmallVectors.jl")
using .SmallVectors
include("SortedSets/src/SortedSets.jl")
include("lib/SortedSets/src/SortedSets.jl")
using .SortedSets
include("TagSets/src/TagSets.jl")
include("lib/TagSets/src/TagSets.jl")
using .TagSets
include("Unwrap/src/Unwrap.jl")
using .Unwrap

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down

This file was deleted.

This file was deleted.

14 changes: 0 additions & 14 deletions NDTensors/src/NamedDimsArrays/src/abstractnamedunitrange.jl

This file was deleted.

38 changes: 0 additions & 38 deletions NDTensors/src/NamedDimsArrays/test/test_basic.jl

This file was deleted.

23 changes: 0 additions & 23 deletions NDTensors/src/TensorAlgebra/src/fusedims.jl

This file was deleted.

29 changes: 0 additions & 29 deletions NDTensors/src/algorithm.jl

This file was deleted.

2 changes: 2 additions & 0 deletions NDTensors/src/blocksparse/contract.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .AlgorithmSelection: Algorithm, @Algorithm_str

function contract(
tensor1::BlockSparseTensor,
labelstensor1,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module AlgorithmSelection
include("algorithm.jl")
end
39 changes: 39 additions & 0 deletions NDTensors/src/lib/AlgorithmSelection/src/algorithm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Algorithm

A type representing an algorithm backend for a function.

For example, a function might have multiple backend algorithm
implementations, which internally are selected with an `Algorithm` type.

This allows users to extend functionality with a new algorithm but
use the same interface.
"""
struct Algorithm{Alg,Kwargs<:NamedTuple}
kwargs::Kwargs
end

Algorithm{Alg}(kwargs::NamedTuple) where {Alg} = Algorithm{Alg,typeof(kwargs)}(kwargs)
Algorithm{Alg}(; kwargs...) where {Alg} = Algorithm{Alg}(NamedTuple(kwargs))
Algorithm(s; kwargs...) = Algorithm{Symbol(s)}(NamedTuple(kwargs))

Algorithm(alg::Algorithm) = alg

# TODO: Use `SetParameters`.
algorithm_string(::Algorithm{Alg}) where {Alg} = string(Alg)

function Base.show(io::IO, alg::Algorithm)
return print(io, "Algorithm type ", algorithm_string(alg), ", ", alg.kwargs)
end
Base.print(io::IO, alg::Algorithm) = print(io, algorithm_string(alg), ", ", alg.kwargs)

"""
@Algorithm_str

A convenience macro for writing [`Algorithm`](@ref) types, typically used when
adding methods to a function that supports multiple algorithm
backends.
"""
macro Algorithm_str(s)
return :(Algorithm{$(Expr(:quote, Symbol(s)))})
end
3 changes: 3 additions & 0 deletions NDTensors/src/lib/AlgorithmSelection/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
12 changes: 12 additions & 0 deletions NDTensors/src/lib/AlgorithmSelection/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
@eval module $(gensym())
using Test: @test, @testset
using NDTensors.AlgorithmSelection: Algorithm, @Algorithm_str
@testset "AlgorithmSelection" begin
@test Algorithm"alg"() isa Algorithm{:alg}
@test Algorithm("alg") isa Algorithm{:alg}
@test Algorithm(:alg) isa Algorithm{:alg}
alg = Algorithm"alg"(; x=2, y=3)
@test alg isa Algorithm{:alg}
@test alg.kwargs == (; x=2, y=3)
end
end
3 changes: 3 additions & 0 deletions NDTensors/src/lib/BaseExtensions/src/BaseExtensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module BaseExtensions
include("replace.jl")
end
32 changes: 32 additions & 0 deletions NDTensors/src/lib/BaseExtensions/src/replace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
replace(collection, replacements::Pair...) = Base.replace(collection, replacements...)
@static if VERSION < v"1.7.0-DEV.15"
# https://github.com/JuliaLang/julia/pull/38216
# TODO: Add to `Compat.jl` or delete when we drop Julia 1.6 support.
# `replace` for Tuples.
function _replace(f::Base.Callable, t::Tuple, count::Int)
return if count == 0 || isempty(t)
t
else
x = f(t[1])
(x, _replace(f, Base.tail(t), count - !==(x, t[1]))...)
end
end

function replace(f::Base.Callable, t::Tuple; count::Integer=typemax(Int))
return _replace(f, t, Base.check_count(count))
end

function _replace(t::Tuple, count::Int, old_new::Tuple{Vararg{Pair}})
return _replace(t, count) do x
Base.@_inline_meta
for o_n in old_new
isequal(first(o_n), x) && return last(o_n)
end
return x
end
end

function replace(t::Tuple, old_new::Pair...; count::Integer=typemax(Int))
return _replace(t, Base.check_count(count), old_new)
end
end
3 changes: 3 additions & 0 deletions NDTensors/src/lib/BaseExtensions/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
15 changes: 15 additions & 0 deletions NDTensors/src/lib/BaseExtensions/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using SafeTestsets: @safetestset

@safetestset "BaseExtensions" begin
using NDTensors.BaseExtensions: BaseExtensions
using Test: @test, @testset
@testset "replace $(typeof(collection))" for collection in
(["a", "b", "c"], ("a", "b", "c"))
r1 = BaseExtensions.replace(collection, "b" => "d")
@test r1 == typeof(collection)(["a", "d", "c"])
@test typeof(r1) === typeof(collection)
r2 = BaseExtensions.replace(collection, "b" => "d", "a" => "e")
@test r2 == typeof(collection)(["e", "d", "c"])
@test typeof(r2) === typeof(collection)
end
end
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module BlockSparseArrays
using ..AlgorithmSelection: Algorithm, @Algorithm_str
using BlockArrays:
AbstractBlockArray,
BlockArrays,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module LinearAlgebraExt
using ...AlgorithmSelection: Algorithm, @Algorithm_str
using BlockArrays: BlockArrays, blockedrange, blocks
using ..BlockSparseArrays: SparseArray, nonzero_keys # TODO: Move to `SparseArraysExtensions` module, rename `SparseArrayDOK`.
using ..BlockSparseArrays: BlockSparseArrays, BlockSparseArray, nonzero_blockkeys
using LinearAlgebra: LinearAlgebra, Hermitian, Transpose, I, eigen, qr
using ...NDTensors: Algorithm, @Algorithm_str # TODO: Move to `AlgorithmSelector` module.
using SparseArrays: SparseArrays, SparseMatrixCSC, spzeros, sparse

# TODO: Move to `SparseArraysExtensions`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ include("TestBlockSparseArraysUtils.jl")
@testset "README" begin
@test include(
joinpath(
pkgdir(BlockSparseArrays), "src", "BlockSparseArrays", "examples", "README.jl"
pkgdir(BlockSparseArrays),
"src",
"lib",
"BlockSparseArrays",
"examples",
"README.jl",
),
) isa Any
end
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module BroadcastMapConversion
# Convert broadcast call to map call by capturing array arguments
# with `map_args` and creating a map function with `map_function`.
# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl.

using Base.Broadcast: Broadcasted

const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}

function map_args(bc::Broadcasted, rest...)
return (map_args(bc.args...)..., map_args(rest...)...)
end
map_args(a::AbstractArray, rest...) = (a, map_args(rest...)...)
map_args(a, rest...) = map_args(rest...)
map_args() = ()

struct MapFunction{F,Args<:Tuple}
f::F
args::Args
end
struct Arg end

# construct MapFunction
function map_function(bc::Broadcasted)
args = map_function_tuple(bc.args)
return MapFunction(bc.f, args)
end
map_function_tuple(t::Tuple{}) = t
map_function_tuple(t::Tuple) = (map_function(t[1]), map_function_tuple(Base.tail(t))...)
map_function(a::WrappedScalarArgs) = a[]
map_function(a::AbstractArray) = Arg()
map_function(a) = a

# Evaluate MapFunction
(f::MapFunction)(args...) = apply(f, args)[1]
function apply(f::MapFunction, args)
args, newargs = apply_tuple(f.args, args)
return f.f(args...), newargs
end
apply(a::Arg, args::Tuple) = args[1], Base.tail(args)
apply(a, args) = a, args
apply_tuple(t::Tuple{}, args) = t, args
function apply_tuple(t::Tuple, args)
t1, newargs1 = apply(t[1], args)
ttail, newargs = apply_tuple(Base.tail(t), newargs1)
return (t1, ttail...), newargs
end
end
14 changes: 14 additions & 0 deletions NDTensors/src/lib/BroadcastMapConversion/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@eval module $(gensym())
using Test: @test, @testset
using NDTensors.BroadcastMapConversion: map_function, map_args
@testset "BroadcastMapConversion" begin
using Base.Broadcast: Broadcasted
c = 2.2
a = randn(2, 3)
b = randn(2, 3)
bc = Broadcasted(*, (c, a))
@test copy(bc) ≈ c * a ≈ map(map_function(bc), map_args(bc)...)
bc = Broadcasted(+, (a, b))
@test copy(bc) ≈ a + b ≈ map(map_function(bc), map_args(bc)...)
end
end
Loading
Loading