Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
willow-ahrens committed Jan 8, 2025
1 parent 2bd8d4c commit aa59702
Show file tree
Hide file tree
Showing 48 changed files with 1,834 additions and 842 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TensorMarket = "8b7d4fe7-0b45-4d0d-9dd8-5cc9b23b4b77"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"

[targets]
test = ["Test", "ArgParse", "LinearAlgebra", "Random", "SparseArrays", "Graphs", "SimpleWeightedGraphs", "HDF5", "NPZ", "Pkg", "TensorMarket", "Documenter"]
test = ["ReTestItems", "Test", "ArgParse", "LinearAlgebra", "Random", "SparseArrays", "Graphs", "SimpleWeightedGraphs", "HDF5", "NPZ", "Pkg", "TensorMarket", "Documenter"]
3 changes: 3 additions & 0 deletions src/tensors/combinators/swizzle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ fill_value(arr::SwizzleArray) = fill_value(typeof(arr))
fill_value(::Type{SwizzleArray{dims, Body}}) where {dims, Body} = fill_value(Body)
Base.similar(arr::SwizzleArray{dims}) where {dims} = SwizzleArray{dims}(similar(arr.body))

isstructequal(a::T, b::T) where {T <: Finch.SwizzleArray} =
isstructequal(a.body, b.body)

countstored(arr::SwizzleArray) = countstored(arr.body)

Base.size(arr::SwizzleArray{dims}) where {dims} = map(n->size(arr.body)[n], dims)
Expand Down
17 changes: 16 additions & 1 deletion src/tensors/fibers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,19 @@ Base.similar(fbr::AbstractFiber, dims::Tuple) = similar(fbr, fill_value(fbr), el
Base.similar(fbr::AbstractFiber, eltype::Type, dims::Tuple) = similar(fbr, convert(eltype, fill_value(fbr)), eltype, dims)
Base.similar(fbr::AbstractFiber, fill_value, eltype::Type, dims::Tuple) = Tensor(similar_level(fbr.lvl, fill_value, eltype, dims...))

moveto(tns::Tensor, device) = Tensor(moveto(tns.lvl, device))
moveto(tns::Tensor, device) = Tensor(moveto(tns.lvl, device))

struct Structure
t
end

Base.:(==)(a::Structure, b::Structure) = isstructequal(a.t, b.t)

isstructequal(a, b) = a === b

Check warning on line 309 in src/tensors/fibers.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/fibers.jl#L309

Added line #L309 was not covered by tests

isstructequal(a::T, b::T) where {T <: Tensor} =
isstructequal(a.lvl, b.lvl)

isstructequal(a::T, b::T) where {T <: Finch.SubFiber} =

Check warning on line 314 in src/tensors/fibers.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/fibers.jl#L314

Added line #L314 was not covered by tests
isstructequal(a.lvl, b.lvl) &&
isstructequal(a.ptr, b.ptr)
4 changes: 4 additions & 0 deletions src/tensors/levels/dense_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ Base.resize!(lvl::DenseLevel{Ti}, dims...) where {Ti} =
@inline level_fill_value(::Type{<:DenseLevel{Ti, Lvl}}) where {Ti, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:DenseLevel{Ti, Lvl}}) where {Ti, Lvl} = DenseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: Dense} =
a.shape == b.shape &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:DenseLevel})() = fbr
function (fbr::SubFiber{<:DenseLevel{Ti}})(idxs...) where {Ti}
isempty(idxs) && return fbr
Expand Down
6 changes: 6 additions & 0 deletions src/tensors/levels/dense_rle_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ end
@inline level_fill_value(::Type{<:RunListLevel{Ti, Ptr, Right, merge, Lvl}}) where {Ti, Ptr, Right, merge, Lvl}= level_fill_value(Lvl)
data_rep_level(::Type{<:RunListLevel{Ti, Ptr, Right, merge, Lvl}}) where {Ti, Ptr, Right, merge, Lvl} = DenseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: RunList} =
a.shape == b.shape &&
a.ptr == b.ptr &&
a.right == b.right &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:RunListLevel})() = fbr
function (fbr::SubFiber{<:RunListLevel})(idxs...)
isempty(idxs) && return fbr
Expand Down
3 changes: 3 additions & 0 deletions src/tensors/levels/element_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ set_fill_value!(lvl::ElementLevel{Vf, Tv, Tp}, init) where {Vf, Tv, Tp} =
ElementLevel{init, Tv, Tp}(lvl.val)
Base.resize!(lvl::ElementLevel) = lvl

isstructequal(a::T, b::T) where {T <: Element} =
a.val == b.val

