Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inlining of generated functions #13

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,20 @@ include("generatively_unrolled_functions.jl")
error("unrolled_applyat has detected an out-of-bounds index")

@inline unrolled_reduce(op, itr, init) =
(rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init)
@inline unrolled_reduce(op, itr; init = NoInit()) =
isempty(itr) && init isa NoInit ?
error("unrolled_reduce requires an init value for empty iterators") :
(rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init)
@inline unrolled_reduce(op, itr; init = NoInit()) =
unrolled_reduce(op, itr, init)

# TODO: Figure out why unrolled_reduce(op, Val(N), init) compiles faster than
# unrolled_reduce(op, StaticOneTo(N), init) for the non-orographic gravity wave
# parametrization test in ClimaAtmos, to the point where the StaticOneTo version
# completely hangs while the Val version compiles in only a few seconds.
@inline unrolled_reduce(op, val_N::Val, init) =
val_unrolled_reduce(op, val_N, init)
@inline unrolled_reduce(op, val_N::Val; init = NoInit()) =
val_N isa Val{0} && init isa NoInit ?
error("unrolled_reduce requires an init value for empty iterators") :
unrolled_reduce(op, val_N, init)
error("unrolled_reduce requires an init value for Val(0)") :
val_unrolled_reduce(op, val_N, init)

@inline unrolled_mapreduce(f, op, itrs...; init = NoInit()) =
unrolled_reduce(op, Iterators.map(f, itrs...), init)
Expand Down
78 changes: 50 additions & 28 deletions src/generatively_unrolled_functions.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
@inline @generated _gen_unrolled_any(::Val{N}, f, itr) where {N} =
Expr(:||, (:(f(generic_getindex(itr, $n))) for n in 1:N)...)
@generated _gen_unrolled_any(::Val{N}, f, itr) where {N} = Expr(
:block,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine to me.

Expr(:meta, :inline),
Expr(:||, (:(f(generic_getindex(itr, $n))) for n in 1:N)...),
)
@inline gen_unrolled_any(f, itr) = _gen_unrolled_any(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_all(::Val{N}, f, itr) where {N} =
Expr(:&&, (:(f(generic_getindex(itr, $n))) for n in 1:N)...)
@generated _gen_unrolled_all(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
Expr(:&&, (:(f(generic_getindex(itr, $n))) for n in 1:N)...),
)
@inline gen_unrolled_all(f, itr) = _gen_unrolled_all(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} =
Expr(:block, (:(f(generic_getindex(itr, $n))) for n in 1:N)..., nothing)
@generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
(:(f(generic_getindex(itr, $n))) for n in 1:N)...,
nothing,
)
@inline gen_unrolled_foreach(f, itr) =
_gen_unrolled_foreach(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_map(::Val{N}, f, itr) where {N} =
Expr(:tuple, (:(f(generic_getindex(itr, $n))) for n in 1:N)...)
@generated _gen_unrolled_map(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
Expr(:tuple, (:(f(generic_getindex(itr, $n))) for n in 1:N)...),
)
@inline gen_unrolled_map(f, itr) = _gen_unrolled_map(Val(length(itr)), f, itr)

@inline @generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr(
@generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
(:(n == $n && return f(generic_getindex(itr, $n))) for n in 1:N)...,
:(unrolled_applyat_bounds_error()),
) # This block gets optimized into a switch instruction during LLVM codegen.
@inline gen_unrolled_applyat(f, n, itr) =
_gen_unrolled_applyat(Val(length(itr)), f, n, itr)

@inline @generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} =
@generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} = Expr(
:block,
Expr(:meta, :inline),
foldl(
init <: NoInit ? (2:N) : (1:N);
(op_expr, n) -> :(op($op_expr, generic_getindex(itr, $n))),
(init <: NoInit ? 2 : 1):N;
init = init <: NoInit ? :(generic_getindex(itr, 1)) : :init,
) do prev_op_expr, n
:(op($prev_op_expr, generic_getindex(itr, $n)))
end # Use foldl instead of reduce to guarantee left associativity.
), # Use foldl instead of reduce to guarantee left associativity.
)
@inline gen_unrolled_reduce(op, itr, init) =
_gen_unrolled_reduce(Val(length(itr)), op, itr, init)

