Skip to content

Commit

Permalink
Add to_arrow overload for 32-bit decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
vyasr committed Sep 12, 2023
1 parent 7324b1d commit 7a75d19
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions cpp/src/interop/to_arrow.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,46 @@ struct dispatch_to_arrow {
}
};

template <>
std::shared_ptr<arrow::Array> dispatch_to_arrow::operator()<numeric::decimal32>(
column_view input,
cudf::type_id,
column_metadata const&,
arrow::MemoryPool* ar_mr,
rmm::cuda_stream_view stream)
{
using DeviceType = int32_t;
size_type const BIT_WIDTH_RATIO = 4; // Array::Type:type::DECIMAL (128) / int32_t

rmm::device_uvector<__int128_t> buf(input.size() * BIT_WIDTH_RATIO, stream);

auto count = thrust::make_counting_iterator(0);

thrust::for_each(rmm::exec_policy(cudf::get_default_stream()),
count,
count + input.size(),
[in = input.begin<DeviceType>(), out = buf.data()] __device__(auto in_idx) {
auto const out_idx = in_idx;
auto unsigned_value = in[in_idx] < 0 ? -in[in_idx] : in[in_idx];
auto unsigned_128bit = static_cast<__int128_t>(unsigned_value);
auto signed_128bit = in[in_idx] < 0 ? -unsigned_128bit : unsigned_128bit;
out[out_idx] = signed_128bit;
});

auto const buf_size_in_bytes = buf.size() * sizeof(DeviceType);
auto data_buffer = allocate_arrow_buffer(buf_size_in_bytes, ar_mr);

CUDF_CUDA_TRY(cudaMemcpyAsync(
data_buffer->mutable_data(), buf.data(), buf_size_in_bytes, cudaMemcpyDefault, stream.value()));

auto type = arrow::decimal(9, -input.type().scale());
auto mask = fetch_mask_buffer(input, ar_mr, stream);
auto buffers = std::vector<std::shared_ptr<arrow::Buffer>>{mask, std::move(data_buffer)};
auto data = std::make_shared<arrow::ArrayData>(type, input.size(), buffers);

return std::make_shared<arrow::Decimal128Array>(data);
}

template <>
std::shared_ptr<arrow::Array> dispatch_to_arrow::operator()<numeric::decimal64>(
column_view input,
Expand Down

0 comments on commit 7a75d19

Please sign in to comment.