Skip to content

Commit

Permalink
Fix GPU compilation bugs that required val_unrolled_reduce workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Oct 18, 2024
1 parent 2d471e1 commit 5384492
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
14 changes: 7 additions & 7 deletions src/StaticBitVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ end
@inline
first_index = n_bits_per_int * (int_index - 1) + 1
unrolled_reduce(
StaticOneTo(min(n_bits_per_int, N - first_index + 1));
init = zero(U),
StaticOneTo(min(n_bits_per_int, N - first_index + 1)),
zero(U),
) do int, bit_index
@inline
bit_offset = bit_index - 1
Expand Down Expand Up @@ -93,15 +93,15 @@ end
n_bits_per_int = 8 * sizeof(U)
n_ints = cld(N, n_bits_per_int)
ints = unrolled_accumulate(
StaticOneTo(n_ints);
init = (nothing, init),
transform = first,
StaticOneTo(n_ints),
(nothing, init),
first,
) do (_, init_value_for_new_int), int_index
@inline
first_index = n_bits_per_int * (int_index - 1) + 1
unrolled_reduce(
StaticOneTo(min(n_bits_per_int, N - first_index + 1));
init = (zero(U), init_value_for_new_int),
StaticOneTo(min(n_bits_per_int, N - first_index + 1)),
(zero(U), init_value_for_new_int),
) do (int, prev_value), bit_index
@inline
bit_offset = bit_index - 1
Expand Down
73 changes: 45 additions & 28 deletions src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,71 +28,88 @@ include("unrollable_iterator_interface.jl")
include("recursively_unrolled_functions.jl")
include("generatively_unrolled_functions.jl")

@inline unrolled_any(f, itr) =
@inline unrolled_any(f::F, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_any : gen_unrolled_any)(f, itr)
@inline unrolled_any(itr) = unrolled_any(identity, itr)

@inline unrolled_all(f, itr) =
@inline unrolled_all(f::F, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_all : gen_unrolled_all)(f, itr)
@inline unrolled_all(itr) = unrolled_all(identity, itr)

@inline unrolled_foreach(f, itr) =
@inline unrolled_foreach(f::F, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_foreach : gen_unrolled_foreach)(f, itr)
@inline unrolled_foreach(f, itrs...) = unrolled_foreach(splat(f), zip(itrs...))

@inline unrolled_map_into_tuple(f, itr) =
@inline unrolled_map_into_tuple(f::F, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_map : gen_unrolled_map)(f, itr)
@inline unrolled_map_into(output_type, f, itr) =
@inline unrolled_map_into(output_type, f::F, itr) where {F} =
constructor_from_tuple(output_type)(unrolled_map_into_tuple(f, itr))
@inline unrolled_map(f, itr) =
@inline unrolled_map(f::F, itr) where {F} =
unrolled_map_into(inferred_output_type(Iterators.map(f, itr)), f, itr)
@inline unrolled_map(f, itrs...) = unrolled_map(splat(f), zip(itrs...))
@inline unrolled_map(f::F, itrs...) where {F} =
unrolled_map(splat(f), zip(itrs...))

@inline unrolled_applyat(f, n, itr) =
@inline unrolled_applyat(f::F, n, itr) where {F} =
(rec_unroll(itr) ? rec_unrolled_applyat : gen_unrolled_applyat)(f, n, itr)
@inline unrolled_applyat(f, n, itrs...) =
@inline unrolled_applyat(f::F, n, itrs...) where {F} =
unrolled_applyat(splat(f), n, zip(itrs...))
@inline unrolled_applyat_bounds_error() =
error("unrolled_applyat has detected an out-of-bounds index")

@inline unrolled_reduce(op, itr, init) =
@inline unrolled_reduce(op::O, itr, init) where {O} =
isempty(itr) && init isa NoInit ?
error("unrolled_reduce requires an init value for empty iterators") :
(rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init)
@inline unrolled_reduce(op, itr; init = NoInit()) =
@inline unrolled_reduce(op::O, itr; init = NoInit()) where {O} =
unrolled_reduce(op, itr, init)

