From 538449220d79423804e206c549147ca92db050f1 Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Fri, 11 Oct 2024 18:31:26 -0700 Subject: [PATCH] Fix GPU compilation bugs that required val_unrolled_reduce workaround --- src/StaticBitVector.jl | 14 ++++---- src/UnrolledUtilities.jl | 73 +++++++++++++++++++++++++--------------- 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/src/StaticBitVector.jl b/src/StaticBitVector.jl index d41bffa..3584624 100644 --- a/src/StaticBitVector.jl +++ b/src/StaticBitVector.jl @@ -34,8 +34,8 @@ end @inline first_index = n_bits_per_int * (int_index - 1) + 1 unrolled_reduce( - StaticOneTo(min(n_bits_per_int, N - first_index + 1)); - init = zero(U), + StaticOneTo(min(n_bits_per_int, N - first_index + 1)), + zero(U), ) do int, bit_index @inline bit_offset = bit_index - 1 @@ -93,15 +93,15 @@ end n_bits_per_int = 8 * sizeof(U) n_ints = cld(N, n_bits_per_int) ints = unrolled_accumulate( - StaticOneTo(n_ints); - init = (nothing, init), - transform = first, + StaticOneTo(n_ints), + (nothing, init), + first, ) do (_, init_value_for_new_int), int_index @inline first_index = n_bits_per_int * (int_index - 1) + 1 unrolled_reduce( - StaticOneTo(min(n_bits_per_int, N - first_index + 1)); - init = (zero(U), init_value_for_new_int), + StaticOneTo(min(n_bits_per_int, N - first_index + 1)), + (zero(U), init_value_for_new_int), ) do (int, prev_value), bit_index @inline bit_offset = bit_index - 1 diff --git a/src/UnrolledUtilities.jl b/src/UnrolledUtilities.jl index 8377705..8f810f1 100644 --- a/src/UnrolledUtilities.jl +++ b/src/UnrolledUtilities.jl @@ -28,64 +28,75 @@ include("unrollable_iterator_interface.jl") include("recursively_unrolled_functions.jl") include("generatively_unrolled_functions.jl") -@inline unrolled_any(f, itr) = +@inline unrolled_any(f::F, itr) where {F} = (rec_unroll(itr) ? rec_unrolled_any : gen_unrolled_any)(f, itr) @inline unrolled_any(itr) = unrolled_any(identity, itr) -@inline unrolled_all(f, itr) = +@inline unrolled_all(f::F, itr) where {F} = (rec_unroll(itr) ? rec_unrolled_all : gen_unrolled_all)(f, itr) @inline unrolled_all(itr) = unrolled_all(identity, itr) -@inline unrolled_foreach(f, itr) = +@inline unrolled_foreach(f::F, itr) where {F} = (rec_unroll(itr) ? rec_unrolled_foreach : gen_unrolled_foreach)(f, itr) @inline unrolled_foreach(f, itrs...) = unrolled_foreach(splat(f), zip(itrs...)) -@inline unrolled_map_into_tuple(f, itr) = +@inline unrolled_map_into_tuple(f::F, itr) where {F} = (rec_unroll(itr) ? rec_unrolled_map : gen_unrolled_map)(f, itr) -@inline unrolled_map_into(output_type, f, itr) = +@inline unrolled_map_into(output_type, f::F, itr) where {F} = constructor_from_tuple(output_type)(unrolled_map_into_tuple(f, itr)) -@inline unrolled_map(f, itr) = +@inline unrolled_map(f::F, itr) where {F} = unrolled_map_into(inferred_output_type(Iterators.map(f, itr)), f, itr) -@inline unrolled_map(f, itrs...) = unrolled_map(splat(f), zip(itrs...)) +@inline unrolled_map(f::F, itrs...) where {F} = + unrolled_map(splat(f), zip(itrs...)) -@inline unrolled_applyat(f, n, itr) = +@inline unrolled_applyat(f::F, n, itr) where {F} = (rec_unroll(itr) ? rec_unrolled_applyat : gen_unrolled_applyat)(f, n, itr) -@inline unrolled_applyat(f, n, itrs...) = +@inline unrolled_applyat(f::F, n, itrs...) where {F} = unrolled_applyat(splat(f), n, zip(itrs...)) @inline unrolled_applyat_bounds_error() = error("unrolled_applyat has detected an out-of-bounds index") -@inline unrolled_reduce(op, itr, init) = +@inline unrolled_reduce(op::O, itr, init) where {O} = isempty(itr) && init isa NoInit ? error("unrolled_reduce requires an init value for empty iterators") : (rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init) -@inline unrolled_reduce(op, itr; init = NoInit()) = +@inline unrolled_reduce(op::O, itr; init = NoInit()) where {O} = unrolled_reduce(op, itr, init) # TODO: Figure out why unrolled_reduce(op, Val(N), init) compiles faster than # unrolled_reduce(op, StaticOneTo(N), init) for the non-orographic gravity wave # parametrization test in ClimaAtmos, to the point where the StaticOneTo version # completely hangs while the Val version compiles in only a few seconds. -@inline unrolled_reduce(op, val_N::Val, init) = +@inline unrolled_reduce(op::O, val_N::Val, init) where {O} = val_N isa Val{0} && init isa NoInit ? error("unrolled_reduce requires an init value for Val(0)") : val_unrolled_reduce(op, val_N, init) -@inline unrolled_mapreduce(f, op, itrs...; init = NoInit()) = +@inline unrolled_mapreduce(f::F, op::O, itrs...; init = NoInit()) where {F, O} = unrolled_reduce(op, unrolled_map(f, itrs...), init) -@inline unrolled_accumulate_into_tuple(op, itr, init, transform) = +@inline unrolled_accumulate_into_tuple( + op::O, + itr, + init, + transform::T, +) where {O, T} = (rec_unroll(itr) ? rec_unrolled_accumulate : gen_unrolled_accumulate)( op, itr, init, transform, ) -@inline unrolled_accumulate_into(output_type, op, itr, init, transform) = - constructor_from_tuple(output_type)( - unrolled_accumulate_into_tuple(op, itr, init, transform), - ) -@inline unrolled_accumulate(op, itr; init = NoInit(), transform = identity) = +@inline unrolled_accumulate_into( + output_type, + op::O, + itr, + init, + transform::T, +) where {O, T} = constructor_from_tuple(output_type)( + unrolled_accumulate_into_tuple(op, itr, init, transform), +) +@inline unrolled_accumulate(op::O, itr, init, transform::T) where {O, T} = unrolled_accumulate_into( accumulate_output_type(op, itr, init, transform), op, @@ -93,6 +104,12 @@ include("generatively_unrolled_functions.jl") init, transform, ) +@inline unrolled_accumulate( + op::O, + itr; + init = NoInit(), + transform::T = identity, +) where {O, T} = unrolled_accumulate(op, itr, init, transform) @inline unrolled_push_into(output_type, itr, item) = constructor_from_tuple(output_type)((itr..., item)) @@ -122,22 +139,22 @@ include("generatively_unrolled_functions.jl") # Using === instead of == or isequal improves type stability for singletons. @inline unrolled_unique(itr) = - unrolled_reduce(itr; init = inferred_empty(itr)) do unique_items, item + unrolled_reduce(itr, inferred_empty(itr)) do unique_items, item @inline unrolled_in(item, unique_items) ? unique_items : unrolled_push(unique_items, item) end -@inline unrolled_filter(f, itr) = - unrolled_reduce(itr; init = inferred_empty(itr)) do items_with_true_f, item +@inline unrolled_filter(f::F, itr) where {F} = + unrolled_reduce(itr, inferred_empty(itr)) do items_with_true_f, item @inline f(item) ? unrolled_push(items_with_true_f, item) : items_with_true_f end -@inline unrolled_split(f, itr) = +@inline unrolled_split(f::F, itr) where {F} = unrolled_reduce( - itr; - init = (inferred_empty(itr), inferred_empty(itr)), + itr, + (inferred_empty(itr), inferred_empty(itr)), ) do (items_with_true_f, items_with_false_f), item @inline f(item) ? (unrolled_push(items_with_true_f, item), items_with_false_f) : @@ -145,13 +162,13 @@ include("generatively_unrolled_functions.jl") end @inline unrolled_flatten(itr) = - unrolled_reduce(unrolled_append, itr; init = promoted_empty(itr)) + unrolled_reduce(unrolled_append, itr, promoted_empty(itr)) -@inline unrolled_flatmap(f, itrs...) = +@inline unrolled_flatmap(f::F, itrs...) where {F} = unrolled_flatten(unrolled_map(f, itrs...)) @inline unrolled_product(itrs...) = - unrolled_reduce(itrs; init = (promoted_empty(itrs),)) do product_itr, itr + unrolled_reduce(itrs, (promoted_empty(itrs),)) do product_itr, itr @inline unrolled_flatmap(itr) do item @inline