From 81fc2b70a03fcae06e5211cb3002b7cc75bef7bc Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Fri, 16 Aug 2024 14:09:55 -0700 Subject: [PATCH] Add support for non-Tuple iterators --- Project.toml | 7 + docs/make.jl | 18 +- docs/src/index.md | 117 ++- ext/UnrolledUtilitiesStaticArraysExt.jl | 12 + src/StaticBitVector.jl | 155 ++++ src/StaticOneTo.jl | 18 + src/UnrolledUtilities.jl | 191 +++-- src/generatively_unrolled_functions.jl | 62 ++ src/recursion_limits.jl | 55 ++ src/recursively_unrolled_functions.jl | 47 ++ src/unrollable_iterator_interface.jl | 216 ++++++ test/aqua.jl | 1 + test/runtests.jl | 4 +- test/test_and_analyze.jl | 981 +++++++++++++++++------- 14 files changed, 1485 insertions(+), 399 deletions(-) create mode 100644 ext/UnrolledUtilitiesStaticArraysExt.jl create mode 100644 src/StaticBitVector.jl create mode 100644 src/StaticOneTo.jl create mode 100644 src/generatively_unrolled_functions.jl create mode 100644 src/recursion_limits.jl create mode 100644 src/recursively_unrolled_functions.jl create mode 100644 src/unrollable_iterator_interface.jl diff --git a/Project.toml b/Project.toml index 62f5f0a..646c778 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,13 @@ version = "0.1.2" [compat] julia = "1.10" +StaticArrays = "1" + +[weakdeps] +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[extensions] +UnrolledUtilitiesStaticArraysExt = "StaticArrays" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/docs/make.jl b/docs/make.jl index ee728d5..2930c69 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,23 +2,25 @@ using Documenter include(joinpath("..", "test", "test_and_analyze.jl")) -comparison_table_file = joinpath("docs", "src", "comparison_table.md") +comparison_table_file = joinpath("docs", "src", "comparison_tables.md") open(comparison_table_file, "w") do io - println(io, "# Comparison Table\n```@raw html") - println(io, "
") # use 80% of viewport - print_comparison_table(io, true) - println(io, "
") - println(io, "```") + println(io, "# Comparison Tables") + for (title, comparison_table_dict) in comparison_table_dicts + print_comparison_table(title, comparison_table_dict, io) + end end makedocs(; sitename = "UnrolledUtilities.jl", modules = [UnrolledUtilities], - pages = ["Home" => "index.md", "Comparison Table" => "comparison_table.md"], + pages = [ + "Home" => "index.md", + "Comparison Tables" => "comparison_tables.md", + ], format = Documenter.HTML( prettyurls = get(ENV, "CI", nothing) == "true", - size_threshold_ignore = ["comparison_table.md"], + size_threshold_ignore = ["comparison_tables.md"], ), clean = true, ) diff --git a/docs/src/index.md b/docs/src/index.md index 8aaec38..5fa4958 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,15 +1,28 @@ +```@meta +CurrentModule = UnrolledUtilities +``` + # UnrolledUtilities.jl -A collection of generated functions in which all loops are unrolled and inlined: +## Unrolled Functions + +This package exports the following functions, in which all loops are unrolled +and inlined: - `unrolled_any(f, itr)`: similar to `any` - `unrolled_all(f, itr)`: similar to `all` - `unrolled_foreach(f, itrs...)`: similar to `foreach` - `unrolled_map(f, itrs...)`: similar to `map` +- `unrolled_applyat(f, n, itrs...)`: similar to `f(map(itr -> itr[n], itrs)...)` - `unrolled_reduce(op, itr; [init])`: similar to `reduce` - `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce` -- `unrolled_zip(itrs...)`: similar to `zip` -- `unrolled_enumerate(itrs...)`: similar to `enumerate`, but with the ability to - handle multiple iterators +- `unrolled_accumulate(op, itr; [init], [transform])`: similar to `accumulate`, + but with an optional `transform` function applied to every accumulated value +- `unrolled_push(itr, item)`: similar to `push!`, but non-mutating +- `unrolled_append(itr1, itr2)`: similar to `append!`, but non-mutating +- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take` (and to + `itr[1:N]`), but with `N` wrapped in a `Val` +- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop` (and to + `itr[(N + 1):end]`), but with `N` wrapped in a `Val` - `unrolled_in(item, itr)`: similar to `in` - `unrolled_unique(itr)`: similar to `unique` - `unrolled_filter(f, itr)`: similar to `filter` @@ -18,11 +31,6 @@ A collection of generated functions in which all loops are unrolled and inlined: - `unrolled_flatten(itr)`: similar to `Iterators.flatten` - `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap` - `unrolled_product(itrs...)`: similar to `Iterators.product` -- `unrolled_applyat(f, n, itrs...)`: similar to `f(map(itr -> itr[n], itrs)...)` -- `unrolled_take(itr, ::Val{N})`: similar to `itr[1:N]` (and to - `Iterators.take`), but with `N` wrapped in a `Val` -- `unrolled_drop(itr, ::Val{N})`: similar to `itr[(N + 1):end]` (and to - `Iterators.drop`), but with `N` wrapped in a `Val` These functions are guaranteed to be type-stable whenever they are given iterators with inferrable lengths and element types, including when @@ -42,34 +50,77 @@ iterators have singleton element types (and when the result of calling `f` and/or `op` on these elements is inferrable). However, they can also be much more expensive to compile than their counterparts from `Base` and `Base.Iterators`, in which case they should not be used unless there is a clear -performance benefit. Some notable exceptions to this are `unrolled_zip`, -`unrolled_take`, and `unrolled_drop`, which tend to be easier to compile than -`zip`, `Iterators.take`, `Iterators.drop`, and standard indexing notation. +performance benefit. Two notable exceptions to this are `unrolled_take` and +`unrolled_drop`, which are faster to compile than their non-static versions. + +## Interface + +These functions can be used to unroll loops over all iterators with statically +inferrable lengths. Compatibility with any such iterator type can be added +through the following interface: + +```@docs +rec_unroll +generic_getindex +output_type_for_promotion +NoOutputType +ConditionalOutputType +output_promote_rule +constructor_from_tuple +``` + +This interface is used to provide built-in compatibility with +- statically sized iterators from `Base` (`Tuple` and `NamedTuple`) +- lazy iterators from `Base` (`enumerate`, `zip`, `Iterators.map`, and other + generator expressions) +- statically sized iterators from + [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) (`SVector` + and `MVector`) +- custom lazy and low-storage iterators (`StaticOneTo` and `StaticBitVector`) + +```@docs +StaticOneTo +StaticBitVector +``` + +## When to Unroll For a more precise indication of whether you should use `UnrolledUtilities`, -please consult the autogenerated [Comparison Table](@ref). This table contains a -comprehensive set of potential use cases, each with a measurement of performance -optimization, the time required for compilation, and the memory usage during -compilation. Most cases involve simple functions `f` and/or `op`, but the last -few demonstrate the benefits of unrolling with non-trivial recursive functions. +please consult the autogenerated [Comparison Tables](@ref). These tables contain +a comprehensive set of potential use cases, along with a few measurements that +summarize their performance, compilation, and allocations: +- run time (best of several trial measurements) +- compilation time (as reported by the compiler) +- overall level of optimization (type stability, constant propagation, etc.) and + allocations during run time (as reported by the garbage collector) +- total allocations during compilation and first run (as reported by the garbage + collector and, when possible, the Julia process's resident set size estimator) -The rows of the table are highlighted as follows: -- green indicates an improvement in performance and either no change in - compilation or easier compilation (i.e., either similar or smaller values of - compilation time and memory usage) -- dark blue indicates an improvement in performance and harder compilation - (i.e., larger values of compilation time and/or memory usage) -- light blue indicates no change in performance and easier compilation -- yellow indicates no change in performance and no change in compilation -- magenta indicates no change in performance, an increase in compilation time, - and a decrease in compilation memory usage -- red indicates no change in performance and harder compilation +The rows of the tables are highlighted as follows: +- light blue indicates an improvement in performance due to better optimization + and either an improvement or no change in compilation time and total + allocations +- green indicates either faster run time or fewer allocations during run time + and either an improvement or no change in compilation time and total + allocations +- dark blue indicates an improvement in performance due to better optimization + and either slower compilation or more total allocations +- yellow indicates either faster run time or fewer allocations during run time + and either slower compilation or more total allocations +- magenta indicates no change in performance and either an improvement or no + change in compilation time and total allocations +- light gray indicates no change in performance and no change in compilation + time and total allocations +- dark gray indicates no change in performance and either faster compilation + with more total allocations or slower compilation with fewer total allocations +- red indicates a deterioration in performance, or no change in + performance and either slower compilation or more total allocations -Rows highlighted in green and blue present a clear advantage for unrolling, -whereas those highlighted in yellow, magenta, and red either have no clear -advantage, or they have a clear disadvantage. It is recommended that you only -unroll when your use case is similar to a row in the first category. +Rows highlighted in gray present no clear advantage to unrolling, while those +highlighted in red present a clear disadvantage. It is recommended that you only +unroll when your use case is similar to a row in one of the remaining +categories, each of which demonstrates some advantage to unrolling. -The table is also printed out by this package's unit tests, so these +The tables are also printed out by this package's unit tests, so these measurements can be compared across different operating systems by checking the [CI pipeline](https://github.com/CliMA/UnrolledUtilities.jl/actions/workflows/ci.yml). diff --git a/ext/UnrolledUtilitiesStaticArraysExt.jl b/ext/UnrolledUtilitiesStaticArraysExt.jl new file mode 100644 index 0000000..67058a7 --- /dev/null +++ b/ext/UnrolledUtilitiesStaticArraysExt.jl @@ -0,0 +1,12 @@ +module UnrolledUtilitiesStaticArraysExt + +import UnrolledUtilities +import StaticArrays: SVector, MVector + +@inline UnrolledUtilities.output_type_for_promotion(::SVector) = SVector +@inline UnrolledUtilities.constructor_from_tuple(::Type{SVector}) = SVector + +@inline UnrolledUtilities.output_type_for_promotion(::MVector) = MVector +@inline UnrolledUtilities.constructor_from_tuple(::Type{MVector}) = MVector + +end diff --git a/src/StaticBitVector.jl b/src/StaticBitVector.jl new file mode 100644 index 0000000..5605faf --- /dev/null +++ b/src/StaticBitVector.jl @@ -0,0 +1,155 @@ +""" + StaticBitVector{N, [U]}(f) + StaticBitVector{N, [U]}([bit]) + +A statically-sized analogue of `BitVector` with `Unsigned` chunks of type `U`, +which can be constructed using either a function `f(n)` or a constant `bit`. By +default, `U` is set to `UInt8` and `bit` is set to `false`. + +This iterator can only store `Bool`s, so its `output_type_for_promotion` is a +`ConditionalOutputType`. Efficient methods are provided for `unrolled_map`, +`unrolled_accumulate`, `unrolled_take`, and `unrolled_drop`, though the methods +for `unrolled_map` and `unrolled_accumulate` only apply when their output's +first item is a `Bool`. No other unrolled functions can use `StaticBitVector`s +as output types. +""" +struct StaticBitVector{N, U <: Unsigned, I <: NTuple{<:Any, U}} <: + StaticSequence{N} + ints::I +end +@inline StaticBitVector{N, U}(ints) where {N, U} = + StaticBitVector{N, U, typeof(ints)}(ints) +@inline StaticBitVector{N}(args...) where {N} = + StaticBitVector{N, UInt8}(args...) + +@inline function StaticBitVector{N, U}(bit::Bool = false) where {N, U} + n_bits_per_int = 8 * sizeof(U) + n_ints = cld(N, n_bits_per_int) + ints = ntuple(Returns(bit ? ~zero(U) : zero(U)), Val(n_ints)) + return StaticBitVector{N, U}(ints) +end + +@inline function StaticBitVector{N, U}(f::Function) where {N, U} + n_bits_per_int = 8 * sizeof(U) + n_ints = cld(N, n_bits_per_int) + ints = ntuple(Val(n_ints)) do int_index + @inline + first_index = n_bits_per_int * (int_index - 1) + 1 + unrolled_reduce( + StaticOneTo(min(n_bits_per_int, N - first_index + 1)); + init = zero(U), + ) do int, bit_index + @inline + bit_offset = bit_index - 1 + int | U(f(first_index + bit_offset)::Bool) << bit_offset + end + end + return StaticBitVector{N, U}(ints) +end + +@inline function int_index_and_bit_offset(::Type{U}, n) where {U} + int_offset, bit_offset = divrem(n - 1, 8 * sizeof(U)) + return (int_offset + 1, bit_offset) +end + +@inline function generic_getindex( + itr::StaticBitVector{<:Any, U}, + n::Integer, +) where {U} + int_index, bit_offset = int_index_and_bit_offset(U, n) + int = itr.ints[int_index] + return Bool(int >> bit_offset & one(int)) +end + +@inline function Base.setindex( + itr::StaticBitVector{N, U}, + bit::Bool, + n::Integer, +) where {N, U} + int_index, bit_offset = int_index_and_bit_offset(U, n) + int = itr.ints[int_index] + int′ = int & ~(one(int) << bit_offset) | U(bit) << bit_offset + ints = Base.setindex(itr.ints, int′, int_index) + return StaticBitVector{N, U}(ints) +end + +@inline output_type_for_promotion(::StaticBitVector{<:Any, U}) where {U} = + ConditionalOutputType(Bool, StaticBitVector{<:Any, U}) + +@inline function unrolled_map_into( + ::Type{StaticBitVector{<:Any, U}}, + f, + itrs..., +) where {U} + lazy_itr = Iterators.map(f, itrs...) + N = length(lazy_itr) + return StaticBitVector{N, U}(Base.Fix1(generic_getindex, lazy_itr)) +end + +@inline function unrolled_accumulate_into( + ::Type{StaticBitVector{<:Any, U}}, + op, + itr, + init, + transform, +) where {U} + N = length(itr) + n_bits_per_int = 8 * sizeof(U) + n_ints = cld(N, n_bits_per_int) + ints = unrolled_accumulate_into_tuple( + StaticOneTo(n_ints); + init = (nothing, init), + transform = first, + ) do (_, init_value_for_new_int), int_index + @inline + first_index = n_bits_per_int * (int_index - 1) + 1 + unrolled_reduce( + StaticOneTo(min(n_bits_per_int, N - first_index + 1)); + init = (zero(U), init_value_for_new_int), + ) do (int, prev_value), bit_index + @inline + bit_offset = bit_index - 1 + item = generic_getindex(itr, first_index + bit_offset) + new_value = + first_index + bit_offset == 1 && prev_value isa NoInit ? + item : op(prev_value, item) + (int | U(transform(new_value)::Bool) << bit_offset, new_value) + end + end + return StaticBitVector{N, U}(ints) +end + +# TODO: Add unrolled_push and unrolled_append + +@inline function unrolled_take( + itr::StaticBitVector{<:Any, U}, + ::Val{N}, +) where {N, U} + n_bits_per_int = 8 * sizeof(U) + n_ints = cld(N, n_bits_per_int) + ints = unrolled_take(itr.ints, Val(n_ints)) + return StaticBitVector{N, U}(ints) +end + +@inline function unrolled_drop( + itr::StaticBitVector{N_old, U}, + ::Val{N}, +) where {N_old, N, U} + n_bits_per_int = 8 * sizeof(U) + n_ints = cld(N_old - N, n_bits_per_int) + n_dropped_ints = length(itr.ints) - n_ints + bit_offset = N - n_bits_per_int * n_dropped_ints + ints_without_offset = unrolled_drop(itr.ints, Val(n_dropped_ints)) + ints = if bit_offset == 0 + ints_without_offset + else + cur_ints = ints_without_offset + next_ints = unrolled_push(unrolled_drop(cur_ints, Val(1)), nothing) + unrolled_map_into_tuple(cur_ints, next_ints) do cur_int, next_int + @inline + isnothing(next_int) ? cur_int >> bit_offset : + cur_int >> bit_offset | next_int << (n_bits_per_int - bit_offset) + end + end + return StaticBitVector{N_old - N, U}(ints) +end diff --git a/src/StaticOneTo.jl b/src/StaticOneTo.jl new file mode 100644 index 0000000..2721169 --- /dev/null +++ b/src/StaticOneTo.jl @@ -0,0 +1,18 @@ +""" + StaticOneTo(N) + +A lazy and statically-sized analogue of `Base.OneTo(N)`. + +This iterator can only store the integers from 1 to `N`, so its +`output_type_for_promotion` is `NoOutputType()`. An efficient method is provided +for `unrolled_take`, but no other unrolled functions can use `StaticOneTo`s as +output types. +""" +struct StaticOneTo{N} <: StaticSequence{N} end +@inline StaticOneTo(N) = StaticOneTo{N}() + +@inline generic_getindex(::StaticOneTo, n) = n + +@inline output_type_for_promotion(::StaticOneTo) = NoOutputType() + +@inline unrolled_take(::StaticOneTo, ::Val{N}) where {N} = StaticOneTo(N) diff --git a/src/UnrolledUtilities.jl b/src/UnrolledUtilities.jl index dc69559..52cf883 100644 --- a/src/UnrolledUtilities.jl +++ b/src/UnrolledUtilities.jl @@ -4,10 +4,14 @@ export unrolled_any, unrolled_all, unrolled_foreach, unrolled_map, + unrolled_applyat, unrolled_reduce, unrolled_mapreduce, - unrolled_zip, - unrolled_enumerate, + unrolled_accumulate, + unrolled_push, + unrolled_append, + unrolled_take, + unrolled_drop, unrolled_in, unrolled_unique, unrolled_filter, @@ -15,114 +19,145 @@ export unrolled_any, unrolled_flatten, unrolled_flatmap, unrolled_product, - unrolled_applyat, - unrolled_take, - unrolled_drop - -inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N -# We could also add support for statically-sized iterators that are not Tuples. - -f_exprs(itr_type) = (:(f(itr[$n])) for n in 1:inferred_length(itr_type)) -@inline @generated unrolled_any(f, itr) = Expr(:||, f_exprs(itr)...) -@inline @generated unrolled_all(f, itr) = Expr(:&&, f_exprs(itr)...) - -function zipped_f_exprs(itr_types) - L = length(itr_types) - L == 0 && error("unrolled functions need at least one iterator as input") - N = minimum(inferred_length, itr_types) - return (:(f($((:(itrs[$l][$n]) for l in 1:L)...))) for n in 1:N) -end -@inline @generated unrolled_foreach(f, itrs...) = - Expr(:block, zipped_f_exprs(itrs)..., nothing) -@inline @generated unrolled_map(f, itrs...) = - Expr(:tuple, zipped_f_exprs(itrs)...) - -function nested_op_expr(itr_type) - N = inferred_length(itr_type) - N == 0 && error("unrolled_reduce needs an `init` value for empty iterators") - item_exprs = (:(itr[$n]) for n in 1:N) - return reduce((expr1, expr2) -> :(op($expr1, $expr2)), item_exprs) -end -@inline @generated unrolled_reduce_without_init(op, itr) = nested_op_expr(itr) - -struct NoInit end + StaticOneTo, + StaticBitVector + +include("unrollable_iterator_interface.jl") +include("recursively_unrolled_functions.jl") +include("generatively_unrolled_functions.jl") + +struct NoInit end # Analogue of Base._InitialValue for reduction/accumulation. + +@inline unrolled_any(f, itr) = + (rec_unroll(itr) ? rec_unrolled_any : gen_unrolled_any)(f, itr) +@inline unrolled_any(itr) = unrolled_any(identity, itr) + +@inline unrolled_all(f, itr) = + (rec_unroll(itr) ? rec_unrolled_all : gen_unrolled_all)(f, itr) +@inline unrolled_all(itr) = unrolled_all(identity, itr) + +@inline unrolled_foreach(f, itr) = + (rec_unroll(itr) ? rec_unrolled_foreach : gen_unrolled_foreach)(f, itr) +@inline unrolled_foreach(f, itrs...) = unrolled_foreach(splat(f), zip(itrs...)) + +@inline unrolled_map_into_tuple(f, itr) = + (rec_unroll(itr) ? rec_unrolled_map : gen_unrolled_map)(f, itr) +@inline unrolled_map_into_tuple(f, itrs...) = + unrolled_map_into_tuple(splat(f), zip(itrs...)) +@inline unrolled_map_into(output_type, f, itrs...) = + constructor_from_tuple(output_type)(unrolled_map_into_tuple(f, itrs...)) +@inline unrolled_map(f, itrs...) = + unrolled_map_into(map_output_type(f, itrs...), f, itrs...) + +@inline unrolled_applyat(f, n, itr) = + (rec_unroll(itr) ? rec_unrolled_applyat : gen_unrolled_applyat)(f, n, itr) +@inline unrolled_applyat(f, n, itrs...) = + unrolled_applyat(splat(f), n, zip(itrs...)) +@inline unrolled_applyat_bounds_error() = + error("unrolled_applyat has detected an out-of-bounds index") + +@inline unrolled_reduce(op, itr, init) = + (rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init) @inline unrolled_reduce(op, itr; init = NoInit()) = - unrolled_reduce_without_init(op, init isa NoInit ? itr : (init, itr...)) + isempty(itr) && init isa NoInit ? + error("unrolled_reduce requires an init value for empty iterators") : + unrolled_reduce(op, itr, init) @inline unrolled_mapreduce(f, op, itrs...; init = NoInit()) = - unrolled_reduce(op, unrolled_map(f, itrs...); init) + unrolled_reduce(op, Iterators.map(f, itrs...), init) + +@inline unrolled_accumulate_into_tuple(op, itr, init, transform) = + (rec_unroll(itr) ? rec_unrolled_accumulate : gen_unrolled_accumulate)( + op, + itr, + init, + transform, + ) +@inline unrolled_accumulate_into(output_type, op, itr, init, transform) = + constructor_from_tuple(output_type)( + unrolled_accumulate_into_tuple(op, itr, init, transform), + ) +@inline unrolled_accumulate(op, itr; init = NoInit(), transform = identity) = + unrolled_accumulate_into( + accumulate_output_type(op, itr, init, transform), + op, + itr, + init, + transform, + ) -@inline unrolled_zip(itrs...) = unrolled_map(tuple, itrs...) +@inline unrolled_push(itr, item) = + inferred_constructor_from_tuple(itr)((itr..., item)) -@inline unrolled_enumerate(itrs...) = - unrolled_zip(ntuple(identity, Val(length(itrs[1]))), itrs...) +@inline unrolled_append_into(output_type, itr1, itr2) = + constructor_from_tuple(output_type)((itr1..., itr2...)) +@inline unrolled_append(itr1, itr2) = + unrolled_append_into(promoted_output_type((itr1, itr2)), itr1, itr2) + +@inline unrolled_take(itr, ::Val{N}) where {N} = + inferred_constructor_from_tuple(itr)( + ntuple(Base.Fix1(generic_getindex, itr), Val(N)), + ) + +@inline unrolled_drop(itr, ::Val{N}) where {N} = + inferred_constructor_from_tuple(itr)( + ntuple(n -> generic_getindex(itr, N + n), Val(length(itr) - N)), + ) @inline unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr) # Using === instead of == or isequal improves type stability for singletons. @inline unrolled_unique(itr) = - unrolled_reduce(itr; init = ()) do unique_items, item + unrolled_reduce(itr; init = inferred_empty(itr)) do unique_items, item @inline - unrolled_in(item, unique_items) ? unique_items : (unique_items..., item) + unrolled_in(item, unique_items) ? unique_items : + unrolled_push(unique_items, item) end @inline unrolled_filter(f, itr) = - unrolled_reduce(itr; init = ()) do filtered_items, item + unrolled_reduce(itr; init = inferred_empty(itr)) do items_with_true_f, item @inline - f(item) ? (filtered_items..., item) : filtered_items + f(item) ? unrolled_push(items_with_true_f, item) : items_with_true_f end @inline unrolled_split(f, itr) = - unrolled_reduce(itr; init = ((), ())) do (f_items, not_f_items), item + unrolled_reduce( + itr; + init = (inferred_empty(itr), inferred_empty(itr)), + ) do (items_with_true_f, items_with_false_f), item @inline - f(item) ? ((f_items..., item), not_f_items) : - (f_items, (not_f_items..., item)) + f(item) ? (unrolled_push(items_with_true_f, item), items_with_false_f) : + (items_with_true_f, unrolled_push(items_with_false_f, item)) end @inline unrolled_flatten(itr) = - unrolled_reduce((item1, item2) -> (item1..., item2...), itr; init = ()) + unrolled_reduce(unrolled_append, itr; init = promoted_empty(itr)) @inline unrolled_flatmap(f, itrs...) = - unrolled_flatten(unrolled_map(f, itrs...)) + unrolled_flatten(Iterators.map(f, itrs...)) @inline unrolled_product(itrs...) = - unrolled_reduce(itrs; init = ((),)) do product_itr, itr + unrolled_reduce(itrs; init = (promoted_empty(itrs),)) do product_itr, itr @inline unrolled_flatmap(itr) do item @inline - unrolled_map(product_tuple -> (product_tuple..., item), product_itr) + unrolled_map_into_tuple(Base.Fix2(unrolled_push, item), product_itr) end end -@inline unrolled_applyat(f, n, itrs...) = unrolled_foreach( - (i, items...) -> i == n && f(items...), - unrolled_enumerate(itrs...), -) +abstract type StaticSequence{N} end -@inline unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N)) -@inline unrolled_drop(itr, ::Val{N}) where {N} = - ntuple(i -> itr[N + i], Val(length(itr) - N)) -# When its second argument is a Val, ntuple is unrolled via Base.@ntuple. - -@static if hasfield(Method, :recursion_relation) - # Remove recursion limits for functions whose arguments are also functions. - for func in ( - unrolled_any, - unrolled_all, - unrolled_foreach, - unrolled_map, - unrolled_reduce_without_init, - unrolled_reduce, - unrolled_mapreduce, - unrolled_filter, - unrolled_split, - unrolled_flatmap, - unrolled_applyat, - ) - for method in methods(func) - method.recursion_relation = (_...) -> true - end - end -end +@inline Base.length(::StaticSequence{N}) where {N} = N +@inline Base.firstindex(::StaticSequence) = 1 +@inline Base.lastindex(itr::StaticSequence) = length(itr) +@inline Base.getindex(itr::StaticSequence, n::Integer) = + generic_getindex(itr, n) +@inline Base.iterate(itr::StaticSequence, n = 1) = + n > length(itr) ? nothing : (generic_getindex(itr, n), n + 1) + +include("StaticOneTo.jl") +include("StaticBitVector.jl") + +include("recursion_limits.jl") # This must be included at the end of the module. end diff --git a/src/generatively_unrolled_functions.jl b/src/generatively_unrolled_functions.jl new file mode 100644 index 0000000..e8afe0b --- /dev/null +++ b/src/generatively_unrolled_functions.jl @@ -0,0 +1,62 @@ +f_exprs(N) = (:(f(generic_getindex(itr, $n))) for n in 1:N) + +@inline @generated _gen_unrolled_any(::Val{N}, f, itr) where {N} = + Expr(:||, f_exprs(N)...) +@inline gen_unrolled_any(f, itr) = _gen_unrolled_any(Val(length(itr)), f, itr) + +@inline @generated _gen_unrolled_all(::Val{N}, f, itr) where {N} = + Expr(:&&, f_exprs(N)...) +@inline gen_unrolled_all(f, itr) = _gen_unrolled_all(Val(length(itr)), f, itr) + +@inline @generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} = + Expr(:block, f_exprs(N)..., nothing) +@inline gen_unrolled_foreach(f, itr) = + _gen_unrolled_foreach(Val(length(itr)), f, itr) + +@inline @generated _gen_unrolled_map(::Val{N}, f, itr) where {N} = + Expr(:tuple, f_exprs(N)...) +@inline gen_unrolled_map(f, itr) = _gen_unrolled_map(Val(length(itr)), f, itr) + +@inline @generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr( + :block, + map(((n, expr),) -> :(n == $n && return $expr), enumerate(f_exprs(N)))..., + :(unrolled_applyat_bounds_error()), +) # This block should get optimized into a switch statement during LLVM codegen. +@inline gen_unrolled_applyat(f, n, itr) = + _gen_unrolled_applyat(Val(length(itr)), f, n, itr) + +function nested_op_expr(N, init_type) + init_expr = init_type <: NoInit ? :(generic_getindex(itr, 1)) : :init + n_range = init_type <: NoInit ? (2:N) : (1:N) + return foldl(n_range; init = init_expr) do prev_op_expr, n + :(op($prev_op_expr, generic_getindex(itr, $n))) + end # Use foldl instead of reduce to guarantee left associativity. +end + +@inline @generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} = + nested_op_expr(N, init) +@inline gen_unrolled_reduce(op, itr, init) = + _gen_unrolled_reduce(Val(length(itr)), op, itr, init) + +function transformed_sequential_op_exprs(N, init_type) + first_item_expr = :(generic_getindex(itr, 1)) + init_expr = + init_type <: NoInit ? first_item_expr : :(op(init, $first_item_expr)) + transformed_exprs_and_next_op_exprs = + accumulate(1:N; init = (nothing, init_expr)) do (_, prev_op_expr), n + var = gensym() + next_op_expr = :(op($var, generic_getindex(itr, $(n + 1)))) + (:($var = $prev_op_expr; transform($var)), next_op_expr) + end + return map(first, transformed_exprs_and_next_op_exprs) +end + +@inline @generated _gen_unrolled_accumulate( + ::Val{N}, + op, + itr, + init, + transform, +) where {N} = Expr(:tuple, transformed_sequential_op_exprs(N, init)...) +@inline gen_unrolled_accumulate(op, itr, init, transform) = + _gen_unrolled_accumulate(Val(length(itr)), op, itr, init, transform) diff --git a/src/recursion_limits.jl b/src/recursion_limits.jl new file mode 100644 index 0000000..2098fe9 --- /dev/null +++ b/src/recursion_limits.jl @@ -0,0 +1,55 @@ +# Remove recursion limits from functions that call themselves, and also from all +# functions whose arguments can be arbitrary functions (including themselves). +@static if hasfield(Method, :recursion_relation) + for func in ( + generic_getindex, + output_type_for_promotion, + _rec_unrolled_any, + _rec_unrolled_all, + _rec_unrolled_foreach, + _rec_unrolled_map, + _rec_unrolled_applyat, + _rec_unrolled_reduce, + _rec_unrolled_accumulate, + rec_unrolled_any, + rec_unrolled_all, + rec_unrolled_foreach, + rec_unrolled_map, + rec_unrolled_applyat, + rec_unrolled_reduce, + rec_unrolled_accumulate, + _gen_unrolled_any, + _gen_unrolled_all, + _gen_unrolled_foreach, + _gen_unrolled_map, + _gen_unrolled_applyat, + _gen_unrolled_reduce, + _gen_unrolled_accumulate, + gen_unrolled_any, + gen_unrolled_all, + gen_unrolled_foreach, + gen_unrolled_map, + gen_unrolled_applyat, + gen_unrolled_reduce, + gen_unrolled_accumulate, + unrolled_any, + unrolled_all, + unrolled_foreach, + unrolled_map_into_tuple, + unrolled_map_into, + unrolled_map, + unrolled_applyat, + unrolled_reduce, + unrolled_mapreduce, + unrolled_accumulate_into_tuple, + unrolled_accumulate_into, + unrolled_accumulate, + unrolled_filter, + unrolled_split, + unrolled_flatmap, + ) + for method in methods(func) + method.recursion_relation = Returns(true) + end + end +end diff --git a/src/recursively_unrolled_functions.jl b/src/recursively_unrolled_functions.jl new file mode 100644 index 0000000..db88bef --- /dev/null +++ b/src/recursively_unrolled_functions.jl @@ -0,0 +1,47 @@ +@inline _rec_unrolled_any(f) = false +@inline _rec_unrolled_any(f, item, items...) = + f(item) || _rec_unrolled_any(f, items...) +@inline rec_unrolled_any(f, itr) = _rec_unrolled_any(f, itr...) + +@inline _rec_unrolled_all(f) = true +@inline _rec_unrolled_all(f, item, items...) = + f(item) && _rec_unrolled_all(f, items...) +@inline rec_unrolled_all(f, itr) = _rec_unrolled_all(f, itr...) + +@inline _rec_unrolled_foreach(f) = nothing +@inline _rec_unrolled_foreach(f, item, items...) = + (f(item); _rec_unrolled_foreach(f, items...)) +@inline rec_unrolled_foreach(f, itr) = _rec_unrolled_foreach(f, itr...) + +@inline _rec_unrolled_map(f) = () +@inline _rec_unrolled_map(f, item, items...) = + (f(item), _rec_unrolled_map(f, items...)...) +@inline rec_unrolled_map(f, itr) = _rec_unrolled_map(f, itr...) + +@inline _rec_unrolled_applyat(f, offset_n) = unrolled_applyat_bounds_error() +@inline _rec_unrolled_applyat(f, offset_n, item, items...) = + offset_n == 1 ? f(item) : _rec_unrolled_applyat(f, offset_n - 1, items...) +@inline rec_unrolled_applyat(f, n, itr) = _rec_unrolled_applyat(f, n, itr...) + +@inline _rec_unrolled_reduce(op, prev_value) = prev_value +@inline _rec_unrolled_reduce(op, prev_value, item, items...) = + _rec_unrolled_reduce(op, op(prev_value, item), items...) +@inline rec_unrolled_reduce(op, itr, init) = + init isa NoInit ? _rec_unrolled_reduce(op, itr...) : + _rec_unrolled_reduce(op, init, itr...) + +@inline _rec_unrolled_accumulate(op, transform, prev_value) = + (transform(prev_value),) +@inline _rec_unrolled_accumulate(op, transform, prev_value, item, items...) = ( + transform(prev_value), + _rec_unrolled_accumulate(op, transform, op(prev_value, item), items...)..., +) +@inline rec_unrolled_accumulate(op, itr, init, transform) = + isempty(itr) ? () : + init isa NoInit ? _rec_unrolled_accumulate(op, transform, itr...) : + _rec_unrolled_accumulate( + op, + transform, + op(init, generic_getindex(itr, 1)), + unrolled_drop(itr, Val(1))..., + ) diff --git a/src/unrollable_iterator_interface.jl b/src/unrollable_iterator_interface.jl new file mode 100644 index 0000000..9784996 --- /dev/null +++ b/src/unrollable_iterator_interface.jl @@ -0,0 +1,216 @@ +#= +To unroll over a statically-sized iterator of type T, follow these steps: +- Add a method for either getindex(::T, n) or generic_getindex(::T, n) +- If every unrolled function that needs to construct an iterator when given an + iterator of type T can return a Tuple instead, stop here; otherwise, to return + a non-Tuple iterator whenever possible, follow these steps: + - Add a method for output_type_for_promotion(::T) = O, where O can be T, a + supertype of T, or some other type or AmbiguousOutputType + - If an output of type O can be used together with an output of type O′, add + a method for output_promote_rule(O, O′) + - If an output of type O can be efficiently constructed from a Tuple, add a + method for constructor_from_tuple(O); otherwise, add a method for each + unrolled function that can efficiently generate an output of type O + without temporarily storing it as a Tuple +=# + +""" + rec_unroll(itr) + +Whether to use recursive loop unrolling instead of generative loop unrolling for +the iterator `itr`. + +In general, recursive loop unrolling is faster to compile for small iterators, +but it becomes extremely slow to compile for large iterators, and it usually +generates suboptimal LLVM code for large iterators. On the other hand, +generative loop unrolling is slow to compile for small iterators, but its +compilation time does not grow as rapidly with respect to iterator size, and it +always generates optimal LLVM code. The default is currently to use recursive +unrolling for iterators lengths up to 16, based on the latest compilation and +performance comparisons of recursively and generatively unrolled functions. +""" +@inline rec_unroll(itr) = length(itr) <= 16 + +""" + generic_getindex(itr, n) + +An alternative to `Base.getindex` that can also handle internal types such as +`Base.Generator` and `Base.Iterators.Enumerate`. Defaults to `Base.getindex`. +""" +@inline generic_getindex(itr, n) = getindex(itr, n) +@inline generic_getindex(itr::Base.Generator, n) = + itr.f(generic_getindex(itr.iter, n)) +@inline generic_getindex(itr::Base.Iterators.Enumerate, n) = + (n, generic_getindex(itr.itr, n)) +@inline generic_getindex(itr::Base.Iterators.Zip, n) = + unrolled_map_into_tuple(Base.Fix2(generic_getindex, n), itr.is) + +@inline first_item_type(itr) = + Base.promote_op(Base.Fix2(generic_getindex, 1), typeof(itr)) +@inline second_item_type(itr) = + Base.promote_op(Base.Fix2(generic_getindex, 2), typeof(itr)) + +""" + output_type_for_promotion(itr) + +The type of output that unrolled functions should try to generate for the input +iterator `itr`, or a `ConditionalOutputType` if the output type depends on the +type of items that need to be stored in it, or `NoOutputType()` if `itr` is a +lazy iterator without any associated output type. Defaults to `Tuple`. +""" +@inline output_type_for_promotion(_) = Tuple +@inline output_type_for_promotion(::NamedTuple{names}) where {names} = + NamedTuple{names} +@inline output_type_for_promotion(itr::Base.Generator) = + output_type_for_promotion(itr.iter) +@inline output_type_for_promotion(itr::Base.Iterators.Enumerate) = + output_type_for_promotion(itr.itr) +@inline output_type_for_promotion(itr::Base.Iterators.Zip) = + maybe_ambiguous_promoted_output_type(itr.is) + +abstract type AmbiguousOutputType end + +""" + NoOutputType() + +The `output_type_for_promotion` of lazy iterators. +""" +struct NoOutputType <: AmbiguousOutputType end + +""" + ConditionalOutputType(allowed_item_type, output_type, [fallback_type]) + +An `output_type_for_promotion` that can have one of two possible values. If the +first item in the output is a subtype of `allowed_item_type`, the output will +have the type `output_type`; otherwise, it will have the type `fallback_type`, +which is set to `Tuple` by default. +""" +struct ConditionalOutputType{I, O, O′} <: AmbiguousOutputType end +@inline ConditionalOutputType( + allowed_item_type::Type, + output_type::Type, + fallback_type::Type = Tuple, +) = ConditionalOutputType{allowed_item_type, output_type, fallback_type}() + +@inline unambiguous_output_type(_, ::Type{O}) where {O} = O +@inline unambiguous_output_type(_, ::NoOutputType) = Tuple +@inline unambiguous_output_type( + get_first_item_type, + ::ConditionalOutputType{I, O, O′}, +) where {I, O, O′} = get_first_item_type() <: I ? O : O′ + +""" + output_promote_rule(output_type1, output_type2) + +The type of output that should be generated when two iterators do not have the +same `output_type_for_promotion`, or `Union{}` if these iterators should not be +used together. Only one method of `output_promote_rule` needs to be defined for +any pair of output types. + +By default, all types take precedence over `NoOutputType()`, and the conditional +part of any `ConditionalOutputType` takes precedence over an unconditional type +(so that only the `fallback_type` of any conditional type gets promoted). The +default result for all other pairs of unequal output types is `Union{}`. +""" +@inline output_promote_rule(_, _) = Union{} +@inline output_promote_rule(::Type{O}, ::Type{O}) where {O} = O +@inline output_promote_rule(::NoOutputType, output_type) = output_type +@inline output_promote_rule( + ::ConditionalOutputType{I, O, O′}, + ::Type{O′′}, +) where {I, O, O′, O′′} = + ConditionalOutputType(I, O, output_promote_rule(O′, O′′)) +@inline output_promote_rule( + ::Type{O′}, + ::ConditionalOutputType{I, O, O′′}, +) where {I, O, O′, O′′} = + ConditionalOutputType(I, O, output_promote_rule(O′, O′′)) +@inline output_promote_rule( + ::ConditionalOutputType{I, O, O′}, + ::ConditionalOutputType{I, O, O′′}, +) where {I, O, O′, O′′} = + ConditionalOutputType(I, O, output_promote_rule(O′, O′′)) + +@inline function output_promote_result(O1, O2) + O12 = output_promote_rule(O1, O2) + O21 = output_promote_rule(O2, O1) + O12 == O21 == Union{} && + error("output_promote_rule is undefined for $O1 and $O2") + (O12 == O21 || O21 == Union{}) && return O12 + O12 == Union{} && return O21 + error("output_promote_rule yields inconsistent results for $O1 and $O2: \ + $O12 for $O1 followed by $O2, versus $O21 for $O2 followed by $O1") +end + +@inline maybe_ambiguous_promoted_output_type(itrs) = + isempty(itrs) ? Tuple : # Generate a Tuple when given 0 inputs. + unrolled_mapreduce(output_type_for_promotion, output_promote_result, itrs) + +@inline inferred_output_type(itr) = + unambiguous_output_type(output_type_for_promotion(itr)) do + @inline + first_item_type(itr) + end +@inline inferred_output_type(itr::Base.Generator) = + unambiguous_output_type(output_type_for_promotion(itr.iter)) do + @inline + Base.promote_op(itr.f, first_item_type(itr.iter)) + end +@inline inferred_output_type(itr::Base.Iterators.Enumerate) = + unambiguous_output_type(output_type_for_promotion(itr.itr)) do + @inline + Tuple{Int, first_item_type(itr.itr)} + end +@inline inferred_output_type(itr::Base.Iterators.Zip) = + unambiguous_output_type(maybe_ambiguous_promoted_output_type(itr.is)) do + @inline + Tuple{unrolled_map_into_tuple(first_item_type, itr.is)...} + end + +@inline promoted_output_type(itrs) = + unambiguous_output_type(maybe_ambiguous_promoted_output_type(itrs)) do + @inline + first_item_type(generic_getindex(itrs, 1)) + end + +@inline map_output_type(f, itrs...) = + inferred_output_type(Iterators.map(f, itrs...)) + +@inline accumulate_output_type(op, itr, init, transform) = + unambiguous_output_type(output_type_for_promotion(itr)) do + @inline + no_init = init isa NoInit + arg1_type = no_init ? first_item_type(itr) : typeof(init) + arg2_type = no_init ? second_item_type(itr) : first_item_type(itr) + Base.promote_op(transform, Base.promote_op(op, arg1_type, arg2_type)) + end + +""" + constructor_from_tuple(output_type) + +A function that can be used to efficiently construct an output of type +`output_type` from a `Tuple`, or `identity` if such an output should not be +constructed from a `Tuple`. Defaults to `identity`, which also handles the case +where `output_type` is already `Tuple`. The `output_type` here is guaranteed to +be a `Type`, rather than a `ConditionalOutputType` or `NoOutputType`. + +Many statically-sized iterators (e.g., `SVector`s) are essentially wrappers for +`Tuple`s, and their constructors for `Tuple`s can be reduced to no-ops. The main +exceptions are [`StaticOneTo`](@ref UnrolledUtilities.StaticOneTo)s and +[`StaticBitVector`](@ref UnrolledUtilities.StaticBitVector)s, which do not +provide constructors for `Tuple`s because there is no performance benefit to +making a lazy or low-storage data structure once a corresponding high-storage +data structure has already been constructed. +""" +@inline constructor_from_tuple(::Type) = identity +@inline constructor_from_tuple(::Type{NT}) where {NT <: NamedTuple} = NT + +@inline inferred_constructor_from_tuple(itr) = + constructor_from_tuple(inferred_output_type(itr)) + +@inline promoted_constructor_from_tuple(itrs) = + constructor_from_tuple(promoted_output_type(itrs)) + +@inline inferred_empty(itr) = inferred_constructor_from_tuple(itr)(()) + +@inline promoted_empty(itrs) = promoted_constructor_from_tuple(itrs)(()) diff --git a/test/aqua.jl b/test/aqua.jl index d7becf1..ff1edd1 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -1,3 +1,4 @@ +using Test import Aqua, UnrolledUtilities # This is separate from all the other tests because Aqua.test_all checks for diff --git a/test/runtests.jl b/test/runtests.jl index 631181c..0cfddab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,9 @@ using SafeTestsets @safetestset "Test and Analyze" begin @time include("test_and_analyze.jl") - print_comparison_table() + for (title, comparison_table_dict) in comparison_table_dicts + print_comparison_table(title, comparison_table_dict) + end end @safetestset "Aqua" begin @time include("aqua.jl") diff --git a/test/test_and_analyze.jl b/test/test_and_analyze.jl index 70415e3..f14da78 100644 --- a/test/test_and_analyze.jl +++ b/test/test_and_analyze.jl @@ -6,45 +6,78 @@ using InteractiveUtils using UnrolledUtilities -comparison_table_dict = OrderedDict() +comparison_table_dicts = OrderedDict() -function print_comparison_table(io = stdout, generate_html = false) +function print_comparison_table(title, comparison_table_dict, io = stdout) table_data = mapreduce(vcat, collect(comparison_table_dict)) do (key, entries) stack(entry -> (key..., entry...), entries; dims = 1) end - highlighter(f, color) = - generate_html ? HtmlHighlighter(f, HtmlDecoration(; color)) : - Highlighter(f, Crayon(; foreground = Symbol(color))) - - better_performance_but_harder_to_compile = - highlighter(generate_html ? "royalblue" : "blue") do data, i, j - data[i, 4] != data[i, 5] && - (endswith(data[i, 6], "slower") || endswith(data[i, 7], "more")) - end - better_performance = - highlighter(generate_html ? "mediumseagreen" : "green") do data, i, j - data[i, 4] != data[i, 5] - end - mixed_compilation = - highlighter(generate_html ? "mediumorchid" : "magenta") do data, i, j - (endswith(data[i, 6], "slower") && endswith(data[i, 7], "less")) || - (endswith(data[i, 6], "faster") && endswith(data[i, 7], "more")) - end - harder_to_compile = - highlighter(generate_html ? "indianred" : "red") do data, i, j - endswith(data[i, 6], "slower") || endswith(data[i, 7], "more") - end - easier_to_compile = - highlighter(generate_html ? "darkturquoise" : "cyan") do data, i, j - endswith(data[i, 6], "faster") || endswith(data[i, 7], "less") + writing_to_docs = io isa IOStream + + color(color_str) = + writing_to_docs ? HtmlDecoration(; color_str) : + Crayon(; foreground = Symbol(color_str)) + highlighter_color(optimization, run_time, compile_time, allocs) = + if ( + contains(optimization, "better") || + contains(optimization, "fewer allocs") + ) && !contains(run_time, "more") || + contains(optimization, "identical") && contains(run_time, "less") + # better performance + if !contains(compile_time, "more") && !contains(allocs, "more") + # similar or better compilation and total allocations + if contains(optimization, "better") + # better optimization + color(writing_to_docs ? "darkturquoise" : "cyan") + else + # faster run time or fewer allocations at run time + color(writing_to_docs ? "mediumseagreen" : "green") + end + else + # worse compilation or total allocations + if contains(optimization, "better") + # better optimization + color(writing_to_docs ? "royalblue" : "blue") + else + # faster run time or fewer allocations at run time + color(writing_to_docs ? "khaki" : "yellow") + end + end + elseif contains(optimization, "identical") && + contains(run_time, "similar") + # similar performance + if contains(compile_time, "less") && !contains(allocs, "more") || + !contains(compile_time, "more") && contains(allocs, "less") + # better compilation or total allocations + color(writing_to_docs ? "mediumorchid" : "magenta") + elseif contains(compile_time, "similar") && + contains(allocs, "similar") + # similar compilation and total allocations + color(writing_to_docs ? "lightgray" : "light_gray") + elseif contains(compile_time, "less") && contains(allocs, "more") || + contains(compile_time, "more") && contains(allocs, "less") + # mixed compilation and total allocations + color(writing_to_docs ? "darkgray" : "dark_gray") + else + # worse compilation and total allocations + color(writing_to_docs ? "indianred" : "red") + end + else + # worse performance + color(writing_to_docs ? "indianred" : "red") end - no_difference = - highlighter((data, i, j) -> true, generate_html ? "khaki" : "yellow") + highlighter = (writing_to_docs ? HtmlHighlighter : Highlighter)( + Returns(true), + (_, data, row, _) -> highlighter_color(data[row, 6:9]...), + ) + + # TODO: Why does Sys.maxrss() always seem to be 0 on Ubuntu systems? + has_rss = any(contains('['), table_data[:, 9]) other_kwargs = - generate_html ? + writing_to_docs ? (; backend = Val(:html), table_style = Dict( @@ -53,38 +86,86 @@ function print_comparison_table(io = stdout, generate_html = false) ), ) : (; + title, + title_alignment = :c, title_same_width_as_table = true, - columns_width = [45, 45, 0, 0, 0, 0, 0], + columns_width = [45, 45, 15, 10, 30, 25, 20, 20, has_rss ? 30 : 20], linebreaks = true, autowrap = true, crop = :none, ) + if writing_to_docs + println(io, "## $title") + println(io, "```@raw html") + println(io, "
") # 80% of viewport + end pretty_table( io, table_data; - title = "Comparison of UnrolledUtilities to Base and Base.Iterators", - title_alignment = :c, alignment = :l, header = [ "Unrolled Expression", "Reference Expression", - "Iterator Contents", - "Unrolled Performance", - "Reference Performance", - "Unrolled Compilation Time", - "Unrolled Compilation Memory", + "Itr Type", + "Itr Length", + "Itr Contents", + "Optimization", + "Run Time", + "Compilation Time", + "Total $(has_rss ? "GC [and RSS] " : "")Allocations", ], - highlighters = ( - better_performance_but_harder_to_compile, - better_performance, - mixed_compilation, - harder_to_compile, - easier_to_compile, - no_difference, - ), + highlighters = highlighter, other_kwargs..., ) + if writing_to_docs + println(io, "
") + println(io, "```") + else + println(io) + end +end + +function time_string(nanoseconds) + nanoseconds == 0 && return "$nanoseconds ns" + n_decimal_digits = floor(Int, log10(nanoseconds) + 1) + return if n_decimal_digits <= 3 + "$nanoseconds ns" + elseif n_decimal_digits <= 6 + "$(round(Int, nanoseconds / 10^3)) μs" + elseif n_decimal_digits <= 9 + "$(round(Int, nanoseconds / 10^6)) ms" + else + "$(round(Int, nanoseconds / 10^9)) s" + end +end + +function memory_string(bytes) + bytes == 0 && return "$bytes B" + n_binary_digits = floor(Int, log2(bytes) + 1) + return if n_binary_digits <= 10 + "$bytes B" + elseif n_binary_digits <= 20 + "$(round(Int, bytes / 2^10)) kB" + elseif n_binary_digits <= 30 + "$(round(Int, bytes / 2^20)) MB" + else + "$(round(Int, bytes / 2^30)) GB" + end +end + +function comparison_string(value1, value2, to_string, to_number = identity) + ratio = to_number(value1) / to_number(value2) + itr_str = if ratio >= 2 + floored_ratio = ratio == Inf ? Inf : floor(Int, ratio) + "$floored_ratio times more" + elseif inv(ratio) >= 2 + floored_inv_ratio = ratio == 0 ? Inf : floor(Int, inv(ratio)) + "$floored_inv_ratio times less" + else + "similar" + end + return "$itr_str ($(to_string(value1)) vs. $(to_string(value2)))" end function drop_line_numbers(expr) @@ -118,18 +199,61 @@ function code_instance(f, args...) end end -macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str) +macro benchmark(expression) + return quote + prev_time = time_ns() + $(esc(expression)) + new_time = time_ns() + best_time = new_time - prev_time + + # Benchmark for at most 0.1 s (10^8 ns), ignoring the first call above. + n_trials = 0 + start_time = new_time + while n_trials < 10^4 && new_time - start_time < 10^8 + prev_time = time_ns() + $(esc(expression)) + new_time = time_ns() + best_time = min(best_time, new_time - prev_time) + n_trials += 1 + end + + best_time + end +end + +macro test_unrolled( + args_expr, + unrolled_expr, + reference_expr, + itr_contents_str, + skip_allocations_test = false, + skip_type_stability_test = false, + reference_preamble_expr = nothing, +) @assert Meta.isexpr(args_expr, :tuple) arg_names = args_expr.args @assert all(arg_name -> arg_name isa Symbol, arg_names) args = map(esc, arg_names) unrolled_expr_str = simplified_expression_string(unrolled_expr) reference_expr_str = simplified_expression_string(reference_expr) - expr_info_str = - length(args) == 1 ? "$unrolled_expr_str with 1 iterator that contains" : - "$unrolled_expr_str with $(length(args)) iterators that each contain" + reference_preamble_str = + isnothing(reference_preamble_expr) ? "" : + string(reference_preamble_expr) * '\n' + contains_str = length(args) == 1 ? " that contains" : "s that each contain" quote - @info "Testing $($expr_info_str) $($(esc(contents_info_str)))" + itr_types = map(arg -> typeof(arg).name.wrapper, ($(args...),)) + itr_lengths = map(length, ($(args...),)) + + itr_type_str = + length(unique(itr_types)) == 1 ? string(itr_types[1]) : + join(itr_types, '/') + itr_length_str = + length(unique(itr_lengths)) == 1 ? string(itr_lengths[1]) : + join(itr_lengths, '/') + itr_str = "$itr_type_str$($contains_str) $itr_length_str \ + $($(esc(itr_contents_str)))" + + @info "Testing $($unrolled_expr_str) with $($(length(args))) $itr_str" unrolled_func($(arg_names...)) = $(esc(unrolled_expr)) reference_func($(arg_names...)) = $(esc(reference_expr)) @@ -146,26 +270,27 @@ macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str) reference_func_and_nothing($(args...)) # Test for allocations. - @test (@allocated unrolled_func_and_nothing($(args...))) == 0 - is_reference_non_allocating = - (@allocated reference_func_and_nothing($(args...))) == 0 + unrolled_run_memory = @allocated unrolled_func_and_nothing($(args...)) + reference_run_memory = @allocated reference_func_and_nothing($(args...)) + $(esc(skip_allocations_test)) || @test unrolled_run_memory == 0 # Test for type-stability. - @test_opt unrolled_func($(args...)) + is_unrolled_stable = + isempty(JET.get_reports(@report_opt unrolled_func($(args...)))) is_reference_stable = isempty(JET.get_reports(@report_opt reference_func($(args...)))) - - unrolled_instance = code_instance(unrolled_func, $(args...)) - reference_instance = code_instance(reference_func, $(args...)) + $(esc(skip_type_stability_test)) || @test_opt unrolled_func($(args...)) # Test for constant propagation. - is_unrolled_const = isdefined(unrolled_instance, :rettype_const) - Base.issingletontype(typeof(($(args...),))) && @test is_unrolled_const - is_reference_const = isdefined(reference_instance, :rettype_const) + is_unrolled_const = + isdefined(code_instance(unrolled_func, $(args...)), :rettype_const) + is_reference_const = + isdefined(code_instance(reference_func, $(args...)), :rettype_const) + # Base.issingletontype(typeof(($(args...),))) && @test is_unrolled_const buffer = IOBuffer() - # Check whether the functions are fully optimized out. + # Determine whether the functions are fully optimized out. args_type = Tuple{map(typeof, ($(args...),))...} code_llvm(buffer, unrolled_func, args_type; debuginfo = :none) is_unrolled_optimized_out = @@ -174,86 +299,116 @@ macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str) is_reference_optimized_out = length(split(String(take!(buffer)), '\n')) == 5 + # Test the overall level of optimization. + unrolled_opt_str, unrolled_opt_score = if unrolled_run_memory > 0 + "$(memory_string(unrolled_run_memory)) allocs", 1 / unrolled_run_memory + elseif !is_unrolled_stable + "type-unstable", 2 + elseif !is_unrolled_const && !is_unrolled_optimized_out + "type-stable", 3 + elseif !is_unrolled_optimized_out + "constant", 4 + else + "optimized out", 5 + end + reference_opt_str, reference_opt_score = if reference_run_memory > 0 + "$(memory_string(reference_run_memory)) allocs", + 1 / reference_run_memory + elseif !is_reference_stable + "type-unstable", 2 + elseif !is_reference_const && !is_reference_optimized_out + "type-stable", 3 + elseif !is_reference_optimized_out + "constant", 4 + else + "optimized out", 5 + end + $(esc(skip_type_stability_test)) || + @test unrolled_opt_score >= reference_opt_score + + # Measure the run times. + unrolled_run_time = @benchmark unrolled_func($(args...)) + reference_run_time = @benchmark reference_func($(args...)) + + # Measure the compilation times and memory allocations in separate + # processes to ensure that they are not under-counted. arg_name_strs = ($(map(string, arg_names)...),) arg_names_str = join(arg_name_strs, ", ") arg_definition_strs = map((name, value) -> "$name = $value", arg_name_strs, ($(args...),)) arg_definitions_str = join(arg_definition_strs, '\n') - unrolled_command_str = """ + command_str(func_str) = """ using UnrolledUtilities - unrolled_func($arg_names_str) = $($(string(unrolled_expr))) $arg_definitions_str - stats1 = @timed unrolled_func($arg_names_str) - stats2 = @timed unrolled_func($arg_names_str) - print(stats1.time - stats2.time, ',', stats1.bytes - stats2.bytes) - """ - reference_command_str = """ - reference_func($arg_names_str) = $($(string(reference_expr))) - $arg_definitions_str - stats1 = @timed reference_func($arg_names_str) - stats2 = @timed reference_func($arg_names_str) - print(stats1.time - stats2.time, ',', stats1.bytes - stats2.bytes) + Base.cumulative_compile_timing(true) + nanoseconds1 = Base.cumulative_compile_time_ns()[1] + rss_bytes_1 = Sys.maxrss() + Δgc_bytes = @allocated $func_str + rss_bytes_2 = Sys.maxrss() + nanoseconds2 = Base.cumulative_compile_time_ns()[1] + Base.cumulative_compile_timing(false) + Δnanoseconds = nanoseconds2 - nanoseconds1 + Δrss_bytes = rss_bytes_2 - rss_bytes_1 + print(Δnanoseconds, ", ", Δgc_bytes, ", ", Δrss_bytes) """ - # Get the unrolled function's time-to-first-run and its memory usage. + unrolled_command_str = command_str($(string(unrolled_expr))) run(pipeline(`julia --project -e $unrolled_command_str`, buffer)) - unrolled_time, unrolled_memory = - parse.((Float64, Int), split(String(take!(buffer)), ',')) + unrolled_compile_time, unrolled_total_memory, unrolled_total_rss = + parse.((Int, Int, Int), split(String(take!(buffer)), ',')) # Make a new buffer to avoid a potential data race: # https://discourse.julialang.org/t/iobuffer-becomes-not-writable-after-run/92323/3 close(buffer) buffer = IOBuffer() - # Get the reference function's time-to-first-run and its memory usage. + reference_command_str = + $reference_preamble_str * command_str($(string(reference_expr))) run(pipeline(`julia --project -e $reference_command_str`, buffer)) - reference_time, reference_memory = - parse.((Float64, Int), split(String(take!(buffer)), ',')) + reference_compile_time, reference_total_memory, reference_total_rss = + parse.((Int, Int, Int), split(String(take!(buffer)), ',')) close(buffer) - # Record all relevant information in comparison_table_dict. - unrolled_performance_str = if !is_unrolled_const - "type-stable" - elseif !is_unrolled_optimized_out - "const return value" - else - "fully optimized out" - end - reference_performance_str = if !is_reference_non_allocating - "allocating" - elseif !is_reference_stable - "type-unstable" - elseif !is_reference_const - "type-stable" - elseif !is_reference_optimized_out - "const return value" - else - "fully optimized out" - end - time_ratio = unrolled_time / reference_time - time_ratio_str = if time_ratio >= 1.5 - "$(round(Int, time_ratio)) times slower" - elseif inv(time_ratio) >= 1.5 - "$(round(Int, inv(time_ratio))) times faster" - else - "similar" - end - memory_ratio = unrolled_memory / reference_memory - memory_ratio_str = if memory_ratio >= 1.5 - "$(round(Int, memory_ratio)) times more" - elseif inv(memory_ratio) >= 1.5 - "$(round(Int, inv(memory_ratio))) times less" + optimization_str = if unrolled_opt_score > reference_opt_score + if unrolled_opt_score <= 1 + "fewer allocs ($unrolled_opt_str vs. $reference_opt_str)" + else + "better ($unrolled_opt_str vs. $reference_opt_str)" + end + elseif unrolled_opt_score < reference_opt_score + "worse ($unrolled_opt_str vs. $reference_opt_str)" else - "similar" + "identical ($unrolled_opt_str)" end + run_time_str = comparison_string( + unrolled_run_time, + reference_run_time, + time_string, + ) + compile_time_str = comparison_string( + unrolled_compile_time, + reference_compile_time, + time_string, + ) + memory_str = comparison_string( + (unrolled_total_memory, unrolled_total_rss), + (reference_total_memory, reference_total_rss), + ((gc_bytes, rss_bytes),) -> + rss_bytes == 0 ? memory_string(gc_bytes) : + "$(memory_string(gc_bytes)) [$(memory_string(rss_bytes))]", + first, # Use GC value for comparison since RSS might be unavailable. + ) + dict_key = ($unrolled_expr_str, $reference_expr_str) dict_entry = ( - $(esc(contents_info_str)), - unrolled_performance_str, - reference_performance_str, - time_ratio_str, - memory_ratio_str, + itr_type_str, + itr_length_str, + $(esc(itr_contents_str)), + optimization_str, + run_time_str, + compile_time_str, + memory_str, ) if dict_key in keys(comparison_table_dict) push!(comparison_table_dict[dict_key], dict_entry) @@ -263,160 +418,212 @@ macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str) end end -@testset "empty iterators" begin - itr = () - str = "nothing" - @test_unrolled (itr,) unrolled_any(error, itr) any(error, itr) str - @test_unrolled (itr,) unrolled_all(error, itr) all(error, itr) str - @test_unrolled (itr,) unrolled_foreach(error, itr) foreach(error, itr) str - @test_unrolled (itr,) unrolled_map(error, itr, itr) map(error, itr, itr) str - @test_unrolled( - (itr,), - unrolled_reduce(error, itr; init = 0), - reduce(error, itr; init = 0), - str, - ) +tuple_of_tuples(num_tuples, min_tuple_length, singleton, identical) = + ntuple(num_tuples) do index + tuple_length = min_tuple_length + (identical ? 0 : (index - 1) % 7) + ntuple(singleton ? Val : identity, tuple_length) + end +function tuples_of_tuples_contents_str(itrs...) + str = "" + all(itr -> length(itr) > 1 && length(unique(itr)) == 1, itrs) && + (str *= "identical ") + all(itr -> length(itr) > 1 && length(unique(itr)) != 1, itrs) && + (str *= "distinct ") + all(itr -> all(isempty, itr), itrs) && (str *= "empty ") + all(itr -> all(!isempty, itr), itrs) && (str *= "nonempty ") + all(itr -> any(isempty, itr) && any(!isempty, itr), itrs) && + (str *= "empty & nonempty ") + all(itr -> Base.issingletontype(typeof(itr)), itrs) && (str *= "singleton ") + all(itr -> !Base.issingletontype(typeof(itr)), itrs) && + (str *= "non-singleton ") + str *= "Tuple" + all(itr -> length(itr) > 1, itrs) && (str *= "s") + return str end -for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false)) - itr1 = ntuple(i -> ntuple(Val, identical ? 0 : (i - 1) % 7), n) - itr2 = ntuple(i -> ntuple(Val, identical ? 1 : (i - 1) % 7 + 1), n) - itr3 = ntuple(i -> ntuple(identity, identical ? 1 : (i - 1) % 7 + 1), n) - if n == 1 - str1 = "1 empty tuple" - str2 = "1 nonempty singleton tuple" - str3 = "1 nonempty non-singleton tuple" - str12 = "1 singleton tuple" - str23 = "1 nonempty tuple" - str123 = "1 tuple" - elseif identical - str1 = "$n empty tuples" - str2 = "$n identical nonempty singleton tuples" - str3 = "$n identical nonempty non-singleton tuples" - str12 = "$n identical singleton tuples" - str23 = "$n identical nonempty tuples" - str123 = "$n identical tuples" - else - str1 = "$n empty and nonempty singleton tuples" - str2 = "$n nonempty singleton tuples" - str3 = "$n nonempty non-singleton tuples" - str12 = "$n singleton tuples" - str23 = "$n nonempty tuples" - str123 = "$n tuples" - end - @testset "iterators of $str123" begin - for (itr, str) in ((itr1, str1), (itr2, str2), (itr3, str3)) - @test_unrolled (itr,) unrolled_any(isempty, itr) any(isempty, itr) str - @test_unrolled (itr,) unrolled_any(!isempty, itr) any(!isempty, itr) str +# TODO: +# - Split recursively and generatively unrolled funcs by length +# - Move empty iterator, no-iterator, and 8187 iterator tests into unit tests +# - Make init accept Fields, turn reduce and accumulate into operators through +# Base.Broadcast.broadcasted_kwsyntax + +title = "Individual unrolled functions" +comparison_table_dict = (comparison_table_dicts[title] = OrderedDict()) + +for itr in ( + tuple_of_tuples(1, 0, true, true), + tuple_of_tuples(1, 1, true, true), + tuple_of_tuples(1, 1, false, true), + map(n -> tuple_of_tuples(n, 0, true, true), (8, 32, 33, 128))..., + map(n -> tuple_of_tuples(n, 1, true, true), (8, 32, 33, 128))..., + map(n -> tuple_of_tuples(n, 1, false, true), (8, 32, 33, 128))..., + map(n -> tuple_of_tuples(n, 0, true, false), (8, 32, 33, 128))..., + map(n -> tuple_of_tuples(n, 1, true, false), (8, 32, 33, 128))..., + map(n -> tuple_of_tuples(n, 1, false, false), (8, 32, 33, 128))..., +) + str = tuples_of_tuples_contents_str(itr) + itr_str = "a Tuple that contains $(length(itr)) $str" + @testset "individual unrolled functions of $itr_str" begin + @test_unrolled (itr,) unrolled_any(isempty, itr) any(isempty, itr) str + @test_unrolled (itr,) unrolled_any(!isempty, itr) any(!isempty, itr) str + + @test_unrolled (itr,) unrolled_all(isempty, itr) all(isempty, itr) str + @test_unrolled (itr,) unrolled_all(!isempty, itr) all(!isempty, itr) str - @test_unrolled (itr,) unrolled_all(isempty, itr) all(isempty, itr) str - @test_unrolled (itr,) unrolled_all(!isempty, itr) all(!isempty, itr) str + @test_unrolled( + (itr,), + unrolled_foreach(x -> @assert(length(x) <= 7), itr), + foreach(x -> @assert(length(x) <= 7), itr), + str, + ) - @test_unrolled( - (itr,), - unrolled_foreach(x -> @assert(length(x) <= 7), itr), - foreach(x -> @assert(length(x) <= 7), itr), - str, - ) + @test_unrolled (itr,) unrolled_map(length, itr) map(length, itr) str - @test_unrolled (itr,) unrolled_map(length, itr) map(length, itr) str + @test_unrolled( + (itr,), + unrolled_applyat(length, rand(1:7:length(itr)), itr), + length(itr[rand(1:7:length(itr))]), + str, + ) - @test_unrolled (itr,) unrolled_reduce(tuple, itr) reduce(tuple, itr) str - @test_unrolled( - (itr,), - unrolled_reduce(tuple, itr; init = ()), - reduce(tuple, itr; init = ()), - str, - ) + @test_unrolled (itr,) unrolled_reduce(tuple, itr) reduce(tuple, itr) str + @test_unrolled( + (itr,), + unrolled_reduce(tuple, itr; init = ()), + reduce(tuple, itr; init = ()), + str, + ) + @test_unrolled( + (itr,), + unrolled_mapreduce(length, +, itr), + mapreduce(length, +, itr), + str, + ) + @test_unrolled( + (itr,), + unrolled_mapreduce(length, +, itr; init = 0), + mapreduce(length, +, itr; init = 0), + str, + ) + + if length(itr) <= 33 @test_unrolled( (itr,), - unrolled_mapreduce(length, +, itr), - mapreduce(length, +, itr), + unrolled_accumulate(tuple, itr), + accumulate(tuple, itr), str, ) @test_unrolled( (itr,), - unrolled_mapreduce(length, +, itr; init = 0), - mapreduce(length, +, itr; init = 0), + unrolled_accumulate(tuple, itr; init = ()), + accumulate(tuple, itr; init = ()), str, ) + end # These can take half a minute to compile when the length is 128. - @test_unrolled (itr,) unrolled_zip(itr) Tuple(zip(itr)) str + @test_unrolled (itr,) unrolled_push(itr, itr[1]) (itr..., itr[1]) str + @test_unrolled (itr,) unrolled_append(itr, itr) (itr..., itr...) str - @test_unrolled (itr,) unrolled_enumerate(itr) Tuple(enumerate(itr)) str + @test_unrolled( + (itr,), + unrolled_take(itr, Val(length(itr) ÷ 2)), + itr[1:(length(itr) ÷ 2)], + str, + ) + @test_unrolled( + (itr,), + unrolled_drop(itr, Val(length(itr) ÷ 2)), + itr[(length(itr) ÷ 2 + 1):end], + str, + ) - @test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str - @test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str - @test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str + @test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str + @test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str + @test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str - # unrolled_unique is only type-stable for singletons - if Base.issingletontype(typeof(itr)) - @test_unrolled (itr,) unrolled_unique(itr) Tuple(unique(itr)) str - end + @test_unrolled( + (itr,), + unrolled_unique(itr), + Tuple(unique(itr)), + str, + !Base.issingletontype(typeof(itr)), + !Base.issingletontype(typeof(itr)), + ) # unrolled_unique is type-unstable for non-singleton values - @test_unrolled( - (itr,), - unrolled_filter(!isempty, itr), - filter(!isempty, itr), - str, - ) + @test_unrolled( + (itr,), + unrolled_filter(!isempty, itr), + filter(!isempty, itr), + str, + ) - @test_unrolled( - (itr,), - unrolled_split(isempty, itr), - (filter(isempty, itr), filter(!isempty, itr)), - str, - ) + @test_unrolled( + (itr,), + unrolled_split(isempty, itr), + (filter(isempty, itr), filter(!isempty, itr)), + str, + ) - @test_unrolled( - (itr,), - unrolled_flatten(itr), - Tuple(Iterators.flatten(itr)), - str, - ) + @test_unrolled( + (itr,), + unrolled_flatten(itr), + Tuple(Iterators.flatten(itr)), + str, + ) - @test_unrolled( - (itr,), - unrolled_flatmap(reverse, itr), - Tuple(Iterators.flatmap(reverse, itr)), - str, - ) + @test_unrolled( + (itr,), + unrolled_flatmap(reverse, itr), + Tuple(Iterators.flatmap(reverse, itr)), + str, + ) + if length(itr) <= 33 @test_unrolled( (itr,), - unrolled_product(itr), - Tuple(Iterators.product(itr)), + unrolled_product(itr, itr), + Tuple(Iterators.product(itr, itr)), str, ) - + end + if length(itr) <= 8 @test_unrolled( (itr,), - unrolled_applyat( - x -> @assert(length(x) <= 7), - rand(1:length(itr)), - itr, - ), - @assert(length(itr[rand(1:length(itr))]) <= 7), + unrolled_product(itr, itr, itr), + Tuple(Iterators.product(itr, itr, itr)), str, ) + end # This can take several minutes to compile when the length is 32. + end +end - if n > 1 - @test_unrolled( - (itr,), - unrolled_take(itr, Val(7)), - itr[1:7], - str, - ) - @test_unrolled( - (itr,), - unrolled_drop(itr, Val(7)), - itr[8:end], - str, - ) - end - end - +title = "Nested unrolled functions" +comparison_table_dict = (comparison_table_dicts[title] = OrderedDict()) + +for (itr1, itr2, itr3) in ( + ( + tuple_of_tuples(1, 0, true, true), + tuple_of_tuples(1, 1, true, true), + tuple_of_tuples(1, 1, false, true), + ), + zip( + map(n -> tuple_of_tuples(n, 0, true, true), (8, 32, 33, 128)), + map(n -> tuple_of_tuples(n, 1, true, true), (8, 32, 33, 128)), + map(n -> tuple_of_tuples(n, 1, false, true), (8, 32, 33, 128)), + )..., + zip( + map(n -> tuple_of_tuples(n, 0, true, false), (8, 32, 33, 128)), + map(n -> tuple_of_tuples(n, 1, true, false), (8, 32, 33, 128)), + map(n -> tuple_of_tuples(n, 1, false, false), (8, 32, 33, 128)), + )..., +) + str3 = tuples_of_tuples_contents_str(itr3) + str12 = tuples_of_tuples_contents_str(itr1, itr2) + str23 = tuples_of_tuples_contents_str(itr2, itr3) + str123 = tuples_of_tuples_contents_str(itr1, itr2, itr3) + itr_str = "Tuples that contain $(length(itr1)) $str123" + @testset "nested unrolled functions of $itr_str" begin @test_unrolled( (itr3,), unrolled_any(x -> unrolled_reduce(+, x) > 7, itr3), @@ -434,11 +641,11 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false)) @test_unrolled( (itr1, itr2), unrolled_foreach( - (x1, x2) -> @assert(length(x1) < length(x2)), + (x1, x2) -> @assert(x1 == unrolled_take(x2, Val(length(x1)))), itr1, itr2, ), - foreach((x1, x2) -> @assert(length(x1) < length(x2)), itr1, itr2), + foreach((x1, x2) -> @assert(x1 == x2[1:length(x1)]), itr1, itr2), str12, ) @test_unrolled( @@ -455,13 +662,13 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false)) @test_unrolled( (itr1, itr2), unrolled_applyat( - (x1, x2) -> @assert(length(x1) < length(x2)), + (x1, x2) -> @assert(x1 == unrolled_take(x2, Val(length(x1)))), rand(1:length(itr1)), itr1, itr2, ), let n = rand(1:length(itr1)) - @assert(length(itr1[n]) < length(itr2[n])) + @assert(itr1[n] == itr2[n][1:length(itr1[n])]) end, str12, ) @@ -478,51 +685,25 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false)) end, str23, ) - - @test_unrolled( - (itr1, itr2), - unrolled_zip(itr1, itr2), - Tuple(zip(itr1, itr2)), - str12, - ) - @test_unrolled( - (itr1, itr2, itr3), - unrolled_zip(itr1, itr2, itr3), - Tuple(zip(itr1, itr2, itr3)), - str123, - ) - - # unrolled_product can take several minutes to compile when n is large - if n <= 33 - @test_unrolled( - (itr1, itr2), - unrolled_product(itr1, itr2), - Tuple(Iterators.product(itr1, itr2)), - str12, - ) - end - if n <= 8 - @test_unrolled( - (itr1, itr2, itr3), - unrolled_product(itr1, itr2, itr3), - Tuple(Iterators.product(itr1, itr2, itr3)), - str123, - ) - end end end nested_iterator(depth, n, inner_n) = depth == 1 ? ntuple(identity, n) : - ntuple(inner_n) do _ - nested_iterator(depth - 1, Int(n / inner_n), inner_n) - end + ntuple( + Returns(nested_iterator(depth - 1, Int(n / inner_n), inner_n)), + inner_n, + ) + +title = "Recursive unrolled functions" +comparison_table_dict = (comparison_table_dicts[title] = OrderedDict()) for n in (8, 32, 128) - @testset "iterators of $n values in nested tuples" begin + itr_str = "a Tuple that contains $n values in nested Tuples" + @testset "recursive unrolled functions of $itr_str" begin for depth in (2, 3, 4:2:(Int(log2(n)) + 1)...) itr = nested_iterator(depth, n, 2) - str = "$n values in nested tuples of depth $depth" + str = "recursive unrolled functions of $itr_str of depth $depth" # In the following definitions, use var"#self#" to avoid boxing: # https://discourse.julialang.org/t/performant-recursive-anonymous-functions/90984/5 @test_unrolled( @@ -561,3 +742,245 @@ for n in (8, 32, 128) end end end + +title = "Nested unrolled closures" +comparison_table_dict = (comparison_table_dicts[title] = OrderedDict()) + +@testset "nested unrolled closures of Tuples vs. StaticBitVectors" begin + for (itr, skip_allocations_test) in ( + (ntuple(Returns(true), 32), false), + (ntuple(Returns(true), 33), true), + (StaticBitVector{256}(true), false), + (StaticBitVector{257}(true), true), + ) + @test_unrolled( + (itr,), + unrolled_reduce( + (itr′, i) -> Base.setindex(itr′, !itr′[i], i), + StaticOneTo(length(itr)); + init = itr, + ), + reduce( + (itr′, i) -> Base.setindex(itr′, !itr′[i], i), + Base.OneTo(length(itr)); + init = itr, + ), + "Bools", + skip_allocations_test, + ) + @test_unrolled( + (itr,), + unrolled_reduce( + (itr′, i) -> unrolled_reduce( + (itr′′, j) -> + Base.setindex(itr′′, !itr′′[min(i, j)], j), + StaticOneTo(length(itr′)); + init = itr′, + ), + StaticOneTo(length(itr)); + init = itr, + ), + reduce( + (itr′, i) -> reduce( + (itr′′, j) -> + Base.setindex(itr′′, !itr′′[min(i, j)], j), + Base.OneTo(length(itr′)); + init = itr′, + ), + Base.OneTo(length(itr)); + init = itr, + ), + "Bools", + skip_allocations_test, + ) + if length(itr) <= 256 + @test_unrolled( + (itr,), + unrolled_reduce( + (itr′, i) -> unrolled_reduce( + (itr′′, j) -> unrolled_reduce( + (itr′′′, k) -> Base.setindex( + itr′′′, + !itr′′′[min(i, j, k)], + k, + ), + StaticOneTo(length(itr′′)); + init = itr′′, + ), + StaticOneTo(length(itr′)); + init = itr′, + ), + StaticOneTo(length(itr)); + init = itr, + ), + reduce( + (itr′, i) -> reduce( + (itr′′, j) -> reduce( + (itr′′′, k) -> Base.setindex( + itr′′′, + !itr′′′[min(i, j, k)], + k, + ), + Base.OneTo(length(itr′′)); + init = itr′′, + ), + Base.OneTo(length(itr′)); + init = itr′, + ), + Base.OneTo(length(itr)); + init = itr, + ), + "Bools", + skip_allocations_test, + ) + end # The StaticBitVector{257} allocates over 2 GB for this test. + end +end + +title = "Edge cases for unrolling" +comparison_table_dict = (comparison_table_dicts[title] = OrderedDict()) + +@testset "unrolled functions of an empty Tuple" begin + itr = () + str = "nothing" + @test_unrolled (itr,) unrolled_any(error, itr) any(error, itr) str + @test_unrolled (itr,) unrolled_all(error, itr) all(error, itr) str + @test_unrolled (itr,) unrolled_foreach(error, itr) foreach(error, itr) str + @test_unrolled (itr,) unrolled_map(error, itr) map(error, itr) str + @test_throws "init" unrolled_reduce(error, itr) + @test_unrolled( + (itr,), + unrolled_reduce(error, itr; init = 0), + reduce(error, itr; init = 0), + str, + ) + @test_unrolled( + (itr,), + unrolled_accumulate(error, itr), + accumulate(error, itr), + str, + ) + @test_unrolled( + (itr,), + unrolled_accumulate(error, itr; init = 0), + accumulate(error, itr; init = 0), + str, + ) +end + +@testset "unrolled functions of a large Tuple or StaticOneTo" begin + itr = StaticOneTo(8186) + @test_unrolled((itr,), unrolled_reduce(+, itr), reduce(+, itr), "Ints",) + @test_unrolled( + (itr,), + unrolled_mapreduce(log, +, itr), + mapreduce(log, +, itr), + "Ints", + ) + + # The previous tests also work with ntuple(identity, 8186), but they take a + # very long time to compile. + + @test_throws "gc handles" unrolled_reduce(+, StaticOneTo(8187)) + @test_throws "gc handles" unrolled_mapreduce(log, +, StaticOneTo(8187)) + + @test_throws "gc handles" unrolled_reduce(+, ntuple(identity, 8187)) + @test_throws "gc handles" unrolled_mapreduce(log, +, ntuple(identity, 8187)) + + # TODO: Why does the compiler throw an error when generating functions that + # require more than 8187 calls to getindex (or, rather, generic_getindex)? +end + +title = "Generative vs. recursive unrolling" +comparison_table_dict = (comparison_table_dicts[title] = OrderedDict()) + +for itr in ( + tuple_of_tuples(1, 0, true, true), + tuple_of_tuples(1, 1, true, true), + tuple_of_tuples(1, 1, false, true), + map(n -> tuple_of_tuples(n, 0, true, true), (8, 16, 32, 33, 128, 256))..., + map(n -> tuple_of_tuples(n, 1, true, true), (8, 16, 32, 33, 128, 256))..., + map(n -> tuple_of_tuples(n, 1, false, true), (8, 16, 32, 33, 128, 256))..., + map(n -> tuple_of_tuples(n, 0, true, false), (8, 16, 32, 33, 128, 256))..., + map(n -> tuple_of_tuples(n, 1, true, false), (8, 16, 32, 33, 128, 256))..., + map(n -> tuple_of_tuples(n, 1, false, false), (8, 16, 32, 33, 128, 256))..., +) + str = tuples_of_tuples_contents_str(itr) + itr_str = "a Tuple that contains $(length(itr)) $str" + @testset "generatively vs. recursively unrolled functions of $itr_str" begin + @test_unrolled( + (itr,), + UnrolledUtilities.gen_unrolled_any(isempty, itr), + UnrolledUtilities.rec_unrolled_any(isempty, itr), + str, + ) + + @test_unrolled( + (itr,), + UnrolledUtilities.gen_unrolled_all(isempty, itr), + UnrolledUtilities.rec_unrolled_all(isempty, itr), + str, + ) + + @test_unrolled( + (itr,), + UnrolledUtilities.gen_unrolled_foreach( + x -> @assert(length(x) <= 7), + itr, + ), + UnrolledUtilities.rec_unrolled_foreach( + x -> @assert(length(x) <= 7), + itr, + ), + str, + ) + + @test_unrolled( + (itr,), + UnrolledUtilities.gen_unrolled_map(length, itr), + UnrolledUtilities.rec_unrolled_map(length, itr), + str, + ) + + @test_unrolled( + (itr,), + UnrolledUtilities.gen_unrolled_applyat( + length, + rand(1:7:length(itr)), + itr, + ), + UnrolledUtilities.rec_unrolled_applyat( + length, + rand(1:7:length(itr)), + itr, + ), + str, + ) + + if length(itr) <= 33 + @test_unrolled( + (itr,), + UnrolledUtilities.gen_unrolled_reduce(tuple, itr, ()), + UnrolledUtilities.rec_unrolled_reduce(tuple, itr, ()), + str, + ) + + @test_unrolled( + (itr,), + UnrolledUtilities.gen_unrolled_accumulate( + tuple, + itr, + (), + identity, + ), + UnrolledUtilities.rec_unrolled_accumulate( + tuple, + itr, + (), + identity, + ), + str, + ) + end # These can take over a minute to compile when the length is 128. + end +end