-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
208 changed files
with
1,656 additions
and
271 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 0 additions & 6 deletions
6
...amedDimsArrays/ext/NamedDimsArraysTensorAlgebraExt/src/NamedDimsArraysTensorAlgebraExt.jl
This file was deleted.
Oops, something went wrong.
13 changes: 0 additions & 13 deletions
13
NDTensors/src/NamedDimsArrays/ext/NamedDimsArraysTensorAlgebraExt/test/runtests.jl
This file was deleted.
Oops, something went wrong.
14 changes: 0 additions & 14 deletions
14
NDTensors/src/NamedDimsArrays/src/abstractnamedunitrange.jl
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
NDTensors/src/lib/AlgorithmSelection/src/AlgorithmSelection.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
module AlgorithmSelection | ||
include("algorithm.jl") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
Algorithm | ||
A type representing an algorithm backend for a function. | ||
For example, a function might have multiple backend algorithm | ||
implementations, which internally are selected with an `Algorithm` type. | ||
This allows users to extend functionality with a new algorithm but | ||
use the same interface. | ||
""" | ||
struct Algorithm{Alg,Kwargs<:NamedTuple} | ||
kwargs::Kwargs | ||
end | ||
|
||
Algorithm{Alg}(kwargs::NamedTuple) where {Alg} = Algorithm{Alg,typeof(kwargs)}(kwargs) | ||
Algorithm{Alg}(; kwargs...) where {Alg} = Algorithm{Alg}(NamedTuple(kwargs)) | ||
Algorithm(s; kwargs...) = Algorithm{Symbol(s)}(NamedTuple(kwargs)) | ||
|
||
Algorithm(alg::Algorithm) = alg | ||
|
||
# TODO: Use `SetParameters`. | ||
algorithm_string(::Algorithm{Alg}) where {Alg} = string(Alg) | ||
|
||
function Base.show(io::IO, alg::Algorithm) | ||
return print(io, "Algorithm type ", algorithm_string(alg), ", ", alg.kwargs) | ||
end | ||
Base.print(io::IO, alg::Algorithm) = print(io, algorithm_string(alg), ", ", alg.kwargs) | ||
|
||
""" | ||
@Algorithm_str | ||
A convenience macro for writing [`Algorithm`](@ref) types, typically used when | ||
adding methods to a function that supports multiple algorithm | ||
backends. | ||
""" | ||
macro Algorithm_str(s) | ||
return :(Algorithm{$(Expr(:quote, Symbol(s)))}) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[deps] | ||
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
@eval module $(gensym()) | ||
using Test: @test, @testset | ||
using NDTensors.AlgorithmSelection: Algorithm, @Algorithm_str | ||
@testset "AlgorithmSelection" begin | ||
@test Algorithm"alg"() isa Algorithm{:alg} | ||
@test Algorithm("alg") isa Algorithm{:alg} | ||
@test Algorithm(:alg) isa Algorithm{:alg} | ||
alg = Algorithm"alg"(; x=2, y=3) | ||
@test alg isa Algorithm{:alg} | ||
@test alg.kwargs == (; x=2, y=3) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
module BaseExtensions | ||
include("replace.jl") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
replace(collection, replacements::Pair...) = Base.replace(collection, replacements...) | ||
@static if VERSION < v"1.7.0-DEV.15" | ||
# https://github.com/JuliaLang/julia/pull/38216 | ||
# TODO: Add to `Compat.jl` or delete when we drop Julia 1.6 support. | ||
# `replace` for Tuples. | ||
function _replace(f::Base.Callable, t::Tuple, count::Int) | ||
return if count == 0 || isempty(t) | ||
t | ||
else | ||
x = f(t[1]) | ||
(x, _replace(f, Base.tail(t), count - !==(x, t[1]))...) | ||
end | ||
end | ||
|
||
function replace(f::Base.Callable, t::Tuple; count::Integer=typemax(Int)) | ||
return _replace(f, t, Base.check_count(count)) | ||
end | ||
|
||
function _replace(t::Tuple, count::Int, old_new::Tuple{Vararg{Pair}}) | ||
return _replace(t, count) do x | ||
Base.@_inline_meta | ||
for o_n in old_new | ||
isequal(first(o_n), x) && return last(o_n) | ||
end | ||
return x | ||
end | ||
end | ||
|
||
function replace(t::Tuple, old_new::Pair...; count::Integer=typemax(Int)) | ||
return _replace(t, Base.check_count(count), old_new) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[deps] | ||
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" | ||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
using SafeTestsets: @safetestset | ||
|
||
@safetestset "BaseExtensions" begin | ||
using NDTensors.BaseExtensions: BaseExtensions | ||
using Test: @test, @testset | ||
@testset "replace $(typeof(collection))" for collection in | ||
(["a", "b", "c"], ("a", "b", "c")) | ||
r1 = BaseExtensions.replace(collection, "b" => "d") | ||
@test r1 == typeof(collection)(["a", "d", "c"]) | ||
@test typeof(r1) === typeof(collection) | ||
r2 = BaseExtensions.replace(collection, "b" => "d", "a" => "e") | ||
@test r2 == typeof(collection)(["e", "d", "c"]) | ||
@test typeof(r2) === typeof(collection) | ||
end | ||
end |
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions
1
...lockSparseArrays/src/BlockSparseArrays.jl → ...lockSparseArrays/src/BlockSparseArrays.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
.../src/LinearAlgebraExt/LinearAlgebraExt.jl → .../src/LinearAlgebraExt/LinearAlgebraExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
48 changes: 48 additions & 0 deletions
48
NDTensors/src/lib/BroadcastMapConversion/src/BroadcastMapConversion.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
module BroadcastMapConversion | ||
# Convert broadcast call to map call by capturing array arguments | ||
# with `map_args` and creating a map function with `map_function`. | ||
# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl. | ||
|
||
using Base.Broadcast: Broadcasted | ||
|
||
const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}} | ||
|
||
function map_args(bc::Broadcasted, rest...) | ||
return (map_args(bc.args...)..., map_args(rest...)...) | ||
end | ||
map_args(a::AbstractArray, rest...) = (a, map_args(rest...)...) | ||
map_args(a, rest...) = map_args(rest...) | ||
map_args() = () | ||
|
||
struct MapFunction{F,Args<:Tuple} | ||
f::F | ||
args::Args | ||
end | ||
struct Arg end | ||
|
||
# construct MapFunction | ||
function map_function(bc::Broadcasted) | ||
args = map_function_tuple(bc.args) | ||
return MapFunction(bc.f, args) | ||
end | ||
map_function_tuple(t::Tuple{}) = t | ||
map_function_tuple(t::Tuple) = (map_function(t[1]), map_function_tuple(Base.tail(t))...) | ||
map_function(a::WrappedScalarArgs) = a[] | ||
map_function(a::AbstractArray) = Arg() | ||
map_function(a) = a | ||
|
||
# Evaluate MapFunction | ||
(f::MapFunction)(args...) = apply(f, args)[1] | ||
function apply(f::MapFunction, args) | ||
args, newargs = apply_tuple(f.args, args) | ||
return f.f(args...), newargs | ||
end | ||
apply(a::Arg, args::Tuple) = args[1], Base.tail(args) | ||
apply(a, args) = a, args | ||
apply_tuple(t::Tuple{}, args) = t, args | ||
function apply_tuple(t::Tuple, args) | ||
t1, newargs1 = apply(t[1], args) | ||
ttail, newargs = apply_tuple(Base.tail(t), newargs1) | ||
return (t1, ttail...), newargs | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
@eval module $(gensym()) | ||
using Test: @test, @testset | ||
using NDTensors.BroadcastMapConversion: map_function, map_args | ||
@testset "BroadcastMapConversion" begin | ||
using Base.Broadcast: Broadcasted | ||
c = 2.2 | ||
a = randn(2, 3) | ||
b = randn(2, 3) | ||
bc = Broadcasted(*, (c, a)) | ||
@test copy(bc) ≈ c * a ≈ map(map_function(bc), map_args(bc)...) | ||
bc = Broadcasted(+, (a, b)) | ||
@test copy(bc) ≈ a + b ≈ map(map_function(bc), map_args(bc)...) | ||
end | ||
end |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.