function Base.show(io::IO, lvl::ElementLevel{Vf, Tv, Tp, Val}) where {Vf, Tv, Tp, Val}
print(io, "Element{")
show(io, Vf)
Expand Down
5 changes: 5 additions & 0 deletions src/tensors/levels/mutex_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ end
@inline level_fill_value(::Type{<:MutexLevel{AVal, Lvl}}) where {AVal, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:MutexLevel{AVal,Lvl}}) where {AVal,Lvl} = data_rep_level(Lvl)

isstructequal(a::T, b::T) where {T <: Mutex} =
typeof(a.locks) == typeof(b.locks) &&
isstructequal(a.lvl, b.lvl)
# Temporary hack to deal with SpinLock allocate undefined references.

# FIXME: These.
(fbr::Tensor{<:MutexLevel})() = SubFiber(fbr.lvl, 1)()
(fbr::SubFiber{<:MutexLevel})() = fbr #TODO this is not consistent somehow
Expand Down
2 changes: 2 additions & 0 deletions src/tensors/levels/pattern_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ end
(fbr::AbstractFiber{<:PatternLevel})() = true
data_rep_level(::Type{<:PatternLevel}) = ElementData(false, Bool)

isstructequal(a::T, b::T) where {T <: Pattern} = true

Check warning on line 46 in src/tensors/levels/pattern_levels.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/levels/pattern_levels.jl#L46

Added line #L46 was not covered by tests

postype(::Type{<:PatternLevel{Tp}}) where {Tp} = Tp