@inline @generated function _gen_unrolled_accumulate(
@generated function _gen_unrolled_accumulate(
::Val{N},
op,
itr,
Expand All @@ -43,29 +59,35 @@
first_item_expr = :(generic_getindex(itr, 1))
init_expr = init <: NoInit ? first_item_expr : :(op(init, $first_item_expr))
transformed_exprs_and_op_exprs =
accumulate(1:N; init = (nothing, init_expr)) do (_, prev_op_expr), n
accumulate(1:N; init = (nothing, init_expr)) do (_, op_expr), n
var = gensym()
op_expr = :(op($var, generic_getindex(itr, $(n + 1))))
(:($var = $prev_op_expr; transform($var)), op_expr)
next_op_expr = :(op($var, generic_getindex(itr, $(n + 1))))
(:($var = $op_expr; transform($var)), next_op_expr)
end
return Expr(:tuple, Iterators.map(first, transformed_exprs_and_op_exprs)...)
return Expr(
:block,
Expr(:meta, :inline),
Expr(:tuple, Iterators.map(first, transformed_exprs_and_op_exprs)...),
)
end
@inline gen_unrolled_accumulate(op, itr, init, transform) =
_gen_unrolled_accumulate(Val(length(itr)), op, itr, init, transform)

# TODO: The following is experimental and will likely be removed in the future.
# For some reason, combining these two methods into one (or combining them with
# the method for gen_unrolled_reduce defined above) causes compilation of the
# non-orographic gravity wave parametrization test in ClimaAtmos to hang. Even
# more bizarrely, using the assignment form of the first method definition below
# (as opposed to the function syntax used here) causes compilation to hang as
# well. This behavior has not yet been replicated in a minimal working example.
@inline @generated function val_unrolled_reduce(op, ::Val{N}, init) where {N}
# non-orographic gravity wave parametrization test in ClimaAtmos to hang.
# Wrapping the first method's result in a block and adding an inline annotation
# also causes compilation to hang. Even using the assignment form of the first
# method definition below (as opposed to the function syntax used here) causes
# it to hang. This has not yet been replicated in a minimal working example.
@generated function val_unrolled_reduce(op, ::Val{N}, init) where {N}
return foldl((:init, 1:N...)) do prev_op_expr, item_expr
:(op($prev_op_expr, $item_expr))
end
end
@inline @generated val_unrolled_reduce(op, ::Val{N}, ::NoInit) where {N} =
foldl(1:N) do prev_op_expr, item_expr
:(op($prev_op_expr, $item_expr))
end
@generated val_unrolled_reduce(op, ::Val{N}, ::NoInit) where {N} = Expr(
:block,
Expr(:meta, :inline),
foldl((op_expr, item_expr) -> :(op($op_expr, $item_expr)), 1:N),
)
12 changes: 6 additions & 6 deletions test/test_and_analyze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -877,23 +877,23 @@ title = "Very Long Iterators"
comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())

@testset "unrolled functions of Tuples vs. StaticOneTos" begin
for itr in (ntuple(identity, 2000), StaticOneTo(2000), StaticOneTo(8186))
for itr in (ntuple(identity, 2000), StaticOneTo(2000), StaticOneTo(8185))
@test_unrolled (itr,) unrolled_reduce(+, itr) reduce(+, itr) "Ints"
@test_unrolled(
(itr,),
unrolled_mapreduce(log, +, itr),
mapreduce(log, +, itr),
"Ints",
)
end # These can each take 40 seconds to compile for ntuple(identity, 8186).
for itr in (ntuple(identity, 8187), StaticOneTo(8187))
end # These can each take 40 seconds to compile for ntuple(identity, 8185).
for itr in (ntuple(identity, 8186), StaticOneTo(8186))
@test_throws "gc handles" unrolled_reduce(+, itr)
@test_throws "gc handles" unrolled_mapreduce(log, +, itr)
end
# TODO: Why does the compiler throw an error when generating functions that
# get unrolled into more than 8186 lines of LLVM code?
# get unrolled into more than 8185 lines of LLVM code?

for itr in (StaticOneTo(8186), StaticOneTo(8187))
for itr in (StaticOneTo(8185), StaticOneTo(8186))
@test_unrolled(
(itr,),
unrolled_reduce(+, Val(length(itr))),
Expand All @@ -902,7 +902,7 @@ comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())
)
end
@test_throws "gc handles" unrolled_reduce(+, Val(8188))
# TODO: Why is the limit 8187 for the Val version of unrolled_reduce?
# TODO: Why is the limit 8186 for the Val version of unrolled_reduce?
end

title = "Generative vs. Recursive Unrolling"
Expand Down
Loading