# TODO: Figure out why unrolled_reduce(op, Val(N), init) compiles faster than
# unrolled_reduce(op, StaticOneTo(N), init) for the non-orographic gravity wave
# parametrization test in ClimaAtmos, to the point where the StaticOneTo version
# completely hangs while the Val version compiles in only a few seconds.
@inline unrolled_reduce(op, val_N::Val, init) =
@inline unrolled_reduce(op::O, val_N::Val, init) where {O} =
val_N isa Val{0} && init isa NoInit ?
error("unrolled_reduce requires an init value for Val(0)") :
val_unrolled_reduce(op, val_N, init)

@inline unrolled_mapreduce(f, op, itrs...; init = NoInit()) =
@inline unrolled_mapreduce(f::F, op::O, itrs...; init = NoInit()) where {F, O} =
unrolled_reduce(op, unrolled_map(f, itrs...), init)

@inline unrolled_accumulate_into_tuple(op, itr, init, transform) =
@inline unrolled_accumulate_into_tuple(
op::O,
itr,
init,
transform::T,
) where {O, T} =
(rec_unroll(itr) ? rec_unrolled_accumulate : gen_unrolled_accumulate)(
op,
itr,
init,
transform,
)
@inline unrolled_accumulate_into(output_type, op, itr, init, transform) =
constructor_from_tuple(output_type)(
unrolled_accumulate_into_tuple(op, itr, init, transform),
)
@inline unrolled_accumulate(op, itr; init = NoInit(), transform = identity) =
@inline unrolled_accumulate_into(
output_type,
op::O,
itr,
init,
transform::T,
) where {O, T} = constructor_from_tuple(output_type)(
unrolled_accumulate_into_tuple(op, itr, init, transform),
)
@inline unrolled_accumulate(op::O, itr, init, transform::T) where {O, T} =
unrolled_accumulate_into(
accumulate_output_type(op, itr, init, transform),
op,
itr,
init,
transform,
)
@inline unrolled_accumulate(
op::O,
itr;
init = NoInit(),
transform::T = identity,
) where {O, T} = unrolled_accumulate(op, itr, init, transform)

@inline unrolled_push_into(output_type, itr, item) =
constructor_from_tuple(output_type)((itr..., item))
Expand Down Expand Up @@ -122,36 +139,36 @@ include("generatively_unrolled_functions.jl")
# Using === instead of == or isequal improves type stability for singletons.

@inline unrolled_unique(itr) =
unrolled_reduce(itr; init = inferred_empty(itr)) do unique_items, item
unrolled_reduce(itr, inferred_empty(itr)) do unique_items, item
@inline
unrolled_in(item, unique_items) ? unique_items :
unrolled_push(unique_items, item)
end

@inline unrolled_filter(f, itr) =
unrolled_reduce(itr; init = inferred_empty(itr)) do items_with_true_f, item
@inline unrolled_filter(f::F, itr) where {F} =
unrolled_reduce(itr, inferred_empty(itr)) do items_with_true_f, item
@inline
f(item) ? unrolled_push(items_with_true_f, item) : items_with_true_f
end

@inline unrolled_split(f, itr) =
@inline unrolled_split(f::F, itr) where {F} =
unrolled_reduce(
itr;
init = (inferred_empty(itr), inferred_empty(itr)),
itr,
(inferred_empty(itr), inferred_empty(itr)),
) do (items_with_true_f, items_with_false_f), item
@inline
f(item) ? (unrolled_push(items_with_true_f, item), items_with_false_f) :
(items_with_true_f, unrolled_push(items_with_false_f, item))
end

@inline unrolled_flatten(itr) =
unrolled_reduce(unrolled_append, itr; init = promoted_empty(itr))
unrolled_reduce(unrolled_append, itr, promoted_empty(itr))

@inline unrolled_flatmap(f, itrs...) =
@inline unrolled_flatmap(f::F, itrs...) where {F} =
unrolled_flatten(unrolled_map(f, itrs...))

@inline unrolled_product(itrs...) =
unrolled_reduce(itrs; init = (promoted_empty(itrs),)) do product_itr, itr
unrolled_reduce(itrs, (promoted_empty(itrs),)) do product_itr, itr
@inline
unrolled_flatmap(itr) do item
@inline
Expand Down

0 comments on commit 5384492

Please sign in to comment.