function moveto(lvl::PatternLevel{Tp}, device) where {Tp}
Expand Down
3 changes: 3 additions & 0 deletions src/tensors/levels/separate_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ pattern!(lvl::SeparateLevel) = SeparateLevel(pattern!(lvl.lvl), map(pattern!, lv
set_fill_value!(lvl::SeparateLevel, init) = SeparateLevel(set_fill_value!(lvl.lvl, init), map(lvl_2->set_fill_value!(lvl_2, init), lvl.val))
Base.resize!(lvl::SeparateLevel, dims...) = SeparateLevel(resize!(lvl.lvl, dims...), map(lvl_2->resize!(lvl_2, dims...), lvl.val))

isstructequal(a::T, b::T) where {T <: Separate} =
all(isstructequal(x,y) for (x,y) in zip(a.val, b.val)) && isstructequal(a.lvl, b.lvl)

function Base.show(io::IO, lvl::SeparateLevel{Lvl, Val}) where {Lvl, Val}
print(io, "Separate(")
if get(io, :compact, false)
Expand Down
6 changes: 6 additions & 0 deletions src/tensors/levels/sparse_band_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ end
@inline level_fill_value(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparseBandLevel{Ti, Idx, Ofs, Lvl}}) where {Ti, Idx, Ofs, Lvl} = SparseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: SparseBand} =
a.shape == b.shape &&
a.idx == b.idx &&
a.ofs == b.ofs &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:SparseBandLevel})() = fbr
function (fbr::SubFiber{<:SparseBandLevel})(idxs...)
isempty(idxs) && return fbr
Expand Down
7 changes: 7 additions & 0 deletions src/tensors/levels/sparse_bytemap_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ end
@inline level_fill_value(::Type{<:SparseByteMapLevel{Ti, Ptr, Tbl, Srt, Lvl}}) where {Ti, Ptr, Tbl, Srt, Lvl}= level_fill_value(Lvl)
data_rep_level(::Type{<:SparseByteMapLevel{Ti, Ptr, Tbl, Srt, Lvl}}) where {Ti, Ptr, Tbl, Srt, Lvl} = SparseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: SparseByteMap} =
a.shape == b.shape &&
a.ptr == b.ptr &&
a.tbl == b.tbl &&
a.srt == b.srt &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:SparseByteMapLevel})() = fbr
function (fbr::SubFiber{<:SparseByteMapLevel{Ti}})(idxs...) where {Ti}
isempty(idxs) && return fbr
Expand Down
6 changes: 6 additions & 0 deletions src/tensors/levels/sparse_coo_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ end
@inline level_fill_value(::Type{<:SparseCOOLevel{N, TI, Ptr, Tbl, Lvl}}) where {N, TI, Ptr, Tbl, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparseCOOLevel{N, TI, Ptr, Tbl, Lvl}}) where {N, TI, Ptr, Tbl, Lvl} = (SparseData^N)(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: SparseCOO} =
a.shape == b.shape &&
a.ptr == b.ptr &&
a.tbl == b.tbl &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:SparseCOOLevel})() = fbr
(fbr::SubFiber{<:SparseCOOLevel})() = fbr
function (fbr::SubFiber{<:SparseCOOLevel{N, TI}})(idxs...) where {N, TI}
Expand Down
5 changes: 5 additions & 0 deletions src/tensors/levels/sparse_dict_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ end
@inline level_fill_value(::Type{<:SparseDictLevel{Ti, Ptr, Idx, Val, Tbl, Pool, Lvl}}) where {Ti, Ptr, Idx, Val, Tbl, Pool, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparseDictLevel{Ti, Ptr, Idx, Val, Tbl, Pool, Lvl}}) where {Ti, Ptr, Idx, Val, Tbl, Pool, Lvl} = SparseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: SparseDict} =
a.shape == b.shape &&
a.tbl == b.tbl &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:SparseDictLevel})() = fbr
function (fbr::SubFiber{<:SparseDictLevel{Ti}})(idxs...) where {Ti}
isempty(idxs) && return fbr
Expand Down
6 changes: 6 additions & 0 deletions src/tensors/levels/sparse_interval_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ end
@inline level_fill_value(::Type{<:SparseIntervalLevel{Ti, Left, Right, Lvl}}) where {Ti, Left, Right, Lvl}= level_fill_value(Lvl)
data_rep_level(::Type{<:SparseIntervalLevel{Ti, Left, Right, Lvl}}) where {Ti, Left, Right, Lvl} = SparseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: SparseInterval} =
a.shape == b.shape &&
a.left == b.left &&
a.right == b.right &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:SparseIntervalLevel})() = fbr
function (fbr::SubFiber{<:SparseIntervalLevel})(idxs...)
isempty(idxs) && return fbr
Expand Down
6 changes: 6 additions & 0 deletions src/tensors/levels/sparse_list_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ end
@inline level_fill_value(::Type{<:SparseListLevel{Ti, Ptr, Idx, Lvl}}) where {Ti, Ptr, Idx, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparseListLevel{Ti, Ptr, Idx, Lvl}}) where {Ti, Ptr, Idx, Lvl} = SparseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: SparseList} =
a.shape == b.shape &&
a.ptr == b.ptr &&
a.idx == b.idx &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:SparseListLevel})() = fbr
function (fbr::SubFiber{<:SparseListLevel{Ti}})(idxs...) where {Ti}
isempty(idxs) && return fbr
Expand Down
5 changes: 5 additions & 0 deletions src/tensors/levels/sparse_point_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ end
@inline level_fill_value(::Type{<:SparsePointLevel{Ti, Idx, Lvl}}) where {Ti, Idx, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparsePointLevel{Ti, Idx, Lvl}}) where {Ti, Idx, Lvl} = SparseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: SparsePoint} =
a.shape == b.shape &&
a.idx == b.idx &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:SparsePointLevel})() = fbr
function (fbr::SubFiber{<:SparsePointLevel{Ti}})(idxs...) where {Ti}
isempty(idxs) && return fbr
Expand Down
7 changes: 7 additions & 0 deletions src/tensors/levels/sparse_rle_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ end
@inline level_fill_value(::Type{<:SparseRunListLevel{Ti, Ptr, Left, Right, merge, Lvl}}) where {Ti, Ptr, Left, Right, merge, Lvl}= level_fill_value(Lvl)
data_rep_level(::Type{<:SparseRunListLevel{Ti, Ptr, Left, Right, merge, Lvl}}) where {Ti, Ptr, Left, Right, merge, Lvl} = SparseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: SparseRunList} =
a.shape == b.shape &&
a.ptr == b.ptr &&
a.left == b.left &&
a.right == b.right &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:SparseRunListLevel})() = fbr
function (fbr::SubFiber{<:SparseRunListLevel})(idxs...)
isempty(idxs) && return fbr
Expand Down
7 changes: 7 additions & 0 deletions src/tensors/levels/sparse_vbl_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ end
@inline level_fill_value(::Type{<:SparseBlockListLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = level_fill_value(Lvl)
data_rep_level(::Type{<:SparseBlockListLevel{Ti, Ptr, Idx, Ofs, Lvl}}) where {Ti, Ptr, Idx, Ofs, Lvl} = SparseData(data_rep_level(Lvl))

isstructequal(a::T, b::T) where {T <: SparseBlockList} =
a.shape == b.shape &&
a.ptr == b.ptr &&
a.idx == b.idx &&
a.ofs == b.ofs &&
isstructequal(a.lvl, b.lvl)

(fbr::AbstractFiber{<:SparseBlockListLevel})() = fbr
function (fbr::SubFiber{<:SparseBlockListLevel})(idxs...)
isempty(idxs) && return fbr
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
RewriteTools = "5969e224-3634-4c61-9f66-721b69e98b8a"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
Expand Down
Loading

0 comments on commit aa59702

Please sign in to comment.