Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unrolled_enumerate and unrolled_applyat #5

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
name = "UnrolledUtilities"
uuid = "0fe1646c-419e-43be-ac14-22321958931b"
authors = ["CliMA Contributors <[email protected]>"]
version = "0.1.1"

[deps]
version = "0.1.2"

[compat]
julia = "1.10"
Expand Down
11 changes: 7 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ A collection of generated functions in which all loops are unrolled and inlined:
- `unrolled_reduce(op, itr; [init])`: similar to `reduce`
- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce`
- `unrolled_zip(itrs...)`: similar to `zip`
- `unrolled_enumerate(itrs...)`: similar to `enumerate`, but with the ability to
handle multiple iterators
- `unrolled_in(item, itr)`: similar to `in`
- `unrolled_unique(itr)`: similar to `unique`
- `unrolled_filter(f, itr)`: similar to `filter`
Expand All @@ -16,10 +18,11 @@ A collection of generated functions in which all loops are unrolled and inlined:
- `unrolled_flatten(itr)`: similar to `Iterators.flatten`
- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap`
- `unrolled_product(itrs...)`: similar to `Iterators.product`
- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take` or `itr[1:N]`, but
with `N` wrapped in a `Val`
- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop` or
`itr[(N + 1):end]`, but with `N` wrapped in a `Val`
- `unrolled_applyat(f, n, itrs...)`: similar to `f(map(itr -> itr[n], itrs)...)`
- `unrolled_take(itr, ::Val{N})`: similar to `itr[1:N]` (and to
`Iterators.take`), but with `N` wrapped in a `Val`
- `unrolled_drop(itr, ::Val{N})`: similar to `itr[(N + 1):end]` (and to
`Iterators.drop`), but with `N` wrapped in a `Val`

These functions are guaranteed to be type-stable whenever they are given
iterators with inferrable lengths and element types, including when
Expand Down
13 changes: 12 additions & 1 deletion src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ export unrolled_any,
unrolled_reduce,
unrolled_mapreduce,
unrolled_zip,
unrolled_enumerate,
unrolled_in,
unrolled_unique,
unrolled_filter,
unrolled_split,
unrolled_flatten,
unrolled_flatmap,
unrolled_product,
unrolled_applyat,
unrolled_take,
unrolled_drop

inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types)
inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N
# We could also add support for statically-sized iterators that are not Tuples.

f_exprs(itr_type) = (:(f(itr[$n])) for n in 1:inferred_length(itr_type))
Expand Down Expand Up @@ -52,6 +54,9 @@ struct NoInit end

@inline unrolled_zip(itrs...) = unrolled_map(tuple, itrs...)

@inline unrolled_enumerate(itrs...) =
unrolled_zip(ntuple(identity, Val(length(itrs[1]))), itrs...)

@inline unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr)
# Using === instead of == or isequal improves type stability for singletons.

Expand Down Expand Up @@ -89,6 +94,11 @@ struct NoInit end
end
end

@inline unrolled_applyat(f, n, itrs...) = unrolled_foreach(
(i, items...) -> i == n && f(items...),
unrolled_enumerate(itrs...),
)

@inline unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N))
@inline unrolled_drop(itr, ::Val{N}) where {N} =
ntuple(i -> itr[N + i], Val(length(itr) - N))
Expand All @@ -107,6 +117,7 @@ struct NoInit end
unrolled_filter,
unrolled_split,
unrolled_flatmap,
unrolled_applyat,
)
for method in methods(func)
method.recursion_relation = (_...) -> true
Expand Down
44 changes: 42 additions & 2 deletions test/test_and_analyze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str)
arg_definitions_str = join(arg_definition_strs, '\n')
unrolled_command_str = """
using UnrolledUtilities
unrolled_func($arg_names_str) = $($unrolled_expr_str)
unrolled_func($arg_names_str) = $($(string(unrolled_expr)))
$arg_definitions_str
stats1 = @timed unrolled_func($arg_names_str)
stats2 = @timed unrolled_func($arg_names_str)
print(stats1.time - stats2.time, ',', stats1.bytes - stats2.bytes)
"""
reference_command_str = """
reference_func($arg_names_str) = $($reference_expr_str)
reference_func($arg_names_str) = $($(string(reference_expr)))
$arg_definitions_str
stats1 = @timed reference_func($arg_names_str)
stats2 = @timed reference_func($arg_names_str)
Expand Down Expand Up @@ -344,6 +344,8 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))

@test_unrolled (itr,) unrolled_zip(itr) Tuple(zip(itr)) str

@test_unrolled (itr,) unrolled_enumerate(itr) Tuple(enumerate(itr)) str

@test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str
@test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str
@test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str
Expand Down Expand Up @@ -388,6 +390,17 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))
str,
)

@test_unrolled(
(itr,),
unrolled_applyat(
x -> @assert(length(x) <= 7),
rand(1:length(itr)),
itr,
),
@assert(length(itr[rand(1:length(itr))]) <= 7),
str,
)

if n > 1
@test_unrolled(
(itr,),
Expand Down Expand Up @@ -439,6 +452,33 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false))
str23,
)

@test_unrolled(
(itr1, itr2),
unrolled_applyat(
(x1, x2) -> @assert(length(x1) < length(x2)),
rand(1:length(itr1)),
itr1,
itr2,
),
let n = rand(1:length(itr1))
@assert(length(itr1[n]) < length(itr2[n]))
end,
str12,
)
@test_unrolled(
(itr2, itr3),
unrolled_applyat(
(x2, x3) -> @assert(x2 == unrolled_map(Val, x3)),
rand(1:length(itr2)),
itr2,
itr3,
),
let n = rand(1:length(itr2))
@assert(itr2[n] == map(Val, itr3[n]))
end,
str23,
)

@test_unrolled(
(itr1, itr2),
unrolled_zip(itr1, itr2),
Expand Down
Loading