Skip to content

Commit

Permalink
Fix inlining of generated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Sep 10, 2024
1 parent aa04164 commit 57ae250
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 46 deletions.
10 changes: 4 additions & 6 deletions src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,20 @@ include("generatively_unrolled_functions.jl")
error("unrolled_applyat has detected an out-of-bounds index")

@inline unrolled_reduce(op, itr, init) =
(rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init)
@inline unrolled_reduce(op, itr; init = NoInit()) =
isempty(itr) && init isa NoInit ?
error("unrolled_reduce requires an init value for empty iterators") :
(rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init)
@inline unrolled_reduce(op, itr; init = NoInit()) =
unrolled_reduce(op, itr, init)

# TODO: Figure out why unrolled_reduce(op, Val(N), init) compiles faster than
# unrolled_reduce(op, StaticOneTo(N), init) for the non-orographic gravity wave
# parametrization test in ClimaAtmos, to the point where the StaticOneTo version
# completely hangs while the Val version compiles in only a few seconds.
@inline unrolled_reduce(op, val_N::Val, init) =
val_unrolled_reduce(op, val_N, init)
@inline unrolled_reduce(op, val_N::Val; init = NoInit()) =
val_N isa Val{0} && init isa NoInit ?
error("unrolled_reduce requires an init value for empty iterators") :
unrolled_reduce(op, val_N, init)
error("unrolled_reduce requires an init value for Val(0)") :
val_unrolled_reduce(op, val_N, init)

@inline unrolled_mapreduce(f, op, itrs...; init = NoInit()) =
unrolled_reduce(op, Iterators.map(f, itrs...), init)
Expand Down
87 changes: 53 additions & 34 deletions src/generatively_unrolled_functions.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
@inline @generated _gen_unrolled_any(::Val{N}, f, itr) where {N} =
Expr(:||, (:(f(generic_getindex(itr, $n))) for n in 1:N)...)
@generated _gen_unrolled_any(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
Expr(:||, (:(f(generic_getindex(itr, $n))) for n in 1:N)...),
)
@inline gen_unrolled_any(f, itr) = _gen_unrolled_any(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_all(::Val{N}, f, itr) where {N} =
Expr(:&&, (:(f(generic_getindex(itr, $n))) for n in 1:N)...)
@generated _gen_unrolled_all(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
Expr(:&&, (:(f(generic_getindex(itr, $n))) for n in 1:N)...),
)
@inline gen_unrolled_all(f, itr) = _gen_unrolled_all(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} =
Expr(:block, (:(f(generic_getindex(itr, $n))) for n in 1:N)..., nothing)
@generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
(:(f(generic_getindex(itr, $n))) for n in 1:N)...,
nothing,
)
@inline gen_unrolled_foreach(f, itr) =
_gen_unrolled_foreach(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_map(::Val{N}, f, itr) where {N} =
Expr(:tuple, (:(f(generic_getindex(itr, $n))) for n in 1:N)...)
@generated _gen_unrolled_map(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
Expr(:tuple, (:(f(generic_getindex(itr, $n))) for n in 1:N)...),
)
@inline gen_unrolled_map(f, itr) = _gen_unrolled_map(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr(
@generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
(:(n == $n && return f(generic_getindex(itr, $n))) for n in 1:N)...,
:(unrolled_applyat_bounds_error()),
) # This block gets optimized into a switch instruction during LLVM codegen.
@inline gen_unrolled_applyat(f, n, itr) =
_gen_unrolled_applyat(Val(length(itr)), f, n, itr)

@inline @generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} =
@generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} = Expr(
:block,
Expr(:meta, :inline),
foldl(
init <: NoInit ? (2:N) : (1:N);
(op_expr, n) -> :(op($op_expr, generic_getindex(itr, $n))),
(init <: NoInit ? 2 : 1):N;
init = init <: NoInit ? :(generic_getindex(itr, 1)) : :init,
) do prev_op_expr, n
:(op($prev_op_expr, generic_getindex(itr, $n)))
end # Use foldl instead of reduce to guarantee left associativity.
), # Use foldl instead of reduce to guarantee left associativity.
)
@inline gen_unrolled_reduce(op, itr, init) =
_gen_unrolled_reduce(Val(length(itr)), op, itr, init)

