Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: David Wendt <[email protected]>
Co-authored-by: Nghia Truong <[email protected]>
  • Loading branch information
3 people authored Oct 30, 2024
1 parent 6f4db6e commit 8afef8e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
4 changes: 2 additions & 2 deletions cpp/src/groupby/hash/compute_shared_memory_aggs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ __device__ void initialize_shmem_aggregations(cooperative_groups::thread_block c
{
for (auto col_idx = col_start; col_idx < col_end; col_idx++) {
for (auto idx = block.thread_rank(); idx < cardinality; idx += block.num_threads()) {
cuda::std::byte* target =
auto target =
reinterpret_cast<cuda::std::byte*>(shmem_agg_storage + shmem_agg_res_offsets[col_idx]);
bool* target_mask =
auto target_mask =
reinterpret_cast<bool*>(shmem_agg_storage + shmem_agg_mask_offsets[col_idx]);
cudf::detail::dispatch_type_and_aggregation(output_values.column(col_idx).type(),
d_agg_kinds[col_idx],
Expand Down
10 changes: 3 additions & 7 deletions cpp/src/groupby/hash/single_pass_functors.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ template <typename T, cudf::aggregation::Kind k>
__device__ T get_identity()
{
if ((k == cudf::aggregation::ARGMAX) || (k == cudf::aggregation::ARGMIN)) {
if constexpr (cudf::is_timestamp<T>())
if constexpr (cudf::is_timestamp<T>()) {
return k == cudf::aggregation::ARGMAX
? T{typename T::duration(cudf::detail::ARGMAX_SENTINEL)}
: T{typename T::duration(cudf::detail::ARGMIN_SENTINEL)};
else {
} else {
using DeviceType = cudf::device_storage_type_t<T>;
return k == cudf::aggregation::ARGMAX
? static_cast<DeviceType>(cudf::detail::ARGMAX_SENTINEL)
Expand Down Expand Up @@ -90,11 +90,7 @@ struct initialize_target_element<Target, k, std::enable_if_t<is_supported<Target

target_casted[idx] = get_identity<DeviceType, k>();

if (k == cudf::aggregation::COUNT_ALL || k == cudf::aggregation::COUNT_VALID) {
target_mask[idx] = true;
} else {
target_mask[idx] = false;
}
target_mask[idx] = (k == cudf::aggregation::COUNT_ALL || k == cudf::aggregation::COUNT_VALID);
}
};

Expand Down

0 comments on commit 8afef8e

Please sign in to comment.