Skip to content

Commit

Permalink
Merge pull request #610 from willow-ahrens/wma/fix609
Browse files Browse the repository at this point in the history
fix 609
  • Loading branch information
willow-ahrens authored Jun 26, 2024
2 parents a7f5b89 + ccc99c0 commit 6d5b2cf
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
52 changes: 41 additions & 11 deletions src/interface/copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,6 @@ format.
"""
dropfills(src) = dropfills!(similar(src), src)

"""
dropfills!(dst, src)
Copy only the non- fill values from `src` into `dst`. The shape and format of
`dst` must match `src`
"""
dropfills!(dst::AbstractTensor, src) = dropfills_helper!(dst, src)
dropfills!(dst::SwizzleArray{dims}, src::SwizzleArray{dims}) where {dims} = swizzle(dropfills_helper!(dst.body, src.body), dims...)

@staged function dropfills_helper!(dst, src)
ndims(dst) > ndims(src) && throw(DimensionMismatch("more dimensions in destination than source"))
ndims(dst) < ndims(src) && throw(DimensionMismatch("less dimensions in destination than source"))
Expand All @@ -83,7 +74,7 @@ dropfills!(dst::SwizzleArray{dims}, src::SwizzleArray{dims}) where {dims} = swiz
T = eltype(dst)
d = fill_value(dst)
return quote
@finch begin
@finch mode=:fast begin
dst .= $(fill_value(dst))
$(Expr(:for, exts, quote
let tmp = src[$(idxs...)]
Expand All @@ -95,4 +86,43 @@ dropfills!(dst::SwizzleArray{dims}, src::SwizzleArray{dims}) where {dims} = swiz
end
return dst
end
end
end

"""
dropfills!(dst, src)
Copy only the non-fill values from `src` into `dst`.
"""
dropfills!(dst::AbstractTensor, src::AbstractTensor) =
dropfills_helper!(dst, src)

dropfills!(dst::AbstractTensor, src::AbstractArray) =
dropfills_helper!(dst, src)

dropfills!(dst::AbstractArray, src::AbstractTensor) =
dropfills_helper!(dst, src)

function dropfills_swizzled!(dst, src, perm)
if issorted(perm)
return dropfills_helper!(dst, src)
else
tmp = rep_construct(permutedims_rep(data_rep(src), perm))
tmp = dropfills_helper!(swizzle(tmp, invperm(perm)...), src)
return copyto_helper!(dst, tmp.body)
end
end

dropfills!(dst::AbstractArray, src::SwizzleArray{dims}) where {dims} =
dropfills_swizzled!(dst, src.body, dims)

dropfills!(dst::AbstractTensor, src::SwizzleArray{dims}) where {dims} =
dropfills_swizzled!(dst, src.body, dims)

dropfills!(dst::SwizzleArray{dims}, src::SwizzleArray{dims2}) where {dims, dims2} =
swizzle(dropfills!(dst.body, swizzle(src, invperm(dims)...)), dims...)

dropfills!(dst::SwizzleArray{dims}, src::AbstractTensor) where {dims} =
swizzle(dropfills!(dst.body, swizzle(src, invperm(dims)...)), dims...)

dropfills!(dst::SwizzleArray{dims}, src::AbstractArray) where {dims} =
swizzle(dropfills!(dst.body, swizzle(src, invperm(dims)...)), dims...)
5 changes: 5 additions & 0 deletions test/test_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -809,4 +809,9 @@ using Finch: AsArray
=#
end

begin
A = Tensor(Dense(SparseList(Element(0.0))))
B = dropfills!(swizzle(A, 2, 1), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0])
@test B == swizzle(Tensor(Dense{Int64}(SparseList{Int64}(Element{0.0, Float64, Int64}([4.4, 1.1, 2.2, 5.5, 3.3]), 3, [1, 2, 3, 5, 6], [3, 1, 1, 3, 1]), 4)), 2, 1)
end
end

0 comments on commit 6d5b2cf

Please sign in to comment.