diff --git a/CMakeLists.txt b/CMakeLists.txt index 3157d3a1ec49..04fa1c485801 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -187,9 +187,6 @@ if (COMPILER_CLANG) set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Xclang -fuse-ctor-homing") endif() endif() - - # set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -fno-optimize-sibling-calls") - # set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -fno-optimize-sibling-calls") endif () # If compiler has support for -Wreserved-identifier. It is difficult to detect by clang version, diff --git a/utils/local-engine/Builder/SerializedPlanBuilder.cpp b/utils/local-engine/Builder/SerializedPlanBuilder.cpp index b51a4a810555..7b2f8c721f55 100644 --- a/utils/local-engine/Builder/SerializedPlanBuilder.cpp +++ b/utils/local-engine/Builder/SerializedPlanBuilder.cpp @@ -1,24 +1,7 @@ #include "SerializedPlanBuilder.h" -#include -#include -#include -#include -#include -#include -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int UNKNOWN_TYPE; -} -} namespace dbms { - -using namespace DB; SchemaPtr SerializedSchemaBuilder::build() { for (const auto & [name, type] : this->type_map) @@ -188,11 +171,9 @@ std::unique_ptr SerializedPlanBuilder::build() { return std::move(this->plan); } - SerializedPlanBuilder::SerializedPlanBuilder() : plan(std::make_unique()) { } - SerializedPlanBuilder & SerializedPlanBuilder::aggregate(std::vector /*keys*/, std::vector aggregates) { substrait::Rel * rel = new substrait::Rel(); @@ -207,7 +188,6 @@ SerializedPlanBuilder & SerializedPlanBuilder::aggregate(std::vector / this->prev_rel = rel; return *this; } - SerializedPlanBuilder & SerializedPlanBuilder::project(std::vector projections) { substrait::Rel * project = new substrait::Rel(); @@ -220,94 +200,6 @@ SerializedPlanBuilder & SerializedPlanBuilder::project(std::vector SerializedPlanBuilder::buildType(const DB::DataTypePtr & ch_type) -{ - const auto * ch_type_nullable = checkAndGetDataType(ch_type.get()); - const bool is_nullable = (ch_type_nullable != nullptr); - auto type_nullability - = is_nullable ? substrait::Type_Nullability_NULLABILITY_NULLABLE : substrait::Type_Nullability_NULLABILITY_REQUIRED; - - const auto ch_type_without_nullable = DB::removeNullable(ch_type); - const DB::WhichDataType which(ch_type_without_nullable); - - auto res = std::make_shared(); - if (which.isUInt8()) - res->mutable_bool_()->set_nullability(type_nullability); - else if (which.isInt8()) - res->mutable_i8()->set_nullability(type_nullability); - else if (which.isInt16()) - res->mutable_i16()->set_nullability(type_nullability); - else if (which.isInt32()) - res->mutable_i32()->set_nullability(type_nullability); - else if (which.isInt64()) - res->mutable_i64()->set_nullability(type_nullability); - else if (which.isString() || which.isAggregateFunction()) - res->mutable_binary()->set_nullability(type_nullability); /// Spark Binary type is more similiar to CH String type - else if (which.isFloat32()) - res->mutable_fp32()->set_nullability(type_nullability); - else if (which.isFloat64()) - res->mutable_fp64()->set_nullability(type_nullability); - else if (which.isFloat64()) - res->mutable_fp64()->set_nullability(type_nullability); - else if (which.isDateTime64()) - { - const auto * ch_type_datetime64 = checkAndGetDataType(ch_type_without_nullable.get()); - if (ch_type_datetime64->getScale() != 6) - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); - res->mutable_timestamp()->set_nullability(type_nullability); - } - else if (which.isDate32()) - res->mutable_date()->set_nullability(type_nullability); - else if (which.isDecimal()) - { - if (which.isDecimal256()) - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); - - const auto scale = getDecimalScale(*ch_type_without_nullable, 0); - const auto precision = getDecimalPrecision(*ch_type_without_nullable); - if (scale == 0 && precision == 0) - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); - res->mutable_decimal()->set_nullability(type_nullability); - res->mutable_decimal()->set_scale(scale); - res->mutable_decimal()->set_precision(precision); - } - else if (which.isTuple()) - { - const auto * ch_tuple_type = checkAndGetDataType(ch_type_without_nullable.get()); - const auto & ch_field_types = ch_tuple_type->getElements(); - res->mutable_struct_()->set_nullability(type_nullability); - for (const auto & ch_field_type: ch_field_types) - res->mutable_struct_()->mutable_types()->Add(std::move(*buildType(ch_field_type))); - } - else if (which.isArray()) - { - const auto * ch_array_type = checkAndGetDataType(ch_type_without_nullable.get()); - const auto & ch_nested_type = ch_array_type->getNestedType(); - res->mutable_list()->set_nullability(type_nullability); - *(res->mutable_list()->mutable_type()) = *buildType(ch_nested_type); - } - else if (which.isMap()) - { - const auto & ch_map_type = checkAndGetDataType(ch_type_without_nullable.get()); - const auto & ch_key_type = ch_map_type->getKeyType(); - const auto & ch_val_type = ch_map_type->getValueType(); - res->mutable_map()->set_nullability(type_nullability); - *(res->mutable_map()->mutable_key()) = *buildType(ch_key_type); - *(res->mutable_map()->mutable_value()) = *buildType(ch_val_type); - } - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); - - return std::move(res); -} - -void SerializedPlanBuilder::buildType(const DB::DataTypePtr & ch_type, String & substrait_type) -{ - auto pb = buildType(ch_type); - substrait_type = pb->SerializeAsString(); -} - - substrait::Expression * selection(int32_t field_id) { substrait::Expression * rel = new substrait::Expression(); diff --git a/utils/local-engine/Builder/SerializedPlanBuilder.h b/utils/local-engine/Builder/SerializedPlanBuilder.h index 3b0638a3eeb4..66345c55260a 100644 --- a/utils/local-engine/Builder/SerializedPlanBuilder.h +++ b/utils/local-engine/Builder/SerializedPlanBuilder.h @@ -53,9 +53,6 @@ class SerializedPlanBuilder SchemaPtr schema); std::unique_ptr build(); - static std::shared_ptr buildType(const DB::DataTypePtr & ch_type); - static void buildType(const DB::DataTypePtr & ch_type, String & substrait_type); - private: void setInputToPrev(substrait::Rel * input); substrait::Rel * prev_rel = nullptr; diff --git a/utils/local-engine/CMakeLists.txt b/utils/local-engine/CMakeLists.txt index eade3b81bb74..766610a49261 100644 --- a/utils/local-engine/CMakeLists.txt +++ b/utils/local-engine/CMakeLists.txt @@ -111,7 +111,6 @@ target_compile_options(_icui18n PRIVATE -fPIC) target_compile_options(_cpuid PRIVATE -fPIC) target_compile_options(re2_st PRIVATE -fPIC) target_compile_options(_boost_program_options PRIVATE -fPIC) -target_compile_options(_boost_context PRIVATE -fPIC) target_compile_options(clickhouse_common_io PRIVATE -fPIC) target_compile_options(clickhouse_dictionaries_embedded PRIVATE -fPIC) target_compile_options(clickhouse_common_zookeeper PRIVATE -fPIC) diff --git a/utils/local-engine/Common/DebugUtils.cpp b/utils/local-engine/Common/DebugUtils.cpp index aecb1dc827a5..9606e20ee33d 100644 --- a/utils/local-engine/Common/DebugUtils.cpp +++ b/utils/local-engine/Common/DebugUtils.cpp @@ -17,9 +17,10 @@ void headBlock(const DB::Block & block, size_t count) std::cerr << block.dumpStructure() << std::endl; // print header for (const auto& name : block.getNames()) + { std::cerr << name << "\t"; + } std::cerr << std::endl; - // print rows for (size_t row = 0; row < std::min(count, block.rows()); ++row) { @@ -27,10 +28,70 @@ void headBlock(const DB::Block & block, size_t count) { const auto type = block.getByPosition(column).type; auto col = block.getByPosition(column).column; - - if (column > 0) - std::cerr << "\t"; - std::cerr << toString((*col)[row]); + auto nested_col = col; + DB::DataTypePtr nested_type = type; + if (const auto *nullable = DB::checkAndGetDataType(type.get())) + { + nested_type = nullable->getNestedType(); + const auto *nullable_column = DB::checkAndGetColumn(*col); + nested_col = nullable_column->getNestedColumnPtr(); + } + DB::WhichDataType which(nested_type); + if (col->isNullAt(row)) + { + std::cerr << "null" << "\t"; + } + else if (which.isUInt()) + { + auto value = nested_col->getUInt(row); + std::cerr << std::to_string(value) << "\t"; + } + else if (which.isString()) + { + auto value = DB::checkAndGetColumn(*nested_col)->getDataAt(row).toString(); + std::cerr << value << "\t"; + } + else if (which.isInt()) + { + auto value = nested_col->getInt(row); + std::cerr << std::to_string(value) << "\t"; + } + else if (which.isFloat32()) + { + auto value = nested_col->getFloat32(row); + std::cerr << std::to_string(value) << "\t"; + } + else if (which.isFloat64()) + { + auto value = nested_col->getFloat64(row); + std::cerr << std::to_string(value) << "\t"; + } + else if (which.isDate()) + { + const auto * date_type = DB::checkAndGetDataType(nested_type.get()); + String date_string; + DB::WriteBufferFromString wb(date_string); + date_type->getSerialization(DB::ISerialization::Kind::DEFAULT)->serializeText(*nested_col, row, wb, {}); + std::cerr << date_string.substr(0, 10) << "\t"; + } + else if (which.isDate32()) + { + const auto * date_type = DB::checkAndGetDataType(nested_type.get()); + String date_string; + DB::WriteBufferFromString wb(date_string); + date_type->getSerialization(DB::ISerialization::Kind::DEFAULT)->serializeText(*nested_col, row, wb, {}); + std::cerr << date_string.substr(0, 10) << "\t"; + } + else if (which.isDateTime64()) + { + const auto * datetime64_type = DB::checkAndGetDataType(nested_type.get()); + String datetime64_string; + DB::WriteBufferFromString wb(datetime64_string); + datetime64_type->getSerialization(DB::ISerialization::Kind::DEFAULT)->serializeText(*nested_col, row, wb, {}); + std::cerr << datetime64_string << "\t"; + } + else + std::cerr << "N/A" << "\t"; } std::cerr << std::endl; } @@ -39,14 +100,49 @@ void headBlock(const DB::Block & block, size_t count) void headColumn(const DB::ColumnPtr column, size_t count) { std::cerr << "============Column============" << std::endl; - // print header + std::cerr << column->getName() << "\t"; std::cerr << std::endl; - // print rows for (size_t row = 0; row < std::min(count, column->size()); ++row) - std::cerr << toString((*column)[row]) << std::endl; + { + auto type = column->getDataType(); + const auto& col = column; + DB::WhichDataType which(type); + if (col->isNullAt(row)) + { + std::cerr << "null" << "\t"; + } + else if (which.isUInt()) + { + auto value = col->getUInt(row); + std::cerr << std::to_string(value) << std::endl; + } + else if (which.isString()) + { + auto value = DB::checkAndGetColumn(*col)->getDataAt(row).toString(); + std::cerr << value << std::endl; + } + else if (which.isInt()) + { + auto value = col->getInt(row); + std::cerr << std::to_string(value) << std::endl; + } + else if (which.isFloat32()) + { + auto value = col->getFloat32(row); + std::cerr << std::to_string(value) << std::endl; + } + else if (which.isFloat64()) + { + auto value = col->getFloat64(row); + std::cerr << std::to_string(value) << std::endl; + } + else + { + std::cerr << "N/A" << std::endl; + } + } } - } diff --git a/utils/local-engine/Parser/CHColumnToSparkRow.cpp b/utils/local-engine/Parser/CHColumnToSparkRow.cpp index 8ed0d168b0d0..96d093a5d36a 100644 --- a/utils/local-engine/Parser/CHColumnToSparkRow.cpp +++ b/utils/local-engine/Parser/CHColumnToSparkRow.cpp @@ -1,16 +1,10 @@ #include "CHColumnToSparkRow.h" -#include -#include #include #include #include #include -#include -#include -#include -#include #include -#include +#include #include namespace DB @@ -21,810 +15,299 @@ namespace ErrorCodes } } +#define WRITE_VECTOR_COLUMN(TYPE, PRIME_TYPE, GETTER) \ + const auto * type_col = checkAndGetColumn>(*nested_col); \ + for (auto i = 0; i < num_rows; i++) \ + { \ + bool is_null = nullable_column && nullable_column->isNullAt(i); \ + if (is_null) \ + { \ + setNullAt(buffer_address, offsets[i], field_offset, col_index); \ + } \ + else \ + { \ + auto * pointer = reinterpret_cast(buffer_address + offsets[i] + field_offset); \ + pointer[0] = type_col->GETTER(i);\ + } \ + } namespace local_engine { using namespace DB; - int64_t calculateBitSetWidthInBytes(int32_t num_fields) { return ((num_fields + 63) / 64) * 8; } -static int64_t calculatedFixeSizePerRow(int64_t num_cols) +int64_t calculatedFixeSizePerRow(DB::Block & header, int64_t num_cols) { - return calculateBitSetWidthInBytes(num_cols) + num_cols * 8; + auto fields = header.getNamesAndTypesList(); + // Calculate the decimal col num when the precision >18 + int32_t count = 0; + for (auto i = 0; i < num_cols; i++) + { + auto type = removeNullable(fields.getTypes()[i]); + DB::WhichDataType which(type); + if (which.isDecimal128()) + { + const auto & dtype = typeid_cast *>(type.get()); + int32_t precision = dtype->getPrecision(); + if (precision > 18) + count++; + } + } + + int64_t fixed_size = calculateBitSetWidthInBytes(num_cols) + num_cols * 8; + int64_t decimal_cols_size = count * 16; + return fixed_size + decimal_cols_size; } -int64_t roundNumberOfBytesToNearestWord(int64_t num_bytes) +int64_t roundNumberOfBytesToNearestWord(int64_t numBytes) { - auto remainder = num_bytes & 0x07; // This is equivalent to `numBytes % 8` - return num_bytes + ((8 - remainder) & 0x7); + int64_t remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + if (remainder == 0) + { + return numBytes; + } + else + { + return numBytes + (8 - remainder); + } } +int64_t getFieldOffset(int64_t nullBitsetWidthInBytes, int32_t index) +{ + return nullBitsetWidthInBytes + 8L * index; +} -void bitSet(char * bitmap, int32_t index) +void bitSet(uint8_t * buffer_address, int32_t index) { int64_t mask = 1L << (index & 0x3f); // mod 64 and shift int64_t word_offset = (index >> 6) * 8; int64_t word; - memcpy(&word, bitmap + word_offset, sizeof(int64_t)); + memcpy(&word, buffer_address + word_offset, sizeof(int64_t)); int64_t value = word | mask; - memcpy(bitmap + word_offset, &value, sizeof(int64_t)); + memcpy(buffer_address + word_offset, &value, sizeof(int64_t)); } -ALWAYS_INLINE bool isBitSet(const char * bitmap, int32_t index) +void setNullAt(uint8_t * buffer_address, int64_t row_offset, int64_t field_offset, int32_t col_index) { - assert(index >= 0); - int64_t mask = 1 << (index & 63); - int64_t word_offset = static_cast(index >> 6) * 8L; - int64_t word = *reinterpret_cast(bitmap + word_offset); - return word & mask; + bitSet(buffer_address + row_offset, col_index); + // set the value to 0 + memset(buffer_address + row_offset + field_offset, 0, sizeof(int64_t)); } -static void writeFixedLengthNonNullableValue( - char * buffer_address, +void writeValue( + uint8_t * buffer_address, int64_t field_offset, - const ColumnWithTypeAndName & col, - int64_t num_rows, - const std::vector & offsets) -{ - FixedLengthDataWriter writer(col.type); - for (size_t i = 0; i < static_cast(num_rows); i++) - writer.unsafeWrite(col.column->getDataAt(i), buffer_address + offsets[i] + field_offset); -} - -static void writeFixedLengthNullableValue( - char * buffer_address, - int64_t field_offset, - const ColumnWithTypeAndName & col, + ColumnWithTypeAndName & col, int32_t col_index, int64_t num_rows, - const std::vector & offsets) + std::vector & offsets, + std::vector & buffer_cursor) { + ColumnPtr nested_col = col.column; const auto * nullable_column = checkAndGetColumn(*col.column); - const auto & null_map = nullable_column->getNullMapData(); - const auto & nested_column = nullable_column->getNestedColumn(); - FixedLengthDataWriter writer(col.type); - for (size_t i = 0; i < static_cast(num_rows); i++) + if (nullable_column) { - if (null_map[i]) - bitSet(buffer_address + offsets[i], col_index); - else - writer.unsafeWrite(nested_column.getDataAt(i), buffer_address + offsets[i] + field_offset); + nested_col = nullable_column->getNestedColumnPtr(); } -} - -static void writeVariableLengthNonNullableValue( - char * buffer_address, - int64_t field_offset, - const ColumnWithTypeAndName & col, - int64_t num_rows, - const std::vector & offsets, - std::vector & buffer_cursor) -{ - const auto type_without_nullable{std::move(removeNullable(col.type))}; - const bool use_raw_data = BackingDataLengthCalculator::isDataTypeSupportRawData(type_without_nullable); - VariableLengthDataWriter writer(col.type, buffer_address, offsets, buffer_cursor); - if (use_raw_data) + nested_col = nested_col->convertToFullColumnIfConst(); + WhichDataType which(nested_col->getDataType()); + if (which.isUInt8()) { - for (size_t i = 0; i < static_cast(num_rows); i++) - { - StringRef str = col.column->getDataAt(i); - int64_t offset_and_size = writer.writeUnalignedBytes(i, str.data, str.size, 0); - memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); - } + WRITE_VECTOR_COLUMN(UInt8, uint8_t, getInt) } - else + else if (which.isInt8()) { - Field field; - for (size_t i = 0; i < static_cast(num_rows); i++) - { - field = std::move((*col.column)[i]); - int64_t offset_and_size = writer.write(i, field, 0); - memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); - } + WRITE_VECTOR_COLUMN(Int8, int8_t, getInt) } -} - -static void writeVariableLengthNullableValue( - char * buffer_address, - int64_t field_offset, - const ColumnWithTypeAndName & col, - int32_t col_index, - int64_t num_rows, - const std::vector & offsets, - std::vector & buffer_cursor) -{ - const auto * nullable_column = checkAndGetColumn(*col.column); - const auto & null_map = nullable_column->getNullMapData(); - const auto & nested_column = nullable_column->getNestedColumn(); - const auto type_without_nullable{std::move(removeNullable(col.type))}; - const bool use_raw_data = BackingDataLengthCalculator::isDataTypeSupportRawData(type_without_nullable); - VariableLengthDataWriter writer(col.type, buffer_address, offsets, buffer_cursor); - if (use_raw_data) + else if (which.isInt16()) + { + WRITE_VECTOR_COLUMN(Int16, int16_t, getInt) + } + else if (which.isUInt16()) + { + WRITE_VECTOR_COLUMN(UInt16, uint16_t , get64) + } + else if (which.isInt32()) + { + WRITE_VECTOR_COLUMN(Int32, int32_t, getInt) + } + else if (which.isInt64()) + { + WRITE_VECTOR_COLUMN(Int64, int64_t, getInt) + } + else if (which.isUInt64()) { - for (size_t i = 0; i < static_cast(num_rows); i++) + WRITE_VECTOR_COLUMN(UInt64, int64_t, get64) + } + else if (which.isFloat32()) + { + WRITE_VECTOR_COLUMN(Float32, float_t, getFloat32) + } + else if (which.isFloat64()) + { + WRITE_VECTOR_COLUMN(Float64, double_t, getFloat64) + } + else if (which.isDate()) + { + WRITE_VECTOR_COLUMN(UInt16, uint16_t, get64) + } + else if (which.isDate32()) + { + WRITE_VECTOR_COLUMN(UInt32, uint32_t, get64) + } + else if (which.isDateTime64()) + { + using ColumnDateTime64 = ColumnDecimal; + const auto * datetime64_col = checkAndGetColumn(*nested_col); + for (auto i=0; iisNullAt(i); + if (is_null) + { + setNullAt(buffer_address, offsets[i], field_offset, col_index); + } else { - StringRef str = nested_column.getDataAt(i); - int64_t offset_and_size = writer.writeUnalignedBytes(i, str.data, str.size, 0); - memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); + auto * pointer = reinterpret_cast(buffer_address + offsets[i] + field_offset); + pointer[0] = datetime64_col->getInt(i); } } } - else + else if (which.isString()) { - Field field; - for (size_t i = 0; i < static_cast(num_rows); i++) + const auto * string_col = checkAndGetColumn(*nested_col); + for (auto i = 0; i < num_rows; i++) { - if (null_map[i]) - bitSet(buffer_address + offsets[i], col_index); + bool is_null = nullable_column && nullable_column->isNullAt(i); + if (is_null) + { + setNullAt(buffer_address, offsets[i], field_offset, col_index); + } else { - field = std::move(nested_column[i]); - int64_t offset_and_size = writer.write(i, field, 0); - memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); + StringRef string_value = string_col->getDataAt(i); + // write the variable value + memcpy(buffer_address + offsets[i] + buffer_cursor[i], string_value.data, string_value.size); + // write the offset and size + int64_t offset_and_size = (buffer_cursor[i] << 32) | string_value.size; + memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, sizeof(int64_t)); + buffer_cursor[i] += string_value.size; } } } -} - - -static void writeValue( - char * buffer_address, - int64_t field_offset, - const ColumnWithTypeAndName & col, - int32_t col_index, - int64_t num_rows, - const std::vector & offsets, - std::vector & buffer_cursor) -{ - const auto type_without_nullable{std::move(removeNullable(col.type))}; - const auto is_nullable = isColumnNullable(*col.column); - if (BackingDataLengthCalculator::isFixedLengthDataType(type_without_nullable)) - { - if (is_nullable) - writeFixedLengthNullableValue(buffer_address, field_offset, col, col_index, num_rows, offsets); - else - writeFixedLengthNonNullableValue(buffer_address, field_offset, col, num_rows, offsets); - } - else if (BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) + else { - if (is_nullable) - writeVariableLengthNullableValue(buffer_address, field_offset, col, col_index, num_rows, offsets, buffer_cursor); - else - writeVariableLengthNonNullableValue(buffer_address, field_offset, col, num_rows, offsets, buffer_cursor); + throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support type {} convert from ch to spark" ,magic_enum::enum_name(nested_col->getDataType())); } - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Doesn't support type {} for writeValue", col.type->getName()); } -SparkRowInfo::SparkRowInfo(const Block & block) - : types(std::move(block.getDataTypes())) - , num_rows(block.rows()) - , num_cols(block.columns()) - , null_bitset_width_in_bytes(calculateBitSetWidthInBytes(num_cols)) - , total_bytes(0) - , offsets(num_rows, 0) - , lengths(num_rows, 0) - , buffer_cursor(num_rows, 0) - , buffer_address(nullptr) +SparkRowInfo::SparkRowInfo(DB::Block & block) { - int64_t fixed_size_per_row = calculatedFixeSizePerRow(num_cols); - - /// Initialize lengths and buffer_cursor - for (int64_t i = 0; i < num_rows; i++) + num_rows = block.rows(); + num_cols = block.columns(); + null_bitset_width_in_bytes = calculateBitSetWidthInBytes(num_cols); + int64_t fixed_size_per_row = calculatedFixeSizePerRow(block, num_cols); + // Initialize the offsets_ , lengths_, buffer_cursor_ + for (auto i = 0; i < num_rows; i++) { - lengths[i] = fixed_size_per_row; - buffer_cursor[i] = fixed_size_per_row; + lengths.push_back(fixed_size_per_row); + offsets.push_back(0); + buffer_cursor.push_back(null_bitset_width_in_bytes + 8 * num_cols); } - - for (int64_t col_idx = 0; col_idx < num_cols; ++col_idx) + // Calculated the lengths_ + for (auto i = 0; i < num_cols; i++) { - const auto & col = block.getByPosition(col_idx); - - /// No need to calculate backing data length for fixed length types - const auto type_without_nullable = removeNullable(col.type); - if (BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) + auto col = block.getByPosition(i); + if (isStringOrFixedString(removeNullable(col.type))) { - if (BackingDataLengthCalculator::isDataTypeSupportRawData(type_without_nullable)) + size_t length; + for (auto j = 0; j < num_rows; j++) { - const auto * nullable_column = checkAndGetColumn(*col.column); - if (nullable_column) - { - const auto & nested_column = nullable_column->getNestedColumn(); - const auto & null_map = nullable_column->getNullMapData(); - for (auto row_idx = 0; row_idx < num_rows; ++row_idx) - if (!null_map[row_idx]) - lengths[row_idx] += roundNumberOfBytesToNearestWord(nested_column.getDataAt(row_idx).size); - } - else - { - for (auto row_idx = 0; row_idx < num_rows; ++row_idx) - lengths[row_idx] += roundNumberOfBytesToNearestWord(col.column->getDataAt(row_idx).size); - } - } - else - { - BackingDataLengthCalculator calculator(col.type); - for (auto row_idx = 0; row_idx < num_rows; ++row_idx) - { - const auto field = (*col.column)[row_idx]; - lengths[row_idx] += calculator.calculate(field); - } + length = col.column->getDataAt(j).size; + lengths[j] += roundNumberOfBytesToNearestWord(length); } } } - - /// Initialize offsets - for (int64_t i = 1; i < num_rows; ++i) - offsets[i] = offsets[i - 1] + lengths[i - 1]; - - /// Initialize total_bytes - for (int64_t i = 0; i < num_rows; ++i) - total_bytes += lengths[i]; } -const DB::DataTypes & SparkRowInfo::getDataTypes() const -{ - return types; -} - -int64_t SparkRowInfo::getFieldOffset(int32_t col_idx) const -{ - return null_bitset_width_in_bytes + 8L * col_idx; -} - -int64_t SparkRowInfo::getNullBitsetWidthInBytes() const +int64_t local_engine::SparkRowInfo::getNullBitsetWidthInBytes() const { return null_bitset_width_in_bytes; } -void SparkRowInfo::setNullBitsetWidthInBytes(int64_t null_bitset_width_in_bytes_) +void local_engine::SparkRowInfo::setNullBitsetWidthInBytes(int64_t null_bitset_width_in_bytes_) { null_bitset_width_in_bytes = null_bitset_width_in_bytes_; } - -int64_t SparkRowInfo::getNumCols() const +int64_t local_engine::SparkRowInfo::getNumCols() const { return num_cols; } - -void SparkRowInfo::setNumCols(int64_t num_cols_) +void local_engine::SparkRowInfo::setNumCols(int64_t numCols) { - num_cols = num_cols_; + num_cols = numCols; } - -int64_t SparkRowInfo::getNumRows() const +int64_t local_engine::SparkRowInfo::getNumRows() const { return num_rows; } - -void SparkRowInfo::setNumRows(int64_t num_rows_) +void local_engine::SparkRowInfo::setNumRows(int64_t numRows) { - num_rows = num_rows_; + num_rows = numRows; } - -char * SparkRowInfo::getBufferAddress() const +unsigned char * local_engine::SparkRowInfo::getBufferAddress() const { return buffer_address; } - -void SparkRowInfo::setBufferAddress(char * buffer_address_) +void local_engine::SparkRowInfo::setBufferAddress(unsigned char * bufferAddress) { - buffer_address = buffer_address_; + buffer_address = bufferAddress; } - -const std::vector & SparkRowInfo::getOffsets() const +const std::vector & local_engine::SparkRowInfo::getOffsets() const { return offsets; } - -const std::vector & SparkRowInfo::getLengths() const +const std::vector & local_engine::SparkRowInfo::getLengths() const { return lengths; } - -std::vector & SparkRowInfo::getBufferCursor() -{ - return buffer_cursor; -} - int64_t SparkRowInfo::getTotalBytes() const { return total_bytes; } - -std::unique_ptr CHColumnToSparkRow::convertCHColumnToSparkRow(const Block & block) +std::unique_ptr local_engine::CHColumnToSparkRow::convertCHColumnToSparkRow(Block & block) { - if (!block.rows() || !block.columns()) - return {}; - std::unique_ptr spark_row_info = std::make_unique(block); - spark_row_info->setBufferAddress(reinterpret_cast(alloc(spark_row_info->getTotalBytes(), 64))); - // spark_row_info->setBufferAddress(alignedAlloc(spark_row_info->getTotalBytes(), 64)); - memset(spark_row_info->getBufferAddress(), 0, spark_row_info->getTotalBytes()); - for (auto col_idx = 0; col_idx < spark_row_info->getNumCols(); col_idx++) + // Calculated the offsets_ and total memory size based on lengths_ + int64_t total_memory_size = spark_row_info->lengths[0]; + for (auto i = 1; i < spark_row_info->num_rows; i++) { - const auto & col = block.getByPosition(col_idx); - int64_t field_offset = spark_row_info->getFieldOffset(col_idx); + spark_row_info->offsets[i] = spark_row_info->offsets[i - 1] + spark_row_info->lengths[i - 1]; + total_memory_size += spark_row_info->lengths[i]; + } + spark_row_info->total_bytes = total_memory_size; + spark_row_info->buffer_address = reinterpret_cast(alloc(total_memory_size)); + memset(spark_row_info->buffer_address, 0, sizeof(int8_t) * spark_row_info->total_bytes); + for (auto i = 0; i < spark_row_info->num_cols; i++) + { + auto array = block.getByPosition(i); + int64_t field_offset = getFieldOffset(spark_row_info->null_bitset_width_in_bytes, i); writeValue( - spark_row_info->getBufferAddress(), + spark_row_info->buffer_address, field_offset, - col, - col_idx, - spark_row_info->getNumRows(), - spark_row_info->getOffsets(), - spark_row_info->getBufferCursor()); + array, + i, + spark_row_info->num_rows, + spark_row_info->offsets, + spark_row_info->buffer_cursor); } return spark_row_info; } - -void CHColumnToSparkRow::freeMem(char * address, size_t size) +void CHColumnToSparkRow::freeMem(uint8_t * address, size_t size) { free(address, size); - // rollback(size); -} - -BackingDataLengthCalculator::BackingDataLengthCalculator(const DataTypePtr & type_) - : type_without_nullable(removeNullable(type_)), which(type_without_nullable) -{ - if (!isFixedLengthDataType(type_without_nullable) && !isVariableLengthDataType(type_without_nullable)) - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Doesn't support type {} for BackingDataLengthCalculator", type_without_nullable->getName()); -} - -int64_t BackingDataLengthCalculator::calculate(const Field & field) const -{ - if (field.isNull()) - return 0; - - if (which.isNativeInt() || which.isNativeUInt() || which.isFloat() || which.isDateOrDate32() || which.isDateTime64() - || which.isDecimal32() || which.isDecimal64()) - return 0; - - if (which.isStringOrFixedString()) - { - const auto & str = field.get(); - return roundNumberOfBytesToNearestWord(str.size()); - } - - if (which.isDecimal128()) - return 16; - - if (which.isArray()) - { - /// 内存布局:numElements(8B) | null_bitmap(与numElements成正比) | values(每个值长度与类型有关) | backing buffer - const auto & array = field.get(); /// Array can not be wrapped with Nullable - const auto num_elems = array.size(); - int64_t res = 8 + calculateBitSetWidthInBytes(num_elems); - - const auto * array_type = typeid_cast(type_without_nullable.get()); - const auto & nested_type = array_type->getNestedType(); - res += roundNumberOfBytesToNearestWord(getArrayElementSize(nested_type) * num_elems); - - BackingDataLengthCalculator calculator(nested_type); - for (size_t i = 0; i < array.size(); ++i) - res += calculator.calculate(array[i]); - return res; - } - - if (which.isMap()) - { - /// 内存布局:Length of UnsafeArrayData of key(8B) | UnsafeArrayData of key | UnsafeArrayData of value - int64_t res = 8; - - /// Construct Array of keys and values from Map - const auto & map = field.get(); /// Map can not be wrapped with Nullable - const auto num_keys = map.size(); - auto array_key = Array(); - auto array_val = Array(); - array_key.reserve(num_keys); - array_val.reserve(num_keys); - for (size_t i = 0; i < num_keys; ++i) - { - const auto & pair = map[i].get(); - array_key.push_back(pair[0]); - array_val.push_back(pair[1]); - } - - const auto * map_type = typeid_cast(type_without_nullable.get()); - - const auto & key_type = map_type->getKeyType(); - const auto key_array_type = std::make_shared(key_type); - BackingDataLengthCalculator calculator_key(key_array_type); - res += calculator_key.calculate(array_key); - - const auto & val_type = map_type->getValueType(); - const auto type_array_val = std::make_shared(val_type); - BackingDataLengthCalculator calculator_val(type_array_val); - res += calculator_val.calculate(array_val); - return res; - } - - if (which.isTuple()) - { - /// 内存布局:null_bitmap(字节数与字段数成正比) | field1 value(8B) | field2 value(8B) | ... | fieldn value(8B) | backing buffer - const auto & tuple = field.get(); /// Tuple can not be wrapped with Nullable - const auto * type_tuple = typeid_cast(type_without_nullable.get()); - const auto & type_fields = type_tuple->getElements(); - const auto num_fields = type_fields.size(); - int64_t res = calculateBitSetWidthInBytes(num_fields) + 8 * num_fields; - for (size_t i = 0; i < num_fields; ++i) - { - BackingDataLengthCalculator calculator(type_fields[i]); - res += calculator.calculate(tuple[i]); - } - return res; - } - - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Doesn't support type {} for BackingBufferLengthCalculator", type_without_nullable->getName()); } - -int64_t BackingDataLengthCalculator::getArrayElementSize(const DataTypePtr & nested_type) -{ - const WhichDataType nested_which(removeNullable(nested_type)); - if (nested_which.isUInt8() || nested_which.isInt8()) - return 1; - else if (nested_which.isUInt16() || nested_which.isInt16() || nested_which.isDate()) - return 2; - else if ( - nested_which.isUInt32() || nested_which.isInt32() || nested_which.isFloat32() || nested_which.isDate32() - || nested_which.isDecimal32()) - return 4; - else if ( - nested_which.isUInt64() || nested_which.isInt64() || nested_which.isFloat64() || nested_which.isDateTime64() - || nested_which.isDecimal64()) - return 8; - else - return 8; -} - -bool BackingDataLengthCalculator::isFixedLengthDataType(const DataTypePtr & type_without_nullable) -{ - const WhichDataType which(type_without_nullable); - return which.isUInt8() || which.isInt8() || which.isUInt16() || which.isInt16() || which.isDate() || which.isUInt32() || which.isInt32() - || which.isFloat32() || which.isDate32() || which.isDecimal32() || which.isUInt64() || which.isInt64() || which.isFloat64() - || which.isDateTime64() || which.isDecimal64(); -} - -bool BackingDataLengthCalculator::isVariableLengthDataType(const DataTypePtr & type_without_nullable) -{ - const WhichDataType which(type_without_nullable); - return which.isStringOrFixedString() || which.isDecimal128() || which.isArray() || which.isMap() || which.isTuple(); -} - -bool BackingDataLengthCalculator::isDataTypeSupportRawData(const DB::DataTypePtr & type_without_nullable) -{ - const WhichDataType which(type_without_nullable); - return isFixedLengthDataType(type_without_nullable) || which.isStringOrFixedString() || which.isDecimal128(); -} - - -VariableLengthDataWriter::VariableLengthDataWriter( - const DataTypePtr & type_, char * buffer_address_, const std::vector & offsets_, std::vector & buffer_cursor_) - : type_without_nullable(removeNullable(type_)) - , which(type_without_nullable) - , buffer_address(buffer_address_) - , offsets(offsets_) - , buffer_cursor(buffer_cursor_) -{ - assert(buffer_address); - assert(!offsets.empty()); - assert(!buffer_cursor.empty()); - assert(offsets.size() == buffer_cursor.size()); - - if (!BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) - throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataWriter doesn't support type {}", type_without_nullable->getName()); -} - -int64_t VariableLengthDataWriter::writeArray(size_t row_idx, const DB::Array & array, int64_t parent_offset) -{ - /// 内存布局:numElements(8B) | null_bitmap(与numElements成正比) | values(每个值长度与类型有关) | backing data - const auto & offset = offsets[row_idx]; - auto & cursor = buffer_cursor[row_idx]; - const auto num_elems = array.size(); - const auto * array_type = typeid_cast(type_without_nullable.get()); - const auto & nested_type = array_type->getNestedType(); - - /// Write numElements(8B) - const auto start = cursor; - memcpy(buffer_address + offset + cursor, &num_elems, 8); - cursor += 8; - if (num_elems == 0) - return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, 8); - - /// Skip null_bitmap(already reset to zero) - const auto len_null_bitmap = calculateBitSetWidthInBytes(num_elems); - cursor += len_null_bitmap; - - /// Skip values(already reset to zero) - const auto elem_size = BackingDataLengthCalculator::getArrayElementSize(nested_type); - const auto len_values = roundNumberOfBytesToNearestWord(elem_size * num_elems); - cursor += len_values; - - if (BackingDataLengthCalculator::isFixedLengthDataType(removeNullable(nested_type))) - { - /// If nested type is fixed-length data type, update null_bitmap and values in place - FixedLengthDataWriter writer(nested_type); - for (size_t i = 0; i < num_elems; ++i) - { - const auto & elem = array[i]; - if (elem.isNull()) - bitSet(buffer_address + offset + start + 8, i); - else - // writer.write(elem, buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); - writer.unsafeWrite(&elem.reinterpret(), buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); - } - } - else - { - /// If nested type is not fixed-length data type, update null_bitmap in place - /// And append values in backing data recursively - VariableLengthDataWriter writer(nested_type, buffer_address, offsets, buffer_cursor); - for (size_t i = 0; i < num_elems; ++i) - { - const auto & elem = array[i]; - if (elem.isNull()) - bitSet(buffer_address + offset + start + 8, i); - else - { - const auto offset_and_size = writer.write(row_idx, elem, start); - memcpy(buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size, &offset_and_size, 8); - } - } - } - return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, cursor - start); -} - -int64_t VariableLengthDataWriter::writeMap(size_t row_idx, const DB::Map & map, int64_t parent_offset) -{ - /// 内存布局:Length of UnsafeArrayData of key(8B) | UnsafeArrayData of key | UnsafeArrayData of value - const auto & offset = offsets[row_idx]; - auto & cursor = buffer_cursor[row_idx]; - - /// Skip length of UnsafeArrayData of key(8B) - const auto start = cursor; - cursor += 8; - - /// If Map is empty, return in advance - const auto num_pairs = map.size(); - if (num_pairs == 0) - return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, 8); - - /// Construct array of keys and array of values from map - auto key_array = Array(); - auto val_array = Array(); - key_array.reserve(num_pairs); - val_array.reserve(num_pairs); - for (size_t i = 0; i < num_pairs; ++i) - { - const auto & pair = map[i].get(); - key_array.push_back(pair[0]); - val_array.push_back(pair[1]); - } - - const auto * map_type = typeid_cast(type_without_nullable.get()); - - /// Append UnsafeArrayData of key - const auto & key_type = map_type->getKeyType(); - const auto key_array_type = std::make_shared(key_type); - VariableLengthDataWriter key_writer(key_array_type, buffer_address, offsets, buffer_cursor); - const auto key_array_size = BackingDataLengthCalculator::extractSize(key_writer.write(row_idx, key_array, start + 8)); - - /// Fill length of UnsafeArrayData of key - memcpy(buffer_address + offset + start, &key_array_size, 8); - - /// Append UnsafeArrayData of value - const auto & val_type = map_type->getValueType(); - const auto val_array_type = std::make_shared(val_type); - VariableLengthDataWriter val_writer(val_array_type, buffer_address, offsets, buffer_cursor); - val_writer.write(row_idx, val_array, start + 8 + key_array_size); - return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, cursor - start); -} - -int64_t VariableLengthDataWriter::writeStruct(size_t row_idx, const DB::Tuple & tuple, int64_t parent_offset) -{ - /// 内存布局:null_bitmap(字节数与字段数成正比) | values(num_fields * 8B) | backing data - const auto & offset = offsets[row_idx]; - auto & cursor = buffer_cursor[row_idx]; - const auto start = cursor; - - /// Skip null_bitmap - const auto * tuple_type = typeid_cast(type_without_nullable.get()); - const auto & field_types = tuple_type->getElements(); - const auto num_fields = field_types.size(); - if (num_fields == 0) - return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, 0); - const auto len_null_bitmap = calculateBitSetWidthInBytes(num_fields); - cursor += len_null_bitmap; - - /// Skip values - cursor += num_fields * 8; - - /// If field type is fixed-length, fill field value in values region - /// else append it to backing data region, and update offset_and_size in values region - for (size_t i = 0; i < num_fields; ++i) - { - const auto & field_value = tuple[i]; - const auto & field_type = field_types[i]; - if (field_value.isNull()) - { - bitSet(buffer_address + offset + start, i); - continue; - } - - if (BackingDataLengthCalculator::isFixedLengthDataType(removeNullable(field_type))) - { - FixedLengthDataWriter writer(field_type); - // writer.write(field_value, buffer_address + offset + start + len_null_bitmap + i * 8); - writer.unsafeWrite(&field_value.reinterpret(), buffer_address + offset + start + len_null_bitmap + i * 8); - } - else - { - VariableLengthDataWriter writer(field_type, buffer_address, offsets, buffer_cursor); - const auto offset_and_size = writer.write(row_idx, field_value, start); - memcpy(buffer_address + offset + start + len_null_bitmap + 8 * i, &offset_and_size, 8); - } - } - return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, cursor - start); -} - -int64_t VariableLengthDataWriter::write(size_t row_idx, const DB::Field & field, int64_t parent_offset) -{ - assert(row_idx < offsets.size()); - - if (field.isNull()) - return 0; - - if (which.isStringOrFixedString()) - { - const auto & str = field.get(); - return writeUnalignedBytes(row_idx, str.data(), str.size(), parent_offset); - } - - if (which.isDecimal128()) - { - // const auto & decimal = field.get>(); - // const auto value = decimal.getValue(); - return writeUnalignedBytes(row_idx, &field.reinterpret(), sizeof(Decimal128), parent_offset); - } - - if (which.isArray()) - { - const auto & array = field.get(); - return writeArray(row_idx, array, parent_offset); - } - - if (which.isMap()) - { - const auto & map = field.get(); - return writeMap(row_idx, map, parent_offset); - } - - if (which.isTuple()) - { - const auto & tuple = field.get(); - return writeStruct(row_idx, tuple, parent_offset); - } - - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Doesn't support type {} for BackingDataWriter", type_without_nullable->getName()); -} - -int64_t BackingDataLengthCalculator::getOffsetAndSize(int64_t cursor, int64_t size) -{ - return (cursor << 32) | size; -} - -int64_t BackingDataLengthCalculator::extractOffset(int64_t offset_and_size) -{ - return offset_and_size >> 32; -} - -int64_t BackingDataLengthCalculator::extractSize(int64_t offset_and_size) -{ - return offset_and_size & 0xffffffff; -} - -int64_t VariableLengthDataWriter::writeUnalignedBytes(size_t row_idx, const char * src, size_t size, int64_t parent_offset) -{ - memcpy(buffer_address + offsets[row_idx] + buffer_cursor[row_idx], src, size); - auto res = BackingDataLengthCalculator::getOffsetAndSize(buffer_cursor[row_idx] - parent_offset, size); - buffer_cursor[row_idx] += roundNumberOfBytesToNearestWord(size); - return res; -} - - -FixedLengthDataWriter::FixedLengthDataWriter(const DB::DataTypePtr & type_) - : type_without_nullable(removeNullable(type_)), which(type_without_nullable) -{ - if (!BackingDataLengthCalculator::isFixedLengthDataType(type_without_nullable)) - throw Exception(ErrorCodes::UNKNOWN_TYPE, "FixedLengthWriter doesn't support type {}", type_without_nullable->getName()); -} - -void FixedLengthDataWriter::write(const DB::Field & field, char * buffer) -{ - /// Skip null value - if (field.isNull()) - return; - - if (which.isUInt8()) - { - const auto value = UInt8(field.get()); - memcpy(buffer, &value, 1); - } - else if (which.isUInt16() || which.isDate()) - { - const auto value = UInt16(field.get()); - memcpy(buffer, &value, 2); - } - else if (which.isUInt32() || which.isDate32()) - { - const auto value = UInt32(field.get()); - memcpy(buffer, &value, 4); - } - else if (which.isUInt64()) - { - const auto & value = field.get(); - memcpy(buffer, &value, 8); - } - else if (which.isInt8()) - { - const auto value = Int8(field.get()); - memcpy(buffer, &value, 1); - } - else if (which.isInt16()) - { - const auto value = Int16(field.get()); - memcpy(buffer, &value, 2); - } - else if (which.isInt32()) - { - const auto value = Int32(field.get()); - memcpy(buffer, &value, 4); - } - else if (which.isInt64()) - { - const auto & value = field.get(); - memcpy(buffer, &value, 8); - } - else if (which.isFloat32()) - { - const auto value = Float32(field.get()); - memcpy(buffer, &value, 4); - } - else if (which.isFloat64()) - { - const auto & value = field.get(); - memcpy(buffer, &value, 8); - } - else if (which.isDecimal32()) - { - const auto & value = field.get(); - const auto decimal = value.getValue(); - memcpy(buffer, &decimal, 4); - } - else if (which.isDecimal64() || which.isDateTime64()) - { - const auto & value = field.get(); - auto decimal = value.getValue(); - memcpy(buffer, &decimal, 8); - } - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "FixedLengthDataWriter doesn't support type {}", type_without_nullable->getName()); -} - -void FixedLengthDataWriter::unsafeWrite(const StringRef & str, char * buffer) -{ - memcpy(buffer, str.data, str.size); -} - -void FixedLengthDataWriter::unsafeWrite(const char * __restrict src, char * __restrict buffer) -{ - memcpy(buffer, src, type_without_nullable->getSizeOfValueInMemory()); -} - } diff --git a/utils/local-engine/Parser/CHColumnToSparkRow.h b/utils/local-engine/Parser/CHColumnToSparkRow.h index 5b10aa3abe53..ec49cec6eebf 100644 --- a/utils/local-engine/Parser/CHColumnToSparkRow.h +++ b/utils/local-engine/Parser/CHColumnToSparkRow.h @@ -1,166 +1,51 @@ #pragma once #include #include -#include #include -#include - namespace local_engine { int64_t calculateBitSetWidthInBytes(int32_t num_fields); -int64_t roundNumberOfBytesToNearestWord(int64_t num_bytes); -void bitSet(char * bitmap, int32_t index); -bool isBitSet(const char * bitmap, int32_t index); class CHColumnToSparkRow; class SparkRowToCHColumn; -class SparkRowInfo : public boost::noncopyable +class SparkRowInfo { friend CHColumnToSparkRow; friend SparkRowToCHColumn; public: - explicit SparkRowInfo(const DB::Block & block); - - const DB::DataTypes & getDataTypes() const; - - int64_t getFieldOffset(int32_t col_idx) const; - + explicit SparkRowInfo(DB::Block & block); int64_t getNullBitsetWidthInBytes() const; - void setNullBitsetWidthInBytes(int64_t null_bitset_width_in_bytes_); - + void setNullBitsetWidthInBytes(int64_t nullBitsetWidthInBytes); int64_t getNumCols() const; - void setNumCols(int64_t num_cols_); - + void setNumCols(int64_t numCols); int64_t getNumRows() const; - void setNumRows(int64_t num_rows_); - - char * getBufferAddress() const; - void setBufferAddress(char * buffer_address); - + void setNumRows(int64_t numRows); + unsigned char * getBufferAddress() const; + void setBufferAddress(unsigned char * bufferAddress); const std::vector & getOffsets() const; const std::vector & getLengths() const; - std::vector & getBufferCursor(); int64_t getTotalBytes() const; private: - const DB::DataTypes types; - int64_t num_rows; - int64_t num_cols; - int64_t null_bitset_width_in_bytes; int64_t total_bytes; - + int64_t null_bitset_width_in_bytes; + int64_t num_cols; + int64_t num_rows; + std::vector buffer_cursor; + uint8_t * buffer_address; std::vector offsets; std::vector lengths; - std::vector buffer_cursor; - char * buffer_address; }; using SparkRowInfoPtr = std::unique_ptr; -class CHColumnToSparkRow : private Allocator -// class CHColumnToSparkRow : public DB::Arena -{ -public: - std::unique_ptr convertCHColumnToSparkRow(const DB::Block & block); - void freeMem(char * address, size_t size); -}; - -/// Return backing data length of values with variable-length type in bytes -class BackingDataLengthCalculator -{ -public: - static constexpr size_t DECIMAL_MAX_INT64_DIGITS = 18; - - explicit BackingDataLengthCalculator(const DB::DataTypePtr & type_); - virtual ~BackingDataLengthCalculator() = default; - - /// Return length is guranteed to round up to 8 - virtual int64_t calculate(const DB::Field & field) const; - - static int64_t getArrayElementSize(const DB::DataTypePtr & nested_type); - - /// Is CH DataType can be converted to fixed-length data type in Spark? - static bool isFixedLengthDataType(const DB::DataTypePtr & type_without_nullable); - - /// Is CH DataType can be converted to variable-length data type in Spark? - static bool isVariableLengthDataType(const DB::DataTypePtr & type_without_nullable); - - /// If Data Type can use raw data between CH Column and Spark Row if value is not null - static bool isDataTypeSupportRawData(const DB::DataTypePtr & type_without_nullable); - - static int64_t getOffsetAndSize(int64_t cursor, int64_t size); - static int64_t extractOffset(int64_t offset_and_size); - static int64_t extractSize(int64_t offset_and_size); - -private: - // const DB::DataTypePtr type; - const DB::DataTypePtr type_without_nullable; - const DB::WhichDataType which; -}; - -/// Writing variable-length typed values to backing data region of Spark Row -/// User who calls VariableLengthDataWriter is responsible to write offset_and_size -/// returned by VariableLengthDataWriter::write to field value in Spark Row -class VariableLengthDataWriter +class CHColumnToSparkRow : private Allocator { public: - VariableLengthDataWriter( - const DB::DataTypePtr & type_, - char * buffer_address_, - const std::vector & offsets_, - std::vector & buffer_cursor_); - - virtual ~VariableLengthDataWriter() = default; - - /// Write value of variable-length to backing data region of structure(row or array) and return offset and size in backing data region - /// It's caller's duty to make sure that row fields or array elements are written in order - /// parent_offset: the starting offset of current structure in which we are updating it's backing data region - virtual int64_t write(size_t row_idx, const DB::Field & field, int64_t parent_offset); - - /// Only support String/FixedString/Decimal32/Decimal64 - int64_t writeUnalignedBytes(size_t row_idx, const char * src, size_t size, int64_t parent_offset); -private: - int64_t writeArray(size_t row_idx, const DB::Array & array, int64_t parent_offset); - int64_t writeMap(size_t row_idx, const DB::Map & map, int64_t parent_offset); - int64_t writeStruct(size_t row_idx, const DB::Tuple & tuple, int64_t parent_offset); - - // const DB::DataTypePtr type; - const DB::DataTypePtr type_without_nullable; - const DB::WhichDataType which; - - /// Global buffer of spark rows - char * const buffer_address; - /// Offsets of each spark row - const std::vector & offsets; - /// Cursors of backing data in each spark row, relative to offsets - std::vector & buffer_cursor; -}; - -class FixedLengthDataWriter -{ -public: - explicit FixedLengthDataWriter(const DB::DataTypePtr & type_); - virtual ~FixedLengthDataWriter() = default; - - /// Write value of fixed-length to values region of structure(struct or array) - /// It's caller's duty to make sure that struct fields or array elements are written in order - virtual void write(const DB::Field & field, char * buffer); - - /// Copy memory chunk of Fixed length typed CH Column directory to buffer for performance. - /// It is unsafe unless you know what you are doing. - virtual void unsafeWrite(const StringRef & str, char * buffer); - - /// Copy memory chunk of in fixed length typed Field directory to buffer for performance. - /// It is unsafe unless you know what you are doing. - virtual void unsafeWrite(const char * __restrict src, char * __restrict buffer); - -private: - // const DB::DataTypePtr type; - const DB::DataTypePtr type_without_nullable; - const DB::WhichDataType which; + std::unique_ptr convertCHColumnToSparkRow(DB::Block & block); + void freeMem(uint8_t * address, size_t size); }; - } diff --git a/utils/local-engine/Parser/SerializedPlanParser.cpp b/utils/local-engine/Parser/SerializedPlanParser.cpp index d1f26006208b..075a89ed9b1d 100644 --- a/utils/local-engine/Parser/SerializedPlanParser.cpp +++ b/utils/local-engine/Parser/SerializedPlanParser.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -11,12 +10,8 @@ #include #include #include -#include #include #include -#include -#include -#include #include #include #include @@ -59,7 +54,6 @@ namespace ErrorCodes extern const int BAD_ARGUMENTS; extern const int NO_SUCH_DATA_PART; extern const int UNKNOWN_FUNCTION; - extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; } } @@ -78,10 +72,71 @@ void join(ActionsDAG::NodeRawConstPtrs v, char c, std::string & s) } } -bool isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & ch_type) +std::string typeName(const substrait::Type & type) { - const auto parsed_ch_type = SerializedPlanParser::parseType(substrait_type); - return parsed_ch_type->equals(*ch_type); + if (type.has_string()) + { + return "String"; + } + else if (type.has_i8()) + { + return "I8"; + } + else if (type.has_i16()) + { + return "I16"; + } + else if (type.has_i32()) + { + return "I32"; + } + else if (type.has_i64()) + { + return "I64"; + } + else if (type.has_fp32()) + { + return "FP32"; + } + else if (type.has_fp64()) + { + return "FP64"; + } + else if (type.has_bool_()) + { + return "Boolean"; + } + else if (type.has_date()) + { + return "Date"; + } + else if (type.has_timestamp()) + { + return "Timestamp"; + } + + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unknown type {}", magic_enum::enum_name(type.kind_case())); +} + +bool isTypeSame(const substrait::Type & type, DataTypePtr data_type) +{ + static const std::map type_mapping + = {{"I8", "Int8"}, + {"I16", "Int16"}, + {"I32", "Int32"}, + {"I64", "Int64"}, + {"FP32", "Float32"}, + {"FP64", "Float64"}, + {"Date", "Date32"}, + {"Timestamp", "DateTime64(6)"}, + {"String", "String"}, + {"Boolean", "UInt8"}}; + + std::string type_name = typeName(type); + if (!type_mapping.contains(type_name)) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Unknown type {}", type_name); + + return type_mapping.at(type_name) == data_type->getName(); } std::string getCastFunction(const substrait::Type & type) @@ -131,8 +186,6 @@ std::string getCastFunction(const substrait::Type & type) else throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support cast type {}", type.DebugString()); - /// TODO(taiyang-li): implement cast functions of other types - return ch_function_name; } @@ -257,116 +310,77 @@ Block SerializedPlanParser::parseNameStruct(const substrait::NamedStruct & struc return Block(*std::move(internal_cols)); } -DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type) -{ - return wrapNullableType(nullable == substrait::Type_Nullability_NULLABILITY_NULLABLE, nested_type); -} - -DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type) +static DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type) { - if (nullable) + if (nullable == substrait::Type_Nullability_NULLABILITY_NULLABLE) + { return std::make_shared(nested_type); + } else + { return nested_type; + } } -DataTypePtr SerializedPlanParser::parseType(const substrait::Type & substrait_type) +DataTypePtr SerializedPlanParser::parseType(const substrait::Type & type) { - DataTypePtr ch_type; - if (substrait_type.has_bool_()) - { - ch_type = std::make_shared(); - ch_type = wrapNullableType(substrait_type.bool_().nullability(), ch_type); - } - else if (substrait_type.has_i8()) - { - ch_type = std::make_shared(); - ch_type = wrapNullableType(substrait_type.i8().nullability(), ch_type); - } - else if (substrait_type.has_i16()) - { - ch_type = std::make_shared(); - ch_type = wrapNullableType(substrait_type.i16().nullability(), ch_type); - } - else if (substrait_type.has_i32()) + DataTypePtr internal_type = nullptr; + auto & factory = DataTypeFactory::instance(); + if (type.has_bool_()) { - ch_type = std::make_shared(); - ch_type = wrapNullableType(substrait_type.i32().nullability(), ch_type); + internal_type = factory.get("UInt8"); + internal_type = wrapNullableType(type.bool_().nullability(), internal_type); } - else if (substrait_type.has_i64()) + else if (type.has_i8()) { - ch_type = std::make_shared(); - ch_type = wrapNullableType(substrait_type.i64().nullability(), ch_type); + internal_type = factory.get("Int8"); + internal_type = wrapNullableType(type.i8().nullability(), internal_type); } - else if (substrait_type.has_string() || substrait_type.has_binary()) + else if (type.has_i16()) { - ch_type = std::make_shared(); - ch_type = wrapNullableType(substrait_type.string().nullability(), ch_type); + internal_type = factory.get("Int16"); + internal_type = wrapNullableType(type.i16().nullability(), internal_type); } - else if (substrait_type.has_fp32()) + else if (type.has_i32()) { - ch_type = std::make_shared(); - ch_type = wrapNullableType(substrait_type.fp32().nullability(), ch_type); + internal_type = factory.get("Int32"); + internal_type = wrapNullableType(type.i32().nullability(), internal_type); } - else if (substrait_type.has_fp64()) + else if (type.has_i64()) { - ch_type = std::make_shared(); - ch_type = wrapNullableType(substrait_type.fp64().nullability(), ch_type); + internal_type = factory.get("Int64"); + internal_type = wrapNullableType(type.i64().nullability(), internal_type); } - else if (substrait_type.has_timestamp()) + else if (type.has_string()) { - ch_type = std::make_shared(6); - ch_type = wrapNullableType(substrait_type.timestamp().nullability(), ch_type); + internal_type = factory.get("String"); + internal_type = wrapNullableType(type.string().nullability(), internal_type); } - else if (substrait_type.has_date()) + else if (type.has_fp32()) { - ch_type = std::make_shared(); - ch_type = wrapNullableType(substrait_type.date().nullability(), ch_type); + internal_type = factory.get("Float32"); + internal_type = wrapNullableType(type.fp32().nullability(), internal_type); } - else if (substrait_type.has_decimal()) + else if (type.has_fp64()) { - UInt32 precision = substrait_type.decimal().precision(); - UInt32 scale = substrait_type.decimal().scale(); - if (precision <= DataTypeDecimal32::maxPrecision()) - ch_type = std::make_shared(precision, scale); - else if (precision <= DataTypeDecimal64::maxPrecision()) - ch_type = std::make_shared(precision, scale); - else if (precision <= DataTypeDecimal128::maxPrecision()) - ch_type = std::make_shared(precision, scale); - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision); - - ch_type = wrapNullableType(substrait_type.decimal().nullability(), ch_type); + internal_type = factory.get("Float64"); + internal_type = wrapNullableType(type.fp64().nullability(), internal_type); } - else if (substrait_type.has_struct_()) + else if (type.has_date()) { - assert(substrait_type.struct_().nullability() == substrait::Type_Nullability_NULLABILITY_REQUIRED); - - DataTypes ch_field_types(substrait_type.struct_().types().size()); - for (size_t i = 0; i < ch_field_types.size(); ++i) - ch_field_types[i] = std::move(parseType(substrait_type.struct_().types()[i])); - ch_type = std::make_shared(ch_field_types); + internal_type = factory.get("Date32"); + internal_type = wrapNullableType(type.date().nullability(), internal_type); } - else if (substrait_type.has_list()) + else if (type.has_timestamp()) { - assert(substrait_type.struct_().nullability() == substrait::Type_Nullability_NULLABILITY_REQUIRED); - - auto ch_nested_type = parseType(substrait_type.list().type()); - ch_type = std::make_shared(ch_nested_type); + internal_type = factory.get("DateTime64(6)"); + internal_type = wrapNullableType(type.timestamp().nullability(), internal_type); } - else if (substrait_type.has_map()) + else { - assert(substrait_type.map().nullability() == substrait::Type_Nullability_NULLABILITY_REQUIRED); - - auto ch_key_type = parseType(substrait_type.map().key()); - auto ch_val_type = parseType(substrait_type.map().value()); - ch_type = std::make_shared(ch_key_type, ch_val_type); + throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support type {}", type.DebugString()); } - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support type {}", substrait_type.DebugString()); - - /// TODO(taiyang-li): consider Time/IntervalYear/IntervalDay/TimestampTZ/UUID/FixedChar/VarChar/FixedBinary/UserDefined - return std::move(ch_type); + return internal_type; } QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr plan) { @@ -567,7 +581,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel) for (size_t i = 0; i < measure_positions.size(); i++) { - if (!isTypeMatched(measure_types[i], source[measure_positions[i]].type)) + if (!isTypeSame(measure_types[i], source[measure_positions[i]].type)) { auto target_type = parseType(measure_types[i]); target[measure_positions[i]].type = target_type; @@ -796,8 +810,8 @@ std::string SerializedPlanParser::getFunctionName(const std::string & function_s { const auto & output_type = function.output_type(); auto args = function.arguments(); - auto pos = function_signature.find(':'); - auto function_name = function_signature.substr(0, pos); + auto function_name_idx = function_signature.find(':'); + auto function_name = function_signature.substr(0, function_name_idx); if (!SCALAR_FUNCTIONS.contains(function_name)) throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unsupported function {}", function_name); @@ -889,20 +903,18 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( { required_columns.emplace_back(args[0]->result_name); } - if (function_signature.find("extract:", 0) != function_signature.npos) { // delete the first arg args.erase(args.begin()); } - auto function_builder = FunctionFactory::instance().get(function_name, this->context); std::string args_name; join(args, ',', args_name); result_name = function_name + "(" + args_name + ")"; const auto * function_node = &actions_dag->addFunction(function_builder, args, result_name); result_node = function_node; - if (!isTypeMatched(rel.scalar_function().output_type(), function_node->result_type)) + if (!isTypeSame(rel.scalar_function().output_type(), function_node->result_type)) { auto cast_function = getCastFunction(rel.scalar_function().output_type()); DB::ActionsDAG::NodeRawConstPtrs cast_args({function_node}); @@ -946,151 +958,234 @@ SerializedPlanParser::toFunctionNode(ActionsDAGPtr action_dag, const String & fu return function_node; } -std::pair SerializedPlanParser::parseLiteral(const substrait::Expression_Literal & literal) -{ - DataTypePtr type; - Field field; - - switch (literal.literal_type_case()) - { - case substrait::Expression_Literal::kFp64: { - type = std::make_shared(); - field = literal.fp64(); - break; - } - case substrait::Expression_Literal::kFp32: { - type = std::make_shared(); - field = literal.fp32(); - break; - } - case substrait::Expression_Literal::kString: { - type = std::make_shared(); - field = literal.string(); - break; - } - case substrait::Expression_Literal::kBinary: { - type = std::make_shared(); - field = literal.binary(); - break; - } - case substrait::Expression_Literal::kI64: { - type = std::make_shared(); - field = literal.i64(); - break; - } - case substrait::Expression_Literal::kI32: { - type = std::make_shared(); - field = literal.i32(); - break; - } - case substrait::Expression_Literal::kBoolean: { - type = std::make_shared(); - field = literal.boolean() ? UInt8(1) : UInt8(0); - break; - } - case substrait::Expression_Literal::kI16: { - type = std::make_shared(); - field = literal.i16(); - break; - } - case substrait::Expression_Literal::kI8: { - type = std::make_shared(); - field = literal.i8(); - break; - } - case substrait::Expression_Literal::kDate: { - type = std::make_shared(); - field = literal.date(); - break; - } - case substrait::Expression_Literal::kTimestamp: { - type = std::make_shared(6); - field = DecimalField(literal.timestamp(), 6); - break; - } - case substrait::Expression_Literal::kDecimal: { - UInt32 precision = literal.decimal().precision(); - UInt32 scale = literal.decimal().scale(); - const auto & bytes = literal.decimal().value(); - - if (precision <= DataTypeDecimal32::maxPrecision()) - { - type = std::make_shared(precision, scale); - auto value = *reinterpret_cast(bytes.data()); - field = DecimalField(value, scale); - } - else if (precision <= DataTypeDecimal64::maxPrecision()) - { - type = std::make_shared(precision, scale); - auto value = *reinterpret_cast(bytes.data()); - field = DecimalField(value, scale); - } - else if (precision <= DataTypeDecimal128::maxPrecision()) - { - type = std::make_shared(precision, scale); - auto value = *reinterpret_cast(bytes.data()); - field = DecimalField(value, scale); - } - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision); - break; - } - /// TODO(taiyang-li) Other type: Struct/Map/List - case substrait::Expression_Literal::kList: { - /// TODO(taiyang-li) Implement empty list - if (literal.has_empty_list()) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Empty list not support!"); - - DataTypePtr first_type; - std::tie(first_type, std::ignore) = parseLiteral(literal.list().values(0)); - - size_t list_len = literal.list().values_size(); - Array array(list_len); - for (size_t i = 0; i < list_len; ++i) - { - auto type_and_field = std::move(parseLiteral(literal.list().values(i))); - if (!first_type->equals(*type_and_field.first)) - throw Exception( - ErrorCodes::LOGICAL_ERROR, - "Literal list type mismatch:{} and {}", - first_type->getName(), - type_and_field.first->getName()); - array[i] = std::move(type_and_field.second); - } - - type = std::make_shared(first_type); - field = std::move(array); - break; - } - case substrait::Expression_Literal::kNull: { - type = parseType(literal.null()); - field = std::move(Field{}); - break; - } - default: { - throw Exception( - ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}", magic_enum::enum_name(literal.literal_type_case())); - } - } - return std::make_pair(std::move(type), std::move(field)); -} - const ActionsDAG::Node * SerializedPlanParser::parseArgument(ActionsDAGPtr action_dag, const substrait::Expression & rel) { - auto add_column = [&](const DataTypePtr & type, const Field & field) -> auto - { - return &action_dag->addColumn(ColumnWithTypeAndName(type->createColumnConst(1, field), type, getUniqueName(toString(field)))); - }; - switch (rel.rex_type_case()) { case substrait::Expression::RexTypeCase::kLiteral: { - DataTypePtr type; - Field field; - std::tie(type, field) = parseLiteral(rel.literal()); - return add_column(type, field); + const auto & literal = rel.literal(); + switch (literal.literal_type_case()) + { + case substrait::Expression_Literal::kFp64: { + auto type = std::make_shared(); + return &action_dag->addColumn(ColumnWithTypeAndName( + type->createColumnConst(1, literal.fp64()), type, getUniqueName(std::to_string(literal.fp64())))); + } + case substrait::Expression_Literal::kFp32: { + auto type = std::make_shared(); + return &action_dag->addColumn(ColumnWithTypeAndName( + type->createColumnConst(1, literal.fp32()), type, getUniqueName(std::to_string(literal.fp32())))); + } + case substrait::Expression_Literal::kString: { + auto type = std::make_shared(); + return &action_dag->addColumn( + ColumnWithTypeAndName(type->createColumnConst(1, literal.string()), type, getUniqueName(literal.string()))); + } + case substrait::Expression_Literal::kI64: { + auto type = std::make_shared(); + return &action_dag->addColumn(ColumnWithTypeAndName( + type->createColumnConst(1, literal.i64()), type, getUniqueName(std::to_string(literal.i64())))); + } + case substrait::Expression_Literal::kI32: { + auto type = std::make_shared(); + return &action_dag->addColumn(ColumnWithTypeAndName( + type->createColumnConst(1, literal.i32()), type, getUniqueName(std::to_string(literal.i32())))); + } + case substrait::Expression_Literal::kBoolean: { + auto type = std::make_shared(); + return &action_dag->addColumn(ColumnWithTypeAndName( + type->createColumnConst(1, literal.boolean() ? 1 : 0), type, getUniqueName(std::to_string(literal.boolean())))); + } + case substrait::Expression_Literal::kI16: { + auto type = std::make_shared(); + return &action_dag->addColumn(ColumnWithTypeAndName( + type->createColumnConst(1, literal.i16()), type, getUniqueName(std::to_string(literal.i16())))); + } + case substrait::Expression_Literal::kI8: { + auto type = std::make_shared(); + return &action_dag->addColumn( + ColumnWithTypeAndName(type->createColumnConst(1, literal.i8()), type, getUniqueName(std::to_string(literal.i8())))); + } + case substrait::Expression_Literal::kDate: { + auto type = std::make_shared(); + return &action_dag->addColumn(ColumnWithTypeAndName( + type->createColumnConst(1, literal.date()), type, getUniqueName(std::to_string(literal.date())))); + } + case substrait::Expression_Literal::kTimestamp: { + auto type = std::make_shared(6); + auto field = DecimalField(literal.timestamp(), 6); + return &action_dag->addColumn(ColumnWithTypeAndName( + type->createColumnConst(1, field), type, getUniqueName(std::to_string(literal.timestamp())))); + } + case substrait::Expression_Literal::kList: { + SizeLimits limit; + if (literal.has_empty_list()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "empty list not support!"); + } + MutableColumnPtr values; + DataTypePtr type; + auto first_value = literal.list().values(0); + if (first_value.has_boolean()) + { + type = std::make_shared(); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + values->insert(literal.list().values(i).boolean() ? 1 : 0); + } + } + else if (first_value.has_i8()) + { + type = std::make_shared(); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + values->insert(literal.list().values(i).i8()); + } + } + else if (first_value.has_i16()) + { + type = std::make_shared(); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + values->insert(literal.list().values(i).i16()); + } + } + else if (first_value.has_i32()) + { + type = std::make_shared(); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + values->insert(literal.list().values(i).i32()); + } + } + else if (first_value.has_i64()) + { + type = std::make_shared(); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + values->insert(literal.list().values(i).i64()); + } + } + else if (first_value.has_fp32()) + { + type = std::make_shared(); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + values->insert(literal.list().values(i).fp32()); + } + } + else if (first_value.has_fp64()) + { + type = std::make_shared(); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + values->insert(literal.list().values(i).fp64()); + } + } + else if (first_value.has_date()) + { + type = std::make_shared(); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + values->insert(literal.list().values(i).date()); + } + } + else if (first_value.has_timestamp()) + { + type = std::make_shared(6); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + auto field = DecimalField(literal.list().values(i).timestamp(), 6); + values->insert(literal.list().values(i).timestamp()); + } + } + else if (first_value.has_string()) + { + type = std::make_shared(); + values = type->createColumn(); + for (int i = 0; i < literal.list().values_size(); ++i) + { + values->insert(literal.list().values(i).string()); + } + } + else + { + throw Exception( + ErrorCodes::UNKNOWN_TYPE, + "unsupported literal list type. {}", + magic_enum::enum_name(first_value.literal_type_case())); + } + auto set = std::make_shared(limit, true, false); + Block values_block; + auto name = getUniqueName("__set"); + values_block.insert(ColumnWithTypeAndName(std::move(values), type, name)); + set->setHeader(values_block.getColumnsWithTypeAndName()); + set->insertFromBlock(values_block.getColumnsWithTypeAndName()); + set->finishInsert(); + + auto arg = ColumnSet::create(set->getTotalRowCount(), set); + return &action_dag->addColumn(ColumnWithTypeAndName(std::move(arg), std::make_shared(), name)); + } + case substrait::Expression_Literal::kNull: { + DataTypePtr nested_type; + if (literal.null().has_i8()) + { + nested_type = std::make_shared(); + } + else if (literal.null().has_i16()) + { + nested_type = std::make_shared(); + } + else if (literal.null().has_i32()) + { + nested_type = std::make_shared(); + } + else if (literal.null().has_i64()) + { + nested_type = std::make_shared(); + } + else if (literal.null().has_bool_()) + { + nested_type = std::make_shared(); + } + else if (literal.null().has_fp32()) + { + nested_type = std::make_shared(); + } + else if (literal.null().has_fp64()) + { + nested_type = std::make_shared(); + } + else if (literal.null().has_date()) + { + nested_type = std::make_shared(); + } + else if (literal.null().has_timestamp()) + { + nested_type = std::make_shared(6); + } + else if (literal.null().has_string()) + { + nested_type = std::make_shared(); + } + auto type = std::make_shared(nested_type); + return &action_dag->addColumn(ColumnWithTypeAndName(type->createColumnConst(1, Field()), type, getUniqueName("null"))); + } + default: { + throw Exception( + ErrorCodes::UNKNOWN_TYPE, "unsupported constant type {}", magic_enum::enum_name(literal.literal_type_case())); + } + } } - case substrait::Expression::RexTypeCase::kSelection: { if (!rel.selection().has_direct_reference() || !rel.selection().direct_reference().has_struct_field()) { @@ -1099,7 +1194,6 @@ const ActionsDAG::Node * SerializedPlanParser::parseArgument(ActionsDAGPtr actio const auto * field = action_dag->getInputs()[rel.selection().direct_reference().struct_field().field()]; return action_dag->tryFindInIndex(field->result_name); } - case substrait::Expression::RexTypeCase::kCast: { if (!rel.cast().has_type() || !rel.cast().has_input()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node."); @@ -1130,7 +1224,6 @@ const ActionsDAG::Node * SerializedPlanParser::parseArgument(ActionsDAGPtr actio action_dag->addOrReplaceInIndex(*function_node); return function_node; } - case substrait::Expression::RexTypeCase::kIfThen: { const auto & if_then = rel.if_then(); auto function_multi_if = DB::FunctionFactory::instance().get("multiIf", this->context); @@ -1165,85 +1258,118 @@ const ActionsDAG::Node * SerializedPlanParser::parseArgument(ActionsDAGPtr actio action_dag->addOrReplaceInIndex(*function_node); return function_node; } - case substrait::Expression::RexTypeCase::kScalarFunction: { std::string result; std::vector useless; return parseFunctionWithDAG(rel, result, useless, action_dag, false); } - case substrait::Expression::RexTypeCase::kSingularOrList: { DB::ActionsDAG::NodeRawConstPtrs args; args.emplace_back(parseArgument(action_dag, rel.singular_or_list().value())); - - /// options should be non-empty and literals const auto & options = rel.singular_or_list().options(); + + SizeLimits limit; if (options.empty()) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Empty SingularOrList not supported"); + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "empty list not support!"); + } + ColumnPtr values; + DataTypePtr type; if (!options[0].has_literal()) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type"); - - DataTypePtr elem_type; - std::tie(elem_type, std::ignore) = parseLiteral(options[0].literal()); - - size_t options_len = options.size(); - MutableColumnPtr elem_column = elem_type->createColumn(); - elem_column->reserve(options_len); - for (size_t i = 0; i < options_len; ++i) { - if (!options[i].has_literal()) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!"); - - auto type_and_field = std::move(parseLiteral(options[i].literal())); - if (!elem_type->equals(*type_and_field.first)) - throw Exception( - ErrorCodes::LOGICAL_ERROR, - "SingularOrList options type mismatch:{} and {}", - elem_type->getName(), - type_and_field.first->getName()); - - elem_column->insert(type_and_field.second); + throw Exception(ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!"); } - - MutableColumns elem_columns; - elem_columns.emplace_back(std::move(elem_column)); - + auto first_value = options[0].literal(); + using FieldGetter = std::function; + auto fill_values = [options](DataTypePtr type, FieldGetter getter) -> ColumnPtr { + auto values = type->createColumn(); + for (const auto & v : options) + { + values->insert(getter(v.literal())); + } + return values; + }; + if (first_value.has_boolean()) + { + type = std::make_shared(); + values = fill_values(type, [](substrait::Expression_Literal expr) -> Field {return expr.boolean() ? 1 : 0;}); + } + else if (first_value.has_i8()) + { + type = std::make_shared(); + values = fill_values(type, [](substrait::Expression_Literal expr) -> Field {return expr.i8();}); + } + else if (first_value.has_i16()) + { + type = std::make_shared(); + values = type->createColumn(); + values = fill_values(type, [](substrait::Expression_Literal expr) -> Field {return expr.i16();}); + } + else if (first_value.has_i32()) + { + type = std::make_shared(); + values = fill_values(type, [](substrait::Expression_Literal expr) -> Field {return expr.i32();}); + } + else if (first_value.has_i64()) + { + type = std::make_shared(); + values = fill_values(type, [](substrait::Expression_Literal expr) -> Field {return expr.i64();}); + } + else if (first_value.has_fp32()) + { + type = std::make_shared(); + values = fill_values(type, [](substrait::Expression_Literal expr) -> Field {return expr.fp32();}); + } + else if (first_value.has_fp64()) + { + type = std::make_shared(); + values = fill_values(type, [](substrait::Expression_Literal expr) -> Field {return expr.fp64();}); + } + else if (first_value.has_date()) + { + type = std::make_shared(); + values = fill_values(type, [](substrait::Expression_Literal expr) -> Field {return expr.date();}); + } + else if (first_value.has_string()) + { + type = std::make_shared(); + values = fill_values(type, [](substrait::Expression_Literal expr) -> Field {return expr.string();}); + } + else + { + throw Exception( + ErrorCodes::UNKNOWN_TYPE, + "unsupported literal list type. {}", + magic_enum::enum_name(first_value.literal_type_case())); + } + auto set = std::make_shared(limit, true, false); + Block values_block; auto name = getUniqueName("__set"); - Block elem_block; - elem_block.insert(ColumnWithTypeAndName(nullptr, elem_type, name)); - elem_block.setColumns(std::move(elem_columns)); + values_block.insert(ColumnWithTypeAndName(values, type, name)); + set->setHeader(values_block.getColumnsWithTypeAndName()); + set->insertFromBlock(values_block.getColumnsWithTypeAndName()); + set->finishInsert(); - SizeLimits limit; - auto elem_set = std::make_shared(limit, true, false); - elem_set->setHeader(elem_block.getColumnsWithTypeAndName()); - elem_set->insertFromBlock(elem_block.getColumnsWithTypeAndName()); - elem_set->finishInsert(); - - auto arg = ColumnSet::create(elem_set->getTotalRowCount(), elem_set); + auto arg = ColumnSet::create(set->getTotalRowCount(), set); args.emplace_back(&action_dag->addColumn(ColumnWithTypeAndName(std::move(arg), std::make_shared(), name))); const auto * function_node = toFunctionNode(action_dag, "in", args); action_dag->addOrReplaceInIndex(*function_node); return function_node; } - - default: + default: { throw Exception( - ErrorCodes::UNKNOWN_TYPE, - "Unsupported spark expression type {} : {}", - magic_enum::enum_name(rel.rex_type_case()), - rel.DebugString()); + ErrorCodes::BAD_ARGUMENTS, "unsupported arg type {} : {}", magic_enum::enum_name(rel.rex_type_case()), rel.DebugString()); + } } } QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) { auto plan_ptr = std::make_unique(); - auto ok = plan_ptr->ParseFromString(plan); - if (!ok) - throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from string failed"); - - return std::move(parse(std::move(plan_ptr))); + plan_ptr->ParseFromString(plan); + LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "parse plan \n{}", plan_ptr->DebugString()); + return parse(std::move(plan_ptr)); } void SerializedPlanParser::initFunctionEnv() { diff --git a/utils/local-engine/Parser/SerializedPlanParser.h b/utils/local-engine/Parser/SerializedPlanParser.h index 0da7363d4382..22a2f5428f39 100644 --- a/utils/local-engine/Parser/SerializedPlanParser.h +++ b/utils/local-engine/Parser/SerializedPlanParser.h @@ -78,9 +78,6 @@ struct QueryContext std::shared_ptr custom_storage_merge_tree; }; -DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type); -DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type); - class SerializedPlanParser { public: @@ -135,8 +132,6 @@ class SerializedPlanParser void wrapNullable(std::vector columns, ActionsDAGPtr actionsDag); std::string getUniqueName(const std::string & name) { return name + "_" + std::to_string(name_no++); } - static std::pair parseLiteral(const substrait::Expression_Literal & literal); - static Aggregator::Params getAggregateParam(const Block & header, const ColumnNumbers & keys, const AggregateDescriptions & aggregates) { Settings settings; @@ -174,7 +169,7 @@ class SerializedPlanParser struct SparkBuffer { - char * address; + uint8_t * address; size_t size; }; diff --git a/utils/local-engine/Parser/SparkRowToCHColumn.cpp b/utils/local-engine/Parser/SparkRowToCHColumn.cpp index 7eb9a7bd4187..5b6c46812c78 100644 --- a/utils/local-engine/Parser/SparkRowToCHColumn.cpp +++ b/utils/local-engine/Parser/SparkRowToCHColumn.cpp @@ -2,11 +2,6 @@ #include #include #include -#include -#include -#include -#include -#include #include namespace DB @@ -14,7 +9,6 @@ namespace DB namespace ErrorCodes { extern const int UNKNOWN_TYPE; - extern const int LOGICAL_ERROR; } } @@ -27,356 +21,139 @@ jmethodID SparkRowToCHColumn::spark_row_interator_hasNext = nullptr; jmethodID SparkRowToCHColumn::spark_row_interator_next = nullptr; jmethodID SparkRowToCHColumn::spark_row_iterator_nextBatch = nullptr; -ALWAYS_INLINE static void writeRowToColumns(std::vector & columns, const SparkRowReader & spark_row_reader) +int64_t getStringColumnTotalSize(int ordinal, SparkRowInfo & spark_row_info) { - auto num_fields = columns.size(); - const auto & field_types = spark_row_reader.getFieldTypes(); - for (size_t i = 0; i < num_fields; i++) + SparkRowReader reader(spark_row_info.getNumCols()); + int64_t size = 0; + for (int64_t i = 0; i < spark_row_info.getNumRows(); i++) { - if (spark_row_reader.supportRawData(i)) - { - const StringRef str{std::move(spark_row_reader.getStringRef(i))}; - columns[i]->insertData(str != EMPTY_STRING_REF ? str.data : nullptr, str.size); - } - else - columns[i]->insert(spark_row_reader.getField(i)); + reader.pointTo( + reinterpret_cast(spark_row_info.getBufferAddress() + spark_row_info.getOffsets()[i]), spark_row_info.getLengths()[i]); + size += (reader.getStringSize(ordinal) + 1); } + return size; } -std::unique_ptr -SparkRowToCHColumn::convertSparkRowInfoToCHColumn(const SparkRowInfo & spark_row_info, const Block & header) +void writeRowToColumns(std::vector & columns,std::vector& types, SparkRowReader & spark_row_reader) { - auto block = std::make_unique(); - *block = std::move(header.cloneEmpty()); - MutableColumns mutable_columns{std::move(block->mutateColumns())}; - const auto num_rows = spark_row_info.getNumRows(); - for (size_t col_i = 0; col_i < header.columns(); ++col_i) - mutable_columns[col_i]->reserve(num_rows); - - DataTypes types{std::move(header.getDataTypes())}; - SparkRowReader row_reader(types); - for (int64_t i = 0; i < num_rows; i++) + int32_t num_fields = columns.size(); + [[maybe_unused]] bool is_nullable = false; + for (int32_t i = 0; i < num_fields; i++) { - row_reader.pointTo(spark_row_info.getBufferAddress() + spark_row_info.getOffsets()[i], spark_row_info.getLengths()[i]); - writeRowToColumns(mutable_columns, row_reader); - } - block->setColumns(std::move(mutable_columns)); - return std::move(block); -} - -void SparkRowToCHColumn::appendSparkRowToCHColumn(SparkRowToCHColumnHelper & helper, char * buffer, int32_t length) -{ - SparkRowReader row_reader(helper.data_types); - row_reader.pointTo(buffer, length); - writeRowToColumns(helper.mutable_columns, row_reader); -} - -Block * SparkRowToCHColumn::getBlock(SparkRowToCHColumnHelper & helper) -{ - auto * block = new Block(); - *block = std::move(helper.header.cloneEmpty()); - block->setColumns(std::move(helper.mutable_columns)); - return block; -} - -VariableLengthDataReader::VariableLengthDataReader(const DataTypePtr & type_) - : type(type_), type_without_nullable(removeNullable(type)), which(type_without_nullable) -{ - if (!BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) - throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataReader doesn't support type {}", type->getName()); -} - -Field VariableLengthDataReader::read(const char *buffer, size_t length) const -{ - if (which.isStringOrFixedString()) - return std::move(readString(buffer, length)); - - if (which.isDecimal128()) - return std::move(readDecimal(buffer, length)); - - if (which.isArray()) - return std::move(readArray(buffer, length)); - - if (which.isMap()) - return std::move(readMap(buffer, length)); - - if (which.isTuple()) - return std::move(readStruct(buffer, length)); - - throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataReader doesn't support type {}", type->getName()); -} - -StringRef VariableLengthDataReader::readUnalignedBytes(const char * buffer, size_t length) const -{ - return {buffer, length}; -} - -Field VariableLengthDataReader::readDecimal(const char * buffer, size_t length) const -{ - assert(sizeof(Decimal128) == length); - - Decimal128 value; - memcpy(&value, buffer, length); - - const auto * decimal128_type = typeid_cast(type_without_nullable.get()); - return std::move(DecimalField(value, decimal128_type->getScale())); -} - -Field VariableLengthDataReader::readString(const char * buffer, size_t length) const -{ - String str(buffer, length); - return std::move(Field(std::move(str))); -} - -Field VariableLengthDataReader::readArray(const char * buffer, [[maybe_unused]] size_t length) const -{ - /// 内存布局:numElements(8B) | null_bitmap(与numElements成正比) | values(每个值长度与类型有关) | backing data - /// Read numElements - int64_t num_elems = 0; - memcpy(&num_elems, buffer, 8); - if (num_elems == 0) - return Array(); - - /// Skip null_bitmap - const auto len_null_bitmap = calculateBitSetWidthInBytes(num_elems); + WhichDataType which(columns[i]->getDataType()); + if (which.isNullable()) + { + const auto * nullable = checkAndGetDataType(types[i].get()); + which = WhichDataType(nullable->getNestedType()); + is_nullable = true; + } - /// Read values - const auto * array_type = typeid_cast(type.get()); - const auto & nested_type = array_type->getNestedType(); - const auto elem_size = BackingDataLengthCalculator::getArrayElementSize(nested_type); - const auto len_values = roundNumberOfBytesToNearestWord(elem_size * num_elems); - Array array; - array.reserve(num_elems); + if (spark_row_reader.isNullAt(i)) { + assert(is_nullable); + ColumnNullable & column = assert_cast(*columns[i]); + column.insertData(nullptr, 0); + continue; + } - if (BackingDataLengthCalculator::isFixedLengthDataType(removeNullable(nested_type))) - { - FixedLengthDataReader reader(nested_type); - for (int64_t i = 0; i < num_elems; ++i) + if (which.isUInt8()) { - if (isBitSet(buffer + 8, i)) - { - array.emplace_back(std::move(Null{})); - } - else - { - const auto elem = reader.read(buffer + 8 + len_null_bitmap + i * elem_size); - array.emplace_back(elem); - } + columns[i]->insertData(spark_row_reader.getRawDataForFixedNumber(i), sizeof(uint8_t)); } - } - else if (BackingDataLengthCalculator::isVariableLengthDataType(removeNullable(nested_type))) - { - VariableLengthDataReader reader(nested_type); - for (int64_t i = 0; i < num_elems; ++i) + else if (which.isInt8()) { - if (isBitSet(buffer + 8, i)) - { - array.emplace_back(std::move(Null{})); - } - else - { - int64_t offset_and_size = 0; - memcpy(&offset_and_size, buffer + 8 + len_null_bitmap + i * 8, 8); - const int64_t offset = BackingDataLengthCalculator::extractOffset(offset_and_size); - const int64_t size = BackingDataLengthCalculator::extractSize(offset_and_size); - - const auto elem = reader.read(buffer + offset, size); - array.emplace_back(elem); - } + columns[i]->insertData(spark_row_reader.getRawDataForFixedNumber(i), sizeof(int8_t)); } - } - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataReader doesn't support type {}", nested_type->getName()); - - return std::move(array); -} - -Field VariableLengthDataReader::readMap(const char * buffer, size_t length) const -{ - /// 内存布局:Length of UnsafeArrayData of key(8B) | UnsafeArrayData of key | UnsafeArrayData of value - /// Read Length of UnsafeArrayData of key - int64_t key_array_size = 0; - memcpy(&key_array_size, buffer, 8); - if (key_array_size == 0) - return std::move(Map()); - - /// Read UnsafeArrayData of keys - const auto * map_type = typeid_cast(type.get()); - const auto & key_type = map_type->getKeyType(); - const auto key_array_type = std::make_shared(key_type); - VariableLengthDataReader key_reader(key_array_type); - auto key_field = key_reader.read(buffer + 8, key_array_size); - auto & key_array = key_field.safeGet(); - - /// Read UnsafeArrayData of values - const auto & val_type = map_type->getValueType(); - const auto val_array_type = std::make_shared(val_type); - VariableLengthDataReader val_reader(val_array_type); - auto val_field = val_reader.read(buffer + 8 + key_array_size, length - 8 - key_array_size); - auto & val_array = val_field.safeGet(); - - /// Construct map in CH way [(k1, v1), (k2, v2), ...] - if (key_array.size() != val_array.size()) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Key size {} not equal to value size {} in map", key_array.size(), val_array.size()); - Map map(key_array.size()); - for (size_t i = 0; i < key_array.size(); ++i) - { - Tuple tuple(2); - tuple[0] = std::move(key_array[i]); - tuple[1] = std::move(val_array[i]); - - map[i] = std::move(tuple); - } - return std::move(map); -} - -Field VariableLengthDataReader::readStruct(const char * buffer, size_t /*length*/) const -{ - /// 内存布局:null_bitmap(字节数与字段数成正比) | values(num_fields * 8B) | backing data - const auto * tuple_type = typeid_cast(type.get()); - const auto & field_types = tuple_type->getElements(); - const auto num_fields = field_types.size(); - if (num_fields == 0) - return std::move(Tuple()); - - const auto len_null_bitmap = calculateBitSetWidthInBytes(num_fields); - - Tuple tuple(num_fields); - for (size_t i=0; iinsertData(spark_row_reader.getRawDataForFixedNumber(i), sizeof(int16_t)); } - - if (BackingDataLengthCalculator::isFixedLengthDataType(removeNullable(field_type))) + else if (which.isInt32()) { - FixedLengthDataReader reader(field_type); - tuple[i] = std::move(reader.read(buffer + len_null_bitmap + i * 8)); + columns[i]->insertData(spark_row_reader.getRawDataForFixedNumber(i), sizeof(int32_t)); } - else if (BackingDataLengthCalculator::isVariableLengthDataType(removeNullable(field_type))) + else if (which.isInt64() || which.isDateTime64()) { - int64_t offset_and_size = 0; - memcpy(&offset_and_size, buffer + len_null_bitmap + i * 8, 8); - const int64_t offset = BackingDataLengthCalculator::extractOffset(offset_and_size); - const int64_t size = BackingDataLengthCalculator::extractSize(offset_and_size); - - VariableLengthDataReader reader(field_type); - tuple[i] = std::move(reader.read(buffer + offset, size)); + columns[i]->insertData(spark_row_reader.getRawDataForFixedNumber(i), sizeof(int64_t)); + } + else if (which.isFloat32()) + { + columns[i]->insertData(spark_row_reader.getRawDataForFixedNumber(i), sizeof(float_t)); + } + else if (which.isFloat64()) + { + columns[i]->insertData(spark_row_reader.getRawDataForFixedNumber(i), sizeof(double_t)); + } + else if (which.isDate() || which.isUInt16()) + { + columns[i]->insertData(spark_row_reader.getRawDataForFixedNumber(i), sizeof(uint16_t)); + } + else if (which.isDate32() || which.isUInt32()) + { + columns[i]->insertData(spark_row_reader.getRawDataForFixedNumber(i), sizeof(uint32_t)); + } + else if (which.isString()) + { + StringRef data = spark_row_reader.getString(i); + columns[i]->insertData(data.data, data.size); } else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataReader doesn't support type {}", field_type->getName()); + { + throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support type {} convert from spark row to ch columnar" , + magic_enum::enum_name(columns[i]->getDataType())); + } } - return std::move(tuple); -} - -FixedLengthDataReader::FixedLengthDataReader(const DataTypePtr & type_) - : type(type_), type_without_nullable(removeNullable(type)), which(type_without_nullable) -{ - if (!BackingDataLengthCalculator::isFixedLengthDataType(type_without_nullable) || !type_without_nullable->isValueRepresentedByNumber()) - throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataReader doesn't support type {}", type->getName()); - - value_size = type_without_nullable->getSizeOfValueInMemory(); } -StringRef FixedLengthDataReader::unsafeRead(const char * buffer) const -{ - return {buffer, value_size}; -} - -Field FixedLengthDataReader::read(const char * buffer) const +std::unique_ptr +local_engine::SparkRowToCHColumn::convertSparkRowInfoToCHColumn(local_engine::SparkRowInfo & spark_row_info, DB::Block & header) { - if (which.isUInt8()) + auto columns_list = std::make_unique(); + columns_list->reserve(header.columns()); + std::vector mutable_columns; + std::vector types; + for (size_t column_i = 0, columns = header.columns(); column_i < columns; ++column_i) { - UInt8 value = 0; - memcpy(&value, buffer, 1); - return value; + const ColumnWithTypeAndName & header_column = header.getByPosition(column_i); + MutableColumnPtr read_column = header_column.type->createColumn(); + read_column->reserve(spark_row_info.getNumRows()); + mutable_columns.push_back(std::move(read_column)); + types.push_back(header_column.type); } - - if (which.isUInt16() || which.isDate()) + SparkRowReader row_reader(header.columns()); + for (int64_t i = 0; i < spark_row_info.getNumRows(); i++) { - UInt16 value = 0; - memcpy(&value, buffer, 2); - return value; + row_reader.pointTo( + reinterpret_cast(spark_row_info.getBufferAddress() + spark_row_info.getOffsets()[i]), spark_row_info.getLengths()[i]); + writeRowToColumns(mutable_columns, types, row_reader); } - - if (which.isUInt32()) + auto block = std::make_unique(*std::move(columns_list)); + for (size_t column_i = 0, columns = mutable_columns.size(); column_i < columns; ++column_i) { - UInt32 value = 0; - memcpy(&value, buffer, 4); - return value; + const ColumnWithTypeAndName & header_column = header.getByPosition(column_i); + ColumnWithTypeAndName column(std::move(mutable_columns[column_i]), header_column.type, header_column.name); + block->insert(column); } + mutable_columns.clear(); + return block; +} - if (which.isUInt64()) - { - UInt64 value = 0; - memcpy(&value, buffer, 8); - return value; - } - - if (which.isInt8()) - { - Int8 value = 0; - memcpy(&value, buffer, 1); - return value; - } - - if (which.isInt16()) - { - Int16 value = 0; - memcpy(&value, buffer, 2); - return value; - } - - if (which.isInt32() || which.isDate32()) - { - Int32 value = 0; - memcpy(&value, buffer, 4); - return value; - } - - if (which.isInt64()) - { - Int64 value = 0; - memcpy(&value, buffer, 8); - return value; - } - - if (which.isFloat32()) - { - Float32 value = 0.0; - memcpy(&value, buffer, 4); - return value; - } - - if (which.isFloat64()) - { - Float64 value = 0.0; - memcpy(&value, buffer, 8); - return value; - } - - if (which.isDecimal32()) - { - Decimal32 value = 0; - memcpy(&value, buffer, 4); - - const auto * decimal32_type = typeid_cast(type_without_nullable.get()); - return std::move(DecimalField{value, decimal32_type->getScale()}); - } +void local_engine::SparkRowToCHColumn::appendSparkRowToCHColumn(SparkRowToCHColumnHelper & helper, int64_t address, int32_t size) +{ + SparkRowReader row_reader(helper.header->columns()); + row_reader.pointTo(address, size); + writeRowToColumns(*helper.cols, *helper.typePtrs, row_reader); +} - if (which.isDecimal64() || which.isDateTime64()) +Block * local_engine::SparkRowToCHColumn::getWrittenBlock(SparkRowToCHColumnHelper & helper) +{ + auto * block = new Block(); + for (size_t column_i = 0, columns = helper.cols->size(); column_i < columns; ++column_i) { - Decimal64 value = 0; - memcpy(&value, buffer, 8); - - UInt32 scale = which.isDecimal64() ? typeid_cast(type_without_nullable.get())->getScale() - : typeid_cast(type_without_nullable.get())->getScale(); - return std::move(DecimalField{value, scale}); + const ColumnWithTypeAndName & header_column = helper.header->getByPosition(column_i); + ColumnWithTypeAndName column(std::move(helper.cols->operator[](column_i)), header_column.type, header_column.name); + block->insert(column); } - throw Exception(ErrorCodes::UNKNOWN_TYPE, "FixedLengthDataReader doesn't support type {}", type->getName()); + return block; } } diff --git a/utils/local-engine/Parser/SparkRowToCHColumn.h b/utils/local-engine/Parser/SparkRowToCHColumn.h index 1c84c69269ab..c303bdedf7a8 100644 --- a/utils/local-engine/Parser/SparkRowToCHColumn.h +++ b/utils/local-engine/Parser/SparkRowToCHColumn.h @@ -5,62 +5,123 @@ #include #include #include -#include -#include #include #include #include -#include #include -namespace DB -{ -namespace ErrorCodes -{ - extern const int UNKNOWN_TYPE; - extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; -} -} namespace local_engine { using namespace DB; using namespace std; + + struct SparkRowToCHColumnHelper { - DataTypes data_types; - Block header; - MutableColumns mutable_columns; - - SparkRowToCHColumnHelper(vector & names, vector & types) - : data_types(names.size()) + SparkRowToCHColumnHelper(vector& names, vector& types, vector& isNullables) { - assert(names.size() == types.size()); - - ColumnsWithTypeAndName columns(names.size()); + internal_cols = std::make_unique>(); + internal_cols->reserve(names.size()); + typePtrs = std::make_unique>(); + typePtrs->reserve(names.size()); for (size_t i = 0; i < names.size(); ++i) { - data_types[i] = parseType(types[i]); - columns[i] = std::move(ColumnWithTypeAndName(data_types[i], names[i])); + const auto & name = names[i]; + const auto & type = types[i]; + const bool is_nullable = isNullables[i]; + auto data_type = parseType(type, is_nullable); + internal_cols->push_back(ColumnWithTypeAndName(data_type, name)); + typePtrs->push_back(data_type); } - - header = std::move(Block(columns)); - resetMutableColumns(); + header = std::make_shared(*std::move(internal_cols)); + resetWrittenColumns(); } - ~SparkRowToCHColumnHelper() = default; + unique_ptr> internal_cols; //for headers + unique_ptr> cols; + unique_ptr> typePtrs; + shared_ptr header; - void resetMutableColumns() + void resetWrittenColumns() { - mutable_columns = std::move(header.mutateColumns()); + cols = make_unique>(); + for (size_t i = 0; i < internal_cols->size(); i++) + { + cols->push_back(internal_cols->at(i).type->createColumn()); + } } - static DataTypePtr parseType(const string & type) + static DataTypePtr inline wrapNullableType(bool isNullable, DataTypePtr nested_type) { - auto substrait_type = std::make_unique(); - auto ok = substrait_type->ParseFromString(type); - if (!ok) - throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Type from string failed"); - return std::move(SerializedPlanParser::parseType(*substrait_type)); + if (isNullable) + { + return std::make_shared(nested_type); + } + else + { + return nested_type; + } + } + + //parse Spark type name to CH DataType + DataTypePtr parseType(const string & type, const bool isNullable) + { + DataTypePtr internal_type = nullptr; + auto & factory = DataTypeFactory::instance(); + if ("boolean" == type) + { + internal_type = factory.get("UInt8"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else if ("byte" == type) + { + internal_type = factory.get("Int8"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else if ("short" == type) + { + internal_type = factory.get("Int16"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else if ("integer" == type) + { + internal_type = factory.get("Int32"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else if ("long" == type) + { + internal_type = factory.get("Int64"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else if ("string" == type) + { + internal_type = factory.get("String"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else if ("float" == type) + { + internal_type = factory.get("Float32"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else if ("double" == type) + { + internal_type = factory.get("Float64"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else if ("date" == type) + { + internal_type = factory.get("Date32"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else if ("timestamp" == type) + { + internal_type = factory.get("DateTime64(6)"); + internal_type = wrapNullableType(isNullable, internal_type); + } + else + throw Exception(0, "doesn't support spark type {}", type); + + return internal_type; } }; @@ -73,13 +134,12 @@ class SparkRowToCHColumn static jmethodID spark_row_iterator_nextBatch; // case 1: rows are batched (this is often directly converted from Block) - static std::unique_ptr convertSparkRowInfoToCHColumn(const SparkRowInfo & spark_row_info, const Block & header); + static std::unique_ptr convertSparkRowInfoToCHColumn(SparkRowInfo & spark_row_info, Block & header); // case 2: provided with a sequence of spark UnsafeRow, convert them to a Block - static Block * - convertSparkRowItrToCHColumn(jobject java_iter, vector & names, vector & types) + static Block* convertSparkRowItrToCHColumn(jobject java_iter, vector& names, vector& types, vector& isNullables) { - SparkRowToCHColumnHelper helper(names, types); + SparkRowToCHColumnHelper helper(names, types, isNullables); int attached; JNIEnv * env = JNIUtils::getENV(&attached); @@ -93,110 +153,33 @@ class SparkRowToCHColumn while (len > 0) { rows_buf_ptr += 4; - appendSparkRowToCHColumn(helper, rows_buf_ptr, len); + appendSparkRowToCHColumn(helper, reinterpret_cast(rows_buf_ptr), len); rows_buf_ptr += len; len = *(reinterpret_cast(rows_buf_ptr)); } // Try to release reference. env->DeleteLocalRef(rows_buf); } - return getBlock(helper); + return getWrittenBlock(helper); } - static void freeBlock(Block * block) - { - delete block; - block = nullptr; - } - -private: - static void appendSparkRowToCHColumn(SparkRowToCHColumnHelper & helper, char * buffer, int32_t length); - static Block * getBlock(SparkRowToCHColumnHelper & helper); -}; - -class VariableLengthDataReader -{ -public: - explicit VariableLengthDataReader(const DataTypePtr& type_); - virtual ~VariableLengthDataReader() = default; - - virtual Field read(const char * buffer, size_t length) const; - virtual StringRef readUnalignedBytes(const char * buffer, size_t length) const; + static void freeBlock(Block * block) { delete block; } private: - virtual Field readDecimal(const char * buffer, size_t length) const; - virtual Field readString(const char * buffer, size_t length) const; - virtual Field readArray(const char * buffer, size_t length) const; - virtual Field readMap(const char * buffer, size_t length) const; - virtual Field readStruct(const char * buffer, size_t length) const; - - const DataTypePtr type; - const DataTypePtr type_without_nullable; - const WhichDataType which; + static void appendSparkRowToCHColumn(SparkRowToCHColumnHelper & helper, int64_t address, int32_t size); + static Block* getWrittenBlock(SparkRowToCHColumnHelper & helper); }; -class FixedLengthDataReader -{ -public: - explicit FixedLengthDataReader(const DB::DataTypePtr & type_); - virtual ~FixedLengthDataReader() = default; - - virtual Field read(const char * buffer) const; - virtual StringRef unsafeRead(const char * buffer) const; - -private: - const DB::DataTypePtr type; - const DB::DataTypePtr type_without_nullable; - const DB::WhichDataType which; - size_t value_size; -}; class SparkRowReader { public: - explicit SparkRowReader(const DataTypes & field_types_) - : field_types(field_types_) - , num_fields(field_types.size()) - , bit_set_width_in_bytes(calculateBitSetWidthInBytes(num_fields)) - , field_offsets(num_fields) - , support_raw_datas(num_fields) - , fixed_length_data_readers(num_fields) - , variable_length_data_readers(num_fields) - { - for (auto ordinal = 0; ordinal < num_fields; ++ordinal) - { - const auto type_without_nullable = removeNullable(field_types[ordinal]); - field_offsets[ordinal] = bit_set_width_in_bytes + ordinal * 8L; - support_raw_datas[ordinal] = BackingDataLengthCalculator::isDataTypeSupportRawData(type_without_nullable); - if (BackingDataLengthCalculator::isFixedLengthDataType(type_without_nullable)) - fixed_length_data_readers[ordinal] = std::make_shared(field_types[ordinal]); - else if (BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) - variable_length_data_readers[ordinal] = std::make_shared(field_types[ordinal]); - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "SparkRowReader doesn't support type {}", field_types[ordinal]->getName()); - } - } - - const DataTypes & getFieldTypes() const - { - return field_types; - } - - bool supportRawData(int ordinal) const - { - assertIndexIsValid(ordinal); - return support_raw_datas[ordinal]; - } - - std::shared_ptr getFixedLengthDataReader(int ordinal) const + bool isSet(int index) const { - assertIndexIsValid(ordinal); - return fixed_length_data_readers[ordinal]; - } - - std::shared_ptr getVariableLengthDataReader(int ordinal) const - { - assertIndexIsValid(ordinal); - return variable_length_data_readers[ordinal]; + assert(index >= 0); + int64_t mask = 1 << (index & 63); + int64_t word_offset = base_offset + static_cast(index >> 6) * 8L; + int64_t word = *reinterpret_cast(word_offset); + return (word & mask) != 0; } void assertIndexIsValid([[maybe_unused]] int index) const @@ -208,153 +191,103 @@ class SparkRowReader bool isNullAt(int ordinal) const { assertIndexIsValid(ordinal); - return isBitSet(buffer, ordinal); + return isSet(ordinal); } - const char* getRawDataForFixedNumber(int ordinal) const + char* getRawDataForFixedNumber(int ordinal) { assertIndexIsValid(ordinal); - return reinterpret_cast(getFieldOffset(ordinal)); + return reinterpret_cast(getFieldOffset(ordinal)); } - int8_t getByte(int ordinal) const + int8_t getByte(int ordinal) { assertIndexIsValid(ordinal); - return *reinterpret_cast(getFieldOffset(ordinal)); + return *reinterpret_cast(getFieldOffset(ordinal)); } - uint8_t getUnsignedByte(int ordinal) const + uint8_t getUnsignedByte(int ordinal) { assertIndexIsValid(ordinal); - return *reinterpret_cast(getFieldOffset(ordinal)); + return *reinterpret_cast(getFieldOffset(ordinal)); } - int16_t getShort(int ordinal) const + + int16_t getShort(int ordinal) { assertIndexIsValid(ordinal); - return *reinterpret_cast(getFieldOffset(ordinal)); + return *reinterpret_cast(getFieldOffset(ordinal)); } - uint16_t getUnsignedShort(int ordinal) const + uint16_t getUnsignedShort(int ordinal) { assertIndexIsValid(ordinal); - return *reinterpret_cast(getFieldOffset(ordinal)); + return *reinterpret_cast(getFieldOffset(ordinal)); } - int32_t getInt(int ordinal) const + int32_t getInt(int ordinal) { assertIndexIsValid(ordinal); - return *reinterpret_cast(getFieldOffset(ordinal)); + return *reinterpret_cast(getFieldOffset(ordinal)); } - uint32_t getUnsignedInt(int ordinal) const + uint32_t getUnsignedInt(int ordinal) { assertIndexIsValid(ordinal); - return *reinterpret_cast(getFieldOffset(ordinal)); + return *reinterpret_cast(getFieldOffset(ordinal)); } - int64_t getLong(int ordinal) const + int64_t getLong(int ordinal) { assertIndexIsValid(ordinal); - return *reinterpret_cast(getFieldOffset(ordinal)); + return *reinterpret_cast(getFieldOffset(ordinal)); } - float_t getFloat(int ordinal) const + float_t getFloat(int ordinal) { assertIndexIsValid(ordinal); - return *reinterpret_cast(getFieldOffset(ordinal)); + return *reinterpret_cast(getFieldOffset(ordinal)); } - double_t getDouble(int ordinal) const + double_t getDouble(int ordinal) { assertIndexIsValid(ordinal); - return *reinterpret_cast(getFieldOffset(ordinal)); + return *reinterpret_cast(getFieldOffset(ordinal)); } - StringRef getString(int ordinal) const + StringRef getString(int ordinal) { assertIndexIsValid(ordinal); int64_t offset_and_size = getLong(ordinal); int32_t offset = static_cast(offset_and_size >> 32); int32_t size = static_cast(offset_and_size); - return StringRef(reinterpret_cast(this->buffer + offset), size); + return StringRef(reinterpret_cast(this->base_offset + offset), size); } - int32_t getStringSize(int ordinal) const + int32_t getStringSize(int ordinal) { assertIndexIsValid(ordinal); return static_cast(getLong(ordinal)); } - void pointTo(const char * buffer_, int32_t length_) + void pointTo(int64_t base_offset_, int32_t size_in_bytes_) { - buffer = buffer_; - length = length_; - } - - StringRef getStringRef(int ordinal) const - { - assertIndexIsValid(ordinal); - if (!support_raw_datas[ordinal]) - throw Exception( - ErrorCodes::UNKNOWN_TYPE, "SparkRowReader::getStringRef doesn't support type {}", field_types[ordinal]->getName()); - - if (isNullAt(ordinal)) - return EMPTY_STRING_REF; - - const auto & fixed_length_data_reader = fixed_length_data_readers[ordinal]; - const auto & variable_length_data_reader = variable_length_data_readers[ordinal]; - if (fixed_length_data_reader) - return std::move(fixed_length_data_reader->unsafeRead(getFieldOffset(ordinal))); - else if (variable_length_data_reader) - { - int64_t offset_and_size = 0; - memcpy(&offset_and_size, buffer + bit_set_width_in_bytes + ordinal * 8, 8); - const int64_t offset = BackingDataLengthCalculator::extractOffset(offset_and_size); - const int64_t size = BackingDataLengthCalculator::extractSize(offset_and_size); - return std::move(variable_length_data_reader->readUnalignedBytes(buffer + offset, size)); - } - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "SparkRowReader::getStringRef doesn't support type {}", field_types[ordinal]->getName()); + this->base_offset = base_offset_; + this->size_in_bytes = size_in_bytes_; } - Field getField(int ordinal) const + explicit SparkRowReader(int32_t numFields) : num_fields(numFields) { - assertIndexIsValid(ordinal); - - if (isNullAt(ordinal)) - return std::move(Null{}); - - const auto & fixed_length_data_reader = fixed_length_data_readers[ordinal]; - const auto & variable_length_data_reader = variable_length_data_readers[ordinal]; - - if (fixed_length_data_reader) - return std::move(fixed_length_data_reader->read(getFieldOffset(ordinal))); - else if (variable_length_data_reader) - { - int64_t offset_and_size = 0; - memcpy(&offset_and_size, buffer + bit_set_width_in_bytes + ordinal * 8, 8); - const int64_t offset = BackingDataLengthCalculator::extractOffset(offset_and_size); - const int64_t size = BackingDataLengthCalculator::extractSize(offset_and_size); - return std::move(variable_length_data_reader->read(buffer + offset, size)); - } - else - throw Exception(ErrorCodes::UNKNOWN_TYPE, "SparkRowReader::getField doesn't support type {}", field_types[ordinal]->getName()); + this->bit_set_width_in_bytes = local_engine::calculateBitSetWidthInBytes(numFields); } private: - const char * getFieldOffset(int ordinal) const { return buffer + field_offsets[ordinal]; } - - const DataTypes field_types; - const int32_t num_fields; - const int32_t bit_set_width_in_bytes; - std::vector field_offsets; - std::vector support_raw_datas; - std::vector> fixed_length_data_readers; - std::vector> variable_length_data_readers; - - const char * buffer; - int32_t length; + int64_t getFieldOffset(int ordinal) const { return base_offset + bit_set_width_in_bytes + ordinal * 8L; } + + int64_t base_offset; + [[maybe_unused]] int32_t num_fields; + int32_t size_in_bytes; + int32_t bit_set_width_in_bytes; }; } diff --git a/utils/local-engine/jni/jni_common.cpp b/utils/local-engine/jni/jni_common.cpp index 7df7d34c841b..a2423eb7e170 100644 --- a/utils/local-engine/jni/jni_common.cpp +++ b/utils/local-engine/jni/jni_common.cpp @@ -66,12 +66,4 @@ jstring charTojstring(JNIEnv* env, const char* pat) { env->DeleteLocalRef(encoding); return result; } - -jbyteArray stringTojbyteArray(JNIEnv* env, const std::string & str) { - const auto * ptr = reinterpret_cast(str.c_str()) ; - jbyteArray jarray = env->NewByteArray(str.size()); - env->SetByteArrayRegion(jarray, 0, str.size(), ptr); - return jarray; -} - } diff --git a/utils/local-engine/jni/jni_common.h b/utils/local-engine/jni/jni_common.h index 13afc8120f97..02d92cd154ca 100644 --- a/utils/local-engine/jni/jni_common.h +++ b/utils/local-engine/jni/jni_common.h @@ -2,7 +2,6 @@ #include #include #include -#include #include namespace DB @@ -25,8 +24,6 @@ jmethodID GetStaticMethodID(JNIEnv * env, jclass this_class, const char * name, jstring charTojstring(JNIEnv* env, const char* pat); -jbyteArray stringTojbyteArray(JNIEnv* env, const std::string & str); - #define LOCAL_ENGINE_JNI_JMETHOD_START #define LOCAL_ENGINE_JNI_JMETHOD_END(env) \ if ((env)->ExceptionCheck())\ diff --git a/utils/local-engine/local_engine_jni.cpp b/utils/local-engine/local_engine_jni.cpp index 185a95b41b4f..7ebd8dd1e256 100644 --- a/utils/local-engine/local_engine_jni.cpp +++ b/utils/local-engine/local_engine_jni.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -44,7 +43,7 @@ std::vector stringSplit(const std::string & str, char delim) } } -DB::ColumnWithTypeAndName inline getColumnFromColumnVector(JNIEnv * /*env*/, jobject /*obj*/, jlong block_address, jint column_position) +DB::ColumnWithTypeAndName inline getColumnFromColumnVector(JNIEnv * /*env*/, jobject obj, jlong block_address, jint column_position) { DB::Block * block = reinterpret_cast(block_address); return block->getByPosition(column_position); @@ -480,15 +479,80 @@ jint Java_io_glutenproject_vectorized_CHNativeBlock_nativeNumColumns(JNIEnv * en LOCAL_ENGINE_JNI_METHOD_END(env, -1) } -jbyteArray Java_io_glutenproject_vectorized_CHNativeBlock_nativeColumnType(JNIEnv * env, jobject /*obj*/, jlong block_address, jint position) +jstring Java_io_glutenproject_vectorized_CHNativeBlock_nativeColumnType(JNIEnv * env, jobject /*obj*/, jlong block_address, jint position) { LOCAL_ENGINE_JNI_METHOD_START auto * block = reinterpret_cast(block_address); - const auto & col = block->getByPosition(position); - std::string substrait_type; - dbms::SerializedPlanBuilder::buildType(col.type, substrait_type); - return local_engine::stringTojbyteArray(env, substrait_type); - LOCAL_ENGINE_JNI_METHOD_END(env, local_engine::stringTojbyteArray(env, "")) + DB::WhichDataType which(block->getByPosition(position).type); + std::string type; + if (which.isNullable()) + { + const auto * nullable = checkAndGetDataType(block->getByPosition(position).type.get()); + which = DB::WhichDataType(nullable->getNestedType()); + } + + if (which.isDate32()) + { + type = "Date"; + } + else if (which.isDateTime64()) + { + type = "Timestamp"; + } + else if (which.isFloat32()) + { + type = "Float"; + } + else if (which.isFloat64()) + { + type = "Double"; + } + else if (which.isInt32()) + { + type = "Integer"; + } + else if (which.isInt64()) + { + type = "Long"; + } + else if (which.isUInt64()) + { + type = "Long"; + } + else if (which.isInt8()) + { + type = "Byte"; + } + else if (which.isInt16()) + { + type = "Short"; + } + else if (which.isUInt16()) + { + type = "Integer"; + } + else if (which.isUInt8()) + { + type = "Boolean"; + } + else if (which.isString()) + { + type = "String"; + } + else if (which.isAggregateFunction()) + { + type = "Binary"; + } + else + { + auto type_name = std::string(block->getByPosition(position).type->getName()); + auto col_name = block->getByPosition(position).name; + LOG_ERROR(&Poco::Logger::get("jni"), "column {}, unsupported datatype {}", col_name, type_name); + throw std::runtime_error("unsupported datatype " + type_name); + } + + return local_engine::charTojstring(env, type.c_str()); + LOCAL_ENGINE_JNI_METHOD_END(env, local_engine::charTojstring(env, "")) } jlong Java_io_glutenproject_vectorized_CHNativeBlock_nativeTotalBytes(JNIEnv * env, jobject /*obj*/, jlong block_address) @@ -632,13 +696,13 @@ jobject Java_io_glutenproject_vectorized_CHShuffleSplitterJniWrapper_stop(JNIEnv local_engine::SplitterHolder * splitter = reinterpret_cast(splitterId); auto result = splitter->splitter->stop(); const auto & partition_lengths = result.partition_length; - auto *partition_length_arr = env->NewLongArray(partition_lengths.size()); - const auto *src = reinterpret_cast(partition_lengths.data()); + auto partition_length_arr = env->NewLongArray(partition_lengths.size()); + auto src = reinterpret_cast(partition_lengths.data()); env->SetLongArrayRegion(partition_length_arr, 0, partition_lengths.size(), src); const auto & raw_partition_lengths = result.raw_partition_length; - auto *raw_partition_length_arr = env->NewLongArray(raw_partition_lengths.size()); - const auto *raw_src = reinterpret_cast(raw_partition_lengths.data()); + auto raw_partition_length_arr = env->NewLongArray(raw_partition_lengths.size()); + auto raw_src = reinterpret_cast(raw_partition_lengths.data()); env->SetLongArrayRegion(raw_partition_length_arr, 0, raw_partition_lengths.size(), raw_src); jobject split_result = env->NewObject( @@ -694,37 +758,35 @@ void Java_io_glutenproject_vectorized_BlockNativeConverter_freeMemory(JNIEnv * e { LOCAL_ENGINE_JNI_METHOD_START local_engine::CHColumnToSparkRow converter; - converter.freeMem(reinterpret_cast(address), size); + converter.freeMem(reinterpret_cast(address), size); LOCAL_ENGINE_JNI_METHOD_END(env,) } jlong Java_io_glutenproject_vectorized_BlockNativeConverter_convertSparkRowsToCHColumn( - JNIEnv * env, jobject, jobject java_iter, jobjectArray names, jobjectArray types) + JNIEnv * env, jobject, jobject java_iter, jobjectArray names, jobjectArray types, jbooleanArray is_nullables) { LOCAL_ENGINE_JNI_METHOD_START using namespace std; + int column_size = env->GetArrayLength(names); - int num_columns = env->GetArrayLength(names); vector c_names; vector c_types; - c_names.reserve(num_columns); - for (int i = 0; i < num_columns; i++) + vector c_isnullables; + jboolean * p_booleans = env->GetBooleanArrayElements(is_nullables, nullptr); + for (int i = 0; i < column_size; i++) { auto * name = static_cast(env->GetObjectArrayElement(names, i)); - c_names.emplace_back(std::move(jstring2string(env, name))); - - auto * type = static_cast(env->GetObjectArrayElement(types, i)); - auto type_length = env->GetArrayLength(type); - jbyte * type_ptr = env->GetByteArrayElements(type, nullptr); - string str_type(reinterpret_cast(type_ptr), type_length); - c_types.emplace_back(std::move(str_type)); + auto * type = static_cast(env->GetObjectArrayElement(types, i)); + c_names.push_back(jstring2string(env, name)); + c_types.push_back(jstring2string(env, type)); + c_isnullables.push_back(p_booleans[i] == JNI_TRUE); - env->ReleaseByteArrayElements(type, type_ptr, JNI_ABORT); env->DeleteLocalRef(name); env->DeleteLocalRef(type); } + env->ReleaseBooleanArrayElements(is_nullables, p_booleans, JNI_ABORT); local_engine::SparkRowToCHColumn converter; - return reinterpret_cast(converter.convertSparkRowItrToCHColumn(java_iter, c_names, c_types)); + return reinterpret_cast(converter.convertSparkRowItrToCHColumn(java_iter, c_names, c_types, c_isnullables)); LOCAL_ENGINE_JNI_METHOD_END(env, -1) } diff --git a/utils/local-engine/tests/CMakeLists.txt b/utils/local-engine/tests/CMakeLists.txt index 231c564ade74..1c6b1cb2c4d9 100644 --- a/utils/local-engine/tests/CMakeLists.txt +++ b/utils/local-engine/tests/CMakeLists.txt @@ -40,7 +40,7 @@ grep_gtest_sources("${ClickHouse_SOURCE_DIR}/utils/local_engine/tests" local_eng add_executable(unit_tests_local_engine ${local_engine_gtest_sources} ) -add_executable(benchmark_local_engine benchmark_local_engine.cpp benchmark_parquet_read.cpp benchmark_spark_row.cpp) +add_executable(benchmark_local_engine benchmark_local_engine.cpp benchmark_parquet_read.cpp) target_include_directories(unit_tests_local_engine PRIVATE ${GTEST_INCLUDE_DIRS}/include diff --git a/utils/local-engine/tests/benchmark_local_engine.cpp b/utils/local-engine/tests/benchmark_local_engine.cpp index 26769e00be78..3674faa60b52 100644 --- a/utils/local-engine/tests/benchmark_local_engine.cpp +++ b/utils/local-engine/tests/benchmark_local_engine.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include #include @@ -1252,6 +1251,7 @@ class FasterCompressedReadBuffer : public FasterCompressedReadBufferBase, public #include + [[maybe_unused]] static void BM_CHColumnToSparkRowNew(benchmark::State & state) { std::shared_ptr metadata = std::make_shared(); @@ -1504,9 +1504,7 @@ int main(int argc, char ** argv) SharedContextHolder shared_context = Context::createShared(); global_context = Context::createGlobal(shared_context.get()); global_context->makeGlobalContext(); - - auto config = Poco::AutoPtr(new Poco::Util::MapConfiguration()); - global_context->setConfig(config); + global_context->setConfig(local_engine::SerializedPlanParser::config); const std::string path = "/"; global_context->setPath(path); SerializedPlanParser::global_context = global_context; diff --git a/utils/local-engine/tests/benchmark_spark_row.cpp b/utils/local-engine/tests/benchmark_spark_row.cpp deleted file mode 100644 index 2d332f1632b6..000000000000 --- a/utils/local-engine/tests/benchmark_spark_row.cpp +++ /dev/null @@ -1,121 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -using namespace DB; -using namespace local_engine; - -struct NameType -{ - String name; - String type; -}; - -using NameTypes = std::vector; - -static Block getLineitemHeader(const NameTypes & name_types) -{ - auto & factory = DataTypeFactory::instance(); - ColumnsWithTypeAndName columns(name_types.size()); - for (size_t i=0; i(file); - FormatSettings format_settings; - auto format = std::make_shared(*in, header, format_settings); - auto pipeline = QueryPipeline(std::move(format)); - auto reader = std::make_unique(pipeline); - while (reader->pull(block)) - return; -} - -static void BM_CHColumnToSparkRow_Lineitem(benchmark::State& state) -{ - const NameTypes name_types = { - {"l_orderkey", "Nullable(Int64)"}, - {"l_partkey", "Nullable(Int64)"}, - {"l_suppkey", "Nullable(Int64)"}, - {"l_linenumber", "Nullable(Int64)"}, - {"l_quantity", "Nullable(Float64)"}, - {"l_extendedprice", "Nullable(Float64)"}, - {"l_discount", "Nullable(Float64)"}, - {"l_tax", "Nullable(Float64)"}, - {"l_returnflag", "Nullable(String)"}, - {"l_linestatus", "Nullable(String)"}, - {"l_shipdate", "Nullable(Date32)"}, - {"l_commitdate", "Nullable(Date32)"}, - {"l_receiptdate", "Nullable(Date32)"}, - {"l_shipinstruct", "Nullable(String)"}, - {"l_shipmode", "Nullable(String)"}, - {"l_comment", "Nullable(String)"}, - }; - - const Block header = std::move(getLineitemHeader(name_types)); - const String file = "/data1/liyang/cppproject/gluten/jvm/src/test/resources/tpch-data/lineitem/" - "part-00000-d08071cb-0dfa-42dc-9198-83cb334ccda3-c000.snappy.parquet"; - Block block; - readParquetFile(header, file, block); - // std::cerr << "read_rows:" << block.rows() << std::endl; - CHColumnToSparkRow converter; - for (auto _ : state) - { - auto spark_row_info = converter.convertCHColumnToSparkRow(block); - converter.freeMem(spark_row_info->getBufferAddress(), spark_row_info->getTotalBytes()); - } -} - - -static void BM_SparkRowToCHColumn_Lineitem(benchmark::State& state) -{ - const NameTypes name_types = { - {"l_orderkey", "Nullable(Int64)"}, - {"l_partkey", "Nullable(Int64)"}, - {"l_suppkey", "Nullable(Int64)"}, - {"l_linenumber", "Nullable(Int64)"}, - {"l_quantity", "Nullable(Float64)"}, - {"l_extendedprice", "Nullable(Float64)"}, - {"l_discount", "Nullable(Float64)"}, - {"l_tax", "Nullable(Float64)"}, - {"l_returnflag", "Nullable(String)"}, - {"l_linestatus", "Nullable(String)"}, - {"l_shipdate", "Nullable(Date32)"}, - {"l_commitdate", "Nullable(Date32)"}, - {"l_receiptdate", "Nullable(Date32)"}, - {"l_shipinstruct", "Nullable(String)"}, - {"l_shipmode", "Nullable(String)"}, - {"l_comment", "Nullable(String)"}, - }; - - const Block header = std::move(getLineitemHeader(name_types)); - const String file = "/data1/liyang/cppproject/gluten/jvm/src/test/resources/tpch-data/lineitem/" - "part-00000-d08071cb-0dfa-42dc-9198-83cb334ccda3-c000.snappy.parquet"; - Block in_block; - readParquetFile(header, file, in_block); - - CHColumnToSparkRow spark_row_converter; - auto spark_row_info = spark_row_converter.convertCHColumnToSparkRow(in_block); - for (auto _ : state) - [[maybe_unused]] auto out_block = SparkRowToCHColumn::convertSparkRowInfoToCHColumn(*spark_row_info, header); -} - -BENCHMARK(BM_CHColumnToSparkRow_Lineitem)->Unit(benchmark::kMillisecond)->Iterations(10); -BENCHMARK(BM_SparkRowToCHColumn_Lineitem)->Unit(benchmark::kMillisecond)->Iterations(10); diff --git a/utils/local-engine/tests/data/array.parquet b/utils/local-engine/tests/data/array.parquet deleted file mode 100644 index d989f3d7cbc1..000000000000 Binary files a/utils/local-engine/tests/data/array.parquet and /dev/null differ diff --git a/utils/local-engine/tests/data/decimal.parquet b/utils/local-engine/tests/data/decimal.parquet deleted file mode 100644 index e1981938866e..000000000000 Binary files a/utils/local-engine/tests/data/decimal.parquet and /dev/null differ diff --git a/utils/local-engine/tests/data/map.parquet b/utils/local-engine/tests/data/map.parquet deleted file mode 100644 index def9242ee305..000000000000 Binary files a/utils/local-engine/tests/data/map.parquet and /dev/null differ diff --git a/utils/local-engine/tests/data/struct.parquet b/utils/local-engine/tests/data/struct.parquet deleted file mode 100644 index 7a90433ae703..000000000000 Binary files a/utils/local-engine/tests/data/struct.parquet and /dev/null differ diff --git a/utils/local-engine/tests/gtest_local_engine.cpp b/utils/local-engine/tests/gtest_local_engine.cpp index 361da32e7341..49a4daf880f6 100644 --- a/utils/local-engine/tests/gtest_local_engine.cpp +++ b/utils/local-engine/tests/gtest_local_engine.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -21,9 +20,9 @@ #include #include #include -#include #include #include +#include "Storages/CustomStorageMergeTree.h" #include "testConfig.h" using namespace local_engine; @@ -40,10 +39,11 @@ TEST(TestSelect, ReadRel) .column("type_string", "String") .build(); dbms::SerializedPlanBuilder plan_builder; - auto plan = plan_builder.read(TEST_DATA(/data/iris.parquet), std::move(schema)).build(); + auto plan = plan_builder.read(TEST_DATA(/ data / iris.parquet), std::move(schema)).build(); ASSERT_TRUE(plan->relations(0).root().input().has_read()); ASSERT_EQ(plan->relations_size(), 1); + std::cout << "start execute" << std::endl; local_engine::LocalExecutor local_executor; local_engine::SerializedPlanParser parser(local_engine::SerializedPlanParser::global_context); auto query_plan = parser.parse(std::move(plan)); @@ -51,6 +51,7 @@ TEST(TestSelect, ReadRel) ASSERT_TRUE(local_executor.hasNext()); while (local_executor.hasNext()) { + std::cout << "fetch batch" << std::endl; local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); ASSERT_GT(spark_row_info->getNumRows(), 0); local_engine::SparkRowToCHColumn converter; @@ -64,10 +65,11 @@ TEST(TestSelect, ReadDate) dbms::SerializedSchemaBuilder schema_builder; auto * schema = schema_builder.column("date", "Date").build(); dbms::SerializedPlanBuilder plan_builder; - auto plan = plan_builder.read(TEST_DATA(/data/date.parquet), std::move(schema)).build(); + auto plan = plan_builder.read(TEST_DATA(/ data / date.parquet), std::move(schema)).build(); ASSERT_TRUE(plan->relations(0).root().input().has_read()); ASSERT_EQ(plan->relations_size(), 1); + std::cout << "start execute" << std::endl; local_engine::LocalExecutor local_executor; local_engine::SerializedPlanParser parser(local_engine::SerializedPlanParser::global_context); auto query_plan = parser.parse(std::move(plan)); @@ -75,6 +77,7 @@ TEST(TestSelect, ReadDate) ASSERT_TRUE(local_executor.hasNext()); while (local_executor.hasNext()) { + std::cout << "fetch batch" << std::endl; local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); ASSERT_GT(spark_row_info->getNumRows(), 0); local_engine::SparkRowToCHColumn converter; @@ -104,8 +107,9 @@ TEST(TestSelect, TestFilter) auto * type_0 = dbms::scalarFunction(dbms::EQUAL_TO, {dbms::selection(5), dbms::literal("类型1")}); auto * filter = dbms::scalarFunction(dbms::AND, {less_exp, type_0}); - auto plan = plan_builder.registerSupportedFunctions().filter(filter).read(TEST_DATA(/data/iris.parquet), std::move(schema)).build(); + auto plan = plan_builder.registerSupportedFunctions().filter(filter).read(TEST_DATA(/ data / iris.parquet), std::move(schema)).build(); ASSERT_EQ(plan->relations_size(), 1); + std::cout << "start execute" << std::endl; local_engine::LocalExecutor local_executor; local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); auto query_plan = parser.parse(std::move(plan)); @@ -113,6 +117,7 @@ TEST(TestSelect, TestFilter) ASSERT_TRUE(local_executor.hasNext()); while (local_executor.hasNext()) { + std::cout << "fetch batch" << std::endl; local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); ASSERT_EQ(spark_row_info->getNumRows(), 1); local_engine::SparkRowToCHColumn converter; @@ -139,9 +144,10 @@ TEST(TestSelect, TestAgg) auto plan = plan_builder.registerSupportedFunctions() .aggregate({}, {measure}) .filter(less_exp) - .read(TEST_DATA(/data/iris.parquet), std::move(schema)) + .read(TEST_DATA(/ data / iris.parquet), std::move(schema)) .build(); ASSERT_EQ(plan->relations_size(), 1); + std::cout << "start execute" << std::endl; local_engine::LocalExecutor local_executor; local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); auto query_plan = parser.parse(std::move(plan)); @@ -149,14 +155,17 @@ TEST(TestSelect, TestAgg) ASSERT_TRUE(local_executor.hasNext()); while (local_executor.hasNext()) { + std::cout << "fetch batch" << std::endl; local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); ASSERT_EQ(spark_row_info->getNumRows(), 1); ASSERT_EQ(spark_row_info->getNumCols(), 1); local_engine::SparkRowToCHColumn converter; auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); ASSERT_EQ(spark_row_info->getNumRows(), block->rows()); - auto reader = SparkRowReader(block->getDataTypes()); - reader.pointTo(spark_row_info->getBufferAddress() + spark_row_info->getOffsets()[1], spark_row_info->getLengths()[0]); + auto reader = SparkRowReader(spark_row_info->getNumCols()); + reader.pointTo( + reinterpret_cast(spark_row_info->getBufferAddress() + spark_row_info->getOffsets()[1]), + spark_row_info->getLengths()[0]); ASSERT_EQ(reader.getDouble(0), 103.2); } } @@ -321,8 +330,7 @@ int main(int argc, char ** argv) SharedContextHolder shared_context = Context::createShared(); local_engine::SerializedPlanParser::global_context = Context::createGlobal(shared_context.get()); local_engine::SerializedPlanParser::global_context->makeGlobalContext(); - auto config = Poco::AutoPtr(new Poco::Util::MapConfiguration()); - local_engine::SerializedPlanParser::global_context->setConfig(config); + local_engine::SerializedPlanParser::global_context->setConfig(local_engine::SerializedPlanParser::config); local_engine::SerializedPlanParser::global_context->setPath("/tmp"); local_engine::SerializedPlanParser::global_context->getDisksMap().emplace(); local_engine::SerializedPlanParser::initFunctionEnv(); diff --git a/utils/local-engine/tests/gtest_spark_row.cpp b/utils/local-engine/tests/gtest_spark_row.cpp deleted file mode 100644 index 313c62de39a3..000000000000 --- a/utils/local-engine/tests/gtest_spark_row.cpp +++ /dev/null @@ -1,443 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace local_engine; -using namespace DB; - - -struct DataTypeAndField -{ - DataTypePtr type; - Field field; -}; -using DataTypeAndFields = std::vector; - -using SparkRowAndBlock = std::pair; - -static SparkRowAndBlock mockSparkRowInfoAndBlock(const DataTypeAndFields & type_and_fields) -{ - /// Initialize types - ColumnsWithTypeAndName columns(type_and_fields.size()); - for (size_t i=0; iinsert(type_and_fields[i].field); - block.setColumns(std::move(mutable_colums)); - - auto converter = CHColumnToSparkRow(); - auto spark_row_info = converter.convertCHColumnToSparkRow(block); - return std::make_tuple(std::move(spark_row_info), std::make_shared(std::move(block))); -} - -static Int32 getDayNum(const String & date) -{ - ExtendedDayNum res; - ReadBufferFromString in(date); - readDateText(res, in); - return res; -} - -static DateTime64 getDateTime64(const String & datetime64, UInt32 scale) -{ - DateTime64 res; - ReadBufferFromString in(datetime64); - readDateTime64Text(res, scale, in); - return res; -} - -static void assertReadConsistentWithWritten(const SparkRowInfo & spark_row_info, const Block & in, const DataTypeAndFields type_and_fields) -{ - /// Check if output of SparkRowReader is consistent with types_and_fields - { - auto reader = SparkRowReader(spark_row_info.getDataTypes()); - reader.pointTo(spark_row_info.getBufferAddress(), spark_row_info.getTotalBytes()); - for (size_t i = 0; i < type_and_fields.size(); ++i) - { - /* - const auto read_field{std::move(reader.getField(i))}; - const auto & written_field = type_and_fields[i].field; - std::cout << "read_field:" << read_field.getType() << "," << toString(read_field) << std::endl; - std::cout << "written_field:" << written_field.getType() << "," << toString(written_field) << std::endl; - */ - EXPECT_TRUE(reader.getField(i) == type_and_fields[i].field); - } - } - - /// check if output of SparkRowToCHColumn is consistents with initial block. - { - auto block = SparkRowToCHColumn::convertSparkRowInfoToCHColumn(spark_row_info, in.cloneEmpty()); - const auto & out = *block; - EXPECT_TRUE(in.rows() == out.rows()); - EXPECT_TRUE(in.columns() == out.columns()); - for (size_t col_idx = 0; col_idx < in.columns(); ++col_idx) - { - const auto & in_col = in.getByPosition(col_idx); - const auto & out_col = out.getByPosition(col_idx); - for (size_t row_idx = 0; row_idx < in.rows(); ++row_idx) - { - const auto in_field = (*in_col.column)[row_idx]; - const auto out_field = (*out_col.column)[row_idx]; - EXPECT_TRUE(in_field == out_field); - } - } - } -} - -TEST(SparkRow, BitSetWidthCalculation) -{ - EXPECT_TRUE(calculateBitSetWidthInBytes(0) == 0); - EXPECT_TRUE(calculateBitSetWidthInBytes(1) == 8); - EXPECT_TRUE(calculateBitSetWidthInBytes(32) == 8); - EXPECT_TRUE(calculateBitSetWidthInBytes(64) == 8); - EXPECT_TRUE(calculateBitSetWidthInBytes(65) == 16); - EXPECT_TRUE(calculateBitSetWidthInBytes(128) == 16); -} - -TEST(SparkRow, GetArrayElementSize) -{ - const std::map type_to_size = { - {std::make_shared(), 1}, - {std::make_shared(), 1}, - {std::make_shared(), 2}, - {std::make_shared(), 2}, - {std::make_shared(), 2}, - {std::make_shared(), 4}, - {std::make_shared(), 4}, - {std::make_shared(), 4}, - {std::make_shared(), 4}, - {std::make_shared(9, 4), 4}, - {std::make_shared(), 8}, - {std::make_shared(), 8}, - {std::make_shared(), 8}, - {std::make_shared(6), 8}, - {std::make_shared(18, 4), 8}, - - {std::make_shared(), 8}, - {std::make_shared(38, 4), 8}, - {std::make_shared(std::make_shared(), std::make_shared()), 8}, - {std::make_shared(std::make_shared()), 8}, - {std::make_shared(DataTypes{std::make_shared(), std::make_shared()}), 8}, - }; - - for (const auto & [type, size] : type_to_size) - { - EXPECT_TRUE(BackingDataLengthCalculator::getArrayElementSize(type) == size); - if (type->canBeInsideNullable()) - { - const auto type_with_nullable = std::make_shared(type); - EXPECT_TRUE(BackingDataLengthCalculator::getArrayElementSize(type_with_nullable) == size); - } - } -} - -TEST(SparkRow, PrimitiveTypes) -{ - DataTypeAndFields type_and_fields = { - {std::make_shared(), -1}, - {std::make_shared(), UInt64(1)}, - {std::make_shared(), -2}, - {std::make_shared(), UInt32(2)}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + 4 * 8); -} - -TEST(SparkRow, PrimitiveStringTypes) -{ - DataTypeAndFields type_and_fields = { - {std::make_shared(), -1}, - {std::make_shared(), UInt64(1)}, - {std::make_shared(), "Hello World"}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + (8 * 3) + roundNumberOfBytesToNearestWord(strlen("Hello World"))); -} - -TEST(SparkRow, PrimitiveStringDateTimestampTypes) -{ - DataTypeAndFields type_and_fields = { - {std::make_shared(), -1}, - {std::make_shared(), UInt64(1)}, - {std::make_shared(), "Hello World"}, - {std::make_shared(), getDayNum("2015-06-22")}, - {std::make_shared(0), getDateTime64("2015-05-08 08:10:25", 0)}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + (8 * 5) + roundNumberOfBytesToNearestWord(strlen("Hello World"))); -} - - -TEST(SparkRow, DecimalTypes) -{ - DataTypeAndFields type_and_fields = { - {std::make_shared(9, 2), DecimalField(1234, 2)}, - {std::make_shared(18, 2), DecimalField(5678, 2)}, - {std::make_shared(38, 2), DecimalField(Decimal128(Int128(12345678)), 2)}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + (8 * 3) + 16); -} - - -TEST(SparkRow, NullHandling) -{ - DataTypeAndFields type_and_fields = { - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - {std::make_shared(std::make_shared()), Null{}}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + (8 * type_and_fields.size())); -} - -TEST(SparkRow, StructTypes) -{ - DataTypeAndFields type_and_fields = { - {std::make_shared(DataTypes{std::make_shared()}), Tuple{Int32(1)}}, - {std::make_shared(DataTypes{std::make_shared(DataTypes{std::make_shared()})}), - []() -> Field - { - Tuple t(1); - t.back() = Tuple{Int64(2)}; - return std::move(t); - }()}, - }; - - /* - for (size_t i=0; igetName() << ",field:" << toString(type_and_fields[i].field) - << std::endl; - } - */ - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - - EXPECT_TRUE( - spark_row_info->getTotalBytes() - == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) - + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); -} - -TEST(SparkRow, ArrayTypes) -{ - DataTypeAndFields type_and_fields = { - {std::make_shared(std::make_shared()), Array{Int32(1), Int32(2)}}, - {std::make_shared(std::make_shared(std::make_shared())), - []() -> Field - { - Array array(1); - array.back() = Array{Int32(1), Int32(2)}; - return std::move(array); - }()}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - EXPECT_TRUE( - spark_row_info->getTotalBytes() - == 8 + 2 * 8 - + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) - + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); -} - -TEST(SparkRow, MapTypes) -{ - const auto map_type = std::make_shared(std::make_shared(), std::make_shared()); - DataTypeAndFields type_and_fields = { - {map_type, - []() -> Field - { - Map map(2); - map[0] = std::move(Tuple{Int32(1), Int32(2)}); - map[1] = std::move(Tuple{Int32(3), Int32(4)}); - return std::move(map); - }()}, - {std::make_shared(std::make_shared(), map_type), - []() -> Field - { - Map inner_map(2); - inner_map[0] = std::move(Tuple{Int32(5), Int32(6)}); - inner_map[1] = std::move(Tuple{Int32(7), Int32(8)}); - - Map map(1); - map.back() = std::move(Tuple{Int32(9), std::move(inner_map)}); - return std::move(map); - }()}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - - EXPECT_TRUE( - spark_row_info->getTotalBytes() - == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) - + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); -} - - -TEST(SparkRow, StructMapTypes) -{ - const auto map_type = std::make_shared(std::make_shared(), std::make_shared()); - const auto tuple_type = std::make_shared(DataTypes{std::make_shared()}); - - DataTypeAndFields type_and_fields = { - {std::make_shared(DataTypes{map_type}), - []() -> Field - { - Map map(1); - map[0] = std::move(Tuple{Int32(1), Int32(2)}); - return std::move(Tuple{std::move(map)}); - }()}, - {std::make_shared(std::make_shared(), tuple_type), - []() -> Field - { - Tuple inner_tuple{Int32(4)}; - Map map(1); - map.back() = std::move(Tuple{Int32(3), std::move(inner_tuple)}); - return std::move(map); - }()}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - - EXPECT_TRUE( - spark_row_info->getTotalBytes() - == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) - + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); -} - - -TEST(SparkRow, StructArrayTypes) -{ - const auto array_type = std::make_shared(std::make_shared()); - const auto tuple_type = std::make_shared(DataTypes{std::make_shared()}); - DataTypeAndFields type_and_fields = { - {std::make_shared(DataTypes{array_type}), - []() -> Field - { - Array array{Int32(1)}; - Tuple tuple(1); - tuple[0] = std::move(array); - return std::move(tuple); - }()}, - {std::make_shared(tuple_type), - []() -> Field - { - Tuple tuple{Int64(2)}; - Array array(1); - array[0] = std::move(tuple); - return std::move(array); - }()}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - EXPECT_TRUE( - spark_row_info->getTotalBytes() - == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) - + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); - -} - -TEST(SparkRow, ArrayMapTypes) -{ - const auto map_type = std::make_shared(std::make_shared(), std::make_shared()); - const auto array_type = std::make_shared(std::make_shared()); - - DataTypeAndFields type_and_fields = { - {std::make_shared(map_type), - []() -> Field - { - Map map(1); - map[0] = std::move(Tuple{Int32(1),Int32(2)}); - - Array array(1); - array[0] = std::move(map); - return std::move(array); - }()}, - {std::make_shared(std::make_shared(), array_type), - []() -> Field - { - Array array{Int32(4)}; - Tuple tuple(2); - tuple[0] = Int32(3); - tuple[1] = std::move(array); - - Map map(1); - map[0] = std::move(tuple); - return std::move(map); - }()}, - }; - - SparkRowInfoPtr spark_row_info; - BlockPtr block; - std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); - assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); - - EXPECT_TRUE( - spark_row_info->getTotalBytes() - == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) - + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); -}