From a640aa11261a8a5ddb3d671418b23c73ef8000ed Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 10 Jan 2024 17:02:40 -0500 Subject: [PATCH 1/3] fix staging --- src/Finch.jl | 1 - src/interface/index.jl | 12 ++++++++---- src/util/util.jl | 41 ++++++++++++++++++++++------------------- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/Finch.jl b/src/Finch.jl index 0b5686230..dcd52fff8 100644 --- a/src/Finch.jl +++ b/src/Finch.jl @@ -125,7 +125,6 @@ include("tensors/combinators/swizzle.jl") include("tensors/combinators/scale.jl") include("tensors/combinators/product.jl") - include("traits.jl") export fsparse, fsparse!, fsprand, fspzeros, ffindnz, fread, fwrite, countstored diff --git a/src/interface/index.jl b/src/interface/index.jl index 6d95a4b10..cec3c757a 100644 --- a/src/interface/index.jl +++ b/src/interface/index.jl @@ -25,8 +25,10 @@ getindex_rep_def(lvl::RepeatData, idx::Drop) = SolidData(ElementData(lvl.default getindex_rep_def(lvl::RepeatData, idx) = SolidData(ElementData(lvl.default, lvl.eltype)) getindex_rep_def(lvl::RepeatData, idx::Type{<:AbstractUnitRange}) = SolidData(ElementData(lvl.default, lvl.eltype)) -Base.getindex(arr::Tensor, inds...) = getindex_helper(arr, to_indices(arr, inds)...) -@staged function getindex_helper(arr, inds...) +Base.getindex(arr::Tensor, inds...) = getindex_helper(arr, to_indices(arr, inds)) +@staged function getindex_helper(arr, inds::Tuple) + inds <: Type{<:Tuple} + inds = inds.parameters @assert ndims(arr) == length(inds) N = ndims(arr) @@ -69,8 +71,10 @@ Base.getindex(arr::Tensor, inds...) = getindex_helper(arr, to_indices(arr, inds) end end -Base.setindex!(arr::Tensor, src, inds...) = setindex_helper(arr, src, to_indices(arr, inds)...) -@staged function setindex_helper(arr, src, inds...) +Base.setindex!(arr::Tensor, src, inds...) = setindex_helper(arr, src, to_indices(arr, inds)) +@staged function setindex_helper(arr, src, inds) + inds <: Type{<:Tuple} + inds = inds.parameters @assert ndims(arr) == length(inds) @assert sum(ndims.(inds)) == 0 || (ndims(src) == sum(ndims.(inds))) N = ndims(arr) diff --git a/src/util/util.jl b/src/util/util.jl index 5ad8226d2..0dc98e43d 100644 --- a/src/util/util.jl +++ b/src/util/util.jl @@ -41,36 +41,39 @@ eval and invokelatest strategy. Otherwise, it uses a generated function. macro staged(def) (@capture def :function(:call(~name, ~args...), ~body)) || throw(ArgumentError("unrecognized function definition in @staged")) - called = gensym(Symbol(name, :_called)) - name_2 = gensym(Symbol(name, :_eval_invokelatest)) - name_3 = gensym(Symbol(name, :_evaled)) + name_generator = gensym(Symbol(name, :_generator)) + name_invokelatest = gensym(Symbol(name, :_invokelatest)) + name_eval_invokelatest = gensym(Symbol(name, :_eval_invokelatest)) def = quote - $called = false + function $name_generator($(args...)) + $body + end + + function $name_invokelatest($(args...)) + $invokelatest($name_eval_invokelatest, $(args...)) + end - function $name_2($(args...)) - global $called - if !$called - code = let ($(args...),) = ($(map((arg)->:(typeof($arg)), args)...),) - $body + function $name_eval_invokelatest($(args...)) + code = $name_generator($(map((arg)->:(typeof($arg)), args)...),) + def = quote + function $($(QuoteNode(name_invokelatest)))($($(map(arg -> :(:($($(QuoteNode(arg)))::$(typeof($arg)))), args)...))) + $($(QuoteNode(name_eval_invokelatest)))($($(map(QuoteNode, args)...))) end - def = quote - function $($(QuoteNode(name_3)))($($(map(QuoteNode, args)...))) - $code - end + function $($(QuoteNode(name_eval_invokelatest)))($($(map(arg -> :(:($($(QuoteNode(arg)))::$(typeof($arg)))), args)...))) + $code end - ($@__MODULE__).eval(def) - $called = true end - Base.invokelatest(($@__MODULE__).$name_3, $(args...)) + ($@__MODULE__).eval(def) + Base.invokelatest(($@__MODULE__).$name_eval_invokelatest, $(args...)) end @generated function $name($(args...)) # Taken from https://github.com/NHDaly/StagedFunctions.jl/blob/6fafbc560421f70b05e3df330b872877db0bf3ff/src/StagedFunctions.jl#L116 body_2 = () -> begin - code = $(body) - if has_function_def(macroexpand($@__MODULE__, code)) - :($($(name_2))($($args...))) + code = $name_generator($(args...)) + if true #has_function_def(macroexpand($@__MODULE__, code)) + :($($(name_invokelatest))($($(map(QuoteNode, args)...)))) else quote $code From 04c62eb57ecd6084be92c6c7066d48aeaff3154e Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 10 Jan 2024 17:10:22 -0500 Subject: [PATCH 2/3] fix the issue --- src/interface/index.jl | 2 +- src/util/util.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/interface/index.jl b/src/interface/index.jl index cec3c757a..53b4608ac 100644 --- a/src/interface/index.jl +++ b/src/interface/index.jl @@ -26,7 +26,7 @@ getindex_rep_def(lvl::RepeatData, idx) = SolidData(ElementData(lvl.default, lvl. getindex_rep_def(lvl::RepeatData, idx::Type{<:AbstractUnitRange}) = SolidData(ElementData(lvl.default, lvl.eltype)) Base.getindex(arr::Tensor, inds...) = getindex_helper(arr, to_indices(arr, inds)) -@staged function getindex_helper(arr, inds::Tuple) +@staged function getindex_helper(arr, inds) inds <: Type{<:Tuple} inds = inds.parameters @assert ndims(arr) == length(inds) diff --git a/src/util/util.jl b/src/util/util.jl index 0dc98e43d..699ac6583 100644 --- a/src/util/util.jl +++ b/src/util/util.jl @@ -37,6 +37,7 @@ ensures the first Finch invocation runs in the latest world, and leaves hooks so that subsequent calls to [`Finch.refresh`](@ref) can update the world and invalidate old versions. If the body contains closures, this macro uses an eval and invokelatest strategy. Otherwise, it uses a generated function. +This macro does not support type parameters, varargs, or keyword arguments. """ macro staged(def) (@capture def :function(:call(~name, ~args...), ~body)) || throw(ArgumentError("unrecognized function definition in @staged")) @@ -72,7 +73,7 @@ macro staged(def) # Taken from https://github.com/NHDaly/StagedFunctions.jl/blob/6fafbc560421f70b05e3df330b872877db0bf3ff/src/StagedFunctions.jl#L116 body_2 = () -> begin code = $name_generator($(args...)) - if true #has_function_def(macroexpand($@__MODULE__, code)) + if has_function_def(macroexpand($@__MODULE__, code)) :($($(name_invokelatest))($($(map(QuoteNode, args)...)))) else quote From 3111ea4085d54c452e30699f87db27d0621b9a11 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Wed, 10 Jan 2024 17:15:30 -0500 Subject: [PATCH 3/3] fix --- src/util/util.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/util/util.jl b/src/util/util.jl index 699ac6583..6f3090c34 100644 --- a/src/util/util.jl +++ b/src/util/util.jl @@ -52,7 +52,7 @@ macro staged(def) end function $name_invokelatest($(args...)) - $invokelatest($name_eval_invokelatest, $(args...)) + $(Base.invokelatest)($name_eval_invokelatest, $(args...)) end function $name_eval_invokelatest($(args...)) @@ -66,7 +66,7 @@ macro staged(def) end end ($@__MODULE__).eval(def) - Base.invokelatest(($@__MODULE__).$name_eval_invokelatest, $(args...)) + $(Base.invokelatest)(($@__MODULE__).$name_eval_invokelatest, $(args...)) end @generated function $name($(args...))