Skip to content

Commit

Permalink
[NDTensors] Fix bugs in SortedSets, SmallVectors, and `BlockSpars…
Browse files Browse the repository at this point in the history
…eArrays` (#1211)
  • Loading branch information
mtfishman authored Oct 7, 2023
1 parent 81568ad commit 8d3a807
Show file tree
Hide file tree
Showing 14 changed files with 436 additions and 297 deletions.
1 change: 1 addition & 0 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ FLoops = "0.2.1"
Folds = "0.2.8"
Functors = "0.2, 0.3, 0.4"
HDF5 = "0.14, 0.15, 0.16, 0.17"
InlineStrings = "1"
Requires = "1.1"
SimpleTraits = "0.9.4"
SplitApplyCombine = "1.2.2"
Expand Down
13 changes: 11 additions & 2 deletions NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,17 @@ struct BlockZero{Axes}
axes::Axes
end

function (f::BlockZero)(T::Type, I::CartesianIndex)
return fill!(T(undef, block_size(f.axes, Block(Tuple(I)))), false)
function (f::BlockZero)(
arraytype::Type{<:AbstractArray{T,N}}, I::CartesianIndex{N}
) where {T,N}
return fill!(arraytype(undef, block_size(f.axes, Block(Tuple(I)))), false)
end

# Fallback to Array if it is abstract
function (f::BlockZero)(
arraytype::Type{AbstractArray{T,N}}, I::CartesianIndex{N}
) where {T,N}
return fill!(Array{T,N}(undef, block_size(f.axes, Block(Tuple(I)))), false)
end

function BlockSparseArray(
Expand Down
9 changes: 9 additions & 0 deletions NDTensors/src/BlockSparseArrays/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using NDTensors.BlockSparseArrays
using BlockArrays: BlockArrays

@testset "Test NDTensors.BlockSparseArrays" begin
@testset "README" begin
Expand All @@ -9,4 +10,12 @@ using NDTensors.BlockSparseArrays
),
) isa Any
end
@testset "Mixed block test" begin
blocks = [BlockArrays.Block(1, 1), BlockArrays.Block(2, 2)]
block_data = [randn(2, 2)', randn(3, 3)]
inds = ([2, 3], [2, 3])
A = BlockSparseArray(blocks, block_data, inds)
@test A[BlockArrays.Block(1, 1)] == block_data[1]
@test A[BlockArrays.Block(1, 2)] == zeros(2, 3)
end
end
23 changes: 23 additions & 0 deletions NDTensors/src/SmallVectors/src/BaseExt/sort.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Custom version of `sort` (`SmallVectors.sort`) that directly uses an `order::Ordering`.
function sort(v, order::Base.Sort.Ordering; alg::Base.Sort.Algorithm=Base.Sort.defalg(v))
mv = thaw(v)
SmallVectors.sort!(mv, order; alg)
return freeze(mv)
end

# Custom version of `sort!` (`SmallVectors.sort!`) that directly uses an `order::Ordering`.
function sort!(
v::AbstractVector{T},
order::Base.Sort.Ordering;
alg::Base.Sort.Algorithm=Base.Sort.defalg(v),
scratch::Union{Vector{T},Nothing}=nothing,
) where {T}
if VERSION < v"1.8.4"
Base.sort!(v, alg, order)
else
Base.Sort._sort!(
v, Base.Sort.maybe_apply_initial_optimizations(alg), order, (; scratch)
)
end
return v
end
7 changes: 3 additions & 4 deletions NDTensors/src/SmallVectors/src/BaseExt/sortedunique.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ function unionsortedunique!(itr1, itr2, order::Ordering)
for j in length(itr1):-1:(i1 + 1)
itr1[j] = itr1[j - 1]
end
# Replace with the item from the second list
itr1[i1] = item2
i1 += 1
i2 += 1
Expand Down Expand Up @@ -62,7 +63,7 @@ end
# Union two unique sorted collections into an
# output buffer, returning a unique sorted collection.
function unionsortedunique(itr1, itr2, order::Ordering)
out = thaw_type(itr1)(undef, length(itr1))
out = thaw_type(itr1)()
i1 = firstindex(itr1)
i2 = firstindex(itr2)
iout = firstindex(out)
Expand All @@ -82,14 +83,12 @@ function unionsortedunique(itr1, itr2, order::Ordering)
iout += 1
i2 += 1
else # They are equal
out[iout] = item1
out[iout] = item2
iout += 1
i1 += 1
i2 += 1
end
end
# In case `out` was too long to begin with.
## resize!(out, iout - 1)
# TODO: Use `insertat!`?
r1 = i1:stop1
resize!(out, length(out) + length(r1))
Expand Down
1 change: 1 addition & 0 deletions NDTensors/src/SmallVectors/src/SmallVectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ NotImplemented() = NotImplemented("Not implemented.")

include("BaseExt/insertstyle.jl")
include("BaseExt/thawfreeze.jl")
include("BaseExt/sort.jl")
include("BaseExt/sortedunique.jl")
include("abstractarray/insert.jl")
include("abstractsmallvector/abstractsmallvector.jl")
Expand Down
55 changes: 33 additions & 22 deletions NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Base.resize!(vec::AbstractSmallVector, len) = throw(NotImplemented())

@inline function resize(vec::AbstractSmallVector, len)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
resize!(mvec, len)
return convert(similar_type(vec), mvec)
end
Expand All @@ -25,14 +25,14 @@ end
end

@inline function empty(vec::AbstractSmallVector)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
empty!(mvec)
return convert(similar_type(vec), mvec)
end

@inline function StaticArrays.setindex(vec::AbstractSmallVector, item, index::Integer)
@boundscheck checkbounds(vec, index)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
@inbounds mvec[index] = item
return convert(similar_type(vec), mvec)
end
Expand All @@ -44,7 +44,7 @@ end
end

@inline function StaticArrays.push(vec::AbstractSmallVector, item)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
push!(mvec, item)
return convert(similar_type(vec), mvec)
end
Expand All @@ -55,7 +55,7 @@ end
end

@inline function StaticArrays.pop(vec::AbstractSmallVector)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
pop!(mvec)
return convert(similar_type(vec), mvec)
end
Expand All @@ -67,7 +67,7 @@ end

# Don't `@inline`, makes it slower.
function StaticArrays.pushfirst(vec::AbstractSmallVector, item)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
pushfirst!(mvec, item)
return convert(similar_type(vec), mvec)
end
Expand All @@ -80,7 +80,7 @@ end

# Don't `@inline`, makes it slower.
function StaticArrays.popfirst(vec::AbstractSmallVector)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
popfirst!(mvec)
return convert(similar_type(vec), mvec)
end
Expand Down Expand Up @@ -129,7 +129,7 @@ end

# Don't @inline, makes it slower.
function StaticArrays.insert(vec::AbstractSmallVector, index::Integer, item)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
insert!(mvec, index, item)
return convert(similar_type(vec), mvec)
end
Expand All @@ -154,7 +154,7 @@ end
function StaticArrays.deleteat(
vec::AbstractSmallVector, index::Union{Integer,AbstractUnitRange{<:Integer}}
)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
deleteat!(mvec, index)
return convert(similar_type(vec), mvec)
end
Expand All @@ -163,9 +163,7 @@ end
# https://github.com/JuliaLang/julia/blob/bed2cd540a11544ed4be381d471bbf590f0b745e/base/sort.jl#L722-L736
# https://en.wikipedia.org/wiki/Insertion_sort#:~:text=Insertion%20sort%20is%20a%20simple,%2C%20heapsort%2C%20or%20merge%20sort.
# Alternatively could use `TupleTools.jl` or `StaticArrays.jl` for out-of-place sorting.
@inline function Base.sort!(
vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false
)
@inline function sort!(vec::AbstractSmallVector, order::Base.Sort.Ordering)
lo, hi = firstindex(vec), lastindex(vec)
lo_plus_1 = (lo + 1)
@inbounds for i in lo_plus_1:hi
Expand All @@ -174,7 +172,7 @@ end
jmax = j
for _ in jmax:-1:lo_plus_1
y = vec[j - 1]
if !(lt(by(x), by(y)) != rev)
if !Base.Sort.lt(order, x, y)
break
end
vec[j] = y
Expand All @@ -185,20 +183,33 @@ end
return vec
end

@inline function Base.sort!(
vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false
)
SmallVectors.sort!(vec, Base.Sort.ord(lt, by, rev))
return vec
end

# Don't @inline, makes it slower.
function Base.sort(vec::AbstractSmallVector; kwargs...)
mvec = Base.copymutable(vec)
sort!(mvec; kwargs...)
function sort(vec::AbstractSmallVector, order::Base.Sort.Ordering)
mvec = thaw(vec)
SmallVectors.sort!(mvec, order)
return convert(similar_type(vec), mvec)
end

@inline function Base.sort(
vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false
)
return SmallVectors.sort(vec, Base.Sort.ord(lt, by, rev))
end

@inline function insertsorted!(vec::AbstractSmallVector, item; kwargs...)
insert!(vec, searchsortedfirst(vec, item; kwargs...), item)
return vec
end

function insertsorted(vec::AbstractSmallVector, item; kwargs...)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
insertsorted!(mvec, item; kwargs...)
return convert(similar_type(vec), mvec)
end
Expand Down Expand Up @@ -228,7 +239,7 @@ end
end

function mergesorted(vec::AbstractSmallVector, item; kwargs...)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
mergesorted!(mvec, item; kwargs...)
return convert(similar_type(vec), mvec)
end
Expand Down Expand Up @@ -261,7 +272,7 @@ end

# Don't @inline, makes it slower.
function Base.circshift(vec::AbstractSmallVector, shift::Integer)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
circshift!(mvec, shift)
return convert(similar_type(vec), mvec)
end
Expand All @@ -277,7 +288,7 @@ end
# Missing from `StaticArrays.jl`.
# Don't @inline, makes it slower.
function append(vec::AbstractSmallVector, item::AbstractVector)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
append!(mvec, item)
return convert(similar_type(vec), mvec)
end
Expand All @@ -294,14 +305,14 @@ end
# Missing from `StaticArrays.jl`.
# Don't @inline, makes it slower.
function prepend(vec::AbstractSmallVector, item::AbstractVector)
mvec = Base.copymutable(vec)
mvec = thaw(vec)
prepend!(mvec, item)
return convert(similar_type(vec), mvec)
end

# Don't @inline, makes it slower.
function Base.vcat(vec1::AbstractSmallVector, vec2::AbstractVector)
mvec1 = Base.copymutable(vec1)
mvec1 = thaw(vec1)
append!(mvec1, vec2)
return convert(similar_type(vec1), mvec1)
end
12 changes: 6 additions & 6 deletions NDTensors/src/SortedSets/src/BaseExt/sorted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ function uniquesorted(vec; lt=isless, by=identity, rev::Bool=false, order::Order
end

function uniquesorted(vec::AbstractVector, order::Ordering)
vec = copy(vec)
i = firstindex(vec)
stopi = lastindex(vec)
mvec = thaw(vec)
i = firstindex(mvec)
stopi = lastindex(mvec)
while i < stopi
if !lt(order, @inbounds(vec[i]), @inbounds(vec[i + 1]))
deleteat!(vec, i)
if !lt(order, @inbounds(mvec[i]), @inbounds(mvec[i + 1]))
deleteat!(mvec, i)
stopi -= 1
else
i += 1
end
end
return vec
return freeze(mvec)
end
37 changes: 35 additions & 2 deletions NDTensors/src/SortedSets/src/DictionariesExt/insert.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,35 @@
SmallVectors.insert(inds::AbstractIndices, i) = insert!(copy(inds), i)
SmallVectors.delete(inds::AbstractIndices, i) = delete!(copy(inds), i)
SmallVectors.insert(inds::AbstractIndices, i) = insert(InsertStyle(inds), inds, i)

function SmallVectors.insert(::InsertStyle, inds::AbstractIndices, i)
return error("Not implemented")
end

function SmallVectors.insert(::IsInsertable, inds::AbstractIndices, i)
inds = copy(inds)
insert!(inds, i)
return inds
end

function SmallVectors.insert(::FastCopy, inds::AbstractIndices, i)
minds = thaw(inds)
insert!(minds, i)
return freeze(minds)
end

SmallVectors.delete(inds::AbstractIndices, i) = delete(InsertStyle(inds), inds, i)

function SmallVectors.delete(::InsertStyle, inds::AbstractIndices, i)
return error("Not implemented")
end

function SmallVectors.delete(::IsInsertable, inds::AbstractIndices, i)
inds = copy(inds)
delete!(inds, i)
return inds
end

function SmallVectors.delete(::FastCopy, inds::AbstractIndices, i)
minds = thaw(inds)
delete!(minds, i)
return freeze(minds)
end
2 changes: 1 addition & 1 deletion NDTensors/src/SortedSets/src/SortedSets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ..SmallVectors
using Base: @propagate_inbounds
using Base.Order: Ordering, Forward, ord, lt

export AbstractWrappedIndices, SortedSet, SmallSet, MSmallSet
export AbstractWrappedSet, SortedSet, SmallSet, MSmallSet

include("BaseExt/sorted.jl")
include("DictionariesExt/insert.jl")
Expand Down
Loading

0 comments on commit 8d3a807

Please sign in to comment.