From 91f7943e61c75dffff4bcff382dd62b2be167978 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Tue, 12 Sep 2023 18:28:58 -0700 Subject: [PATCH] add DELTA_BINARY_PACKED encoder --- cpp/src/io/parquet/delta_binary.cuh | 6 - cpp/src/io/parquet/delta_enc.cuh | 269 ++++++++ cpp/src/io/parquet/page_enc.cu | 983 ++++++++++++++++++++-------- cpp/src/io/parquet/parquet_gpu.hpp | 18 + cpp/tests/io/parquet_test.cpp | 18 +- 5 files changed, 999 insertions(+), 295 deletions(-) create mode 100644 cpp/src/io/parquet/delta_enc.cuh diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index 4fc8b9cfb8e..7aecc7f01e0 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -46,12 +46,6 @@ namespace cudf::io::parquet::gpu { // encoded with DELTA_LENGTH_BYTE_ARRAY encoding, which is a DELTA_BINARY_PACKED list of suffix // lengths, followed by the concatenated suffix data. -// TODO: The delta encodings use ULEB128 integers, but for now we're only -// using max 64 bits. Need to see what the performance impact is of using -// __int128_t rather than int64_t. -using uleb128_t = uint64_t; -using zigzag128_t = int64_t; - // we decode one mini-block at a time. max mini-block size seen is 64. constexpr int delta_rolling_buf_size = 128; diff --git a/cpp/src/io/parquet/delta_enc.cuh b/cpp/src/io/parquet/delta_enc.cuh new file mode 100644 index 00000000000..164849edd63 --- /dev/null +++ b/cpp/src/io/parquet/delta_enc.cuh @@ -0,0 +1,269 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "parquet_gpu.hpp" + +#include + +#include + +namespace cudf::io::parquet::gpu { + +namespace delta { + +inline __device__ void put_uleb128(uint8_t*& p, uleb128_t v) +{ + while (v > 0x7f) { + *p++ = v | 0x80; + v >>= 7; + } + *p++ = v; +} + +inline __device__ uint8_t* put_zz128(uint8_t*& p, zigzag128_t v) +{ + zigzag128_t s = (v < 0); + put_uleb128(p, (v ^ -s) * 2 + s); +} + +// a block size of 128, with 4 mini-blocks of 32 values each fits nicely without consuming +// too much shared memory. +constexpr int block_size = 128; +constexpr int num_mini_blocks = 4; +constexpr int values_per_mini_block = block_size / num_mini_blocks; +constexpr int buffer_size = 2 * block_size; + +using block_reduce = cub::BlockReduce; +using warp_reduce = cub::WarpReduce; +using index_scan = cub::BlockScan; + +constexpr int rolling_idx(int index) { return rolling_index(index); } + +// version of bit packer that can handle up to 64 bits values. +template +inline __device__ void bitpack_mini_block( + uint8_t* dst, T val, uint32_t count, uint8_t nbits, void* temp_space) +{ + // typing for atomicOr is annoying + using scratch_type = + std::conditional_t, unsigned long long, uint32_t>; + using cudf::detail::warp_size; + T constexpr mask = sizeof(T) * 8 - 1; + auto constexpr div = sizeof(T) * 8; + + auto const lane_id = threadIdx.x % warp_size; + auto const warp_id = threadIdx.x / warp_size; + + scratch_type* scratch = reinterpret_cast(temp_space) + warp_id * warp_size; + + // zero out scratch + scratch[lane_id] = 0; + __syncwarp(); + + // why use bit packing when there's no savings??? + if (nbits == div) { + if (lane_id < count) { + for (int i = 0; i < sizeof(T); i++) { + dst[lane_id * sizeof(T) + i] = val & 0xff; + if constexpr (sizeof(T) > 1) { val >>= 8; } + } + } + __syncwarp(); + return; + } + + if (lane_id <= count) { + // shift symbol left by up to mask bits + WideType v2 = val; + v2 <<= (lane_id * nbits) & mask; + + // Copy N bit word into two N/2 bit words while following C++ strict aliasing rules. + T v1[2]; + memcpy(&v1, &v2, sizeof(WideType)); + + // Atomically write result to scratch + if (v1[0]) { atomicOr(scratch + ((lane_id * nbits) / div), v1[0]); } + if (v1[1]) { atomicOr(scratch + ((lane_id * nbits) / div) + 1, v1[1]); } + } + __syncwarp(); + + // Copy scratch data to final destination + auto available_bytes = (count * nbits + 7) / 8; + + auto scratch_bytes = reinterpret_cast(scratch); + for (uint32_t i = lane_id; i < available_bytes; i += warp_size) { + dst[i] = scratch_bytes[i]; + } + __syncwarp(); +} + +} // namespace delta + +// Object used to turn a stream of integers into a DELTA_BINARY_PACKED stream. This takes as input +// 128 values with validity at a time, saving them until there are enough values for a block +// to be written. +// +// T can only be uint32_t or uint64_t since the DELTA_BINARY_PACKED encoding is only defined for +// INT32 and INT64 physical types +template +class DeltaBinaryPacker { + private: + // static_assert(std::is_same_v || std::is_same_v); + + uint8_t* _dst; // sink to dump encoded values to + size_type _current_idx; // index of first value in buffer + uint32_t _num_values; // total number of values to encode + size_type _values_in_buffer; // current number of values stored in _buffer + T* _buffer; // buffer to store values to be encoded + uint8_t _mb_bits[delta::num_mini_blocks]; // bitwidth for each mini-block + + // pointers to shared scratch memory for the warp and block scans/reduces + delta::index_scan::TempStorage* _scan_tmp; + delta::warp_reduce::TempStorage* _warp_tmp; + delta::block_reduce::TempStorage* _block_tmp; + + void* _bitpack_tmp; // pointer to shared scratch memory used in bitpacking + + // write the delta binary header. only call from thread 0 + inline __device__ void write_header(T first_value) + { + delta::put_uleb128(_dst, delta::block_size); + delta::put_uleb128(_dst, delta::num_mini_blocks); + delta::put_uleb128(_dst, _num_values); + delta::put_zz128(_dst, first_value); + } + + // write the block header. only call from thread 0 + inline __device__ void write_block_header(zigzag128_t block_min) + { + delta::put_zz128(_dst, block_min); + memcpy(_dst, _mb_bits, 4); + _dst += 4; + } + + public: + inline __device__ auto num_values() const { return _num_values; } + + // initialize the object. only call from thread 0 + inline __device__ void init(uint8_t* dest, uint32_t num_values, T* buffer, void* temp_storage) + { + _dst = dest; + _num_values = num_values; + _buffer = buffer; + _scan_tmp = reinterpret_cast(temp_storage); + _warp_tmp = reinterpret_cast(temp_storage); + _block_tmp = reinterpret_cast(temp_storage); + _bitpack_tmp = _buffer + delta::buffer_size; + _current_idx = 0; + _values_in_buffer = 0; + } + + // each thread calls this to add it's current value + inline __device__ void add_value(T value, bool is_valid) + { + // figure out the correct position for the given value + size_type const valid = is_valid; + size_type pos; + size_type num_valid; + delta::index_scan(*_scan_tmp).ExclusiveSum(valid, pos, num_valid); + + if (is_valid) { _buffer[delta::rolling_idx(pos + _current_idx + _values_in_buffer)] = value; } + __syncthreads(); + + if (threadIdx.x == 0) { + _values_in_buffer += num_valid; + // if first pass write header + if (_current_idx == 0) { + write_header(_buffer[0]); + _current_idx = 1; + _values_in_buffer -= 1; + } + } + __syncthreads(); + + if (_values_in_buffer >= delta::block_size) { flush(); } + } + + // called by each thread to flush data to the sink. + inline __device__ uint8_t const* flush() + { + using cudf::detail::warp_size; + __shared__ zigzag128_t block_min; + + int const t = threadIdx.x; + int const warp_id = t / warp_size; + int const lane_id = t % warp_size; + + if (_values_in_buffer <= 0) { return _dst; } + + // calculate delta for this thread + size_type const idx = _current_idx + t; + zigzag128_t const delta = + idx < _num_values ? _buffer[delta::rolling_idx(idx)] - _buffer[delta::rolling_idx(idx - 1)] + : std::numeric_limits::max(); + + // find min delta for the block + auto const min_delta = delta::block_reduce(*_block_tmp).Reduce(delta, cub::Min()); + + if (t == 0) { block_min = min_delta; } + __syncthreads(); + + // compute frame of reference for the block + uleb128_t const norm_delta = idx < _num_values ? delta - block_min : 0; + + // get max normalized delta for each warp, and use that to determine how many bits to use + // for the bitpacking of this warp + zigzag128_t const warp_max = + delta::warp_reduce(_warp_tmp[warp_id]).Reduce(norm_delta, cub::Max()); + + if (lane_id == 0) { _mb_bits[warp_id] = sizeof(zigzag128_t) * 8 - __clzll(warp_max); } + __syncthreads(); + + // write block header + if (t == 0) { write_block_header(block_min); } + __syncthreads(); + + // now each warp encodes it's data...can calculate starting offset with _mb_bits + uint8_t* mb_ptr = _dst; + switch (warp_id) { + case 3: mb_ptr += _mb_bits[2] * delta::values_per_mini_block / 8; [[fallthrough]]; + case 2: mb_ptr += _mb_bits[1] * delta::values_per_mini_block / 8; [[fallthrough]]; + case 1: mb_ptr += _mb_bits[0] * delta::values_per_mini_block / 8; + } + + // encoding happens here....will have to update pack literals to deal with larger numbers + auto const warp_idx = _current_idx + warp_id * delta::values_per_mini_block; + if (warp_idx < _num_values) { + auto const num_enc = min(delta::values_per_mini_block, _num_values - warp_idx); + delta::bitpack_mini_block( + mb_ptr, norm_delta, num_enc, _mb_bits[warp_id], _bitpack_tmp); + } + + // last lane updates global delta ptr + if (warp_id == delta::num_mini_blocks - 1 && lane_id == 0) { + _dst = mb_ptr + _mb_bits[warp_id] * delta::values_per_mini_block / 8; + _current_idx = min(warp_idx + delta::values_per_mini_block, _num_values); + _values_in_buffer = max(_values_in_buffer - delta::block_size, 0U); + } + __syncthreads(); + + return _dst; + } +}; + +} // namespace cudf::io::parquet::gpu diff --git a/cpp/src/io/parquet/page_enc.cu b/cpp/src/io/parquet/page_enc.cu index 0af561be8da..fe212ec6714 100644 --- a/cpp/src/io/parquet/page_enc.cu +++ b/cpp/src/io/parquet/page_enc.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "delta_enc.cuh" #include "parquet_gpu.cuh" #include @@ -21,6 +22,7 @@ #include #include #include +#include #include #include @@ -41,6 +43,8 @@ #include #include +#include + namespace cudf { namespace io { namespace parquet { @@ -50,7 +54,11 @@ namespace { using ::cudf::detail::device_2dspan; -constexpr uint32_t rle_buffer_size = (1 << 9); +constexpr int encode_block_size = 128; +constexpr int rle_buffer_size = 2 * encode_block_size; +constexpr int num_encode_warps = encode_block_size / cudf::detail::warp_size; + +constexpr int rolling_idx(int pos) { return rolling_index(pos); } // do not truncate statistics constexpr int32_t NO_TRUNC_STATS = 0; @@ -72,6 +80,7 @@ struct frag_init_state_s { PageFragment frag; }; +template struct page_enc_state_s { uint8_t* cur; //!< current output ptr uint8_t* rle_out; //!< current RLE write ptr @@ -84,12 +93,11 @@ struct page_enc_state_s { uint32_t rle_rpt_count; uint32_t page_start_val; uint32_t chunk_start_val; - volatile uint32_t rpt_map[4]; - volatile uint32_t scratch_red[32]; + volatile uint32_t rpt_map[num_encode_warps]; EncPage page; EncColumnChunk ck; parquet_column_device_view col; - uint32_t vals[rle_buffer_size]; + uint32_t vals[rle_buf_size]; }; /** @@ -239,6 +247,49 @@ struct BitwiseOr { } }; +// T is the parquet physical type +// W is double the bitwidth of T +// I is the column type from the input table +// F is a function that computes validity and the src index for a given input position +template +struct delta_enc { + page_enc_state_s<0>* s; + uint32_t valid_count; + F& f; + uint64_t* buffer; + void* temp_space; + + __device__ uint8_t const* encode() + { + __shared__ DeltaBinaryPacker packer; + + auto const t = threadIdx.x; + + if (t == 0) { packer.init(s->cur, valid_count, reinterpret_cast(buffer), temp_space); } + __syncthreads(); + + // FIXME int the plain encoder the scaling is a little different for INT32 than INT64. + // might need to patch this up some. + int32_t const scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale; + for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { + uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, delta::block_size); + + auto [is_valid, val_idx] = f(cur_val_idx); + cur_val_idx += nvals; + + T v = s->col.leaf_column->element(val_idx); + if (scale < 0) { + v /= -scale; + } else { + v *= scale; + } + packer.add_value(v, is_valid); + } + + return packer.flush(); + } +}; + } // anonymous namespace // blockDim {512,1,1} @@ -326,6 +377,31 @@ __global__ void __launch_bounds__(128) } } +__device__ size_t delta_data_len(parquet::Type physical_type, + cudf::type_id type_id, + uint32_t num_values) +{ + auto const dtype_len_out = physical_type_len(physical_type, type_id); + auto const dtype_len = [&]() -> uint32_t { + if (physical_type == INT32) { return int32_logical_len(type_id); } + if (physical_type == INT96) { return sizeof(int64_t); } + return dtype_len_out; + }(); + + auto const vals_per_block = delta::block_size; + size_t const num_blocks = util::div_rounding_up_unsafe(num_values, vals_per_block); + // need max dtype_len_in + 1 bytes for min_delta + // one byte per mini block for the bitwidth + // and block_size * dtype_len_in bytes for the actual encoded data + auto const block_size = dtype_len + 1 + delta::num_mini_blocks + vals_per_block * dtype_len; + + // delta header is 2 bytes for the block_size, 1 byte for number of mini-blocks, + // max 5 bytes for number of values, and max dtype_len_in + 1 for first value. + auto const header_size = 2 + 1 + 5 + dtype_len + 1; + + return header_size + num_blocks * block_size; +} + // blockDim {128,1,1} __global__ void __launch_bounds__(128) gpuInitPages(device_2dspan chunks, @@ -357,6 +433,14 @@ __global__ void __launch_bounds__(128) page_g = {}; } __syncthreads(); + + // if writing delta encoded values, we're going to need to know the data length to get a guess + // at the worst case number of bytes needed to encode. + auto const physical_type = col_g.physical_type; + auto const type_id = col_g.leaf_column->type().id(); + auto const is_use_delta = + write_v2_headers && !ck_g.use_dictionary && (physical_type == INT32 || physical_type == INT64); + if (t < 32) { uint32_t fragments_in_chunk = 0; uint32_t rows_in_page = 0; @@ -406,9 +490,12 @@ __global__ void __launch_bounds__(128) } __syncwarp(); if (t == 0) { - if (not pages.empty()) pages[ck_g.first_page] = page_g; - if (not page_sizes.empty()) page_sizes[ck_g.first_page] = page_g.max_data_size; - if (page_grstats) page_grstats[ck_g.first_page] = pagestats_g; + if (not pages.empty()) { + page_g.kernel_mask = ENC_MASK_PLAIN; + pages[ck_g.first_page] = page_g; + } + if (not page_sizes.empty()) { page_sizes[ck_g.first_page] = page_g.max_data_size; } + if (page_grstats) { page_grstats[ck_g.first_page] = pagestats_g; } } num_pages = 1; } @@ -508,7 +595,12 @@ __global__ void __launch_bounds__(128) page_g.num_values = values_in_page; auto const def_level_size = max_RLE_page_size(col_g.num_def_level_bits(), values_in_page); auto const rep_level_size = max_RLE_page_size(col_g.num_rep_level_bits(), values_in_page); - auto const max_data_size = page_size + def_level_size + rep_level_size + rle_pad; + // get a different bound if using delta encoding + if (is_use_delta) { + page_size = + max(page_size, delta_data_len(physical_type, type_id, page_g.num_leaf_values)); + } + auto const max_data_size = page_size + def_level_size + rep_level_size + rle_pad; // page size must fit in 32-bit signed integer if (max_data_size > std::numeric_limits::max()) { CUDF_UNREACHABLE("page size exceeds maximum for i32"); @@ -528,7 +620,16 @@ __global__ void __launch_bounds__(128) } __syncwarp(); if (t == 0) { - if (not pages.empty()) { pages[ck_g.first_page + num_pages] = page_g; } + if (not pages.empty()) { + if (is_use_delta) { + page_g.kernel_mask = ENC_MASK_DELTA_BINARY; + } else if (ck_g.use_dictionary || physical_type == BOOLEAN) { + page_g.kernel_mask = ENC_MASK_DICTIONARY; + } else { + page_g.kernel_mask = ENC_MASK_PLAIN; + } + pages[ck_g.first_page + num_pages] = page_g; + } if (not page_sizes.empty()) { page_sizes[ck_g.first_page + num_pages] = page_g.max_data_size; } @@ -791,9 +892,14 @@ inline __device__ void PackLiterals( * @param[in] flush nonzero if last batch in block * @param[in] t thread id (0..127) */ +template static __device__ void RleEncode( - page_enc_state_s* s, uint32_t numvals, uint32_t nbits, uint32_t flush, uint32_t t) + state_buf* s, uint32_t numvals, uint32_t nbits, uint32_t flush, uint32_t t) { + using cudf::detail::warp_size; + auto const lane_id = t % warp_size; + auto const warp_id = t / warp_size; + uint32_t rle_pos = s->rle_pos; uint32_t rle_run = s->rle_run; @@ -801,20 +907,20 @@ static __device__ void RleEncode( uint32_t pos = rle_pos + t; if (rle_run > 0 && !(rle_run & 1)) { // Currently in a long repeat run - uint32_t mask = ballot(pos < numvals && s->vals[pos & (rle_buffer_size - 1)] == s->run_val); + uint32_t mask = ballot(pos < numvals && s->vals[rolling_idx(pos)] == s->run_val); uint32_t rle_rpt_count, max_rpt_count; - if (!(t & 0x1f)) { s->rpt_map[t >> 5] = mask; } + if (lane_id == 0) { s->rpt_map[warp_id] = mask; } __syncthreads(); - if (t < 32) { + if (t < warp_size) { uint32_t c32 = ballot(t >= 4 || s->rpt_map[t] != 0xffff'ffffu); - if (!t) { + if (t == 0) { uint32_t last_idx = __ffs(c32) - 1; s->rle_rpt_count = - last_idx * 32 + ((last_idx < 4) ? __ffs(~s->rpt_map[last_idx]) - 1 : 0); + last_idx * warp_size + ((last_idx < 4) ? __ffs(~s->rpt_map[last_idx]) - 1 : 0); } } __syncthreads(); - max_rpt_count = min(numvals - rle_pos, 128); + max_rpt_count = min(numvals - rle_pos, encode_block_size); rle_rpt_count = s->rle_rpt_count; rle_run += rle_rpt_count << 1; rle_pos += rle_rpt_count; @@ -831,17 +937,17 @@ static __device__ void RleEncode( } } else { // New run or in a literal run - uint32_t v0 = s->vals[pos & (rle_buffer_size - 1)]; - uint32_t v1 = s->vals[(pos + 1) & (rle_buffer_size - 1)]; + uint32_t v0 = s->vals[rolling_idx(pos)]; + uint32_t v1 = s->vals[rolling_idx(pos + 1)]; uint32_t mask = ballot(pos + 1 < numvals && v0 == v1); - uint32_t maxvals = min(numvals - rle_pos, 128); + uint32_t maxvals = min(numvals - rle_pos, encode_block_size); uint32_t rle_lit_count, rle_rpt_count; - if (!(t & 0x1f)) { s->rpt_map[t >> 5] = mask; } + if (lane_id == 0) { s->rpt_map[warp_id] = mask; } __syncthreads(); - if (t < 32) { + if (t < warp_size) { // Repeat run can only start on a multiple of 8 values - uint32_t idx8 = (t * 8) >> 5; - uint32_t pos8 = (t * 8) & 0x1f; + uint32_t idx8 = (t * 8) / warp_size; + uint32_t pos8 = (t * 8) % warp_size; uint32_t m0 = (idx8 < 4) ? s->rpt_map[idx8] : 0; uint32_t m1 = (idx8 < 3) ? s->rpt_map[idx8 + 1] : 0; uint32_t needed_mask = kRleRunMask[nbits - 1]; @@ -850,8 +956,8 @@ static __device__ void RleEncode( uint32_t rle_run_start = (mask != 0) ? min((__ffs(mask) - 1) * 8, maxvals) : maxvals; uint32_t rpt_len = 0; if (rle_run_start < maxvals) { - uint32_t idx_cur = rle_run_start >> 5; - uint32_t idx_ofs = rle_run_start & 0x1f; + uint32_t idx_cur = rle_run_start / warp_size; + uint32_t idx_ofs = rle_run_start % warp_size; while (idx_cur < 4) { m0 = (idx_cur < 4) ? s->rpt_map[idx_cur] : 0; m1 = (idx_cur < 3) ? s->rpt_map[idx_cur + 1] : 0; @@ -860,7 +966,7 @@ static __device__ void RleEncode( rpt_len += __ffs(mask) - 1; break; } - rpt_len += 32; + rpt_len += warp_size; idx_cur++; } } @@ -931,17 +1037,15 @@ static __device__ void RleEncode( * @param[in] flush nonzero if last batch in block * @param[in] t thread id (0..127) */ -static __device__ void PlainBoolEncode(page_enc_state_s* s, - uint32_t numvals, - uint32_t flush, - uint32_t t) +template +static __device__ void PlainBoolEncode(state_buf* s, uint32_t numvals, uint32_t flush, uint32_t t) { uint32_t rle_pos = s->rle_pos; uint8_t* dst = s->rle_out; while (rle_pos < numvals) { uint32_t pos = rle_pos + t; - uint32_t v = (pos < numvals) ? s->vals[pos & (rle_buffer_size - 1)] : 0; + uint32_t v = (pos < numvals) ? s->vals[rolling_idx(pos)] : 0; uint32_t n = min(numvals - rle_pos, 128); uint32_t nbytes = (n + ((flush) ? 7 : 0)) >> 3; if (!nbytes) { break; } @@ -995,28 +1099,22 @@ __device__ auto julian_days_with_time(int64_t v) return std::make_pair(dur_time_of_day_nanos, julian_days); } +// this has been split out into its own kernel because of the amount of shared memory required +// for the state buffer. encode kernels that don't use the RLE buffer can get started while +// the level data is encoded. +// FIXME: what should the args to launch_bounds be now? // blockDim(128, 1, 1) -template -__global__ void __launch_bounds__(128, 8) - gpuEncodePages(device_span pages, - device_span> comp_in, - device_span> comp_out, - device_span comp_results, - bool write_v2_headers) +template +__global__ void __launch_bounds__(block_size, 8) + gpuEncodePageLevels(device_span pages, bool write_v2_headers) { - __shared__ __align__(8) page_enc_state_s state_g; - using block_reduce = cub::BlockReduce; - using block_scan = cub::BlockScan; - __shared__ union { - typename block_reduce::TempStorage reduce_storage; - typename block_scan::TempStorage scan_storage; - } temp_storage; + __shared__ __align__(8) page_enc_state_s state_g; - page_enc_state_s* const s = &state_g; - auto const t = threadIdx.x; + auto* const s = &state_g; + uint32_t const t = threadIdx.x; if (t == 0) { - state_g = page_enc_state_s{}; + state_g = page_enc_state_s{}; s->page = pages[blockIdx.x]; s->ck = *s->page.chunk; s->col = *s->ck.col_desc; @@ -1029,6 +1127,8 @@ __global__ void __launch_bounds__(128, 8) } __syncthreads(); + if ((s->page.kernel_mask & kernel_mask) == 0) { return; } + auto const is_v2 = s->page.page_type == PageType::DATA_PAGE_V2; // Encode Repetition and Definition levels @@ -1081,7 +1181,7 @@ __global__ void __launch_bounds__(128, 8) } while (is_col_struct); return def; }(); - s->vals[(rle_numvals + t) & (rle_buffer_size - 1)] = def_lvl; + s->vals[rolling_idx(rle_numvals + t)] = def_lvl; __syncthreads(); rle_numvals += nrows; RleEncode(s, rle_numvals, def_lvl_bits, (rle_numvals == s->page.num_rows), t); @@ -1091,13 +1191,12 @@ __global__ void __launch_bounds__(128, 8) uint8_t* const cur = s->cur; uint8_t* const rle_out = s->rle_out; uint32_t const rle_bytes = static_cast(rle_out - cur) - (is_v2 ? 0 : 4); - if (is_v2 && t == 0) { + if (not is_v2 && t < 4) { cur[t] = rle_bytes >> (t * 8); } + __syncwarp(); + if (t == 0) { + s->cur = rle_out; s->page.def_lvl_bytes = rle_bytes; - } else if (not is_v2 && t < 4) { - cur[t] = rle_bytes >> (t * 8); } - __syncwarp(); - if (t == 0) { s->cur = rle_out; } } } } else if (s->page.page_type != PageType::DICTIONARY_PAGE && @@ -1124,7 +1223,7 @@ __global__ void __launch_bounds__(128, 8) uint32_t idx = page_first_val_idx + rle_numvals + t; uint32_t lvl_val = (rle_numvals + t < s->page.num_values && idx < col_last_val_idx) ? lvl_val_data[idx] : 0; - s->vals[(rle_numvals + t) & (rle_buffer_size - 1)] = lvl_val; + s->vals[rolling_idx(rle_numvals + t)] = lvl_val; __syncthreads(); rle_numvals += nvals; RleEncode(s, rle_numvals, nbits, (rle_numvals == s->page.num_values), t); @@ -1134,19 +1233,109 @@ __global__ void __launch_bounds__(128, 8) uint8_t* const cur = s->cur; uint8_t* const rle_out = s->rle_out; uint32_t const rle_bytes = static_cast(rle_out - cur) - (is_v2 ? 0 : 4); - if (is_v2 && t == 0) { + if (not is_v2 && t < 4) { cur[t] = rle_bytes >> (t * 8); } + __syncwarp(); + if (t == 0) { + s->cur = rle_out; lvl_bytes = rle_bytes; - } else if (not is_v2 && t < 4) { - cur[t] = rle_bytes >> (t * 8); } - __syncwarp(); - if (t == 0) { s->cur = rle_out; } } }; encode_levels(s->col.rep_values, s->col.num_rep_level_bits(), s->page.rep_lvl_bytes); __syncthreads(); encode_levels(s->col.def_values, s->col.num_def_level_bits(), s->page.def_lvl_bytes); } + + if (t == 0) { pages[blockIdx.x] = s->page; } +} + +template +__device__ void finish_page_encode(state_buf* s, + uint32_t valid_count, + uint8_t const* end_ptr, + device_span pages, + device_span> comp_in, + device_span> comp_out, + device_span comp_results, + bool write_v2_headers) +{ + auto const t = threadIdx.x; + + // V2 does not compress rep and def level data + size_t const skip_comp_size = + write_v2_headers ? s->page.def_lvl_bytes + s->page.rep_lvl_bytes : 0; + + if (t == 0) { + // only need num_nulls for v2 data page headers + if (write_v2_headers) { s->page.num_nulls = s->page.num_values - valid_count; } + uint8_t const* const base = s->page.page_data + s->page.max_hdr_size; + auto const actual_data_size = static_cast(end_ptr - base); + if (actual_data_size > s->page.max_data_size) { + printf("data corruption %d %d\n", actual_data_size, s->page.max_data_size); + CUDF_UNREACHABLE("detected possible page data corruption"); + } + s->page.max_data_size = actual_data_size; + if (not comp_in.empty()) { + comp_in[blockIdx.x] = {base + skip_comp_size, actual_data_size - skip_comp_size}; + comp_out[blockIdx.x] = {s->page.compressed_data + s->page.max_hdr_size + skip_comp_size, + 0}; // size is unused + } + pages[blockIdx.x] = s->page; + if (not comp_results.empty()) { + comp_results[blockIdx.x] = {0, compression_status::FAILURE}; + pages[blockIdx.x].comp_res = &comp_results[blockIdx.x]; + } + } + + // copy uncompressed bytes over + if (skip_comp_size != 0 && not comp_in.empty()) { + uint8_t* src = s->page.page_data + s->page.max_hdr_size; + uint8_t* dst = s->page.compressed_data + s->page.max_hdr_size; + for (int i = t; i < skip_comp_size; i += block_size) { + dst[i] = src[i]; + } + } +} + +// FIXME: what should the args to launch_bounds be now? +// blockDim(128, 1, 1) +template +__global__ void __launch_bounds__(block_size, 8) + gpuEncodePages(device_span pages, + device_span> comp_in, + device_span> comp_out, + device_span comp_results, + bool write_v2_headers) +{ + __shared__ __align__(8) page_enc_state_s<0> state_g; + using block_reduce = cub::BlockReduce; + using block_scan = cub::BlockScan; + __shared__ union { + typename block_reduce::TempStorage reduce_storage; + typename block_scan::TempStorage scan_storage; + } temp_storage; + + auto* const s = &state_g; + uint32_t t = threadIdx.x; + + if (t == 0) { + state_g = page_enc_state_s<0>{}; + s->page = pages[blockIdx.x]; + s->ck = *s->page.chunk; + s->col = *s->ck.col_desc; + s->rle_len_pos = nullptr; + // get s->cur back to where it was at the end of encoding the rep and def level data + s->cur = + s->page.page_data + s->page.max_hdr_size + s->page.def_lvl_bytes + s->page.rep_lvl_bytes; + if (s->page.page_type == PageType::DATA_PAGE) { + if (s->col.num_def_level_bits() != 0) { s->cur += 4; } + if (s->col.num_rep_level_bits() != 0) { s->cur += 4; } + } + } + __syncthreads(); + + if ((s->page.kernel_mask & ENC_MASK_PLAIN) == 0) { return; } + // Encode data values __syncthreads(); auto const physical_type = s->col.physical_type; @@ -1158,10 +1347,6 @@ __global__ void __launch_bounds__(128, 8) return dtype_len_out; }(); - auto const dict_bits = (physical_type == BOOLEAN) ? 1 - : (s->ck.use_dictionary and s->page.page_type != PageType::DICTIONARY_PAGE) - ? s->ck.dict_rle_bits - : -1; if (t == 0) { uint8_t* dst = s->cur; s->rle_run = 0; @@ -1170,219 +1355,314 @@ __global__ void __launch_bounds__(128, 8) s->rle_out = dst; s->page.encoding = determine_encoding(s->page.page_type, physical_type, s->ck.use_dictionary, write_v2_headers); - if (dict_bits >= 0 && physical_type != BOOLEAN) { - dst[0] = dict_bits; - s->rle_out = dst + 1; - } else if (is_v2 && physical_type == BOOLEAN) { - // save space for RLE length. we don't know the total length yet. - s->rle_out = dst + RLE_LENGTH_FIELD_LEN; - s->rle_len_pos = dst; - } s->page_start_val = row_to_value_idx(s->page.start_row, s->col); s->chunk_start_val = row_to_value_idx(s->ck.start_row, s->col); } __syncthreads(); + uint32_t num_valid = 0; for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { - uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, 128); + uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, block_size); uint32_t len, pos; auto [is_valid, val_idx] = [&]() { uint32_t val_idx; uint32_t is_valid; - size_type val_idx_in_block = cur_val_idx + t; + size_type const val_idx_in_block = cur_val_idx + t; if (s->page.page_type == PageType::DICTIONARY_PAGE) { val_idx = val_idx_in_block; is_valid = (val_idx < s->page.num_leaf_values); if (is_valid) { val_idx = s->ck.dict_data[val_idx]; } } else { - size_type val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; + size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && val_idx_in_block < s->page.num_leaf_values) ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) : 0; - val_idx = - (s->ck.use_dictionary) ? val_idx_in_leaf_col - s->chunk_start_val : val_idx_in_leaf_col; + val_idx = val_idx_in_leaf_col; } return std::make_tuple(is_valid, val_idx); }(); - if (is_valid) num_valid++; - + if (is_valid) { num_valid++; } cur_val_idx += nvals; - if (dict_bits >= 0) { - // Dictionary encoding - if (dict_bits > 0) { - uint32_t rle_numvals; - uint32_t rle_numvals_in_block; - block_scan(temp_storage.scan_storage).ExclusiveSum(is_valid, pos, rle_numvals_in_block); - rle_numvals = s->rle_numvals; - if (is_valid) { - uint32_t v; - if (physical_type == BOOLEAN) { - v = s->col.leaf_column->element(val_idx); - } else { - v = s->ck.dict_index[val_idx]; - } - s->vals[(rle_numvals + pos) & (rle_buffer_size - 1)] = v; - } - rle_numvals += rle_numvals_in_block; - __syncthreads(); - if (!is_v2 && physical_type == BOOLEAN) { - PlainBoolEncode(s, rle_numvals, (cur_val_idx == s->page.num_leaf_values), t); - } else { - RleEncode(s, rle_numvals, dict_bits, (cur_val_idx == s->page.num_leaf_values), t); + + // Non-dictionary encoding + uint8_t* dst = s->cur; + + if (is_valid) { + len = dtype_len_out; + if (physical_type == BYTE_ARRAY) { + if (type_id == type_id::STRING) { + len += s->col.leaf_column->element(val_idx).size_bytes(); + } else if (s->col.output_as_byte_array && type_id == type_id::LIST) { + len += + get_element(*s->col.leaf_column, val_idx).size_bytes(); } - __syncthreads(); } - if (t == 0) { s->cur = s->rle_out; } - __syncthreads(); } else { - // Non-dictionary encoding - uint8_t* dst = s->cur; - - if (is_valid) { - len = dtype_len_out; - if (physical_type == BYTE_ARRAY) { - if (type_id == type_id::STRING) { - len += s->col.leaf_column->element(val_idx).size_bytes(); - } else if (s->col.output_as_byte_array && type_id == type_id::LIST) { - len += - get_element(*s->col.leaf_column, val_idx).size_bytes(); + len = 0; + } + uint32_t total_len = 0; + block_scan(temp_storage.scan_storage).ExclusiveSum(len, pos, total_len); + __syncthreads(); + if (t == 0) { s->cur = dst + total_len; } + if (is_valid) { + switch (physical_type) { + case INT32: [[fallthrough]]; + case FLOAT: { + auto const v = [dtype_len = dtype_len_in, + idx = val_idx, + col = s->col.leaf_column, + scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale]() -> int32_t { + switch (dtype_len) { + case 8: return col->element(idx) * scale; + case 4: return col->element(idx) * scale; + case 2: return col->element(idx) * scale; + default: return col->element(idx) * scale; + } + }(); + + dst[pos + 0] = v; + dst[pos + 1] = v >> 8; + dst[pos + 2] = v >> 16; + dst[pos + 3] = v >> 24; + } break; + case INT64: { + int64_t v = s->col.leaf_column->element(val_idx); + int32_t ts_scale = s->col.ts_scale; + if (ts_scale != 0) { + if (ts_scale < 0) { + v /= -ts_scale; + } else { + v *= ts_scale; + } + } + dst[pos + 0] = v; + dst[pos + 1] = v >> 8; + dst[pos + 2] = v >> 16; + dst[pos + 3] = v >> 24; + dst[pos + 4] = v >> 32; + dst[pos + 5] = v >> 40; + dst[pos + 6] = v >> 48; + dst[pos + 7] = v >> 56; + } break; + case INT96: { + int64_t v = s->col.leaf_column->element(val_idx); + int32_t ts_scale = s->col.ts_scale; + if (ts_scale != 0) { + if (ts_scale < 0) { + v /= -ts_scale; + } else { + v *= ts_scale; + } } - } - } else { - len = 0; - } - uint32_t total_len = 0; - block_scan(temp_storage.scan_storage).ExclusiveSum(len, pos, total_len); - __syncthreads(); - if (t == 0) { s->cur = dst + total_len; } - if (is_valid) { - switch (physical_type) { - case INT32: [[fallthrough]]; - case FLOAT: { - auto const v = [dtype_len = dtype_len_in, - idx = val_idx, - col = s->col.leaf_column, - scale = s->col.ts_scale == 0 ? 1 : s->col.ts_scale]() -> int32_t { - switch (dtype_len) { - case 8: return col->element(idx) * scale; - case 4: return col->element(idx) * scale; - case 2: return col->element(idx) * scale; - default: return col->element(idx) * scale; - } - }(); - dst[pos + 0] = v; - dst[pos + 1] = v >> 8; - dst[pos + 2] = v >> 16; - dst[pos + 3] = v >> 24; - } break; - case INT64: { - int64_t v = s->col.leaf_column->element(val_idx); - int32_t ts_scale = s->col.ts_scale; - if (ts_scale != 0) { - if (ts_scale < 0) { - v /= -ts_scale; - } else { - v *= ts_scale; - } + auto const [last_day_nanos, julian_days] = [&] { + using namespace cuda::std::chrono; + switch (s->col.leaf_column->type().id()) { + case type_id::TIMESTAMP_SECONDS: + case type_id::TIMESTAMP_MILLISECONDS: { + return julian_days_with_time(v); + } break; + case type_id::TIMESTAMP_MICROSECONDS: + case type_id::TIMESTAMP_NANOSECONDS: { + return julian_days_with_time(v); + } break; } - dst[pos + 0] = v; - dst[pos + 1] = v >> 8; - dst[pos + 2] = v >> 16; - dst[pos + 3] = v >> 24; - dst[pos + 4] = v >> 32; - dst[pos + 5] = v >> 40; - dst[pos + 6] = v >> 48; - dst[pos + 7] = v >> 56; - } break; - case INT96: { - int64_t v = s->col.leaf_column->element(val_idx); - int32_t ts_scale = s->col.ts_scale; - if (ts_scale != 0) { - if (ts_scale < 0) { - v /= -ts_scale; - } else { - v *= ts_scale; - } + return julian_days_with_time(0); + }(); + + // the 12 bytes of fixed length data. + v = last_day_nanos.count(); + dst[pos + 0] = v; + dst[pos + 1] = v >> 8; + dst[pos + 2] = v >> 16; + dst[pos + 3] = v >> 24; + dst[pos + 4] = v >> 32; + dst[pos + 5] = v >> 40; + dst[pos + 6] = v >> 48; + dst[pos + 7] = v >> 56; + uint32_t w = julian_days.count(); + dst[pos + 8] = w; + dst[pos + 9] = w >> 8; + dst[pos + 10] = w >> 16; + dst[pos + 11] = w >> 24; + } break; + + case DOUBLE: { + auto v = s->col.leaf_column->element(val_idx); + memcpy(dst + pos, &v, 8); + } break; + case BYTE_ARRAY: { + auto const bytes = [](cudf::type_id const type_id, + column_device_view const* leaf_column, + uint32_t const val_idx) -> void const* { + switch (type_id) { + case type_id::STRING: + return reinterpret_cast( + leaf_column->element(val_idx).data()); + case type_id::LIST: + return reinterpret_cast( + get_element(*(leaf_column), val_idx).data()); + default: CUDF_UNREACHABLE("invalid type id for byte array writing!"); } + }(type_id, s->col.leaf_column, val_idx); + uint32_t v = len - 4; // string length + dst[pos + 0] = v; + dst[pos + 1] = v >> 8; + dst[pos + 2] = v >> 16; + dst[pos + 3] = v >> 24; + if (v != 0) memcpy(dst + pos + 4, bytes, v); + } break; + case FIXED_LEN_BYTE_ARRAY: { + if (type_id == type_id::DECIMAL128) { + // When using FIXED_LEN_BYTE_ARRAY for decimals, the rep is encoded in big-endian + auto const v = s->col.leaf_column->element(val_idx).value(); + auto const v_char_ptr = reinterpret_cast(&v); + thrust::copy(thrust::seq, + thrust::make_reverse_iterator(v_char_ptr + sizeof(v)), + thrust::make_reverse_iterator(v_char_ptr), + dst + pos); + } + } break; + } + } + __syncthreads(); + } - auto const [last_day_nanos, julian_days] = [&] { - using namespace cuda::std::chrono; - switch (s->col.leaf_column->type().id()) { - case type_id::TIMESTAMP_SECONDS: - case type_id::TIMESTAMP_MILLISECONDS: { - return julian_days_with_time(v); - } break; - case type_id::TIMESTAMP_MICROSECONDS: - case type_id::TIMESTAMP_NANOSECONDS: { - return julian_days_with_time(v); - } break; - } - return julian_days_with_time(0); - }(); - - // the 12 bytes of fixed length data. - v = last_day_nanos.count(); - dst[pos + 0] = v; - dst[pos + 1] = v >> 8; - dst[pos + 2] = v >> 16; - dst[pos + 3] = v >> 24; - dst[pos + 4] = v >> 32; - dst[pos + 5] = v >> 40; - dst[pos + 6] = v >> 48; - dst[pos + 7] = v >> 56; - uint32_t w = julian_days.count(); - dst[pos + 8] = w; - dst[pos + 9] = w >> 8; - dst[pos + 10] = w >> 16; - dst[pos + 11] = w >> 24; - } break; + uint32_t const valid_count = block_reduce(temp_storage.reduce_storage).Sum(num_valid); - case DOUBLE: { - auto v = s->col.leaf_column->element(val_idx); - memcpy(dst + pos, &v, 8); - } break; - case BYTE_ARRAY: { - auto const bytes = [](cudf::type_id const type_id, - column_device_view const* leaf_column, - uint32_t const val_idx) -> void const* { - switch (type_id) { - case type_id::STRING: - return reinterpret_cast( - leaf_column->element(val_idx).data()); - case type_id::LIST: - return reinterpret_cast( - get_element(*(leaf_column), val_idx).data()); - default: CUDF_UNREACHABLE("invalid type id for byte array writing!"); - } - }(type_id, s->col.leaf_column, val_idx); - uint32_t v = len - 4; // string length - dst[pos + 0] = v; - dst[pos + 1] = v >> 8; - dst[pos + 2] = v >> 16; - dst[pos + 3] = v >> 24; - if (v != 0) memcpy(dst + pos + 4, bytes, v); - } break; - case FIXED_LEN_BYTE_ARRAY: { - if (type_id == type_id::DECIMAL128) { - // When using FIXED_LEN_BYTE_ARRAY for decimals, the rep is encoded in big-endian - auto const v = s->col.leaf_column->element(val_idx).value(); - auto const v_char_ptr = reinterpret_cast(&v); - thrust::copy(thrust::seq, - thrust::make_reverse_iterator(v_char_ptr + sizeof(v)), - thrust::make_reverse_iterator(v_char_ptr), - dst + pos); - } - } break; + finish_page_encode( + s, valid_count, s->cur, pages, comp_in, comp_out, comp_results, write_v2_headers); +} + +// FIXME: what should the args to launch_bounds be now? +// blockDim(128, 1, 1) +template +__global__ void __launch_bounds__(block_size, 8) + gpuEncodeDictPages(device_span pages, + device_span> comp_in, + device_span> comp_out, + device_span comp_results, + bool write_v2_headers) +{ + __shared__ __align__(8) page_enc_state_s state_g; + using block_reduce = cub::BlockReduce; + using block_scan = cub::BlockScan; + __shared__ union { + typename block_reduce::TempStorage reduce_storage; + typename block_scan::TempStorage scan_storage; + } temp_storage; + + auto* const s = &state_g; + uint32_t t = threadIdx.x; + + if (t == 0) { + state_g = page_enc_state_s{}; + s->page = pages[blockIdx.x]; + s->ck = *s->page.chunk; + s->col = *s->ck.col_desc; + s->rle_len_pos = nullptr; + // get s->cur back to where it was at the end of encoding the rep and def level data + s->cur = + s->page.page_data + s->page.max_hdr_size + s->page.def_lvl_bytes + s->page.rep_lvl_bytes; + if (s->page.page_type == PageType::DATA_PAGE) { + if (s->col.num_def_level_bits() != 0) { s->cur += 4; } + if (s->col.num_rep_level_bits() != 0) { s->cur += 4; } + } + } + __syncthreads(); + + if ((s->page.kernel_mask & ENC_MASK_DICTIONARY) == 0) { return; } + + // Encode data values + __syncthreads(); + auto const physical_type = s->col.physical_type; + auto const type_id = s->col.leaf_column->type().id(); + auto const dtype_len_out = physical_type_len(physical_type, type_id); + auto const dtype_len_in = [&]() -> uint32_t { + if (physical_type == INT32) { return int32_logical_len(type_id); } + if (physical_type == INT96) { return sizeof(int64_t); } + return dtype_len_out; + }(); + + // TODO assert dict_bits >= 0 + auto const dict_bits = (physical_type == BOOLEAN) ? 1 + : (s->ck.use_dictionary and s->page.page_type != PageType::DICTIONARY_PAGE) + ? s->ck.dict_rle_bits + : -1; + if (t == 0) { + uint8_t* dst = s->cur; + s->rle_run = 0; + s->rle_pos = 0; + s->rle_numvals = 0; + s->rle_out = dst; + s->page.encoding = + determine_encoding(s->page.page_type, physical_type, s->ck.use_dictionary, write_v2_headers); + if (dict_bits >= 0 && physical_type != BOOLEAN) { + dst[0] = dict_bits; + s->rle_out = dst + 1; + } else if (write_v2_headers && physical_type == BOOLEAN) { + // save space for RLE length. we don't know the total length yet. + s->rle_out = dst + RLE_LENGTH_FIELD_LEN; + s->rle_len_pos = dst; + } + s->page_start_val = row_to_value_idx(s->page.start_row, s->col); + s->chunk_start_val = row_to_value_idx(s->ck.start_row, s->col); + } + __syncthreads(); + + uint32_t num_valid = 0; + for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { + uint32_t nvals = min(s->page.num_leaf_values - cur_val_idx, block_size); + + auto [is_valid, val_idx] = [&]() { + size_type const val_idx_in_block = cur_val_idx + t; + size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; + + uint32_t const is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && + val_idx_in_block < s->page.num_leaf_values) + ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) + : 0; + // need to test for use_dictionary because it might be boolean + uint32_t const val_idx = + (s->ck.use_dictionary) ? val_idx_in_leaf_col - s->chunk_start_val : val_idx_in_leaf_col; + return std::make_tuple(is_valid, val_idx); + }(); + + if (is_valid) { num_valid++; } + cur_val_idx += nvals; + + // Dictionary encoding + if (dict_bits > 0) { + uint32_t rle_numvals; + uint32_t rle_numvals_in_block; + uint32_t pos; + block_scan(temp_storage.scan_storage).ExclusiveSum(is_valid, pos, rle_numvals_in_block); + rle_numvals = s->rle_numvals; + if (is_valid) { + uint32_t v; + if (physical_type == BOOLEAN) { + v = s->col.leaf_column->element(val_idx); + } else { + v = s->ck.dict_index[val_idx]; } + s->vals[rolling_idx(rle_numvals + pos)] = v; + } + rle_numvals += rle_numvals_in_block; + __syncthreads(); + if ((!write_v2_headers) && (physical_type == BOOLEAN)) { + PlainBoolEncode(s, rle_numvals, (cur_val_idx == s->page.num_leaf_values), t); + } else { + RleEncode(s, rle_numvals, dict_bits, (cur_val_idx == s->page.num_leaf_values), t); } __syncthreads(); } + if (t == 0) { s->cur = s->rle_out; } + __syncthreads(); } uint32_t const valid_count = block_reduce(temp_storage.reduce_storage).Sum(num_valid); @@ -1390,42 +1670,139 @@ __global__ void __launch_bounds__(128, 8) // save RLE length if necessary if (s->rle_len_pos != nullptr && t < 32) { // size doesn't include the 4 bytes for the length - auto const rle_size = static_cast(s->cur - s->rle_len_pos) - RLE_LENGTH_FIELD_LEN; - if (t < RLE_LENGTH_FIELD_LEN) { s->rle_len_pos[t] = rle_size >> (t * 8); } + auto const rle_size = static_cast(s->cur - s->rle_len_pos) - 4; + if (t < 4) { s->rle_len_pos[t] = rle_size >> (t * 8); } __syncwarp(); } - // V2 does not compress rep and def level data - size_t const skip_comp_size = s->page.def_lvl_bytes + s->page.rep_lvl_bytes; + finish_page_encode( + s, valid_count, s->cur, pages, comp_in, comp_out, comp_results, write_v2_headers); +} + +// FIXME: what should the args to launch_bounds be now? +// blockDim(128, 1, 1) +template +__global__ void __launch_bounds__(block_size, 8) + gpuEncodeDeltaBinaryPages(device_span pages, + device_span> comp_in, + device_span> comp_out, + device_span comp_results) +{ + // block of shared memory for value storage and bit packing + // TODO add constant that's the sum of buffer_size and block_size + __shared__ uint64_t delta_shared[delta::buffer_size + delta::block_size]; + __shared__ __align__(8) page_enc_state_s<0> state_g; + using block_reduce = cub::BlockReduce; + __shared__ union { + typename block_reduce::TempStorage reduce_storage; + typename delta::index_scan::TempStorage delta_index_tmp; + typename delta::block_reduce::TempStorage delta_reduce_tmp; + typename delta::warp_reduce::TempStorage delta_warp_red_tmp[delta::num_mini_blocks]; + } temp_storage; + + auto* const s = &state_g; + uint32_t t = threadIdx.x; if (t == 0) { - s->page.num_nulls = s->page.num_values - valid_count; - uint8_t* const base = s->page.page_data + s->page.max_hdr_size; - auto const actual_data_size = static_cast(s->cur - base); - if (actual_data_size > s->page.max_data_size) { - CUDF_UNREACHABLE("detected possible page data corruption"); - } - s->page.max_data_size = actual_data_size; - if (not comp_in.empty()) { - comp_in[blockIdx.x] = {base + skip_comp_size, actual_data_size - skip_comp_size}; - comp_out[blockIdx.x] = {s->page.compressed_data + s->page.max_hdr_size + skip_comp_size, - 0}; // size is unused - } - pages[blockIdx.x] = s->page; - if (not comp_results.empty()) { - comp_results[blockIdx.x] = {0, compression_status::FAILURE}; - pages[blockIdx.x].comp_res = &comp_results[blockIdx.x]; - } + state_g = page_enc_state_s<0>{}; + s->page = pages[blockIdx.x]; + s->ck = *s->page.chunk; + s->col = *s->ck.col_desc; + s->rle_len_pos = nullptr; + // get s->cur back to where it was at the end of encoding the rep and def level data + s->cur = + s->page.page_data + s->page.max_hdr_size + s->page.def_lvl_bytes + s->page.rep_lvl_bytes; } + __syncthreads(); - // copy over uncompressed data - if (skip_comp_size != 0 && not comp_in.empty()) { - uint8_t const* const src = s->page.page_data + s->page.max_hdr_size; - uint8_t* const dst = s->page.compressed_data + s->page.max_hdr_size; - for (int i = t; i < skip_comp_size; i += block_size) { - dst[i] = src[i]; + if ((s->page.kernel_mask & ENC_MASK_DELTA_BINARY) == 0) { return; } + + // Encode data values + __syncthreads(); + auto const physical_type = s->col.physical_type; + auto const type_id = s->col.leaf_column->type().id(); + auto const dtype_len_out = physical_type_len(physical_type, type_id); + auto const dtype_len_in = [&]() -> uint32_t { + if (physical_type == INT32) { return int32_logical_len(type_id); } + if (physical_type == INT96) { return sizeof(int64_t); } + return dtype_len_out; + }(); + + if (t == 0) { + uint8_t* dst = s->cur; + s->rle_run = 0; + s->rle_pos = 0; + s->rle_numvals = 0; + s->rle_out = dst; + s->page.encoding = Encoding::DELTA_BINARY_PACKED; + s->page_start_val = row_to_value_idx(s->page.start_row, s->col); + s->chunk_start_val = row_to_value_idx(s->ck.start_row, s->col); + } + __syncthreads(); + + // need to know the number of valid values for the null values calculation and to size + // the delta binary encoder. + uint32_t valid_count = 0; + if (not s->col.leaf_column->nullable()) { + valid_count = s->page.num_leaf_values; + } else { + uint32_t num_valid = 0; + for (uint32_t cur_val_idx = 0; cur_val_idx < s->page.num_leaf_values;) { + uint32_t const nvals = min(s->page.num_leaf_values - cur_val_idx, block_size); + size_type const val_idx_in_block = cur_val_idx + t; + size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; + + if (val_idx_in_leaf_col < s->col.leaf_column->size() && + val_idx_in_block < s->page.num_leaf_values && + s->col.leaf_column->is_valid(val_idx_in_leaf_col)) { + num_valid++; + } + cur_val_idx += nvals; + } + valid_count = block_reduce(temp_storage.reduce_storage).Sum(num_valid); + } + + auto calc_idx_and_validity = [&](uint32_t cur_val_idx) { + size_type const val_idx_in_block = cur_val_idx + t; + size_type const val_idx_in_leaf_col = s->page_start_val + val_idx_in_block; + + uint32_t const is_valid = (val_idx_in_leaf_col < s->col.leaf_column->size() && + val_idx_in_block < s->page.num_leaf_values) + ? s->col.leaf_column->is_valid(val_idx_in_leaf_col) + : 0; + + return std::make_tuple(is_valid, val_idx_in_leaf_col); + }; + + uint8_t const* delta_ptr = nullptr; // this will be the end of delta block pointer + + if (physical_type == INT32) { + // FIXME need to handle all the time scaling stuff here too + if (dtype_len_in == 4) { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else if (dtype_len_in == 2) { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else if (dtype_len_in == 8) { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); + } else { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); } + } else { + delta_enc encoder{ + s, valid_count, calc_idx_and_validity, delta_shared, &temp_storage}; + delta_ptr = encoder.encode(); } + + finish_page_encode( + s, valid_count, delta_ptr, pages, comp_in, comp_out, comp_results, true); } constexpr int decide_compression_warps_in_block = 4; @@ -1460,7 +1837,8 @@ __global__ void __launch_bounds__(decide_compression_block_size) for (auto page_id = lane_id; page_id < num_pages; page_id += cudf::detail::warp_size) { auto const& curr_page = ck_g[warp_id].pages[page_id]; auto const page_data_size = curr_page.max_data_size; - auto const lvl_bytes = curr_page.def_lvl_bytes + curr_page.rep_lvl_bytes; + auto const is_v2 = curr_page.page_type == PageType::DATA_PAGE_V2; + auto const lvl_bytes = is_v2 ? curr_page.def_lvl_bytes + curr_page.rep_lvl_bytes : 0; uncompressed_data_size += page_data_size; if (auto comp_res = curr_page.comp_res; comp_res != nullptr) { compressed_data_size += comp_res->bytes_written + lvl_bytes; @@ -1923,7 +2301,8 @@ __global__ void __launch_bounds__(128) } uncompressed_page_size = page_g.max_data_size; if (ck_g.is_compressed) { - auto const lvl_bytes = page_g.def_lvl_bytes + page_g.rep_lvl_bytes; + auto const is_v2 = page_g.page_type == PageType::DATA_PAGE_V2; + auto const lvl_bytes = is_v2 ? page_g.def_lvl_bytes + page_g.rep_lvl_bytes : 0; hdr_start = page_g.compressed_data; compressed_page_size = static_cast(comp_results[blockIdx.x].bytes_written) + lvl_bytes; @@ -2158,6 +2537,10 @@ constexpr __device__ void* align8(void* ptr) return static_cast(ptr) - algn; } +struct mask_tform { + __device__ uint32_t operator()(EncPage const& p) { return p.kernel_mask; } +}; + } // namespace // blockDim(1, 1, 1) @@ -2260,8 +2643,9 @@ void InitFragmentStatistics(device_span groups, rmm::cuda_stream_view stream) { int const num_fragments = fragments.size(); - int const dim = util::div_rounding_up_safe(num_fragments, 128 / cudf::detail::warp_size); - gpuInitFragmentStats<<>>(groups, fragments); + int const dim = + util::div_rounding_up_safe(num_fragments, encode_block_size / cudf::detail::warp_size); + gpuInitFragmentStats<<>>(groups, fragments); } void InitEncoderPages(device_2dspan chunks, @@ -2280,18 +2664,18 @@ void InitEncoderPages(device_2dspan chunks, { auto num_rowgroups = chunks.size().first; dim3 dim_grid(num_columns, num_rowgroups); // 1 threadblock per rowgroup - gpuInitPages<<>>(chunks, - pages, - page_sizes, - comp_page_sizes, - col_desc, - page_grstats, - chunk_grstats, - num_columns, - max_page_size_bytes, - max_page_size_rows, - page_align, - write_v2_headers); + gpuInitPages<<>>(chunks, + pages, + page_sizes, + comp_page_sizes, + col_desc, + page_grstats, + chunk_grstats, + num_columns, + max_page_size_bytes, + max_page_size_rows, + page_align, + write_v2_headers); } void EncodePages(device_span pages, @@ -2302,10 +2686,43 @@ void EncodePages(device_span pages, rmm::cuda_stream_view stream) { auto num_pages = pages.size(); + + // determine which kernels to invoke + auto mask_iter = thrust::make_transform_iterator(pages.begin(), mask_tform{}); + int kernel_mask = thrust::reduce( + rmm::exec_policy(stream), mask_iter, mask_iter + pages.size(), 0U, thrust::bit_or{}); + + // get the number of streams we need from the pool + int nkernels = std::bitset<32>(kernel_mask).count(); + auto streams = cudf::detail::fork_streams(stream, nkernels); + // A page is part of one column. This is launching 1 block per page. 1 block will exclusively // deal with one datatype. - gpuEncodePages<128><<>>( - pages, comp_in, comp_out, comp_results, write_v2_headers); + + int s_idx = 0; + if ((kernel_mask & ENC_MASK_PLAIN) != 0) { + auto const strm = streams[s_idx++]; + gpuEncodePageLevels + <<>>(pages, write_v2_headers); + gpuEncodePages<<>>( + pages, comp_in, comp_out, comp_results, write_v2_headers); + } + if ((kernel_mask & ENC_MASK_DELTA_BINARY) != 0) { + auto const strm = streams[s_idx++]; + gpuEncodePageLevels + <<>>(pages, write_v2_headers); + gpuEncodeDeltaBinaryPages + <<>>(pages, comp_in, comp_out, comp_results); + } + if ((kernel_mask & ENC_MASK_DICTIONARY) != 0) { + auto const strm = streams[s_idx++]; + gpuEncodePageLevels + <<>>(pages, write_v2_headers); + gpuEncodeDictPages<<>>( + pages, comp_in, comp_out, comp_results, write_v2_headers); + } + + cudf::detail::join_streams(streams, stream); } void DecideCompression(device_span chunks, rmm::cuda_stream_view stream) @@ -2323,7 +2740,7 @@ void EncodePageHeaders(device_span pages, { // TODO: single thread task. No need for 128 threads/block. Earlier it used to employ rest of the // threads to coop load structs - gpuEncodePageHeaders<<>>( + gpuEncodePageHeaders<<>>( pages, comp_results, page_stats, chunk_stats); } diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index e82b6abc13d..c2892ed6495 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -75,6 +75,12 @@ struct input_column_info { namespace gpu { +// TODO: The delta encodings use ULEB128 integers, but for now we're only +// using max 64 bits. Need to see what the performance impact is of using +// __int128_t rather than int64_t. +using uleb128_t = uint64_t; +using zigzag128_t = int64_t; + /** * @brief Enums for the flags in the page header */ @@ -390,6 +396,17 @@ constexpr uint32_t encoding_to_mask(Encoding encoding) return 1 << static_cast(encoding); } +/** + * @brief Enum of mask bits for the EncPage kernel_mask + * + * Used to control which encode kernels to run. + */ +enum encoder_kernel_mask_bits { + ENC_MASK_PLAIN = (1 << 0), // Run plain encoding kernel + ENC_MASK_DICTIONARY = (1 << 1), // Run dictionary encoding kernel + ENC_MASK_DELTA_BINARY = (1 << 2) // Run DELTA_BINARY_PACKED encoding kernel +}; + /** * @brief Struct describing an encoder column chunk */ @@ -452,6 +469,7 @@ struct EncPage { uint32_t rep_lvl_bytes; //!< Number of bytes of encoded repetition level data (V2 only) compression_result* comp_res; //!< Ptr to compression result uint32_t num_nulls; //!< Number of null values (V2 only) (down here for alignment) + uint32_t kernel_mask; //!< Mask used to control which encoding kernels to run }; /** diff --git a/cpp/tests/io/parquet_test.cpp b/cpp/tests/io/parquet_test.cpp index 64aca091686..907805d0abd 100644 --- a/cpp/tests/io/parquet_test.cpp +++ b/cpp/tests/io/parquet_test.cpp @@ -540,7 +540,9 @@ TYPED_TEST(ParquetWriterTimestampTypeTest, Timestamps) auto filepath = temp_env->get_temp_filepath("Timestamps.parquet"); cudf::io::parquet_writer_options out_opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected); + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) + .write_v2_headers(true) + .dictionary_policy(cudf::io::dictionary_policy::NEVER); cudf::io::write_parquet(out_opts); cudf::io::parquet_reader_options in_opts = @@ -566,7 +568,9 @@ TYPED_TEST(ParquetWriterTimestampTypeTest, TimestampsWithNulls) auto filepath = temp_env->get_temp_filepath("TimestampsWithNulls.parquet"); cudf::io::parquet_writer_options out_opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected); + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) + .write_v2_headers(true) + .dictionary_policy(cudf::io::dictionary_policy::NEVER); cudf::io::write_parquet(out_opts); cudf::io::parquet_reader_options in_opts = @@ -590,7 +594,9 @@ TYPED_TEST(ParquetWriterTimestampTypeTest, TimestampOverflow) auto filepath = temp_env->get_temp_filepath("ParquetTimestampOverflow.parquet"); cudf::io::parquet_writer_options out_opts = - cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected); + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, expected) + .write_v2_headers(true) + .dictionary_policy(cudf::io::dictionary_policy::NEVER); cudf::io::write_parquet(out_opts); cudf::io::parquet_reader_options in_opts = @@ -6672,7 +6678,7 @@ TEST_P(ParquetV2Test, CheckEncodings) // data should be PLAIN for v1, RLE for V2 auto col0_data = cudf::detail::make_counting_transform_iterator(0, [](auto i) -> bool { return i % 2 == 0; }); - // data should be PLAIN for both + // data should be PLAIN for v1, DELTA_BINARY_PACKED for v2 auto col1_data = random_values(num_rows); // data should be PLAIN_DICTIONARY for v1, PLAIN and RLE_DICTIONARY for v2 auto col2_data = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return 1; }); @@ -6707,10 +6713,10 @@ TEST_P(ParquetV2Test, CheckEncodings) // col0 should have RLE for rep/def and data EXPECT_TRUE(chunk0_enc.size() == 1); EXPECT_TRUE(contains(chunk0_enc, Encoding::RLE)); - // col1 should have RLE for rep/def and PLAIN for data + // col1 should have RLE for rep/def and DELTA_BINARY_PACKED for data EXPECT_TRUE(chunk1_enc.size() == 2); EXPECT_TRUE(contains(chunk1_enc, Encoding::RLE)); - EXPECT_TRUE(contains(chunk1_enc, Encoding::PLAIN)); + EXPECT_TRUE(contains(chunk1_enc, Encoding::DELTA_BINARY_PACKED)); // col2 should have RLE for rep/def, PLAIN for dict, and RLE_DICTIONARY for data EXPECT_TRUE(chunk2_enc.size() == 3); EXPECT_TRUE(contains(chunk2_enc, Encoding::RLE));