diff --git a/Project.toml b/Project.toml
index 62f5f0a..646c778 100644
--- a/Project.toml
+++ b/Project.toml
@@ -5,6 +5,13 @@ version = "0.1.2"
[compat]
julia = "1.10"
+StaticArrays = "1"
+
+[weakdeps]
+StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
+
+[extensions]
+UnrolledUtilitiesStaticArraysExt = "StaticArrays"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
diff --git a/docs/make.jl b/docs/make.jl
index ee728d5..2930c69 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -2,23 +2,25 @@ using Documenter
include(joinpath("..", "test", "test_and_analyze.jl"))
-comparison_table_file = joinpath("docs", "src", "comparison_table.md")
+comparison_table_file = joinpath("docs", "src", "comparison_tables.md")
open(comparison_table_file, "w") do io
- println(io, "# Comparison Table\n```@raw html")
- println(io, "
") # use 80% of viewport
- print_comparison_table(io, true)
- println(io, "
")
- println(io, "```")
+ println(io, "# Comparison Tables")
+ for (title, comparison_table_dict) in comparison_table_dicts
+ print_comparison_table(title, comparison_table_dict, io)
+ end
end
makedocs(;
sitename = "UnrolledUtilities.jl",
modules = [UnrolledUtilities],
- pages = ["Home" => "index.md", "Comparison Table" => "comparison_table.md"],
+ pages = [
+ "Home" => "index.md",
+ "Comparison Tables" => "comparison_tables.md",
+ ],
format = Documenter.HTML(
prettyurls = get(ENV, "CI", nothing) == "true",
- size_threshold_ignore = ["comparison_table.md"],
+ size_threshold_ignore = ["comparison_tables.md"],
),
clean = true,
)
diff --git a/docs/src/index.md b/docs/src/index.md
index 8aaec38..5fa4958 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -1,15 +1,28 @@
+```@meta
+CurrentModule = UnrolledUtilities
+```
+
# UnrolledUtilities.jl
-A collection of generated functions in which all loops are unrolled and inlined:
+## Unrolled Functions
+
+This package exports the following functions, in which all loops are unrolled
+and inlined:
- `unrolled_any(f, itr)`: similar to `any`
- `unrolled_all(f, itr)`: similar to `all`
- `unrolled_foreach(f, itrs...)`: similar to `foreach`
- `unrolled_map(f, itrs...)`: similar to `map`
+- `unrolled_applyat(f, n, itrs...)`: similar to `f(map(itr -> itr[n], itrs)...)`
- `unrolled_reduce(op, itr; [init])`: similar to `reduce`
- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce`
-- `unrolled_zip(itrs...)`: similar to `zip`
-- `unrolled_enumerate(itrs...)`: similar to `enumerate`, but with the ability to
- handle multiple iterators
+- `unrolled_accumulate(op, itr; [init], [transform])`: similar to `accumulate`,
+ but with an optional `transform` function applied to every accumulated value
+- `unrolled_push(itr, item)`: similar to `push!`, but non-mutating
+- `unrolled_append(itr1, itr2)`: similar to `append!`, but non-mutating
+- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take` (and to
+ `itr[1:N]`), but with `N` wrapped in a `Val`
+- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop` (and to
+ `itr[(N + 1):end]`), but with `N` wrapped in a `Val`
- `unrolled_in(item, itr)`: similar to `in`
- `unrolled_unique(itr)`: similar to `unique`
- `unrolled_filter(f, itr)`: similar to `filter`
@@ -18,11 +31,6 @@ A collection of generated functions in which all loops are unrolled and inlined:
- `unrolled_flatten(itr)`: similar to `Iterators.flatten`
- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap`
- `unrolled_product(itrs...)`: similar to `Iterators.product`
-- `unrolled_applyat(f, n, itrs...)`: similar to `f(map(itr -> itr[n], itrs)...)`
-- `unrolled_take(itr, ::Val{N})`: similar to `itr[1:N]` (and to
- `Iterators.take`), but with `N` wrapped in a `Val`
-- `unrolled_drop(itr, ::Val{N})`: similar to `itr[(N + 1):end]` (and to
- `Iterators.drop`), but with `N` wrapped in a `Val`
These functions are guaranteed to be type-stable whenever they are given
iterators with inferrable lengths and element types, including when
@@ -42,34 +50,77 @@ iterators have singleton element types (and when the result of calling `f`
and/or `op` on these elements is inferrable). However, they can also be much
more expensive to compile than their counterparts from `Base` and
`Base.Iterators`, in which case they should not be used unless there is a clear
-performance benefit. Some notable exceptions to this are `unrolled_zip`,
-`unrolled_take`, and `unrolled_drop`, which tend to be easier to compile than
-`zip`, `Iterators.take`, `Iterators.drop`, and standard indexing notation.
+performance benefit. Two notable exceptions to this are `unrolled_take` and
+`unrolled_drop`, which are faster to compile than their non-static versions.
+
+## Interface
+
+These functions can be used to unroll loops over all iterators with statically
+inferrable lengths. Compatibility with any such iterator type can be added
+through the following interface:
+
+```@docs
+rec_unroll
+generic_getindex
+output_type_for_promotion
+NoOutputType
+ConditionalOutputType
+output_promote_rule
+constructor_from_tuple
+```
+
+This interface is used to provide built-in compatibility with
+- statically sized iterators from `Base` (`Tuple` and `NamedTuple`)
+- lazy iterators from `Base` (`enumerate`, `zip`, `Iterators.map`, and other
+ generator expressions)
+- statically sized iterators from
+ [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) (`SVector`
+ and `MVector`)
+- custom lazy and low-storage iterators (`StaticOneTo` and `StaticBitVector`)
+
+```@docs
+StaticOneTo
+StaticBitVector
+```
+
+## When to Unroll
For a more precise indication of whether you should use `UnrolledUtilities`,
-please consult the autogenerated [Comparison Table](@ref). This table contains a
-comprehensive set of potential use cases, each with a measurement of performance
-optimization, the time required for compilation, and the memory usage during
-compilation. Most cases involve simple functions `f` and/or `op`, but the last
-few demonstrate the benefits of unrolling with non-trivial recursive functions.
+please consult the autogenerated [Comparison Tables](@ref). These tables contain
+a comprehensive set of potential use cases, along with a few measurements that
+summarize their performance, compilation, and allocations:
+- run time (best of several trial measurements)
+- compilation time (as reported by the compiler)
+- overall level of optimization (type stability, constant propagation, etc.) and
+ allocations during run time (as reported by the garbage collector)
+- total allocations during compilation and first run (as reported by the garbage
+ collector and, when possible, the Julia process's resident set size estimator)
-The rows of the table are highlighted as follows:
-- green indicates an improvement in performance and either no change in
- compilation or easier compilation (i.e., either similar or smaller values of
- compilation time and memory usage)
-- dark blue indicates an improvement in performance and harder compilation
- (i.e., larger values of compilation time and/or memory usage)
-- light blue indicates no change in performance and easier compilation
-- yellow indicates no change in performance and no change in compilation
-- magenta indicates no change in performance, an increase in compilation time,
- and a decrease in compilation memory usage
-- red indicates no change in performance and harder compilation
+The rows of the tables are highlighted as follows:
+- light blue indicates an improvement in performance due to better optimization
+ and either an improvement or no change in compilation time and total
+ allocations
+- green indicates either faster run time or fewer allocations during run time
+ and either an improvement or no change in compilation time and total
+ allocations
+- dark blue indicates an improvement in performance due to better optimization
+ and either slower compilation or more total allocations
+- yellow indicates either faster run time or fewer allocations during run time
+ and either slower compilation or more total allocations
+- magenta indicates no change in performance and either an improvement or no
+ change in compilation time and total allocations
+- light gray indicates no change in performance and no change in compilation
+ time and total allocations
+- dark gray indicates no change in performance and either faster compilation
+ with more total allocations or slower compilation with fewer total allocations
+- red indicates a deterioration in performance, or no change in
+ performance and either slower compilation or more total allocations
-Rows highlighted in green and blue present a clear advantage for unrolling,
-whereas those highlighted in yellow, magenta, and red either have no clear
-advantage, or they have a clear disadvantage. It is recommended that you only
-unroll when your use case is similar to a row in the first category.
+Rows highlighted in gray present no clear advantage to unrolling, while those
+highlighted in red present a clear disadvantage. It is recommended that you only
+unroll when your use case is similar to a row in one of the remaining
+categories, each of which demonstrates some advantage to unrolling.
-The table is also printed out by this package's unit tests, so these
+The tables are also printed out by this package's unit tests, so these
measurements can be compared across different operating systems by checking the
[CI pipeline](https://github.com/CliMA/UnrolledUtilities.jl/actions/workflows/ci.yml).
diff --git a/ext/UnrolledUtilitiesStaticArraysExt.jl b/ext/UnrolledUtilitiesStaticArraysExt.jl
new file mode 100644
index 0000000..67058a7
--- /dev/null
+++ b/ext/UnrolledUtilitiesStaticArraysExt.jl
@@ -0,0 +1,12 @@
+module UnrolledUtilitiesStaticArraysExt
+
+import UnrolledUtilities
+import StaticArrays: SVector, MVector
+
+@inline UnrolledUtilities.output_type_for_promotion(::SVector) = SVector
+@inline UnrolledUtilities.constructor_from_tuple(::Type{SVector}) = SVector
+
+@inline UnrolledUtilities.output_type_for_promotion(::MVector) = MVector
+@inline UnrolledUtilities.constructor_from_tuple(::Type{MVector}) = MVector
+
+end
diff --git a/src/StaticBitVector.jl b/src/StaticBitVector.jl
new file mode 100644
index 0000000..5605faf
--- /dev/null
+++ b/src/StaticBitVector.jl
@@ -0,0 +1,155 @@
+"""
+ StaticBitVector{N, [U]}(f)
+ StaticBitVector{N, [U]}([bit])
+
+A statically-sized analogue of `BitVector` with `Unsigned` chunks of type `U`,
+which can be constructed using either a function `f(n)` or a constant `bit`. By
+default, `U` is set to `UInt8` and `bit` is set to `false`.
+
+This iterator can only store `Bool`s, so its `output_type_for_promotion` is a
+`ConditionalOutputType`. Efficient methods are provided for `unrolled_map`,
+`unrolled_accumulate`, `unrolled_take`, and `unrolled_drop`, though the methods
+for `unrolled_map` and `unrolled_accumulate` only apply when their output's
+first item is a `Bool`. No other unrolled functions can use `StaticBitVector`s
+as output types.
+"""
+struct StaticBitVector{N, U <: Unsigned, I <: NTuple{<:Any, U}} <:
+ StaticSequence{N}
+ ints::I
+end
+@inline StaticBitVector{N, U}(ints) where {N, U} =
+ StaticBitVector{N, U, typeof(ints)}(ints)
+@inline StaticBitVector{N}(args...) where {N} =
+ StaticBitVector{N, UInt8}(args...)
+
+@inline function StaticBitVector{N, U}(bit::Bool = false) where {N, U}
+ n_bits_per_int = 8 * sizeof(U)
+ n_ints = cld(N, n_bits_per_int)
+ ints = ntuple(Returns(bit ? ~zero(U) : zero(U)), Val(n_ints))
+ return StaticBitVector{N, U}(ints)
+end
+
+@inline function StaticBitVector{N, U}(f::Function) where {N, U}
+ n_bits_per_int = 8 * sizeof(U)
+ n_ints = cld(N, n_bits_per_int)
+ ints = ntuple(Val(n_ints)) do int_index
+ @inline
+ first_index = n_bits_per_int * (int_index - 1) + 1
+ unrolled_reduce(
+ StaticOneTo(min(n_bits_per_int, N - first_index + 1));
+ init = zero(U),
+ ) do int, bit_index
+ @inline
+ bit_offset = bit_index - 1
+ int | U(f(first_index + bit_offset)::Bool) << bit_offset
+ end
+ end
+ return StaticBitVector{N, U}(ints)
+end
+
+@inline function int_index_and_bit_offset(::Type{U}, n) where {U}
+ int_offset, bit_offset = divrem(n - 1, 8 * sizeof(U))
+ return (int_offset + 1, bit_offset)
+end
+
+@inline function generic_getindex(
+ itr::StaticBitVector{<:Any, U},
+ n::Integer,
+) where {U}
+ int_index, bit_offset = int_index_and_bit_offset(U, n)
+ int = itr.ints[int_index]
+ return Bool(int >> bit_offset & one(int))
+end
+
+@inline function Base.setindex(
+ itr::StaticBitVector{N, U},
+ bit::Bool,
+ n::Integer,
+) where {N, U}
+ int_index, bit_offset = int_index_and_bit_offset(U, n)
+ int = itr.ints[int_index]
+ int′ = int & ~(one(int) << bit_offset) | U(bit) << bit_offset
+ ints = Base.setindex(itr.ints, int′, int_index)
+ return StaticBitVector{N, U}(ints)
+end
+
+@inline output_type_for_promotion(::StaticBitVector{<:Any, U}) where {U} =
+ ConditionalOutputType(Bool, StaticBitVector{<:Any, U})
+
+@inline function unrolled_map_into(
+ ::Type{StaticBitVector{<:Any, U}},
+ f,
+ itrs...,
+) where {U}
+ lazy_itr = Iterators.map(f, itrs...)
+ N = length(lazy_itr)
+ return StaticBitVector{N, U}(Base.Fix1(generic_getindex, lazy_itr))
+end
+
+@inline function unrolled_accumulate_into(
+ ::Type{StaticBitVector{<:Any, U}},
+ op,
+ itr,
+ init,
+ transform,
+) where {U}
+ N = length(itr)
+ n_bits_per_int = 8 * sizeof(U)
+ n_ints = cld(N, n_bits_per_int)
+ ints = unrolled_accumulate_into_tuple(
+ StaticOneTo(n_ints);
+ init = (nothing, init),
+ transform = first,
+ ) do (_, init_value_for_new_int), int_index
+ @inline
+ first_index = n_bits_per_int * (int_index - 1) + 1
+ unrolled_reduce(
+ StaticOneTo(min(n_bits_per_int, N - first_index + 1));
+ init = (zero(U), init_value_for_new_int),
+ ) do (int, prev_value), bit_index
+ @inline
+ bit_offset = bit_index - 1
+ item = generic_getindex(itr, first_index + bit_offset)
+ new_value =
+ first_index + bit_offset == 1 && prev_value isa NoInit ?
+ item : op(prev_value, item)
+ (int | U(transform(new_value)::Bool) << bit_offset, new_value)
+ end
+ end
+ return StaticBitVector{N, U}(ints)
+end
+
+# TODO: Add unrolled_push and unrolled_append
+
+@inline function unrolled_take(
+ itr::StaticBitVector{<:Any, U},
+ ::Val{N},
+) where {N, U}
+ n_bits_per_int = 8 * sizeof(U)
+ n_ints = cld(N, n_bits_per_int)
+ ints = unrolled_take(itr.ints, Val(n_ints))
+ return StaticBitVector{N, U}(ints)
+end
+
+@inline function unrolled_drop(
+ itr::StaticBitVector{N_old, U},
+ ::Val{N},
+) where {N_old, N, U}
+ n_bits_per_int = 8 * sizeof(U)
+ n_ints = cld(N_old - N, n_bits_per_int)
+ n_dropped_ints = length(itr.ints) - n_ints
+ bit_offset = N - n_bits_per_int * n_dropped_ints
+ ints_without_offset = unrolled_drop(itr.ints, Val(n_dropped_ints))
+ ints = if bit_offset == 0
+ ints_without_offset
+ else
+ cur_ints = ints_without_offset
+ next_ints = unrolled_push(unrolled_drop(cur_ints, Val(1)), nothing)
+ unrolled_map_into_tuple(cur_ints, next_ints) do cur_int, next_int
+ @inline
+ isnothing(next_int) ? cur_int >> bit_offset :
+ cur_int >> bit_offset | next_int << (n_bits_per_int - bit_offset)
+ end
+ end
+ return StaticBitVector{N_old - N, U}(ints)
+end
diff --git a/src/StaticOneTo.jl b/src/StaticOneTo.jl
new file mode 100644
index 0000000..2721169
--- /dev/null
+++ b/src/StaticOneTo.jl
@@ -0,0 +1,18 @@
+"""
+ StaticOneTo(N)
+
+A lazy and statically-sized analogue of `Base.OneTo(N)`.
+
+This iterator can only store the integers from 1 to `N`, so its
+`output_type_for_promotion` is `NoOutputType()`. An efficient method is provided
+for `unrolled_take`, but no other unrolled functions can use `StaticOneTo`s as
+output types.
+"""
+struct StaticOneTo{N} <: StaticSequence{N} end
+@inline StaticOneTo(N) = StaticOneTo{N}()
+
+@inline generic_getindex(::StaticOneTo, n) = n
+
+@inline output_type_for_promotion(::StaticOneTo) = NoOutputType()
+
+@inline unrolled_take(::StaticOneTo, ::Val{N}) where {N} = StaticOneTo(N)
diff --git a/src/UnrolledUtilities.jl b/src/UnrolledUtilities.jl
index dc69559..52cf883 100644
--- a/src/UnrolledUtilities.jl
+++ b/src/UnrolledUtilities.jl
@@ -4,10 +4,14 @@ export unrolled_any,
unrolled_all,
unrolled_foreach,
unrolled_map,
+ unrolled_applyat,
unrolled_reduce,
unrolled_mapreduce,
- unrolled_zip,
- unrolled_enumerate,
+ unrolled_accumulate,
+ unrolled_push,
+ unrolled_append,
+ unrolled_take,
+ unrolled_drop,
unrolled_in,
unrolled_unique,
unrolled_filter,
@@ -15,114 +19,145 @@ export unrolled_any,
unrolled_flatten,
unrolled_flatmap,
unrolled_product,
- unrolled_applyat,
- unrolled_take,
- unrolled_drop
-
-inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N
-# We could also add support for statically-sized iterators that are not Tuples.
-
-f_exprs(itr_type) = (:(f(itr[$n])) for n in 1:inferred_length(itr_type))
-@inline @generated unrolled_any(f, itr) = Expr(:||, f_exprs(itr)...)
-@inline @generated unrolled_all(f, itr) = Expr(:&&, f_exprs(itr)...)
-
-function zipped_f_exprs(itr_types)
- L = length(itr_types)
- L == 0 && error("unrolled functions need at least one iterator as input")
- N = minimum(inferred_length, itr_types)
- return (:(f($((:(itrs[$l][$n]) for l in 1:L)...))) for n in 1:N)
-end
-@inline @generated unrolled_foreach(f, itrs...) =
- Expr(:block, zipped_f_exprs(itrs)..., nothing)
-@inline @generated unrolled_map(f, itrs...) =
- Expr(:tuple, zipped_f_exprs(itrs)...)
-
-function nested_op_expr(itr_type)
- N = inferred_length(itr_type)
- N == 0 && error("unrolled_reduce needs an `init` value for empty iterators")
- item_exprs = (:(itr[$n]) for n in 1:N)
- return reduce((expr1, expr2) -> :(op($expr1, $expr2)), item_exprs)
-end
-@inline @generated unrolled_reduce_without_init(op, itr) = nested_op_expr(itr)
-
-struct NoInit end
+ StaticOneTo,
+ StaticBitVector
+
+include("unrollable_iterator_interface.jl")
+include("recursively_unrolled_functions.jl")
+include("generatively_unrolled_functions.jl")
+
+struct NoInit end # Analogue of Base._InitialValue for reduction/accumulation.
+
+@inline unrolled_any(f, itr) =
+ (rec_unroll(itr) ? rec_unrolled_any : gen_unrolled_any)(f, itr)
+@inline unrolled_any(itr) = unrolled_any(identity, itr)
+
+@inline unrolled_all(f, itr) =
+ (rec_unroll(itr) ? rec_unrolled_all : gen_unrolled_all)(f, itr)
+@inline unrolled_all(itr) = unrolled_all(identity, itr)
+
+@inline unrolled_foreach(f, itr) =
+ (rec_unroll(itr) ? rec_unrolled_foreach : gen_unrolled_foreach)(f, itr)
+@inline unrolled_foreach(f, itrs...) = unrolled_foreach(splat(f), zip(itrs...))
+
+@inline unrolled_map_into_tuple(f, itr) =
+ (rec_unroll(itr) ? rec_unrolled_map : gen_unrolled_map)(f, itr)
+@inline unrolled_map_into_tuple(f, itrs...) =
+ unrolled_map_into_tuple(splat(f), zip(itrs...))
+@inline unrolled_map_into(output_type, f, itrs...) =
+ constructor_from_tuple(output_type)(unrolled_map_into_tuple(f, itrs...))
+@inline unrolled_map(f, itrs...) =
+ unrolled_map_into(map_output_type(f, itrs...), f, itrs...)
+
+@inline unrolled_applyat(f, n, itr) =
+ (rec_unroll(itr) ? rec_unrolled_applyat : gen_unrolled_applyat)(f, n, itr)
+@inline unrolled_applyat(f, n, itrs...) =
+ unrolled_applyat(splat(f), n, zip(itrs...))
+@inline unrolled_applyat_bounds_error() =
+ error("unrolled_applyat has detected an out-of-bounds index")
+
+@inline unrolled_reduce(op, itr, init) =
+ (rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init)
@inline unrolled_reduce(op, itr; init = NoInit()) =
- unrolled_reduce_without_init(op, init isa NoInit ? itr : (init, itr...))
+ isempty(itr) && init isa NoInit ?
+ error("unrolled_reduce requires an init value for empty iterators") :
+ unrolled_reduce(op, itr, init)
@inline unrolled_mapreduce(f, op, itrs...; init = NoInit()) =
- unrolled_reduce(op, unrolled_map(f, itrs...); init)
+ unrolled_reduce(op, Iterators.map(f, itrs...), init)
+
+@inline unrolled_accumulate_into_tuple(op, itr, init, transform) =
+ (rec_unroll(itr) ? rec_unrolled_accumulate : gen_unrolled_accumulate)(
+ op,
+ itr,
+ init,
+ transform,
+ )
+@inline unrolled_accumulate_into(output_type, op, itr, init, transform) =
+ constructor_from_tuple(output_type)(
+ unrolled_accumulate_into_tuple(op, itr, init, transform),
+ )
+@inline unrolled_accumulate(op, itr; init = NoInit(), transform = identity) =
+ unrolled_accumulate_into(
+ accumulate_output_type(op, itr, init, transform),
+ op,
+ itr,
+ init,
+ transform,
+ )
-@inline unrolled_zip(itrs...) = unrolled_map(tuple, itrs...)
+@inline unrolled_push(itr, item) =
+ inferred_constructor_from_tuple(itr)((itr..., item))
-@inline unrolled_enumerate(itrs...) =
- unrolled_zip(ntuple(identity, Val(length(itrs[1]))), itrs...)
+@inline unrolled_append_into(output_type, itr1, itr2) =
+ constructor_from_tuple(output_type)((itr1..., itr2...))
+@inline unrolled_append(itr1, itr2) =
+ unrolled_append_into(promoted_output_type((itr1, itr2)), itr1, itr2)
+
+@inline unrolled_take(itr, ::Val{N}) where {N} =
+ inferred_constructor_from_tuple(itr)(
+ ntuple(Base.Fix1(generic_getindex, itr), Val(N)),
+ )
+
+@inline unrolled_drop(itr, ::Val{N}) where {N} =
+ inferred_constructor_from_tuple(itr)(
+ ntuple(n -> generic_getindex(itr, N + n), Val(length(itr) - N)),
+ )
@inline unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr)
# Using === instead of == or isequal improves type stability for singletons.
@inline unrolled_unique(itr) =
- unrolled_reduce(itr; init = ()) do unique_items, item
+ unrolled_reduce(itr; init = inferred_empty(itr)) do unique_items, item
@inline
- unrolled_in(item, unique_items) ? unique_items : (unique_items..., item)
+ unrolled_in(item, unique_items) ? unique_items :
+ unrolled_push(unique_items, item)
end
@inline unrolled_filter(f, itr) =
- unrolled_reduce(itr; init = ()) do filtered_items, item
+ unrolled_reduce(itr; init = inferred_empty(itr)) do items_with_true_f, item
@inline
- f(item) ? (filtered_items..., item) : filtered_items
+ f(item) ? unrolled_push(items_with_true_f, item) : items_with_true_f
end
@inline unrolled_split(f, itr) =
- unrolled_reduce(itr; init = ((), ())) do (f_items, not_f_items), item
+ unrolled_reduce(
+ itr;
+ init = (inferred_empty(itr), inferred_empty(itr)),
+ ) do (items_with_true_f, items_with_false_f), item
@inline
- f(item) ? ((f_items..., item), not_f_items) :
- (f_items, (not_f_items..., item))
+ f(item) ? (unrolled_push(items_with_true_f, item), items_with_false_f) :
+ (items_with_true_f, unrolled_push(items_with_false_f, item))
end
@inline unrolled_flatten(itr) =
- unrolled_reduce((item1, item2) -> (item1..., item2...), itr; init = ())
+ unrolled_reduce(unrolled_append, itr; init = promoted_empty(itr))
@inline unrolled_flatmap(f, itrs...) =
- unrolled_flatten(unrolled_map(f, itrs...))
+ unrolled_flatten(Iterators.map(f, itrs...))
@inline unrolled_product(itrs...) =
- unrolled_reduce(itrs; init = ((),)) do product_itr, itr
+ unrolled_reduce(itrs; init = (promoted_empty(itrs),)) do product_itr, itr
@inline
unrolled_flatmap(itr) do item
@inline
- unrolled_map(product_tuple -> (product_tuple..., item), product_itr)
+ unrolled_map_into_tuple(Base.Fix2(unrolled_push, item), product_itr)
end
end
-@inline unrolled_applyat(f, n, itrs...) = unrolled_foreach(
- (i, items...) -> i == n && f(items...),
- unrolled_enumerate(itrs...),
-)
+abstract type StaticSequence{N} end
-@inline unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N))
-@inline unrolled_drop(itr, ::Val{N}) where {N} =
- ntuple(i -> itr[N + i], Val(length(itr) - N))
-# When its second argument is a Val, ntuple is unrolled via Base.@ntuple.
-
-@static if hasfield(Method, :recursion_relation)
- # Remove recursion limits for functions whose arguments are also functions.
- for func in (
- unrolled_any,
- unrolled_all,
- unrolled_foreach,
- unrolled_map,
- unrolled_reduce_without_init,
- unrolled_reduce,
- unrolled_mapreduce,
- unrolled_filter,
- unrolled_split,
- unrolled_flatmap,
- unrolled_applyat,
- )
- for method in methods(func)
- method.recursion_relation = (_...) -> true
- end
- end
-end
+@inline Base.length(::StaticSequence{N}) where {N} = N
+@inline Base.firstindex(::StaticSequence) = 1
+@inline Base.lastindex(itr::StaticSequence) = length(itr)
+@inline Base.getindex(itr::StaticSequence, n::Integer) =
+ generic_getindex(itr, n)
+@inline Base.iterate(itr::StaticSequence, n = 1) =
+ n > length(itr) ? nothing : (generic_getindex(itr, n), n + 1)
+
+include("StaticOneTo.jl")
+include("StaticBitVector.jl")
+
+include("recursion_limits.jl") # This must be included at the end of the module.
end
diff --git a/src/generatively_unrolled_functions.jl b/src/generatively_unrolled_functions.jl
new file mode 100644
index 0000000..e8afe0b
--- /dev/null
+++ b/src/generatively_unrolled_functions.jl
@@ -0,0 +1,62 @@
+f_exprs(N) = (:(f(generic_getindex(itr, $n))) for n in 1:N)
+
+@inline @generated _gen_unrolled_any(::Val{N}, f, itr) where {N} =
+ Expr(:||, f_exprs(N)...)
+@inline gen_unrolled_any(f, itr) = _gen_unrolled_any(Val(length(itr)), f, itr)
+
+@inline @generated _gen_unrolled_all(::Val{N}, f, itr) where {N} =
+ Expr(:&&, f_exprs(N)...)
+@inline gen_unrolled_all(f, itr) = _gen_unrolled_all(Val(length(itr)), f, itr)
+
+@inline @generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} =
+ Expr(:block, f_exprs(N)..., nothing)
+@inline gen_unrolled_foreach(f, itr) =
+ _gen_unrolled_foreach(Val(length(itr)), f, itr)
+
+@inline @generated _gen_unrolled_map(::Val{N}, f, itr) where {N} =
+ Expr(:tuple, f_exprs(N)...)
+@inline gen_unrolled_map(f, itr) = _gen_unrolled_map(Val(length(itr)), f, itr)
+
+@inline @generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr(
+ :block,
+ map(((n, expr),) -> :(n == $n && return $expr), enumerate(f_exprs(N)))...,
+ :(unrolled_applyat_bounds_error()),
+) # This block should get optimized into a switch statement during LLVM codegen.
+@inline gen_unrolled_applyat(f, n, itr) =
+ _gen_unrolled_applyat(Val(length(itr)), f, n, itr)
+
+function nested_op_expr(N, init_type)
+ init_expr = init_type <: NoInit ? :(generic_getindex(itr, 1)) : :init
+ n_range = init_type <: NoInit ? (2:N) : (1:N)
+ return foldl(n_range; init = init_expr) do prev_op_expr, n
+ :(op($prev_op_expr, generic_getindex(itr, $n)))
+ end # Use foldl instead of reduce to guarantee left associativity.
+end
+
+@inline @generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} =
+ nested_op_expr(N, init)
+@inline gen_unrolled_reduce(op, itr, init) =
+ _gen_unrolled_reduce(Val(length(itr)), op, itr, init)
+
+function transformed_sequential_op_exprs(N, init_type)
+ first_item_expr = :(generic_getindex(itr, 1))
+ init_expr =
+ init_type <: NoInit ? first_item_expr : :(op(init, $first_item_expr))
+ transformed_exprs_and_next_op_exprs =
+ accumulate(1:N; init = (nothing, init_expr)) do (_, prev_op_expr), n
+ var = gensym()
+ next_op_expr = :(op($var, generic_getindex(itr, $(n + 1))))
+ (:($var = $prev_op_expr; transform($var)), next_op_expr)
+ end
+ return map(first, transformed_exprs_and_next_op_exprs)
+end
+
+@inline @generated _gen_unrolled_accumulate(
+ ::Val{N},
+ op,
+ itr,
+ init,
+ transform,
+) where {N} = Expr(:tuple, transformed_sequential_op_exprs(N, init)...)
+@inline gen_unrolled_accumulate(op, itr, init, transform) =
+ _gen_unrolled_accumulate(Val(length(itr)), op, itr, init, transform)
diff --git a/src/recursion_limits.jl b/src/recursion_limits.jl
new file mode 100644
index 0000000..2098fe9
--- /dev/null
+++ b/src/recursion_limits.jl
@@ -0,0 +1,55 @@
+# Remove recursion limits from functions that call themselves, and also from all
+# functions whose arguments can be arbitrary functions (including themselves).
+@static if hasfield(Method, :recursion_relation)
+ for func in (
+ generic_getindex,
+ output_type_for_promotion,
+ _rec_unrolled_any,
+ _rec_unrolled_all,
+ _rec_unrolled_foreach,
+ _rec_unrolled_map,
+ _rec_unrolled_applyat,
+ _rec_unrolled_reduce,
+ _rec_unrolled_accumulate,
+ rec_unrolled_any,
+ rec_unrolled_all,
+ rec_unrolled_foreach,
+ rec_unrolled_map,
+ rec_unrolled_applyat,
+ rec_unrolled_reduce,
+ rec_unrolled_accumulate,
+ _gen_unrolled_any,
+ _gen_unrolled_all,
+ _gen_unrolled_foreach,
+ _gen_unrolled_map,
+ _gen_unrolled_applyat,
+ _gen_unrolled_reduce,
+ _gen_unrolled_accumulate,
+ gen_unrolled_any,
+ gen_unrolled_all,
+ gen_unrolled_foreach,
+ gen_unrolled_map,
+ gen_unrolled_applyat,
+ gen_unrolled_reduce,
+ gen_unrolled_accumulate,
+ unrolled_any,
+ unrolled_all,
+ unrolled_foreach,
+ unrolled_map_into_tuple,
+ unrolled_map_into,
+ unrolled_map,
+ unrolled_applyat,
+ unrolled_reduce,
+ unrolled_mapreduce,
+ unrolled_accumulate_into_tuple,
+ unrolled_accumulate_into,
+ unrolled_accumulate,
+ unrolled_filter,
+ unrolled_split,
+ unrolled_flatmap,
+ )
+ for method in methods(func)
+ method.recursion_relation = Returns(true)
+ end
+ end
+end
diff --git a/src/recursively_unrolled_functions.jl b/src/recursively_unrolled_functions.jl
new file mode 100644
index 0000000..db88bef
--- /dev/null
+++ b/src/recursively_unrolled_functions.jl
@@ -0,0 +1,47 @@
+@inline _rec_unrolled_any(f) = false
+@inline _rec_unrolled_any(f, item, items...) =
+ f(item) || _rec_unrolled_any(f, items...)
+@inline rec_unrolled_any(f, itr) = _rec_unrolled_any(f, itr...)
+
+@inline _rec_unrolled_all(f) = true
+@inline _rec_unrolled_all(f, item, items...) =
+ f(item) && _rec_unrolled_all(f, items...)
+@inline rec_unrolled_all(f, itr) = _rec_unrolled_all(f, itr...)
+
+@inline _rec_unrolled_foreach(f) = nothing
+@inline _rec_unrolled_foreach(f, item, items...) =
+ (f(item); _rec_unrolled_foreach(f, items...))
+@inline rec_unrolled_foreach(f, itr) = _rec_unrolled_foreach(f, itr...)
+
+@inline _rec_unrolled_map(f) = ()
+@inline _rec_unrolled_map(f, item, items...) =
+ (f(item), _rec_unrolled_map(f, items...)...)
+@inline rec_unrolled_map(f, itr) = _rec_unrolled_map(f, itr...)
+
+@inline _rec_unrolled_applyat(f, offset_n) = unrolled_applyat_bounds_error()
+@inline _rec_unrolled_applyat(f, offset_n, item, items...) =
+ offset_n == 1 ? f(item) : _rec_unrolled_applyat(f, offset_n - 1, items...)
+@inline rec_unrolled_applyat(f, n, itr) = _rec_unrolled_applyat(f, n, itr...)
+
+@inline _rec_unrolled_reduce(op, prev_value) = prev_value
+@inline _rec_unrolled_reduce(op, prev_value, item, items...) =
+ _rec_unrolled_reduce(op, op(prev_value, item), items...)
+@inline rec_unrolled_reduce(op, itr, init) =
+ init isa NoInit ? _rec_unrolled_reduce(op, itr...) :
+ _rec_unrolled_reduce(op, init, itr...)
+
+@inline _rec_unrolled_accumulate(op, transform, prev_value) =
+ (transform(prev_value),)
+@inline _rec_unrolled_accumulate(op, transform, prev_value, item, items...) = (
+ transform(prev_value),
+ _rec_unrolled_accumulate(op, transform, op(prev_value, item), items...)...,
+)
+@inline rec_unrolled_accumulate(op, itr, init, transform) =
+ isempty(itr) ? () :
+ init isa NoInit ? _rec_unrolled_accumulate(op, transform, itr...) :
+ _rec_unrolled_accumulate(
+ op,
+ transform,
+ op(init, generic_getindex(itr, 1)),
+ unrolled_drop(itr, Val(1))...,
+ )
diff --git a/src/unrollable_iterator_interface.jl b/src/unrollable_iterator_interface.jl
new file mode 100644
index 0000000..9784996
--- /dev/null
+++ b/src/unrollable_iterator_interface.jl
@@ -0,0 +1,216 @@
+#=
+To unroll over a statically-sized iterator of type T, follow these steps:
+- Add a method for either getindex(::T, n) or generic_getindex(::T, n)
+- If every unrolled function that needs to construct an iterator when given an
+ iterator of type T can return a Tuple instead, stop here; otherwise, to return
+ a non-Tuple iterator whenever possible, follow these steps:
+ - Add a method for output_type_for_promotion(::T) = O, where O can be T, a
+ supertype of T, or some other type or AmbiguousOutputType
+ - If an output of type O can be used together with an output of type O′, add
+ a method for output_promote_rule(O, O′)
+ - If an output of type O can be efficiently constructed from a Tuple, add a
+ method for constructor_from_tuple(O); otherwise, add a method for each
+ unrolled function that can efficiently generate an output of type O
+ without temporarily storing it as a Tuple
+=#
+
+"""
+ rec_unroll(itr)
+
+Whether to use recursive loop unrolling instead of generative loop unrolling for
+the iterator `itr`.
+
+In general, recursive loop unrolling is faster to compile for small iterators,
+but it becomes extremely slow to compile for large iterators, and it usually
+generates suboptimal LLVM code for large iterators. On the other hand,
+generative loop unrolling is slow to compile for small iterators, but its
+compilation time does not grow as rapidly with respect to iterator size, and it
+always generates optimal LLVM code. The default is currently to use recursive
+unrolling for iterators lengths up to 16, based on the latest compilation and
+performance comparisons of recursively and generatively unrolled functions.
+"""
+@inline rec_unroll(itr) = length(itr) <= 16
+
+"""
+ generic_getindex(itr, n)
+
+An alternative to `Base.getindex` that can also handle internal types such as
+`Base.Generator` and `Base.Iterators.Enumerate`. Defaults to `Base.getindex`.
+"""
+@inline generic_getindex(itr, n) = getindex(itr, n)
+@inline generic_getindex(itr::Base.Generator, n) =
+ itr.f(generic_getindex(itr.iter, n))
+@inline generic_getindex(itr::Base.Iterators.Enumerate, n) =
+ (n, generic_getindex(itr.itr, n))
+@inline generic_getindex(itr::Base.Iterators.Zip, n) =
+ unrolled_map_into_tuple(Base.Fix2(generic_getindex, n), itr.is)
+
+@inline first_item_type(itr) =
+ Base.promote_op(Base.Fix2(generic_getindex, 1), typeof(itr))
+@inline second_item_type(itr) =
+ Base.promote_op(Base.Fix2(generic_getindex, 2), typeof(itr))
+
+"""
+ output_type_for_promotion(itr)
+
+The type of output that unrolled functions should try to generate for the input
+iterator `itr`, or a `ConditionalOutputType` if the output type depends on the
+type of items that need to be stored in it, or `NoOutputType()` if `itr` is a
+lazy iterator without any associated output type. Defaults to `Tuple`.
+"""
+@inline output_type_for_promotion(_) = Tuple
+@inline output_type_for_promotion(::NamedTuple{names}) where {names} =
+ NamedTuple{names}
+@inline output_type_for_promotion(itr::Base.Generator) =
+ output_type_for_promotion(itr.iter)
+@inline output_type_for_promotion(itr::Base.Iterators.Enumerate) =
+ output_type_for_promotion(itr.itr)
+@inline output_type_for_promotion(itr::Base.Iterators.Zip) =
+ maybe_ambiguous_promoted_output_type(itr.is)
+
+abstract type AmbiguousOutputType end
+
+"""
+ NoOutputType()
+
+The `output_type_for_promotion` of lazy iterators.
+"""
+struct NoOutputType <: AmbiguousOutputType end
+
+"""
+ ConditionalOutputType(allowed_item_type, output_type, [fallback_type])
+
+An `output_type_for_promotion` that can have one of two possible values. If the
+first item in the output is a subtype of `allowed_item_type`, the output will
+have the type `output_type`; otherwise, it will have the type `fallback_type`,
+which is set to `Tuple` by default.
+"""
+struct ConditionalOutputType{I, O, O′} <: AmbiguousOutputType end
+@inline ConditionalOutputType(
+ allowed_item_type::Type,
+ output_type::Type,
+ fallback_type::Type = Tuple,
+) = ConditionalOutputType{allowed_item_type, output_type, fallback_type}()
+
+@inline unambiguous_output_type(_, ::Type{O}) where {O} = O
+@inline unambiguous_output_type(_, ::NoOutputType) = Tuple
+@inline unambiguous_output_type(
+ get_first_item_type,
+ ::ConditionalOutputType{I, O, O′},
+) where {I, O, O′} = get_first_item_type() <: I ? O : O′
+
+"""
+ output_promote_rule(output_type1, output_type2)
+
+The type of output that should be generated when two iterators do not have the
+same `output_type_for_promotion`, or `Union{}` if these iterators should not be
+used together. Only one method of `output_promote_rule` needs to be defined for
+any pair of output types.
+
+By default, all types take precedence over `NoOutputType()`, and the conditional
+part of any `ConditionalOutputType` takes precedence over an unconditional type
+(so that only the `fallback_type` of any conditional type gets promoted). The
+default result for all other pairs of unequal output types is `Union{}`.
+"""
+@inline output_promote_rule(_, _) = Union{}
+@inline output_promote_rule(::Type{O}, ::Type{O}) where {O} = O
+@inline output_promote_rule(::NoOutputType, output_type) = output_type
+@inline output_promote_rule(
+ ::ConditionalOutputType{I, O, O′},
+ ::Type{O′′},
+) where {I, O, O′, O′′} =
+ ConditionalOutputType(I, O, output_promote_rule(O′, O′′))
+@inline output_promote_rule(
+ ::Type{O′},
+ ::ConditionalOutputType{I, O, O′′},
+) where {I, O, O′, O′′} =
+ ConditionalOutputType(I, O, output_promote_rule(O′, O′′))
+@inline output_promote_rule(
+ ::ConditionalOutputType{I, O, O′},
+ ::ConditionalOutputType{I, O, O′′},
+) where {I, O, O′, O′′} =
+ ConditionalOutputType(I, O, output_promote_rule(O′, O′′))
+
+@inline function output_promote_result(O1, O2)
+ O12 = output_promote_rule(O1, O2)
+ O21 = output_promote_rule(O2, O1)
+ O12 == O21 == Union{} &&
+ error("output_promote_rule is undefined for $O1 and $O2")
+ (O12 == O21 || O21 == Union{}) && return O12
+ O12 == Union{} && return O21
+ error("output_promote_rule yields inconsistent results for $O1 and $O2: \
+ $O12 for $O1 followed by $O2, versus $O21 for $O2 followed by $O1")
+end
+
+@inline maybe_ambiguous_promoted_output_type(itrs) =
+ isempty(itrs) ? Tuple : # Generate a Tuple when given 0 inputs.
+ unrolled_mapreduce(output_type_for_promotion, output_promote_result, itrs)
+
+@inline inferred_output_type(itr) =
+ unambiguous_output_type(output_type_for_promotion(itr)) do
+ @inline
+ first_item_type(itr)
+ end
+@inline inferred_output_type(itr::Base.Generator) =
+ unambiguous_output_type(output_type_for_promotion(itr.iter)) do
+ @inline
+ Base.promote_op(itr.f, first_item_type(itr.iter))
+ end
+@inline inferred_output_type(itr::Base.Iterators.Enumerate) =
+ unambiguous_output_type(output_type_for_promotion(itr.itr)) do
+ @inline
+ Tuple{Int, first_item_type(itr.itr)}
+ end
+@inline inferred_output_type(itr::Base.Iterators.Zip) =
+ unambiguous_output_type(maybe_ambiguous_promoted_output_type(itr.is)) do
+ @inline
+ Tuple{unrolled_map_into_tuple(first_item_type, itr.is)...}
+ end
+
+@inline promoted_output_type(itrs) =
+ unambiguous_output_type(maybe_ambiguous_promoted_output_type(itrs)) do
+ @inline
+ first_item_type(generic_getindex(itrs, 1))
+ end
+
+@inline map_output_type(f, itrs...) =
+ inferred_output_type(Iterators.map(f, itrs...))
+
+@inline accumulate_output_type(op, itr, init, transform) =
+ unambiguous_output_type(output_type_for_promotion(itr)) do
+ @inline
+ no_init = init isa NoInit
+ arg1_type = no_init ? first_item_type(itr) : typeof(init)
+ arg2_type = no_init ? second_item_type(itr) : first_item_type(itr)
+ Base.promote_op(transform, Base.promote_op(op, arg1_type, arg2_type))
+ end
+
+"""
+ constructor_from_tuple(output_type)
+
+A function that can be used to efficiently construct an output of type
+`output_type` from a `Tuple`, or `identity` if such an output should not be
+constructed from a `Tuple`. Defaults to `identity`, which also handles the case
+where `output_type` is already `Tuple`. The `output_type` here is guaranteed to
+be a `Type`, rather than a `ConditionalOutputType` or `NoOutputType`.
+
+Many statically-sized iterators (e.g., `SVector`s) are essentially wrappers for
+`Tuple`s, and their constructors for `Tuple`s can be reduced to no-ops. The main
+exceptions are [`StaticOneTo`](@ref UnrolledUtilities.StaticOneTo)s and
+[`StaticBitVector`](@ref UnrolledUtilities.StaticBitVector)s, which do not
+provide constructors for `Tuple`s because there is no performance benefit to
+making a lazy or low-storage data structure once a corresponding high-storage
+data structure has already been constructed.
+"""
+@inline constructor_from_tuple(::Type) = identity
+@inline constructor_from_tuple(::Type{NT}) where {NT <: NamedTuple} = NT
+
+@inline inferred_constructor_from_tuple(itr) =
+ constructor_from_tuple(inferred_output_type(itr))
+
+@inline promoted_constructor_from_tuple(itrs) =
+ constructor_from_tuple(promoted_output_type(itrs))
+
+@inline inferred_empty(itr) = inferred_constructor_from_tuple(itr)(())
+
+@inline promoted_empty(itrs) = promoted_constructor_from_tuple(itrs)(())
diff --git a/test/aqua.jl b/test/aqua.jl
index d7becf1..ff1edd1 100644
--- a/test/aqua.jl
+++ b/test/aqua.jl
@@ -1,3 +1,4 @@
+using Test
import Aqua, UnrolledUtilities
# This is separate from all the other tests because Aqua.test_all checks for
diff --git a/test/runtests.jl b/test/runtests.jl
index 631181c..0cfddab 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,7 +2,9 @@ using SafeTestsets
@safetestset "Test and Analyze" begin
@time include("test_and_analyze.jl")
- print_comparison_table()
+ for (title, comparison_table_dict) in comparison_table_dicts
+ print_comparison_table(title, comparison_table_dict)
+ end
end
@safetestset "Aqua" begin
@time include("aqua.jl")
diff --git a/test/test_and_analyze.jl b/test/test_and_analyze.jl
index 70415e3..f14da78 100644
--- a/test/test_and_analyze.jl
+++ b/test/test_and_analyze.jl
@@ -6,45 +6,78 @@ using InteractiveUtils
using UnrolledUtilities
-comparison_table_dict = OrderedDict()
+comparison_table_dicts = OrderedDict()
-function print_comparison_table(io = stdout, generate_html = false)
+function print_comparison_table(title, comparison_table_dict, io = stdout)
table_data =
mapreduce(vcat, collect(comparison_table_dict)) do (key, entries)
stack(entry -> (key..., entry...), entries; dims = 1)
end
- highlighter(f, color) =
- generate_html ? HtmlHighlighter(f, HtmlDecoration(; color)) :
- Highlighter(f, Crayon(; foreground = Symbol(color)))
-
- better_performance_but_harder_to_compile =
- highlighter(generate_html ? "royalblue" : "blue") do data, i, j
- data[i, 4] != data[i, 5] &&
- (endswith(data[i, 6], "slower") || endswith(data[i, 7], "more"))
- end
- better_performance =
- highlighter(generate_html ? "mediumseagreen" : "green") do data, i, j
- data[i, 4] != data[i, 5]
- end
- mixed_compilation =
- highlighter(generate_html ? "mediumorchid" : "magenta") do data, i, j
- (endswith(data[i, 6], "slower") && endswith(data[i, 7], "less")) ||
- (endswith(data[i, 6], "faster") && endswith(data[i, 7], "more"))
- end
- harder_to_compile =
- highlighter(generate_html ? "indianred" : "red") do data, i, j
- endswith(data[i, 6], "slower") || endswith(data[i, 7], "more")
- end
- easier_to_compile =
- highlighter(generate_html ? "darkturquoise" : "cyan") do data, i, j
- endswith(data[i, 6], "faster") || endswith(data[i, 7], "less")
+ writing_to_docs = io isa IOStream
+
+ color(color_str) =
+ writing_to_docs ? HtmlDecoration(; color_str) :
+ Crayon(; foreground = Symbol(color_str))
+ highlighter_color(optimization, run_time, compile_time, allocs) =
+ if (
+ contains(optimization, "better") ||
+ contains(optimization, "fewer allocs")
+ ) && !contains(run_time, "more") ||
+ contains(optimization, "identical") && contains(run_time, "less")
+ # better performance
+ if !contains(compile_time, "more") && !contains(allocs, "more")
+ # similar or better compilation and total allocations
+ if contains(optimization, "better")
+ # better optimization
+ color(writing_to_docs ? "darkturquoise" : "cyan")
+ else
+ # faster run time or fewer allocations at run time
+ color(writing_to_docs ? "mediumseagreen" : "green")
+ end
+ else
+ # worse compilation or total allocations
+ if contains(optimization, "better")
+ # better optimization
+ color(writing_to_docs ? "royalblue" : "blue")
+ else
+ # faster run time or fewer allocations at run time
+ color(writing_to_docs ? "khaki" : "yellow")
+ end
+ end
+ elseif contains(optimization, "identical") &&
+ contains(run_time, "similar")
+ # similar performance
+ if contains(compile_time, "less") && !contains(allocs, "more") ||
+ !contains(compile_time, "more") && contains(allocs, "less")
+ # better compilation or total allocations
+ color(writing_to_docs ? "mediumorchid" : "magenta")
+ elseif contains(compile_time, "similar") &&
+ contains(allocs, "similar")
+ # similar compilation and total allocations
+ color(writing_to_docs ? "lightgray" : "light_gray")
+ elseif contains(compile_time, "less") && contains(allocs, "more") ||
+ contains(compile_time, "more") && contains(allocs, "less")
+ # mixed compilation and total allocations
+ color(writing_to_docs ? "darkgray" : "dark_gray")
+ else
+ # worse compilation and total allocations
+ color(writing_to_docs ? "indianred" : "red")
+ end
+ else
+ # worse performance
+ color(writing_to_docs ? "indianred" : "red")
end
- no_difference =
- highlighter((data, i, j) -> true, generate_html ? "khaki" : "yellow")
+ highlighter = (writing_to_docs ? HtmlHighlighter : Highlighter)(
+ Returns(true),
+ (_, data, row, _) -> highlighter_color(data[row, 6:9]...),
+ )
+
+ # TODO: Why does Sys.maxrss() always seem to be 0 on Ubuntu systems?
+ has_rss = any(contains('['), table_data[:, 9])
other_kwargs =
- generate_html ?
+ writing_to_docs ?
(;
backend = Val(:html),
table_style = Dict(
@@ -53,38 +86,86 @@ function print_comparison_table(io = stdout, generate_html = false)
),
) :
(;
+ title,
+ title_alignment = :c,
title_same_width_as_table = true,
- columns_width = [45, 45, 0, 0, 0, 0, 0],
+ columns_width = [45, 45, 15, 10, 30, 25, 20, 20, has_rss ? 30 : 20],
linebreaks = true,
autowrap = true,
crop = :none,
)
+ if writing_to_docs
+ println(io, "## $title")
+ println(io, "```@raw html")
+ println(io, "") # 80% of viewport
+ end
pretty_table(
io,
table_data;
- title = "Comparison of UnrolledUtilities to Base and Base.Iterators",
- title_alignment = :c,
alignment = :l,
header = [
"Unrolled Expression",
"Reference Expression",
- "Iterator Contents",
- "Unrolled Performance",
- "Reference Performance",
- "Unrolled Compilation Time",
- "Unrolled Compilation Memory",
+ "Itr Type",
+ "Itr Length",
+ "Itr Contents",
+ "Optimization",
+ "Run Time",
+ "Compilation Time",
+ "Total $(has_rss ? "GC [and RSS] " : "")Allocations",
],
- highlighters = (
- better_performance_but_harder_to_compile,
- better_performance,
- mixed_compilation,
- harder_to_compile,
- easier_to_compile,
- no_difference,
- ),
+ highlighters = highlighter,
other_kwargs...,
)
+ if writing_to_docs
+ println(io, "
")
+ println(io, "```")
+ else
+ println(io)
+ end
+end
+
+function time_string(nanoseconds)
+ nanoseconds == 0 && return "$nanoseconds ns"
+ n_decimal_digits = floor(Int, log10(nanoseconds) + 1)
+ return if n_decimal_digits <= 3
+ "$nanoseconds ns"
+ elseif n_decimal_digits <= 6
+ "$(round(Int, nanoseconds / 10^3)) μs"
+ elseif n_decimal_digits <= 9
+ "$(round(Int, nanoseconds / 10^6)) ms"
+ else
+ "$(round(Int, nanoseconds / 10^9)) s"
+ end
+end
+
+function memory_string(bytes)
+ bytes == 0 && return "$bytes B"
+ n_binary_digits = floor(Int, log2(bytes) + 1)
+ return if n_binary_digits <= 10
+ "$bytes B"
+ elseif n_binary_digits <= 20
+ "$(round(Int, bytes / 2^10)) kB"
+ elseif n_binary_digits <= 30
+ "$(round(Int, bytes / 2^20)) MB"
+ else
+ "$(round(Int, bytes / 2^30)) GB"
+ end
+end
+
+function comparison_string(value1, value2, to_string, to_number = identity)
+ ratio = to_number(value1) / to_number(value2)
+ itr_str = if ratio >= 2
+ floored_ratio = ratio == Inf ? Inf : floor(Int, ratio)
+ "$floored_ratio times more"
+ elseif inv(ratio) >= 2
+ floored_inv_ratio = ratio == 0 ? Inf : floor(Int, inv(ratio))
+ "$floored_inv_ratio times less"
+ else
+ "similar"
+ end
+ return "$itr_str ($(to_string(value1)) vs. $(to_string(value2)))"
end
function drop_line_numbers(expr)
@@ -118,18 +199,61 @@ function code_instance(f, args...)
end
end
-macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str)
+macro benchmark(expression)
+ return quote
+ prev_time = time_ns()
+ $(esc(expression))
+ new_time = time_ns()
+ best_time = new_time - prev_time
+
+ # Benchmark for at most 0.1 s (10^8 ns), ignoring the first call above.
+ n_trials = 0
+ start_time = new_time
+ while n_trials < 10^4 && new_time - start_time < 10^8
+ prev_time = time_ns()
+ $(esc(expression))
+ new_time = time_ns()
+ best_time = min(best_time, new_time - prev_time)
+ n_trials += 1
+ end
+
+ best_time
+ end
+end
+
+macro test_unrolled(
+ args_expr,
+ unrolled_expr,
+ reference_expr,
+ itr_contents_str,
+ skip_allocations_test = false,
+ skip_type_stability_test = false,
+ reference_preamble_expr = nothing,
+)
@assert Meta.isexpr(args_expr, :tuple)
arg_names = args_expr.args
@assert all(arg_name -> arg_name isa Symbol, arg_names)
args = map(esc, arg_names)
unrolled_expr_str = simplified_expression_string(unrolled_expr)
reference_expr_str = simplified_expression_string(reference_expr)
- expr_info_str =
- length(args) == 1 ? "$unrolled_expr_str with 1 iterator that contains" :
- "$unrolled_expr_str with $(length(args)) iterators that each contain"
+ reference_preamble_str =
+ isnothing(reference_preamble_expr) ? "" :
+ string(reference_preamble_expr) * '\n'
+ contains_str = length(args) == 1 ? " that contains" : "s that each contain"
quote
- @info "Testing $($expr_info_str) $($(esc(contents_info_str)))"
+ itr_types = map(arg -> typeof(arg).name.wrapper, ($(args...),))
+ itr_lengths = map(length, ($(args...),))
+
+ itr_type_str =
+ length(unique(itr_types)) == 1 ? string(itr_types[1]) :
+ join(itr_types, '/')
+ itr_length_str =
+ length(unique(itr_lengths)) == 1 ? string(itr_lengths[1]) :
+ join(itr_lengths, '/')
+ itr_str = "$itr_type_str$($contains_str) $itr_length_str \
+ $($(esc(itr_contents_str)))"
+
+ @info "Testing $($unrolled_expr_str) with $($(length(args))) $itr_str"
unrolled_func($(arg_names...)) = $(esc(unrolled_expr))
reference_func($(arg_names...)) = $(esc(reference_expr))
@@ -146,26 +270,27 @@ macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str)
reference_func_and_nothing($(args...))
# Test for allocations.
- @test (@allocated unrolled_func_and_nothing($(args...))) == 0
- is_reference_non_allocating =
- (@allocated reference_func_and_nothing($(args...))) == 0
+ unrolled_run_memory = @allocated unrolled_func_and_nothing($(args...))
+ reference_run_memory = @allocated reference_func_and_nothing($(args...))
+ $(esc(skip_allocations_test)) || @test unrolled_run_memory == 0
# Test for type-stability.
- @test_opt unrolled_func($(args...))
+ is_unrolled_stable =
+ isempty(JET.get_reports(@report_opt unrolled_func($(args...))))
is_reference_stable =
isempty(JET.get_reports(@report_opt reference_func($(args...))))
-
- unrolled_instance = code_instance(unrolled_func, $(args...))
- reference_instance = code_instance(reference_func, $(args...))
+ $(esc(skip_type_stability_test)) || @test_opt unrolled_func($(args...))
# Test for constant propagation.
- is_unrolled_const = isdefined(unrolled_instance, :rettype_const)
- Base.issingletontype(typeof(($(args...),))) && @test is_unrolled_const
- is_reference_const = isdefined(reference_instance, :rettype_const)
+ is_unrolled_const =
+ isdefined(code_instance(unrolled_func, $(args...)), :rettype_const)
+ is_reference_const =
+ isdefined(code_instance(reference_func, $(args...)), :rettype_const)
+ # Base.issingletontype(typeof(($(args...),))) && @test is_unrolled_const
buffer = IOBuffer()
- # Check whether the functions are fully optimized out.
+ # Determine whether the functions are fully optimized out.
args_type = Tuple{map(typeof, ($(args...),))...}
code_llvm(buffer, unrolled_func, args_type; debuginfo = :none)
is_unrolled_optimized_out =
@@ -174,86 +299,116 @@ macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str)
is_reference_optimized_out =
length(split(String(take!(buffer)), '\n')) == 5
+ # Test the overall level of optimization.
+ unrolled_opt_str, unrolled_opt_score = if unrolled_run_memory > 0
+ "$(memory_string(unrolled_run_memory)) allocs", 1 / unrolled_run_memory
+ elseif !is_unrolled_stable
+ "type-unstable", 2
+ elseif !is_unrolled_const && !is_unrolled_optimized_out
+ "type-stable", 3
+ elseif !is_unrolled_optimized_out
+ "constant", 4
+ else
+ "optimized out", 5
+ end
+ reference_opt_str, reference_opt_score = if reference_run_memory > 0
+ "$(memory_string(reference_run_memory)) allocs",
+ 1 / reference_run_memory
+ elseif !is_reference_stable
+ "type-unstable", 2
+ elseif !is_reference_const && !is_reference_optimized_out
+ "type-stable", 3
+ elseif !is_reference_optimized_out
+ "constant", 4
+ else
+ "optimized out", 5
+ end
+ $(esc(skip_type_stability_test)) ||
+ @test unrolled_opt_score >= reference_opt_score
+
+ # Measure the run times.
+ unrolled_run_time = @benchmark unrolled_func($(args...))
+ reference_run_time = @benchmark reference_func($(args...))
+
+ # Measure the compilation times and memory allocations in separate
+ # processes to ensure that they are not under-counted.
arg_name_strs = ($(map(string, arg_names)...),)
arg_names_str = join(arg_name_strs, ", ")
arg_definition_strs =
map((name, value) -> "$name = $value", arg_name_strs, ($(args...),))
arg_definitions_str = join(arg_definition_strs, '\n')
- unrolled_command_str = """
+ command_str(func_str) = """
using UnrolledUtilities
- unrolled_func($arg_names_str) = $($(string(unrolled_expr)))
$arg_definitions_str
- stats1 = @timed unrolled_func($arg_names_str)
- stats2 = @timed unrolled_func($arg_names_str)
- print(stats1.time - stats2.time, ',', stats1.bytes - stats2.bytes)
- """
- reference_command_str = """
- reference_func($arg_names_str) = $($(string(reference_expr)))
- $arg_definitions_str
- stats1 = @timed reference_func($arg_names_str)
- stats2 = @timed reference_func($arg_names_str)
- print(stats1.time - stats2.time, ',', stats1.bytes - stats2.bytes)
+ Base.cumulative_compile_timing(true)
+ nanoseconds1 = Base.cumulative_compile_time_ns()[1]
+ rss_bytes_1 = Sys.maxrss()
+ Δgc_bytes = @allocated $func_str
+ rss_bytes_2 = Sys.maxrss()
+ nanoseconds2 = Base.cumulative_compile_time_ns()[1]
+ Base.cumulative_compile_timing(false)
+ Δnanoseconds = nanoseconds2 - nanoseconds1
+ Δrss_bytes = rss_bytes_2 - rss_bytes_1
+ print(Δnanoseconds, ", ", Δgc_bytes, ", ", Δrss_bytes)
"""
- # Get the unrolled function's time-to-first-run and its memory usage.
+ unrolled_command_str = command_str($(string(unrolled_expr)))
run(pipeline(`julia --project -e $unrolled_command_str`, buffer))
- unrolled_time, unrolled_memory =
- parse.((Float64, Int), split(String(take!(buffer)), ','))
+ unrolled_compile_time, unrolled_total_memory, unrolled_total_rss =
+ parse.((Int, Int, Int), split(String(take!(buffer)), ','))
# Make a new buffer to avoid a potential data race:
# https://discourse.julialang.org/t/iobuffer-becomes-not-writable-after-run/92323/3
close(buffer)
buffer = IOBuffer()
- # Get the reference function's time-to-first-run and its memory usage.
+ reference_command_str =
+ $reference_preamble_str * command_str($(string(reference_expr)))
run(pipeline(`julia --project -e $reference_command_str`, buffer))
- reference_time, reference_memory =
- parse.((Float64, Int), split(String(take!(buffer)), ','))
+ reference_compile_time, reference_total_memory, reference_total_rss =
+ parse.((Int, Int, Int), split(String(take!(buffer)), ','))
close(buffer)
- # Record all relevant information in comparison_table_dict.
- unrolled_performance_str = if !is_unrolled_const
- "type-stable"
- elseif !is_unrolled_optimized_out
- "const return value"
- else
- "fully optimized out"
- end
- reference_performance_str = if !is_reference_non_allocating
- "allocating"
- elseif !is_reference_stable
- "type-unstable"
- elseif !is_reference_const
- "type-stable"
- elseif !is_reference_optimized_out
- "const return value"
- else
- "fully optimized out"
- end
- time_ratio = unrolled_time / reference_time
- time_ratio_str = if time_ratio >= 1.5
- "$(round(Int, time_ratio)) times slower"
- elseif inv(time_ratio) >= 1.5
- "$(round(Int, inv(time_ratio))) times faster"
- else
- "similar"
- end
- memory_ratio = unrolled_memory / reference_memory
- memory_ratio_str = if memory_ratio >= 1.5
- "$(round(Int, memory_ratio)) times more"
- elseif inv(memory_ratio) >= 1.5
- "$(round(Int, inv(memory_ratio))) times less"
+ optimization_str = if unrolled_opt_score > reference_opt_score
+ if unrolled_opt_score <= 1
+ "fewer allocs ($unrolled_opt_str vs. $reference_opt_str)"
+ else
+ "better ($unrolled_opt_str vs. $reference_opt_str)"
+ end
+ elseif unrolled_opt_score < reference_opt_score
+ "worse ($unrolled_opt_str vs. $reference_opt_str)"
else
- "similar"
+ "identical ($unrolled_opt_str)"
end
+ run_time_str = comparison_string(
+ unrolled_run_time,
+ reference_run_time,
+ time_string,
+ )
+ compile_time_str = comparison_string(
+ unrolled_compile_time,
+ reference_compile_time,
+ time_string,
+ )
+ memory_str = comparison_string(
+ (unrolled_total_memory, unrolled_total_rss),
+ (reference_total_memory, reference_total_rss),
+ ((gc_bytes, rss_bytes),) ->
+ rss_bytes == 0 ? memory_string(gc_bytes) :
+ "$(memory_string(gc_bytes)) [$(memory_string(rss_bytes))]",
+ first, # Use GC value for comparison since RSS might be unavailable.
+ )
+
dict_key = ($unrolled_expr_str, $reference_expr_str)
dict_entry = (
- $(esc(contents_info_str)),
- unrolled_performance_str,
- reference_performance_str,
- time_ratio_str,
- memory_ratio_str,
+ itr_type_str,
+ itr_length_str,
+ $(esc(itr_contents_str)),
+ optimization_str,
+ run_time_str,
+ compile_time_str,
+ memory_str,
)
if dict_key in keys(comparison_table_dict)
push!(comparison_table_dict[dict_key], dict_entry)
@@ -263,160 +418,212 @@ macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str)
end
end
-@testset "empty iterators" begin
- itr = ()
- str = "nothing"
- @test_unrolled (itr,) unrolled_any(error, itr) any(error, itr) str
- @test_unrolled (itr,) unrolled_all(error, itr) all(error, itr) str
- @test_unrolled (itr,) unrolled_foreach(error, itr) foreach(error, itr) str
- @test_unrolled (itr,) unrolled_map(error, itr, itr) map(error, itr, itr) str
- @test_unrolled(
- (itr,),
- unrolled_reduce(error, itr; init = 0),
- reduce(error, itr; init = 0),
- str,
- )
+tuple_of_tuples(num_tuples, min_tuple_length, singleton, identical) =
+ ntuple(num_tuples) do index
+ tuple_length = min_tuple_length + (identical ? 0 : (index - 1) % 7)
+ ntuple(singleton ? Val : identity, tuple_length)
+ end
+function tuples_of_tuples_contents_str(itrs...)
+ str = ""
+ all(itr -> length(itr) > 1 && length(unique(itr)) == 1, itrs) &&
+ (str *= "identical ")
+ all(itr -> length(itr) > 1 && length(unique(itr)) != 1, itrs) &&
+ (str *= "distinct ")
+ all(itr -> all(isempty, itr), itrs) && (str *= "empty ")
+ all(itr -> all(!isempty, itr), itrs) && (str *= "nonempty ")
+ all(itr -> any(isempty, itr) && any(!isempty, itr), itrs) &&
+ (str *= "empty & nonempty ")
+ all(itr -> Base.issingletontype(typeof(itr)), itrs) && (str *= "singleton ")
+ all(itr -> !Base.issingletontype(typeof(itr)), itrs) &&
+ (str *= "non-singleton ")
+ str *= "Tuple"
+ all(itr -> length(itr) > 1, itrs) && (str *= "s")
+ return str
end
-for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))
- itr1 = ntuple(i -> ntuple(Val, identical ? 0 : (i - 1) % 7), n)
- itr2 = ntuple(i -> ntuple(Val, identical ? 1 : (i - 1) % 7 + 1), n)
- itr3 = ntuple(i -> ntuple(identity, identical ? 1 : (i - 1) % 7 + 1), n)
- if n == 1
- str1 = "1 empty tuple"
- str2 = "1 nonempty singleton tuple"
- str3 = "1 nonempty non-singleton tuple"
- str12 = "1 singleton tuple"
- str23 = "1 nonempty tuple"
- str123 = "1 tuple"
- elseif identical
- str1 = "$n empty tuples"
- str2 = "$n identical nonempty singleton tuples"
- str3 = "$n identical nonempty non-singleton tuples"
- str12 = "$n identical singleton tuples"
- str23 = "$n identical nonempty tuples"
- str123 = "$n identical tuples"
- else
- str1 = "$n empty and nonempty singleton tuples"
- str2 = "$n nonempty singleton tuples"
- str3 = "$n nonempty non-singleton tuples"
- str12 = "$n singleton tuples"
- str23 = "$n nonempty tuples"
- str123 = "$n tuples"
- end
- @testset "iterators of $str123" begin
- for (itr, str) in ((itr1, str1), (itr2, str2), (itr3, str3))
- @test_unrolled (itr,) unrolled_any(isempty, itr) any(isempty, itr) str
- @test_unrolled (itr,) unrolled_any(!isempty, itr) any(!isempty, itr) str
+# TODO:
+# - Split recursively and generatively unrolled funcs by length
+# - Move empty iterator, no-iterator, and 8187 iterator tests into unit tests
+# - Make init accept Fields, turn reduce and accumulate into operators through
+# Base.Broadcast.broadcasted_kwsyntax
+
+title = "Individual unrolled functions"
+comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())
+
+for itr in (
+ tuple_of_tuples(1, 0, true, true),
+ tuple_of_tuples(1, 1, true, true),
+ tuple_of_tuples(1, 1, false, true),
+ map(n -> tuple_of_tuples(n, 0, true, true), (8, 32, 33, 128))...,
+ map(n -> tuple_of_tuples(n, 1, true, true), (8, 32, 33, 128))...,
+ map(n -> tuple_of_tuples(n, 1, false, true), (8, 32, 33, 128))...,
+ map(n -> tuple_of_tuples(n, 0, true, false), (8, 32, 33, 128))...,
+ map(n -> tuple_of_tuples(n, 1, true, false), (8, 32, 33, 128))...,
+ map(n -> tuple_of_tuples(n, 1, false, false), (8, 32, 33, 128))...,
+)
+ str = tuples_of_tuples_contents_str(itr)
+ itr_str = "a Tuple that contains $(length(itr)) $str"
+ @testset "individual unrolled functions of $itr_str" begin
+ @test_unrolled (itr,) unrolled_any(isempty, itr) any(isempty, itr) str
+ @test_unrolled (itr,) unrolled_any(!isempty, itr) any(!isempty, itr) str
+
+ @test_unrolled (itr,) unrolled_all(isempty, itr) all(isempty, itr) str
+ @test_unrolled (itr,) unrolled_all(!isempty, itr) all(!isempty, itr) str
- @test_unrolled (itr,) unrolled_all(isempty, itr) all(isempty, itr) str
- @test_unrolled (itr,) unrolled_all(!isempty, itr) all(!isempty, itr) str
+ @test_unrolled(
+ (itr,),
+ unrolled_foreach(x -> @assert(length(x) <= 7), itr),
+ foreach(x -> @assert(length(x) <= 7), itr),
+ str,
+ )
- @test_unrolled(
- (itr,),
- unrolled_foreach(x -> @assert(length(x) <= 7), itr),
- foreach(x -> @assert(length(x) <= 7), itr),
- str,
- )
+ @test_unrolled (itr,) unrolled_map(length, itr) map(length, itr) str
- @test_unrolled (itr,) unrolled_map(length, itr) map(length, itr) str
+ @test_unrolled(
+ (itr,),
+ unrolled_applyat(length, rand(1:7:length(itr)), itr),
+ length(itr[rand(1:7:length(itr))]),
+ str,
+ )
- @test_unrolled (itr,) unrolled_reduce(tuple, itr) reduce(tuple, itr) str
- @test_unrolled(
- (itr,),
- unrolled_reduce(tuple, itr; init = ()),
- reduce(tuple, itr; init = ()),
- str,
- )
+ @test_unrolled (itr,) unrolled_reduce(tuple, itr) reduce(tuple, itr) str
+ @test_unrolled(
+ (itr,),
+ unrolled_reduce(tuple, itr; init = ()),
+ reduce(tuple, itr; init = ()),
+ str,
+ )
+ @test_unrolled(
+ (itr,),
+ unrolled_mapreduce(length, +, itr),
+ mapreduce(length, +, itr),
+ str,
+ )
+ @test_unrolled(
+ (itr,),
+ unrolled_mapreduce(length, +, itr; init = 0),
+ mapreduce(length, +, itr; init = 0),
+ str,
+ )
+
+ if length(itr) <= 33
@test_unrolled(
(itr,),
- unrolled_mapreduce(length, +, itr),
- mapreduce(length, +, itr),
+ unrolled_accumulate(tuple, itr),
+ accumulate(tuple, itr),
str,
)
@test_unrolled(
(itr,),
- unrolled_mapreduce(length, +, itr; init = 0),
- mapreduce(length, +, itr; init = 0),
+ unrolled_accumulate(tuple, itr; init = ()),
+ accumulate(tuple, itr; init = ()),
str,
)
+ end # These can take half a minute to compile when the length is 128.
- @test_unrolled (itr,) unrolled_zip(itr) Tuple(zip(itr)) str
+ @test_unrolled (itr,) unrolled_push(itr, itr[1]) (itr..., itr[1]) str
+ @test_unrolled (itr,) unrolled_append(itr, itr) (itr..., itr...) str
- @test_unrolled (itr,) unrolled_enumerate(itr) Tuple(enumerate(itr)) str
+ @test_unrolled(
+ (itr,),
+ unrolled_take(itr, Val(length(itr) ÷ 2)),
+ itr[1:(length(itr) ÷ 2)],
+ str,
+ )
+ @test_unrolled(
+ (itr,),
+ unrolled_drop(itr, Val(length(itr) ÷ 2)),
+ itr[(length(itr) ÷ 2 + 1):end],
+ str,
+ )
- @test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str
- @test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str
- @test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str
+ @test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str
+ @test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str
+ @test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str
- # unrolled_unique is only type-stable for singletons
- if Base.issingletontype(typeof(itr))
- @test_unrolled (itr,) unrolled_unique(itr) Tuple(unique(itr)) str
- end
+ @test_unrolled(
+ (itr,),
+ unrolled_unique(itr),
+ Tuple(unique(itr)),
+ str,
+ !Base.issingletontype(typeof(itr)),
+ !Base.issingletontype(typeof(itr)),
+ ) # unrolled_unique is type-unstable for non-singleton values
- @test_unrolled(
- (itr,),
- unrolled_filter(!isempty, itr),
- filter(!isempty, itr),
- str,
- )
+ @test_unrolled(
+ (itr,),
+ unrolled_filter(!isempty, itr),
+ filter(!isempty, itr),
+ str,
+ )
- @test_unrolled(
- (itr,),
- unrolled_split(isempty, itr),
- (filter(isempty, itr), filter(!isempty, itr)),
- str,
- )
+ @test_unrolled(
+ (itr,),
+ unrolled_split(isempty, itr),
+ (filter(isempty, itr), filter(!isempty, itr)),
+ str,
+ )
- @test_unrolled(
- (itr,),
- unrolled_flatten(itr),
- Tuple(Iterators.flatten(itr)),
- str,
- )
+ @test_unrolled(
+ (itr,),
+ unrolled_flatten(itr),
+ Tuple(Iterators.flatten(itr)),
+ str,
+ )
- @test_unrolled(
- (itr,),
- unrolled_flatmap(reverse, itr),
- Tuple(Iterators.flatmap(reverse, itr)),
- str,
- )
+ @test_unrolled(
+ (itr,),
+ unrolled_flatmap(reverse, itr),
+ Tuple(Iterators.flatmap(reverse, itr)),
+ str,
+ )
+ if length(itr) <= 33
@test_unrolled(
(itr,),
- unrolled_product(itr),
- Tuple(Iterators.product(itr)),
+ unrolled_product(itr, itr),
+ Tuple(Iterators.product(itr, itr)),
str,
)
-
+ end
+ if length(itr) <= 8
@test_unrolled(
(itr,),
- unrolled_applyat(
- x -> @assert(length(x) <= 7),
- rand(1:length(itr)),
- itr,
- ),
- @assert(length(itr[rand(1:length(itr))]) <= 7),
+ unrolled_product(itr, itr, itr),
+ Tuple(Iterators.product(itr, itr, itr)),
str,
)
+ end # This can take several minutes to compile when the length is 32.
+ end
+end
- if n > 1
- @test_unrolled(
- (itr,),
- unrolled_take(itr, Val(7)),
- itr[1:7],
- str,
- )
- @test_unrolled(
- (itr,),
- unrolled_drop(itr, Val(7)),
- itr[8:end],
- str,
- )
- end
- end
-
+title = "Nested unrolled functions"
+comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())
+
+for (itr1, itr2, itr3) in (
+ (
+ tuple_of_tuples(1, 0, true, true),
+ tuple_of_tuples(1, 1, true, true),
+ tuple_of_tuples(1, 1, false, true),
+ ),
+ zip(
+ map(n -> tuple_of_tuples(n, 0, true, true), (8, 32, 33, 128)),
+ map(n -> tuple_of_tuples(n, 1, true, true), (8, 32, 33, 128)),
+ map(n -> tuple_of_tuples(n, 1, false, true), (8, 32, 33, 128)),
+ )...,
+ zip(
+ map(n -> tuple_of_tuples(n, 0, true, false), (8, 32, 33, 128)),
+ map(n -> tuple_of_tuples(n, 1, true, false), (8, 32, 33, 128)),
+ map(n -> tuple_of_tuples(n, 1, false, false), (8, 32, 33, 128)),
+ )...,
+)
+ str3 = tuples_of_tuples_contents_str(itr3)
+ str12 = tuples_of_tuples_contents_str(itr1, itr2)
+ str23 = tuples_of_tuples_contents_str(itr2, itr3)
+ str123 = tuples_of_tuples_contents_str(itr1, itr2, itr3)
+ itr_str = "Tuples that contain $(length(itr1)) $str123"
+ @testset "nested unrolled functions of $itr_str" begin
@test_unrolled(
(itr3,),
unrolled_any(x -> unrolled_reduce(+, x) > 7, itr3),
@@ -434,11 +641,11 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))
@test_unrolled(
(itr1, itr2),
unrolled_foreach(
- (x1, x2) -> @assert(length(x1) < length(x2)),
+ (x1, x2) -> @assert(x1 == unrolled_take(x2, Val(length(x1)))),
itr1,
itr2,
),
- foreach((x1, x2) -> @assert(length(x1) < length(x2)), itr1, itr2),
+ foreach((x1, x2) -> @assert(x1 == x2[1:length(x1)]), itr1, itr2),
str12,
)
@test_unrolled(
@@ -455,13 +662,13 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))
@test_unrolled(
(itr1, itr2),
unrolled_applyat(
- (x1, x2) -> @assert(length(x1) < length(x2)),
+ (x1, x2) -> @assert(x1 == unrolled_take(x2, Val(length(x1)))),
rand(1:length(itr1)),
itr1,
itr2,
),
let n = rand(1:length(itr1))
- @assert(length(itr1[n]) < length(itr2[n]))
+ @assert(itr1[n] == itr2[n][1:length(itr1[n])])
end,
str12,
)
@@ -478,51 +685,25 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))
end,
str23,
)
-
- @test_unrolled(
- (itr1, itr2),
- unrolled_zip(itr1, itr2),
- Tuple(zip(itr1, itr2)),
- str12,
- )
- @test_unrolled(
- (itr1, itr2, itr3),
- unrolled_zip(itr1, itr2, itr3),
- Tuple(zip(itr1, itr2, itr3)),
- str123,
- )
-
- # unrolled_product can take several minutes to compile when n is large
- if n <= 33
- @test_unrolled(
- (itr1, itr2),
- unrolled_product(itr1, itr2),
- Tuple(Iterators.product(itr1, itr2)),
- str12,
- )
- end
- if n <= 8
- @test_unrolled(
- (itr1, itr2, itr3),
- unrolled_product(itr1, itr2, itr3),
- Tuple(Iterators.product(itr1, itr2, itr3)),
- str123,
- )
- end
end
end
nested_iterator(depth, n, inner_n) =
depth == 1 ? ntuple(identity, n) :
- ntuple(inner_n) do _
- nested_iterator(depth - 1, Int(n / inner_n), inner_n)
- end
+ ntuple(
+ Returns(nested_iterator(depth - 1, Int(n / inner_n), inner_n)),
+ inner_n,
+ )
+
+title = "Recursive unrolled functions"
+comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())
for n in (8, 32, 128)
- @testset "iterators of $n values in nested tuples" begin
+ itr_str = "a Tuple that contains $n values in nested Tuples"
+ @testset "recursive unrolled functions of $itr_str" begin
for depth in (2, 3, 4:2:(Int(log2(n)) + 1)...)
itr = nested_iterator(depth, n, 2)
- str = "$n values in nested tuples of depth $depth"
+ str = "recursive unrolled functions of $itr_str of depth $depth"
# In the following definitions, use var"#self#" to avoid boxing:
# https://discourse.julialang.org/t/performant-recursive-anonymous-functions/90984/5
@test_unrolled(
@@ -561,3 +742,245 @@ for n in (8, 32, 128)
end
end
end
+
+title = "Nested unrolled closures"
+comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())
+
+@testset "nested unrolled closures of Tuples vs. StaticBitVectors" begin
+ for (itr, skip_allocations_test) in (
+ (ntuple(Returns(true), 32), false),
+ (ntuple(Returns(true), 33), true),
+ (StaticBitVector{256}(true), false),
+ (StaticBitVector{257}(true), true),
+ )
+ @test_unrolled(
+ (itr,),
+ unrolled_reduce(
+ (itr′, i) -> Base.setindex(itr′, !itr′[i], i),
+ StaticOneTo(length(itr));
+ init = itr,
+ ),
+ reduce(
+ (itr′, i) -> Base.setindex(itr′, !itr′[i], i),
+ Base.OneTo(length(itr));
+ init = itr,
+ ),
+ "Bools",
+ skip_allocations_test,
+ )
+ @test_unrolled(
+ (itr,),
+ unrolled_reduce(
+ (itr′, i) -> unrolled_reduce(
+ (itr′′, j) ->
+ Base.setindex(itr′′, !itr′′[min(i, j)], j),
+ StaticOneTo(length(itr′));
+ init = itr′,
+ ),
+ StaticOneTo(length(itr));
+ init = itr,
+ ),
+ reduce(
+ (itr′, i) -> reduce(
+ (itr′′, j) ->
+ Base.setindex(itr′′, !itr′′[min(i, j)], j),
+ Base.OneTo(length(itr′));
+ init = itr′,
+ ),
+ Base.OneTo(length(itr));
+ init = itr,
+ ),
+ "Bools",
+ skip_allocations_test,
+ )
+ if length(itr) <= 256
+ @test_unrolled(
+ (itr,),
+ unrolled_reduce(
+ (itr′, i) -> unrolled_reduce(
+ (itr′′, j) -> unrolled_reduce(
+ (itr′′′, k) -> Base.setindex(
+ itr′′′,
+ !itr′′′[min(i, j, k)],
+ k,
+ ),
+ StaticOneTo(length(itr′′));
+ init = itr′′,
+ ),
+ StaticOneTo(length(itr′));
+ init = itr′,
+ ),
+ StaticOneTo(length(itr));
+ init = itr,
+ ),
+ reduce(
+ (itr′, i) -> reduce(
+ (itr′′, j) -> reduce(
+ (itr′′′, k) -> Base.setindex(
+ itr′′′,
+ !itr′′′[min(i, j, k)],
+ k,
+ ),
+ Base.OneTo(length(itr′′));
+ init = itr′′,
+ ),
+ Base.OneTo(length(itr′));
+ init = itr′,
+ ),
+ Base.OneTo(length(itr));
+ init = itr,
+ ),
+ "Bools",
+ skip_allocations_test,
+ )
+ end # The StaticBitVector{257} allocates over 2 GB for this test.
+ end
+end
+
+title = "Edge cases for unrolling"
+comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())
+
+@testset "unrolled functions of an empty Tuple" begin
+ itr = ()
+ str = "nothing"
+ @test_unrolled (itr,) unrolled_any(error, itr) any(error, itr) str
+ @test_unrolled (itr,) unrolled_all(error, itr) all(error, itr) str
+ @test_unrolled (itr,) unrolled_foreach(error, itr) foreach(error, itr) str
+ @test_unrolled (itr,) unrolled_map(error, itr) map(error, itr) str
+ @test_throws "init" unrolled_reduce(error, itr)
+ @test_unrolled(
+ (itr,),
+ unrolled_reduce(error, itr; init = 0),
+ reduce(error, itr; init = 0),
+ str,
+ )
+ @test_unrolled(
+ (itr,),
+ unrolled_accumulate(error, itr),
+ accumulate(error, itr),
+ str,
+ )
+ @test_unrolled(
+ (itr,),
+ unrolled_accumulate(error, itr; init = 0),
+ accumulate(error, itr; init = 0),
+ str,
+ )
+end
+
+@testset "unrolled functions of a large Tuple or StaticOneTo" begin
+ itr = StaticOneTo(8186)
+ @test_unrolled((itr,), unrolled_reduce(+, itr), reduce(+, itr), "Ints",)
+ @test_unrolled(
+ (itr,),
+ unrolled_mapreduce(log, +, itr),
+ mapreduce(log, +, itr),
+ "Ints",
+ )
+
+ # The previous tests also work with ntuple(identity, 8186), but they take a
+ # very long time to compile.
+
+ @test_throws "gc handles" unrolled_reduce(+, StaticOneTo(8187))
+ @test_throws "gc handles" unrolled_mapreduce(log, +, StaticOneTo(8187))
+
+ @test_throws "gc handles" unrolled_reduce(+, ntuple(identity, 8187))
+ @test_throws "gc handles" unrolled_mapreduce(log, +, ntuple(identity, 8187))
+
+ # TODO: Why does the compiler throw an error when generating functions that
+ # require more than 8187 calls to getindex (or, rather, generic_getindex)?
+end
+
+title = "Generative vs. recursive unrolling"
+comparison_table_dict = (comparison_table_dicts[title] = OrderedDict())
+
+for itr in (
+ tuple_of_tuples(1, 0, true, true),
+ tuple_of_tuples(1, 1, true, true),
+ tuple_of_tuples(1, 1, false, true),
+ map(n -> tuple_of_tuples(n, 0, true, true), (8, 16, 32, 33, 128, 256))...,
+ map(n -> tuple_of_tuples(n, 1, true, true), (8, 16, 32, 33, 128, 256))...,
+ map(n -> tuple_of_tuples(n, 1, false, true), (8, 16, 32, 33, 128, 256))...,
+ map(n -> tuple_of_tuples(n, 0, true, false), (8, 16, 32, 33, 128, 256))...,
+ map(n -> tuple_of_tuples(n, 1, true, false), (8, 16, 32, 33, 128, 256))...,
+ map(n -> tuple_of_tuples(n, 1, false, false), (8, 16, 32, 33, 128, 256))...,
+)
+ str = tuples_of_tuples_contents_str(itr)
+ itr_str = "a Tuple that contains $(length(itr)) $str"
+ @testset "generatively vs. recursively unrolled functions of $itr_str" begin
+ @test_unrolled(
+ (itr,),
+ UnrolledUtilities.gen_unrolled_any(isempty, itr),
+ UnrolledUtilities.rec_unrolled_any(isempty, itr),
+ str,
+ )
+
+ @test_unrolled(
+ (itr,),
+ UnrolledUtilities.gen_unrolled_all(isempty, itr),
+ UnrolledUtilities.rec_unrolled_all(isempty, itr),
+ str,
+ )
+
+ @test_unrolled(
+ (itr,),
+ UnrolledUtilities.gen_unrolled_foreach(
+ x -> @assert(length(x) <= 7),
+ itr,
+ ),
+ UnrolledUtilities.rec_unrolled_foreach(
+ x -> @assert(length(x) <= 7),
+ itr,
+ ),
+ str,
+ )
+
+ @test_unrolled(
+ (itr,),
+ UnrolledUtilities.gen_unrolled_map(length, itr),
+ UnrolledUtilities.rec_unrolled_map(length, itr),
+ str,
+ )
+
+ @test_unrolled(
+ (itr,),
+ UnrolledUtilities.gen_unrolled_applyat(
+ length,
+ rand(1:7:length(itr)),
+ itr,
+ ),
+ UnrolledUtilities.rec_unrolled_applyat(
+ length,
+ rand(1:7:length(itr)),
+ itr,
+ ),
+ str,
+ )
+
+ if length(itr) <= 33
+ @test_unrolled(
+ (itr,),
+ UnrolledUtilities.gen_unrolled_reduce(tuple, itr, ()),
+ UnrolledUtilities.rec_unrolled_reduce(tuple, itr, ()),
+ str,
+ )
+
+ @test_unrolled(
+ (itr,),
+ UnrolledUtilities.gen_unrolled_accumulate(
+ tuple,
+ itr,
+ (),
+ identity,
+ ),
+ UnrolledUtilities.rec_unrolled_accumulate(
+ tuple,
+ itr,
+ (),
+ identity,
+ ),
+ str,
+ )
+ end # These can take over a minute to compile when the length is 128.
+ end
+end