Skip to content

Commit

Permalink
Add unrolled_enumerate and unrolled_applyat
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Apr 4, 2024
1 parent 5c18833 commit 72bfe93
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 10 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
name = "UnrolledUtilities"
uuid = "0fe1646c-419e-43be-ac14-22321958931b"
authors = ["CliMA Contributors <[email protected]>"]
version = "0.1.1"

[deps]
version = "0.1.2"

[compat]
julia = "1.10"
Expand Down
11 changes: 7 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ A collection of generated functions in which all loops are unrolled and inlined:
- `unrolled_reduce(op, itr; [init])`: similar to `reduce`
- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce`
- `unrolled_zip(itrs...)`: similar to `zip`
- `unrolled_enumerate(itrs...)`: similar to `enumerate`, but with the ability to
handle multiple iterators
- `unrolled_in(item, itr)`: similar to `in`
- `unrolled_unique(itr)`: similar to `unique`
- `unrolled_filter(f, itr)`: similar to `filter`
Expand All @@ -16,10 +18,11 @@ A collection of generated functions in which all loops are unrolled and inlined:
- `unrolled_flatten(itr)`: similar to `Iterators.flatten`
- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap`
- `unrolled_product(itrs...)`: similar to `Iterators.product`
- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take` or `itr[1:N]`, but
with `N` wrapped in a `Val`
- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop` or
`itr[(N + 1):end]`, but with `N` wrapped in a `Val`
- `unrolled_applyat(f, n, itrs...)`: similar to `f(map(itr -> itr[n], itrs)...)`
- `unrolled_take(itr, ::Val{N})`: similar to `itr[1:N]` (and to
`Iterators.take`), but with `N` wrapped in a `Val`
- `unrolled_drop(itr, ::Val{N})`: similar to `itr[(N + 1):end]` (and to
`Iterators.drop`), but with `N` wrapped in a `Val`

These functions are guaranteed to be type-stable whenever they are given
iterators with inferrable lengths and element types, including when
Expand Down
13 changes: 12 additions & 1 deletion src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ export unrolled_any,
unrolled_reduce,
unrolled_mapreduce,
unrolled_zip,
unrolled_enumerate,
unrolled_in,
unrolled_unique,
unrolled_filter,
unrolled_split,
unrolled_flatten,
unrolled_flatmap,
unrolled_product,
unrolled_applyat,
unrolled_take,
unrolled_drop

inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types)
inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N
# We could also add support for statically-sized iterators that are not Tuples.

f_exprs(itr_type) = (:(f(itr[$n])) for n in 1:inferred_length(itr_type))
Expand Down Expand Up @@ -52,6 +54,9 @@ struct NoInit end

@inline unrolled_zip(itrs...) = unrolled_map(tuple, itrs...)

@inline unrolled_enumerate(itrs...) =
unrolled_zip(ntuple(identity, Val(length(itrs[1]))), itrs...)

@inline unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr)
# Using === instead of == or isequal improves type stability for singletons.

Expand Down Expand Up @@ -89,6 +94,11 @@ struct NoInit end
end
end

@inline unrolled_applyat(f, n, itrs...) = unrolled_foreach(
(i, items...) -> i == n && f(items...),
unrolled_enumerate(itrs...),
)

@inline unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N))
@inline unrolled_drop(itr, ::Val{N}) where {N} =
ntuple(i -> itr[N + i], Val(length(itr) - N))
Expand All @@ -107,6 +117,7 @@ struct NoInit end
unrolled_filter,
unrolled_split,
unrolled_flatmap,
unrolled_applyat,
)
for method in methods(func)
method.recursion_relation = (_...) -> true
Expand Down
44 changes: 42 additions & 2 deletions test/test_and_analyze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str)
arg_definitions_str = join(arg_definition_strs, '\n')
unrolled_command_str = """
using UnrolledUtilities
unrolled_func($arg_names_str) = $($unrolled_expr_str)
unrolled_func($arg_names_str) = $($(string(unrolled_expr)))
$arg_definitions_str
stats1 = @timed unrolled_func($arg_names_str)
stats2 = @timed unrolled_func($arg_names_str)
print(stats1.time - stats2.time, ',', stats1.bytes - stats2.bytes)
"""
reference_command_str = """
reference_func($arg_names_str) = $($reference_expr_str)
reference_func($arg_names_str) = $($(string(reference_expr)))
$arg_definitions_str
stats1 = @timed reference_func($arg_names_str)
stats2 = @timed reference_func($arg_names_str)
Expand Down Expand Up @@ -344,6 +344,8 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))

@test_unrolled (itr,) unrolled_zip(itr) Tuple(zip(itr)) str

@test_unrolled (itr,) unrolled_enumerate(itr) Tuple(enumerate(itr)) str

@test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str
@test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str
@test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str
Expand Down Expand Up @@ -388,6 +390,17 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))
str,
)

@test_unrolled(
(itr,),
unrolled_applyat(
x -> @assert(length(x) <= 7),
rand(1:length(itr)),
itr,
),
@assert(length(itr[rand(1:length(itr))]) <= 7),
str,
)

if n > 1
@test_unrolled(
(itr,),
Expand Down Expand Up @@ -439,6 +452,33 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))
str23,
)

@test_unrolled(
(itr1, itr2),
unrolled_applyat(
(x1, x2) -> @assert(length(x1) < length(x2)),
rand(1:length(itr1)),
itr1,
itr2,
),
let n = rand(1:length(itr1))
@assert(length(itr1[n]) < length(itr2[n]))
end,
str12,
)
@test_unrolled(
(itr2, itr3),
unrolled_applyat(
(x2, x3) -> @assert(x2 == unrolled_map(Val, x3)),
rand(1:length(itr2)),
itr2,
itr3,
),
let n = rand(1:length(itr2))
@assert(itr2[n] == map(Val, itr3[n]))
end,
str23,
)

@test_unrolled(
(itr1, itr2),
unrolled_zip(itr1, itr2),
Expand Down

0 comments on commit 72bfe93

Please sign in to comment.