Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wma/fix staging #371

Merged
merged 3 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ include("tensors/combinators/swizzle.jl")
include("tensors/combinators/scale.jl")
include("tensors/combinators/product.jl")


include("traits.jl")

export fsparse, fsparse!, fsprand, fspzeros, ffindnz, fread, fwrite, countstored
Expand Down
12 changes: 8 additions & 4 deletions src/interface/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
getindex_rep_def(lvl::RepeatData, idx) = SolidData(ElementData(lvl.default, lvl.eltype))
getindex_rep_def(lvl::RepeatData, idx::Type{<:AbstractUnitRange}) = SolidData(ElementData(lvl.default, lvl.eltype))

Base.getindex(arr::Tensor, inds...) = getindex_helper(arr, to_indices(arr, inds)...)
@staged function getindex_helper(arr, inds...)
Base.getindex(arr::Tensor, inds...) = getindex_helper(arr, to_indices(arr, inds))
@staged function getindex_helper(arr, inds)

Check warning on line 29 in src/interface/index.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/index.jl#L29

Added line #L29 was not covered by tests
inds <: Type{<:Tuple}
inds = inds.parameters
@assert ndims(arr) == length(inds)
N = ndims(arr)

Expand Down Expand Up @@ -69,8 +71,10 @@
end
end

Base.setindex!(arr::Tensor, src, inds...) = setindex_helper(arr, src, to_indices(arr, inds)...)
@staged function setindex_helper(arr, src, inds...)
Base.setindex!(arr::Tensor, src, inds...) = setindex_helper(arr, src, to_indices(arr, inds))
@staged function setindex_helper(arr, src, inds)

Check warning on line 75 in src/interface/index.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/index.jl#L75

Added line #L75 was not covered by tests
inds <: Type{<:Tuple}
inds = inds.parameters
@assert ndims(arr) == length(inds)
@assert sum(ndims.(inds)) == 0 || (ndims(src) == sum(ndims.(inds)))
N = ndims(arr)
Expand Down
40 changes: 22 additions & 18 deletions src/util/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,40 +37,44 @@ ensures the first Finch invocation runs in the latest world, and leaves hooks so
that subsequent calls to [`Finch.refresh`](@ref) can update the world and
invalidate old versions. If the body contains closures, this macro uses an
eval and invokelatest strategy. Otherwise, it uses a generated function.
This macro does not support type parameters, varargs, or keyword arguments.
"""
macro staged(def)
(@capture def :function(:call(~name, ~args...), ~body)) || throw(ArgumentError("unrecognized function definition in @staged"))

called = gensym(Symbol(name, :_called))
name_2 = gensym(Symbol(name, :_eval_invokelatest))
name_3 = gensym(Symbol(name, :_evaled))
name_generator = gensym(Symbol(name, :_generator))
name_invokelatest = gensym(Symbol(name, :_invokelatest))
name_eval_invokelatest = gensym(Symbol(name, :_eval_invokelatest))

def = quote
$called = false
function $name_generator($(args...))
$body
end

function $name_invokelatest($(args...))
$(Base.invokelatest)($name_eval_invokelatest, $(args...))
end

function $name_2($(args...))
global $called
if !$called
code = let ($(args...),) = ($(map((arg)->:(typeof($arg)), args)...),)
$body
function $name_eval_invokelatest($(args...))
code = $name_generator($(map((arg)->:(typeof($arg)), args)...),)
def = quote
function $($(QuoteNode(name_invokelatest)))($($(map(arg -> :(:($($(QuoteNode(arg)))::$(typeof($arg)))), args)...)))
$($(QuoteNode(name_eval_invokelatest)))($($(map(QuoteNode, args)...)))
end
def = quote
function $($(QuoteNode(name_3)))($($(map(QuoteNode, args)...)))
$code
end
function $($(QuoteNode(name_eval_invokelatest)))($($(map(arg -> :(:($($(QuoteNode(arg)))::$(typeof($arg)))), args)...)))
$code
end
($@__MODULE__).eval(def)
$called = true
end
Base.invokelatest(($@__MODULE__).$name_3, $(args...))
($@__MODULE__).eval(def)
$(Base.invokelatest)(($@__MODULE__).$name_eval_invokelatest, $(args...))
end

@generated function $name($(args...))
# Taken from https://github.com/NHDaly/StagedFunctions.jl/blob/6fafbc560421f70b05e3df330b872877db0bf3ff/src/StagedFunctions.jl#L116
body_2 = () -> begin
code = $(body)
code = $name_generator($(args...))
if has_function_def(macroexpand($@__MODULE__, code))
:($($(name_2))($($args...)))
:($($(name_invokelatest))($($(map(QuoteNode, args)...))))
else
quote
$code
Expand Down
Loading