From 9e77e49c001827b327f4198e56ac97cbe40747da Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Tue, 3 Sep 2024 14:47:55 -0700 Subject: [PATCH] Add StaticOneTo inlining hack back in --- src/generatively_unrolled_functions.jl | 34 +++++++++++++++----------- test/test_and_analyze.jl | 9 ++++--- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/generatively_unrolled_functions.jl b/src/generatively_unrolled_functions.jl index e8afe0b..3eb1740 100644 --- a/src/generatively_unrolled_functions.jl +++ b/src/generatively_unrolled_functions.jl @@ -1,51 +1,57 @@ -f_exprs(N) = (:(f(generic_getindex(itr, $n))) for n in 1:N) +item_expr(itr_type, n) = + itr_type <: StaticOneTo ? n : :(generic_getindex(itr, $n)) + +f_exprs(itr_type, N) = (:(f($(item_expr(itr_type, n)))) for n in 1:N) @inline @generated _gen_unrolled_any(::Val{N}, f, itr) where {N} = - Expr(:||, f_exprs(N)...) + Expr(:||, f_exprs(itr, N)...) @inline gen_unrolled_any(f, itr) = _gen_unrolled_any(Val(length(itr)), f, itr) @inline @generated _gen_unrolled_all(::Val{N}, f, itr) where {N} = - Expr(:&&, f_exprs(N)...) + Expr(:&&, f_exprs(itr, N)...) @inline gen_unrolled_all(f, itr) = _gen_unrolled_all(Val(length(itr)), f, itr) @inline @generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} = - Expr(:block, f_exprs(N)..., nothing) + Expr(:block, f_exprs(itr, N)..., nothing) @inline gen_unrolled_foreach(f, itr) = _gen_unrolled_foreach(Val(length(itr)), f, itr) @inline @generated _gen_unrolled_map(::Val{N}, f, itr) where {N} = - Expr(:tuple, f_exprs(N)...) + Expr(:tuple, f_exprs(itr, N)...) @inline gen_unrolled_map(f, itr) = _gen_unrolled_map(Val(length(itr)), f, itr) @inline @generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr( :block, - map(((n, expr),) -> :(n == $n && return $expr), enumerate(f_exprs(N)))..., + map( + ((n, expr),) -> :(n == $n && return $expr), + enumerate(f_exprs(itr, N)), + )..., :(unrolled_applyat_bounds_error()), ) # This block should get optimized into a switch statement during LLVM codegen. @inline gen_unrolled_applyat(f, n, itr) = _gen_unrolled_applyat(Val(length(itr)), f, n, itr) -function nested_op_expr(N, init_type) - init_expr = init_type <: NoInit ? :(generic_getindex(itr, 1)) : :init +function nested_op_expr(itr_type, N, init_type) + init_expr = init_type <: NoInit ? item_expr(itr_type, 1) : :init n_range = init_type <: NoInit ? (2:N) : (1:N) return foldl(n_range; init = init_expr) do prev_op_expr, n - :(op($prev_op_expr, generic_getindex(itr, $n))) + :(op($prev_op_expr, $(item_expr(itr_type, n)))) end # Use foldl instead of reduce to guarantee left associativity. end @inline @generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} = - nested_op_expr(N, init) + nested_op_expr(itr, N, init) @inline gen_unrolled_reduce(op, itr, init) = _gen_unrolled_reduce(Val(length(itr)), op, itr, init) -function transformed_sequential_op_exprs(N, init_type) - first_item_expr = :(generic_getindex(itr, 1)) +function transformed_sequential_op_exprs(itr_type, N, init_type) + first_item_expr = item_expr(itr_type, 1) init_expr = init_type <: NoInit ? first_item_expr : :(op(init, $first_item_expr)) transformed_exprs_and_next_op_exprs = accumulate(1:N; init = (nothing, init_expr)) do (_, prev_op_expr), n var = gensym() - next_op_expr = :(op($var, generic_getindex(itr, $(n + 1)))) + next_op_expr = :(op($var, $(item_expr(itr_type, n + 1)))) (:($var = $prev_op_expr; transform($var)), next_op_expr) end return map(first, transformed_exprs_and_next_op_exprs) @@ -57,6 +63,6 @@ end itr, init, transform, -) where {N} = Expr(:tuple, transformed_sequential_op_exprs(N, init)...) +) where {N} = Expr(:tuple, transformed_sequential_op_exprs(itr, N, init)...) @inline gen_unrolled_accumulate(op, itr, init, transform) = _gen_unrolled_accumulate(Val(length(itr)), op, itr, init, transform) diff --git a/test/test_and_analyze.jl b/test/test_and_analyze.jl index 9684306..fb31132 100644 --- a/test/test_and_analyze.jl +++ b/test/test_and_analyze.jl @@ -55,7 +55,7 @@ function print_comparison_table(title, comparison_table_dict, io = stdout) elseif contains(compile_time, "similar") && contains(allocs, "similar") # similar compilation and total allocations - color(writing_to_docs ? "lightgray" : "light_gray") + color(writing_to_docs ? "silver" : "light_gray") elseif contains(compile_time, "less") && contains(allocs, "more") || contains(compile_time, "more") && contains(allocs, "less") # mixed compilation and total allocations @@ -870,7 +870,7 @@ end @testset "unrolled functions of a large Tuple or StaticOneTo" begin itr = StaticOneTo(8186) - @test_unrolled((itr,), unrolled_reduce(+, itr), reduce(+, itr), "Ints",) + @test_unrolled (itr,) unrolled_reduce(+, itr) reduce(+, itr) "Ints" @test_unrolled( (itr,), unrolled_mapreduce(log, +, itr), @@ -881,9 +881,10 @@ end # The previous tests also work with ntuple(identity, 8186), but they take a # very long time to compile. - @test_throws "gc handles" unrolled_reduce(+, StaticOneTo(8187)) - @test_throws "gc handles" unrolled_mapreduce(log, +, StaticOneTo(8187)) + itr′ = StaticOneTo(8187) + @test_unrolled (itr′,) unrolled_reduce(+, itr′) reduce(+, itr′) "Ints" + @test_throws "gc handles" unrolled_mapreduce(log, +, StaticOneTo(8187)) @test_throws "gc handles" unrolled_reduce(+, ntuple(identity, 8187)) @test_throws "gc handles" unrolled_mapreduce(log, +, ntuple(identity, 8187))