@inline @generated function _gen_unrolled_accumulate(
@generated function _gen_unrolled_accumulate(
::Val{N},
op,
itr,
Expand All @@ -43,29 +59,32 @@
first_item_expr = :(generic_getindex(itr, 1))
init_expr = init <: NoInit ? first_item_expr : :(op(init, $first_item_expr))
transformed_exprs_and_op_exprs =
accumulate(1:N; init = (nothing, init_expr)) do (_, prev_op_expr), n
accumulate(1:N; init = (nothing, init_expr)) do (_, op_expr), n
var = gensym()
op_expr = :(op($var, generic_getindex(itr, $(n + 1))))
(:($var = $prev_op_expr; transform($var)), op_expr)
next_op_expr = :(op($var, generic_getindex(itr, $(n + 1))))
(:($var = $op_expr; transform($var)), next_op_expr)
end
return Expr(:tuple, Iterators.map(first, transformed_exprs_and_op_exprs)...)
return Expr(
:block,
Expr(:meta, :inline),
Expr(:tuple, Iterators.map(first, transformed_exprs_and_op_exprs)...),
)
end
@inline gen_unrolled_accumulate(op, itr, init, transform) =
_gen_unrolled_accumulate(Val(length(itr)), op, itr, init, transform)

# TODO: The following is experimental and will likely be removed in the future.
# For some reason, combining these two methods into one (or combining them with
# the method for gen_unrolled_reduce defined above) causes compilation of the
# non-orographic gravity wave parametrization test in ClimaAtmos to hang. Even
# more bizarrely, using the assignment form of the first method definition below
# (as opposed to the function syntax used here) causes compilation to hang as
# well. This behavior has not yet been replicated in a minimal working example.
@inline @generated function val_unrolled_reduce(op, ::Val{N}, init) where {N}
return foldl((:init, 1:N...)) do prev_op_expr, item_expr
:(op($prev_op_expr, $item_expr))
end
end
@inline @generated val_unrolled_reduce(op, ::Val{N}, ::NoInit) where {N} =
foldl(1:N) do prev_op_expr, item_expr
:(op($prev_op_expr, $item_expr))
end
# Combining these two methods into one (or combining them with the method for
# gen_unrolled_reduce defined above) causes compilation of the non-orographic
# gravity wave parametrization test in ClimaAtmos to hang. Is this due to a bug
# in the compiler?
@generated val_unrolled_reduce(op, ::Val{N}, init) where {N} = Expr(
:block,
Expr(:meta, :inline),
foldl((op_expr, item_expr) -> :(op($op_expr, $item_expr)), (:init, 1:N...)),
)
@generated val_unrolled_reduce(op, ::Val{N}, ::NoInit) where {N} = Expr(
:block,
Expr(:meta, :inline),
foldl((op_expr, item_expr) -> :(op($op_expr, $item_expr)), 1:N),
)
12 changes: 6 additions & 6 deletions test/test_and_analyze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -877,23 +877,23 @@ title = "Very Long Iterators"
comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())

@testset "unrolled functions of Tuples vs. StaticOneTos" begin
for itr in (ntuple(identity, 2000), StaticOneTo(2000), StaticOneTo(8186))
for itr in (ntuple(identity, 2000), StaticOneTo(2000), StaticOneTo(8185))
@test_unrolled (itr,) unrolled_reduce(+, itr) reduce(+, itr) "Ints"
@test_unrolled(
(itr,),
unrolled_mapreduce(log, +, itr),
mapreduce(log, +, itr),
"Ints",
)
end # These can each take 40 seconds to compile for ntuple(identity, 8186).
for itr in (ntuple(identity, 8187), StaticOneTo(8187))
end # These can each take 40 seconds to compile for ntuple(identity, 8185).
for itr in (ntuple(identity, 8186), StaticOneTo(8186))
@test_throws "gc handles" unrolled_reduce(+, itr)
@test_throws "gc handles" unrolled_mapreduce(log, +, itr)
end
# TODO: Why does the compiler throw an error when generating functions that
# get unrolled into more than 8186 lines of LLVM code?
# get unrolled into more than 8185 lines of LLVM code?

for itr in (StaticOneTo(8186), StaticOneTo(8187))
for itr in (StaticOneTo(8185), StaticOneTo(8186))
@test_unrolled(
(itr,),
unrolled_reduce(+, Val(length(itr))),
Expand All @@ -902,7 +902,7 @@ comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())
)
end
@test_throws "gc handles" unrolled_reduce(+, Val(8188))
# TODO: Why is the limit 8187 for the Val version of unrolled_reduce?
# TODO: Why is the limit 8186 for the Val version of unrolled_reduce?
end

title = "Generative vs. Recursive Unrolling"
Expand Down

0 comments on commit 57ae250

Please sign in to comment.