diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index 8b50852ac6..b0a4ea369a 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -44,6 +44,7 @@ FLoops = "0.2.1" Folds = "0.2.8" Functors = "0.2, 0.3, 0.4" HDF5 = "0.14, 0.15, 0.16, 0.17" +InlineStrings = "1" Requires = "1.1" SimpleTraits = "0.9.4" SplitApplyCombine = "1.2.2" diff --git a/NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl b/NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl index 440c1d2e6a..572ed929e2 100644 --- a/NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl +++ b/NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl @@ -17,8 +17,17 @@ struct BlockZero{Axes} axes::Axes end -function (f::BlockZero)(T::Type, I::CartesianIndex) - return fill!(T(undef, block_size(f.axes, Block(Tuple(I)))), false) +function (f::BlockZero)( + arraytype::Type{<:AbstractArray{T,N}}, I::CartesianIndex{N} +) where {T,N} + return fill!(arraytype(undef, block_size(f.axes, Block(Tuple(I)))), false) +end + +# Fallback to Array if it is abstract +function (f::BlockZero)( + arraytype::Type{AbstractArray{T,N}}, I::CartesianIndex{N} +) where {T,N} + return fill!(Array{T,N}(undef, block_size(f.axes, Block(Tuple(I)))), false) end function BlockSparseArray( diff --git a/NDTensors/src/BlockSparseArrays/test/runtests.jl b/NDTensors/src/BlockSparseArrays/test/runtests.jl index a048b0ac32..db84abdd23 100644 --- a/NDTensors/src/BlockSparseArrays/test/runtests.jl +++ b/NDTensors/src/BlockSparseArrays/test/runtests.jl @@ -1,5 +1,6 @@ using Test using NDTensors.BlockSparseArrays +using BlockArrays: BlockArrays @testset "Test NDTensors.BlockSparseArrays" begin @testset "README" begin @@ -9,4 +10,12 @@ using NDTensors.BlockSparseArrays ), ) isa Any end + @testset "Mixed block test" begin + blocks = [BlockArrays.Block(1, 1), BlockArrays.Block(2, 2)] + block_data = [randn(2, 2)', randn(3, 3)] + inds = ([2, 3], [2, 3]) + A = BlockSparseArray(blocks, block_data, inds) + @test A[BlockArrays.Block(1, 1)] == block_data[1] + @test A[BlockArrays.Block(1, 2)] == zeros(2, 3) + end end diff --git a/NDTensors/src/SmallVectors/src/BaseExt/sort.jl b/NDTensors/src/SmallVectors/src/BaseExt/sort.jl new file mode 100644 index 0000000000..e1b13fd795 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/BaseExt/sort.jl @@ -0,0 +1,23 @@ +# Custom version of `sort` (`SmallVectors.sort`) that directly uses an `order::Ordering`. +function sort(v, order::Base.Sort.Ordering; alg::Base.Sort.Algorithm=Base.Sort.defalg(v)) + mv = thaw(v) + SmallVectors.sort!(mv, order; alg) + return freeze(mv) +end + +# Custom version of `sort!` (`SmallVectors.sort!`) that directly uses an `order::Ordering`. +function sort!( + v::AbstractVector{T}, + order::Base.Sort.Ordering; + alg::Base.Sort.Algorithm=Base.Sort.defalg(v), + scratch::Union{Vector{T},Nothing}=nothing, +) where {T} + if VERSION < v"1.8.4" + Base.sort!(v, alg, order) + else + Base.Sort._sort!( + v, Base.Sort.maybe_apply_initial_optimizations(alg), order, (; scratch) + ) + end + return v +end diff --git a/NDTensors/src/SmallVectors/src/BaseExt/sortedunique.jl b/NDTensors/src/SmallVectors/src/BaseExt/sortedunique.jl index 34fa0ff4c9..b8c851a568 100644 --- a/NDTensors/src/SmallVectors/src/BaseExt/sortedunique.jl +++ b/NDTensors/src/SmallVectors/src/BaseExt/sortedunique.jl @@ -30,6 +30,7 @@ function unionsortedunique!(itr1, itr2, order::Ordering) for j in length(itr1):-1:(i1 + 1) itr1[j] = itr1[j - 1] end + # Replace with the item from the second list itr1[i1] = item2 i1 += 1 i2 += 1 @@ -62,7 +63,7 @@ end # Union two unique sorted collections into an # output buffer, returning a unique sorted collection. function unionsortedunique(itr1, itr2, order::Ordering) - out = thaw_type(itr1)(undef, length(itr1)) + out = thaw_type(itr1)() i1 = firstindex(itr1) i2 = firstindex(itr2) iout = firstindex(out) @@ -82,14 +83,12 @@ function unionsortedunique(itr1, itr2, order::Ordering) iout += 1 i2 += 1 else # They are equal - out[iout] = item1 + out[iout] = item2 iout += 1 i1 += 1 i2 += 1 end end - # In case `out` was too long to begin with. - ## resize!(out, iout - 1) # TODO: Use `insertat!`? r1 = i1:stop1 resize!(out, length(out) + length(r1)) diff --git a/NDTensors/src/SmallVectors/src/SmallVectors.jl b/NDTensors/src/SmallVectors/src/SmallVectors.jl index 75b5ec200c..e6c6f3330f 100644 --- a/NDTensors/src/SmallVectors/src/SmallVectors.jl +++ b/NDTensors/src/SmallVectors/src/SmallVectors.jl @@ -31,6 +31,7 @@ NotImplemented() = NotImplemented("Not implemented.") include("BaseExt/insertstyle.jl") include("BaseExt/thawfreeze.jl") +include("BaseExt/sort.jl") include("BaseExt/sortedunique.jl") include("abstractarray/insert.jl") include("abstractsmallvector/abstractsmallvector.jl") diff --git a/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl index 79fdd5cdda..c8116559e8 100644 --- a/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl +++ b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl @@ -14,7 +14,7 @@ Base.resize!(vec::AbstractSmallVector, len) = throw(NotImplemented()) @inline function resize(vec::AbstractSmallVector, len) - mvec = Base.copymutable(vec) + mvec = thaw(vec) resize!(mvec, len) return convert(similar_type(vec), mvec) end @@ -25,14 +25,14 @@ end end @inline function empty(vec::AbstractSmallVector) - mvec = Base.copymutable(vec) + mvec = thaw(vec) empty!(mvec) return convert(similar_type(vec), mvec) end @inline function StaticArrays.setindex(vec::AbstractSmallVector, item, index::Integer) @boundscheck checkbounds(vec, index) - mvec = Base.copymutable(vec) + mvec = thaw(vec) @inbounds mvec[index] = item return convert(similar_type(vec), mvec) end @@ -44,7 +44,7 @@ end end @inline function StaticArrays.push(vec::AbstractSmallVector, item) - mvec = Base.copymutable(vec) + mvec = thaw(vec) push!(mvec, item) return convert(similar_type(vec), mvec) end @@ -55,7 +55,7 @@ end end @inline function StaticArrays.pop(vec::AbstractSmallVector) - mvec = Base.copymutable(vec) + mvec = thaw(vec) pop!(mvec) return convert(similar_type(vec), mvec) end @@ -67,7 +67,7 @@ end # Don't `@inline`, makes it slower. function StaticArrays.pushfirst(vec::AbstractSmallVector, item) - mvec = Base.copymutable(vec) + mvec = thaw(vec) pushfirst!(mvec, item) return convert(similar_type(vec), mvec) end @@ -80,7 +80,7 @@ end # Don't `@inline`, makes it slower. function StaticArrays.popfirst(vec::AbstractSmallVector) - mvec = Base.copymutable(vec) + mvec = thaw(vec) popfirst!(mvec) return convert(similar_type(vec), mvec) end @@ -129,7 +129,7 @@ end # Don't @inline, makes it slower. function StaticArrays.insert(vec::AbstractSmallVector, index::Integer, item) - mvec = Base.copymutable(vec) + mvec = thaw(vec) insert!(mvec, index, item) return convert(similar_type(vec), mvec) end @@ -154,7 +154,7 @@ end function StaticArrays.deleteat( vec::AbstractSmallVector, index::Union{Integer,AbstractUnitRange{<:Integer}} ) - mvec = Base.copymutable(vec) + mvec = thaw(vec) deleteat!(mvec, index) return convert(similar_type(vec), mvec) end @@ -163,9 +163,7 @@ end # https://github.com/JuliaLang/julia/blob/bed2cd540a11544ed4be381d471bbf590f0b745e/base/sort.jl#L722-L736 # https://en.wikipedia.org/wiki/Insertion_sort#:~:text=Insertion%20sort%20is%20a%20simple,%2C%20heapsort%2C%20or%20merge%20sort. # Alternatively could use `TupleTools.jl` or `StaticArrays.jl` for out-of-place sorting. -@inline function Base.sort!( - vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false -) +@inline function sort!(vec::AbstractSmallVector, order::Base.Sort.Ordering) lo, hi = firstindex(vec), lastindex(vec) lo_plus_1 = (lo + 1) @inbounds for i in lo_plus_1:hi @@ -174,7 +172,7 @@ end jmax = j for _ in jmax:-1:lo_plus_1 y = vec[j - 1] - if !(lt(by(x), by(y)) != rev) + if !Base.Sort.lt(order, x, y) break end vec[j] = y @@ -185,20 +183,33 @@ end return vec end +@inline function Base.sort!( + vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false +) + SmallVectors.sort!(vec, Base.Sort.ord(lt, by, rev)) + return vec +end + # Don't @inline, makes it slower. -function Base.sort(vec::AbstractSmallVector; kwargs...) - mvec = Base.copymutable(vec) - sort!(mvec; kwargs...) +function sort(vec::AbstractSmallVector, order::Base.Sort.Ordering) + mvec = thaw(vec) + SmallVectors.sort!(mvec, order) return convert(similar_type(vec), mvec) end +@inline function Base.sort( + vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false +) + return SmallVectors.sort(vec, Base.Sort.ord(lt, by, rev)) +end + @inline function insertsorted!(vec::AbstractSmallVector, item; kwargs...) insert!(vec, searchsortedfirst(vec, item; kwargs...), item) return vec end function insertsorted(vec::AbstractSmallVector, item; kwargs...) - mvec = Base.copymutable(vec) + mvec = thaw(vec) insertsorted!(mvec, item; kwargs...) return convert(similar_type(vec), mvec) end @@ -228,7 +239,7 @@ end end function mergesorted(vec::AbstractSmallVector, item; kwargs...) - mvec = Base.copymutable(vec) + mvec = thaw(vec) mergesorted!(mvec, item; kwargs...) return convert(similar_type(vec), mvec) end @@ -261,7 +272,7 @@ end # Don't @inline, makes it slower. function Base.circshift(vec::AbstractSmallVector, shift::Integer) - mvec = Base.copymutable(vec) + mvec = thaw(vec) circshift!(mvec, shift) return convert(similar_type(vec), mvec) end @@ -277,7 +288,7 @@ end # Missing from `StaticArrays.jl`. # Don't @inline, makes it slower. function append(vec::AbstractSmallVector, item::AbstractVector) - mvec = Base.copymutable(vec) + mvec = thaw(vec) append!(mvec, item) return convert(similar_type(vec), mvec) end @@ -294,14 +305,14 @@ end # Missing from `StaticArrays.jl`. # Don't @inline, makes it slower. function prepend(vec::AbstractSmallVector, item::AbstractVector) - mvec = Base.copymutable(vec) + mvec = thaw(vec) prepend!(mvec, item) return convert(similar_type(vec), mvec) end # Don't @inline, makes it slower. function Base.vcat(vec1::AbstractSmallVector, vec2::AbstractVector) - mvec1 = Base.copymutable(vec1) + mvec1 = thaw(vec1) append!(mvec1, vec2) return convert(similar_type(vec1), mvec1) end diff --git a/NDTensors/src/SortedSets/src/BaseExt/sorted.jl b/NDTensors/src/SortedSets/src/BaseExt/sorted.jl index 939eb12274..54a2873765 100644 --- a/NDTensors/src/SortedSets/src/BaseExt/sorted.jl +++ b/NDTensors/src/SortedSets/src/BaseExt/sorted.jl @@ -39,16 +39,16 @@ function uniquesorted(vec; lt=isless, by=identity, rev::Bool=false, order::Order end function uniquesorted(vec::AbstractVector, order::Ordering) - vec = copy(vec) - i = firstindex(vec) - stopi = lastindex(vec) + mvec = thaw(vec) + i = firstindex(mvec) + stopi = lastindex(mvec) while i < stopi - if !lt(order, @inbounds(vec[i]), @inbounds(vec[i + 1])) - deleteat!(vec, i) + if !lt(order, @inbounds(mvec[i]), @inbounds(mvec[i + 1])) + deleteat!(mvec, i) stopi -= 1 else i += 1 end end - return vec + return freeze(mvec) end diff --git a/NDTensors/src/SortedSets/src/DictionariesExt/insert.jl b/NDTensors/src/SortedSets/src/DictionariesExt/insert.jl index 1721487b9b..847f312208 100644 --- a/NDTensors/src/SortedSets/src/DictionariesExt/insert.jl +++ b/NDTensors/src/SortedSets/src/DictionariesExt/insert.jl @@ -1,2 +1,35 @@ -SmallVectors.insert(inds::AbstractIndices, i) = insert!(copy(inds), i) -SmallVectors.delete(inds::AbstractIndices, i) = delete!(copy(inds), i) +SmallVectors.insert(inds::AbstractIndices, i) = insert(InsertStyle(inds), inds, i) + +function SmallVectors.insert(::InsertStyle, inds::AbstractIndices, i) + return error("Not implemented") +end + +function SmallVectors.insert(::IsInsertable, inds::AbstractIndices, i) + inds = copy(inds) + insert!(inds, i) + return inds +end + +function SmallVectors.insert(::FastCopy, inds::AbstractIndices, i) + minds = thaw(inds) + insert!(minds, i) + return freeze(minds) +end + +SmallVectors.delete(inds::AbstractIndices, i) = delete(InsertStyle(inds), inds, i) + +function SmallVectors.delete(::InsertStyle, inds::AbstractIndices, i) + return error("Not implemented") +end + +function SmallVectors.delete(::IsInsertable, inds::AbstractIndices, i) + inds = copy(inds) + delete!(inds, i) + return inds +end + +function SmallVectors.delete(::FastCopy, inds::AbstractIndices, i) + minds = thaw(inds) + delete!(minds, i) + return freeze(minds) +end diff --git a/NDTensors/src/SortedSets/src/SortedSets.jl b/NDTensors/src/SortedSets/src/SortedSets.jl index 09343deb3b..777407f3c5 100644 --- a/NDTensors/src/SortedSets/src/SortedSets.jl +++ b/NDTensors/src/SortedSets/src/SortedSets.jl @@ -7,7 +7,7 @@ using ..SmallVectors using Base: @propagate_inbounds using Base.Order: Ordering, Forward, ord, lt -export AbstractWrappedIndices, SortedSet, SmallSet, MSmallSet +export AbstractWrappedSet, SortedSet, SmallSet, MSmallSet include("BaseExt/sorted.jl") include("DictionariesExt/insert.jl") diff --git a/NDTensors/src/SortedSets/src/abstractwrappedset.jl b/NDTensors/src/SortedSets/src/abstractwrappedset.jl index e1bbf50d24..e1a1094ff8 100644 --- a/NDTensors/src/SortedSets/src/abstractwrappedset.jl +++ b/NDTensors/src/SortedSets/src/abstractwrappedset.jl @@ -1,111 +1,117 @@ -# AbstractWrappedIndices: a wrapper around an `AbstractIndices` +# AbstractWrappedSet: a wrapper around an `AbstractIndices` # with methods automatically forwarded via `parent` # and rewrapped via `rewrap`. -abstract type AbstractWrappedIndices{T,D} <: AbstractIndices{T} end +abstract type AbstractWrappedSet{T,D} <: AbstractIndices{T} end # Required interface -Base.parent(inds::AbstractWrappedIndices) = error("Not implemented") -function Dictionaries.empty_type(::Type{AbstractWrappedIndices{I}}, ::Type{I}) where {I} +Base.parent(set::AbstractWrappedSet) = error("Not implemented") +function Dictionaries.empty_type(::Type{AbstractWrappedSet{I}}, ::Type{I}) where {I} return error("Not implemented") end -SmallVectors.thaw(::AbstractWrappedIndices) = error("Not implemented") -SmallVectors.freeze(::AbstractWrappedIndices) = error("Not implemented") -rewrap(::AbstractWrappedIndices, data) = error("Not implemented") +rewrap(::AbstractWrappedSet, data) = error("Not implemented") + +SmallVectors.thaw(set::AbstractWrappedSet) = rewrap(set, thaw(parent(set))) +SmallVectors.freeze(set::AbstractWrappedSet) = rewrap(set, freeze(parent(set))) # Traits -SmallVectors.InsertStyle(::Type{<:AbstractWrappedIndices{T,D}}) where {T,D} = InsertStyle(D) +SmallVectors.InsertStyle(::Type{<:AbstractWrappedSet{T,D}}) where {T,D} = InsertStyle(D) -# AbstractIndices interface -@propagate_inbounds function Base.iterate(inds::AbstractWrappedIndices, state...) - return iterate(parent(inds), state...) +# AbstractSet interface +@propagate_inbounds function Base.iterate(set::AbstractWrappedSet, state...) + return iterate(parent(set), state...) end # `I` is needed to avoid ambiguity error. -@inline Base.in(tag::I, inds::AbstractWrappedIndices{I}) where {I} = in(tag, parent(inds)) -@inline Base.IteratorSize(inds::AbstractWrappedIndices) = Base.IteratorSize(parent(inds)) -@inline Base.length(inds::AbstractWrappedIndices) = length(parent(inds)) +@inline Base.in(item::I, set::AbstractWrappedSet{I}) where {I} = in(item, parent(set)) +@inline Base.IteratorSize(set::AbstractWrappedSet) = Base.IteratorSize(parent(set)) +@inline Base.length(set::AbstractWrappedSet) = length(parent(set)) -@inline Dictionaries.istokenizable(inds::AbstractWrappedIndices) = - istokenizable(parent(inds)) -@inline Dictionaries.tokentype(inds::AbstractWrappedIndices) = tokentype(parent(inds)) -@inline Dictionaries.iteratetoken(inds::AbstractWrappedIndices, s...) = - iterate(parent(inds), s...) -@inline function Dictionaries.iteratetoken_reverse(inds::AbstractWrappedIndices) - return iteratetoken_reverse(parent(inds)) +@inline Dictionaries.istokenizable(set::AbstractWrappedSet) = istokenizable(parent(set)) +@inline Dictionaries.tokentype(set::AbstractWrappedSet) = tokentype(parent(set)) +@inline Dictionaries.iteratetoken(set::AbstractWrappedSet, s...) = + iterate(parent(set), s...) +@inline function Dictionaries.iteratetoken_reverse(set::AbstractWrappedSet) + return iteratetoken_reverse(parent(set)) end -@inline function Dictionaries.iteratetoken_reverse(inds::AbstractWrappedIndices, t) - return iteratetoken_reverse(parent(inds), t) +@inline function Dictionaries.iteratetoken_reverse(set::AbstractWrappedSet, t) + return iteratetoken_reverse(parent(set), t) end -@inline function Dictionaries.gettoken(inds::AbstractWrappedIndices, i) - return gettoken(parent(inds), i) +@inline function Dictionaries.gettoken(set::AbstractWrappedSet, i) + return gettoken(parent(set), i) end -@propagate_inbounds Dictionaries.gettokenvalue(inds::AbstractWrappedIndices, x) = - gettokenvalue(parent(inds), x) +@propagate_inbounds Dictionaries.gettokenvalue(set::AbstractWrappedSet, x) = + gettokenvalue(parent(set), x) -@inline Dictionaries.isinsertable(inds::AbstractWrappedIndices) = isinsertable(parent(inds)) +@inline Dictionaries.isinsertable(set::AbstractWrappedSet) = isinsertable(parent(set)) # Specify `I` to fix ambiguity error. @inline function Dictionaries.gettoken!( - inds::AbstractWrappedIndices{I}, i::I, values=() + set::AbstractWrappedSet{I}, i::I, values=() ) where {I} - return gettoken!(parent(inds), i, values) + return gettoken!(parent(set), i, values) end -@inline function Dictionaries.deletetoken!(inds::AbstractWrappedIndices, x, values=()) - deletetoken!(parent(inds), x, values) - return inds +@inline function Dictionaries.deletetoken!(set::AbstractWrappedSet, x, values=()) + deletetoken!(parent(set), x, values) + return set end -@inline function Base.empty!(inds::AbstractWrappedIndices, values=()) - empty!(parent(inds)) - return inds +@inline function Base.empty!(set::AbstractWrappedSet, values=()) + empty!(parent(set)) + return set end # Not defined to be part of the `AbstractIndices` interface, # but seems to be needed. -@inline function Base.filter!(pred, inds::AbstractWrappedIndices) - filter!(pred, parent(inds)) - return inds +@inline function Base.filter!(pred, set::AbstractWrappedSet) + filter!(pred, parent(set)) + return set end # TODO: Maybe require an implementation? -@inline function Base.copy(inds::AbstractWrappedIndices, eltype::Type) - return typeof(inds)(copy(parent(inds), eltype)) +@inline function Base.copy(set::AbstractWrappedSet, eltype::Type) + return typeof(set)(copy(parent(set), eltype)) end # Not required for AbstractIndices interface but # helps with faster code paths -SmallVectors.insert(inds::AbstractWrappedIndices, tag) = insert(parent(inds), tag) -Base.insert!(inds::AbstractWrappedIndices, tag) = insert!(parent(inds), tag) +SmallVectors.insert(set::AbstractWrappedSet, item) = rewrap(set, insert(parent(set), item)) +function Base.insert!(set::AbstractWrappedSet, item) + insert!(parent(set), item) + return set +end -SmallVectors.delete(inds::AbstractWrappedIndices, tag) = delete(parent(inds), tag) -Base.delete!(inds::AbstractWrappedIndices, tag) = delete!(parent(inds), tag) +SmallVectors.delete(set::AbstractWrappedSet, item) = rewrap(set, delete(parent(set), item)) +function Base.delete!(set::AbstractWrappedSet, item) + delete!(parent(set), item) + return set +end -function Base.union(inds1::AbstractWrappedIndices, inds2::AbstractWrappedIndices) - return rewrap(inds1, union(parent(inds1), parent(inds2))) +function Base.union(set1::AbstractWrappedSet, set2::AbstractWrappedSet) + return rewrap(set1, union(parent(set1), parent(set2))) end -function Base.union(inds1::AbstractWrappedIndices, inds2) - return rewrap(inds1, union(parent(inds1), inds2)) +function Base.union(set1::AbstractWrappedSet, set2) + return rewrap(set1, union(parent(set1), set2)) end -function Base.intersect(inds1::AbstractWrappedIndices, inds2::AbstractWrappedIndices) - return rewrap(inds1, intersect(parent(inds1), parent(inds2))) +function Base.intersect(set1::AbstractWrappedSet, set2::AbstractWrappedSet) + return rewrap(set1, intersect(parent(set1), parent(set2))) end -function Base.intersect(inds1::AbstractWrappedIndices, inds2) - return rewrap(inds1, intersect(parent(inds1), inds2)) +function Base.intersect(set1::AbstractWrappedSet, set2) + return rewrap(set1, intersect(parent(set1), set2)) end -function Base.setdiff(inds1::AbstractWrappedIndices, inds2::AbstractWrappedIndices) - return rewrap(inds1, setdiff(parent(inds1), parent(inds2))) +function Base.setdiff(set1::AbstractWrappedSet, set2::AbstractWrappedSet) + return rewrap(set1, setdiff(parent(set1), parent(set2))) end -function Base.setdiff(inds1::AbstractWrappedIndices, inds2) - return rewrap(inds1, setdiff(parent(inds1), inds2)) +function Base.setdiff(set1::AbstractWrappedSet, set2) + return rewrap(set1, setdiff(parent(set1), set2)) end -function Base.symdiff(inds1::AbstractWrappedIndices, inds2::AbstractWrappedIndices) - return rewrap(inds1, symdiff(parent(inds1), parent(inds2))) +function Base.symdiff(set1::AbstractWrappedSet, set2::AbstractWrappedSet) + return rewrap(set1, symdiff(parent(set1), parent(set2))) end -function Base.symdiff(inds1::AbstractWrappedIndices, inds2) - return rewrap(inds1, symdiff(parent(inds1), inds2)) +function Base.symdiff(set1::AbstractWrappedSet, set2) + return rewrap(set1, symdiff(parent(set1), set2)) end diff --git a/NDTensors/src/SortedSets/src/sortedset.jl b/NDTensors/src/SortedSets/src/sortedset.jl index 43e2db1c8c..dff1b441a0 100644 --- a/NDTensors/src/SortedSets/src/sortedset.jl +++ b/NDTensors/src/SortedSets/src/sortedset.jl @@ -1,146 +1,171 @@ """ - SortedIndices(iter) + SortedSet(iter) -Construct an `SortedIndices <: AbstractIndices` from an arbitrary Julia iterable with unique +Construct an `SortedSet <: AbstractSet` from an arbitrary Julia iterable with unique elements. Lookup uses that they are sorted. -SortedIndices can be faster than ArrayIndices which use naive search that may be optimal for +SortedSet can be faster than ArrayIndices which use naive search that may be optimal for small collections. Larger collections are better handled by containers like `Indices`. """ -struct SortedIndices{I,Inds<:AbstractArray{I},Order<:Ordering} <: AbstractSet{I} - inds::Inds +struct SortedSet{T,Data<:AbstractArray{T},Order<:Ordering} <: AbstractSet{T} + data::Data order::Order - global @inline _SortedIndices( - inds::Inds, order::Order - ) where {I,Inds<:AbstractArray{I},Order<:Ordering} = new{I,Inds,Order}(inds, order) + global @inline _SortedSet( + data::Data, order::Order + ) where {T,Data<:AbstractArray{T},Order<:Ordering} = new{T,Data,Order}(data, order) end -# Inner constructor -function SortedIndices{I,Inds,Order}( - a::Inds, order::Order; issorted=issorted, allunique=allunique -) where {I,Inds<:AbstractArray{I},Order<:Ordering} +@inline Base.parent(set::SortedSet) = getfield(set, :data) +@inline order(set::SortedSet) = getfield(set, :order) + +# Dictionaries.jl interface +const SortedIndices = SortedSet + +# Inner constructor. +# Sorts and makes unique as needed. +function SortedSet{T,Data,Order}( + a::Data, order::Order +) where {T,Data<:AbstractArray{T},Order<:Ordering} if !issorted(a, order) - a = sort(a, order) + a = SmallVectors.sort(a, order) end if !alluniquesorted(a, order) a = uniquesorted(a, order) end - return _SortedIndices(a, order) + return _SortedSet(a, order) end -@inline function SortedIndices{I,Inds,Order}( - a::AbstractArray, order::Ordering; issorted=issorted, allunique=allunique -) where {I,Inds<:AbstractArray{I},Order<:Ordering} - return SortedIndices{I,Inds,Order}( - convert(Inds, a), convert(Order, order); issorted, allunique - ) +@inline function SortedSet{T,Data,Order}( + a::AbstractArray, order::Ordering +) where {T,Data<:AbstractArray{T},Order<:Ordering} + return SortedSet{T,Data,Order}(convert(Data, a), convert(Order, order)) end -@inline function SortedIndices{I,Inds}( - a::AbstractArray, order::Order; issorted=issorted, allunique=allunique -) where {I,Inds<:AbstractArray{I},Order<:Ordering} - return SortedIndices{I,Inds,Order}(a, order; issorted, allunique) +@inline function SortedSet{T,Data}( + a::AbstractArray, order::Order +) where {T,Data<:AbstractArray{T},Order<:Ordering} + return SortedSet{T,Data,Order}(a, order) end -@inline function SortedIndices( - a::Inds, order::Ordering; issorted=issorted, allunique=allunique -) where {I,Inds<:AbstractArray{I}} - return SortedIndices{I,Inds}(a, order; issorted, allunique) +@inline function SortedSet(a::Data, order::Ordering) where {T,Data<:AbstractArray{T}} + return SortedSet{T,Data}(a, order) end -@inline function SortedIndices{I,Inds}( - a::Inds; - lt=isless, - by=identity, - rev::Bool=false, - order::Ordering=Forward, - issorted=issorted, - allunique=allunique, -) where {I,Inds<:AbstractArray{I}} - order = ord(lt, by, rev, order) - return SortedIndices{I,Inds}(a, order; issorted, allunique) +# Accept other inputs like `Tuple`. +@inline function SortedSet(itr, order::Ordering) + return SortedSet(collect(itr), order) end -const SortedSet = SortedIndices +@inline function SortedSet{T,Data}( + a::Data; lt=isless, by=identity, rev::Bool=false +) where {T,Data<:AbstractArray{T}} + return SortedSet{T,Data}(a, ord(lt, by, rev)) +end # Traits -@inline SmallVectors.InsertStyle(::Type{<:SortedIndices{I,Inds}}) where {I,Inds} = - InsertStyle(Inds) -@inline SmallVectors.thaw(i::SortedIndices) = SortedIndices(thaw(i.inds), i.order) -@inline SmallVectors.freeze(i::SortedIndices) = SortedIndices(freeze(i.inds), i.order) - -@propagate_inbounds SortedIndices(; kwargs...) = SortedIndices{Any}([]; kwargs...) -@propagate_inbounds SortedIndices{I}(; kwargs...) where {I} = - SortedIndices{I,Vector{I}}(I[]; kwargs...) -@propagate_inbounds SortedIndices{I,Inds}(; kwargs...) where {I,Inds} = - SortedIndices{I}(Inds(); kwargs...) - -@propagate_inbounds SortedIndices(iter; kwargs...) = SortedIndices(collect(iter); kwargs...) -@propagate_inbounds SortedIndices{I}(iter; kwargs...) where {I} = - SortedIndices{I}(collect(I, iter); kwargs...) - -@propagate_inbounds SortedIndices(a::AbstractArray{I}; kwargs...) where {I} = - SortedIndices{I}(a; kwargs...) -@propagate_inbounds SortedIndices{I}(a::AbstractArray{I}; kwargs...) where {I} = - SortedIndices{I,typeof(a)}(a; kwargs...) - -@propagate_inbounds SortedIndices{I,Inds}( +@inline SmallVectors.InsertStyle(::Type{<:SortedSet{T,Data}}) where {T,Data} = + InsertStyle(Data) +@inline SmallVectors.thaw(set::SortedSet) = SortedSet(thaw(parent(set)), order(set)) +@inline SmallVectors.freeze(set::SortedSet) = SortedSet(freeze(parent(set)), order(set)) + +@propagate_inbounds SortedSet(; kwargs...) = SortedSet{Any}([]; kwargs...) +@propagate_inbounds SortedSet{T}(; kwargs...) where {T} = + SortedSet{T,Vector{T}}(T[]; kwargs...) +@propagate_inbounds SortedSet{T,Data}(; kwargs...) where {T,Data} = + SortedSet{T}(Data(); kwargs...) + +@propagate_inbounds SortedSet(iter; kwargs...) = SortedSet(collect(iter); kwargs...) +@propagate_inbounds SortedSet{T}(iter; kwargs...) where {T} = + SortedSet{T}(collect(T, iter); kwargs...) + +@propagate_inbounds SortedSet(a::AbstractArray{T}; kwargs...) where {T} = + SortedSet{T}(a; kwargs...) +@propagate_inbounds SortedSet{T}(a::AbstractArray{T}; kwargs...) where {T} = + SortedSet{T,typeof(a)}(a; kwargs...) + +@propagate_inbounds SortedSet{T,Data}( a::AbstractArray; kwargs... -) where {I,Inds<:AbstractArray{I}} = SortedIndices{I,Inds}(Inds(a); kwargs...) +) where {T,Data<:AbstractArray{T}} = SortedSet{T,Data}(Data(a); kwargs...) -function Base.convert(::Type{AbstractIndices{I}}, inds::SortedIndices) where {I} - return convert(SortedIndices{I}, inds) +function Base.convert(::Type{AbstractIndices{T}}, set::SortedSet) where {T} + return convert(SortedSet{T}, set) end -function Base.convert(::Type{SortedIndices}, inds::AbstractIndices{I}) where {I} - return convert(SortedIndices{I}, inds) +function Base.convert(::Type{SortedSet}, set::AbstractIndices{T}) where {T} + return convert(SortedSet{T}, set) end -function Base.convert(::Type{SortedIndices{I}}, inds::AbstractIndices) where {I} - return convert(SortedIndices{I,Vector{I}}, inds) +function Base.convert(::Type{SortedSet{T}}, set::AbstractIndices) where {T} + return convert(SortedSet{T,Vector{T}}, set) end function Base.convert( - ::Type{SortedIndices{I,Inds}}, inds::AbstractIndices -) where {I,Inds<:AbstractArray{I}} - a = convert(Inds, collect(I, inds)) - return @inbounds SortedIndices{I,typeof(a)}(a) + ::Type{SortedSet{T,Data}}, set::AbstractIndices +) where {T,Data<:AbstractArray{T}} + a = convert(Data, collect(T, set)) + return @inbounds SortedSet{T,typeof(a)}(a) end -Base.convert(::Type{SortedIndices{I}}, inds::SortedIndices{I}) where {I} = inds +Base.convert(::Type{SortedSet{T}}, set::SortedSet{T}) where {T} = set function Base.convert( - ::Type{SortedIndices{I}}, inds::SortedIndices{<:Any,Inds} -) where {I,Inds<:AbstractArray{I}} - return convert(SortedIndices{I,Inds}, inds) + ::Type{SortedSet{T}}, set::SortedSet{<:Any,Data} +) where {T,Data<:AbstractArray{T}} + return convert(SortedSet{T,Data}, set) end function Base.convert( - ::Type{SortedIndices{I,Inds}}, inds::SortedIndices{I,Inds} -) where {I,Inds<:AbstractArray{I}} - return inds + ::Type{SortedSet{T,Data}}, set::SortedSet{T,Data} +) where {T,Data<:AbstractArray{T}} + return set end function Base.convert( - ::Type{SortedIndices{I,Inds}}, inds::SortedIndices -) where {I,Inds<:AbstractArray{I}} - a = convert(Inds, parent(inds)) - return @inbounds SortedIndices{I,Inds}(a) + ::Type{SortedSet{T,Data}}, set::SortedSet +) where {T,Data<:AbstractArray{T}} + a = convert(Data, parent(set)) + return @inbounds SortedSet{T,Data}(a) end -@inline Base.parent(inds::SortedIndices) = getfield(inds, :inds) - # Basic interface -@propagate_inbounds function Base.iterate(i::SortedIndices{I}, state...) where {I} - return iterate(parent(i), state...) +@propagate_inbounds function Base.iterate(set::SortedSet{T}, state...) where {T} + return iterate(parent(set), state...) +end + +@inline function Base.in(i::T, set::SortedSet{T}) where {T} + return _insorted(i, parent(set), order(set)) +end +@inline Base.IteratorSize(::SortedSet) = Base.HasLength() +@inline Base.length(set::SortedSet) = length(parent(set)) + +function Base.:(==)(set1::SortedSet, set2::SortedSet) + if length(set1) ≠ length(set2) + return false + end + for (j1, j2) in zip(set1, set2) + if j1 ≠ j2 + return false + end + end + return true end -@inline function Base.in(i::I, inds::SortedIndices{I}) where {I} - return _insorted(i, parent(inds), inds.order) +function Base.issetequal(set1::SortedSet, set2::SortedSet) + if length(set1) ≠ length(set2) + return false + end + if order(set1) ≠ order(set2) + # TODO: Make sure this actually sorts! + set2 = SortedSet(parent(set2), order(set1)) + end + for (j1, j2) in zip(set1, set2) + if lt(order(set1), j1, j2) || lt(order(set1), j2, j1) + return false + end + end + return true end -@inline Base.IteratorSize(::SortedIndices) = Base.HasLength() -@inline Base.length(inds::SortedIndices) = length(parent(inds)) -@inline Dictionaries.istokenizable(i::SortedIndices) = true -@inline Dictionaries.tokentype(::SortedIndices) = Int -@inline Dictionaries.iteratetoken(inds::SortedIndices, s...) = - iterate(LinearIndices(parent(inds)), s...) -@inline function Dictionaries.iteratetoken_reverse(inds::SortedIndices) - li = LinearIndices(parent(inds)) +@inline Dictionaries.istokenizable(::SortedSet) = true +@inline Dictionaries.tokentype(::SortedSet) = Int +@inline Dictionaries.iteratetoken(set::SortedSet, s...) = + iterate(LinearIndices(parent(set)), s...) +@inline function Dictionaries.iteratetoken_reverse(set::SortedSet) + li = LinearIndices(parent(set)) if isempty(li) return nothing else @@ -148,8 +173,8 @@ end return (t, t) end end -@inline function Dictionaries.iteratetoken_reverse(inds::SortedIndices, t) - li = LinearIndices(parent(inds)) +@inline function Dictionaries.iteratetoken_reverse(set::SortedSet, t) + li = LinearIndices(parent(set)) t -= 1 if t < first(li) return nothing @@ -158,21 +183,20 @@ end end end -@inline function Dictionaries.gettoken(inds::SortedIndices, i) - a = parent(inds) - r = searchsorted(a, i, inds.order) +@inline function Dictionaries.gettoken(set::SortedSet, i) + a = parent(set) + r = searchsorted(a, i, order(set)) @assert 0 ≤ length(r) ≤ 1 # If > 1, means the elements are not unique length(r) == 0 && return (false, 0) return (true, convert(Int, only(r))) end -@propagate_inbounds Dictionaries.gettokenvalue(inds::SortedIndices, x::Int) = - parent(inds)[x] +@propagate_inbounds Dictionaries.gettokenvalue(set::SortedSet, x::Int) = parent(set)[x] -@inline Dictionaries.isinsertable(i::SortedIndices) = isinsertable(parent(inds)) +@inline Dictionaries.isinsertable(set::SortedSet) = isinsertable(parent(set)) -@inline function Dictionaries.gettoken!(inds::SortedIndices{I}, i::I, values=()) where {I} - a = parent(inds) - r = searchsorted(a, i, inds.order) +@inline function Dictionaries.gettoken!(set::SortedSet{T}, i::T, values=()) where {T} + a = parent(set) + r = searchsorted(a, i, order(set)) @assert 0 ≤ length(r) ≤ 1 # If > 1, means the elements are not unique if length(r) == 0 insert!(a, first(r), i) @@ -182,106 +206,96 @@ end return (true, convert(Int, only(r))) end -@inline function Dictionaries.deletetoken!(inds::SortedIndices, x::Int, values=()) - deleteat!(parent(inds), x) +@inline function Dictionaries.deletetoken!(set::SortedSet, x::Int, values=()) + deleteat!(parent(set), x) foreach(v -> deleteat!(v, x), values) - return inds + return set end -@inline function Base.empty!(inds::SortedIndices, values=()) - empty!(parent(inds)) +@inline function Base.empty!(set::SortedSet, values=()) + empty!(parent(set)) foreach(empty!, values) - return inds + return set end # TODO: Make into `MSmallVector`? # More generally, make a `thaw(::AbstractArray)` function to return # a mutable version of an AbstractArray. -@inline Dictionaries.empty_type( - ::Type{SortedIndices{I,D,Order}}, ::Type{I} -) where {I,D,Order} = SortedIndices{I,Dictionaries.empty_type(D, I),Order} +@inline Dictionaries.empty_type(::Type{SortedSet{T,D,Order}}, ::Type{T}) where {T,D,Order} = + SortedSet{T,Dictionaries.empty_type(D, T),Order} -@inline Dictionaries.empty_type(::Type{<:AbstractVector}, ::Type{I}) where {I} = Vector{I} +@inline Dictionaries.empty_type(::Type{<:AbstractVector}, ::Type{T}) where {T} = Vector{T} -function Base.empty(inds::SortedIndices{I,D}, ::Type{I}) where {I,D} - return Dictionaries.empty_type(typeof(inds), I)(D(), inds.order) +function Base.empty(set::SortedSet{T,D}, ::Type{T}) where {T,D} + return Dictionaries.empty_type(typeof(set), T)(D(), order(set)) end -@inline function Base.copy(inds::SortedIndices, ::Type{I}) where {I} - if I === eltype(inds) - SortedIndices( - copy(parent(inds)), inds.order; issorted=Returns(true), allunique=Returns(true) - ) +@inline function Base.copy(set::SortedSet, ::Type{T}) where {T} + if T === eltype(set) + SortedSet(copy(parent(set)), order(set)) else - SortedIndices( - convert(AbstractArray{I}, parent(inds)), - inds.order; - issorted=Returns(true), - allunique=Returns(true), - ) + SortedSet(convert(AbstractArray{T}, parent(set)), order(set)) end end # TODO: Can this take advantage of sorting? -@inline function Base.filter!(pred, inds::SortedIndices) - filter!(pred, parent(inds)) - return inds +@inline function Base.filter!(pred, set::SortedSet) + filter!(pred, parent(set)) + return set end -function Dictionaries.randtoken(rng::Random.AbstractRNG, inds::SortedIndices) - return rand(rng, keys(parent(inds))) +function Dictionaries.randtoken(rng::Random.AbstractRNG, set::SortedSet) + return rand(rng, keys(parent(set))) end -@inline function Base.sort!( - inds::SortedIndices; lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward -) +@inline function Base.sort!(set::SortedSet; lt=isless, by=identity, rev::Bool=false) + @assert Base.Sort.ord(lt, by, rev) == order(set) # No-op, should be sorted already. - # TODO: Check `ord(lt, by, rev, order) == inds.ord`. - return inds + return set end # Custom faster operations (not required for interface) -function Base.union!(inds::SortedIndices, items::SortedIndices) - if inds.order ≠ items.order +function Base.union!(set::SortedSet, items::SortedSet) + if order(set) ≠ order(items) # Reorder if the orderings are different. - items = SortedIndices(parent(inds), inds.order) + items = SortedSet(parent(set), order(set)) end - unionsortedunique!(parent(inds), parent(items), inds.order) - return inds + unionsortedunique!(parent(set), parent(items), order(set)) + return set end -function Base.union(inds::SortedIndices, items::SortedIndices) - if inds.order ≠ items.order - # Reorder if the orderings are different. - items = SortedIndices(parent(inds), inds.order) +function Base.union(set::SortedSet, items::SortedSet) + if order(set) ≠ order(items) + # TODO: Reorder if the orderings are different. + items = SortedSet(parent(set), order(set)) end - out = unionsortedunique(parent(inds), parent(items), inds.order) - return SortedIndices(out, inds.order; issorted=Returns(true), allunique=Returns(true)) + out = unionsortedunique(parent(set), parent(items), order(set)) + return SortedSet(out, order(set)) end -function Base.union(inds::SortedIndices, items) - return union(inds, SortedIndices(items, inds.order)) +function Base.union(set::SortedSet, items) + return union(set, SortedSet(items, order(set))) end -function Base.intersect(inds::SortedIndices, items::SortedIndices) +function Base.intersect(set::SortedSet, items::SortedSet) # TODO: Make an `intersectsortedunique`. - return intersect(NotInsertable(), inds, items) + return intersect(NotInsertable(), set, items) end -function Base.setdiff(inds::SortedIndices, items) - return setdiff(inds, SortedIndices(items, inds.order)) +function Base.setdiff(set::SortedSet, items) + return setdiff(set, SortedSet(items, order(set))) end -function Base.setdiff(inds::SortedIndices, items::SortedIndices) +function Base.setdiff(set::SortedSet, items::SortedSet) # TODO: Make an `setdiffsortedunique`. - return setdiff(NotInsertable(), inds, items) + return setdiff(NotInsertable(), set, items) end -function Base.symdiff(inds::SortedIndices, items) - return symdiff(inds, SortedIndices(items, inds.order)) +function Base.symdiff(set::SortedSet, items) + return symdiff(set, SortedSet(items, order(set))) end -function Base.symdiff(inds::SortedIndices, items::SortedIndices) +function Base.symdiff(set::SortedSet, items::SortedSet) # TODO: Make an `symdiffsortedunique`. - return symdiff(NotInsertable(), inds, items) + return symdiff(NotInsertable(), set, items) end diff --git a/NDTensors/src/SortedSets/test/runtests.jl b/NDTensors/src/SortedSets/test/runtests.jl index 24500f474d..14061d17dd 100644 --- a/NDTensors/src/SortedSets/test/runtests.jl +++ b/NDTensors/src/SortedSets/test/runtests.jl @@ -3,18 +3,38 @@ using NDTensors.SortedSets using NDTensors.SmallVectors @testset "Test NDTensors.SortedSets" begin - for V in (Vector, MSmallVector{10}, SmallVector{10}) - s1 = SortedSet(V([1, 3, 5])) - s2 = SortedSet(V([2, 3, 6])) + @testset "Basic operations" begin + for V in (Vector, MSmallVector{10}, SmallVector{10}) + for by in (+, -) + s1 = SortedSet(V([1, 5, 3]); by) + s2 = SortedSet(V([2, 3, 6]); by) - # Set interface - @test union(s1, s2) == SortedSet([1, 2, 3, 5, 6]) - @test setdiff(s1, s2) == SortedSet([1, 5]) - @test symdiff(s1, s2) == SortedSet([1, 2, 5, 6]) - @test intersect(s1, s2) == SortedSet([3]) - if SmallVectors.InsertStyle(V) isa IsInsertable - @test insert!(copy(s1), 4) == SortedSet([1, 3, 4, 5]) - @test delete!(copy(s1), 3) == SortedSet([1, 5]) + @test thaw(s1) == s1 + @test SmallVectors.insert(s1, 2) isa typeof(s1) + @test SmallVectors.insert(s1, 2) == SortedSet([1, 2, 3, 5]; by) + @test SmallVectors.delete(s1, 3) isa typeof(s1) + @test SmallVectors.delete(s1, 3) == SortedSet([1, 5]; by) + + # Set interface + @test union(s1, s2) == SortedSet([1, 2, 3, 5, 6]; by) + @test union(s1, [3]) == s1 + @test setdiff(s1, s2) == SortedSet([1, 5]; by) + @test symdiff(s1, s2) == SortedSet([1, 2, 5, 6]; by) + @test intersect(s1, s2) == SortedSet([3]; by) + if SmallVectors.InsertStyle(V) isa IsInsertable + @test insert!(copy(s1), 4) == SortedSet([1, 3, 4, 5]; by) + @test delete!(copy(s1), 3) == SortedSet([1, 5]; by) + end + end end end + @testset "Replacement behavior" begin + s1 = SortedSet([("a", 3), ("b", 2)]; by=first) + s2 = SortedSet([("a", 5)]; by=first) + s = union(s1, s2) + @test s ≠ s1 + @test issetequal(s, s1) + @test ("a", 5) ∈ parent(s) + @test ("a", 3) ∉ parent(s) + end end diff --git a/NDTensors/src/TagSets/src/TagSets.jl b/NDTensors/src/TagSets/src/TagSets.jl index 53df13d0ee..c4d9f1013f 100644 --- a/NDTensors/src/TagSets/src/TagSets.jl +++ b/NDTensors/src/TagSets/src/TagSets.jl @@ -5,10 +5,11 @@ using ..SortedSets using Base: @propagate_inbounds -export TagSet, SmallTagSet, addtags, removetags, replacetags, commontags, noncommontags +export TagSet, + SmallTagSet, MSmallTagSet, addtags, removetags, replacetags, commontags, noncommontags # A sorted collection of unique tags of type `T`. -struct TagSet{T,D<:AbstractIndices{T}} <: AbstractWrappedIndices{T,D} +struct TagSet{T,D<:AbstractIndices{T}} <: AbstractWrappedSet{T,D} data::D end @@ -37,38 +38,50 @@ end return TagSet{T,D}(split(str, delim)) end -const SmallTagSet{S,T} = TagSet{T,SmallSet{S,T}} -@propagate_inbounds SmallTagSet{S}(; kwargs...) where {S} = SmallTagSet{S}([]; kwargs...) -@propagate_inbounds SmallTagSet{S}(iter; kwargs...) where {S} = - SmallTagSet{S}(collect(iter); kwargs...) -@propagate_inbounds SmallTagSet{S}(a::AbstractArray{I}; kwargs...) where {S,I} = - SmallTagSet{S,I}(a; kwargs...) -# Specialized `SmallSet{S,T} = SortedSet{T,SmallVector{S,T}}` constructor -function SmallTagSet{S,T}(str::AbstractString; delim=default_delim()) where {S,T} - # TODO: Optimize for `SmallSet`. - return SmallTagSet{S,T}(split(str, delim)) +for (SetTyp, TagSetTyp) in ((:SmallSet, :SmallTagSet), (:MSmallSet, :MSmallTagSet)) + @eval begin + const $TagSetTyp{S,T,Order} = TagSet{T,$SetTyp{S,T,Order}} + @propagate_inbounds function $TagSetTyp{S,I}(a::AbstractArray; kwargs...) where {S,I} + return TagSet($SetTyp{S,I}(a; kwargs...)) + end + @propagate_inbounds $TagSetTyp{S}(; kwargs...) where {S} = $TagSetTyp{S}([]; kwargs...) + @propagate_inbounds $TagSetTyp{S}(iter; kwargs...) where {S} = + $TagSetTyp{S}(collect(iter); kwargs...) + @propagate_inbounds $TagSetTyp{S}(a::AbstractArray{I}; kwargs...) where {S,I} = + $TagSetTyp{S,I}(a; kwargs...) + # Strings get split by a deliminator. + function $TagSetTyp{S}(str::T; kwargs...) where {S,T<:AbstractString} + return $TagSetTyp{S,T}(str, kwargs...) + end + # Strings get split by a deliminator. + function $TagSetTyp{S,T}( + str::AbstractString; delim=default_delim(), kwargs... + ) where {S,T} + # TODO: Optimize for `SmallSet`. + return $TagSetTyp{S,T}(split(str, delim); kwargs...) + end + end end # Field accessors -Base.parent(tags::TagSet) = getfield(tags, :data) +Base.parent(set::TagSet) = getfield(set, :data) # AbstractWrappedSet interface. # Specialized version when they are the same data type is faster. -@inline SortedSets.rewrap(vec::TagSet{T,D}, data::D) where {T,D<:AbstractIndices{T}} = - TagSet{T,D}(data) -@inline SortedSets.rewrap(vec::TagSet{T,D}, data) where {T,D<:AbstractIndices{T}} = +@inline SortedSets.rewrap(::TagSet{T,D}, data::D) where {T,D<:AbstractIndices{T}} = TagSet{T,D}(data) +@inline SortedSets.rewrap(::TagSet, data) = TagSet(data) # TagSet interface -addtags(tags::TagSet, items) = union(tags, items) -removetags(tags::TagSet, items) = setdiff(tags, items) -commontags(tags::TagSet, items) = intersect(tags, items) -noncommontags(tags::TagSet, items) = symdiff(tags, items) -function replacetags(tags::TagSet, rem, add) - remtags = setdiff(tags, rem) - if length(tags) ≠ length(remtags) + length(rem) +addtags(set::TagSet, items) = union(set, items) +removetags(set::TagSet, items) = setdiff(set, items) +commontags(set::TagSet, items) = intersect(set, items) +noncommontags(set::TagSet, items) = symdiff(set, items) +function replacetags(set::TagSet, rem, add) + remtags = setdiff(set, rem) + if length(set) ≠ length(remtags) + length(rem) # Not all are removed, no replacement - return tags + return set end return union(remtags, add) end