Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPU compilation bugs that required val_unrolled_reduce workaround #18

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/StaticBitVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ end
@inline
first_index = n_bits_per_int * (int_index - 1) + 1
unrolled_reduce(
StaticOneTo(min(n_bits_per_int, N - first_index + 1));
init = zero(U),
StaticOneTo(min(n_bits_per_int, N - first_index + 1)),
zero(U),
) do int, bit_index
@inline
bit_offset = bit_index - 1
Expand Down Expand Up @@ -93,15 +93,15 @@ end
n_bits_per_int = 8 * sizeof(U)
n_ints = cld(N, n_bits_per_int)
ints = unrolled_accumulate(
StaticOneTo(n_ints);
init = (nothing, init),
transform = first,
StaticOneTo(n_ints),
(nothing, init),
first,
) do (_, init_value_for_new_int), int_index
@inline
first_index = n_bits_per_int * (int_index - 1) + 1
unrolled_reduce(
StaticOneTo(min(n_bits_per_int, N - first_index + 1));
init = (zero(U), init_value_for_new_int),
StaticOneTo(min(n_bits_per_int, N - first_index + 1)),
(zero(U), init_value_for_new_int),
) do (int, prev_value), bit_index
@inline
bit_offset = bit_index - 1
Expand Down
73 changes: 45 additions & 28 deletions src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,71 +28,88 @@ include("unrollable_iterator_interface.jl")
include("recursively_unrolled_functions.jl")
include("generatively_unrolled_functions.jl")

@inline unrolled_any(f, itr) =
@inline unrolled_any(f::F, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_any : gen_unrolled_any)(f, itr)
@inline unrolled_any(itr) = unrolled_any(identity, itr)

@inline unrolled_all(f, itr) =
@inline unrolled_all(f::F, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_all : gen_unrolled_all)(f, itr)
@inline unrolled_all(itr) = unrolled_all(identity, itr)

@inline unrolled_foreach(f, itr) =
@inline unrolled_foreach(f::F, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_foreach : gen_unrolled_foreach)(f, itr)
@inline unrolled_foreach(f, itrs...) = unrolled_foreach(splat(f), zip(itrs...))

@inline unrolled_map_into_tuple(f, itr) =
@inline unrolled_map_into_tuple(f::F, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_map : gen_unrolled_map)(f, itr)
@inline unrolled_map_into(output_type, f, itr) =
@inline unrolled_map_into(output_type, f::F, itr) where {F} =
constructor_from_tuple(output_type)(unrolled_map_into_tuple(f, itr))
@inline unrolled_map(f, itr) =
@inline unrolled_map(f::F, itr) where {F} =
unrolled_map_into(inferred_output_type(Iterators.map(f, itr)), f, itr)
@inline unrolled_map(f, itrs...) = unrolled_map(splat(f), zip(itrs...))
@inline unrolled_map(f::F, itrs...) where {F} =
unrolled_map(splat(f), zip(itrs...))

@inline unrolled_applyat(f, n, itr) =
@inline unrolled_applyat(f::F, n, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_applyat : gen_unrolled_applyat)(f, n, itr)
@inline unrolled_applyat(f, n, itrs...) =
@inline unrolled_applyat(f::F, n, itrs...) where {F} =
unrolled_applyat(splat(f), n, zip(itrs...))
@inline unrolled_applyat_bounds_error() =
error("unrolled_applyat has detected an out-of-bounds index")

@inline unrolled_reduce(op, itr, init) =
@inline unrolled_reduce(op::O, itr, init) where {O} =
isempty(itr) && init isa NoInit ?
error("unrolled_reduce requires an init value for empty iterators") :
(rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init)
@inline unrolled_reduce(op, itr; init = NoInit()) =
@inline unrolled_reduce(op::O, itr; init = NoInit()) where {O} =
unrolled_reduce(op, itr, init)

