Skip to content

Commit

Permalink
Add StaticOneTo inlining hack back in
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Sep 3, 2024
1 parent 3a78132 commit 9e77e49
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
34 changes: 20 additions & 14 deletions src/generatively_unrolled_functions.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,57 @@
f_exprs(N) = (:(f(generic_getindex(itr, $n))) for n in 1:N)
item_expr(itr_type, n) =
itr_type <: StaticOneTo ? n : :(generic_getindex(itr, $n))

f_exprs(itr_type, N) = (:(f($(item_expr(itr_type, n)))) for n in 1:N)

@inline @generated _gen_unrolled_any(::Val{N}, f, itr) where {N} =
Expr(:||, f_exprs(N)...)
Expr(:||, f_exprs(itr, N)...)
@inline gen_unrolled_any(f, itr) = _gen_unrolled_any(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_all(::Val{N}, f, itr) where {N} =
Expr(:&&, f_exprs(N)...)
Expr(:&&, f_exprs(itr, N)...)
@inline gen_unrolled_all(f, itr) = _gen_unrolled_all(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} =
Expr(:block, f_exprs(N)..., nothing)
Expr(:block, f_exprs(itr, N)..., nothing)
@inline gen_unrolled_foreach(f, itr) =
_gen_unrolled_foreach(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_map(::Val{N}, f, itr) where {N} =
Expr(:tuple, f_exprs(N)...)
Expr(:tuple, f_exprs(itr, N)...)
@inline gen_unrolled_map(f, itr) = _gen_unrolled_map(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr(
:block,
map(((n, expr),) -> :(n == $n && return $expr), enumerate(f_exprs(N)))...,
map(
((n, expr),) -> :(n == $n && return $expr),
enumerate(f_exprs(itr, N)),
)...,
:(unrolled_applyat_bounds_error()),
) # This block should get optimized into a switch statement during LLVM codegen.
@inline gen_unrolled_applyat(f, n, itr) =
_gen_unrolled_applyat(Val(length(itr)), f, n, itr)

function nested_op_expr(N, init_type)
init_expr = init_type <: NoInit ? :(generic_getindex(itr, 1)) : :init
function nested_op_expr(itr_type, N, init_type)
init_expr = init_type <: NoInit ? item_expr(itr_type, 1) : :init
n_range = init_type <: NoInit ? (2:N) : (1:N)
return foldl(n_range; init = init_expr) do prev_op_expr, n
:(op($prev_op_expr, generic_getindex(itr, $n)))
:(op($prev_op_expr, $(item_expr(itr_type, n))))
end # Use foldl instead of reduce to guarantee left associativity.
end

@inline @generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} =
nested_op_expr(N, init)
nested_op_expr(itr, N, init)
@inline gen_unrolled_reduce(op, itr, init) =
_gen_unrolled_reduce(Val(length(itr)), op, itr, init)

function transformed_sequential_op_exprs(N, init_type)
first_item_expr = :(generic_getindex(itr, 1))
function transformed_sequential_op_exprs(itr_type, N, init_type)
first_item_expr = item_expr(itr_type, 1)
init_expr =
init_type <: NoInit ? first_item_expr : :(op(init, $first_item_expr))
transformed_exprs_and_next_op_exprs =
accumulate(1:N; init = (nothing, init_expr)) do (_, prev_op_expr), n
var = gensym()
next_op_expr = :(op($var, generic_getindex(itr, $(n + 1))))
next_op_expr = :(op($var, $(item_expr(itr_type, n + 1))))
(:($var = $prev_op_expr; transform($var)), next_op_expr)
end
return map(first, transformed_exprs_and_next_op_exprs)
Expand All @@ -57,6 +63,6 @@ end
itr,
init,
transform,
) where {N} = Expr(:tuple, transformed_sequential_op_exprs(N, init)...)
) where {N} = Expr(:tuple, transformed_sequential_op_exprs(itr, N, init)...)
@inline gen_unrolled_accumulate(op, itr, init, transform) =
_gen_unrolled_accumulate(Val(length(itr)), op, itr, init, transform)
9 changes: 5 additions & 4 deletions test/test_and_analyze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function print_comparison_table(title, comparison_table_dict, io = stdout)
elseif contains(compile_time, "similar") &&
contains(allocs, "similar")
# similar compilation and total allocations
color(writing_to_docs ? "lightgray" : "light_gray")
color(writing_to_docs ? "silver" : "light_gray")
elseif contains(compile_time, "less") && contains(allocs, "more") ||
contains(compile_time, "more") && contains(allocs, "less")
# mixed compilation and total allocations
Expand Down Expand Up @@ -870,7 +870,7 @@ end

@testset "unrolled functions of a large Tuple or StaticOneTo" begin
itr = StaticOneTo(8186)
@test_unrolled((itr,), unrolled_reduce(+, itr), reduce(+, itr), "Ints",)
@test_unrolled (itr,) unrolled_reduce(+, itr) reduce(+, itr) "Ints"
@test_unrolled(
(itr,),
unrolled_mapreduce(log, +, itr),
Expand All @@ -881,9 +881,10 @@ end
# The previous tests also work with ntuple(identity, 8186), but they take a
# very long time to compile.

@test_throws "gc handles" unrolled_reduce(+, StaticOneTo(8187))
@test_throws "gc handles" unrolled_mapreduce(log, +, StaticOneTo(8187))
itr′ = StaticOneTo(8187)
@test_unrolled (itr′,) unrolled_reduce(+, itr′) reduce(+, itr′) "Ints"

@test_throws "gc handles" unrolled_mapreduce(log, +, StaticOneTo(8187))
@test_throws "gc handles" unrolled_reduce(+, ntuple(identity, 8187))
@test_throws "gc handles" unrolled_mapreduce(log, +, ntuple(identity, 8187))

Expand Down

0 comments on commit 9e77e49

Please sign in to comment.