Skip to content

Commit

Permalink
Merge pull request #16 from CliMA/dy/improve_compilation
Browse files Browse the repository at this point in the history
Improve compilation by using macros from Base.Cartesian
  • Loading branch information
dennisYatunin authored Oct 18, 2024
2 parents 2d471e1 + 28a17e6 commit 98c616f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 87 deletions.
13 changes: 6 additions & 7 deletions docs/src/developer_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ CurrentModule = UnrolledUtilities

There are two general ways to implement loop unrolling in Julia—recursively
splatting iterator contents and manually generating unrolled expressions. For
example, the recursively unrolled version of `foreach` is
example, a recursively unrolled version of the `foreach` function is

```julia
unrolled_foreach(f, itr) = _unrolled_foreach(f, itr...)
_unrolled_foreach(f) = nothing
_unrolled_foreach(f, item, items...) = (f(item); _unrolled_foreach(f, items...))
```

In contrast, the generatively unrolled version of `foreach` is
In contrast, a generatively unrolled implementation of this function looks like

```julia
unrolled_foreach(f, itr) = _unrolled_foreach(Val(length(itr)), f, itr)
Expand All @@ -30,15 +30,14 @@ rec_unroll
```

!!! tip "Tip"
Recursive loop unrolling can be disabled globally with the following
function redefinition:
Recursive loop unrolling can be enabled by redefining this function:

```julia
rec_unroll(itr) = false
rec_unroll(itr) = true
```

The cutoff length of 16 for switching to generative unrolling is motivated by
the benchmarks for [Generative vs. Recursive Unrolling](@ref).
The default choice of generative unrolling is motivated by the benchmarks for
[Generative vs. Recursive Unrolling](@ref).

## Interface API

Expand Down
97 changes: 46 additions & 51 deletions src/generatively_unrolled_functions.jl
Original file line number Diff line number Diff line change
@@ -1,74 +1,69 @@
@generated _gen_unrolled_any(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
Expr(:||, (:(f(generic_getindex(itr, $n))) for n in 1:N)...),
)
@generated _gen_unrolled_any(::Val{N}, f, itr) where {N} = quote
@inline
return Base.Cartesian.@nany $N n -> f(generic_getindex(itr, n))
end
@inline gen_unrolled_any(f, itr) = _gen_unrolled_any(Val(length(itr)), f, itr)

@generated _gen_unrolled_all(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
Expr(:&&, (:(f(generic_getindex(itr, $n))) for n in 1:N)...),
)
@generated _gen_unrolled_all(::Val{N}, f, itr) where {N} = quote
@inline
return Base.Cartesian.@nall $N n -> f(generic_getindex(itr, n))
end
@inline gen_unrolled_all(f, itr) = _gen_unrolled_all(Val(length(itr)), f, itr)

@generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
(:(f(generic_getindex(itr, $n))) for n in 1:N)...,
nothing,
)
@generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} = quote
@inline
Base.Cartesian.@nexprs $N n -> f(generic_getindex(itr, n))
return nothing
end
@inline gen_unrolled_foreach(f, itr) =
_gen_unrolled_foreach(Val(length(itr)), f, itr)

@generated _gen_unrolled_map(::Val{N}, f, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
Expr(:tuple, (:(f(generic_getindex(itr, $n))) for n in 1:N)...),
)
@generated _gen_unrolled_map(::Val{N}, f, itr) where {N} = quote
@inline
return Base.Cartesian.@ntuple $N n -> f(generic_getindex(itr, n))
end
@inline gen_unrolled_map(f, itr) = _gen_unrolled_map(Val(length(itr)), f, itr)

@generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr(
:block,
Expr(:meta, :inline),
(:(n == $n && return f(generic_getindex(itr, $n))) for n in 1:N)...,
:(unrolled_applyat_bounds_error()),
) # This block gets optimized into a switch instruction during LLVM codegen.
@generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = quote
@inline
Base.Cartesian.@nexprs $N n ->
(n′ == n && return f(generic_getindex(itr, n)))
unrolled_applyat_bounds_error()
end # This is optimized into a switch instruction during LLVM code generation.
@inline gen_unrolled_applyat(f, n, itr) =
_gen_unrolled_applyat(Val(length(itr)), f, n, itr)

@generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} = Expr(
:block,
Expr(:meta, :inline),
foldl(
(op_expr, n) -> :(op($op_expr, generic_getindex(itr, $n))),
(init <: NoInit ? 2 : 1):N;
init = init <: NoInit ? :(generic_getindex(itr, 1)) : :init,
), # Use foldl instead of reduce to guarantee left associativity.
)
@generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} = quote
@inline
value_0 = init
$N == 0 && return value_0
return Base.Cartesian.@nexprs $N n ->
(value_n = op(value_{n - 1}, generic_getindex(itr, n)))
end
@generated _gen_unrolled_reduce(::Val{N}, op, itr, ::NoInit) where {N} = quote
@inline
value_1 = generic_getindex(itr, 1)
$N == 1 && return value_1
return Base.Cartesian.@nexprs $(N - 1) n ->
(value_{n + 1} = op(value_n, generic_getindex(itr, n + 1)))
end
@inline gen_unrolled_reduce(op, itr, init) =
_gen_unrolled_reduce(Val(length(itr)), op, itr, init)

@generated function _gen_unrolled_accumulate(
@generated _gen_unrolled_accumulate(
::Val{N},
op,
itr,
init,
transform,
) where {N}
first_item_expr = :(generic_getindex(itr, 1))
init_expr = init <: NoInit ? first_item_expr : :(op(init, $first_item_expr))
transformed_exprs_and_op_exprs =
accumulate(1:N; init = (nothing, init_expr)) do (_, op_expr), n
var = gensym()
next_op_expr = :(op($var, generic_getindex(itr, $(n + 1))))
(:($var = $op_expr; transform($var)), next_op_expr)
end
return Expr(
:block,
Expr(:meta, :inline),
Expr(:tuple, Iterators.map(first, transformed_exprs_and_op_exprs)...),
)
) where {N} = quote
@inline
$N == 0 && return ()
first_itr_item = generic_getindex(itr, 1)
value_1 = init isa NoInit ? first_itr_item : op(init, first_itr_item)
Base.Cartesian.@nexprs $(N - 1) n ->
(value_{n + 1} = op(value_n, generic_getindex(itr, n + 1)))
return Base.Cartesian.@ntuple $N n -> transform(value_n)
end
@inline gen_unrolled_accumulate(op, itr, init, transform) =
_gen_unrolled_accumulate(Val(length(itr)), op, itr, init, transform)
Expand Down
13 changes: 3 additions & 10 deletions src/unrollable_iterator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@
rec_unroll(itr)
Whether to use recursive loop unrolling instead of generative loop unrolling for
the iterator `itr`.
In general, recursive loop unrolling is faster to compile for small iterators,
but it becomes extremely slow to compile for long iterators, and it usually
generates suboptimal LLVM code for long iterators. On the other hand, generative
loop unrolling is slow to compile for small iterators, but its compilation time
does not grow as rapidly with respect to iterator size, and it always generates
optimal LLVM code. The default is currently to use recursive unrolling for
iterator lengths up to 16, and to use generative unrolling for longer iterators.
the iterator `itr`. Recursive unrolling can lead to suboptimal LLVM code for
iterators of more than 32 items, so this is set to `false` by default.
"""
@inline rec_unroll(itr) = length(itr) <= 16
@inline rec_unroll(itr) = false

"""
generic_getindex(itr, n)
Expand Down
21 changes: 2 additions & 19 deletions test/test_and_analyze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -877,32 +877,15 @@ title = "Very Long Iterators"
comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())

@testset "unrolled functions of Tuples vs. StaticOneTos" begin
for itr in (ntuple(identity, 2000), StaticOneTo(2000), StaticOneTo(8185))
for itr in (ntuple(identity, 2000), StaticOneTo(2000), StaticOneTo(9000))
@test_unrolled (itr,) unrolled_reduce(+, itr) reduce(+, itr) "Ints"
@test_unrolled(
(itr,),
unrolled_mapreduce(log, +, itr),
mapreduce(log, +, itr),
"Ints",
)
end # These can each take 40 seconds to compile for ntuple(identity, 8185).
for itr in (ntuple(identity, 8186), StaticOneTo(8186))
@test_throws "gc handles" unrolled_reduce(+, itr)
@test_throws "gc handles" unrolled_mapreduce(log, +, itr)
end
# TODO: Why does the compiler throw an error when generating functions that
# get unrolled into more than 8185 lines of LLVM code?

for itr in (StaticOneTo(8185), StaticOneTo(8186))
@test_unrolled(
(itr,),
unrolled_reduce(+, Val(length(itr))),
reduce(+, itr),
"Ints",
)
end
@test_throws "gc handles" unrolled_reduce(+, Val(8188))
# TODO: Why is the limit 8186 for the Val version of unrolled_reduce?
end # These can take over a minute to compile for ntuple(identity, 9000).
end

title = "Generative vs. Recursive Unrolling"
Expand Down

0 comments on commit 98c616f

Please sign in to comment.