# TODO: Figure out why unrolled_reduce(op, Val(N), init) compiles faster than
# unrolled_reduce(op, StaticOneTo(N), init) for the non-orographic gravity wave
# parametrization test in ClimaAtmos, to the point where the StaticOneTo version
# completely hangs while the Val version compiles in only a few seconds.
@inline unrolled_reduce(op, val_N::Val, init) =
@inline unrolled_reduce(op::O, val_N::Val, init) where {O} =
val_N isa Val{0} && init isa NoInit ?
error("unrolled_reduce requires an init value for Val(0)") :
val_unrolled_reduce(op, val_N, init)

@inline unrolled_mapreduce(f, op, itrs...; init = NoInit()) =
@inline unrolled_mapreduce(f::F, op::O, itrs...; init = NoInit()) where {F, O} =
unrolled_reduce(op, unrolled_map(f, itrs...), init)

@inline unrolled_accumulate_into_tuple(op, itr, init, transform) =
@inline unrolled_accumulate_into_tuple(
op::O,
itr,
init,
transform::T,
) where {O, T} =
(rec_unroll(itr) ? rec_unrolled_accumulate : gen_unrolled_accumulate)(
op,
itr,
init,
transform,
)
@inline unrolled_accumulate_into(output_type, op, itr, init, transform) =
constructor_from_tuple(output_type)(
unrolled_accumulate_into_tuple(op, itr, init, transform),
)
@inline unrolled_accumulate(op, itr; init = NoInit(), transform = identity) =
@inline unrolled_accumulate_into(
output_type,
op::O,
itr,
init,
transform::T,
) where {O, T} = constructor_from_tuple(output_type)(
unrolled_accumulate_into_tuple(op, itr, init, transform),
)
@inline unrolled_accumulate(op::O, itr, init, transform::T) where {O, T} =
unrolled_accumulate_into(
accumulate_output_type(op, itr, init, transform),
op,
itr,
init,
transform,
)
@inline unrolled_accumulate(
op::O,
itr;
init = NoInit(),
transform::T = identity,
) where {O, T} = unrolled_accumulate(op, itr, init, transform)

@inline unrolled_push_into(output_type, itr, item) =
constructor_from_tuple(output_type)((itr..., item))
Expand Down Expand Up @@ -122,36 +139,36 @@ include("generatively_unrolled_functions.jl")
# Using === instead of == or isequal improves type stability for singletons.

@inline unrolled_unique(itr) =
unrolled_reduce(itr; init = inferred_empty(itr)) do unique_items, item
unrolled_reduce(itr, inferred_empty(itr)) do unique_items, item
@inline
unrolled_in(item, unique_items) ? unique_items :
unrolled_push(unique_items, item)
end

@inline unrolled_filter(f, itr) =
unrolled_reduce(itr; init = inferred_empty(itr)) do items_with_true_f, item
@inline unrolled_filter(f::F, itr) where {F} =
unrolled_reduce(itr, inferred_empty(itr)) do items_with_true_f, item
@inline
f(item) ? unrolled_push(items_with_true_f, item) : items_with_true_f
end

@inline unrolled_split(f, itr) =
@inline unrolled_split(f::F, itr) where {F} =
unrolled_reduce(
itr;
init = (inferred_empty(itr), inferred_empty(itr)),
itr,
(inferred_empty(itr), inferred_empty(itr)),
) do (items_with_true_f, items_with_false_f), item
@inline
f(item) ? (unrolled_push(items_with_true_f, item), items_with_false_f) :
(items_with_true_f, unrolled_push(items_with_false_f, item))
end

@inline unrolled_flatten(itr) =
unrolled_reduce(unrolled_append, itr; init = promoted_empty(itr))
unrolled_reduce(unrolled_append, itr, promoted_empty(itr))

@inline unrolled_flatmap(f, itrs...) =
@inline unrolled_flatmap(f::F, itrs...) where {F} =
unrolled_flatten(unrolled_map(f, itrs...))

@inline unrolled_product(itrs...) =
unrolled_reduce(itrs; init = (promoted_empty(itrs),)) do product_itr, itr
unrolled_reduce(itrs, (promoted_empty(itrs),)) do product_itr, itr
@inline
unrolled_flatmap(itr) do item
@inline
Expand Down
Loading