Skip to content

Commit

Permalink
Try removing StaticOneTo inlining hack
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Aug 29, 2024
1 parent 6cd498c commit 69727e5
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 98 deletions.
10 changes: 5 additions & 5 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ and inlined:
- `unrolled_any(f, itr)`: similar to `any`
- `unrolled_all(f, itr)`: similar to `all`
- `unrolled_foreach(f, itrs...)`: similar to `foreach`
- `unrolled_push(itr, item)`: similar to `push!`, but non-mutating
- `unrolled_append(itr, item)`: similar to `append!`, but non-mutating
- `unrolled_map(f, itrs...)`: similar to `map`
- `unrolled_reduce(op, itr; [init])`: similar to `reduce`
- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce`
- `unrolled_accumulate(op, itr; [init], [transform])`: similar to `accumulate`,
but with an optional `transform` function applied to every accumulated value
- `unrolled_in(item, itr)`: similar to `in`
- `unrolled_push(itr, item)`: similar to `push!`, but non-mutating
- `unrolled_append(itr, item)`: similar to `append!`, but non-mutating
- `unrolled_unique(itr)`: similar to `unique`
- `unrolled_filter(f, itr)`: similar to `filter`
- `unrolled_split(f, itr)`: similar to `(filter(f, itr), filter(!f, itr))`, but
Expand Down Expand Up @@ -63,8 +63,8 @@ through the following interface:
type_length
generic_getindex
output_type_for_promotion
ConditionalOutputType
NoOutputType
ConditionalOutputType
output_promote_rule
constructor_from_tuple
```
Expand Down Expand Up @@ -98,9 +98,9 @@ summarize their performance, compilation, and allocations:
The rows of the tables are highlighted as follows:
- green indicates an improvement in performance and either an improvement or
no change in compilation and allocations
- dark blue indicates an improvement in performance and either slower
- light blue indicates an improvement in performance and either slower
compilation or more allocations
- light blue indicates no change in performance and either faster compilation or
- dark blue indicates no change in performance and either faster compilation or
fewer allocations
- magenta indicates no change in performance and either faster compilation with
more allocations or slower compilation with fewer allocations
Expand Down
9 changes: 8 additions & 1 deletion ext/UnrolledUtilitiesStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
module UnrolledUtilitiesStaticArraysExt

import UnrolledUtilities
import StaticArrays: SVector, MVector
import StaticArrays: SOneTo, SVector, MVector

@inline UnrolledUtilities.type_length(::Type{SOneTo{N}}) where {N} = N
@inline UnrolledUtilities.generic_getindex(::SOneTo, n) = n
@inline UnrolledUtilities.output_type_for_promotion(::SOneTo) =
UnrolledUtilities.NoOutputType()
@inline UnrolledUtilities.unrolled_take(::SOneTo, ::Val{N}) where {N} =
SOneTo(N)

@inline UnrolledUtilities.type_length(::Type{<:SVector{N}}) where {N} = N
@inline UnrolledUtilities.output_type_for_promotion(::SVector) = SVector
Expand Down
3 changes: 1 addition & 2 deletions src/StaticBitVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ end
@inline function StaticBitVector{N, U}(bit::Bool = false) where {N, U}
n_bits_per_int = 8 * sizeof(U)
n_ints = cld(N, n_bits_per_int)
int = bit ? ~zero(U) : zero(U)
ints = ntuple(_ -> int, Val(n_ints))
ints = ntuple(Returns(bit ? ~zero(U) : zero(U)), Val(n_ints))
return StaticBitVector{N, U}(ints)
end

Expand Down
2 changes: 1 addition & 1 deletion src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ include("StaticBitVector.jl")
unrolled_applyat,
)
for method in methods(func)
method.recursion_relation = (_...) -> true
method.recursion_relation = Returns(true)
end
end
end
Expand Down
16 changes: 6 additions & 10 deletions src/generated_functions.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# Manually inlining the value of (::StaticOneTo)[n] improves performance.
item_expr(itr_type, itr_expr, n) =
itr_type <: StaticOneTo ? n : :(generic_getindex($itr_expr, $n))

f_exprs(itr_type) =
(:(f($(item_expr(itr_type, :itr, n)))) for n in 1:type_length(itr_type))
(:(f(generic_getindex(itr, $n))) for n in 1:type_length(itr_type))
@inline @generated unrolled_any(f, itr) = Expr(:||, f_exprs(itr)...)
@inline @generated unrolled_all(f, itr) = Expr(:&&, f_exprs(itr)...)

Expand All @@ -14,7 +10,7 @@ function zipped_f_exprs(itr_types)
L = length(itr_types)
N = isempty(itr_types) ? 0 : mapreduce(type_length, min, itr_types)
return (
:(f($((item_expr(itr_types[l], :(itrs[$l]), n) for l in 1:L)...))) for
:(f($((:(generic_getindex(itrs[$l], $n)) for l in 1:L)...))) for
n in 1:N
)
end
Expand All @@ -27,10 +23,10 @@ function nested_op_expr(itr_type, init_type)
N = type_length(itr_type)
(N == 0 && init_type <: NoInit) &&
error("unrolled_reduce requires an init value for empty iterators")
init_expr = init_type <: NoInit ? item_expr(itr_type, :itr, 1) : :init
init_expr = init_type <: NoInit ? :(generic_getindex(itr, 1)) : :init
n_range = init_type <: NoInit ? (2:N) : (1:N)
return foldl(n_range; init = init_expr) do prev_op_expr, n
:(op($prev_op_expr, $(item_expr(itr_type, :itr, n))))
:(op($prev_op_expr, generic_getindex(itr, $n)))
end
end
@inline @generated unrolled_reduce(op, itr, init) = nested_op_expr(itr, init)
Expand All @@ -39,13 +35,13 @@ function transformed_sequential_op_exprs(itr_type, init_type)
N = type_length(itr_type)
(N == 0 && init_type <: NoInit) &&
error("unrolled_accumulate requires an init value for empty iterators")
first_item_expr = item_expr(itr_type, :itr, 1)
first_item_expr = :(generic_getindex(itr, 1))
init_expr =
init_type <: NoInit ? first_item_expr : :(op(init, $first_item_expr))
transformed_exprs_and_next_op_exprs =
accumulate(1:N; init = (nothing, init_expr)) do (_, prev_op_expr), n
var = gensym()
next_op_expr = :(op($var, $(item_expr(itr_type, :itr, n + 1))))
next_op_expr = :(op($var, generic_getindex(itr, $(n + 1))))
(:($var = $prev_op_expr; transform($var)), next_op_expr)
end
return map(first, transformed_exprs_and_next_op_exprs)
Expand Down
Loading

0 comments on commit 69727e5

Please sign in to comment.