diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index cd4ee4827021..d7de4fe7fc75 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -839,6 +839,34 @@ jobs: docker ps --all --quiet | xargs --no-run-if-empty docker rm -f ||: sudo rm -fr "$TEMP_PATH" ############################################################################################ +##################################### Docker images ####################################### +############################################################################################ + DockerServerImages: + needs: + - BuilderDebRelease + - BuilderDebAarch64 + runs-on: [self-hosted, style-checker] + steps: + - name: Clear repository + run: | + sudo rm -fr "$GITHUB_WORKSPACE" && mkdir "$GITHUB_WORKSPACE" + - name: Check out repository code + uses: actions/checkout@v2 + with: + fetch-depth: 0 # otherwise we will have no version info + - name: Check docker clickhouse/clickhouse-server building + run: | + cd "$GITHUB_WORKSPACE/tests/ci" + python3 docker_server.py --release-type head + python3 docker_server.py --release-type head --no-ubuntu \ + --image-repo clickhouse/clickhouse-keeper --image-path docker/keeper + - name: Cleanup + if: always() + run: | + docker kill "$(docker ps -q)" ||: + docker rm -f "$(docker ps -a -q)" ||: + sudo rm -fr "$TEMP_PATH" +############################################################################################ ##################################### BUILD REPORTER ####################################### ############################################################################################ BuilderReport: diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index c677ec4bf5c7..774327a767e7 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -894,6 +894,34 @@ jobs: docker ps --all --quiet | xargs --no-run-if-empty docker rm -f ||: sudo rm -fr "$TEMP_PATH" ############################################################################################ +##################################### Docker images ####################################### +############################################################################################ + DockerServerImages: + needs: + - BuilderDebRelease + - BuilderDebAarch64 + runs-on: [self-hosted, style-checker] + steps: + - name: Clear repository + run: | + sudo rm -fr "$GITHUB_WORKSPACE" && mkdir "$GITHUB_WORKSPACE" + - name: Check out repository code + uses: actions/checkout@v2 + with: + fetch-depth: 0 # otherwise we will have no version info + - name: Check docker clickhouse/clickhouse-server building + run: | + cd "$GITHUB_WORKSPACE/tests/ci" + python3 docker_server.py --release-type head --no-push + python3 docker_server.py --release-type head --no-push --no-ubuntu \ + --image-repo clickhouse/clickhouse-keeper --image-path docker/keeper + - name: Cleanup + if: always() + run: | + docker kill "$(docker ps -q)" ||: + docker rm -f "$(docker ps -a -q)" ||: + sudo rm -fr "$TEMP_PATH" +############################################################################################ ##################################### BUILD REPORTER ####################################### ############################################################################################ BuilderReport: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cfc1ccb7175e..2ef05fe989b3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,9 @@ jobs: steps: - name: Deploy packages and assets run: | - curl '${{ secrets.PACKAGES_RELEASE_URL }}/release/${{ github.ref }}?binary=binary_darwin&binary=binary_darwin_aarch64&sync=true' -d '' + GITHUB_TAG="${GITHUB_REF#refs/tags/}" + curl --silent --data '' \ + '${{ secrets.PACKAGES_RELEASE_URL }}/release/'"${GITHUB_TAG}"'?binary=binary_darwin&binary=binary_darwin_aarch64&sync=true' ############################################################################################ ##################################### Docker images ####################################### ############################################################################################ diff --git a/.gitignore b/.gitignore index 7a513ec1a09b..082c15c1f439 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,5 @@ tests/queries/0_stateless/*.generated-expect /rust/**/target # It is autogenerated from *.in /rust/**/.cargo/config.toml + +utils/local-engine/tests/testConfig.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ab9766124014..1d9de6f08999 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -178,7 +178,7 @@ else () set(NO_WHOLE_ARCHIVE --no-whole-archive) endif () -option(ENABLE_CURL_BUILD "Enable curl, azure, sentry build on by default except MacOS." ON) +option(ENABLE_CURL_BUILD "Enable curl, azure, sentry build on by default except MacOS." OFF) if (OS_DARWIN) # Disable the curl, azure, senry build on MacOS set (ENABLE_CURL_BUILD OFF) @@ -422,6 +422,8 @@ endif () include(cmake/dbms_glob_sources.cmake) +set (CMAKE_POSITION_INDEPENDENT_CODE ON) + add_library(global-group INTERFACE) if (OS_LINUX OR OS_ANDROID) include(cmake/linux/default_libs.cmake) @@ -452,7 +454,6 @@ endif () set (CMAKE_POSTFIX_VARIABLE "CMAKE_${CMAKE_BUILD_TYPE_UC}_POSTFIX") -set (CMAKE_POSITION_INDEPENDENT_CODE OFF) if (OS_LINUX AND NOT ARCH_AARCH64) # Slightly more efficient code can be generated # It's disabled for ARM because otherwise ClickHouse cannot run on Android. @@ -510,8 +511,10 @@ macro (clickhouse_add_executable target) if (${type} STREQUAL EXECUTABLE) # disabled for TSAN and gcc since libtsan.a provides overrides too if (TARGET clickhouse_new_delete) - # operator::new/delete for executables (MemoryTracker stuff) - target_link_libraries (${target} PRIVATE clickhouse_new_delete) + if (NOT ${target} STREQUAL Git::Git) + # operator::new/delete for executables (MemoryTracker stuff) + target_link_libraries (${target} PRIVATE clickhouse_new_delete ${MALLOC_LIBRARIES}) + endif() endif() # In case of static jemalloc, because zone_register() is located in zone.c and diff --git a/cmake/autogenerated_versions.txt b/cmake/autogenerated_versions.txt index 27c086673cca..8e495b8257ed 100644 --- a/cmake/autogenerated_versions.txt +++ b/cmake/autogenerated_versions.txt @@ -5,8 +5,8 @@ SET(VERSION_REVISION 54470) SET(VERSION_MAJOR 23) SET(VERSION_MINOR 1) -SET(VERSION_PATCH 1) -SET(VERSION_GITHASH 688e488e930c83eefeac4f87c4cc029cc5b231e3) -SET(VERSION_DESCRIBE v23.1.1.1-testing) -SET(VERSION_STRING 23.1.1.1) +SET(VERSION_PATCH 3) +SET(VERSION_GITHASH 8dfb1700858195fa704221e360fa0798ac6ee9ed) +SET(VERSION_DESCRIBE v23.1.3.1-stable) +SET(VERSION_STRING 23.1.3.1) # end of autochange diff --git a/contrib/jemalloc-cmake/include_linux_x86_64/jemalloc/internal/jemalloc_internal_defs.h.in b/contrib/jemalloc-cmake/include_linux_x86_64/jemalloc/internal/jemalloc_internal_defs.h.in index d21098a4dcce..93d583e2ecb9 100644 --- a/contrib/jemalloc-cmake/include_linux_x86_64/jemalloc/internal/jemalloc_internal_defs.h.in +++ b/contrib/jemalloc-cmake/include_linux_x86_64/jemalloc/internal/jemalloc_internal_defs.h.in @@ -139,7 +139,8 @@ /* #undef JEMALLOC_MUTEX_INIT_CB */ /* Non-empty if the tls_model attribute is supported. */ -#define JEMALLOC_TLS_MODEL __attribute__((tls_model("initial-exec"))) +/*#define JEMALLOC_TLS_MODEL __attribute__((tls_model("initial-exec")))*/ +#define JEMALLOC_TLS_MODEL /* * JEMALLOC_DEBUG enables assertions and other sanity checks, and disables diff --git a/docs/en/sql-reference/functions/json-functions.md b/docs/en/sql-reference/functions/json-functions.md index 71483896189e..5797f7b5bc2f 100644 --- a/docs/en/sql-reference/functions/json-functions.md +++ b/docs/en/sql-reference/functions/json-functions.md @@ -200,6 +200,7 @@ Examples: ``` sql SELECT JSONExtract('{"a": "hello", "b": [-100, 200.0, 300]}', 'Tuple(String, Array(Float64))') = ('hello',[-100,200,300]) SELECT JSONExtract('{"a": "hello", "b": [-100, 200.0, 300]}', 'Tuple(b Array(Float64), a String)') = ([-100,200,300],'hello') +SELECT JSONExtract('{"a": "hello", "b": "world"}', 'Map(String, String)') = map('a', 'hello', 'b', 'world'); SELECT JSONExtract('{"a": "hello", "b": [-100, 200.0, 300]}', 'b', 'Array(Nullable(Int8))') = [-100, NULL, NULL] SELECT JSONExtract('{"a": "hello", "b": [-100, 200.0, 300]}', 'b', 4, 'Nullable(Int64)') = NULL SELECT JSONExtract('{"passed": true}', 'passed', 'UInt8') = 1 diff --git a/docs/en/sql-reference/functions/tuple-map-functions.md b/docs/en/sql-reference/functions/tuple-map-functions.md index 1905e53af3e8..9885ac9d76ea 100644 --- a/docs/en/sql-reference/functions/tuple-map-functions.md +++ b/docs/en/sql-reference/functions/tuple-map-functions.md @@ -66,6 +66,46 @@ Result: - [Map(key, value)](../../sql-reference/data-types/map.md) data type +## mapFromArrays + +Merges an [Array](../../sql-reference/data-types/array.md) of keys and an [Array](../../sql-reference/data-types/array.md) of values into a [Map(key, value)](../../sql-reference/data-types/map.md). Notice that the second argument could also be a [Map](../../sql-reference/data-types/map.md), thus it is casted to an Array when executing. + +The function is a more convenient alternative to `CAST((key_array, value_array_or_map), 'Map(key_type, value_type)')`. For example, instead of writing `CAST((['aa', 'bb'], [4, 5]), 'Map(String, UInt32)')`, you can write `mapFromArrays(['aa', 'bb'], [4, 5])`. + +**Syntax** + +```sql +mapFromArrays(keys, values) +``` + +Alias: `MAP_FROM_ARRAYS(keys, values)` + +**Parameters** +- `keys` — Given key array to create a map from. The nested type of array must be: [String](../../sql-reference/data-types/string.md), [Integer](../../sql-reference/data-types/int-uint.md), [LowCardinality](../../sql-reference/data-types/lowcardinality.md), [FixedString](../../sql-reference/data-types/fixedstring.md), [UUID](../../sql-reference/data-types/uuid.md), [Date](../../sql-reference/data-types/date.md), [DateTime](../../sql-reference/data-types/datetime.md), [Date32](../../sql-reference/data-types/date32.md), [Enum](../../sql-reference/data-types/enum.md) +- `values` - Given value array or map to create a map from. + +**Returned value** + +- A map whose keys and values are constructed from the key array and value array/map. + +**Example** + +Query: + +```sql +select mapFromArrays(['a', 'b', 'c'], [1, 2, 3]) + +┌─mapFromArrays(['a', 'b', 'c'], [1, 2, 3])─┐ +│ {'a':1,'b':2,'c':3} │ +└───────────────────────────────────────────┘ + +SELECT mapFromArrays([1, 2, 3], map('a', 1, 'b', 2, 'c', 3)) + +┌─mapFromArrays([1, 2, 3], map('a', 1, 'b', 2, 'c', 3))─┐ +│ {1:('a',1),2:('b',2),3:('c',3)} │ +└───────────────────────────────────────────────────────┘ +``` + ## mapAdd Collect all the keys and sum corresponding values. @@ -429,6 +469,8 @@ Result: │ {} │ └────────────────────────────┘ ``` + + ## mapApply diff --git a/programs/benchmark/CMakeLists.txt b/programs/benchmark/CMakeLists.txt index ad211399bb59..661b5106ebce 100644 --- a/programs/benchmark/CMakeLists.txt +++ b/programs/benchmark/CMakeLists.txt @@ -1,4 +1,4 @@ -set (CLICKHOUSE_BENCHMARK_SOURCES Benchmark.cpp) +set (CLICKHOUSE_BENCHMARK_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/Benchmark.cpp) set (CLICKHOUSE_BENCHMARK_LINK PRIVATE diff --git a/programs/compressor/CMakeLists.txt b/programs/compressor/CMakeLists.txt index ff642a32fd4f..4eac7aa5136e 100644 --- a/programs/compressor/CMakeLists.txt +++ b/programs/compressor/CMakeLists.txt @@ -1,6 +1,6 @@ # Also in utils -set (CLICKHOUSE_COMPRESSOR_SOURCES Compressor.cpp) +set (CLICKHOUSE_COMPRESSOR_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/Compressor.cpp) set (CLICKHOUSE_COMPRESSOR_LINK PRIVATE diff --git a/programs/extract-from-config/CMakeLists.txt b/programs/extract-from-config/CMakeLists.txt index ff2d79371172..a7abd92bae01 100644 --- a/programs/extract-from-config/CMakeLists.txt +++ b/programs/extract-from-config/CMakeLists.txt @@ -1,4 +1,4 @@ -set (CLICKHOUSE_EXTRACT_FROM_CONFIG_SOURCES ExtractFromConfig.cpp) +set (CLICKHOUSE_EXTRACT_FROM_CONFIG_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/ExtractFromConfig.cpp) set (CLICKHOUSE_EXTRACT_FROM_CONFIG_LINK PRIVATE diff --git a/programs/format/CMakeLists.txt b/programs/format/CMakeLists.txt index 49f17ef163f8..b5db7fb14344 100644 --- a/programs/format/CMakeLists.txt +++ b/programs/format/CMakeLists.txt @@ -1,4 +1,4 @@ -set (CLICKHOUSE_FORMAT_SOURCES Format.cpp) +set (CLICKHOUSE_FORMAT_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/Format.cpp) set (CLICKHOUSE_FORMAT_LINK PRIVATE diff --git a/programs/git-import/CMakeLists.txt b/programs/git-import/CMakeLists.txt index 279bb35a2722..1eb37b1ca9f8 100644 --- a/programs/git-import/CMakeLists.txt +++ b/programs/git-import/CMakeLists.txt @@ -1,4 +1,4 @@ -set (CLICKHOUSE_GIT_IMPORT_SOURCES git-import.cpp) +set (CLICKHOUSE_GIT_IMPORT_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/git-import.cpp) set (CLICKHOUSE_GIT_IMPORT_LINK PRIVATE diff --git a/programs/install/CMakeLists.txt b/programs/install/CMakeLists.txt index c3f4d96d6319..021f7dbe9717 100644 --- a/programs/install/CMakeLists.txt +++ b/programs/install/CMakeLists.txt @@ -1,4 +1,4 @@ -set (CLICKHOUSE_INSTALL_SOURCES Install.cpp) +set (CLICKHOUSE_INSTALL_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/Install.cpp) set (CLICKHOUSE_INSTALL_LINK PRIVATE diff --git a/programs/keeper-converter/CMakeLists.txt b/programs/keeper-converter/CMakeLists.txt index d529f94d3885..6c8226fa56d3 100644 --- a/programs/keeper-converter/CMakeLists.txt +++ b/programs/keeper-converter/CMakeLists.txt @@ -1,4 +1,4 @@ -set (CLICKHOUSE_KEEPER_CONVERTER_SOURCES KeeperConverter.cpp) +set (CLICKHOUSE_KEEPER_CONVERTER_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/KeeperConverter.cpp) set (CLICKHOUSE_KEEPER_CONVERTER_LINK PRIVATE diff --git a/programs/local/CMakeLists.txt b/programs/local/CMakeLists.txt index 6943af48ab90..7ce6661daaa7 100644 --- a/programs/local/CMakeLists.txt +++ b/programs/local/CMakeLists.txt @@ -1,4 +1,4 @@ -set (CLICKHOUSE_LOCAL_SOURCES LocalServer.cpp) +set (CLICKHOUSE_LOCAL_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/LocalServer.cpp) set (CLICKHOUSE_LOCAL_LINK PRIVATE diff --git a/programs/obfuscator/CMakeLists.txt b/programs/obfuscator/CMakeLists.txt index d1179b3718c0..c67b3c00dd7f 100644 --- a/programs/obfuscator/CMakeLists.txt +++ b/programs/obfuscator/CMakeLists.txt @@ -1,4 +1,4 @@ -set (CLICKHOUSE_OBFUSCATOR_SOURCES Obfuscator.cpp) +set (CLICKHOUSE_OBFUSCATOR_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/Obfuscator.cpp) set (CLICKHOUSE_OBFUSCATOR_LINK PRIVATE diff --git a/programs/server/CMakeLists.txt b/programs/server/CMakeLists.txt index 2cfa748d585e..05183daacd91 100644 --- a/programs/server/CMakeLists.txt +++ b/programs/server/CMakeLists.txt @@ -1,8 +1,8 @@ include(${ClickHouse_SOURCE_DIR}/cmake/embed_binary.cmake) set(CLICKHOUSE_SERVER_SOURCES - MetricsTransmitter.cpp - Server.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/MetricsTransmitter.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Server.cpp ) set (LINK_RESOURCE_LIB INTERFACE "-Wl,${WHOLE_ARCHIVE} $ -Wl,${NO_WHOLE_ARCHIVE}") @@ -36,4 +36,6 @@ clickhouse_embed_binaries( TARGET clickhouse_server_configs RESOURCES config.xml users.xml embedded.xml play.html dashboard.html js/uplot.js ) -add_dependencies(clickhouse-server-lib clickhouse_server_configs) +if(NOT CLICKHOUSE_ONE_SHARED) + add_dependencies(clickhouse-server-lib clickhouse_server_configs) +endif() diff --git a/src/AggregateFunctions/AggregateFunctionArray.h b/src/AggregateFunctions/AggregateFunctionArray.h index d1494f46f4bc..21394e3ce055 100644 --- a/src/AggregateFunctions/AggregateFunctionArray.h +++ b/src/AggregateFunctions/AggregateFunctionArray.h @@ -13,7 +13,7 @@ struct Settings; namespace ErrorCodes { - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; extern const int ILLEGAL_TYPE_OF_ARGUMENT; } @@ -129,7 +129,7 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper, properties }); factory.registerFunction("groupArraySample", { createAggregateFunctionGroupArraySample, properties }); diff --git a/src/Columns/ColumnAggregateFunction.cpp b/src/Columns/ColumnAggregateFunction.cpp index fd46b38ada85..2605b3a5966c 100644 --- a/src/Columns/ColumnAggregateFunction.cpp +++ b/src/Columns/ColumnAggregateFunction.cpp @@ -525,6 +525,35 @@ void ColumnAggregateFunction::insertDefault() pushBackAndCreateState(data, arena, func.get()); } +void ColumnAggregateFunction::insertRangeSelective( + const IColumn & from, const IColumn::Selector & selector, size_t selector_start, size_t length) +{ + const ColumnAggregateFunction & from_concrete = static_cast(from); + const auto & from_data = from_concrete.data; + if (!empty() && src.get() != &from_concrete) + { + ensureOwnership(); + Arena & arena = createOrGetArena(); + Arena * arena_ptr = &arena; + data.reserve(size() + length); + for (size_t i = 0; i < length; ++i) + { + pushBackAndCreateState(data, arena, func.get()); + func->merge(data.back(), from_data[selector[selector_start + i]], arena_ptr); + } + return; + } + /// Keep shared ownership of aggregation states. + src = from_concrete.getPtr(); + + size_t old_size = data.size(); + data.resize(old_size + length); + auto * data_start = data.data(); + size_t element_size = sizeof(data[0]); + for (size_t i = 0; i < length; ++i) + memcpy(data_start + old_size + i, &from_concrete.data[selector[selector_start + i]], element_size); +} + StringRef ColumnAggregateFunction::serializeValueIntoArena(size_t n, Arena & arena, const char *& begin) const { WriteBufferFromArena out(arena, begin); diff --git a/src/Columns/ColumnAggregateFunction.h b/src/Columns/ColumnAggregateFunction.h index 38040d65d3b2..84b75385991f 100644 --- a/src/Columns/ColumnAggregateFunction.h +++ b/src/Columns/ColumnAggregateFunction.h @@ -184,6 +184,8 @@ class ColumnAggregateFunction final : public COWHelper(src); + const Offsets & src_offsets = src_concrete.getOffsets(); + const IColumn & src_data = src_concrete.getData(); + IColumn & cur_data = getData(); + Offsets & cur_offsets = getOffsets(); + + size_t old_size = cur_offsets.size(); + size_t cur_size = old_size + length; + cur_data.reserve(cur_size); + cur_offsets.resize(cur_size); + + for (size_t i = 0; i < length; ++i) + { + size_t src_pos = selector[selector_start + i]; + size_t offset = src_offsets[src_pos - 1]; + size_t size = src_offsets[src_pos] - offset; + cur_data.insertRangeFrom(src_data, offset, size); + cur_offsets[old_size + i] = cur_offsets[old_size + i - 1] + size; // PaddedPODArray allows to use -1th element that will have value 0 + } +} + ColumnPtr ColumnArray::filter(const Filter & filt, ssize_t result_size_hint) const { diff --git a/src/Columns/ColumnArray.h b/src/Columns/ColumnArray.h index 44652fd0c4b1..7117aad0a895 100644 --- a/src/Columns/ColumnArray.h +++ b/src/Columns/ColumnArray.h @@ -84,6 +84,7 @@ class ColumnArray final : public COWHelper void updateWeakHash32(WeakHash32 & hash) const override; void updateHashFast(SipHash & hash) const override; void insertRangeFrom(const IColumn & src, size_t start, size_t length) override; + void insertRangeSelective(const IColumn & src, const Selector & selector, size_t selector_start, size_t length) override; void insert(const Field & x) override; void insertFrom(const IColumn & src_, size_t n) override; void insertDefault() override; @@ -149,6 +150,8 @@ class ColumnArray final : public COWHelper void gather(ColumnGathererStream & gatherer_stream) override; + bool canBeInsideNullable() const override { return true; } + ColumnPtr compress() const override; void forEachSubcolumn(ColumnCallback callback) const override diff --git a/src/Columns/ColumnCompressed.h b/src/Columns/ColumnCompressed.h index b258dbac8783..96dc0a2d1a85 100644 --- a/src/Columns/ColumnCompressed.h +++ b/src/Columns/ColumnCompressed.h @@ -85,6 +85,7 @@ class ColumnCompressed : public COWHelper bool isDefaultAt(size_t) const override { throwMustBeDecompressed(); } void insert(const Field &) override { throwMustBeDecompressed(); } void insertRangeFrom(const IColumn &, size_t, size_t) override { throwMustBeDecompressed(); } + void insertRangeSelective(const IColumn &, const Selector &, size_t, size_t) override { throwMustBeDecompressed(); } void insertData(const char *, size_t) override { throwMustBeDecompressed(); } void insertDefault() override { throwMustBeDecompressed(); } void popBack(size_t) override { throwMustBeDecompressed(); } diff --git a/src/Columns/ColumnConst.h b/src/Columns/ColumnConst.h index b86ed393e441..06df1bb91ca7 100644 --- a/src/Columns/ColumnConst.h +++ b/src/Columns/ColumnConst.h @@ -126,6 +126,11 @@ class ColumnConst final : public COWHelper s += length; } + void insertRangeSelective(const IColumn & /*src*/, const Selector & /*selector*/, size_t /*selector_start*/, size_t length) override + { + s += length; + } + void insert(const Field &) override { ++s; diff --git a/src/Columns/ColumnDecimal.cpp b/src/Columns/ColumnDecimal.cpp index e06593c5f45b..828f864a088c 100644 --- a/src/Columns/ColumnDecimal.cpp +++ b/src/Columns/ColumnDecimal.cpp @@ -261,6 +261,18 @@ void ColumnDecimal::insertRangeFrom(const IColumn & src, size_t start, size_t memcpy(data.data() + old_size, &src_vec.data[start], length * sizeof(data[0])); } +template +void ColumnDecimal::insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) +{ + size_t old_size = data.size(); + data.resize(old_size + length); + const auto & src_data = (static_cast(src)).getData(); + for (size_t i = 0; i < length; ++i) + { + data[old_size + i] = src_data[selector[selector_start + i]]; + } +} + template ColumnPtr ColumnDecimal::filter(const IColumn::Filter & filt, ssize_t result_size_hint) const { diff --git a/src/Columns/ColumnDecimal.h b/src/Columns/ColumnDecimal.h index 5634b9064bfb..e8f02a8c1cd2 100644 --- a/src/Columns/ColumnDecimal.h +++ b/src/Columns/ColumnDecimal.h @@ -65,6 +65,7 @@ class ColumnDecimal final : public COWHelper()); } void insertRangeFrom(const IColumn & src, size_t start, size_t length) override; + void insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) override; void popBack(size_t n) override { diff --git a/src/Columns/ColumnFixedString.cpp b/src/Columns/ColumnFixedString.cpp index 1b7355d91f59..7f4effd9ab1b 100644 --- a/src/Columns/ColumnFixedString.cpp +++ b/src/Columns/ColumnFixedString.cpp @@ -202,6 +202,22 @@ void ColumnFixedString::insertRangeFrom(const IColumn & src, size_t start, size_ memcpy(chars.data() + old_size, &src_concrete.chars[start * n], length * n); } +void ColumnFixedString::insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) +{ + const ColumnFixedString & src_concrete = static_cast(src); + + size_t old_size = chars.size(); + chars.resize(old_size + length * n); + auto * cur_data_end = chars.data() + old_size; + auto * src_data_start = src_concrete.chars.data(); + + for (size_t i = 0; i < length; ++i) + { + size_t src_pos = selector[selector_start + i]; + memcpySmallAllowReadWriteOverflow15(cur_data_end + i * n, src_data_start+ n * src_pos, n); + } +} + ColumnPtr ColumnFixedString::filter(const IColumn::Filter & filt, ssize_t result_size_hint) const { size_t col_size = size(); diff --git a/src/Columns/ColumnFixedString.h b/src/Columns/ColumnFixedString.h index 7c2d9b1a155d..cc577e78221c 100644 --- a/src/Columns/ColumnFixedString.h +++ b/src/Columns/ColumnFixedString.h @@ -154,6 +154,8 @@ class ColumnFixedString final : public COWHelperinsertRangeSelective(assert_cast(src).getNestedColumn(), selector, selector_start, length); +} + ColumnPtr ColumnMap::filter(const Filter & filt, ssize_t result_size_hint) const { auto filtered = nested->filter(filt, result_size_hint); diff --git a/src/Columns/ColumnMap.h b/src/Columns/ColumnMap.h index db918c3db501..c1f4ad1285a5 100644 --- a/src/Columns/ColumnMap.h +++ b/src/Columns/ColumnMap.h @@ -65,6 +65,7 @@ class ColumnMap final : public COWHelper void updateWeakHash32(WeakHash32 & hash) const override; void updateHashFast(SipHash & hash) const override; void insertRangeFrom(const IColumn & src, size_t start, size_t length) override; + void insertRangeSelective(const IColumn & src, const Selector & selector, size_t selector_start, size_t length) override; ColumnPtr filter(const Filter & filt, ssize_t result_size_hint) const override; void expand(const Filter & mask, bool inverted) override; ColumnPtr permute(const Permutation & perm, size_t limit) const override; @@ -72,6 +73,7 @@ class ColumnMap final : public COWHelper ColumnPtr replicate(const Offsets & offsets) const override; MutableColumns scatter(ColumnIndex num_columns, const Selector & selector) const override; void gather(ColumnGathererStream & gatherer_stream) override; + bool canBeInsideNullable() const override { return true; } int compareAt(size_t n, size_t m, const IColumn & rhs, int nan_direction_hint) const override; void compareColumn(const IColumn & rhs, size_t rhs_row_num, PaddedPODArray * row_indexes, PaddedPODArray & compare_results, diff --git a/src/Columns/ColumnNullable.cpp b/src/Columns/ColumnNullable.cpp index df2537dcbb5a..7a5b46936a62 100644 --- a/src/Columns/ColumnNullable.cpp +++ b/src/Columns/ColumnNullable.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #if USE_EMBEDDED_COMPILER @@ -183,6 +184,21 @@ void ColumnNullable::insertRangeFrom(const IColumn & src, size_t start, size_t l getNestedColumn().insertRangeFrom(*nullable_col.nested_column, start, length); } +void ColumnNullable::insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) +{ + const ColumnNullable & nullable_col = static_cast(src); + getNestedColumn().insertRangeSelective(*nullable_col.nested_column, selector, selector_start, length); + + if (!memoryIsZero(nullable_col.getNullMapData().data(), 0, nullable_col.size())) + { + getNullMapColumn().insertRangeSelective(*nullable_col.null_map, selector, selector_start, length); + } + else + { + getNullMapColumn().insertManyDefaults(length); + } +} + void ColumnNullable::insert(const Field & x) { if (x.isNull()) diff --git a/src/Columns/ColumnNullable.h b/src/Columns/ColumnNullable.h index 85bf095a9d18..b81f0032fcc8 100644 --- a/src/Columns/ColumnNullable.h +++ b/src/Columns/ColumnNullable.h @@ -66,6 +66,7 @@ class ColumnNullable final : public COWHelper const char * deserializeAndInsertFromArena(const char * pos) override; const char * skipSerializedInArena(const char * pos) const override; void insertRangeFrom(const IColumn & src, size_t start, size_t length) override; + void insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) override; void insert(const Field & x) override; void insertFrom(const IColumn & src, size_t n) override; @@ -128,6 +129,8 @@ class ColumnNullable final : public COWHelper void gather(ColumnGathererStream & gatherer_stream) override; + bool canBeInsideNullable() const override { return true; } + ColumnPtr compress() const override; void forEachSubcolumn(ColumnCallback callback) const override diff --git a/src/Columns/ColumnString.cpp b/src/Columns/ColumnString.cpp index b00600e17487..667a543f3de0 100644 --- a/src/Columns/ColumnString.cpp +++ b/src/Columns/ColumnString.cpp @@ -576,5 +576,41 @@ void ColumnString::validate() const "ColumnString validation failed: size mismatch (internal logical error) {} != {}", offsets.back(), chars.size()); } +void ColumnString::insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) +{ + const ColumnString & src_concrete = static_cast(src); + const Offsets & src_offsets = src_concrete.getOffsets(); + auto * src_data_start = src_concrete.chars.data(); + + Offsets & cur_offsets = getOffsets(); + + if (length == 0) + return; + + size_t old_offset_size = cur_offsets.size(); + cur_offsets.resize(old_offset_size + length); + + size_t old_chars_size = chars.size(); + size_t new_chars_size = old_chars_size; + for (size_t i = 0; i < length; ++i) + { + new_chars_size += src_concrete.sizeAt(selector[selector_start + i]); + } + chars.resize(new_chars_size); + + size_t cur_offset = cur_offsets[old_offset_size - 1]; + + auto * cur_chars_start = chars.data(); // realloc memory is not allowed in the following + for (size_t i = 0; i < length; ++i) + { + size_t src_pos = selector[selector_start + i]; + size_t offset = src_offsets[src_pos - 1]; + const size_t size_to_append = src_offsets[src_pos] - offset; /// -1th index is Ok, see PaddedPODArray. + + memcpySmallAllowReadWriteOverflow15(cur_chars_start + cur_offset, src_data_start + offset, size_to_append); + cur_offset += size_to_append; + cur_offsets[old_offset_size + i] = cur_offset; + } +} } diff --git a/src/Columns/ColumnString.h b/src/Columns/ColumnString.h index aa251b1fda0e..caeb34176ae5 100644 --- a/src/Columns/ColumnString.h +++ b/src/Columns/ColumnString.h @@ -193,6 +193,8 @@ class ColumnString final : public COWHelper void insertRangeFrom(const IColumn & src, size_t start, size_t length) override; + void insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) override; + ColumnPtr filter(const Filter & filt, ssize_t result_size_hint) const override; void expand(const Filter & mask, bool inverted) override; diff --git a/src/Columns/ColumnTuple.cpp b/src/Columns/ColumnTuple.cpp index 903540c18592..46f2fcf58ffd 100644 --- a/src/Columns/ColumnTuple.cpp +++ b/src/Columns/ColumnTuple.cpp @@ -233,6 +233,18 @@ void ColumnTuple::insertRangeFrom(const IColumn & src, size_t start, size_t leng start, length); } +void ColumnTuple::insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) +{ + const ColumnTuple & src_concrete = static_cast(src); + + const size_t tuple_size = columns.size(); + if (src_concrete.columns.size() != tuple_size) + throw Exception("Cannot insert value of different size into tuple", ErrorCodes::CANNOT_INSERT_VALUE_OF_DIFFERENT_SIZE_INTO_TUPLE); + + for (size_t i = 0; i < tuple_size; ++i) + columns[i]->insertRangeSelective(*src_concrete.columns[i], selector, selector_start, length); +} + ColumnPtr ColumnTuple::filter(const Filter & filt, ssize_t result_size_hint) const { const size_t tuple_size = columns.size(); diff --git a/src/Columns/ColumnTuple.h b/src/Columns/ColumnTuple.h index 25f6328b3fc6..b61c7f1bedab 100644 --- a/src/Columns/ColumnTuple.h +++ b/src/Columns/ColumnTuple.h @@ -68,6 +68,7 @@ class ColumnTuple final : public COWHelper void updateWeakHash32(WeakHash32 & hash) const override; void updateHashFast(SipHash & hash) const override; void insertRangeFrom(const IColumn & src, size_t start, size_t length) override; + void insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) override; ColumnPtr filter(const Filter & filt, ssize_t result_size_hint) const override; void expand(const Filter & mask, bool inverted) override; ColumnPtr permute(const Permutation & perm, size_t limit) const override; @@ -75,6 +76,7 @@ class ColumnTuple final : public COWHelper ColumnPtr replicate(const Offsets & offsets) const override; MutableColumns scatter(ColumnIndex num_columns, const Selector & selector) const override; void gather(ColumnGathererStream & gatherer_stream) override; + bool canBeInsideNullable() const override { return true; } int compareAt(size_t n, size_t m, const IColumn & rhs, int nan_direction_hint) const override; void compareColumn(const IColumn & rhs, size_t rhs_row_num, PaddedPODArray * row_indexes, PaddedPODArray & compare_results, diff --git a/src/Columns/ColumnVector.cpp b/src/Columns/ColumnVector.cpp index 96f76b70f313..d55e4fcbbc27 100644 --- a/src/Columns/ColumnVector.cpp +++ b/src/Columns/ColumnVector.cpp @@ -630,6 +630,18 @@ inline void doFilterAligned(const UInt8 *& filt_pos, const UInt8 *& filt_end_ali } ) +template +void ColumnVector::insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) +{ +size_t old_size = data.size(); +data.resize(old_size + length); +const auto & src_data = (static_cast(src)).getData(); +for (size_t i = 0; i < length; ++i) +{ + data[old_size + i] = src_data[selector[selector_start + i]]; +} +} + template ColumnPtr ColumnVector::filter(const IColumn::Filter & filt, ssize_t result_size_hint) const { diff --git a/src/Columns/ColumnVector.h b/src/Columns/ColumnVector.h index ded664301607..7e0d6ba12568 100644 --- a/src/Columns/ColumnVector.h +++ b/src/Columns/ColumnVector.h @@ -330,6 +330,8 @@ class ColumnVector final : public COWHelper> return this->template scatterImpl(num_columns, selector); } + void insertRangeSelective(const IColumn & src, const IColumn::Selector & selector, size_t selector_start, size_t length) override; + void gather(ColumnGathererStream & gatherer_stream) override; bool canBeInsideNullable() const override { return true; } diff --git a/src/Columns/IColumn.h b/src/Columns/IColumn.h index 53619c73e5b4..368377d0a6cf 100644 --- a/src/Columns/IColumn.h +++ b/src/Columns/IColumn.h @@ -379,6 +379,18 @@ class IColumn : public COW using Selector = PaddedPODArray; [[nodiscard]] virtual std::vector scatter(ColumnIndex num_columns, const Selector & selector) const = 0; + /// This function will get row index from selector and append the data to this column. + /// This function will handle indexes start from input 'selector_start' and will append 'size' times + /// For example: + /// input selector: [1, 2, 3, 4, 5, 6] + /// selector_start: 2 + /// length: 3 + /// This function will copy the [3, 4, 5] row of src to this column. + virtual void insertRangeSelective(const IColumn & /*src*/, const Selector & /*selector*/, size_t /*selector_start*/, size_t /*length*/) + { + throw Exception("insertRangeSelective is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } + /// Insert data from several other columns according to source mask (used in vertical merge). /// For now it is a helper to de-virtualize calls to insert*() functions inside gather loop /// (descendants should call gatherer_stream.gather(*this) to implement this function.) diff --git a/src/Common/CurrentMemoryTracker.cpp b/src/Common/CurrentMemoryTracker.cpp index 720df07efb93..d3a2ebb47c3b 100644 --- a/src/Common/CurrentMemoryTracker.cpp +++ b/src/Common/CurrentMemoryTracker.cpp @@ -36,6 +36,9 @@ MemoryTracker * getMemoryTracker() } using DB::current_thread; +thread_local std::function CurrentMemoryTracker::before_alloc = nullptr; + +thread_local std::function CurrentMemoryTracker::before_free = nullptr; void CurrentMemoryTracker::allocImpl(Int64 size, bool throw_if_memory_exceeded) { @@ -52,11 +55,12 @@ void CurrentMemoryTracker::allocImpl(Int64 size, bool throw_if_memory_exceeded) if (current_thread) { Int64 will_be = current_thread->untracked_memory + size; - - if (will_be > current_thread->untracked_memory_limit) + if (will_be > current_thread->untracked_memory_limit) { - memory_tracker->allocImpl(will_be, throw_if_memory_exceeded); current_thread->untracked_memory = 0; + if (before_alloc) + before_alloc(will_be, throw_if_memory_exceeded); + memory_tracker->allocImpl(will_be, throw_if_memory_exceeded); } else { @@ -106,6 +110,8 @@ void CurrentMemoryTracker::free(Int64 size) current_thread->untracked_memory -= size; if (current_thread->untracked_memory < -current_thread->untracked_memory_limit) { + if (before_free) + before_free(-current_thread->untracked_memory); memory_tracker->free(-current_thread->untracked_memory); current_thread->untracked_memory = 0; } diff --git a/src/Common/CurrentMemoryTracker.h b/src/Common/CurrentMemoryTracker.h index e125e4cbe4ab..42ec20417454 100644 --- a/src/Common/CurrentMemoryTracker.h +++ b/src/Common/CurrentMemoryTracker.h @@ -13,7 +13,8 @@ struct CurrentMemoryTracker /// This function should be called after memory deallocation. static void free(Int64 size); static void check(); - + static thread_local std::function before_alloc; + static thread_local std::function before_free; private: static void allocImpl(Int64 size, bool throw_if_memory_exceeded); }; diff --git a/src/Common/ErrorCodes.cpp b/src/Common/ErrorCodes.cpp index 0ad4cbb9e6f5..549fd58af0d5 100644 --- a/src/Common/ErrorCodes.cpp +++ b/src/Common/ErrorCodes.cpp @@ -197,7 +197,7 @@ M(187, COLLATION_COMPARISON_FAILED) \ M(188, UNKNOWN_ACTION) \ M(189, TABLE_MUST_NOT_BE_CREATED_MANUALLY) \ - M(190, SIZES_OF_ARRAYS_DOESNT_MATCH) \ + M(190, SIZES_OF_ARRAYS_DONT_MATCH) \ M(191, SET_SIZE_LIMIT_EXCEEDED) \ M(192, UNKNOWN_USER) \ M(193, WRONG_PASSWORD) \ diff --git a/src/Common/formatIPv6.h b/src/Common/formatIPv6.h index 69963336cefb..bc8f70f047ce 100644 --- a/src/Common/formatIPv6.h +++ b/src/Common/formatIPv6.h @@ -194,6 +194,9 @@ inline bool parseIPv6(T * &src, EOFfunction eof, unsigned char * dst, int32_t fi if (groups <= 1 && zptr == nullptr) /// IPv4 block can't be the first return clear_dst(); + if (group_start) /// first octet of IPv4 should be already parsed as an IPv6 group + return clear_dst(); + ++src; if (eof()) return clear_dst(); diff --git a/src/DataTypes/DataTypeArray.h b/src/DataTypes/DataTypeArray.h index 033a657c845d..7e1e8ab764ef 100644 --- a/src/DataTypes/DataTypeArray.h +++ b/src/DataTypes/DataTypeArray.h @@ -33,7 +33,7 @@ class DataTypeArray final : public IDataType bool canBeInsideNullable() const override { - return false; + return true; } MutableColumnPtr createColumn() const override; diff --git a/src/DataTypes/DataTypeMap.h b/src/DataTypes/DataTypeMap.h index 2ab5c602a259..793866ff9aa1 100644 --- a/src/DataTypes/DataTypeMap.h +++ b/src/DataTypes/DataTypeMap.h @@ -31,7 +31,7 @@ class DataTypeMap final : public IDataType std::string doGetName() const override; const char * getFamilyName() const override { return "Map"; } - bool canBeInsideNullable() const override { return false; } + bool canBeInsideNullable() const override { return true; } MutableColumnPtr createColumn() const override; diff --git a/src/DataTypes/DataTypeNullable.h b/src/DataTypes/DataTypeNullable.h index 379119b364c0..0eba6bcdf3ab 100644 --- a/src/DataTypes/DataTypeNullable.h +++ b/src/DataTypes/DataTypeNullable.h @@ -41,6 +41,7 @@ class DataTypeNullable final : public IDataType bool onlyNull() const override; bool canBeInsideLowCardinality() const override { return nested_data_type->canBeInsideLowCardinality(); } bool canBePromoted() const override { return nested_data_type->canBePromoted(); } + bool canBeInsideNullable() const override { return true; } const DataTypePtr & getNestedType() const { return nested_data_type; } private: diff --git a/src/DataTypes/DataTypeTuple.h b/src/DataTypes/DataTypeTuple.h index 152f21015f5a..2dd8307726dc 100644 --- a/src/DataTypes/DataTypeTuple.h +++ b/src/DataTypes/DataTypeTuple.h @@ -34,7 +34,7 @@ class DataTypeTuple final : public IDataType std::string doGetName() const override; const char * getFamilyName() const override { return "Tuple"; } - bool canBeInsideNullable() const override { return false; } + bool canBeInsideNullable() const override { return true; } bool supportsSparseSerialization() const override { return true; } MutableColumnPtr createColumn() const override; diff --git a/src/DataTypes/NestedUtils.cpp b/src/DataTypes/NestedUtils.cpp index e5ba23b9df83..f029ac6ba27f 100644 --- a/src/DataTypes/NestedUtils.cpp +++ b/src/DataTypes/NestedUtils.cpp @@ -25,7 +25,7 @@ namespace DB namespace ErrorCodes { extern const int ILLEGAL_COLUMN; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; } namespace Nested @@ -242,7 +242,7 @@ void validateArraySizes(const Block & block) const ColumnArray & another_array_column = assert_cast(*elem.column); if (!first_array_column.hasEqualOffsets(another_array_column)) - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Elements '{}' and '{}' " "of Nested data structure '{}' (Array columns) have different array sizes.", block.getByPosition(it->second).name, elem.name, split.first); diff --git a/src/Formats/FormatSettings.h b/src/Formats/FormatSettings.h index 93cdbe630632..12b6a1f5812b 100644 --- a/src/Formats/FormatSettings.h +++ b/src/Formats/FormatSettings.h @@ -33,6 +33,7 @@ struct FormatSettings bool null_as_default = true; bool decimal_trailing_zeros = false; bool defaults_for_omitted_fields = true; + bool use_lowercase_column_name = false; bool seekable_read = true; UInt64 max_rows_to_read_for_schema_inference = 100; diff --git a/src/Functions/FunctionHelpers.cpp b/src/Functions/FunctionHelpers.cpp index 791b2c1bbdb7..0721c52cdadc 100644 --- a/src/Functions/FunctionHelpers.cpp +++ b/src/Functions/FunctionHelpers.cpp @@ -16,7 +16,7 @@ namespace ErrorCodes { extern const int ILLEGAL_COLUMN; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; extern const int ILLEGAL_TYPE_OF_ARGUMENT; } @@ -213,7 +213,7 @@ checkAndGetNestedArrayOffset(const IColumn ** columns, size_t num_arguments) if (i == 0) offsets = offsets_i; else if (*offsets_i != *offsets) - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, "Lengths of all arrays passed to aggregate function must be equal."); + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Lengths of all arrays passed to aggregate function must be equal."); } return {nested_columns, offsets->data()}; } diff --git a/src/Functions/FunctionsJSON.cpp b/src/Functions/FunctionsJSON.cpp index 69edadc4db91..fbd987577e9a 100644 --- a/src/Functions/FunctionsJSON.cpp +++ b/src/Functions/FunctionsJSON.cpp @@ -1,1653 +1,10 @@ -#include -#include - -#include - -#include -#include -#include - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - +#include #include -#include -#include -#include -#include -#include - -#include -#include - - -#include "config.h" namespace DB { -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int ILLEGAL_COLUMN; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; -} - -template -concept HasIndexOperator = requires (T t) -{ - t[0]; -}; - -/// Functions to parse JSONs and extract values from it. -/// The first argument of all these functions gets a JSON, -/// after that there are any number of arguments specifying path to a desired part from the JSON's root. -/// For example, -/// select JSONExtractInt('{"a": "hello", "b": [-100, 200.0, 300]}', 'b', 1) = -100 - -class FunctionJSONHelpers -{ -public: - template typename Impl, class JSONParser> - class Executor - { - public: - static ColumnPtr run(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) - { - MutableColumnPtr to{result_type->createColumn()}; - to->reserve(input_rows_count); - - if (arguments.empty()) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least one argument", String(Name::name)); - - const auto & first_column = arguments[0]; - if (!isString(first_column.type)) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "The first argument of function {} should be a string containing JSON, illegal type: " - "{}", String(Name::name), first_column.type->getName()); - - const ColumnPtr & arg_json = first_column.column; - const auto * col_json_const = typeid_cast(arg_json.get()); - const auto * col_json_string - = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); - - if (!col_json_string) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {}", arg_json->getName()); - - const ColumnString::Chars & chars = col_json_string->getChars(); - const ColumnString::Offsets & offsets = col_json_string->getOffsets(); - - size_t num_index_arguments = Impl::getNumberOfIndexArguments(arguments); - std::vector moves = prepareMoves(Name::name, arguments, 1, num_index_arguments); - - /// Preallocate memory in parser if necessary. - JSONParser parser; - if constexpr (has_member_function_reserve::value) - { - size_t max_size = calculateMaxSize(offsets); - if (max_size) - parser.reserve(max_size); - } - - Impl impl; - - /// prepare() does Impl-specific preparation before handling each row. - if constexpr (has_member_function_prepare::*)(const char *, const ColumnsWithTypeAndName &, const DataTypePtr &)>::value) - impl.prepare(Name::name, arguments, result_type); - - using Element = typename JSONParser::Element; - - Element document; - bool document_ok = false; - if (col_json_const) - { - std::string_view json{reinterpret_cast(chars.data()), offsets[0] - 1}; - document_ok = parser.parse(json, document); - } - - for (const auto i : collections::range(0, input_rows_count)) - { - if (!col_json_const) - { - std::string_view json{reinterpret_cast(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1}; - document_ok = parser.parse(json, document); - } - - bool added_to_column = false; - if (document_ok) - { - /// Perform moves. - Element element; - std::string_view last_key; - bool moves_ok = performMoves(arguments, i, document, moves, element, last_key); - - if (moves_ok) - added_to_column = impl.insertResultToColumn(*to, element, last_key); - } - - /// We add default value (=null or zero) if something goes wrong, we don't throw exceptions in these JSON functions. - if (!added_to_column) - to->insertDefault(); - } - return to; - } - }; - -private: - BOOST_TTI_HAS_MEMBER_FUNCTION(reserve) - BOOST_TTI_HAS_MEMBER_FUNCTION(prepare) - - /// Represents a move of a JSON iterator described by a single argument passed to a JSON function. - /// For example, the call JSONExtractInt('{"a": "hello", "b": [-100, 200.0, 300]}', 'b', 1) - /// contains two moves: {MoveType::ConstKey, "b"} and {MoveType::ConstIndex, 1}. - /// Keys and indices can be nonconst, in this case they are calculated for each row. - enum class MoveType - { - Key, - Index, - ConstKey, - ConstIndex, - }; - - struct Move - { - explicit Move(MoveType type_, size_t index_ = 0) : type(type_), index(index_) {} - Move(MoveType type_, const String & key_) : type(type_), key(key_) {} - MoveType type; - size_t index = 0; - String key; - }; - - static std::vector prepareMoves( - const char * function_name, - const ColumnsWithTypeAndName & columns, - size_t first_index_argument, - size_t num_index_arguments) - { - std::vector moves; - moves.reserve(num_index_arguments); - for (const auto i : collections::range(first_index_argument, first_index_argument + num_index_arguments)) - { - const auto & column = columns[i]; - if (!isString(column.type) && !isNativeInteger(column.type)) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "The argument {} of function {} should be a string specifying key " - "or an integer specifying index, illegal type: {}", - std::to_string(i + 1), String(function_name), column.type->getName()); - - if (column.column && isColumnConst(*column.column)) - { - const auto & column_const = assert_cast(*column.column); - if (isString(column.type)) - moves.emplace_back(MoveType::ConstKey, column_const.getValue()); - else - moves.emplace_back(MoveType::ConstIndex, column_const.getInt(0)); - } - else - { - if (isString(column.type)) - moves.emplace_back(MoveType::Key, ""); - else - moves.emplace_back(MoveType::Index, 0); - } - } - return moves; - } - - - /// Performs moves of types MoveType::Index and MoveType::ConstIndex. - template - static bool performMoves(const ColumnsWithTypeAndName & arguments, size_t row, - const typename JSONParser::Element & document, const std::vector & moves, - typename JSONParser::Element & element, std::string_view & last_key) - { - typename JSONParser::Element res_element = document; - std::string_view key; - - for (size_t j = 0; j != moves.size(); ++j) - { - switch (moves[j].type) - { - case MoveType::ConstIndex: - { - if (!moveToElementByIndex(res_element, static_cast(moves[j].index), key)) - return false; - break; - } - case MoveType::ConstKey: - { - key = moves[j].key; - if (!moveToElementByKey(res_element, key)) - return false; - break; - } - case MoveType::Index: - { - Int64 index = (*arguments[j + 1].column)[row].get(); - if (!moveToElementByIndex(res_element, static_cast(index), key)) - return false; - break; - } - case MoveType::Key: - { - key = (*arguments[j + 1].column).getDataAt(row).toView(); - if (!moveToElementByKey(res_element, key)) - return false; - break; - } - } - } - - element = res_element; - last_key = key; - return true; - } - - template - static bool moveToElementByIndex(typename JSONParser::Element & element, int index, std::string_view & out_key) - { - if (element.isArray()) - { - auto array = element.getArray(); - if (index >= 0) - --index; - else - index += array.size(); - - if (static_cast(index) >= array.size()) - return false; - element = array[index]; - out_key = {}; - return true; - } - - if constexpr (HasIndexOperator) - { - if (element.isObject()) - { - auto object = element.getObject(); - if (index >= 0) - --index; - else - index += object.size(); - - if (static_cast(index) >= object.size()) - return false; - std::tie(out_key, element) = object[index]; - return true; - } - } - - return {}; - } - - /// Performs moves of types MoveType::Key and MoveType::ConstKey. - template - static bool moveToElementByKey(typename JSONParser::Element & element, std::string_view key) - { - if (!element.isObject()) - return false; - auto object = element.getObject(); - return object.find(key, element); - } - - static size_t calculateMaxSize(const ColumnString::Offsets & offsets) - { - size_t max_size = 0; - for (const auto i : collections::range(0, offsets.size())) - { - size_t size = offsets[i] - offsets[i - 1]; - if (max_size < size) - max_size = size; - } - if (max_size) - --max_size; - return max_size; - } - -}; - - -template typename Impl> -class ExecutableFunctionJSON : public IExecutableFunction, WithContext -{ - -public: - explicit ExecutableFunctionJSON(const NullPresence & null_presence_, bool allow_simdjson_, const DataTypePtr & json_return_type_) - : null_presence(null_presence_), allow_simdjson(allow_simdjson_), json_return_type(json_return_type_) - { - } - - String getName() const override { return Name::name; } - bool useDefaultImplementationForNulls() const override { return false; } - bool useDefaultImplementationForConstants() const override { return true; } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override - { - if (null_presence.has_null_constant) - return result_type->createColumnConstWithDefaultValue(input_rows_count); - - ColumnsWithTypeAndName temporary_columns = null_presence.has_nullable ? createBlockWithNestedColumns(arguments) : arguments; - ColumnPtr temporary_result = chooseAndRunJSONParser(temporary_columns, json_return_type, input_rows_count); - if (null_presence.has_nullable) - return wrapInNullable(temporary_result, arguments, result_type, input_rows_count); - return temporary_result; - } - -private: - - ColumnPtr - chooseAndRunJSONParser(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const - { -#if USE_SIMDJSON - if (allow_simdjson) - return FunctionJSONHelpers::Executor::run(arguments, result_type, input_rows_count); -#endif - -#if USE_RAPIDJSON - return FunctionJSONHelpers::Executor::run(arguments, result_type, input_rows_count); -#else - return FunctionJSONHelpers::Executor::run(arguments, result_type, input_rows_count); -#endif - } - - NullPresence null_presence; - bool allow_simdjson; - DataTypePtr json_return_type; -}; - - -template typename Impl> -class FunctionBaseFunctionJSON : public IFunctionBase -{ -public: - explicit FunctionBaseFunctionJSON( - const NullPresence & null_presence_, - bool allow_simdjson_, - DataTypes argument_types_, - DataTypePtr return_type_, - DataTypePtr json_return_type_) - : null_presence(null_presence_) - , allow_simdjson(allow_simdjson_) - , argument_types(std::move(argument_types_)) - , return_type(std::move(return_type_)) - , json_return_type(std::move(json_return_type_)) - { - } - - String getName() const override { return Name::name; } - - const DataTypes & getArgumentTypes() const override - { - return argument_types; - } - - const DataTypePtr & getResultType() const override - { - return return_type; - } - - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } - - ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override - { - return std::make_unique>(null_presence, allow_simdjson, json_return_type); - } - -private: - NullPresence null_presence; - bool allow_simdjson; - DataTypes argument_types; - DataTypePtr return_type; - DataTypePtr json_return_type; -}; - - -/// We use IFunctionOverloadResolver instead of IFunction to handle non-default NULL processing. -/// Both NULL and JSON NULL should generate NULL value. If any argument is NULL, return NULL. -template typename Impl> -class JSONOverloadResolver : public IFunctionOverloadResolver, WithContext -{ -public: - static constexpr auto name = Name::name; - - String getName() const override { return name; } - - static FunctionOverloadResolverPtr create(ContextPtr context_) - { - return std::make_unique(context_); - } - - explicit JSONOverloadResolver(ContextPtr context_) : WithContext(context_) {} - - bool isVariadic() const override { return true; } - size_t getNumberOfArguments() const override { return 0; } - bool useDefaultImplementationForNulls() const override { return false; } - - FunctionBasePtr build(const ColumnsWithTypeAndName & arguments) const override - { - bool has_nothing_argument = false; - for (const auto & arg : arguments) - has_nothing_argument |= isNothing(arg.type); - - DataTypePtr json_return_type = Impl::getReturnType(Name::name, createBlockWithNestedColumns(arguments)); - NullPresence null_presence = getNullPresense(arguments); - DataTypePtr return_type; - if (has_nothing_argument) - return_type = std::make_shared(); - else if (null_presence.has_null_constant) - return_type = makeNullable(std::make_shared()); - else if (null_presence.has_nullable) - return_type = makeNullable(json_return_type); - else - return_type = json_return_type; - - /// Top-level LowCardinality columns are processed outside JSON parser. - json_return_type = removeLowCardinality(json_return_type); - - DataTypes argument_types; - argument_types.reserve(arguments.size()); - for (const auto & argument : arguments) - argument_types.emplace_back(argument.type); - return std::make_unique>( - null_presence, getContext()->getSettingsRef().allow_simdjson, argument_types, return_type, json_return_type); - } -}; - - -struct NameJSONHas { static constexpr auto name{"JSONHas"}; }; -struct NameIsValidJSON { static constexpr auto name{"isValidJSON"}; }; -struct NameJSONLength { static constexpr auto name{"JSONLength"}; }; -struct NameJSONKey { static constexpr auto name{"JSONKey"}; }; -struct NameJSONType { static constexpr auto name{"JSONType"}; }; -struct NameJSONExtractInt { static constexpr auto name{"JSONExtractInt"}; }; -struct NameJSONExtractUInt { static constexpr auto name{"JSONExtractUInt"}; }; -struct NameJSONExtractFloat { static constexpr auto name{"JSONExtractFloat"}; }; -struct NameJSONExtractBool { static constexpr auto name{"JSONExtractBool"}; }; -struct NameJSONExtractString { static constexpr auto name{"JSONExtractString"}; }; -struct NameJSONExtract { static constexpr auto name{"JSONExtract"}; }; -struct NameJSONExtractKeysAndValues { static constexpr auto name{"JSONExtractKeysAndValues"}; }; -struct NameJSONExtractRaw { static constexpr auto name{"JSONExtractRaw"}; }; -struct NameJSONExtractArrayRaw { static constexpr auto name{"JSONExtractArrayRaw"}; }; -struct NameJSONExtractKeysAndValuesRaw { static constexpr auto name{"JSONExtractKeysAndValuesRaw"}; }; -struct NameJSONExtractKeys { static constexpr auto name{"JSONExtractKeys"}; }; - - -template -class JSONHasImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) { return std::make_shared(); } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - static bool insertResultToColumn(IColumn & dest, const Element &, std::string_view) - { - ColumnVector & col_vec = assert_cast &>(dest); - col_vec.insertValue(1); - return true; - } -}; - - -template -class IsValidJSONImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char * function_name, const ColumnsWithTypeAndName & arguments) - { - if (arguments.size() != 1) - { - /// IsValidJSON() shouldn't get parameters other than JSON. - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs exactly one argument", - String(function_name)); - } - return std::make_shared(); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName &) { return 0; } - - static bool insertResultToColumn(IColumn & dest, const Element &, std::string_view) - { - /// This function is called only if JSON is valid. - /// If JSON isn't valid then `FunctionJSON::Executor::run()` adds default value (=zero) to `dest` without calling this function. - ColumnVector & col_vec = assert_cast &>(dest); - col_vec.insertValue(1); - return true; - } -}; - - -template -class JSONLengthImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - return std::make_shared(); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - size_t size; - if (element.isArray()) - size = element.getArray().size(); - else if (element.isObject()) - size = element.getObject().size(); - else - return false; - - ColumnVector & col_vec = assert_cast &>(dest); - col_vec.insertValue(size); - return true; - } -}; - - -template -class JSONKeyImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - return std::make_shared(); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - static bool insertResultToColumn(IColumn & dest, const Element &, std::string_view last_key) - { - if (last_key.empty()) - return false; - ColumnString & col_str = assert_cast(dest); - col_str.insertData(last_key.data(), last_key.size()); - return true; - } -}; - - -template -class JSONTypeImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - static const std::vector> values = { - {"Array", '['}, - {"Object", '{'}, - {"String", '"'}, - {"Int64", 'i'}, - {"UInt64", 'u'}, - {"Double", 'd'}, - {"Bool", 'b'}, - {"Null", 0}, /// the default value for the column. - }; - return std::make_shared>(values); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - UInt8 type; - switch (element.type()) - { - case ElementType::INT64: - type = 'i'; - break; - case ElementType::UINT64: - type = 'u'; - break; - case ElementType::DOUBLE: - type = 'd'; - break; - case ElementType::STRING: - type = '"'; - break; - case ElementType::ARRAY: - type = '['; - break; - case ElementType::OBJECT: - type = '{'; - break; - case ElementType::NULL_VALUE: - type = 0; - break; - default: - return false; - } - - ColumnVector & col_vec = assert_cast &>(dest); - col_vec.insertValue(type); - return true; - } -}; - - -template -class JSONExtractNumericImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - return std::make_shared>(); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - NumberType value; - - switch (element.type()) - { - case ElementType::DOUBLE: - if constexpr (std::is_floating_point_v) - { - /// We permit inaccurate conversion of double to float. - /// Example: double 0.1 from JSON is not representable in float. - /// But it will be more convenient for user to perform conversion. - value = static_cast(element.getDouble()); - } - else if (!accurate::convertNumeric(element.getDouble(), value)) - return false; - break; - case ElementType::UINT64: - if (!accurate::convertNumeric(element.getUInt64(), value)) - return false; - break; - case ElementType::INT64: - if (!accurate::convertNumeric(element.getInt64(), value)) - return false; - break; - case ElementType::BOOL: - if constexpr (is_integer && convert_bool_to_integer) - { - value = static_cast(element.getBool()); - break; - } - return false; - case ElementType::STRING: - { - auto rb = ReadBufferFromMemory{element.getString()}; - if constexpr (std::is_floating_point_v) - { - if (!tryReadFloatText(value, rb) || !rb.eof()) - return false; - } - else - { - if (tryReadIntText(value, rb) && rb.eof()) - break; - - /// Try to parse float and convert it to integer. - Float64 tmp_float; - rb.position() = rb.buffer().begin(); - if (!tryReadFloatText(tmp_float, rb) || !rb.eof()) - return false; - - if (!accurate::convertNumeric(tmp_float, value)) - return false; - } - break; - } - default: - return false; - } - - if (dest.getDataType() == TypeIndex::LowCardinality) - { - ColumnLowCardinality & col_low = assert_cast(dest); - col_low.insertData(reinterpret_cast(&value), sizeof(value)); - } - else - { - auto & col_vec = assert_cast &>(dest); - col_vec.insertValue(value); - } - return true; - } -}; - - -template -using JSONExtractInt8Impl = JSONExtractNumericImpl; -template -using JSONExtractUInt8Impl = JSONExtractNumericImpl; -template -using JSONExtractInt16Impl = JSONExtractNumericImpl; -template -using JSONExtractUInt16Impl = JSONExtractNumericImpl; -template -using JSONExtractInt32Impl = JSONExtractNumericImpl; -template -using JSONExtractUInt32Impl = JSONExtractNumericImpl; -template -using JSONExtractInt64Impl = JSONExtractNumericImpl; -template -using JSONExtractUInt64Impl = JSONExtractNumericImpl; -template -using JSONExtractInt128Impl = JSONExtractNumericImpl; -template -using JSONExtractUInt128Impl = JSONExtractNumericImpl; -template -using JSONExtractInt256Impl = JSONExtractNumericImpl; -template -using JSONExtractUInt256Impl = JSONExtractNumericImpl; -template -using JSONExtractFloat32Impl = JSONExtractNumericImpl; -template -using JSONExtractFloat64Impl = JSONExtractNumericImpl; -template -using JSONExtractDecimal32Impl = JSONExtractNumericImpl; -template -using JSONExtractDecimal64Impl = JSONExtractNumericImpl; -template -using JSONExtractDecimal128Impl = JSONExtractNumericImpl; -template -using JSONExtractDecimal256Impl = JSONExtractNumericImpl; - - -template -class JSONExtractBoolImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - return std::make_shared(); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - bool value; - switch (element.type()) - { - case ElementType::BOOL: - value = element.getBool(); - break; - case ElementType::INT64: - value = element.getInt64() != 0; - break; - case ElementType::UINT64: - value = element.getUInt64() != 0; - break; - default: - return false; - } - - auto & col_vec = assert_cast &>(dest); - col_vec.insertValue(static_cast(value)); - return true; - } -}; - -template -class JSONExtractRawImpl; - -template -class JSONExtractStringImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - return std::make_shared(); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - if (element.isNull()) - return false; - - if (!element.isString()) - return JSONExtractRawImpl::insertResultToColumn(dest, element, {}); - - auto str = element.getString(); - - if (dest.getDataType() == TypeIndex::LowCardinality) - { - ColumnLowCardinality & col_low = assert_cast(dest); - col_low.insertData(str.data(), str.size()); - } - else - { - ColumnString & col_str = assert_cast(dest); - col_str.insertData(str.data(), str.size()); - } - return true; - } -}; - -/// Nodes of the extract tree. We need the extract tree to extract from JSON complex values containing array, tuples or nullables. -template -struct JSONExtractTree -{ - using Element = typename JSONParser::Element; - - class Node - { - public: - Node() = default; - virtual ~Node() = default; - virtual bool insertResultToColumn(IColumn &, const Element &) = 0; - }; - - template - class NumericNode : public Node - { - public: - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - return JSONExtractNumericImpl::insertResultToColumn(dest, element, {}); - } - }; - - class LowCardinalityFixedStringNode : public Node - { - public: - explicit LowCardinalityFixedStringNode(const size_t fixed_length_) : fixed_length(fixed_length_) { } - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - // If element is an object we delegate the insertion to JSONExtractRawImpl - if (element.isObject()) - return JSONExtractRawImpl::insertResultToLowCardinalityFixedStringColumn(dest, element, fixed_length); - else if (!element.isString()) - return false; - - auto str = element.getString(); - if (str.size() > fixed_length) - return false; - - // For the non low cardinality case of FixedString, the padding is done in the FixedString Column implementation. - // In order to avoid having to pass the data to a FixedString Column and read it back (which would slow down the execution) - // the data is padded here and written directly to the Low Cardinality Column - if (str.size() == fixed_length) - { - assert_cast(dest).insertData(str.data(), str.size()); - } - else - { - String padded_str(str); - padded_str.resize(fixed_length, '\0'); - - assert_cast(dest).insertData(padded_str.data(), padded_str.size()); - } - return true; - } - - private: - const size_t fixed_length; - }; - - class UUIDNode : public Node - { - public: - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - if (!element.isString()) - return false; - - auto uuid = parseFromString(element.getString()); - if (dest.getDataType() == TypeIndex::LowCardinality) - { - ColumnLowCardinality & col_low = assert_cast(dest); - col_low.insertData(reinterpret_cast(&uuid), sizeof(uuid)); - } - else - { - assert_cast(dest).insert(uuid); - } - return true; - } - }; - - template - class DecimalNode : public Node - { - public: - explicit DecimalNode(DataTypePtr data_type_) : data_type(data_type_) {} - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - const auto * type = assert_cast *>(data_type.get()); - - DecimalType value{}; - - switch (element.type()) - { - case ElementType::DOUBLE: - value = convertToDecimal, DataTypeDecimal>( - element.getDouble(), type->getScale()); - break; - case ElementType::UINT64: - value = convertToDecimal, DataTypeDecimal>( - element.getUInt64(), type->getScale()); - break; - case ElementType::INT64: - value = convertToDecimal, DataTypeDecimal>( - element.getInt64(), type->getScale()); - break; - case ElementType::STRING: { - auto rb = ReadBufferFromMemory{element.getString()}; - if (!SerializationDecimal::tryReadText(value, rb, DecimalUtils::max_precision, type->getScale())) - return false; - break; - } - default: - return false; - } - - assert_cast &>(dest).insert(value); - return true; - } - - private: - DataTypePtr data_type; - }; - - class StringNode : public Node - { - public: - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - return JSONExtractStringImpl::insertResultToColumn(dest, element, {}); - } - }; - - class FixedStringNode : public Node - { - public: - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - if (element.isNull()) - return false; - - if (!element.isString()) - return JSONExtractRawImpl::insertResultToFixedStringColumn(dest, element, {}); - - auto str = element.getString(); - auto & col_str = assert_cast(dest); - if (str.size() > col_str.getN()) - return false; - col_str.insertData(str.data(), str.size()); - - return true; - } - }; - - template - class EnumNode : public Node - { - public: - explicit EnumNode(const std::vector> & name_value_pairs_) : name_value_pairs(name_value_pairs_) - { - for (const auto & name_value_pair : name_value_pairs) - { - name_to_value_map.emplace(name_value_pair.first, name_value_pair.second); - only_values.emplace(name_value_pair.second); - } - } - - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - auto & col_vec = assert_cast &>(dest); - - if (element.isInt64()) - { - Type value; - if (!accurate::convertNumeric(element.getInt64(), value) || !only_values.contains(value)) - return false; - col_vec.insertValue(value); - return true; - } - - if (element.isUInt64()) - { - Type value; - if (!accurate::convertNumeric(element.getUInt64(), value) || !only_values.contains(value)) - return false; - col_vec.insertValue(value); - return true; - } - - if (element.isString()) - { - auto value = name_to_value_map.find(element.getString()); - if (value == name_to_value_map.end()) - return false; - col_vec.insertValue(value->second); - return true; - } - - return false; - } - - private: - std::vector> name_value_pairs; - std::unordered_map name_to_value_map; - std::unordered_set only_values; - }; - - class NullableNode : public Node - { - public: - explicit NullableNode(std::unique_ptr nested_) : nested(std::move(nested_)) {} - - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - ColumnNullable & col_null = assert_cast(dest); - if (!nested->insertResultToColumn(col_null.getNestedColumn(), element)) - return false; - col_null.getNullMapColumn().insertValue(0); - return true; - } - - private: - std::unique_ptr nested; - }; - - class ArrayNode : public Node - { - public: - explicit ArrayNode(std::unique_ptr nested_) : nested(std::move(nested_)) {} - - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - if (!element.isArray()) - return false; - - auto array = element.getArray(); - - ColumnArray & col_arr = assert_cast(dest); - auto & data = col_arr.getData(); - size_t old_size = data.size(); - bool were_valid_elements = false; - - for (auto value : array) - { - if (nested->insertResultToColumn(data, value)) - were_valid_elements = true; - else - data.insertDefault(); - } - - if (!were_valid_elements) - { - data.popBack(data.size() - old_size); - return false; - } - - col_arr.getOffsets().push_back(data.size()); - return true; - } - - private: - std::unique_ptr nested; - }; - - class TupleNode : public Node - { - public: - TupleNode(std::vector> nested_, const std::vector & explicit_names_) : nested(std::move(nested_)), explicit_names(explicit_names_) - { - for (size_t i = 0; i != explicit_names.size(); ++i) - name_to_index_map.emplace(explicit_names[i], i); - } - - bool insertResultToColumn(IColumn & dest, const Element & element) override - { - ColumnTuple & tuple = assert_cast(dest); - size_t old_size = dest.size(); - bool were_valid_elements = false; - - auto set_size = [&](size_t size) - { - for (size_t i = 0; i != tuple.tupleSize(); ++i) - { - auto & col = tuple.getColumn(i); - if (col.size() != size) - { - if (col.size() > size) - col.popBack(col.size() - size); - else - while (col.size() < size) - col.insertDefault(); - } - } - }; - - if (element.isArray()) - { - auto array = element.getArray(); - auto it = array.begin(); - - for (size_t index = 0; (index != nested.size()) && (it != array.end()); ++index) - { - if (nested[index]->insertResultToColumn(tuple.getColumn(index), *it++)) - were_valid_elements = true; - else - tuple.getColumn(index).insertDefault(); - } - - set_size(old_size + static_cast(were_valid_elements)); - return were_valid_elements; - } - - if (element.isObject()) - { - auto object = element.getObject(); - if (name_to_index_map.empty()) - { - auto it = object.begin(); - for (size_t index = 0; (index != nested.size()) && (it != object.end()); ++index) - { - if (nested[index]->insertResultToColumn(tuple.getColumn(index), (*it++).second)) - were_valid_elements = true; - else - tuple.getColumn(index).insertDefault(); - } - } - else - { - for (const auto & [key, value] : object) - { - auto index = name_to_index_map.find(key); - if (index != name_to_index_map.end()) - { - if (nested[index->second]->insertResultToColumn(tuple.getColumn(index->second), value)) - were_valid_elements = true; - } - } - } - - set_size(old_size + static_cast(were_valid_elements)); - return were_valid_elements; - } - - return false; - } - - private: - std::vector> nested; - std::vector explicit_names; - std::unordered_map name_to_index_map; - }; - - static std::unique_ptr build(const char * function_name, const DataTypePtr & type) - { - switch (type->getTypeId()) - { - case TypeIndex::UInt8: return std::make_unique>(); - case TypeIndex::UInt16: return std::make_unique>(); - case TypeIndex::UInt32: return std::make_unique>(); - case TypeIndex::UInt64: return std::make_unique>(); - case TypeIndex::UInt128: return std::make_unique>(); - case TypeIndex::UInt256: return std::make_unique>(); - case TypeIndex::Int8: return std::make_unique>(); - case TypeIndex::Int16: return std::make_unique>(); - case TypeIndex::Int32: return std::make_unique>(); - case TypeIndex::Int64: return std::make_unique>(); - case TypeIndex::Int128: return std::make_unique>(); - case TypeIndex::Int256: return std::make_unique>(); - case TypeIndex::Float32: return std::make_unique>(); - case TypeIndex::Float64: return std::make_unique>(); - case TypeIndex::String: return std::make_unique(); - case TypeIndex::FixedString: return std::make_unique(); - case TypeIndex::UUID: return std::make_unique(); - case TypeIndex::LowCardinality: - { - // The low cardinality case is treated in two different ways: - // For FixedString type, an especial class is implemented for inserting the data in the destination column, - // as the string length must be passed in order to check and pad the incoming data. - // For the rest of low cardinality types, the insertion is done in their corresponding class, adapting the data - // as needed for the insertData function of the ColumnLowCardinality. - auto dictionary_type = typeid_cast(type.get())->getDictionaryType(); - if ((*dictionary_type).getTypeId() == TypeIndex::FixedString) - { - auto fixed_length = typeid_cast(dictionary_type.get())->getN(); - return std::make_unique(fixed_length); - } - return build(function_name, dictionary_type); - } - case TypeIndex::Decimal256: return std::make_unique>(type); - case TypeIndex::Decimal128: return std::make_unique>(type); - case TypeIndex::Decimal64: return std::make_unique>(type); - case TypeIndex::Decimal32: return std::make_unique>(type); - case TypeIndex::Enum8: - return std::make_unique>(static_cast(*type).getValues()); - case TypeIndex::Enum16: - return std::make_unique>(static_cast(*type).getValues()); - case TypeIndex::Nullable: - { - return std::make_unique(build(function_name, static_cast(*type).getNestedType())); - } - case TypeIndex::Array: - { - return std::make_unique(build(function_name, static_cast(*type).getNestedType())); - } - case TypeIndex::Tuple: - { - const auto & tuple = static_cast(*type); - const auto & tuple_elements = tuple.getElements(); - std::vector> elements; - elements.reserve(tuple_elements.size()); - for (const auto & tuple_element : tuple_elements) - elements.emplace_back(build(function_name, tuple_element)); - return std::make_unique(std::move(elements), tuple.haveExplicitNames() ? tuple.getElementNames() : Strings{}); - } - default: - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Function {} doesn't support the return type schema: {}", - String(function_name), type->getName()); - } - } -}; - - -template -class JSONExtractImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char * function_name, const ColumnsWithTypeAndName & arguments) - { - if (arguments.size() < 2) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least two arguments", String(function_name)); - - const auto & col = arguments.back(); - const auto * col_type_const = typeid_cast(col.column.get()); - if (!col_type_const || !isString(col.type)) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "The last argument of function {} should " - "be a constant string specifying the return data type, illegal value: {}", - String(function_name), col.name); - - return DataTypeFactory::instance().get(col_type_const->getValue()); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 2; } - - void prepare(const char * function_name, const ColumnsWithTypeAndName &, const DataTypePtr & result_type) - { - extract_tree = JSONExtractTree::build(function_name, result_type); - } - - bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - return extract_tree->insertResultToColumn(dest, element); - } - -protected: - std::unique_ptr::Node> extract_tree; -}; - - -template -class JSONExtractKeysAndValuesImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char * function_name, const ColumnsWithTypeAndName & arguments) - { - if (arguments.size() < 2) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least two arguments", String(function_name)); - - const auto & col = arguments.back(); - const auto * col_type_const = typeid_cast(col.column.get()); - if (!col_type_const || !isString(col.type)) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "The last argument of function {} should " - "be a constant string specifying the values' data type, illegal value: {}", - String(function_name), col.name); - - DataTypePtr key_type = std::make_unique(); - DataTypePtr value_type = DataTypeFactory::instance().get(col_type_const->getValue()); - DataTypePtr tuple_type = std::make_unique(DataTypes{key_type, value_type}); - return std::make_unique(tuple_type); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 2; } - - void prepare(const char * function_name, const ColumnsWithTypeAndName &, const DataTypePtr & result_type) - { - const auto tuple_type = typeid_cast(result_type.get())->getNestedType(); - const auto value_type = typeid_cast(tuple_type.get())->getElements()[1]; - extract_tree = JSONExtractTree::build(function_name, value_type); - } - - bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - if (!element.isObject()) - return false; - - auto object = element.getObject(); - - auto & col_arr = assert_cast(dest); - auto & col_tuple = assert_cast(col_arr.getData()); - size_t old_size = col_tuple.size(); - auto & col_key = assert_cast(col_tuple.getColumn(0)); - auto & col_value = col_tuple.getColumn(1); - - for (const auto & [key, value] : object) - { - if (extract_tree->insertResultToColumn(col_value, value)) - col_key.insertData(key.data(), key.size()); - } - - if (col_tuple.size() == old_size) - return false; - - col_arr.getOffsets().push_back(col_tuple.size()); - return true; - } - -private: - std::unique_ptr::Node> extract_tree; -}; - - -template -class JSONExtractRawImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - return std::make_shared(); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - if (dest.getDataType() == TypeIndex::LowCardinality) - { - ColumnString::Chars chars; - WriteBufferFromVector buf(chars, AppendModeTag()); - traverse(element, buf); - buf.finalize(); - assert_cast(dest).insertData(reinterpret_cast(chars.data()), chars.size()); - } - else - { - ColumnString & col_str = assert_cast(dest); - auto & chars = col_str.getChars(); - WriteBufferFromVector buf(chars, AppendModeTag()); - traverse(element, buf); - buf.finalize(); - chars.push_back(0); - col_str.getOffsets().push_back(chars.size()); - } - return true; - } - - // We use insertResultToFixedStringColumn in case we are inserting raw data in a FixedString column - static bool insertResultToFixedStringColumn(IColumn & dest, const Element & element, std::string_view) - { - ColumnFixedString::Chars chars; - WriteBufferFromVector buf(chars, AppendModeTag()); - traverse(element, buf); - buf.finalize(); - - auto & col_str = assert_cast(dest); - - if (chars.size() > col_str.getN()) - return false; - - chars.resize_fill(col_str.getN()); - col_str.insertData(reinterpret_cast(chars.data()), chars.size()); - - - return true; - } - - // We use insertResultToLowCardinalityFixedStringColumn in case we are inserting raw data in a Low Cardinality FixedString column - static bool insertResultToLowCardinalityFixedStringColumn(IColumn & dest, const Element & element, size_t fixed_length) - { - if (element.getObject().size() > fixed_length) - return false; - - ColumnFixedString::Chars chars; - WriteBufferFromVector buf(chars, AppendModeTag()); - traverse(element, buf); - buf.finalize(); - - if (chars.size() > fixed_length) - return false; - chars.resize_fill(fixed_length); - assert_cast(dest).insertData(reinterpret_cast(chars.data()), chars.size()); - - return true; - } - -private: - static void traverse(const Element & element, WriteBuffer & buf) - { - if (element.isInt64()) - { - writeIntText(element.getInt64(), buf); - return; - } - if (element.isUInt64()) - { - writeIntText(element.getUInt64(), buf); - return; - } - if (element.isDouble()) - { - writeFloatText(element.getDouble(), buf); - return; - } - if (element.isBool()) - { - if (element.getBool()) - writeCString("true", buf); - else - writeCString("false", buf); - return; - } - if (element.isString()) - { - writeJSONString(element.getString(), buf, formatSettings()); - return; - } - if (element.isArray()) - { - writeChar('[', buf); - bool need_comma = false; - for (auto value : element.getArray()) - { - if (std::exchange(need_comma, true)) - writeChar(',', buf); - traverse(value, buf); - } - writeChar(']', buf); - return; - } - if (element.isObject()) - { - writeChar('{', buf); - bool need_comma = false; - for (auto [key, value] : element.getObject()) - { - if (std::exchange(need_comma, true)) - writeChar(',', buf); - writeJSONString(key, buf, formatSettings()); - writeChar(':', buf); - traverse(value, buf); - } - writeChar('}', buf); - return; - } - if (element.isNull()) - { - writeCString("null", buf); - return; - } - } - - static const FormatSettings & formatSettings() - { - static const FormatSettings the_instance = [] - { - FormatSettings settings; - settings.json.escape_forward_slashes = false; - return settings; - }(); - return the_instance; - } -}; - - -template -class JSONExtractArrayRawImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - return std::make_shared(std::make_shared()); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - if (!element.isArray()) - return false; - - auto array = element.getArray(); - ColumnArray & col_res = assert_cast(dest); - - for (auto value : array) - JSONExtractRawImpl::insertResultToColumn(col_res.getData(), value, {}); - - col_res.getOffsets().push_back(col_res.getOffsets().back() + array.size()); - return true; - } -}; - - -template -class JSONExtractKeysAndValuesRawImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - DataTypePtr string_type = std::make_unique(); - DataTypePtr tuple_type = std::make_unique(DataTypes{string_type, string_type}); - return std::make_unique(tuple_type); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - if (!element.isObject()) - return false; - - auto object = element.getObject(); - - auto & col_arr = assert_cast(dest); - auto & col_tuple = assert_cast(col_arr.getData()); - auto & col_key = assert_cast(col_tuple.getColumn(0)); - auto & col_value = assert_cast(col_tuple.getColumn(1)); - - for (const auto & [key, value] : object) - { - col_key.insertData(key.data(), key.size()); - JSONExtractRawImpl::insertResultToColumn(col_value, value, {}); - } - - col_arr.getOffsets().push_back(col_arr.getOffsets().back() + object.size()); - return true; - } -}; - -template -class JSONExtractKeysImpl -{ -public: - using Element = typename JSONParser::Element; - - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) - { - return std::make_unique(std::make_shared()); - } - - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - - bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) - { - if (!element.isObject()) - return false; - - auto object = element.getObject(); - - ColumnArray & col_res = assert_cast(dest); - auto & col_key = assert_cast(col_res.getData()); - - for (const auto & [key, value] : object) - { - col_key.insertData(key.data(), key.size()); - } - - col_res.getOffsets().push_back(col_res.getOffsets().back() + object.size()); - return true; - } -}; - REGISTER_FUNCTION(JSON) { factory.registerFunction>(); diff --git a/src/Functions/FunctionsJSON.h b/src/Functions/FunctionsJSON.h new file mode 100644 index 000000000000..4727bb892bfe --- /dev/null +++ b/src/Functions/FunctionsJSON.h @@ -0,0 +1,1679 @@ +#pragma once + +#include +#include + +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + + +#include "config.h" + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ILLEGAL_COLUMN; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +template +concept HasIndexOperator = requires (T t) +{ + t[0]; +}; + +/// Functions to parse JSONs and extract values from it. +/// The first argument of all these functions gets a JSON, +/// after that there are any number of arguments specifying path to a desired part from the JSON's root. +/// For example, +/// select JSONExtractInt('{"a": "hello", "b": [-100, 200.0, 300]}', 'b', 1) = -100 + +class FunctionJSONHelpers +{ +public: + template typename Impl, class JSONParser> + class Executor + { + public: + static ColumnPtr run(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) + { + MutableColumnPtr to{result_type->createColumn()}; + to->reserve(input_rows_count); + + if (arguments.empty()) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least one argument", String(Name::name)); + + const auto & first_column = arguments[0]; + if (!isString(first_column.type)) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "The first argument of function {} should be a string containing JSON, illegal type: " + "{}", String(Name::name), first_column.type->getName()); + + const ColumnPtr & arg_json = first_column.column; + const auto * col_json_const = typeid_cast(arg_json.get()); + const auto * col_json_string + = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); + + if (!col_json_string) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {}", arg_json->getName()); + + const ColumnString::Chars & chars = col_json_string->getChars(); + const ColumnString::Offsets & offsets = col_json_string->getOffsets(); + + size_t num_index_arguments = Impl::getNumberOfIndexArguments(arguments); + std::vector moves = prepareMoves(Name::name, arguments, 1, num_index_arguments); + + /// Preallocate memory in parser if necessary. + JSONParser parser; + if constexpr (has_member_function_reserve::value) + { + size_t max_size = calculateMaxSize(offsets); + if (max_size) + parser.reserve(max_size); + } + + Impl impl; + + /// prepare() does Impl-specific preparation before handling each row. + if constexpr (has_member_function_prepare::*)(const char *, const ColumnsWithTypeAndName &, const DataTypePtr &)>::value) + impl.prepare(Name::name, arguments, result_type); + + using Element = typename JSONParser::Element; + + Element document; + bool document_ok = false; + if (col_json_const) + { + std::string_view json{reinterpret_cast(chars.data()), offsets[0] - 1}; + document_ok = parser.parse(json, document); + } + + for (const auto i : collections::range(0, input_rows_count)) + { + if (!col_json_const) + { + std::string_view json{reinterpret_cast(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1}; + document_ok = parser.parse(json, document); + } + + bool added_to_column = false; + if (document_ok) + { + /// Perform moves. + Element element; + std::string_view last_key; + bool moves_ok = performMoves(arguments, i, document, moves, element, last_key); + + if (moves_ok) + added_to_column = impl.insertResultToColumn(*to, element, last_key); + } + + /// We add default value (=null or zero) if something goes wrong, we don't throw exceptions in these JSON functions. + if (!added_to_column) + to->insertDefault(); + } + return to; + } + }; + +private: + BOOST_TTI_HAS_MEMBER_FUNCTION(reserve) + BOOST_TTI_HAS_MEMBER_FUNCTION(prepare) + + /// Represents a move of a JSON iterator described by a single argument passed to a JSON function. + /// For example, the call JSONExtractInt('{"a": "hello", "b": [-100, 200.0, 300]}', 'b', 1) + /// contains two moves: {MoveType::ConstKey, "b"} and {MoveType::ConstIndex, 1}. + /// Keys and indices can be nonconst, in this case they are calculated for each row. + enum class MoveType + { + Key, + Index, + ConstKey, + ConstIndex, + }; + + struct Move + { + explicit Move(MoveType type_, size_t index_ = 0) : type(type_), index(index_) {} + Move(MoveType type_, const String & key_) : type(type_), key(key_) {} + MoveType type; + size_t index = 0; + String key; + }; + + static std::vector prepareMoves( + const char * function_name, + const ColumnsWithTypeAndName & columns, + size_t first_index_argument, + size_t num_index_arguments) + { + std::vector moves; + moves.reserve(num_index_arguments); + for (const auto i : collections::range(first_index_argument, first_index_argument + num_index_arguments)) + { + const auto & column = columns[i]; + if (!isString(column.type) && !isNativeInteger(column.type)) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "The argument {} of function {} should be a string specifying key " + "or an integer specifying index, illegal type: {}", + std::to_string(i + 1), String(function_name), column.type->getName()); + + if (column.column && isColumnConst(*column.column)) + { + const auto & column_const = assert_cast(*column.column); + if (isString(column.type)) + moves.emplace_back(MoveType::ConstKey, column_const.getValue()); + else + moves.emplace_back(MoveType::ConstIndex, column_const.getInt(0)); + } + else + { + if (isString(column.type)) + moves.emplace_back(MoveType::Key, ""); + else + moves.emplace_back(MoveType::Index, 0); + } + } + return moves; + } + + + /// Performs moves of types MoveType::Index and MoveType::ConstIndex. + template + static bool performMoves(const ColumnsWithTypeAndName & arguments, size_t row, + const typename JSONParser::Element & document, const std::vector & moves, + typename JSONParser::Element & element, std::string_view & last_key) + { + typename JSONParser::Element res_element = document; + std::string_view key; + + for (size_t j = 0; j != moves.size(); ++j) + { + switch (moves[j].type) + { + case MoveType::ConstIndex: + { + if (!moveToElementByIndex(res_element, static_cast(moves[j].index), key)) + return false; + break; + } + case MoveType::ConstKey: + { + key = moves[j].key; + if (!moveToElementByKey(res_element, key)) + return false; + break; + } + case MoveType::Index: + { + Int64 index = (*arguments[j + 1].column)[row].get(); + if (!moveToElementByIndex(res_element, static_cast(index), key)) + return false; + break; + } + case MoveType::Key: + { + key = (*arguments[j + 1].column).getDataAt(row).toView(); + if (!moveToElementByKey(res_element, key)) + return false; + break; + } + } + } + + element = res_element; + last_key = key; + return true; + } + + template + static bool moveToElementByIndex(typename JSONParser::Element & element, int index, std::string_view & out_key) + { + if (element.isArray()) + { + auto array = element.getArray(); + if (index >= 0) + --index; + else + index += array.size(); + + if (static_cast(index) >= array.size()) + return false; + element = array[index]; + out_key = {}; + return true; + } + + if constexpr (HasIndexOperator) + { + if (element.isObject()) + { + auto object = element.getObject(); + if (index >= 0) + --index; + else + index += object.size(); + + if (static_cast(index) >= object.size()) + return false; + std::tie(out_key, element) = object[index]; + return true; + } + } + + return {}; + } + + /// Performs moves of types MoveType::Key and MoveType::ConstKey. + template + static bool moveToElementByKey(typename JSONParser::Element & element, std::string_view key) + { + if (!element.isObject()) + return false; + auto object = element.getObject(); + return object.find(key, element); + } + + static size_t calculateMaxSize(const ColumnString::Offsets & offsets) + { + size_t max_size = 0; + for (const auto i : collections::range(0, offsets.size())) + { + size_t size = offsets[i] - offsets[i - 1]; + if (max_size < size) + max_size = size; + } + if (max_size) + --max_size; + return max_size; + } + +}; + + +template typename Impl> +class ExecutableFunctionJSON : public IExecutableFunction, WithContext +{ + +public: + explicit ExecutableFunctionJSON(const NullPresence & null_presence_, bool allow_simdjson_, const DataTypePtr & json_return_type_) + : null_presence(null_presence_), allow_simdjson(allow_simdjson_), json_return_type(json_return_type_) + { + } + + String getName() const override { return Name::name; } + bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForConstants() const override { return true; } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override + { + if (null_presence.has_null_constant) + return result_type->createColumnConstWithDefaultValue(input_rows_count); + + ColumnsWithTypeAndName temporary_columns = null_presence.has_nullable ? createBlockWithNestedColumns(arguments) : arguments; + ColumnPtr temporary_result = chooseAndRunJSONParser(temporary_columns, json_return_type, input_rows_count); + if (null_presence.has_nullable) + return wrapInNullable(temporary_result, arguments, result_type, input_rows_count); + return temporary_result; + } + +private: + + ColumnPtr + chooseAndRunJSONParser(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const + { +#if USE_SIMDJSON + if (allow_simdjson) + return FunctionJSONHelpers::Executor::run(arguments, result_type, input_rows_count); +#endif + +#if USE_RAPIDJSON + return FunctionJSONHelpers::Executor::run(arguments, result_type, input_rows_count); +#else + return FunctionJSONHelpers::Executor::run(arguments, result_type, input_rows_count); +#endif + } + + NullPresence null_presence; + bool allow_simdjson; + DataTypePtr json_return_type; +}; + + +template typename Impl> +class FunctionBaseFunctionJSON : public IFunctionBase +{ +public: + explicit FunctionBaseFunctionJSON( + const NullPresence & null_presence_, + bool allow_simdjson_, + DataTypes argument_types_, + DataTypePtr return_type_, + DataTypePtr json_return_type_) + : null_presence(null_presence_) + , allow_simdjson(allow_simdjson_) + , argument_types(std::move(argument_types_)) + , return_type(std::move(return_type_)) + , json_return_type(std::move(json_return_type_)) + { + } + + String getName() const override { return Name::name; } + + const DataTypes & getArgumentTypes() const override + { + return argument_types; + } + + const DataTypePtr & getResultType() const override + { + return return_type; + } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override + { + return std::make_unique>(null_presence, allow_simdjson, json_return_type); + } + +private: + NullPresence null_presence; + bool allow_simdjson; + DataTypes argument_types; + DataTypePtr return_type; + DataTypePtr json_return_type; +}; + + +/// We use IFunctionOverloadResolver instead of IFunction to handle non-default NULL processing. +/// Both NULL and JSON NULL should generate NULL value. If any argument is NULL, return NULL. +template typename Impl> +class JSONOverloadResolver : public IFunctionOverloadResolver, WithContext +{ +public: + static constexpr auto name = Name::name; + + String getName() const override { return name; } + + static FunctionOverloadResolverPtr create(ContextPtr context_) + { + return std::make_unique(context_); + } + + explicit JSONOverloadResolver(ContextPtr context_) : WithContext(context_) {} + + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + bool useDefaultImplementationForNulls() const override { return false; } + + FunctionBasePtr build(const ColumnsWithTypeAndName & arguments) const override + { + bool has_nothing_argument = false; + for (const auto & arg : arguments) + has_nothing_argument |= isNothing(arg.type); + + DataTypePtr json_return_type = Impl::getReturnType(Name::name, createBlockWithNestedColumns(arguments)); + NullPresence null_presence = getNullPresense(arguments); + DataTypePtr return_type; + if (has_nothing_argument) + return_type = std::make_shared(); + else if (null_presence.has_null_constant) + return_type = makeNullable(std::make_shared()); + else if (null_presence.has_nullable) + return_type = makeNullable(json_return_type); + else + return_type = json_return_type; + + /// Top-level LowCardinality columns are processed outside JSON parser. + json_return_type = removeLowCardinality(json_return_type); + + DataTypes argument_types; + argument_types.reserve(arguments.size()); + for (const auto & argument : arguments) + argument_types.emplace_back(argument.type); + return std::make_unique>( + null_presence, getContext()->getSettingsRef().allow_simdjson, argument_types, return_type, json_return_type); + } +}; + + +struct NameJSONHas { static constexpr auto name{"JSONHas"}; }; +struct NameIsValidJSON { static constexpr auto name{"isValidJSON"}; }; +struct NameJSONLength { static constexpr auto name{"JSONLength"}; }; +struct NameJSONKey { static constexpr auto name{"JSONKey"}; }; +struct NameJSONType { static constexpr auto name{"JSONType"}; }; +struct NameJSONExtractInt { static constexpr auto name{"JSONExtractInt"}; }; +struct NameJSONExtractUInt { static constexpr auto name{"JSONExtractUInt"}; }; +struct NameJSONExtractFloat { static constexpr auto name{"JSONExtractFloat"}; }; +struct NameJSONExtractBool { static constexpr auto name{"JSONExtractBool"}; }; +struct NameJSONExtractString { static constexpr auto name{"JSONExtractString"}; }; +struct NameJSONExtract { static constexpr auto name{"JSONExtract"}; }; +struct NameJSONExtractKeysAndValues { static constexpr auto name{"JSONExtractKeysAndValues"}; }; +struct NameJSONExtractRaw { static constexpr auto name{"JSONExtractRaw"}; }; +struct NameJSONExtractArrayRaw { static constexpr auto name{"JSONExtractArrayRaw"}; }; +struct NameJSONExtractKeysAndValuesRaw { static constexpr auto name{"JSONExtractKeysAndValuesRaw"}; }; +struct NameJSONExtractKeys { static constexpr auto name{"JSONExtractKeys"}; }; + + +template +class JSONHasImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) { return std::make_shared(); } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, const Element &, std::string_view) + { + ColumnVector & col_vec = assert_cast &>(dest); + col_vec.insertValue(1); + return true; + } +}; + + +template +class IsValidJSONImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char * function_name, const ColumnsWithTypeAndName & arguments) + { + if (arguments.size() != 1) + { + /// IsValidJSON() shouldn't get parameters other than JSON. + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs exactly one argument", + String(function_name)); + } + return std::make_shared(); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName &) { return 0; } + + static bool insertResultToColumn(IColumn & dest, const Element &, std::string_view) + { + /// This function is called only if JSON is valid. + /// If JSON isn't valid then `FunctionJSON::Executor::run()` adds default value (=zero) to `dest` without calling this function. + ColumnVector & col_vec = assert_cast &>(dest); + col_vec.insertValue(1); + return true; + } +}; + + +template +class JSONLengthImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_shared(); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + size_t size; + if (element.isArray()) + size = element.getArray().size(); + else if (element.isObject()) + size = element.getObject().size(); + else + return false; + + ColumnVector & col_vec = assert_cast &>(dest); + col_vec.insertValue(size); + return true; + } +}; + + +template +class JSONKeyImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_shared(); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, const Element &, std::string_view last_key) + { + if (last_key.empty()) + return false; + ColumnString & col_str = assert_cast(dest); + col_str.insertData(last_key.data(), last_key.size()); + return true; + } +}; + + +template +class JSONTypeImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + static const std::vector> values = { + {"Array", '['}, + {"Object", '{'}, + {"String", '"'}, + {"Int64", 'i'}, + {"UInt64", 'u'}, + {"Double", 'd'}, + {"Bool", 'b'}, + {"Null", 0}, /// the default value for the column. + }; + return std::make_shared>(values); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + UInt8 type; + switch (element.type()) + { + case ElementType::INT64: + type = 'i'; + break; + case ElementType::UINT64: + type = 'u'; + break; + case ElementType::DOUBLE: + type = 'd'; + break; + case ElementType::STRING: + type = '"'; + break; + case ElementType::ARRAY: + type = '['; + break; + case ElementType::OBJECT: + type = '{'; + break; + case ElementType::BOOL: + type = 'b'; + break; + case ElementType::NULL_VALUE: + type = 0; + break; + } + + ColumnVector & col_vec = assert_cast &>(dest); + col_vec.insertValue(type); + return true; + } +}; + + +template +class JSONExtractNumericImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_shared>(); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + NumberType value; + + switch (element.type()) + { + case ElementType::DOUBLE: + if constexpr (std::is_floating_point_v) + { + /// We permit inaccurate conversion of double to float. + /// Example: double 0.1 from JSON is not representable in float. + /// But it will be more convenient for user to perform conversion. + value = static_cast(element.getDouble()); + } + else if (!accurate::convertNumeric(element.getDouble(), value)) + return false; + break; + case ElementType::UINT64: + if (!accurate::convertNumeric(element.getUInt64(), value)) + return false; + break; + case ElementType::INT64: + if (!accurate::convertNumeric(element.getInt64(), value)) + return false; + break; + case ElementType::BOOL: + if constexpr (is_integer && convert_bool_to_integer) + { + value = static_cast(element.getBool()); + break; + } + return false; + case ElementType::STRING: + { + auto rb = ReadBufferFromMemory{element.getString()}; + if constexpr (std::is_floating_point_v) + { + if (!tryReadFloatText(value, rb) || !rb.eof()) + return false; + } + else + { + if (tryReadIntText(value, rb) && rb.eof()) + break; + + /// Try to parse float and convert it to integer. + Float64 tmp_float; + rb.position() = rb.buffer().begin(); + if (!tryReadFloatText(tmp_float, rb) || !rb.eof()) + return false; + + if (!accurate::convertNumeric(tmp_float, value)) + return false; + } + break; + } + default: + return false; + } + + if (dest.getDataType() == TypeIndex::LowCardinality) + { + ColumnLowCardinality & col_low = assert_cast(dest); + col_low.insertData(reinterpret_cast(&value), sizeof(value)); + } + else + { + auto & col_vec = assert_cast &>(dest); + col_vec.insertValue(value); + } + return true; + } +}; + + +template +using JSONExtractInt64Impl = JSONExtractNumericImpl; +template +using JSONExtractUInt64Impl = JSONExtractNumericImpl; +template +using JSONExtractFloat64Impl = JSONExtractNumericImpl; + + +template +class JSONExtractBoolImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_shared(); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + bool value; + switch (element.type()) + { + case ElementType::BOOL: + value = element.getBool(); + break; + case ElementType::INT64: + value = element.getInt64() != 0; + break; + case ElementType::UINT64: + value = element.getUInt64() != 0; + break; + default: + return false; + } + + auto & col_vec = assert_cast &>(dest); + col_vec.insertValue(static_cast(value)); + return true; + } +}; + +template +class JSONExtractRawImpl; + +template +class JSONExtractStringImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_shared(); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + if (element.isNull()) + return false; + + if (!element.isString()) + return JSONExtractRawImpl::insertResultToColumn(dest, element, {}); + + auto str = element.getString(); + + if (dest.getDataType() == TypeIndex::LowCardinality) + { + ColumnLowCardinality & col_low = assert_cast(dest); + col_low.insertData(str.data(), str.size()); + } + else + { + ColumnString & col_str = assert_cast(dest); + col_str.insertData(str.data(), str.size()); + } + return true; + } +}; + +/// Nodes of the extract tree. We need the extract tree to extract from JSON complex values containing array, tuples or nullables. +template +struct JSONExtractTree +{ + using Element = typename JSONParser::Element; + + class Node + { + public: + Node() = default; + virtual ~Node() = default; + virtual bool insertResultToColumn(IColumn &, const Element &) = 0; + }; + + template + class NumericNode : public Node + { + public: + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + return JSONExtractNumericImpl::insertResultToColumn(dest, element, {}); + } + }; + + class LowCardinalityFixedStringNode : public Node + { + public: + explicit LowCardinalityFixedStringNode(const size_t fixed_length_) : fixed_length(fixed_length_) { } + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + // If element is an object we delegate the insertion to JSONExtractRawImpl + if (element.isObject()) + return JSONExtractRawImpl::insertResultToLowCardinalityFixedStringColumn(dest, element, fixed_length); + else if (!element.isString()) + return false; + + auto str = element.getString(); + if (str.size() > fixed_length) + return false; + + // For the non low cardinality case of FixedString, the padding is done in the FixedString Column implementation. + // In order to avoid having to pass the data to a FixedString Column and read it back (which would slow down the execution) + // the data is padded here and written directly to the Low Cardinality Column + if (str.size() == fixed_length) + { + assert_cast(dest).insertData(str.data(), str.size()); + } + else + { + String padded_str(str); + padded_str.resize(fixed_length, '\0'); + + assert_cast(dest).insertData(padded_str.data(), padded_str.size()); + } + return true; + } + + private: + const size_t fixed_length; + }; + + class UUIDNode : public Node + { + public: + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + if (!element.isString()) + return false; + + auto uuid = parseFromString(element.getString()); + if (dest.getDataType() == TypeIndex::LowCardinality) + { + ColumnLowCardinality & col_low = assert_cast(dest); + col_low.insertData(reinterpret_cast(&uuid), sizeof(uuid)); + } + else + { + assert_cast(dest).insert(uuid); + } + return true; + } + }; + + template + class DecimalNode : public Node + { + public: + explicit DecimalNode(DataTypePtr data_type_) : data_type(data_type_) {} + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + const auto * type = assert_cast *>(data_type.get()); + + DecimalType value{}; + + switch (element.type()) + { + case ElementType::DOUBLE: + value = convertToDecimal, DataTypeDecimal>( + element.getDouble(), type->getScale()); + break; + case ElementType::UINT64: + value = convertToDecimal, DataTypeDecimal>( + element.getUInt64(), type->getScale()); + break; + case ElementType::INT64: + value = convertToDecimal, DataTypeDecimal>( + element.getInt64(), type->getScale()); + break; + case ElementType::STRING: { + auto rb = ReadBufferFromMemory{element.getString()}; + if (!SerializationDecimal::tryReadText(value, rb, DecimalUtils::max_precision, type->getScale())) + return false; + break; + } + default: + return false; + } + + assert_cast &>(dest).insertValue(value); + return true; + } + + private: + DataTypePtr data_type; + }; + + class StringNode : public Node + { + public: + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + return JSONExtractStringImpl::insertResultToColumn(dest, element, {}); + } + }; + + class FixedStringNode : public Node + { + public: + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + if (element.isNull()) + return false; + + if (!element.isString()) + return JSONExtractRawImpl::insertResultToFixedStringColumn(dest, element, {}); + + auto str = element.getString(); + auto & col_str = assert_cast(dest); + if (str.size() > col_str.getN()) + return false; + col_str.insertData(str.data(), str.size()); + + return true; + } + }; + + template + class EnumNode : public Node + { + public: + explicit EnumNode(const std::vector> & name_value_pairs_) : name_value_pairs(name_value_pairs_) + { + for (const auto & name_value_pair : name_value_pairs) + { + name_to_value_map.emplace(name_value_pair.first, name_value_pair.second); + only_values.emplace(name_value_pair.second); + } + } + + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + auto & col_vec = assert_cast &>(dest); + + if (element.isInt64()) + { + Type value; + if (!accurate::convertNumeric(element.getInt64(), value) || !only_values.contains(value)) + return false; + col_vec.insertValue(value); + return true; + } + + if (element.isUInt64()) + { + Type value; + if (!accurate::convertNumeric(element.getUInt64(), value) || !only_values.contains(value)) + return false; + col_vec.insertValue(value); + return true; + } + + if (element.isString()) + { + auto value = name_to_value_map.find(element.getString()); + if (value == name_to_value_map.end()) + return false; + col_vec.insertValue(value->second); + return true; + } + + return false; + } + + private: + std::vector> name_value_pairs; + std::unordered_map name_to_value_map; + std::unordered_set only_values; + }; + + class NullableNode : public Node + { + public: + explicit NullableNode(std::unique_ptr nested_) : nested(std::move(nested_)) {} + + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + ColumnNullable & col_null = assert_cast(dest); + if (!nested->insertResultToColumn(col_null.getNestedColumn(), element)) + return false; + col_null.getNullMapColumn().insertValue(0); + return true; + } + + private: + std::unique_ptr nested; + }; + + class ArrayNode : public Node + { + public: + explicit ArrayNode(std::unique_ptr nested_) : nested(std::move(nested_)) {} + + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + if (!element.isArray()) + return false; + + auto array = element.getArray(); + + ColumnArray & col_arr = assert_cast(dest); + auto & data = col_arr.getData(); + size_t old_size = data.size(); + bool were_valid_elements = false; + + for (auto value : array) + { + if (nested->insertResultToColumn(data, value)) + were_valid_elements = true; + else + data.insertDefault(); + } + + if (!were_valid_elements) + { + data.popBack(data.size() - old_size); + return false; + } + + col_arr.getOffsets().push_back(data.size()); + return true; + } + + private: + std::unique_ptr nested; + }; + + class TupleNode : public Node + { + public: + TupleNode(std::vector> nested_, const std::vector & explicit_names_) : nested(std::move(nested_)), explicit_names(explicit_names_) + { + for (size_t i = 0; i != explicit_names.size(); ++i) + name_to_index_map.emplace(explicit_names[i], i); + } + + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + ColumnTuple & tuple = assert_cast(dest); + size_t old_size = dest.size(); + bool were_valid_elements = false; + + auto set_size = [&](size_t size) + { + for (size_t i = 0; i != tuple.tupleSize(); ++i) + { + auto & col = tuple.getColumn(i); + if (col.size() != size) + { + if (col.size() > size) + col.popBack(col.size() - size); + else + while (col.size() < size) + col.insertDefault(); + } + } + }; + + if (element.isArray()) + { + auto array = element.getArray(); + auto it = array.begin(); + + for (size_t index = 0; (index != nested.size()) && (it != array.end()); ++index) + { + if (nested[index]->insertResultToColumn(tuple.getColumn(index), *it++)) + were_valid_elements = true; + else + tuple.getColumn(index).insertDefault(); + } + + set_size(old_size + static_cast(were_valid_elements)); + return were_valid_elements; + } + + if (element.isObject()) + { + auto object = element.getObject(); + if (name_to_index_map.empty()) + { + auto it = object.begin(); + for (size_t index = 0; (index != nested.size()) && (it != object.end()); ++index) + { + if (nested[index]->insertResultToColumn(tuple.getColumn(index), (*it++).second)) + were_valid_elements = true; + else + tuple.getColumn(index).insertDefault(); + } + } + else + { + for (const auto & [key, value] : object) + { + auto index = name_to_index_map.find(key); + if (index != name_to_index_map.end()) + { + if (nested[index->second]->insertResultToColumn(tuple.getColumn(index->second), value)) + were_valid_elements = true; + } + } + } + + set_size(old_size + static_cast(were_valid_elements)); + return were_valid_elements; + } + + return false; + } + + private: + std::vector> nested; + std::vector explicit_names; + std::unordered_map name_to_index_map; + }; + + class MapNode : public Node + { + public: + MapNode(std::unique_ptr key_, std::unique_ptr value_) : key(std::move(key_)), value(std::move(value_)) { } + + bool insertResultToColumn(IColumn & dest, const Element & element) override + { + if (!element.isObject()) + return false; + + ColumnMap & map_col = assert_cast(dest); + auto & offsets = map_col.getNestedColumn().getOffsets(); + auto & tuple_col = map_col.getNestedData(); + auto & key_col = tuple_col.getColumn(0); + auto & value_col = tuple_col.getColumn(1); + size_t old_size = tuple_col.size(); + + auto object = element.getObject(); + auto it = object.begin(); + for (; it != object.end(); ++it) + { + auto pair = *it; + + /// Insert key + key_col.insertData(pair.first.data(), pair.first.size()); + + /// Insert value + if (!value->insertResultToColumn(value_col, pair.second)) + value_col.insertDefault(); + } + + offsets.push_back(old_size + object.size()); + return true; + } + + private: + std::unique_ptr key; + std::unique_ptr value; + }; + + static std::unique_ptr build(const char * function_name, const DataTypePtr & type) + { + switch (type->getTypeId()) + { + case TypeIndex::UInt8: return std::make_unique>(); + case TypeIndex::UInt16: return std::make_unique>(); + case TypeIndex::UInt32: return std::make_unique>(); + case TypeIndex::UInt64: return std::make_unique>(); + case TypeIndex::UInt128: return std::make_unique>(); + case TypeIndex::UInt256: return std::make_unique>(); + case TypeIndex::Int8: return std::make_unique>(); + case TypeIndex::Int16: return std::make_unique>(); + case TypeIndex::Int32: return std::make_unique>(); + case TypeIndex::Int64: return std::make_unique>(); + case TypeIndex::Int128: return std::make_unique>(); + case TypeIndex::Int256: return std::make_unique>(); + case TypeIndex::Float32: return std::make_unique>(); + case TypeIndex::Float64: return std::make_unique>(); + case TypeIndex::String: return std::make_unique(); + case TypeIndex::FixedString: return std::make_unique(); + case TypeIndex::UUID: return std::make_unique(); + case TypeIndex::LowCardinality: + { + // The low cardinality case is treated in two different ways: + // For FixedString type, an especial class is implemented for inserting the data in the destination column, + // as the string length must be passed in order to check and pad the incoming data. + // For the rest of low cardinality types, the insertion is done in their corresponding class, adapting the data + // as needed for the insertData function of the ColumnLowCardinality. + auto dictionary_type = typeid_cast(type.get())->getDictionaryType(); + if ((*dictionary_type).getTypeId() == TypeIndex::FixedString) + { + auto fixed_length = typeid_cast(dictionary_type.get())->getN(); + return std::make_unique(fixed_length); + } + return build(function_name, dictionary_type); + } + case TypeIndex::Decimal256: return std::make_unique>(type); + case TypeIndex::Decimal128: return std::make_unique>(type); + case TypeIndex::Decimal64: return std::make_unique>(type); + case TypeIndex::Decimal32: return std::make_unique>(type); + case TypeIndex::Enum8: + return std::make_unique>(static_cast(*type).getValues()); + case TypeIndex::Enum16: + return std::make_unique>(static_cast(*type).getValues()); + case TypeIndex::Nullable: + { + return std::make_unique(build(function_name, static_cast(*type).getNestedType())); + } + case TypeIndex::Array: + { + return std::make_unique(build(function_name, static_cast(*type).getNestedType())); + } + case TypeIndex::Tuple: + { + const auto & tuple = static_cast(*type); + const auto & tuple_elements = tuple.getElements(); + std::vector> elements; + elements.reserve(tuple_elements.size()); + for (const auto & tuple_element : tuple_elements) + elements.emplace_back(build(function_name, tuple_element)); + return std::make_unique(std::move(elements), tuple.haveExplicitNames() ? tuple.getElementNames() : Strings{}); + } + case TypeIndex::Map: + { + const auto & map_type = static_cast(*type); + const auto & key_type = map_type.getKeyType(); + if (!isString(removeLowCardinality(key_type))) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Function {} doesn't support the return type schema: {} with key type not String", + String(function_name), + type->getName()); + + const auto & value_type = map_type.getValueType(); + return std::make_unique(build(function_name, key_type), build(function_name, value_type)); + } + default: + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Function {} doesn't support the return type schema: {}", + String(function_name), type->getName()); + } + } +}; + + +template +class JSONExtractImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char * function_name, const ColumnsWithTypeAndName & arguments) + { + if (arguments.size() < 2) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least two arguments", String(function_name)); + + const auto & col = arguments.back(); + const auto * col_type_const = typeid_cast(col.column.get()); + if (!col_type_const || !isString(col.type)) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, + "The last argument of function {} should " + "be a constant string specifying the return data type, illegal value: {}", + String(function_name), col.name); + + return DataTypeFactory::instance().get(col_type_const->getValue()); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 2; } + + void prepare(const char * function_name, const ColumnsWithTypeAndName &, const DataTypePtr & result_type) + { + extract_tree = JSONExtractTree::build(function_name, result_type); + } + + bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + return extract_tree->insertResultToColumn(dest, element); + } + +protected: + std::unique_ptr::Node> extract_tree; +}; + + +template +class JSONExtractKeysAndValuesImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char * function_name, const ColumnsWithTypeAndName & arguments) + { + if (arguments.size() < 2) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least two arguments", String(function_name)); + + const auto & col = arguments.back(); + const auto * col_type_const = typeid_cast(col.column.get()); + if (!col_type_const || !isString(col.type)) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, + "The last argument of function {} should " + "be a constant string specifying the values' data type, illegal value: {}", + String(function_name), col.name); + + DataTypePtr key_type = std::make_unique(); + DataTypePtr value_type = DataTypeFactory::instance().get(col_type_const->getValue()); + DataTypePtr tuple_type = std::make_unique(DataTypes{key_type, value_type}); + return std::make_unique(tuple_type); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 2; } + + void prepare(const char * function_name, const ColumnsWithTypeAndName &, const DataTypePtr & result_type) + { + const auto tuple_type = typeid_cast(result_type.get())->getNestedType(); + const auto value_type = typeid_cast(tuple_type.get())->getElements()[1]; + extract_tree = JSONExtractTree::build(function_name, value_type); + } + + bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + if (!element.isObject()) + return false; + + auto object = element.getObject(); + + auto & col_arr = assert_cast(dest); + auto & col_tuple = assert_cast(col_arr.getData()); + size_t old_size = col_tuple.size(); + auto & col_key = assert_cast(col_tuple.getColumn(0)); + auto & col_value = col_tuple.getColumn(1); + + for (const auto & [key, value] : object) + { + if (extract_tree->insertResultToColumn(col_value, value)) + col_key.insertData(key.data(), key.size()); + } + + if (col_tuple.size() == old_size) + return false; + + col_arr.getOffsets().push_back(col_tuple.size()); + return true; + } + +private: + std::unique_ptr::Node> extract_tree; +}; + + +template +class JSONExtractRawImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_shared(); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + if (dest.getDataType() == TypeIndex::LowCardinality) + { + ColumnString::Chars chars; + WriteBufferFromVector buf(chars, AppendModeTag()); + traverse(element, buf); + buf.finalize(); + assert_cast(dest).insertData(reinterpret_cast(chars.data()), chars.size()); + } + else + { + ColumnString & col_str = assert_cast(dest); + auto & chars = col_str.getChars(); + WriteBufferFromVector buf(chars, AppendModeTag()); + traverse(element, buf); + buf.finalize(); + chars.push_back(0); + col_str.getOffsets().push_back(chars.size()); + } + return true; + } + + // We use insertResultToFixedStringColumn in case we are inserting raw data in a FixedString column + static bool insertResultToFixedStringColumn(IColumn & dest, const Element & element, std::string_view) + { + ColumnFixedString::Chars chars; + WriteBufferFromVector buf(chars, AppendModeTag()); + traverse(element, buf); + buf.finalize(); + + auto & col_str = assert_cast(dest); + + if (chars.size() > col_str.getN()) + return false; + + chars.resize_fill(col_str.getN()); + col_str.insertData(reinterpret_cast(chars.data()), chars.size()); + + + return true; + } + + // We use insertResultToLowCardinalityFixedStringColumn in case we are inserting raw data in a Low Cardinality FixedString column + static bool insertResultToLowCardinalityFixedStringColumn(IColumn & dest, const Element & element, size_t fixed_length) + { + if (element.getObject().size() > fixed_length) + return false; + + ColumnFixedString::Chars chars; + WriteBufferFromVector buf(chars, AppendModeTag()); + traverse(element, buf); + buf.finalize(); + + if (chars.size() > fixed_length) + return false; + chars.resize_fill(fixed_length); + assert_cast(dest).insertData(reinterpret_cast(chars.data()), chars.size()); + + return true; + } + +private: + static void traverse(const Element & element, WriteBuffer & buf) + { + if (element.isInt64()) + { + writeIntText(element.getInt64(), buf); + return; + } + if (element.isUInt64()) + { + writeIntText(element.getUInt64(), buf); + return; + } + if (element.isDouble()) + { + writeFloatText(element.getDouble(), buf); + return; + } + if (element.isBool()) + { + if (element.getBool()) + writeCString("true", buf); + else + writeCString("false", buf); + return; + } + if (element.isString()) + { + writeJSONString(element.getString(), buf, formatSettings()); + return; + } + if (element.isArray()) + { + writeChar('[', buf); + bool need_comma = false; + for (auto value : element.getArray()) + { + if (std::exchange(need_comma, true)) + writeChar(',', buf); + traverse(value, buf); + } + writeChar(']', buf); + return; + } + if (element.isObject()) + { + writeChar('{', buf); + bool need_comma = false; + for (auto [key, value] : element.getObject()) + { + if (std::exchange(need_comma, true)) + writeChar(',', buf); + writeJSONString(key, buf, formatSettings()); + writeChar(':', buf); + traverse(value, buf); + } + writeChar('}', buf); + return; + } + if (element.isNull()) + { + writeCString("null", buf); + return; + } + } + + static const FormatSettings & formatSettings() + { + static const FormatSettings the_instance = [] + { + FormatSettings settings; + settings.json.escape_forward_slashes = false; + return settings; + }(); + return the_instance; + } +}; + + +template +class JSONExtractArrayRawImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_shared(std::make_shared()); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + if (!element.isArray()) + return false; + + auto array = element.getArray(); + ColumnArray & col_res = assert_cast(dest); + + for (auto value : array) + JSONExtractRawImpl::insertResultToColumn(col_res.getData(), value, {}); + + col_res.getOffsets().push_back(col_res.getOffsets().back() + array.size()); + return true; + } +}; + + +template +class JSONExtractKeysAndValuesRawImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + DataTypePtr string_type = std::make_unique(); + DataTypePtr tuple_type = std::make_unique(DataTypes{string_type, string_type}); + return std::make_unique(tuple_type); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + if (!element.isObject()) + return false; + + auto object = element.getObject(); + + auto & col_arr = assert_cast(dest); + auto & col_tuple = assert_cast(col_arr.getData()); + auto & col_key = assert_cast(col_tuple.getColumn(0)); + auto & col_value = assert_cast(col_tuple.getColumn(1)); + + for (const auto & [key, value] : object) + { + col_key.insertData(key.data(), key.size()); + JSONExtractRawImpl::insertResultToColumn(col_value, value, {}); + } + + col_arr.getOffsets().push_back(col_arr.getOffsets().back() + object.size()); + return true; + } +}; + +template +class JSONExtractKeysImpl +{ +public: + using Element = typename JSONParser::Element; + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_unique(std::make_shared()); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + bool insertResultToColumn(IColumn & dest, const Element & element, std::string_view) + { + if (!element.isObject()) + return false; + + auto object = element.getObject(); + + ColumnArray & col_res = assert_cast(dest); + auto & col_key = assert_cast(col_res.getData()); + + for (const auto & [key, value] : object) + { + col_key.insertData(key.data(), key.size()); + } + + col_res.getOffsets().push_back(col_res.getOffsets().back() + object.size()); + return true; + } +}; + +} diff --git a/src/Functions/IFunction.cpp b/src/Functions/IFunction.cpp index 1c30dee04820..c5d1a574871f 100644 --- a/src/Functions/IFunction.cpp +++ b/src/Functions/IFunction.cpp @@ -194,7 +194,15 @@ ColumnPtr IExecutableFunction::defaultImplementationForNulls( if (null_presence.has_nullable) { ColumnsWithTypeAndName temporary_columns = createBlockWithNestedColumns(args); - auto temporary_result_type = removeNullable(result_type); + + DataTypePtr temporary_result_type; + if (resolver) + { + auto temporary_function_base = resolver->build(temporary_columns); + temporary_result_type = temporary_function_base->getResultType(); + } + else + temporary_result_type = removeNullable(result_type); auto res = executeWithoutLowCardinalityColumns(temporary_columns, temporary_result_type, input_rows_count, dry_run); return wrapInNullable(res, args, result_type, input_rows_count); diff --git a/src/Functions/IFunction.h b/src/Functions/IFunction.h index b613ddb89873..1e161d4536e5 100644 --- a/src/Functions/IFunction.h +++ b/src/Functions/IFunction.h @@ -31,6 +31,8 @@ namespace ErrorCodes } class Field; +class IFunctionOverloadResolver; +using FunctionOverloadResolverPtr = std::shared_ptr; /// The simplest executable object. /// Motivation: @@ -48,6 +50,8 @@ class IExecutableFunction ColumnPtr execute(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const; + void setResolver(const FunctionOverloadResolverPtr & resolver_) { resolver = resolver_; } + protected: virtual ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const = 0; @@ -99,6 +103,8 @@ class IExecutableFunction */ virtual bool canBeExecutedOnDefaultArguments() const { return true; } + FunctionOverloadResolverPtr resolver; + private: ColumnPtr defaultImplementationForConstantArguments( @@ -396,7 +402,6 @@ class IFunctionOverloadResolver DataTypePtr getReturnTypeWithoutLowCardinality(const ColumnsWithTypeAndName & arguments) const; }; -using FunctionOverloadResolverPtr = std::shared_ptr; /// Old function interface. Check documentation in IFunction.h. /// If client do not need stateful properties it can implement this interface. diff --git a/src/Functions/array/FunctionArrayMapped.h b/src/Functions/array/FunctionArrayMapped.h index 89599edd9d10..ccfc7869036e 100644 --- a/src/Functions/array/FunctionArrayMapped.h +++ b/src/Functions/array/FunctionArrayMapped.h @@ -36,7 +36,7 @@ namespace ErrorCodes extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int LOGICAL_ERROR; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } @@ -308,7 +308,7 @@ class FunctionArrayMapped : public IFunction if (getOffsetsPtr(*column_array) != offsets_column && getOffsets(*column_array) != typeid_cast(*offsets_column).getData()) throw Exception( - ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "{}s passed to {} must have equal size", argument_type_name, getName()); diff --git a/src/Functions/array/arrayDistance.cpp b/src/Functions/array/arrayDistance.cpp index c1137848cc50..c68c89ee0d59 100644 --- a/src/Functions/array/arrayDistance.cpp +++ b/src/Functions/array/arrayDistance.cpp @@ -16,7 +16,7 @@ namespace ErrorCodes extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int LOGICAL_ERROR; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; extern const int ARGUMENT_OUT_OF_BOUND; } @@ -356,7 +356,7 @@ class FunctionArrayDistance : public IFunction { ColumnArray::Offset prev_offset = row > 0 ? offsets_x[row] : 0; throw Exception( - ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Arguments of function {} have different array sizes: {} and {}", getName(), offsets_x[row] - prev_offset, @@ -423,7 +423,7 @@ class FunctionArrayDistance : public IFunction if (unlikely(offsets_x[0] != offsets_y[row] - prev_offset)) { throw Exception( - ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Arguments of function {} have different array sizes: {} and {}", getName(), offsets_x[0], diff --git a/src/Functions/array/arrayElement.cpp b/src/Functions/array/arrayElement.cpp index 299f25b82927..8dabf17a18d9 100644 --- a/src/Functions/array/arrayElement.cpp +++ b/src/Functions/array/arrayElement.cpp @@ -483,17 +483,25 @@ ColumnPtr FunctionArrayElement::executeNumberConst( } else if (index.getType() == Field::Types::Int64) { - /// Cast to UInt64 before negation allows to avoid undefined behaviour for negation of the most negative number. - /// NOTE: this would be undefined behaviour in C++ sense, but nevertheless, compiler cannot see it on user provided data, - /// and generates the code that we want on supported CPU architectures (overflow in sense of two's complement arithmetic). - /// This is only needed to avoid UBSan report. - - /// Negative array indices work this way: - /// arr[-1] is the element at offset 0 from the last - /// arr[-2] is the element at offset 1 from the last and so on. - - ArrayElementNumImpl::template vectorConst( - col_nested->getData(), col_array->getOffsets(), -(static_cast(index.safeGet()) + 1), col_res->getData(), builder); + auto value = index.safeGet(); + if (value >= 0) + { + ArrayElementNumImpl::template vectorConst( + col_nested->getData(), col_array->getOffsets(), static_cast(value) - 1, col_res->getData(), builder); + } + else + { + /// Cast to UInt64 before negation allows to avoid undefined behaviour for negation of the most negative number. + /// NOTE: this would be undefined behaviour in C++ sense, but nevertheless, compiler cannot see it on user provided data, + /// and generates the code that we want on supported CPU architectures (overflow in sense of two's complement arithmetic). + /// This is only needed to avoid UBSan report. + + /// Negative array indices work this way: + /// arr[-1] is the element at offset 0 from the last + /// arr[-2] is the element at offset 1 from the last and so on. + ArrayElementNumImpl::template vectorConst( + col_nested->getData(), col_array->getOffsets(), -(static_cast(index.safeGet()) + 1), col_res->getData(), builder); + } } else throw Exception(ErrorCodes::LOGICAL_ERROR, "Illegal type of array index"); diff --git a/src/Functions/array/arrayEnumerateExtended.h b/src/Functions/array/arrayEnumerateExtended.h index c3d69bb6972a..3f145c05b54c 100644 --- a/src/Functions/array/arrayEnumerateExtended.h +++ b/src/Functions/array/arrayEnumerateExtended.h @@ -20,7 +20,7 @@ namespace ErrorCodes extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; } class FunctionArrayEnumerateUniq; @@ -153,7 +153,7 @@ ColumnPtr FunctionArrayEnumerateExtended::executeImpl(const ColumnsWith offsets_column = array->getOffsetsPtr(); } else if (offsets_i != *offsets) - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, "Lengths of all arrays passed to {} must be equal.", + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Lengths of all arrays passed to {} must be equal.", getName()); const auto * array_data = &array->getData(); diff --git a/src/Functions/array/arrayEnumerateRanked.h b/src/Functions/array/arrayEnumerateRanked.h index 73feb3e46ea0..8a348c07421e 100644 --- a/src/Functions/array/arrayEnumerateRanked.h +++ b/src/Functions/array/arrayEnumerateRanked.h @@ -60,7 +60,7 @@ namespace DB namespace ErrorCodes { extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; } class FunctionArrayEnumerateUniqRanked; @@ -194,7 +194,7 @@ ColumnPtr FunctionArrayEnumerateRankedExtended::executeImpl( { if (*offsets_by_depth[0] != array->getOffsets()) { - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Lengths and effective depths of all arrays passed to {} must be equal.", getName()); } } @@ -217,7 +217,7 @@ ColumnPtr FunctionArrayEnumerateRankedExtended::executeImpl( { if (*offsets_by_depth[col_depth] != array->getOffsets()) { - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Lengths and effective depths of all arrays passed to {} must be equal.", getName()); } } @@ -225,7 +225,7 @@ ColumnPtr FunctionArrayEnumerateRankedExtended::executeImpl( if (col_depth < arrays_depths.depths[array_num]) { - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "{}: Passed array number {} depth ({}) is more than the actual array depth ({}).", getName(), array_num, std::to_string(arrays_depths.depths[array_num]), col_depth); } diff --git a/src/Functions/array/arrayReduce.cpp b/src/Functions/array/arrayReduce.cpp index d4896595941c..a4b2cc037ab1 100644 --- a/src/Functions/array/arrayReduce.cpp +++ b/src/Functions/array/arrayReduce.cpp @@ -19,7 +19,7 @@ namespace DB namespace ErrorCodes { - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; @@ -144,7 +144,7 @@ ColumnPtr FunctionArrayReduce::executeImpl(const ColumnsWithTypeAndName & argume if (i == 0) offsets = offsets_i; else if (*offsets_i != *offsets) - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, "Lengths of all arrays passed to {} must be equal.", + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Lengths of all arrays passed to {} must be equal.", getName()); } const IColumn ** aggregate_arguments = aggregate_arguments_vec.data(); diff --git a/src/Functions/array/arrayReduceInRanges.cpp b/src/Functions/array/arrayReduceInRanges.cpp index 07391c963a6a..790bc3ef8798 100644 --- a/src/Functions/array/arrayReduceInRanges.cpp +++ b/src/Functions/array/arrayReduceInRanges.cpp @@ -21,7 +21,7 @@ namespace DB namespace ErrorCodes { - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; @@ -190,7 +190,7 @@ ColumnPtr FunctionArrayReduceInRanges::executeImpl( if (i == 0) offsets = offsets_i; else if (*offsets_i != *offsets) - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, "Lengths of all arrays passed to {} must be equal.", + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Lengths of all arrays passed to {} must be equal.", getName()); } const IColumn ** aggregate_arguments = aggregate_arguments_vec.data(); diff --git a/src/Functions/array/arrayUniq.cpp b/src/Functions/array/arrayUniq.cpp index 1d1cf4e6392f..81ba5b620943 100644 --- a/src/Functions/array/arrayUniq.cpp +++ b/src/Functions/array/arrayUniq.cpp @@ -18,7 +18,7 @@ namespace DB namespace ErrorCodes { - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; @@ -151,7 +151,7 @@ ColumnPtr FunctionArrayUniq::executeImpl(const ColumnsWithTypeAndName & argument if (i == 0) offsets = &offsets_i; else if (offsets_i != *offsets) - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, "Lengths of all arrays passed to {} must be equal.", + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Lengths of all arrays passed to {} must be equal.", getName()); const auto * array_data = &array->getData(); diff --git a/src/Functions/array/arrayZip.cpp b/src/Functions/array/arrayZip.cpp index 3a50491fd4b2..44c323e3fe31 100644 --- a/src/Functions/array/arrayZip.cpp +++ b/src/Functions/array/arrayZip.cpp @@ -13,7 +13,7 @@ namespace DB namespace ErrorCodes { extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_COLUMN; } @@ -81,7 +81,7 @@ class FunctionArrayZip : public IFunction } else if (!column_array->hasEqualOffsets(static_cast(*first_array_column))) { - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "The argument 1 and argument {} of function {} have different array sizes", i + 1, getName()); } diff --git a/src/Functions/map.cpp b/src/Functions/map.cpp index 3160c5ddb437..4217550d5b03 100644 --- a/src/Functions/map.cpp +++ b/src/Functions/map.cpp @@ -26,6 +26,8 @@ namespace ErrorCodes { extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; + extern const int ILLEGAL_COLUMN; } namespace @@ -147,6 +149,94 @@ class FunctionMap : public IFunction } }; +/// mapFromArrays(keys, values) is a function that allows you to make key-value pair from a pair of arrays +class FunctionMapFromArrays : public IFunction +{ +public: + static constexpr auto name = "mapFromArrays"; + + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } + bool useDefaultImplementationForNulls() const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (arguments.size() != 2) + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Function {} requires 2 arguments, but {} given", + getName(), + arguments.size()); + + /// The first argument should always be Array. + /// Because key type can not be nested type of Map, which is Tuple + DataTypePtr key_type; + if (const auto * keys_type = checkAndGetDataType(arguments[0].get())) + key_type = keys_type->getNestedType(); + else + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be an Array", getName()); + + DataTypePtr value_type; + if (const auto * value_array_type = checkAndGetDataType(arguments[1].get())) + value_type = value_array_type->getNestedType(); + else if (const auto * value_map_type = checkAndGetDataType(arguments[1].get())) + value_type = std::make_shared(value_map_type->getKeyValueTypes()); + else + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument for function {} must be Array or Map", getName()); + + DataTypes key_value_types{key_type, value_type}; + return std::make_shared(key_value_types); + } + + ColumnPtr executeImpl( + const ColumnsWithTypeAndName & arguments, const DataTypePtr & /* result_type */, size_t /* input_rows_count */) const override + { + bool is_keys_const = isColumnConst(*arguments[0].column); + ColumnPtr holder_keys; + const ColumnArray * col_keys; + if (is_keys_const) + { + holder_keys = arguments[0].column->convertToFullColumnIfConst(); + col_keys = checkAndGetColumn(holder_keys.get()); + } + else + { + col_keys = checkAndGetColumn(arguments[0].column.get()); + } + + if (!col_keys) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "The first argument of function {} must be Array", getName()); + + bool is_values_const = isColumnConst(*arguments[1].column); + ColumnPtr holder_values; + if (is_values_const) + holder_values = arguments[1].column->convertToFullColumnIfConst(); + else + holder_values = arguments[1].column; + + const ColumnArray * col_values; + if (const auto * col_values_array = checkAndGetColumn(holder_values.get())) + col_values = col_values_array; + else if (const auto * col_values_map = checkAndGetColumn(holder_values.get())) + col_values = &col_values_map->getNestedColumn(); + else + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "The second arguments of function {} must be Array or Map", getName()); + + if (!col_keys->hasEqualOffsets(*col_values)) + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Two arguments for function {} must have equal sizes", getName()); + + const auto & data_keys = col_keys->getDataPtr(); + const auto & data_values = col_values->getDataPtr(); + const auto & offsets = col_keys->getOffsetsPtr(); + auto nested_column = ColumnArray::create(ColumnTuple::create(Columns{data_keys, data_values}), offsets); + return ColumnMap::create(nested_column); + } +}; struct NameMapContains { static constexpr auto name = "mapContains"; }; @@ -649,6 +739,9 @@ REGISTER_FUNCTION(Map) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); + factory.registerFunction(); + factory.registerAlias("MAP_FROM_ARRAYS", "mapFromArrays"); + } } diff --git a/src/Functions/tupleElement.cpp b/src/Functions/tupleElement.cpp index 829262de30a7..048159c3b2f6 100644 --- a/src/Functions/tupleElement.cpp +++ b/src/Functions/tupleElement.cpp @@ -21,7 +21,7 @@ namespace ErrorCodes extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int NOT_FOUND_COLUMN_IN_BLOCK; extern const int NUMBER_OF_DIMENSIONS_MISMATCHED; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; } namespace @@ -201,7 +201,7 @@ class FunctionTupleElement : public IFunction const auto & array_y = *assert_cast(col_y.get()); if (!array_x.hasEqualOffsets(array_y)) { - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "The argument 1 and argument 3 of function {} have different array sizes", getName()); } } @@ -223,7 +223,7 @@ class FunctionTupleElement : public IFunction { if (unlikely(offsets_x[0] != offsets_y[row] - prev_offset)) { - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "The argument 1 and argument 3 of function {} have different array sizes", getName()); } prev_offset = offsets_y[row]; diff --git a/src/Functions/validateNestedArraySizes.cpp b/src/Functions/validateNestedArraySizes.cpp index 7e1dbc798d80..c422637ba7f2 100644 --- a/src/Functions/validateNestedArraySizes.cpp +++ b/src/Functions/validateNestedArraySizes.cpp @@ -12,7 +12,7 @@ namespace ErrorCodes { extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; } /** Function validateNestedArraySizes is used to check the consistency of Nested DataType subcolumns's offsets when Update @@ -106,7 +106,7 @@ ColumnPtr FunctionValidateNestedArraySizes::executeImpl( else if (first_length != length) { throw Exception( - ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, + ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Elements '{}' and '{}' of Nested data structure (Array columns) " "have different array sizes ({} and {} respectively) on row {}", arguments[1].name, arguments[args_idx].name, first_length, length, i); diff --git a/src/Interpreters/ArrayJoinAction.cpp b/src/Interpreters/ArrayJoinAction.cpp index 3650b888f9eb..4f42122e98f3 100644 --- a/src/Interpreters/ArrayJoinAction.cpp +++ b/src/Interpreters/ArrayJoinAction.cpp @@ -14,7 +14,7 @@ namespace DB namespace ErrorCodes { extern const int LOGICAL_ERROR; - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; extern const int TYPE_MISMATCH; } @@ -186,7 +186,7 @@ void ArrayJoinAction::execute(Block & block) const ColumnArray & array = typeid_cast(*array_ptr); if (!is_unaligned && !array.hasEqualOffsets(*any_array)) - throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH, "Sizes of ARRAY-JOIN-ed arrays do not match"); + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Sizes of ARRAY-JOIN-ed arrays do not match"); current.column = typeid_cast(*array_ptr).getDataPtr(); current.type = type->getNestedType(); diff --git a/src/Interpreters/TableJoin.h b/src/Interpreters/TableJoin.h index 84390adc0df4..376d36520abf 100644 --- a/src/Interpreters/TableJoin.h +++ b/src/Interpreters/TableJoin.h @@ -208,6 +208,9 @@ class TableJoin JoinKind kind() const { return table_join.kind; } JoinStrictness strictness() const { return table_join.strictness; } bool sameStrictnessAndKind(JoinStrictness, JoinKind) const; + void setKind(JoinKind kind) { table_join.kind = kind; } + void setStrictness(JoinStrictness strictness) { table_join.strictness = strictness; } + void setColumnsFromJoinedTable(NamesAndTypesList columns_from_joined_table_) {columns_from_joined_table = columns_from_joined_table_;} const SizeLimits & sizeLimits() const { return size_limits; } VolumePtr getTemporaryVolume() { return tmp_volume; } diff --git a/src/Parsers/ParserInsertQuery.cpp b/src/Parsers/ParserInsertQuery.cpp index 9d01cda98a2f..8601e12ebcba 100644 --- a/src/Parsers/ParserInsertQuery.cpp +++ b/src/Parsers/ParserInsertQuery.cpp @@ -256,14 +256,21 @@ bool ParserInsertQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) if (infile) { query->infile = infile; + query->compression = compression; + + query->children.push_back(infile); if (compression) - query->compression = compression; + query->children.push_back(compression); } if (table_function) { query->table_function = table_function; query->partition_by = partition_by_expr; + + query->children.push_back(table_function); + if (partition_by_expr) + query->children.push_back(partition_by_expr); } else { diff --git a/src/Processors/Formats/Impl/ParquetBlockInputFormat.h b/src/Processors/Formats/Impl/ParquetBlockInputFormat.h index 258140905874..75a5be131305 100644 --- a/src/Processors/Formats/Impl/ParquetBlockInputFormat.h +++ b/src/Processors/Formats/Impl/ParquetBlockInputFormat.h @@ -29,6 +29,7 @@ class ParquetBlockInputFormat : public IInputFormat private: Chunk generate() override; +protected: void prepareReader(); void onCancel() override diff --git a/src/Processors/Transforms/AggregatingTransform.h b/src/Processors/Transforms/AggregatingTransform.h index 0771761fa5c4..453c2229315b 100644 --- a/src/Processors/Transforms/AggregatingTransform.h +++ b/src/Processors/Transforms/AggregatingTransform.h @@ -24,6 +24,9 @@ using AggregatorListPtr = std::shared_ptr; using AggregatorList = std::list; using AggregatorListPtr = std::shared_ptr; +using AggregatorList = std::list; +using AggregatorListPtr = std::shared_ptr; + struct AggregatingTransformParams { Aggregator::Params params; diff --git a/src/Storages/MergeTree/IMergeTreeDataPart.cpp b/src/Storages/MergeTree/IMergeTreeDataPart.cpp index e1427413f62b..7b4b7d3b94f3 100644 --- a/src/Storages/MergeTree/IMergeTreeDataPart.cpp +++ b/src/Storages/MergeTree/IMergeTreeDataPart.cpp @@ -471,6 +471,7 @@ void IMergeTreeDataPart::setColumns(const NamesAndTypesList & new_columns, const } columns_description = ColumnsDescription(columns); + columns_description_with_collected_nested = ColumnsDescription(Nested::collect(columns)); } NameAndTypePair IMergeTreeDataPart::getColumn(const String & column_name) const diff --git a/src/Storages/MergeTree/IMergeTreeDataPart.h b/src/Storages/MergeTree/IMergeTreeDataPart.h index 68d5147362b7..154a87cb2647 100644 --- a/src/Storages/MergeTree/IMergeTreeDataPart.h +++ b/src/Storages/MergeTree/IMergeTreeDataPart.h @@ -148,6 +148,7 @@ class IMergeTreeDataPart : public std::enable_shared_from_this tryGetColumn(const String & column_name) const; @@ -531,6 +532,10 @@ class IMergeTreeDataPart : public std::enable_shared_from_this #include +#include #include namespace DB @@ -41,6 +42,10 @@ class IMergeTreeDataPartInfoForReader : public WithContext virtual const NamesAndTypesList & getColumns() const = 0; + virtual const ColumnsDescription & getColumnsDescription() const = 0; + + virtual const ColumnsDescription & getColumnsDescriptionWithCollectedNested() const = 0; + virtual std::optional getColumnPosition(const String & column_name) const = 0; virtual String getColumnNameWithMinimumCompressedSize(bool with_subcolumns) const = 0; diff --git a/src/Storages/MergeTree/IMergeTreeReader.cpp b/src/Storages/MergeTree/IMergeTreeReader.cpp index 10476c1b1299..d8c3de622cb2 100644 --- a/src/Storages/MergeTree/IMergeTreeReader.cpp +++ b/src/Storages/MergeTree/IMergeTreeReader.cpp @@ -41,8 +41,12 @@ IMergeTreeReader::IMergeTreeReader( , alter_conversions(data_part_info_for_read->getAlterConversions()) /// For wide parts convert plain arrays of Nested to subcolumns /// to allow to use shared offset column from cache. - , requested_columns(data_part_info_for_read->isWidePart() ? Nested::convertToSubcolumns(columns_) : columns_) - , part_columns(data_part_info_for_read->isWidePart() ? Nested::collect(data_part_info_for_read->getColumns()) : data_part_info_for_read->getColumns()) + , requested_columns(data_part_info_for_read->isWidePart() + ? Nested::convertToSubcolumns(columns_) + : columns_) + , part_columns(data_part_info_for_read->isWidePart() + ? data_part_info_for_read->getColumnsDescriptionWithCollectedNested() + : data_part_info_for_read->getColumnsDescription()) { columns_to_read.reserve(requested_columns.size()); serializations.reserve(requested_columns.size()); diff --git a/src/Storages/MergeTree/IMergeTreeReader.h b/src/Storages/MergeTree/IMergeTreeReader.h index 16db13692aa9..4a383e4e5218 100644 --- a/src/Storages/MergeTree/IMergeTreeReader.h +++ b/src/Storages/MergeTree/IMergeTreeReader.h @@ -104,7 +104,7 @@ class IMergeTreeReader : private boost::noncopyable NamesAndTypesList requested_columns; /// Actual columns description in part. - ColumnsDescription part_columns; + const ColumnsDescription & part_columns; }; } diff --git a/src/Storages/MergeTree/KeyCondition.cpp b/src/Storages/MergeTree/KeyCondition.cpp index 1fcf564693f8..fda1daec3a31 100644 --- a/src/Storages/MergeTree/KeyCondition.cpp +++ b/src/Storages/MergeTree/KeyCondition.cpp @@ -739,12 +739,9 @@ KeyCondition::KeyCondition( , single_point(single_point_) , strict(strict_) { - for (size_t i = 0, size = key_column_names.size(); i < size; ++i) - { - const auto & name = key_column_names[i]; + for (const auto & name : key_column_names) if (!key_columns.contains(name)) - key_columns[name] = i; - } + key_columns[name] = key_columns.size(); auto filter_node = buildFilterNode(query, additional_filter_asts); @@ -807,12 +804,9 @@ KeyCondition::KeyCondition( , single_point(single_point_) , strict(strict_) { - for (size_t i = 0, size = key_column_names.size(); i < size; ++i) - { - const auto & name = key_column_names[i]; + for (const auto & name : key_column_names) if (!key_columns.contains(name)) - key_columns[name] = i; - } + key_columns[name] = key_columns.size(); if (!filter_dag) { diff --git a/src/Storages/MergeTree/LoadedMergeTreeDataPartInfoForReader.h b/src/Storages/MergeTree/LoadedMergeTreeDataPartInfoForReader.h index bc786ec0428e..3363c75dd6ff 100644 --- a/src/Storages/MergeTree/LoadedMergeTreeDataPartInfoForReader.h +++ b/src/Storages/MergeTree/LoadedMergeTreeDataPartInfoForReader.h @@ -27,6 +27,10 @@ class LoadedMergeTreeDataPartInfoForReader final : public IMergeTreeDataPartInfo const NamesAndTypesList & getColumns() const override { return data_part->getColumns(); } + const ColumnsDescription & getColumnsDescription() const override { return data_part->getColumnsDescription(); } + + const ColumnsDescription & getColumnsDescriptionWithCollectedNested() const override { return data_part->getColumnsDescriptionWithCollectedNested(); } + std::optional getColumnPosition(const String & column_name) const override { return data_part->getColumnPosition(column_name); } AlterConversions getAlterConversions() const override { return data_part->storage.getAlterConversionsForPart(data_part); } diff --git a/src/Storages/System/StorageSystemContributors.generated.cpp b/src/Storages/System/StorageSystemContributors.generated.cpp index f69f9f8ee7fb..6ca6a9db0461 100644 --- a/src/Storages/System/StorageSystemContributors.generated.cpp +++ b/src/Storages/System/StorageSystemContributors.generated.cpp @@ -30,6 +30,7 @@ const char * auto_contributors[] { "Aleksandr Shalimov", "Aleksandra (Ася)", "Aleksandrov Vladimir", + "Aleksei Filatov", "Aleksei Levushkin", "Aleksei Semiglazov", "Aleksey", @@ -192,6 +193,7 @@ const char * auto_contributors[] { "Bill", "BiteTheDDDDt", "BlahGeek", + "Bo Lu", "Bogdan", "Bogdan Voronin", "BohuTANG", @@ -256,6 +258,7 @@ const char * auto_contributors[] { "Denis Krivak", "Denis Zhuravlev", "Denny Crane", + "Denys Golotiuk", "Derek Chia", "Derek Perkins", "Diego Nieto (lesandie)", @@ -300,6 +303,7 @@ const char * auto_contributors[] { "Elizaveta Mironyuk", "Elykov Alexandr", "Emmanuel Donin de Rosière", + "Enrique Herreros", "Eric", "Eric Daniel", "Erixonich", @@ -476,6 +480,7 @@ const char * auto_contributors[] { "Kirill Shvakov", "Koblikov Mihail", "KochetovNicolai", + "Konstantin Bogdanov", "Konstantin Grabar", "Konstantin Ilchenko", "Konstantin Lebedev", @@ -571,6 +576,7 @@ const char * auto_contributors[] { "Mc.Spring", "Meena Renganathan", "Meena-Renganathan", + "MeenaRenganathan22", "MeiK", "Memo", "Metehan Çetinkaya", @@ -866,10 +872,12 @@ const char * auto_contributors[] { "VDimir", "VVMak", "Vadim", + "Vadim Akhma", "Vadim Plakhtinskiy", "Vadim Skipin", "Vadim Volodin", "VadimPE", + "Vage Ogannisian", "Val", "Valera Ryaboshapko", "Varinara", @@ -1033,6 +1041,7 @@ const char * auto_contributors[] { "bobrovskij artemij", "booknouse", "bseng", + "candiduslynx", "canenoneko", "caspian", "cekc", @@ -1266,6 +1275,7 @@ const char * auto_contributors[] { "maxim-babenko", "maxkuzn", "maxulan", + "mayamika", "mehanizm", "melin", "memo", @@ -1348,7 +1358,10 @@ const char * auto_contributors[] { "ritaank", "rnbondarenko", "robert", + "robot-ch-test-poll1", + "robot-ch-test-poll4", "robot-clickhouse", + "robot-clickhouse-ci-1", "robot-metrika-test", "rodrigargar", "roman", @@ -1372,7 +1385,9 @@ const char * auto_contributors[] { "shedx", "shuchaome", "shuyang", + "sichenzhao", "simon-says", + "simpleton", "snyk-bot", "songenjie", "sperlingxx", @@ -1380,6 +1395,7 @@ const char * auto_contributors[] { "spongedc", "spume", "spyros87", + "stan", "stavrolia", "stepenhu", "su-houzhen", @@ -1435,6 +1451,7 @@ const char * auto_contributors[] { "wangdh15", "weeds085490", "whysage", + "wineternity", "wuxiaobai24", "wzl", "xPoSx", @@ -1458,6 +1475,7 @@ const char * auto_contributors[] { "yonesko", "youenn lebras", "young scott", + "yuanyimeng", "yuchuansun", "yuefoo", "yulu86", diff --git a/tests/ci/approve_lambda/Dockerfile b/tests/ci/approve_lambda/Dockerfile new file mode 100644 index 000000000000..f53be71a8931 --- /dev/null +++ b/tests/ci/approve_lambda/Dockerfile @@ -0,0 +1,13 @@ +FROM public.ecr.aws/lambda/python:3.9 + +# Copy function code +COPY app.py ${LAMBDA_TASK_ROOT} + +# Install the function's dependencies using file requirements.txt +# from your project folder. + +COPY requirements.txt . +RUN pip3 install -r requirements.txt --target "${LAMBDA_TASK_ROOT}" + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "app.handler" ] diff --git a/tests/ci/approve_lambda/app.py b/tests/ci/approve_lambda/app.py new file mode 100644 index 000000000000..ffc5afa2f86c --- /dev/null +++ b/tests/ci/approve_lambda/app.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 + +import json +import time +import fnmatch +from collections import namedtuple +import jwt + +import requests +import boto3 + +API_URL = 'https://api.github.com/repos/ClickHouse/ClickHouse' + +SUSPICIOUS_CHANGED_FILES_NUMBER = 200 + +SUSPICIOUS_PATTERNS = [ + "tests/ci/*", + "docs/tools/*", + ".github/*", + "utils/release/*", + "docker/*", + "release", +] + +MAX_RETRY = 5 + +WorkflowDescription = namedtuple('WorkflowDescription', + ['name', 'action', 'run_id', 'event', 'sender_login', + 'workflow_id', 'fork_owner_login', 'fork_branch', 'sender_orgs']) + +TRUSTED_WORKFLOW_IDS = { + 14586616, # Cancel workflows, always trusted +} + +TRUSTED_ORG_IDS = { + 7409213, # yandex + 28471076, # altinity + 54801242, # clickhouse +} + +# Individual trusted contirbutors who are not in any trusted organization. +# Can be changed in runtime: we will append users that we learned to be in +# a trusted org, to save GitHub API calls. +TRUSTED_CONTRIBUTORS = { + "achimbab", + "adevyatova ", # DOCSUP + "Algunenano", # Raúl Marín, Tinybird + "AnaUvarova", # DOCSUP + "anauvarova", # technical writer, Yandex + "annvsh", # technical writer, Yandex + "atereh", # DOCSUP + "azat", + "bharatnc", # Newbie, but already with many contributions. + "bobrik", # Seasoned contributor, CloundFlare + "BohuTANG", + "damozhaeva", # DOCSUP + "den-crane", + "gyuton", # DOCSUP + "hagen1778", # Roman Khavronenko, seasoned contributor + "hczhcz", + "hexiaoting", # Seasoned contributor + "ildus", # adjust, ex-pgpro + "javisantana", # a Spanish ClickHouse enthusiast, ex-Carto + "ka1bi4", # DOCSUP + "kirillikoff", # DOCSUP + "kreuzerkrieg", + "lehasm", # DOCSUP + "michon470", # DOCSUP + "MyroTk", # Tester in Altinity + "myrrc", # Michael Kot, Altinity + "nikvas0", + "nvartolomei", + "olgarev", # DOCSUP + "otrazhenia", # Yandex docs contractor + "pdv-ru", # DOCSUP + "podshumok", # cmake expert from QRator Labs + "s-mx", # Maxim Sabyanin, former employee, present contributor + "sevirov", # technical writer, Yandex + "spongedu", # Seasoned contributor + "ucasfl", # Amos Bird's friend + "vdimir", # Employee + "vzakaznikov", + "YiuRULE", + "zlobober" # Developer of YT +} + + +def get_installation_id(jwt_token): + headers = { + "Authorization": f"Bearer {jwt_token}", + "Accept": "application/vnd.github.v3+json", + } + response = requests.get("https://api.github.com/app/installations", headers=headers) + response.raise_for_status() + data = response.json() + return data[0]['id'] + +def get_access_token(jwt_token, installation_id): + headers = { + "Authorization": f"Bearer {jwt_token}", + "Accept": "application/vnd.github.v3+json", + } + response = requests.post(f"https://api.github.com/app/installations/{installation_id}/access_tokens", headers=headers) + response.raise_for_status() + data = response.json() + return data['token'] + +def get_key_and_app_from_aws(): + secret_name = "clickhouse_github_secret_key" + session = boto3.session.Session() + client = session.client( + service_name='secretsmanager', + ) + get_secret_value_response = client.get_secret_value( + SecretId=secret_name + ) + data = json.loads(get_secret_value_response['SecretString']) + return data['clickhouse-app-key'], int(data['clickhouse-app-id']) + + +def is_trusted_sender(pr_user_login, pr_user_orgs): + if pr_user_login in TRUSTED_CONTRIBUTORS: + print(f"User '{pr_user_login}' is trusted") + return True + + print(f"User '{pr_user_login}' is not trusted") + + for org_id in pr_user_orgs: + if org_id in TRUSTED_ORG_IDS: + print(f"Org '{org_id}' is trusted; will mark user {pr_user_login} as trusted") + return True + print(f"Org '{org_id}' is not trusted") + + return False + +def _exec_get_with_retry(url): + for i in range(MAX_RETRY): + try: + response = requests.get(url) + response.raise_for_status() + return response.json() + except Exception as ex: + print("Got exception executing request", ex) + time.sleep(i + 1) + + raise Exception("Cannot execute GET request with retries") + +def _exec_post_with_retry(url, token, data=None): + headers = { + "Authorization": f"token {token}" + } + for i in range(MAX_RETRY): + try: + if data: + response = requests.post(url, headers=headers, json=data) + else: + response = requests.post(url, headers=headers) + if response.status_code == 403: + data = response.json() + if 'message' in data and data['message'] == 'This workflow run is not waiting for approval': + print("Workflow doesn't need approval") + return data + response.raise_for_status() + return response.json() + except Exception as ex: + print("Got exception executing request", ex) + time.sleep(i + 1) + + raise Exception("Cannot execute POST request with retry") + +def _get_pull_requests_from(owner, branch): + url = f"{API_URL}/pulls?head={owner}:{branch}" + return _exec_get_with_retry(url) + +def get_workflow_description_from_event(event): + action = event['action'] + sender_login = event['sender']['login'] + run_id = event['workflow_run']['id'] + event_type = event['workflow_run']['event'] + fork_owner = event['workflow_run']['head_repository']['owner']['login'] + fork_branch = event['workflow_run']['head_branch'] + orgs_data = _exec_get_with_retry(event['sender']['organizations_url']) + sender_orgs = [org['id'] for org in orgs_data] + name = event['workflow_run']['name'] + workflow_id = event['workflow_run']['workflow_id'] + return WorkflowDescription( + name=name, + action=action, + sender_login=sender_login, + run_id=run_id, + event=event_type, + fork_owner_login=fork_owner, + fork_branch=fork_branch, + sender_orgs=sender_orgs, + workflow_id=workflow_id, + ) + + +def get_changed_files_for_pull_request(pull_request): + number = pull_request['number'] + + changed_files = set([]) + for i in range(1, 31): + print("Requesting changed files page", i) + url = f"{API_URL}/pulls/{number}/files?page={i}&per_page=100" + data = _exec_get_with_retry(url) + print(f"Got {len(data)} changed files") + if len(data) == 0: + print("No more changed files") + break + + for change in data: + #print("Adding changed file", change['filename']) + changed_files.add(change['filename']) + + if len(changed_files) >= SUSPICIOUS_CHANGED_FILES_NUMBER: + print(f"More than {len(changed_files)} changed files. Will stop fetching new files.") + break + + return changed_files + +def check_suspicious_changed_files(changed_files): + if len(changed_files) >= SUSPICIOUS_CHANGED_FILES_NUMBER: + print(f"Too many files changed {len(changed_files)}, need manual approve") + return True + + for path in changed_files: + for pattern in SUSPICIOUS_PATTERNS: + if fnmatch.fnmatch(path, pattern): + print(f"File {path} match suspicious pattern {pattern}, will not approve automatically") + return True + + print("No changed files match suspicious patterns, run will be approved") + return False + +def approve_run(run_id, token): + url = f"{API_URL}/actions/runs/{run_id}/approve" + _exec_post_with_retry(url, token) + +def label_manual_approve(pull_request, token): + number = pull_request['number'] + url = f"{API_URL}/issues/{number}/labels" + data = {"labels" : "manual approve"} + + _exec_post_with_retry(url, token, data) + +def get_token_from_aws(): + private_key, app_id = get_key_and_app_from_aws() + payload = { + "iat": int(time.time()) - 60, + "exp": int(time.time()) + (10 * 60), + "iss": app_id, + } + + encoded_jwt = jwt.encode(payload, private_key, algorithm="RS256") + installation_id = get_installation_id(encoded_jwt) + return get_access_token(encoded_jwt, installation_id) + +def main(event): + token = get_token_from_aws() + event_data = json.loads(event['body']) + workflow_description = get_workflow_description_from_event(event_data) + + print("Got workflow description", workflow_description) + if workflow_description.action != "requested": + print("Exiting, event action is", workflow_description.action) + return + + if workflow_description.workflow_id in TRUSTED_WORKFLOW_IDS: + print("Workflow in trusted list, approving run") + approve_run(workflow_description.run_id, token) + return + + if is_trusted_sender(workflow_description.sender_login, workflow_description.sender_orgs): + print("Sender is trusted, approving run") + approve_run(workflow_description.run_id, token) + return + + pull_requests = _get_pull_requests_from(workflow_description.fork_owner_login, workflow_description.fork_branch) + print("Got pull requests for workflow", len(pull_requests)) + if len(pull_requests) > 1: + raise Exception("Received more than one PR for workflow run") + + if len(pull_requests) < 1: + raise Exception("Cannot find any pull requests for workflow run") + + pull_request = pull_requests[0] + print("Pull request for workflow number", pull_request['number']) + + changed_files = get_changed_files_for_pull_request(pull_request) + print(f"Totally have {len(changed_files)} changed files in PR:", changed_files) + if check_suspicious_changed_files(changed_files): + print(f"Pull Request {pull_request['number']} has suspicious changes, label it for manuall approve") + label_manual_approve(pull_request, token) + else: + print(f"Pull Request {pull_request['number']} has no suspicious changes") + approve_run(workflow_description.run_id, token) + +def handler(event, _): + main(event) diff --git a/tests/ci/approve_lambda/requirements.txt b/tests/ci/approve_lambda/requirements.txt new file mode 100644 index 000000000000..c0dcf4a4dde7 --- /dev/null +++ b/tests/ci/approve_lambda/requirements.txt @@ -0,0 +1,3 @@ +requests +PyJWT +cryptography diff --git a/tests/ci/cancel_workflow_lambda/Dockerfile b/tests/ci/cancel_workflow_lambda/Dockerfile new file mode 100644 index 000000000000..f53be71a8931 --- /dev/null +++ b/tests/ci/cancel_workflow_lambda/Dockerfile @@ -0,0 +1,13 @@ +FROM public.ecr.aws/lambda/python:3.9 + +# Copy function code +COPY app.py ${LAMBDA_TASK_ROOT} + +# Install the function's dependencies using file requirements.txt +# from your project folder. + +COPY requirements.txt . +RUN pip3 install -r requirements.txt --target "${LAMBDA_TASK_ROOT}" + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "app.handler" ] diff --git a/tests/ci/cancel_workflow_lambda/app.py b/tests/ci/cancel_workflow_lambda/app.py new file mode 100644 index 000000000000..e475fcb931a2 --- /dev/null +++ b/tests/ci/cancel_workflow_lambda/app.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 + +import json +import time +import jwt + +import requests +import boto3 + +# https://docs.github.com/en/rest/reference/actions#cancel-a-workflow-run +# +API_URL = 'https://api.github.com/repos/ClickHouse/ClickHouse' + +MAX_RETRY = 5 + +def get_installation_id(jwt_token): + headers = { + "Authorization": f"Bearer {jwt_token}", + "Accept": "application/vnd.github.v3+json", + } + response = requests.get("https://api.github.com/app/installations", headers=headers) + response.raise_for_status() + data = response.json() + return data[0]['id'] + +def get_access_token(jwt_token, installation_id): + headers = { + "Authorization": f"Bearer {jwt_token}", + "Accept": "application/vnd.github.v3+json", + } + response = requests.post(f"https://api.github.com/app/installations/{installation_id}/access_tokens", headers=headers) + response.raise_for_status() + data = response.json() + return data['token'] + +def get_key_and_app_from_aws(): + secret_name = "clickhouse_github_secret_key" + session = boto3.session.Session() + client = session.client( + service_name='secretsmanager', + ) + get_secret_value_response = client.get_secret_value( + SecretId=secret_name + ) + data = json.loads(get_secret_value_response['SecretString']) + return data['clickhouse-app-key'], int(data['clickhouse-app-id']) + +def get_token_from_aws(): + private_key, app_id = get_key_and_app_from_aws() + payload = { + "iat": int(time.time()) - 60, + "exp": int(time.time()) + (10 * 60), + "iss": app_id, + } + + encoded_jwt = jwt.encode(payload, private_key, algorithm="RS256") + installation_id = get_installation_id(encoded_jwt) + return get_access_token(encoded_jwt, installation_id) + +def _exec_get_with_retry(url): + for i in range(MAX_RETRY): + try: + response = requests.get(url) + response.raise_for_status() + return response.json() + except Exception as ex: + print("Got exception executing request", ex) + time.sleep(i + 1) + + raise Exception("Cannot execute GET request with retries") + + +def get_workflows_cancel_urls_for_pull_request(pull_request_event): + head_branch = pull_request_event['head']['ref'] + print("PR", pull_request_event['number'], "has head ref", head_branch) + workflows = _exec_get_with_retry(API_URL + f"/actions/runs?branch={head_branch}") + workflows_urls_to_cancel = set([]) + for workflow in workflows['workflow_runs']: + if workflow['status'] != 'completed': + print("Workflow", workflow['url'], "not finished, going to be cancelled") + workflows_urls_to_cancel.add(workflow['cancel_url']) + else: + print("Workflow", workflow['url'], "already finished, will not try to cancel") + + return workflows_urls_to_cancel + +def _exec_post_with_retry(url, token): + headers = { + "Authorization": f"token {token}" + } + for i in range(MAX_RETRY): + try: + response = requests.post(url, headers=headers) + response.raise_for_status() + return response.json() + except Exception as ex: + print("Got exception executing request", ex) + time.sleep(i + 1) + + raise Exception("Cannot execute POST request with retry") + +def cancel_workflows(urls_to_cancel, token): + for url in urls_to_cancel: + print("Cancelling workflow using url", url) + _exec_post_with_retry(url, token) + print("Workflow cancelled") + +def main(event): + token = get_token_from_aws() + event_data = json.loads(event['body']) + + print("Got event for PR", event_data['number']) + action = event_data['action'] + print("Got action", event_data['action']) + pull_request = event_data['pull_request'] + labels = { l['name'] for l in pull_request['labels'] } + print("PR has labels", labels) + if action == 'closed' or 'do not test' in labels: + print("PR merged/closed or manually labeled 'do not test' will kill workflows") + workflows_to_cancel = get_workflows_cancel_urls_for_pull_request(pull_request) + print(f"Found {len(workflows_to_cancel)} workflows to cancel") + cancel_workflows(workflows_to_cancel, token) + else: + print("Nothing to do") + +def handler(event, _): + main(event) diff --git a/tests/ci/cancel_workflow_lambda/requirements.txt b/tests/ci/cancel_workflow_lambda/requirements.txt new file mode 100644 index 000000000000..c0dcf4a4dde7 --- /dev/null +++ b/tests/ci/cancel_workflow_lambda/requirements.txt @@ -0,0 +1,3 @@ +requests +PyJWT +cryptography diff --git a/tests/ci/pvs_check.py b/tests/ci/pvs_check.py new file mode 100644 index 000000000000..aa4a130902b0 --- /dev/null +++ b/tests/ci/pvs_check.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 + +# pylint: disable=line-too-long + +import os +import json +import logging +import sys +from github import Github +from s3_helper import S3Helper +from pr_info import PRInfo, get_event +from get_robot_token import get_best_robot_token, get_parameter_from_ssm +from upload_result_helper import upload_results +from commit_status_helper import get_commit +from clickhouse_helper import ClickHouseHelper, prepare_tests_results_for_clickhouse +from stopwatch import Stopwatch +from rerun_helper import RerunHelper +from tee_popen import TeePopen + +NAME = 'PVS Studio (actions)' +LICENCE_NAME = 'Free license: ClickHouse, Yandex' +HTML_REPORT_FOLDER = 'pvs-studio-html-report' +TXT_REPORT_NAME = 'pvs-studio-task-report.txt' + +def _process_txt_report(path): + warnings = [] + errors = [] + with open(path, 'r') as report_file: + for line in report_file: + if 'viva64' in line: + continue + + if 'warn' in line: + warnings.append(':'.join(line.split('\t')[0:2])) + elif 'err' in line: + errors.append(':'.join(line.split('\t')[0:2])) + + return warnings, errors + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + stopwatch = Stopwatch() + + repo_path = os.path.join(os.getenv("REPO_COPY", os.path.abspath("../../"))) + temp_path = os.path.join(os.getenv("TEMP_PATH")) + + pr_info = PRInfo(get_event()) + # this check modify repository so copy it to the temp directory + logging.info("Repo copy path %s", repo_path) + + gh = Github(get_best_robot_token()) + rerun_helper = RerunHelper(gh, pr_info, NAME) + if rerun_helper.is_already_finished_by_status(): + logging.info("Check is already finished according to github status, exiting") + sys.exit(0) + + images_path = os.path.join(temp_path, 'changed_images.json') + docker_image = 'clickhouse/pvs-test' + if os.path.exists(images_path): + logging.info("Images file exists") + with open(images_path, 'r') as images_fd: + images = json.load(images_fd) + logging.info("Got images %s", images) + if 'clickhouse/pvs-test' in images: + docker_image += ':' + images['clickhouse/pvs-test'] + + logging.info("Got docker image %s", docker_image) + + s3_helper = S3Helper('https://s3.amazonaws.com') + + licence_key = get_parameter_from_ssm('pvs_studio_key') + cmd = f"docker run -u $(id -u ${{USER}}):$(id -g ${{USER}}) --volume={repo_path}:/repo_folder --volume={temp_path}:/test_output -e LICENCE_NAME='{LICENCE_NAME}' -e LICENCE_KEY='{licence_key}' {docker_image}" + commit = get_commit(gh, pr_info.sha) + + run_log_path = os.path.join(temp_path, 'run_log.log') + + with TeePopen(cmd, run_log_path) as process: + retcode = process.wait() + if retcode != 0: + logging.info("Run failed") + else: + logging.info("Run Ok") + + if retcode != 0: + commit.create_status(context=NAME, description='PVS report failed to build', state='failure', target_url=f"https://github.com/ClickHouse/ClickHouse/actions/runs/{os.getenv('GITHUB_RUN_ID')}") + sys.exit(1) + + try: + s3_path_prefix = str(pr_info.number) + "/" + pr_info.sha + "/" + NAME.lower().replace(' ', '_') + html_urls = s3_helper.upload_test_folder_to_s3(os.path.join(temp_path, HTML_REPORT_FOLDER), s3_path_prefix) + index_html = None + + for url in html_urls: + if 'index.html' in url: + index_html = 'HTML report'.format(url) + break + + if not index_html: + commit.create_status(context=NAME, description='PVS report failed to build', state='failure', + target_url=f"{os.getenv('GITHUB_SERVER_URL')}/{os.getenv('GITHUB_REPOSITORY')}/actions/runs/{os.getenv('GITHUB_RUN_ID')}") + sys.exit(1) + + txt_report = os.path.join(temp_path, TXT_REPORT_NAME) + warnings, errors = _process_txt_report(txt_report) + errors = errors + warnings + + status = 'success' + test_results = [(index_html, "Look at the report"), ("Errors count not checked", "OK")] + description = "Total errors {}".format(len(errors)) + additional_logs = [txt_report, os.path.join(temp_path, 'pvs-studio.log')] + report_url = upload_results(s3_helper, pr_info.number, pr_info.sha, test_results, additional_logs, NAME) + + print("::notice ::Report url: {}".format(report_url)) + commit = get_commit(gh, pr_info.sha) + commit.create_status(context=NAME, description=description, state=status, target_url=report_url) + + ch_helper = ClickHouseHelper() + prepared_events = prepare_tests_results_for_clickhouse(pr_info, test_results, status, stopwatch.duration_seconds, stopwatch.start_time_str, report_url, NAME) + ch_helper.insert_events_into(db="gh-data", table="checks", events=prepared_events) + except Exception as ex: + print("Got an exception", ex) + sys.exit(1) diff --git a/tests/ci/worker/init_builder.sh b/tests/ci/worker/init_builder.sh new file mode 100644 index 000000000000..8fd00c1db0a6 --- /dev/null +++ b/tests/ci/worker/init_builder.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +set -uo pipefail + +echo "Running init script" +export DEBIAN_FRONTEND=noninteractive +export RUNNER_HOME=/home/ubuntu/actions-runner + +export RUNNER_URL="https://github.com/ClickHouse" +# Funny fact, but metadata service has fixed IP +export INSTANCE_ID=`curl -s http://169.254.169.254/latest/meta-data/instance-id` + +while true; do + runner_pid=`pgrep run.sh` + echo "Got runner pid $runner_pid" + + cd $RUNNER_HOME + if [ -z "$runner_pid" ]; then + echo "Receiving token" + RUNNER_TOKEN=`/usr/local/bin/aws ssm get-parameter --name github_runner_registration_token --with-decryption --output text --query Parameter.Value` + + echo "Will try to remove runner" + sudo -u ubuntu ./config.sh remove --token $RUNNER_TOKEN ||: + + echo "Going to configure runner" + sudo -u ubuntu ./config.sh --url $RUNNER_URL --token $RUNNER_TOKEN --name $INSTANCE_ID --runnergroup Default --labels 'self-hosted,Linux,X64,builder' --work _work + + echo "Run" + sudo -u ubuntu ./run.sh & + sleep 15 + else + echo "Runner is working with pid $runner_pid, nothing to do" + sleep 10 + fi +done diff --git a/tests/ci/worker/init_func_tester.sh b/tests/ci/worker/init_func_tester.sh new file mode 100644 index 000000000000..d3ee3cb3d7fb --- /dev/null +++ b/tests/ci/worker/init_func_tester.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +set -uo pipefail + +echo "Running init script" +export DEBIAN_FRONTEND=noninteractive +export RUNNER_HOME=/home/ubuntu/actions-runner + +export RUNNER_URL="https://github.com/ClickHouse" +# Funny fact, but metadata service has fixed IP +export INSTANCE_ID=`curl -s http://169.254.169.254/latest/meta-data/instance-id` + +while true; do + runner_pid=`pgrep run.sh` + echo "Got runner pid $runner_pid" + + cd $RUNNER_HOME + if [ -z "$runner_pid" ]; then + echo "Receiving token" + RUNNER_TOKEN=`/usr/local/bin/aws ssm get-parameter --name github_runner_registration_token --with-decryption --output text --query Parameter.Value` + + echo "Will try to remove runner" + sudo -u ubuntu ./config.sh remove --token $RUNNER_TOKEN ||: + + echo "Going to configure runner" + sudo -u ubuntu ./config.sh --url $RUNNER_URL --token $RUNNER_TOKEN --name $INSTANCE_ID --runnergroup Default --labels 'self-hosted,Linux,X64,func-tester' --work _work + + echo "Run" + sudo -u ubuntu ./run.sh & + sleep 15 + else + echo "Runner is working with pid $runner_pid, nothing to do" + sleep 10 + fi +done diff --git a/tests/ci/worker/init_fuzzer_unit_tester.sh b/tests/ci/worker/init_fuzzer_unit_tester.sh new file mode 100644 index 000000000000..2fbedba9e40b --- /dev/null +++ b/tests/ci/worker/init_fuzzer_unit_tester.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +set -uo pipefail + +echo "Running init script" +export DEBIAN_FRONTEND=noninteractive +export RUNNER_HOME=/home/ubuntu/actions-runner + +export RUNNER_URL="https://github.com/ClickHouse" +# Funny fact, but metadata service has fixed IP +export INSTANCE_ID=`curl -s http://169.254.169.254/latest/meta-data/instance-id` + +while true; do + runner_pid=`pgrep run.sh` + echo "Got runner pid $runner_pid" + + cd $RUNNER_HOME + if [ -z "$runner_pid" ]; then + echo "Receiving token" + RUNNER_TOKEN=`/usr/local/bin/aws ssm get-parameter --name github_runner_registration_token --with-decryption --output text --query Parameter.Value` + + echo "Will try to remove runner" + sudo -u ubuntu ./config.sh remove --token $RUNNER_TOKEN ||: + + echo "Going to configure runner" + sudo -u ubuntu ./config.sh --url $RUNNER_URL --token $RUNNER_TOKEN --name $INSTANCE_ID --runnergroup Default --labels 'self-hosted,Linux,X64,fuzzer-unit-tester' --work _work + + echo "Run" + sudo -u ubuntu ./run.sh & + sleep 15 + else + echo "Runner is working with pid $runner_pid, nothing to do" + sleep 10 + fi +done diff --git a/tests/ci/worker/init_stress_tester.sh b/tests/ci/worker/init_stress_tester.sh new file mode 100644 index 000000000000..234f035e1eaf --- /dev/null +++ b/tests/ci/worker/init_stress_tester.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +set -uo pipefail + +echo "Running init script" +export DEBIAN_FRONTEND=noninteractive +export RUNNER_HOME=/home/ubuntu/actions-runner + +export RUNNER_URL="https://github.com/ClickHouse" +# Funny fact, but metadata service has fixed IP +export INSTANCE_ID=`curl -s http://169.254.169.254/latest/meta-data/instance-id` + +while true; do + runner_pid=`pgrep run.sh` + echo "Got runner pid $runner_pid" + + cd $RUNNER_HOME + if [ -z "$runner_pid" ]; then + echo "Receiving token" + RUNNER_TOKEN=`/usr/local/bin/aws ssm get-parameter --name github_runner_registration_token --with-decryption --output text --query Parameter.Value` + + echo "Will try to remove runner" + sudo -u ubuntu ./config.sh remove --token $RUNNER_TOKEN ||: + + echo "Going to configure runner" + sudo -u ubuntu ./config.sh --url $RUNNER_URL --token $RUNNER_TOKEN --name $INSTANCE_ID --runnergroup Default --labels 'self-hosted,Linux,X64,stress-tester' --work _work + + echo "Run" + sudo -u ubuntu ./run.sh & + sleep 15 + else + echo "Runner is working with pid $runner_pid, nothing to do" + sleep 10 + fi +done diff --git a/tests/ci/worker/init_style_checker.sh b/tests/ci/worker/init_style_checker.sh new file mode 100644 index 000000000000..77cf66b5262e --- /dev/null +++ b/tests/ci/worker/init_style_checker.sh @@ -0,0 +1,20 @@ +#!/usr/bin/bash +set -euo pipefail + +echo "Running init script" +export DEBIAN_FRONTEND=noninteractive +export RUNNER_HOME=/home/ubuntu/actions-runner + +echo "Receiving token" +export RUNNER_TOKEN=`/usr/local/bin/aws ssm get-parameter --name github_runner_registration_token --with-decryption --output text --query Parameter.Value` +export RUNNER_URL="https://github.com/ClickHouse" +# Funny fact, but metadata service has fixed IP +export INSTANCE_ID=`curl -s http://169.254.169.254/latest/meta-data/instance-id` + +cd $RUNNER_HOME + +echo "Going to configure runner" +sudo -u ubuntu ./config.sh --url $RUNNER_URL --token $RUNNER_TOKEN --name $INSTANCE_ID --runnergroup Default --labels 'self-hosted,Linux,X64,style-checker' --work _work + +echo "Run" +sudo -u ubuntu ./run.sh diff --git a/tests/ci/worker/ubuntu_style_check.sh b/tests/ci/worker/ubuntu_style_check.sh new file mode 100644 index 000000000000..bf5c6057bed7 --- /dev/null +++ b/tests/ci/worker/ubuntu_style_check.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash +set -euo pipefail + +echo "Running prepare script" +export DEBIAN_FRONTEND=noninteractive +export RUNNER_VERSION=2.283.1 +export RUNNER_HOME=/home/ubuntu/actions-runner + +apt-get update + +apt-get install --yes --no-install-recommends \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg \ + lsb-release \ + python3-pip \ + unzip + +curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg + +echo "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null + +apt-get update + +apt-get install --yes --no-install-recommends docker-ce docker-ce-cli containerd.io + +usermod -aG docker ubuntu + +# enable ipv6 in containers (fixed-cidr-v6 is some random network mask) +cat < /etc/docker/daemon.json +{ + "ipv6": true, + "fixed-cidr-v6": "2001:db8:1::/64" +} +EOT + +systemctl restart docker + +pip install boto3 pygithub requests urllib3 unidiff + +mkdir -p $RUNNER_HOME && cd $RUNNER_HOME + +curl -O -L https://github.com/actions/runner/releases/download/v$RUNNER_VERSION/actions-runner-linux-x64-$RUNNER_VERSION.tar.gz + +tar xzf ./actions-runner-linux-x64-$RUNNER_VERSION.tar.gz +rm -f ./actions-runner-linux-x64-$RUNNER_VERSION.tar.gz +./bin/installdependencies.sh + +chown -R ubuntu:ubuntu $RUNNER_HOME + +cd /home/ubuntu +curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" +unzip awscliv2.zip +./aws/install + +rm -rf /home/ubuntu/awscliv2.zip /home/ubuntu/aws diff --git a/tests/integration/test_mask_sensitive_info/test.py b/tests/integration/test_mask_sensitive_info/test.py index f546c559f667..f938148e5a0b 100644 --- a/tests/integration/test_mask_sensitive_info/test.py +++ b/tests/integration/test_mask_sensitive_info/test.py @@ -21,11 +21,21 @@ def check_logs(must_contain=[], must_not_contain=[]): node.query("SYSTEM FLUSH LOGS") for str in must_contain: - escaped_str = str.replace("`", "\\`").replace("[", "\\[").replace("]", "\\]") + escaped_str = ( + str.replace("`", "\\`") + .replace("[", "\\[") + .replace("]", "\\]") + .replace("*", "\\*") + ) assert node.contains_in_log(escaped_str) for str in must_not_contain: - escaped_str = str.replace("`", "\\`").replace("[", "\\[").replace("]", "\\]") + escaped_str = ( + str.replace("`", "\\`") + .replace("[", "\\[") + .replace("]", "\\]") + .replace("*", "\\*") + ) assert not node.contains_in_log(escaped_str) for str in must_contain: @@ -257,6 +267,34 @@ def test_table_functions(): node.query(f"DROP TABLE tablefunc{i}") +def test_table_function_ways_to_call(): + password = new_password() + + table_function = f"s3('http://minio1:9001/root/data/testfuncw.tsv.gz', 'minio', '{password}', 'TSV', 'x int')" + + queries = [ + "CREATE TABLE tablefuncw (`x` int) AS {}", + "INSERT INTO FUNCTION {} SELECT * FROM numbers(10)", + "DESCRIBE TABLE {}", + ] + + for query in queries: + # query_and_get_answer_with_error() is used here because we don't want to stop on error "Cannot connect to AWS". + # We test logging here and not actual work with AWS server. + node.query_and_get_answer_with_error(query.format(table_function)) + + table_function_with_hidden_arg = "s3('http://minio1:9001/root/data/testfuncw.tsv.gz', 'minio', '[HIDDEN]', 'TSV', 'x int')" + + check_logs( + must_contain=[ + query.format(table_function_with_hidden_arg) for query in queries + ], + must_not_contain=[password], + ) + + node.query("DROP TABLE tablefuncw") + + def test_encryption_functions(): plaintext = new_password() cipher = new_password() diff --git a/tests/integration/test_timezone_config/test.py b/tests/integration/test_timezone_config/test.py index 180026c58187..a40893e29e19 100644 --- a/tests/integration/test_timezone_config/test.py +++ b/tests/integration/test_timezone_config/test.py @@ -31,6 +31,7 @@ def test_overflow_toDate32(start_cluster): assert node.query("SELECT toDate32('1000-12-31','UTC')") == "1900-01-01\n" + def test_overflow_toDateTime(start_cluster): assert ( node.query("SELECT toDateTime('2999-12-31 00:00:00','UTC')") diff --git a/tests/queries/0_stateless/00918_json_functions.reference b/tests/queries/0_stateless/00918_json_functions.reference index 7f5d5fabf122..7b972fc5cd46 100644 --- a/tests/queries/0_stateless/00918_json_functions.reference +++ b/tests/queries/0_stateless/00918_json_functions.reference @@ -107,6 +107,12 @@ true Bool 123456789012 UInt64 0 UInt64 0 Int8 +{'a':'hello','b':'world'} +{'a':'hello','b':'world'} +{'a':('hello',100),'b':('world',200)} +{'a':[100,200],'b':[-100,200,300]} +{'a':{'c':'hello'},'b':{'d':'world'}} +{'c':'hello'} --JSONExtractKeysAndValues-- [('a','hello'),('b','[-100,200,300]')] [('b',[-100,200,300])] @@ -151,6 +157,7 @@ e u v --show error: type should be const string +--show error: key of map type should be String --allow_simdjson=0-- --JSONLength-- 2 @@ -215,6 +222,12 @@ Friday (3,0) (3,5) (3,0) +{'a':'hello','b':'world'} +{'a':'hello','b':'world'} +{'a':('hello',100),'b':('world',200)} +{'a':[100,200],'b':[-100,200,300]} +{'a':{'c':'hello'},'b':{'d':'world'}} +{'c':'hello'} --JSONExtractKeysAndValues-- [('a','hello'),('b','[-100,200,300]')] [('b',[-100,200,300])] @@ -264,3 +277,4 @@ u v --show error: type should be const string --show error: index type should be integer +--show error: key of map type should be String diff --git a/tests/queries/0_stateless/00918_json_functions.sql b/tests/queries/0_stateless/00918_json_functions.sql index ab4bc6390849..0f534a98d6d1 100644 --- a/tests/queries/0_stateless/00918_json_functions.sql +++ b/tests/queries/0_stateless/00918_json_functions.sql @@ -121,6 +121,13 @@ SELECT JSONExtract('{"a": "123456789012.345"}', 'a', 'UInt64') as a, toTypeName( SELECT JSONExtract('{"a": "-2000.22"}', 'a', 'UInt64') as a, toTypeName(a); SELECT JSONExtract('{"a": "-2000.22"}', 'a', 'Int8') as a, toTypeName(a); +SELECT JSONExtract('{"a": "hello", "b": "world"}', 'Map(String, String)'); +SELECT JSONExtract('{"a": "hello", "b": "world"}', 'Map(LowCardinality(String), String)'); +SELECT JSONExtract('{"a": ["hello", 100.0], "b": ["world", 200]}', 'Map(String, Tuple(String, Float64))'); +SELECT JSONExtract('{"a": [100.0, 200], "b": [-100, 200.0, 300]}', 'Map(String, Array(Float64))'); +SELECT JSONExtract('{"a": {"c": "hello"}, "b": {"d": "world"}}', 'Map(String, Map(String, String))'); +SELECT JSONExtract('{"a": {"c": "hello"}, "b": {"d": "world"}}', 'a', 'Map(String, String)'); + SELECT '--JSONExtractKeysAndValues--'; SELECT JSONExtractKeysAndValues('{"a": "hello", "b": [-100, 200.0, 300]}', 'String'); SELECT JSONExtractKeysAndValues('{"a": "hello", "b": [-100, 200.0, 300]}', 'Array(Float64)'); @@ -164,8 +171,11 @@ SELECT JSONExtractString('["a", "b", "c", "d", "e"]', idx) FROM (SELECT arrayJoi SELECT JSONExtractString(json, 's') FROM (SELECT arrayJoin(['{"s":"u"}', '{"s":"v"}']) AS json); SELECT '--show error: type should be const string'; -SELECT JSONExtractKeysAndValues([], JSONLength('^?V{LSwp')); -- { serverError 44 } -WITH '{"i": 1, "f": 1.2}' AS json SELECT JSONExtract(json, 'i', JSONType(json, 'i')); -- { serverError 44 } +SELECT JSONExtractKeysAndValues([], JSONLength('^?V{LSwp')); -- { serverError ILLEGAL_COLUMN } +WITH '{"i": 1, "f": 1.2}' AS json SELECT JSONExtract(json, 'i', JSONType(json, 'i')); -- { serverError ILLEGAL_COLUMN } + +SELECT '--show error: key of map type should be String'; +SELECT JSONExtract('{"a": [100.0, 200], "b": [-100, 200.0, 300]}', 'Map(Int64, Array(Float64))'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } SELECT '--allow_simdjson=0--'; @@ -244,6 +254,13 @@ SELECT JSONExtract('{"a":3}', 'Tuple(Int, Int)'); SELECT JSONExtract('[3,5,7]', 'Tuple(Int, Int)'); SELECT JSONExtract('[3]', 'Tuple(Int, Int)'); +SELECT JSONExtract('{"a": "hello", "b": "world"}', 'Map(String, String)'); +SELECT JSONExtract('{"a": "hello", "b": "world"}', 'Map(LowCardinality(String), String)'); +SELECT JSONExtract('{"a": ["hello", 100.0], "b": ["world", 200]}', 'Map(String, Tuple(String, Float64))'); +SELECT JSONExtract('{"a": [100.0, 200], "b": [-100, 200.0, 300]}', 'Map(String, Array(Float64))'); +SELECT JSONExtract('{"a": {"c": "hello"}, "b": {"d": "world"}}', 'Map(String, Map(String, String))'); +SELECT JSONExtract('{"a": {"c": "hello"}, "b": {"d": "world"}}', 'a', 'Map(String, String)'); + SELECT '--JSONExtractKeysAndValues--'; SELECT JSONExtractKeysAndValues('{"a": "hello", "b": [-100, 200.0, 300]}', 'String'); SELECT JSONExtractKeysAndValues('{"a": "hello", "b": [-100, 200.0, 300]}', 'Array(Float64)'); @@ -292,8 +309,11 @@ SELECT JSONExtractString('["a", "b", "c", "d", "e"]', idx) FROM (SELECT arrayJoi SELECT JSONExtractString(json, 's') FROM (SELECT arrayJoin(['{"s":"u"}', '{"s":"v"}']) AS json); SELECT '--show error: type should be const string'; -SELECT JSONExtractKeysAndValues([], JSONLength('^?V{LSwp')); -- { serverError 44 } -WITH '{"i": 1, "f": 1.2}' AS json SELECT JSONExtract(json, 'i', JSONType(json, 'i')); -- { serverError 44 } +SELECT JSONExtractKeysAndValues([], JSONLength('^?V{LSwp')); -- { serverError ILLEGAL_COLUMN } +WITH '{"i": 1, "f": 1.2}' AS json SELECT JSONExtract(json, 'i', JSONType(json, 'i')); -- { serverError ILLEGAL_COLUMN } SELECT '--show error: index type should be integer'; -SELECT JSONExtract('[]', JSONExtract('0', 'UInt256'), 'UInt256'); -- { serverError 43 } +SELECT JSONExtract('[]', JSONExtract('0', 'UInt256'), 'UInt256'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } + +SELECT '--show error: key of map type should be String'; +SELECT JSONExtract('{"a": [100.0, 200], "b": [-100, 200.0, 300]}', 'Map(Int64, Array(Float64))'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } diff --git a/tests/queries/0_stateless/01651_map_functions.reference b/tests/queries/0_stateless/01651_map_functions.reference index 06adaf48cfd9..60f1b6e0d0c4 100644 --- a/tests/queries/0_stateless/01651_map_functions.reference +++ b/tests/queries/0_stateless/01651_map_functions.reference @@ -8,6 +8,8 @@ 0 ['name','age'] ['name','gender'] +{'name':'zhangsan','age':'10'} +{'name':'lisi','gender':'female'} 1 0 0 1 0 1 1 0 0 @@ -17,7 +19,20 @@ [1000] [1001] [1002] +{'1000':'2000','1000':'3000','1000':'4000'} +{'1001':'2002','1001':'3003','1001':'4004'} +{'1002':'2004','1002':'3006','1002':'4008'} {'aa':4,'bb':5} ['aa','bb'] [4,5] {'aa':4,'bb':5} 1 0 {0:0} 1 {0:0} 0 +{'aa':4,'bb':5} +{'aa':4,'bb':5} +{'aa':4,'bb':5} +{'aa':4,'bb':5} +{'aa':4,'bb':5} +{'aa':4,'bb':5} +{'aa':4,'bb':5} +{'aa':('a',4),'bb':('b',5)} +{'aa':('a',4),'bb':('b',5)} +{'aa':('a',4),'bb':('b',5)} diff --git a/tests/queries/0_stateless/01651_map_functions.sql b/tests/queries/0_stateless/01651_map_functions.sql index bbaaf9bee84b..5942bf8b2c22 100644 --- a/tests/queries/0_stateless/01651_map_functions.sql +++ b/tests/queries/0_stateless/01651_map_functions.sql @@ -2,23 +2,25 @@ set allow_experimental_map_type = 1; -- String type drop table if exists table_map; -create table table_map (a Map(String, String), b String) engine = Memory; -insert into table_map values ({'name':'zhangsan', 'age':'10'}, 'name'), ({'name':'lisi', 'gender':'female'},'age'); +create table table_map (a Map(String, String), b String, c Array(String), d Array(String)) engine = Memory; +insert into table_map values ({'name':'zhangsan', 'age':'10'}, 'name', ['name', 'age'], ['zhangsan', '10']), ({'name':'lisi', 'gender':'female'},'age',['name', 'gender'], ['lisi', 'female']); select mapContains(a, 'name') from table_map; select mapContains(a, 'gender') from table_map; select mapContains(a, 'abc') from table_map; select mapContains(a, b) from table_map; -select mapContains(a, 10) from table_map; -- { serverError 386 } +select mapContains(a, 10) from table_map; -- { serverError NO_COMMON_TYPE } select mapKeys(a) from table_map; +select mapFromArrays(c, d) from table_map; drop table if exists table_map; -CREATE TABLE table_map (a Map(UInt8, Int), b UInt8, c UInt32) engine = MergeTree order by tuple(); -insert into table_map select map(number, number), number, number from numbers(1000, 3); +CREATE TABLE table_map (a Map(UInt8, Int), b UInt8, c UInt32, d Array(String), e Array(String)) engine = MergeTree order by tuple(); +insert into table_map select map(number, number), number, number, [number, number, number], [number*2, number*3, number*4] from numbers(1000, 3); select mapContains(a, b), mapContains(a, c), mapContains(a, 233) from table_map; -select mapContains(a, 'aaa') from table_map; -- { serverError 386 } -select mapContains(b, 'aaa') from table_map; -- { serverError 43 } +select mapContains(a, 'aaa') from table_map; -- { serverError NO_COMMON_TYPE } +select mapContains(b, 'aaa') from table_map; -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } select mapKeys(a) from table_map; select mapValues(a) from table_map; +select mapFromArrays(d, e) from table_map; drop table if exists table_map; @@ -27,3 +29,18 @@ select map( 'aa', 4, 'bb' , 5) as m, mapKeys(m), mapValues(m); select map( 'aa', 4, 'bb' , 5) as m, mapContains(m, 'aa'), mapContains(m, 'k'); select map(0, 0) as m, mapContains(m, number % 2) from numbers(2); + +select mapFromArrays(['aa', 'bb'], [4, 5]); +select mapFromArrays(['aa', 'bb'], materialize([4, 5])) from numbers(2); +select mapFromArrays(materialize(['aa', 'bb']), [4, 5]) from numbers(2); +select mapFromArrays(materialize(['aa', 'bb']), materialize([4, 5])) from numbers(2); +select mapFromArrays('aa', [4, 5]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select mapFromArrays(['aa', 'bb'], 5); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select mapFromArrays(['aa', 'bb'], [4, 5], [6, 7]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +select mapFromArrays(['aa', 'bb'], [4, 5, 6]); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH } +select mapFromArrays([[1,2], [3,4]], [4, 5, 6]); -- { serverError BAD_ARGUMENTS } + +select mapFromArrays(['aa', 'bb'], map('a', 4, 'b', 5)); +select mapFromArrays(['aa', 'bb'], materialize(map('a', 4, 'b', 5))) from numbers(2); +select mapFromArrays(map('a', 4, 'b', 4), ['aa', 'bb']) from numbers(2); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select mapFromArrays(['aa', 'bb'], map('a', 4)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH } diff --git a/tests/queries/0_stateless/02234_cast_to_ip_address.reference b/tests/queries/0_stateless/02234_cast_to_ip_address.reference index 96aae2a978c9..9023b36a9bfc 100644 --- a/tests/queries/0_stateless/02234_cast_to_ip_address.reference +++ b/tests/queries/0_stateless/02234_cast_to_ip_address.reference @@ -31,6 +31,9 @@ IPv6 functions ::ffff:127.0.0.1 ::ffff:127.0.0.1 ::ffff:127.0.0.1 +:: +\N +100000000 -- ::ffff:127.0.0.1 -- diff --git a/tests/queries/0_stateless/02234_cast_to_ip_address.sql b/tests/queries/0_stateless/02234_cast_to_ip_address.sql index 436f232e441e..6c65fe86cc91 100644 --- a/tests/queries/0_stateless/02234_cast_to_ip_address.sql +++ b/tests/queries/0_stateless/02234_cast_to_ip_address.sql @@ -56,6 +56,12 @@ SELECT toIPv6('::ffff:127.0.0.1'); SELECT toIPv6OrDefault('::ffff:127.0.0.1'); SELECT toIPv6OrNull('::ffff:127.0.0.1'); +SELECT toIPv6('::.1.2.3'); --{serverError CANNOT_PARSE_IPV6} +SELECT toIPv6OrDefault('::.1.2.3'); +SELECT toIPv6OrNull('::.1.2.3'); + +SELECT count() FROM numbers_mt(100000000) WHERE NOT ignore(toIPv6OrZero(randomString(8))); + SELECT '--'; SELECT cast('test' , 'IPv6'); --{serverError CANNOT_PARSE_IPV6} diff --git a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference index d225cf5f332d..1a8d03bc320f 100644 --- a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference +++ b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference @@ -421,6 +421,7 @@ mapContains mapContainsKeyLike mapExtractKeyLike mapFilter +mapFromArrays mapKeys mapPopulateSeries mapSubtract diff --git a/tests/queries/0_stateless/02540_duplicate_primary_key.reference b/tests/queries/0_stateless/02540_duplicate_primary_key.reference new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/queries/0_stateless/02540_duplicate_primary_key.sql b/tests/queries/0_stateless/02540_duplicate_primary_key.sql new file mode 100644 index 000000000000..5934f5973347 --- /dev/null +++ b/tests/queries/0_stateless/02540_duplicate_primary_key.sql @@ -0,0 +1,106 @@ +drop table if exists test; + +set allow_suspicious_low_cardinality_types = 1; + +CREATE TABLE test +( + `timestamp` DateTime, + `latitude` Nullable(Float32) CODEC(Gorilla, ZSTD(1)), + `longitude` Nullable(Float32) CODEC(Gorilla, ZSTD(1)), + `m_registered` UInt8, + `m_mcc` Nullable(Int16), + `m_mnc` Nullable(Int16), + `m_ci` Nullable(Int32), + `m_tac` Nullable(Int32), + `enb_id` Nullable(Int32), + `ci` Nullable(Int32), + `m_earfcn` Int32, + `rsrp` Nullable(Int16), + `rsrq` Nullable(Int16), + `cqi` Nullable(Int16), + `source` String, + `gps_accuracy` Nullable(Float32), + `operator_name` String, + `band` Nullable(String), + `NAME_2` String, + `NAME_1` String, + `quadkey_19_key` FixedString(19), + `quadkey_17_key` FixedString(17), + `manipulation` UInt8, + `ss_rsrp` Nullable(Int16), + `ss_rsrq` Nullable(Int16), + `ss_sinr` Nullable(Int16), + `csi_rsrp` Nullable(Int16), + `csi_rsrq` Nullable(Int16), + `csi_sinr` Nullable(Int16), + `altitude` Nullable(Float32), + `access_technology` Nullable(String), + `buildingtype` String, + `LocationType` String, + `carrier_name` Nullable(String), + `CustomPolygonName` String, + `h3_10_pixel` UInt64, + `stc_cluster` Nullable(String), + PROJECTION cumsum_projection_simple + ( + SELECT + m_registered, + toStartOfInterval(timestamp, toIntervalMonth(1)), + toStartOfWeek(timestamp, 8), + toStartOfInterval(timestamp, toIntervalDay(1)), + NAME_1, + NAME_2, + operator_name, + rsrp, + rsrq, + ss_rsrp, + ss_rsrq, + cqi, + sum(multiIf(ss_rsrp IS NULL, 0, 1)), + sum(multiIf(ss_rsrq IS NULL, 0, 1)), + sum(multiIf(ss_sinr IS NULL, 0, 1)), + max(toStartOfInterval(timestamp, toIntervalDay(1))), + max(CAST(CAST(toStartOfInterval(timestamp, toIntervalDay(1)), 'Nullable(DATE)'), 'Nullable(TIMESTAMP)')), + min(toStartOfInterval(timestamp, toIntervalDay(1))), + min(CAST(CAST(toStartOfInterval(timestamp, toIntervalDay(1)), 'Nullable(DATE)'), 'Nullable(TIMESTAMP)')), + count(), + sum(1) + GROUP BY + m_registered, + toStartOfInterval(timestamp, toIntervalMonth(1)), + toStartOfWeek(timestamp, 8), + toStartOfInterval(timestamp, toIntervalDay(1)), + m_registered, + toStartOfInterval(timestamp, toIntervalMonth(1)), + toStartOfWeek(timestamp, 8), + toStartOfInterval(timestamp, toIntervalDay(1)), + NAME_1, + NAME_2, + operator_name, + rsrp, + rsrq, + ss_rsrp, + ss_rsrq, + cqi + ) +) +ENGINE = MergeTree +PARTITION BY toYYYYMM(timestamp) +ORDER BY (timestamp, operator_name, NAME_1, NAME_2) +SETTINGS index_granularity = 8192; + +insert into test select * from generateRandom() limit 10; + +with tt as ( + select cast(toStartOfInterval(timestamp, INTERVAL 1 day) as Date) as dd, count() as samples + from test + group by dd having dd >= toDate(now())-100 + ), +tt2 as ( + select dd, samples from tt + union distinct + select toDate(now())-1, ifnull((select samples from tt where dd = toDate(now())-1),0) as samples +) +select dd, samples from tt2 order by dd with fill step 1 limit 100 format Null; + +drop table test; diff --git a/tests/queries/0_stateless/2010_lc_native.python b/tests/queries/0_stateless/2010_lc_native.python new file mode 100755 index 000000000000..c850bf3f9060 --- /dev/null +++ b/tests/queries/0_stateless/2010_lc_native.python @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import socket +import os + +CLICKHOUSE_HOST = os.environ.get('CLICKHOUSE_HOST', '127.0.0.1') +CLICKHOUSE_PORT = int(os.environ.get('CLICKHOUSE_PORT_TCP', '900000')) +CLICKHOUSE_DATABASE = os.environ.get('CLICKHOUSE_DATABASE', 'default') + +def writeVarUInt(x, ba): + for _ in range(0, 9): + + byte = x & 0x7F + if x > 0x7F: + byte |= 0x80 + + ba.append(byte) + + x >>= 7 + if x == 0: + return + + +def writeStringBinary(s, ba): + b = bytes(s, 'utf-8') + writeVarUInt(len(s), ba) + ba.extend(b) + + +def readStrict(s, size = 1): + res = bytearray() + while size: + cur = s.recv(size) + # if not res: + # raise "Socket is closed" + size -= len(cur) + res.extend(cur) + + return res + + +def readUInt(s, size=1): + res = readStrict(s, size) + val = 0 + for i in range(len(res)): + val += res[i] << (i * 8) + return val + +def readUInt8(s): + return readUInt(s) + +def readUInt16(s): + return readUInt(s, 2) + +def readUInt32(s): + return readUInt(s, 4) + +def readUInt64(s): + return readUInt(s, 8) + +def readVarUInt(s): + x = 0 + for i in range(9): + byte = readStrict(s)[0] + x |= (byte & 0x7F) << (7 * i) + + if not byte & 0x80: + return x + + return x + + +def readStringBinary(s): + size = readVarUInt(s) + s = readStrict(s, size) + return s.decode('utf-8') + + +def sendHello(s): + ba = bytearray() + writeVarUInt(0, ba) # Hello + writeStringBinary('simple native protocol', ba) + writeVarUInt(21, ba) + writeVarUInt(9, ba) + writeVarUInt(54449, ba) + writeStringBinary('default', ba) # database + writeStringBinary('default', ba) # user + writeStringBinary('', ba) # pwd + s.sendall(ba) + + +def receiveHello(s): + p_type = readVarUInt(s) + assert (p_type == 0) # Hello + server_name = readStringBinary(s) + # print("Server name: ", server_name) + server_version_major = readVarUInt(s) + # print("Major: ", server_version_major) + server_version_minor = readVarUInt(s) + # print("Minor: ", server_version_minor) + server_revision = readVarUInt(s) + # print("Revision: ", server_revision) + server_timezone = readStringBinary(s) + # print("Timezone: ", server_timezone) + server_display_name = readStringBinary(s) + # print("Display name: ", server_display_name) + server_version_patch = readVarUInt(s) + # print("Version patch: ", server_version_patch) + + +def serializeClientInfo(ba): + writeStringBinary('default', ba) # initial_user + writeStringBinary('123456', ba) # initial_query_id + writeStringBinary('127.0.0.1:9000', ba) # initial_address + ba.extend([0] * 8) # initial_query_start_time_microseconds + ba.append(1) # TCP + writeStringBinary('os_user', ba) # os_user + writeStringBinary('client_hostname', ba) # client_hostname + writeStringBinary('client_name', ba) # client_name + writeVarUInt(21, ba) + writeVarUInt(9, ba) + writeVarUInt(54449, ba) + writeStringBinary('', ba) # quota_key + writeVarUInt(0, ba) # distributed_depth + writeVarUInt(1, ba) # client_version_patch + ba.append(0) # No telemetry + + +def sendQuery(s, query): + ba = bytearray() + writeVarUInt(1, ba) # query + writeStringBinary('123456', ba) + + ba.append(1) # INITIAL_QUERY + + # client info + serializeClientInfo(ba) + + writeStringBinary('', ba) # No settings + writeStringBinary('', ba) # No interserver secret + writeVarUInt(2, ba) # Stage - Complete + ba.append(0) # No compression + writeStringBinary(query + ' settings input_format_defaults_for_omitted_fields=0', ba) # query, finally + s.sendall(ba) + + +def serializeBlockInfo(ba): + writeVarUInt(1, ba) # 1 + ba.append(0) # is_overflows + writeVarUInt(2, ba) # 2 + writeVarUInt(0, ba) # 0 + ba.extend([0] * 4) # bucket_num + + +def sendEmptyBlock(s): + ba = bytearray() + writeVarUInt(2, ba) # Data + writeStringBinary('', ba) + serializeBlockInfo(ba) + writeVarUInt(0, ba) # rows + writeVarUInt(0, ba) # columns + s.sendall(ba) + + +def readHeader(s): + readVarUInt(s) # Data + readStringBinary(s) # external table name + # BlockInfo + readVarUInt(s) # 1 + readUInt8(s) # is_overflows + readVarUInt(s) # 2 + readUInt32(s) # bucket_num + readVarUInt(s) # 0 + columns = readVarUInt(s) # rows + rows = readVarUInt(s) # columns + print("Rows {} Columns {}".format(rows, columns)) + for _ in range(columns): + col_name = readStringBinary(s) + type_name = readStringBinary(s) + print("Column {} type {}".format(col_name, type_name)) + + +def readException(s): + assert(readVarUInt(s) == 2) + code = readUInt32(s) + name = readStringBinary(s) + text = readStringBinary(s) + readStringBinary(s) # trace + assert(readUInt8(s) == 0) # has_nested + print("code {}: {}".format(code, text.replace('DB::Exception:', ''))) + + +def insertValidLowCardinalityRow(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(10) + s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT)) + sendHello(s) + receiveHello(s) + sendQuery(s, 'insert into {}.tab format TSV'.format(CLICKHOUSE_DATABASE)) + + # external tables + sendEmptyBlock(s) + readHeader(s) + + # Data + ba = bytearray() + writeVarUInt(2, ba) # Data + writeStringBinary('', ba) + serializeBlockInfo(ba) + writeVarUInt(1, ba) # rows + writeVarUInt(1, ba) # columns + writeStringBinary('x', ba) + writeStringBinary('LowCardinality(String)', ba) + ba.extend([1] + [0] * 7) # SharedDictionariesWithAdditionalKeys + ba.extend([3, 2] + [0] * 6) # indexes type: UInt64 [3], with additional keys [2] + ba.extend([1] + [0] * 7) # num_keys in dict + writeStringBinary('hello', ba) # key + ba.extend([1] + [0] * 7) # num_indexes + ba.extend([0] * 8) # UInt64 index (0 for 'hello') + s.sendall(ba) + + # Fin block + sendEmptyBlock(s) + + assert(readVarUInt(s) == 5) # End of stream + s.close() + + +def insertLowCardinalityRowWithIndexOverflow(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(10) + s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT)) + sendHello(s) + receiveHello(s) + sendQuery(s, 'insert into {}.tab format TSV'.format(CLICKHOUSE_DATABASE)) + + # external tables + sendEmptyBlock(s) + readHeader(s) + + # Data + ba = bytearray() + writeVarUInt(2, ba) # Data + writeStringBinary('', ba) + serializeBlockInfo(ba) + writeVarUInt(1, ba) # rows + writeVarUInt(1, ba) # columns + writeStringBinary('x', ba) + writeStringBinary('LowCardinality(String)', ba) + ba.extend([1] + [0] * 7) # SharedDictionariesWithAdditionalKeys + ba.extend([3, 2] + [0] * 6) # indexes type: UInt64 [3], with additional keys [2] + ba.extend([1] + [0] * 7) # num_keys in dict + writeStringBinary('hello', ba) # key + ba.extend([1] + [0] * 7) # num_indexes + ba.extend([0] * 7 + [1]) # UInt64 index (overflow) + s.sendall(ba) + + readException(s) + s.close() + + +def insertLowCardinalityRowWithIncorrectDictType(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(10) + s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT)) + sendHello(s) + receiveHello(s) + sendQuery(s, 'insert into {}.tab format TSV'.format(CLICKHOUSE_DATABASE)) + + # external tables + sendEmptyBlock(s) + readHeader(s) + + # Data + ba = bytearray() + writeVarUInt(2, ba) # Data + writeStringBinary('', ba) + serializeBlockInfo(ba) + writeVarUInt(1, ba) # rows + writeVarUInt(1, ba) # columns + writeStringBinary('x', ba) + writeStringBinary('LowCardinality(String)', ba) + ba.extend([1] + [0] * 7) # SharedDictionariesWithAdditionalKeys + ba.extend([3, 3] + [0] * 6) # indexes type: UInt64 [3], with global dict and add keys [1 + 2] + ba.extend([1] + [0] * 7) # num_keys in dict + writeStringBinary('hello', ba) # key + ba.extend([1] + [0] * 7) # num_indexes + ba.extend([0] * 8) # UInt64 index (overflow) + s.sendall(ba) + + readException(s) + s.close() + + +def main(): + insertValidLowCardinalityRow() + insertLowCardinalityRowWithIndexOverflow() + insertLowCardinalityRowWithIncorrectDictType() + +if __name__ == "__main__": + main() diff --git a/tests/queries/0_stateless/2010_lc_native.reference b/tests/queries/0_stateless/2010_lc_native.reference new file mode 100644 index 000000000000..0167f05c952b --- /dev/null +++ b/tests/queries/0_stateless/2010_lc_native.reference @@ -0,0 +1,8 @@ +Rows 0 Columns 1 +Column x type LowCardinality(String) +Rows 0 Columns 1 +Column x type LowCardinality(String) +code 117: Index for LowCardinality is out of range. Dictionary size is 1, but found index with value 72057594037927936 +Rows 0 Columns 1 +Column x type LowCardinality(String) +code 117: LowCardinality indexes serialization type for Native format cannot use global dictionary diff --git a/tests/queries/0_stateless/2010_lc_native.sh b/tests/queries/0_stateless/2010_lc_native.sh new file mode 100755 index 000000000000..0890e271318c --- /dev/null +++ b/tests/queries/0_stateless/2010_lc_native.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +$CLICKHOUSE_CLIENT -q "drop table if exists tab;" +$CLICKHOUSE_CLIENT -q "create table tab(x LowCardinality(String)) engine = MergeTree order by tuple();" + +# We should have correct env vars from shell_config.sh to run this test +python3 "$CURDIR"/2010_lc_native.python + +$CLICKHOUSE_CLIENT -q "drop table if exists tab;" diff --git a/tests/queries/0_stateless/2013_lc_nullable_and_infinity.reference b/tests/queries/0_stateless/2013_lc_nullable_and_infinity.reference new file mode 100644 index 000000000000..ef5038b2236f --- /dev/null +++ b/tests/queries/0_stateless/2013_lc_nullable_and_infinity.reference @@ -0,0 +1,4 @@ +0 \N + +0 \N +0 \N diff --git a/tests/queries/0_stateless/2013_lc_nullable_and_infinity.sql b/tests/queries/0_stateless/2013_lc_nullable_and_infinity.sql new file mode 100644 index 000000000000..c1c8a9c00b1a --- /dev/null +++ b/tests/queries/0_stateless/2013_lc_nullable_and_infinity.sql @@ -0,0 +1,3 @@ +set receive_timeout = '10', receive_data_timeout_ms = '10000', extremes = '1', allow_suspicious_low_cardinality_types = '1', force_primary_key = '1', join_use_nulls = '1', max_rows_to_read = '1', join_algorithm = 'partial_merge'; + +SELECT * FROM (SELECT dummy AS val FROM system.one) AS s1 ANY LEFT JOIN (SELECT toLowCardinality(dummy) AS rval FROM system.one) AS s2 ON (val + 9223372036854775806) = (rval * 1); diff --git a/tests/queries/0_stateless/2014_dict_get_nullable_key.reference b/tests/queries/0_stateless/2014_dict_get_nullable_key.reference new file mode 100644 index 000000000000..08127d35829a --- /dev/null +++ b/tests/queries/0_stateless/2014_dict_get_nullable_key.reference @@ -0,0 +1,13 @@ +Non nullable value only null key +\N +Non nullable value nullable key +Test +\N + +Nullable value only null key +\N +Nullable value nullable key +Test +\N +\N +\N diff --git a/tests/queries/0_stateless/2014_dict_get_nullable_key.sql b/tests/queries/0_stateless/2014_dict_get_nullable_key.sql new file mode 100644 index 000000000000..d6c058b285f8 --- /dev/null +++ b/tests/queries/0_stateless/2014_dict_get_nullable_key.sql @@ -0,0 +1,29 @@ +DROP TABLE IF EXISTS dictionary_non_nullable_source_table; +CREATE TABLE dictionary_non_nullable_source_table (id UInt64, value String) ENGINE=TinyLog; +INSERT INTO dictionary_non_nullable_source_table VALUES (0, 'Test'); + +DROP DICTIONARY IF EXISTS test_dictionary_non_nullable; +CREATE DICTIONARY test_dictionary_non_nullable (id UInt64, value String) PRIMARY KEY id LAYOUT(DIRECT()) SOURCE(CLICKHOUSE(TABLE 'dictionary_non_nullable_source_table')); + +SELECT 'Non nullable value only null key '; +SELECT dictGet('test_dictionary_non_nullable', 'value', NULL); +SELECT 'Non nullable value nullable key'; +SELECT dictGet('test_dictionary_non_nullable', 'value', arrayJoin([toUInt64(0), NULL, 1])); + +DROP DICTIONARY test_dictionary_non_nullable; +DROP TABLE dictionary_non_nullable_source_table; + +DROP TABLE IF EXISTS dictionary_nullable_source_table; +CREATE TABLE dictionary_nullable_source_table (id UInt64, value Nullable(String)) ENGINE=TinyLog; +INSERT INTO dictionary_nullable_source_table VALUES (0, 'Test'), (1, NULL); + +DROP DICTIONARY IF EXISTS test_dictionary_nullable; +CREATE DICTIONARY test_dictionary_nullable (id UInt64, value Nullable(String)) PRIMARY KEY id LAYOUT(DIRECT()) SOURCE(CLICKHOUSE(TABLE 'dictionary_nullable_source_table')); + +SELECT 'Nullable value only null key '; +SELECT dictGet('test_dictionary_nullable', 'value', NULL); +SELECT 'Nullable value nullable key'; +SELECT dictGet('test_dictionary_nullable', 'value', arrayJoin([toUInt64(0), NULL, 1, 2])); + +DROP DICTIONARY test_dictionary_nullable; +DROP TABLE dictionary_nullable_source_table; diff --git a/tests/queries/0_stateless/2015_column_default_dict_get_identifier.reference b/tests/queries/0_stateless/2015_column_default_dict_get_identifier.reference new file mode 100644 index 000000000000..29e04d559e12 --- /dev/null +++ b/tests/queries/0_stateless/2015_column_default_dict_get_identifier.reference @@ -0,0 +1 @@ +5 0 diff --git a/tests/queries/0_stateless/2015_column_default_dict_get_identifier.sql b/tests/queries/0_stateless/2015_column_default_dict_get_identifier.sql new file mode 100644 index 000000000000..292f53952d03 --- /dev/null +++ b/tests/queries/0_stateless/2015_column_default_dict_get_identifier.sql @@ -0,0 +1,37 @@ +DROP TABLE IF EXISTS test_table; +CREATE TABLE test_table +( + key_column UInt64, + data_column_1 UInt64, + data_column_2 UInt8 +) +ENGINE = MergeTree +ORDER BY key_column; + +INSERT INTO test_table VALUES (0, 0, 0); + +DROP DICTIONARY IF EXISTS test_dictionary; +CREATE DICTIONARY test_dictionary +( + key_column UInt64 DEFAULT 0, + data_column_1 UInt64 DEFAULT 1, + data_column_2 UInt8 DEFAULT 1 +) +PRIMARY KEY key_column +LAYOUT(DIRECT()) +SOURCE(CLICKHOUSE(TABLE 'test_table')); + +DROP TABLE IF EXISTS test_table_default; +CREATE TABLE test_table_default +( + data_1 DEFAULT dictGetUInt64('test_dictionary', 'data_column_1', toUInt64(0)), + data_2 DEFAULT dictGet(test_dictionary, 'data_column_2', toUInt64(0)) +) +ENGINE=TinyLog; + +INSERT INTO test_table_default(data_1) VALUES (5); +SELECT * FROM test_table_default; + +DROP DICTIONARY test_dictionary; +DROP TABLE test_table; +DROP TABLE test_table_default; diff --git a/tests/queries/0_stateless/2015_global_in_threads.reference b/tests/queries/0_stateless/2015_global_in_threads.reference new file mode 100644 index 000000000000..af81158ecae0 --- /dev/null +++ b/tests/queries/0_stateless/2015_global_in_threads.reference @@ -0,0 +1,2 @@ +10 +1 diff --git a/tests/queries/0_stateless/2015_global_in_threads.sh b/tests/queries/0_stateless/2015_global_in_threads.sh new file mode 100755 index 000000000000..c112e47fe92f --- /dev/null +++ b/tests/queries/0_stateless/2015_global_in_threads.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +${CLICKHOUSE_CLIENT} --log_queries=1 --max_threads=32 --query_id "2015_${CLICKHOUSE_DATABASE}_query" -q "select count() from remote('127.0.0.{2,3}', numbers(10)) where number global in (select number % 5 from numbers_mt(1000000))" +${CLICKHOUSE_CLIENT} -q "system flush logs" +${CLICKHOUSE_CLIENT} -q "select length(thread_ids) >= 32 from system.query_log where event_date = today() and query_id = '2015_${CLICKHOUSE_DATABASE}_query' and type = 'QueryFinish' and current_database = currentDatabase()" diff --git a/tests/queries/0_stateless/2015_order_by_with_fill_misoptimization.reference b/tests/queries/0_stateless/2015_order_by_with_fill_misoptimization.reference new file mode 100644 index 000000000000..07258cd829ac --- /dev/null +++ b/tests/queries/0_stateless/2015_order_by_with_fill_misoptimization.reference @@ -0,0 +1,9 @@ + + + + + + + + +Hello diff --git a/tests/queries/0_stateless/2015_order_by_with_fill_misoptimization.sql b/tests/queries/0_stateless/2015_order_by_with_fill_misoptimization.sql new file mode 100644 index 000000000000..f0d90f151b2c --- /dev/null +++ b/tests/queries/0_stateless/2015_order_by_with_fill_misoptimization.sql @@ -0,0 +1 @@ +SELECT s FROM (SELECT 5 AS x, 'Hello' AS s ORDER BY x WITH FILL FROM 1 TO 10) ORDER BY s; diff --git a/tests/queries/0_stateless/2016_agg_empty_result_bug_28880.reference b/tests/queries/0_stateless/2016_agg_empty_result_bug_28880.reference new file mode 100644 index 000000000000..9edaf84f2959 --- /dev/null +++ b/tests/queries/0_stateless/2016_agg_empty_result_bug_28880.reference @@ -0,0 +1,5 @@ +0 +0 +0 +0 +\N diff --git a/tests/queries/0_stateless/2016_agg_empty_result_bug_28880.sql b/tests/queries/0_stateless/2016_agg_empty_result_bug_28880.sql new file mode 100644 index 000000000000..005358eb4254 --- /dev/null +++ b/tests/queries/0_stateless/2016_agg_empty_result_bug_28880.sql @@ -0,0 +1,10 @@ +SELECT count() AS cnt WHERE 0 HAVING cnt = 0; + +select cnt from (select count() cnt where 0) where cnt = 0; + +select cnt from (select count() cnt from system.one where 0) where cnt = 0; + +select sum from (select sum(dummy) sum from system.one where 0) where sum = 0; + +set aggregate_functions_null_for_empty=1; +select sum from (select sum(dummy) sum from system.one where 0) where sum is null; diff --git a/tests/queries/0_stateless/2016_order_by_with_fill_monotonic_functions_removal.reference b/tests/queries/0_stateless/2016_order_by_with_fill_monotonic_functions_removal.reference new file mode 100644 index 000000000000..264f29a6ecd1 --- /dev/null +++ b/tests/queries/0_stateless/2016_order_by_with_fill_monotonic_functions_removal.reference @@ -0,0 +1,3 @@ +2021-07-07 15:21:00 +2021-07-07 15:21:05 +2021-07-07 15:21:10 diff --git a/tests/queries/0_stateless/2016_order_by_with_fill_monotonic_functions_removal.sql b/tests/queries/0_stateless/2016_order_by_with_fill_monotonic_functions_removal.sql new file mode 100644 index 000000000000..bf232ed5c864 --- /dev/null +++ b/tests/queries/0_stateless/2016_order_by_with_fill_monotonic_functions_removal.sql @@ -0,0 +1,6 @@ +SELECT toStartOfMinute(some_time) AS ts +FROM +( + SELECT toDateTime('2021-07-07 15:21:05') AS some_time +) +ORDER BY ts ASC WITH FILL FROM toDateTime('2021-07-07 15:21:00') TO toDateTime('2021-07-07 15:21:15') STEP 5; diff --git a/tests/queries/0_stateless/2017_order_by_with_fill_redundant_functions.reference b/tests/queries/0_stateless/2017_order_by_with_fill_redundant_functions.reference new file mode 100644 index 000000000000..07193989308c --- /dev/null +++ b/tests/queries/0_stateless/2017_order_by_with_fill_redundant_functions.reference @@ -0,0 +1,9 @@ +1 +2 +3 +4 +5 +6 +7 +8 +9 diff --git a/tests/queries/0_stateless/2017_order_by_with_fill_redundant_functions.sql b/tests/queries/0_stateless/2017_order_by_with_fill_redundant_functions.sql new file mode 100644 index 000000000000..6f3e6787c344 --- /dev/null +++ b/tests/queries/0_stateless/2017_order_by_with_fill_redundant_functions.sql @@ -0,0 +1 @@ +SELECT x FROM (SELECT 5 AS x) ORDER BY -x, x WITH FILL FROM 1 TO 10; diff --git a/tests/queries/0_stateless/2018_multiple_with_fill_for_the_same_column.reference b/tests/queries/0_stateless/2018_multiple_with_fill_for_the_same_column.reference new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/queries/0_stateless/2018_multiple_with_fill_for_the_same_column.sql b/tests/queries/0_stateless/2018_multiple_with_fill_for_the_same_column.sql new file mode 100644 index 000000000000..32b38388cf6d --- /dev/null +++ b/tests/queries/0_stateless/2018_multiple_with_fill_for_the_same_column.sql @@ -0,0 +1 @@ +SELECT x, y FROM (SELECT 5 AS x, 'Hello' AS y) ORDER BY x WITH FILL FROM 3 TO 7, y, x WITH FILL FROM 1 TO 10; -- { serverError 475 } diff --git a/tests/queries/0_stateless/2019_multiple_weird_with_fill.reference b/tests/queries/0_stateless/2019_multiple_weird_with_fill.reference new file mode 100644 index 000000000000..822d290564a8 --- /dev/null +++ b/tests/queries/0_stateless/2019_multiple_weird_with_fill.reference @@ -0,0 +1,45 @@ +3 -10 +3 -9 +3 -8 +3 -7 +3 -6 +3 -5 +3 -4 +3 -3 +3 -2 +4 -10 +4 -9 +4 -8 +4 -7 +4 -6 +4 -5 +4 -4 +4 -3 +4 -2 +5 -10 +5 -9 +5 -8 +5 -7 +5 -6 +5 -5 Hello +5 -4 +5 -3 +5 -2 +6 -10 +6 -9 +6 -8 +6 -7 +6 -6 +6 -5 +6 -4 +6 -3 +6 -2 +7 -10 +7 -9 +7 -8 +7 -7 +7 -6 +7 -5 +7 -4 +7 -3 +7 -2 diff --git a/tests/queries/0_stateless/2019_multiple_weird_with_fill.sql b/tests/queries/0_stateless/2019_multiple_weird_with_fill.sql new file mode 100644 index 000000000000..a2ed33c51ddc --- /dev/null +++ b/tests/queries/0_stateless/2019_multiple_weird_with_fill.sql @@ -0,0 +1,14 @@ +SELECT + x, + -x, + y +FROM +( + SELECT + 5 AS x, + 'Hello' AS y +) +ORDER BY + x ASC WITH FILL FROM 3 TO 7, + y ASC, + -x ASC WITH FILL FROM -10 TO -1; diff --git a/tests/queries/0_stateless/2020_cast_integer_overflow.reference b/tests/queries/0_stateless/2020_cast_integer_overflow.reference new file mode 100644 index 000000000000..acceae4a72e1 --- /dev/null +++ b/tests/queries/0_stateless/2020_cast_integer_overflow.reference @@ -0,0 +1,2 @@ +-2147483648 +-2147483648 diff --git a/tests/queries/0_stateless/2020_cast_integer_overflow.sql b/tests/queries/0_stateless/2020_cast_integer_overflow.sql new file mode 100644 index 000000000000..57aeff9a9828 --- /dev/null +++ b/tests/queries/0_stateless/2020_cast_integer_overflow.sql @@ -0,0 +1,2 @@ +SELECT toInt32('-2147483648'); +SELECT toInt32OrNull('-2147483648'); diff --git a/tests/queries/0_stateless/2025_having_filter_column.reference b/tests/queries/0_stateless/2025_having_filter_column.reference new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/queries/0_stateless/2025_having_filter_column.sql b/tests/queries/0_stateless/2025_having_filter_column.sql new file mode 100644 index 000000000000..aab419adc160 --- /dev/null +++ b/tests/queries/0_stateless/2025_having_filter_column.sql @@ -0,0 +1,40 @@ +drop table if exists test; + +-- #29010 +CREATE TABLE test +( + d DateTime, + a String, + b UInt64 +) +ENGINE = MergeTree +PARTITION BY toDate(d) +ORDER BY d; + +SELECT * +FROM ( + SELECT + a, + max((d, b)).2 AS value + FROM test + GROUP BY rollup(a) +) +WHERE a <> ''; + +-- the same query, but after syntax optimization +SELECT + a, + value +FROM +( + SELECT + a, + max((d, b)).2 AS value + FROM test + GROUP BY a + WITH ROLLUP + HAVING a != '' +) +WHERE a != ''; + +drop table if exists test; diff --git a/utils/CMakeLists.txt b/utils/CMakeLists.txt index bb9d4c88fa16..98f344c7baac 100644 --- a/utils/CMakeLists.txt +++ b/utils/CMakeLists.txt @@ -16,6 +16,9 @@ add_subdirectory (report) # Not used in package if (NOT DEFINED ENABLE_UTILS OR ENABLE_UTILS) add_subdirectory (compressor) + add_subdirectory (local-engine) + add_subdirectory (clickhouse-dep) +# add_subdirectory (iotest) add_subdirectory (corrector_utf8) add_subdirectory (zookeeper-cli) add_subdirectory (zookeeper-dump-tree) diff --git a/utils/clickhouse-dep/CMakeLists.txt b/utils/clickhouse-dep/CMakeLists.txt new file mode 100644 index 000000000000..b0cfaace618d --- /dev/null +++ b/utils/clickhouse-dep/CMakeLists.txt @@ -0,0 +1,115 @@ +set(SOURCE_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) + +# prepare gtest +set(GTEST_LIB "") +if (ENABLE_TESTS) + set(GTEST_LIB _gtest _gtest_main) +endif() +add_custom_command(OUTPUT copy_deps + COMMAND rm -rf deps && mkdir -p deps && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + + cp $ deps/ && + cp ${CMAKE_BINARY_DIR}/contrib/boost-cmake/*.a deps/ && + cp ${CMAKE_BINARY_DIR}/contrib/abseil-cpp/absl/*/*.a deps/ && + cp ${CMAKE_BINARY_DIR}/contrib/icu-cmake/*.a deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + + cp ${CMAKE_BINARY_DIR}/contrib/libunwind-cmake/*.a deps/ && + cp ${CMAKE_BINARY_DIR}/src/Functions/divide/*.a deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp ${CMAKE_BINARY_DIR}/contrib/poco-cmake/XML/*.a deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp ${CMAKE_BINARY_DIR}/contrib/llvm-project/llvm/lib/*.a deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ && + cp $ deps/ + DEPENDS dbms + clickhouse_parsers + clickhouse_aggregate_functions + clickhouse_common_io + clickhouse_new_delete + clickhouse_functions_jsonpath + clickhouse_common_zookeeper + ${GTEST_LIB} + ) + +add_custom_command(OUTPUT copy_objects + COMMAND rm -rf objects && mkdir -p objects && + cp ${CMAKE_BINARY_DIR}/src/Functions/URL/CMakeFiles/clickhouse_functions_url.dir/*.o objects/ && + cp ${CMAKE_BINARY_DIR}/src/Functions/array/CMakeFiles/clickhouse_functions_array.dir/*.o objects/ && + cp ${CMAKE_BINARY_DIR}/src/Functions/CMakeFiles/clickhouse_functions_obj.dir/*.o objects/ + DEPENDS clickhouse_functions_obj clickhouse_functions_array clickhouse_functions_url readpassphrase) + +add_custom_target(build_clickhouse_dep ALL DEPENDS copy_objects copy_deps) + + diff --git a/utils/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.cpp b/utils/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.cpp new file mode 100644 index 000000000000..18a31b5e9f3c --- /dev/null +++ b/utils/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.cpp @@ -0,0 +1,84 @@ +#include +#include +#include + + +using namespace DB; + +namespace DB +{ + namespace ErrorCodes + { + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + } +} + +namespace local_engine +{ + +namespace +{ + + class AggregateFunctionCombinatorPartialMerge final : public IAggregateFunctionCombinator + { + public: + String getName() const override { return "PartialMerge"; } + + DataTypes transformArguments(const DataTypes & arguments) const override + { + if (arguments.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function with " + getName() + " suffix", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + const DataTypePtr & argument = arguments[0]; + + const DataTypeAggregateFunction * function = typeid_cast(argument.get()); + if (!function) + throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function with " + getName() + " suffix" + + " must be AggregateFunction(...)", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + const DataTypeAggregateFunction * function2 = typeid_cast(function->getArgumentsDataTypes()[0].get()); + if (function2) { + return transformArguments(function->getArgumentsDataTypes()); + } + return function->getArgumentsDataTypes(); + } + + AggregateFunctionPtr transformAggregateFunction( + const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties &, + const DataTypes & arguments, + const Array & params) const override + { + DataTypePtr & argument = const_cast(arguments[0]); + + const DataTypeAggregateFunction * function = typeid_cast(argument.get()); + if (!function) + throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function with " + getName() + " suffix" + + " must be AggregateFunction(...)", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + while (nested_function->getName() != function->getFunctionName()) { + argument = function->getArgumentsDataTypes()[0]; + function = typeid_cast(function->getArgumentsDataTypes()[0].get()); + if (!function) + throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function with " + getName() + " suffix" + + " must be AggregateFunction(...)", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + + if (nested_function->getName() != function->getFunctionName()) + throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function with " + getName() + " suffix" + + ", because it corresponds to different aggregate function: " + function->getFunctionName() + " instead of " + nested_function->getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(nested_function, argument, params); + } + }; + +} + +void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCombinatorFactory & factory) +{ + factory.registerCombinator(std::make_shared()); +} + +} diff --git a/utils/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h b/utils/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h new file mode 100644 index 000000000000..ee7120710a2e --- /dev/null +++ b/utils/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h @@ -0,0 +1,136 @@ +#pragma once + +#include +#include +#include +#include +#include + + +namespace DB +{ + namespace ErrorCodes + { + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + } +} + +namespace local_engine +{ + +using namespace DB; + +struct Settings; + +/** + * this class is copied from AggregateFunctionMerge with little enhancement. + * we use this PartialMerge for both spark PartialMerge and Final + */ + + +class AggregateFunctionPartialMerge final : public IAggregateFunctionHelper +{ +private: + AggregateFunctionPtr nested_func; + +public: + AggregateFunctionPartialMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument, const Array & params_) + : IAggregateFunctionHelper({argument}, params_, createResultType(nested_)) + , nested_func(nested_) + { + const DataTypeAggregateFunction * data_type = typeid_cast(argument.get()); + + if (!data_type || !nested_func->haveSameStateRepresentation(*data_type->getFunction())) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}, " + "expected {} or equivalent type", argument->getName(), getName(), getStateType()->getName()); + } + + String getName() const override + { + return nested_func->getName() + "PartialMerge"; + } + + static DataTypePtr createResultType(const AggregateFunctionPtr & nested_) + { + return nested_->getResultType(); + } + + const DataTypePtr & getResultType() const override + { + return nested_func->getResultType(); + } + + bool isVersioned() const override + { + return nested_func->isVersioned(); + } + + size_t getDefaultVersion() const override + { + return nested_func->getDefaultVersion(); + } + + DataTypePtr getStateType() const override + { + return nested_func->getStateType(); + } + + void create(AggregateDataPtr __restrict place) const override + { + nested_func->create(place); + } + + void destroy(AggregateDataPtr __restrict place) const noexcept override + { + nested_func->destroy(place); + } + + bool hasTrivialDestructor() const override + { + return nested_func->hasTrivialDestructor(); + } + + size_t sizeOfData() const override + { + return nested_func->sizeOfData(); + } + + size_t alignOfData() const override + { + return nested_func->alignOfData(); + } + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + nested_func->merge(place, assert_cast(*columns[0]).getData()[row_num], arena); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override + { + nested_func->merge(place, rhs, arena); + } + + void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional version) const override + { + nested_func->serialize(place, buf, version); + } + + void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional version, Arena * arena) const override + { + nested_func->deserialize(place, buf, version, arena); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + { + nested_func->insertResultInto(place, to, arena); + } + + bool allocatesMemoryInArena() const override + { + return nested_func->allocatesMemoryInArena(); + } + + AggregateFunctionPtr getNestedFunction() const override { return nested_func; } +}; + +} diff --git a/utils/local-engine/AggregateFunctions/CMakeLists.txt b/utils/local-engine/AggregateFunctions/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/local-engine/Builder/BroadCastJoinBuilder.cpp b/utils/local-engine/Builder/BroadCastJoinBuilder.cpp new file mode 100644 index 000000000000..b4449172fd51 --- /dev/null +++ b/utils/local-engine/Builder/BroadCastJoinBuilder.cpp @@ -0,0 +1,160 @@ +#include "BroadCastJoinBuilder.h" +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int UNKNOWN_TYPE; +} +} + +namespace local_engine +{ +using namespace DB; + +std::queue BroadCastJoinBuilder::storage_join_queue; +std::unordered_map> BroadCastJoinBuilder::storage_join_map; +std::unordered_map> BroadCastJoinBuilder::storage_join_lock; +std::mutex BroadCastJoinBuilder::join_lock_mutex; + +struct StorageJoinContext +{ + std::string key; + jobject input; + size_t io_buffer_size; + DB::Names key_names; + DB::JoinKind kind; + DB::JoinStrictness strictness; + DB::ColumnsDescription columns; +}; + +void BroadCastJoinBuilder::buildJoinIfNotExist( + const std::string & key, + jobject input, + size_t io_buffer_size, + const DB::Names & key_names_, + DB::JoinKind kind_, + DB::JoinStrictness strictness_, + const DB::ColumnsDescription & columns_) +{ + if (!storage_join_map.contains(key)) + { + std::lock_guard build_lock(join_lock_mutex); + if (!storage_join_map.contains(key)) + { + StorageJoinContext context + { + key, input, io_buffer_size, key_names_, kind_, strictness_, columns_ + }; + // use another thread, exclude broadcast memory allocation from current memory tracker + auto func = [context]() -> void + { + // limit memory usage + if (storage_join_queue.size() > 10) + { + auto tmp = storage_join_queue.front(); + storage_join_queue.pop(); + storage_join_map.erase(tmp); + } + auto storage_join = std::make_shared( + std::make_unique(context.input, context.io_buffer_size), + StorageID("default", context.key), + context.key_names, + true, + SizeLimits(), + context.kind, + context.strictness, + context.columns, + ConstraintsDescription(), + context.key, + true); + storage_join_map.emplace(context.key, storage_join); + storage_join_queue.push(context.key); + }; + ThreadFromGlobalPool build_thread(std::move(func)); + build_thread.join(); + } + } + else + { + GET_JNIENV(env) + // it needs to delete global ref of the input object, otherwise it will hold the input object + // and lead to memory leak. + env->DeleteGlobalRef(input); + CLEAN_JNIENV + } +} +std::shared_ptr BroadCastJoinBuilder::getJoin(const std::string & key) +{ + if (storage_join_map.contains(key)) + { + return storage_join_map.at(key); + } + else + { + return std::shared_ptr(); + } +} + void BroadCastJoinBuilder::buildJoinIfNotExist( + const std::string & key, + jobject input, + size_t io_buffer_size, + const std::string & join_keys, + const std::string & join_type, + const std::string & named_struct) +{ + auto join_key_list = Poco::StringTokenizer(join_keys, ","); + Names key_names; + for (const auto & key_name : join_key_list) + { + key_names.emplace_back(key_name); + } + DB::JoinKind kind; + DB::JoinStrictness strictness; + if (join_type == "Inner") + { + kind = DB::JoinKind::Inner; + strictness = DB::JoinStrictness::All; + } + else if (join_type == "Semi") + { + kind = DB::JoinKind::Left; + strictness = DB::JoinStrictness::Semi; + } + else if (join_type == "Anti") + { + kind = DB::JoinKind::Left; + strictness = DB::JoinStrictness::Anti; + } + else if (join_type == "Left") + { + kind = DB::JoinKind::Left; + strictness = DB::JoinStrictness::All; + } + else + { + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", join_type); + } + + auto substrait_struct = std::make_unique(); + substrait_struct->ParseFromString(named_struct); + + Block header = SerializedPlanParser::parseNameStruct(*substrait_struct); + ColumnsDescription columns_description(header.getNamesAndTypesList()); + buildJoinIfNotExist(key, input, io_buffer_size, key_names, kind, strictness, columns_description); +} +void BroadCastJoinBuilder::clean() +{ + storage_join_lock.clear(); + storage_join_map.clear(); + while (!storage_join_queue.empty()) + { + storage_join_queue.pop(); + } +} + +} diff --git a/utils/local-engine/Builder/BroadCastJoinBuilder.h b/utils/local-engine/Builder/BroadCastJoinBuilder.h new file mode 100644 index 000000000000..bf39580e4007 --- /dev/null +++ b/utils/local-engine/Builder/BroadCastJoinBuilder.h @@ -0,0 +1,37 @@ +#pragma once +#include +#include + +namespace local_engine +{ +class BroadCastJoinBuilder +{ +public: + static void buildJoinIfNotExist( + const std::string & key, + jobject input, + size_t io_buffer_size, + const DB::Names & key_names_, + DB::JoinKind kind_, + DB::JoinStrictness strictness_, + const DB::ColumnsDescription & columns_); + + static void buildJoinIfNotExist( + const std::string & key, + jobject input, + size_t io_buffer_size, + const std::string & join_keys, + const std::string & join_type, + const std::string & named_struct); + + static std::shared_ptr getJoin(const std::string & key); + + static void clean(); + +private: + static std::queue storage_join_queue; + static std::unordered_map> storage_join_map; + static std::unordered_map> storage_join_lock; + static std::mutex join_lock_mutex; +}; +} diff --git a/utils/local-engine/Builder/CMakeLists.txt b/utils/local-engine/Builder/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/local-engine/Builder/SerializedPlanBuilder.cpp b/utils/local-engine/Builder/SerializedPlanBuilder.cpp new file mode 100644 index 000000000000..e8aae9ecdca7 --- /dev/null +++ b/utils/local-engine/Builder/SerializedPlanBuilder.cpp @@ -0,0 +1,378 @@ +#include "SerializedPlanBuilder.h" +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int UNKNOWN_TYPE; +} +} + +namespace dbms +{ + +using namespace DB; +SchemaPtr SerializedSchemaBuilder::build() +{ + for (const auto & [name, type] : this->type_map) + { + this->schema->add_names(name); + auto * type_struct = this->schema->mutable_struct_(); + if (type == "I8") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_i8()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else if (type == "I32") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_i32()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else if (type == "I64") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_i64()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else if (type == "Boolean") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_bool_()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else if (type == "I16") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_i16()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else if (type == "String") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_string()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else if (type == "FP32") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_fp32()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else if (type == "FP64") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_fp64()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else if (type == "Date") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_date()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else if (type == "Timestamp") + { + auto * t = type_struct->mutable_types()->Add(); + t->mutable_timestamp()->set_nullability( + this->nullability_map[name] ? substrait::Type_Nullability_NULLABILITY_NULLABLE + : substrait::Type_Nullability_NULLABILITY_REQUIRED); + } + else + { + throw std::runtime_error("doesn't support type " + type); + } + } + return std::move(this->schema); +} +SerializedSchemaBuilder & SerializedSchemaBuilder::column(const std::string & name, const std::string & type, bool nullable) +{ + this->type_map.emplace(name, type); + this->nullability_map.emplace(name, nullable); + return *this; +} +SerializedSchemaBuilder::SerializedSchemaBuilder() : schema(new substrait::NamedStruct()) +{ +} +SerializedPlanBuilder & SerializedPlanBuilder::registerFunction(int id, const std::string & name) +{ + auto * extension = this->plan->mutable_extensions()->Add(); + auto * function_mapping = extension->mutable_extension_function(); + function_mapping->set_function_anchor(id); + function_mapping->set_name(name); + return *this; +} + +void SerializedPlanBuilder::setInputToPrev(substrait::Rel * input) +{ + if (!this->prev_rel) + { + auto * root = this->plan->mutable_relations()->Add()->mutable_root(); + root->set_allocated_input(input); + return; + } + if (this->prev_rel->has_filter()) + { + this->prev_rel->mutable_filter()->set_allocated_input(input); + } + else if (this->prev_rel->has_aggregate()) + { + this->prev_rel->mutable_aggregate()->set_allocated_input(input); + } + else if (this->prev_rel->has_project()) + { + this->prev_rel->mutable_project()->set_allocated_input(input); + } + else + { + throw std::runtime_error("does support rel type"); + } +} + +SerializedPlanBuilder & SerializedPlanBuilder::filter(substrait::Expression * condition) +{ + substrait::Rel * filter = new substrait::Rel(); + filter->mutable_filter()->set_allocated_condition(condition); + setInputToPrev(filter); + this->prev_rel = filter; + return *this; +} + +SerializedPlanBuilder & SerializedPlanBuilder::read(const std::string & path, SchemaPtr schema) +{ + substrait::Rel * rel = new substrait::Rel(); + auto * read = rel->mutable_read(); + read->mutable_local_files()->add_items()->set_uri_file(path); + read->set_allocated_base_schema(schema); + setInputToPrev(rel); + this->prev_rel = rel; + return *this; +} + +SerializedPlanBuilder & SerializedPlanBuilder::readMergeTree( + const std::string & database, + const std::string & table, + const std::string & relative_path, + int min_block, + int max_block, + SchemaPtr schema) +{ + substrait::Rel * rel = new substrait::Rel(); + auto * read = rel->mutable_read(); + read->mutable_extension_table()->mutable_detail()->set_value(local_engine::MergeTreeTable{.database=database,.table=table,.relative_path=relative_path,.min_block=min_block,.max_block=max_block}.toString()); + read->set_allocated_base_schema(schema); + setInputToPrev(rel); + this->prev_rel = rel; + return *this; +} + + +std::unique_ptr SerializedPlanBuilder::build() +{ + return std::move(this->plan); +} + +SerializedPlanBuilder::SerializedPlanBuilder() : plan(std::make_unique()) +{ +} + +SerializedPlanBuilder & SerializedPlanBuilder::aggregate(std::vector /*keys*/, std::vector aggregates) +{ + substrait::Rel * rel = new substrait::Rel(); + auto * agg = rel->mutable_aggregate(); + // TODO support group + auto * measures = agg->mutable_measures(); + for (auto * measure : aggregates) + { + measures->AddAllocated(measure); + } + setInputToPrev(rel); + this->prev_rel = rel; + return *this; +} + +SerializedPlanBuilder & SerializedPlanBuilder::project(std::vector projections) +{ + substrait::Rel * project = new substrait::Rel(); + for (auto * expr : projections) + { + project->mutable_project()->mutable_expressions()->AddAllocated(expr); + } + setInputToPrev(project); + this->prev_rel = project; + return *this; +} + +std::shared_ptr SerializedPlanBuilder::buildType(const DB::DataTypePtr & ch_type) +{ + const auto * ch_type_nullable = checkAndGetDataType(ch_type.get()); + const bool is_nullable = (ch_type_nullable != nullptr); + auto type_nullability + = is_nullable ? substrait::Type_Nullability_NULLABILITY_NULLABLE : substrait::Type_Nullability_NULLABILITY_REQUIRED; + + const auto ch_type_without_nullable = DB::removeNullable(ch_type); + const DB::WhichDataType which(ch_type_without_nullable); + + auto res = std::make_shared(); + if (which.isUInt8()) + res->mutable_bool_()->set_nullability(type_nullability); + else if (which.isInt8()) + res->mutable_i8()->set_nullability(type_nullability); + else if (which.isInt16()) + res->mutable_i16()->set_nullability(type_nullability); + else if (which.isInt32()) + res->mutable_i32()->set_nullability(type_nullability); + else if (which.isInt64()) + res->mutable_i64()->set_nullability(type_nullability); + else if (which.isString() || which.isAggregateFunction()) + res->mutable_binary()->set_nullability(type_nullability); /// Spark Binary type is more similiar to CH String type + else if (which.isFloat32()) + res->mutable_fp32()->set_nullability(type_nullability); + else if (which.isFloat64()) + res->mutable_fp64()->set_nullability(type_nullability); + else if (which.isFloat64()) + res->mutable_fp64()->set_nullability(type_nullability); + else if (which.isDateTime64()) + { + const auto * ch_type_datetime64 = checkAndGetDataType(ch_type_without_nullable.get()); + if (ch_type_datetime64->getScale() != 6) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); + res->mutable_timestamp()->set_nullability(type_nullability); + } + else if (which.isDate32()) + res->mutable_date()->set_nullability(type_nullability); + else if (which.isDecimal()) + { + if (which.isDecimal256()) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); + + const auto scale = getDecimalScale(*ch_type_without_nullable); + const auto precision = getDecimalPrecision(*ch_type_without_nullable); + if (scale == 0 && precision == 0) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); + + res->mutable_decimal()->set_nullability(type_nullability); + res->mutable_decimal()->set_scale(scale); + res->mutable_decimal()->set_precision(precision); + } + else if (which.isTuple()) + { + const auto * ch_tuple_type = checkAndGetDataType(ch_type_without_nullable.get()); + const auto & ch_field_types = ch_tuple_type->getElements(); + res->mutable_struct_()->set_nullability(type_nullability); + for (const auto & ch_field_type: ch_field_types) + res->mutable_struct_()->mutable_types()->Add(std::move(*buildType(ch_field_type))); + } + else if (which.isArray()) + { + const auto * ch_array_type = checkAndGetDataType(ch_type_without_nullable.get()); + const auto & ch_nested_type = ch_array_type->getNestedType(); + res->mutable_list()->set_nullability(type_nullability); + *(res->mutable_list()->mutable_type()) = *buildType(ch_nested_type); + } + else if (which.isMap()) + { + const auto & ch_map_type = checkAndGetDataType(ch_type_without_nullable.get()); + const auto & ch_key_type = ch_map_type->getKeyType(); + const auto & ch_val_type = ch_map_type->getValueType(); + res->mutable_map()->set_nullability(type_nullability); + *(res->mutable_map()->mutable_key()) = *buildType(ch_key_type); + *(res->mutable_map()->mutable_value()) = *buildType(ch_val_type); + } + else if (which.isNothing()) + res->mutable_nothing(); + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); + + return std::move(res); +} + +void SerializedPlanBuilder::buildType(const DB::DataTypePtr & ch_type, String & substrait_type) +{ + auto pb = buildType(ch_type); + substrait_type = pb->SerializeAsString(); +} + + +substrait::Expression * selection(int32_t field_id) +{ + substrait::Expression * rel = new substrait::Expression(); + auto * selection = rel->mutable_selection(); + selection->mutable_direct_reference()->mutable_struct_field()->set_field(field_id); + return rel; +} +substrait::Expression * scalarFunction(int32_t id, ExpressionList args) +{ + substrait::Expression * rel = new substrait::Expression(); + auto * function = rel->mutable_scalar_function(); + function->set_function_reference(id); + std::for_each(args.begin(), args.end(), [function](auto * expr) { function->mutable_args()->AddAllocated(expr); }); + return rel; +} +substrait::AggregateRel_Measure * measureFunction(int32_t id, ExpressionList args) +{ + substrait::AggregateRel_Measure * rel = new substrait::AggregateRel_Measure(); + auto * measure = rel->mutable_measure(); + measure->set_function_reference(id); + std::for_each(args.begin(), args.end(), [measure](auto * expr) { measure->mutable_args()->AddAllocated(expr); }); + return rel; +} +substrait::Expression * literal(double_t value) +{ + substrait::Expression * rel = new substrait::Expression(); + auto * literal = rel->mutable_literal(); + literal->set_fp64(value); + return rel; +} + +substrait::Expression * literal(int32_t value) +{ + substrait::Expression * rel = new substrait::Expression(); + auto * literal = rel->mutable_literal(); + literal->set_i32(value); + return rel; +} + +substrait::Expression * literal(const std::string & value) +{ + substrait::Expression * rel = new substrait::Expression(); + auto * literal = rel->mutable_literal(); + literal->set_string(value); + return rel; +} + +substrait::Expression* literalDate(int32_t value) +{ + substrait::Expression * rel = new substrait::Expression(); + auto * literal = rel->mutable_literal(); + literal->set_date(value); + return rel; +} + +/// Timestamp in units of microseconds since the UNIX epoch. +substrait::Expression * literalTimestamp(int64_t value) +{ + substrait::Expression * rel = new substrait::Expression(); + auto * literal = rel->mutable_literal(); + literal->set_timestamp(value); + return rel; +} + +} diff --git a/utils/local-engine/Builder/SerializedPlanBuilder.h b/utils/local-engine/Builder/SerializedPlanBuilder.h new file mode 100644 index 000000000000..3b0638a3eeb4 --- /dev/null +++ b/utils/local-engine/Builder/SerializedPlanBuilder.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include + + +namespace dbms +{ +enum Function +{ + IS_NOT_NULL = 0, + GREATER_THAN_OR_EQUAL, + AND, + LESS_THAN_OR_EQUAL, + LESS_THAN, + MULTIPLY, + SUM, + TO_DATE, + EQUAL_TO +}; + +using SchemaPtr = substrait::NamedStruct *; + +class SerializedPlanBuilder +{ +public: + SerializedPlanBuilder(); + + SerializedPlanBuilder & registerSupportedFunctions() + { + this->registerFunction(IS_NOT_NULL, "is_not_null") + .registerFunction(GREATER_THAN_OR_EQUAL, "gte") + .registerFunction(AND, "and") + .registerFunction(LESS_THAN_OR_EQUAL, "lte") + .registerFunction(LESS_THAN, "lt") + .registerFunction(MULTIPLY, "multiply") + .registerFunction(SUM, "sum") + .registerFunction(TO_DATE, "to_date") + .registerFunction(EQUAL_TO, "equal"); + return *this; + } + SerializedPlanBuilder& registerFunction(int id, const std::string & name); + SerializedPlanBuilder& filter(substrait::Expression* condition); + SerializedPlanBuilder& project(std::vector projections); + SerializedPlanBuilder& aggregate(std::vector keys, std::vector aggregates); + SerializedPlanBuilder& read(const std::string & path, SchemaPtr schema); + SerializedPlanBuilder & readMergeTree( + const std::string & database, + const std::string & table, + const std::string & relative_path, + int min_block, + int max_block, + SchemaPtr schema); + std::unique_ptr build(); + + static std::shared_ptr buildType(const DB::DataTypePtr & ch_type); + static void buildType(const DB::DataTypePtr & ch_type, String & substrait_type); + +private: + void setInputToPrev(substrait::Rel * input); + substrait::Rel * prev_rel = nullptr; + std::unique_ptr plan; +}; + + +using Type = substrait::Type; +/** + * build a schema, need define column name and column. + * 1. column name + * 2. column type + * 3. nullability + */ +class SerializedSchemaBuilder +{ +public: + SerializedSchemaBuilder(); + SchemaPtr build(); + SerializedSchemaBuilder& column(const std::string & name, const std::string & type, bool nullable = false); +private: + std::map type_map; + std::map nullability_map; + SchemaPtr schema; +}; + +using ExpressionList = std::vector; +using MeasureList = std::vector; + + +substrait::Expression * scalarFunction(int32_t id, ExpressionList args); +substrait::AggregateRel_Measure * measureFunction(int32_t id, ExpressionList args); + +substrait::Expression* literal(double_t value); +substrait::Expression* literal(int32_t value); +substrait::Expression* literal(const std::string & value); +substrait::Expression* literalDate(int32_t value); + +substrait::Expression * selection(int32_t field_id); + +} diff --git a/utils/local-engine/CMakeLists.txt b/utils/local-engine/CMakeLists.txt new file mode 100644 index 000000000000..cc5137e14161 --- /dev/null +++ b/utils/local-engine/CMakeLists.txt @@ -0,0 +1,90 @@ +set(THRIFT_INCLUDE_DIR "${ClickHouse_SOURCE_DIR}/contrib/thrift/lib/cpp/src") + +# Find java/jni +include(FindJava) +include(UseJava) +include(FindJNI) + +#set(JNI_NATIVE_SOURCES local_engine_jni.cpp) +set(LOCALENGINE_SHARED_LIB ch) +set (ENABLE_CURL_BUILD OFF) + +add_subdirectory(proto) + +add_headers_and_sources(builder Builder) +add_headers_and_sources(parser Parser) +add_headers_and_sources(storages Storages) +add_headers_and_sources(common Common) +add_headers_and_sources(external External) +add_headers_and_sources(shuffle Shuffle) +add_headers_and_sources(operator Operator) +add_headers_and_sources(jni jni) +add_headers_and_sources(aggregate_functions AggregateFunctions) +add_headers_and_sources(functions Functions) + +include_directories( + ${JNI_INCLUDE_DIRS} + ${CMAKE_CURRENT_BINARY_DIR}/proto + ${THRIFT_INCLUDE_DIR} + ${CMAKE_BINARY_DIR}/contrib/thrift-cmake + ${ClickHouse_SOURCE_DIR}/utils/local-engine + ${ClickHouse_SOURCE_DIR}/src + ${ClickHouse_SOURCE_DIR}/base + ${ClickHouse_SOURCE_DIR}/contrib/orc/c++/include + ${CMAKE_BINARY_DIR}/contrib/orc/c++/include + ${ClickHouse_SOURCE_DIR}/contrib/azure/sdk/storage/azure-storage-blobs/inc + ${ClickHouse_SOURCE_DIR}/contrib/azure/sdk/core/azure-core/inc + ${ClickHouse_SOURCE_DIR}/contrib/azure/sdk/storage/azure-storage-common/inc +) + +add_subdirectory(Storages/ch_parquet) +add_subdirectory(Storages/SubstraitSource) + +add_library(${LOCALENGINE_SHARED_LIB} SHARED + ${builder_sources} + ${parser_sources} + ${storages_sources} + ${common_sources} + ${external_sources} + ${shuffle_sources} + ${jni_sources} + ${substrait_source} + ${operator_sources} + ${aggregate_functions_sources} + ${functions_sources} + local_engine_jni.cpp) + + +target_compile_options(${LOCALENGINE_SHARED_LIB} PUBLIC -fPIC + -Wno-shorten-64-to-32) + +target_link_libraries(${LOCALENGINE_SHARED_LIB} PUBLIC + clickhouse_aggregate_functions + clickhouse_common_config + clickhouse_common_io + clickhouse_functions + clickhouse_parsers + clickhouse_storages_system + substrait + loggers + ch_parquet + substait_source + xxHash +) + +#set(CPACK_PACKAGE_VERSION 0.1.0) +#set(CPACK_GENERATOR "RPM") +#set(CPACK_PACKAGE_NAME "local_engine_jni") +#set(CPACK_PACKAGE_RELEASE 1) +#set(CPACK_CMAKE_GENERATOR Ninja) +#set(CPACK_PACKAGE_CONTACT "neng.liu@kyligence.io") +#set(CPACK_PACKAGE_VENDOR "Kyligence") +#set(CPACK_PACKAGING_INSTALL_PREFIX ${CMAKE_INSTALL_PREFIX}) +#set(CPACK_RPM_PACKAGE_AUTOREQPROV "no") +#set(CPACK_PACKAGE_FILE_NAME "${CPACK_PACKAGE_NAME}-${CPACK_PACKAGE_VERSION}-${CPACK_PACKAGE_RELEASE}.${CMAKE_SYSTEM_PROCESSOR}") +#include(CPack) + +if (ENABLE_TESTS) + add_subdirectory(tests) +endif () + diff --git a/utils/local-engine/Common/BlockIterator.cpp b/utils/local-engine/Common/BlockIterator.cpp new file mode 100644 index 000000000000..47d269d0aff5 --- /dev/null +++ b/utils/local-engine/Common/BlockIterator.cpp @@ -0,0 +1,42 @@ +#include "BlockIterator.h" +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +} + +namespace local_engine +{ +void local_engine::BlockIterator::checkNextValid() +{ + if (consumed) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Block iterator next should after hasNext"); + } +} +void BlockIterator::produce() +{ + consumed = false; +} +void BlockIterator::consume() +{ + consumed = true; +} +bool BlockIterator::isConsumed() const +{ + return consumed; +} +DB::Block & BlockIterator::currentBlock() +{ + return cached_block; +} +void BlockIterator::setCurrentBlock(DB::Block & block) +{ + cached_block = block; +} +} + diff --git a/utils/local-engine/Common/BlockIterator.h b/utils/local-engine/Common/BlockIterator.h new file mode 100644 index 000000000000..fc75b150556f --- /dev/null +++ b/utils/local-engine/Common/BlockIterator.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace local_engine +{ +class BlockIterator +{ +protected: + void checkNextValid(); + // make current block available + void produce(); + // consume current block + void consume(); + bool isConsumed() const; + DB::Block & currentBlock(); + void setCurrentBlock(DB::Block & block); + +private: + DB::Block cached_block; + bool consumed = true; +}; +} diff --git a/utils/local-engine/Common/CHUtil.cpp b/utils/local-engine/Common/CHUtil.cpp new file mode 100644 index 000000000000..4e70f9a6fa23 --- /dev/null +++ b/utils/local-engine/Common/CHUtil.cpp @@ -0,0 +1,659 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "CHUtil.h" + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} +} + +namespace local_engine +{ +constexpr auto VIRTUAL_ROW_COUNT_COLOUMN = "__VIRTUAL_ROW_COUNT_COLOUMNOUMN__"; + +namespace fs = std::filesystem; + +DB::Block BlockUtil::buildRowCountHeader() +{ + DB::Block header; + auto uint8_ty = std::make_shared(); + auto col = uint8_ty->createColumn(); + DB::ColumnWithTypeAndName named_col(std::move(col), uint8_ty, VIRTUAL_ROW_COUNT_COLOUMN); + header.insert(named_col); + return header.cloneEmpty(); +} + +DB::Chunk BlockUtil::buildRowCountChunk(UInt64 rows) +{ + auto data_type = std::make_shared(); + auto col = data_type->createColumnConst(rows, 0); + DB::Columns res_columns; + res_columns.emplace_back(std::move(col)); + return DB::Chunk(std::move(res_columns), rows); +} + +DB::Block BlockUtil::buildRowCountBlock(UInt64 rows) +{ + DB::Block block; + auto uint8_ty = std::make_shared(); + auto col = uint8_ty->createColumnConst(rows, 0); + DB::ColumnWithTypeAndName named_col(col, uint8_ty, VIRTUAL_ROW_COUNT_COLOUMN); + block.insert(named_col); + return block; +} + +DB::Block BlockUtil::buildHeader(const DB::NamesAndTypesList & names_types_list) +{ + DB::ColumnsWithTypeAndName cols; + for (const auto & name_type : names_types_list) + { + DB::ColumnWithTypeAndName col(name_type.type->createColumn(), name_type.type, name_type.name); + cols.emplace_back(col); + } + return DB::Block(cols); +} + +/** + * There is a special case with which we need be careful. In spark, struct/map/list are always + * wrapped in Nullable, but this should not happen in clickhouse. + */ +DB::Block BlockUtil::flattenBlock(const DB::Block & block, UInt64 flags, bool recursively) +{ + DB::Block res; + + for (const auto & elem : block) + { + DB::DataTypePtr nested_type = nullptr; + DB::ColumnPtr nested_col = nullptr; + DB::ColumnPtr null_map_col = nullptr; + if (elem.type->isNullable()) + { + nested_type = typeid_cast(elem.type.get())->getNestedType(); + const auto * null_col = typeid_cast(elem.column->getPtr().get()); + nested_col = null_col->getNestedColumnPtr(); + null_map_col = null_col->getNullMapColumnPtr(); + } + else + { + nested_type = elem.type; + nested_col = elem.column; + } + if (const DB::DataTypeArray * type_arr = typeid_cast(nested_type.get())) + { + const DB::DataTypeTuple * type_tuple = typeid_cast(type_arr->getNestedType().get()); + if (type_tuple && type_tuple->haveExplicitNames() && (flags & FLAT_NESTED_TABLE)) + { + const DB::DataTypes & element_types = type_tuple->getElements(); + const DB::Strings & names = type_tuple->getElementNames(); + size_t tuple_size = element_types.size(); + + bool is_const = isColumnConst(*nested_col); + const DB::ColumnArray * column_array; + if (is_const) + column_array = typeid_cast(&assert_cast(*nested_col).getDataColumn()); + else + column_array = typeid_cast(nested_col.get()); + + const DB::ColumnPtr & column_offsets = column_array->getOffsetsPtr(); + + const DB::ColumnTuple & column_tuple = typeid_cast(column_array->getData()); + const auto & element_columns = column_tuple.getColumns(); + + for (size_t i = 0; i < tuple_size; ++i) + { + String nested_name = DB::Nested::concatenateName(elem.name, names[i]); + DB::ColumnPtr column_array_of_element = DB::ColumnArray::create(element_columns[i], column_offsets); + auto named_column_array_of_element = DB::ColumnWithTypeAndName( + is_const ? DB::ColumnConst::create(std::move(column_array_of_element), block.rows()) : column_array_of_element, + std::make_shared(element_types[i]), + nested_name); + if (null_map_col) + { + // Should all field columns have the same null map ? + DB::DataTypePtr null_type = std::make_shared(element_types[i]); + named_column_array_of_element.column = DB::ColumnNullable::create(named_column_array_of_element.column, null_map_col); + named_column_array_of_element.type = null_type; + } + if (recursively) + { + auto flatten_one_col_block = flattenBlock({named_column_array_of_element}, flags, recursively); + for (const auto & named_col : flatten_one_col_block.getColumnsWithTypeAndName()) + { + res.insert(named_col); + } + } + else + { + res.insert(named_column_array_of_element); + } + } + } + else + { + res.insert(elem); + } + } + else if (const DB::DataTypeTuple * type_tuple = typeid_cast(nested_type.get())) + { + if (type_tuple->haveExplicitNames() && (flags & FLAT_STRUCT)) + { + const DB::DataTypes & element_types = type_tuple->getElements(); + const DB::Strings & names = type_tuple->getElementNames(); + const DB::ColumnTuple * column_tuple; + if (isColumnConst(*nested_col)) + column_tuple = typeid_cast(&assert_cast(*nested_col).getDataColumn()); + else + column_tuple = typeid_cast(nested_col.get()); + size_t tuple_size = column_tuple->tupleSize(); + for (size_t i = 0; i < tuple_size; ++i) + { + const auto & element_column = column_tuple->getColumn(i); + String nested_name = DB::Nested::concatenateName(elem.name, names[i]); + auto new_element_col = DB::ColumnWithTypeAndName(element_column.getPtr(), element_types[i], nested_name); + if (null_map_col && !element_types[i]->isNullable()) + { + // Should all field columns have the same null map ? + new_element_col.column = DB::ColumnNullable::create(new_element_col.column, null_map_col); + new_element_col.type = std::make_shared(new_element_col.type); + } + if (recursively) + { + DB::Block one_col_block({new_element_col}); + auto flatten_one_col_block = flattenBlock(one_col_block, flags, recursively); + for (const auto & named_col : flatten_one_col_block.getColumnsWithTypeAndName()) + { + res.insert(named_col); + } + } + else + { + res.insert(std::move(new_element_col)); + } + } + } + else + { + res.insert(elem); + } + } + else + { + res.insert(elem); + } + } + + return res; +} + +std::string PlanUtil::explainPlan(DB::QueryPlan & plan) +{ + std::string plan_str; + DB::QueryPlan::ExplainPlanOptions buf_opt + { + .header = true, + .actions = true, + .indexes = true, + }; + DB::WriteBufferFromOwnString buf; + plan.explainPlan(buf, buf_opt); + plan_str = buf.str(); + return plan_str; +} + +std::vector MergeTreeUtil::getAllMergeTreeParts(const Path &storage_path) +{ + if (!fs::exists(storage_path)) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Invalid merge tree store path:{}", storage_path.string()); + } + + // TODO: May need to check the storage format version + std::vector res; + for (const auto & entry : fs::directory_iterator(storage_path)) + { + auto filename = entry.path().filename(); + if (filename == "format_version.txt" || filename == "detached" || filename == "_delta_log") + continue; + res.push_back(entry.path()); + } + return res; +} + +DB::NamesAndTypesList MergeTreeUtil::getSchemaFromMergeTreePart(const fs::path & part_path) +{ + DB::NamesAndTypesList names_types_list; + if (!fs::exists(part_path)) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Invalid merge tree store path:{}", part_path.string()); + } + DB::ReadBufferFromFile readbuffer((part_path / "columns.txt").string()); + names_types_list.readText(readbuffer); + return names_types_list; +} + + +NestedColumnExtractHelper::NestedColumnExtractHelper(const DB::Block & block_, bool case_insentive_) + : block(block_) + , case_insentive(case_insentive_) +{} + +std::optional NestedColumnExtractHelper::extractColumn(const String & column_name) +{ + if (const auto * col = findColumn(block, column_name)) + return {*col}; + + auto nested_names = DB::Nested::splitName(column_name); + if (case_insentive) + { + boost::to_lower(nested_names.first); + boost::to_lower(nested_names.second); + } + if (!findColumn(block, nested_names.first)) + return {}; + + if (!nested_tables.contains(nested_names.first)) + { + DB::ColumnsWithTypeAndName columns = {*findColumn(block, nested_names.first)}; + nested_tables[nested_names.first] = std::make_shared(BlockUtil::flattenBlock(columns)); + } + + return extractColumn(column_name, nested_names.first, nested_names.second); +} + +std::optional NestedColumnExtractHelper::extractColumn( + const String & original_column_name, const String & column_name_prefix, const String & column_name_suffix) +{ + auto table_iter = nested_tables.find(column_name_prefix); + if (table_iter == nested_tables.end()) + { + return {}; + } + + auto & nested_table = table_iter->second; + auto nested_names = DB::Nested::splitName(column_name_suffix); + auto new_column_name_prefix = DB::Nested::concatenateName(column_name_prefix, nested_names.first); + if (nested_names.second.empty()) + { + if (const auto * column_ref = findColumn(*nested_table, new_column_name_prefix)) + { + DB::ColumnWithTypeAndName column = *column_ref; + if (case_insentive) + column.name = original_column_name; + return {std::move(column)}; + } + else + { + return {}; + } + } + + const auto * sub_col = findColumn(*nested_table, new_column_name_prefix); + if (!sub_col) + { + return {}; + } + + DB::ColumnsWithTypeAndName columns = {*sub_col}; + DB::Block sub_block(columns); + nested_tables[new_column_name_prefix] = std::make_shared(BlockUtil::flattenBlock(sub_block)); + return extractColumn(original_column_name, new_column_name_prefix, nested_names.second); +} + +const DB::ColumnWithTypeAndName * NestedColumnExtractHelper::findColumn(const DB::Block & in_block, const std::string & name) const +{ + + if (case_insentive) + { + std::string final_name = name; + boost::to_lower(final_name); + const auto & cols = in_block.getColumnsWithTypeAndName(); + auto found = std::find_if(cols.begin(), cols.end(), [&](const auto & column) { return boost::iequals(column.name, name); }); + if (found == cols.end()) + { + return nullptr; + } + return &*found; + } + + const auto * col = in_block.findByName(name); + if (col) + return col; + return nullptr; +} + +const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType( + DB::ActionsDAGPtr & actions_dag, const DB::ActionsDAG::Node * node, const std::string & type_name, const std::string & result_name) +{ + DB::ColumnWithTypeAndName type_name_col; + type_name_col.name = type_name; + type_name_col.column = DB::DataTypeString().createColumnConst(0, type_name_col.name); + type_name_col.type = std::make_shared(); + const auto * right_arg = &actions_dag->addColumn(std::move(type_name_col)); + const auto * left_arg = node; + DB::FunctionCastBase::Diagnostic diagnostic = {node->result_name, node->result_name}; + DB::FunctionOverloadResolverPtr func_builder_cast + = DB::CastInternalOverloadResolver::createImpl(std::move(diagnostic)); + + DB::ActionsDAG::NodeRawConstPtrs children = {left_arg, right_arg}; + return &actions_dag->addFunction(func_builder_cast, std::move(children), result_name); +} + +String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline) +{ + DB::WriteBufferFromOwnString buf; + const auto & processors = pipeline.getProcessors(); + DB::printPipelineCompact(processors, buf, true); + return buf.str(); +} + +using namespace DB; + +std::map BackendInitializerUtil::getBackendConfMap(const std::string &plan) +{ + std::map ch_backend_conf; + + /// Parse backend configs from plan extensions + do + { + auto plan_ptr = std::make_unique(); + auto success = plan_ptr->ParseFromString(plan); + if (!success) + break; + + if (!plan_ptr->has_advanced_extensions() || !plan_ptr->advanced_extensions().has_enhancement()) + break; + const auto & enhancement = plan_ptr->advanced_extensions().enhancement(); + + if (!enhancement.Is()) + break; + + substrait::Expression expression; + if (!enhancement.UnpackTo(&expression) || !expression.has_literal() || !expression.literal().has_map()) + break; + + const auto & key_values = expression.literal().map().key_values(); + for (const auto & key_value : key_values) + { + if (!key_value.has_key() || !key_value.has_value()) + continue; + + const auto & key = key_value.key(); + const auto & value = key_value.value(); + if (!key.has_string() || !value.has_string()) + continue; + + if (!key.string().starts_with(CH_BACKEND_CONF_PREFIX) && key.string() != std::string(GLUTEN_TIMEZONE_KEY)) + continue; + + ch_backend_conf[key.string()] = value.string(); + } + } while (false); + + if (!ch_backend_conf.count(CH_RUNTIME_CONF_FILE)) + { + /// Try to get config path from environment variable + const char * config_path = std::getenv("CLICKHOUSE_BACKEND_CONFIG"); /// NOLINT + if (config_path) + { + ch_backend_conf[CH_RUNTIME_CONF_FILE] = config_path; + } + } + return ch_backend_conf; +} + +void BackendInitializerUtil::initConfig(const std::string &plan) +{ + /// Parse input substrait plan, and get native conf map from it. + std::map backend_conf_map; + backend_conf_map = getBackendConfMap(plan); + + if (backend_conf_map.count(CH_RUNTIME_CONF_FILE)) + { + if (fs::exists(CH_RUNTIME_CONF_FILE) && fs::is_regular_file(CH_RUNTIME_CONF_FILE)) + { + ConfigProcessor config_processor(CH_RUNTIME_CONF_FILE, false, true); + config_processor.setConfigPath(fs::path(CH_RUNTIME_CONF_FILE).parent_path()); + auto loaded_config = config_processor.loadConfig(false); + config = loaded_config.configuration; + } + else + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} is not a valid configure file.", CH_RUNTIME_CONF_FILE); + } + else + config = Poco::AutoPtr(new Poco::Util::MapConfiguration()); + + /// Update specified settings + for (const auto & kv : backend_conf_map) + { + if (kv.first.starts_with(CH_RUNTIME_CONF_PREFIX) && kv.first != CH_RUNTIME_CONF_FILE) + config->setString(kv.first.substr(CH_RUNTIME_CONF_PREFIX.size() + 1), kv.second); + else if (kv.first == std::string(GLUTEN_TIMEZONE_KEY)) + config->setString(kv.first, kv.second); + } +} + +void BackendInitializerUtil::initLoggers() +{ + auto level = config->getString("logger.level", "error"); + if (config->has("logger.log")) + local_engine::Logger::initFileLogger(*config, "ClickHouseBackend"); + else + local_engine::Logger::initConsoleLogger(level); + + logger = &Poco::Logger::get("ClickHouseBackend"); +} + +void BackendInitializerUtil::initEnvs() +{ + /// Set environment variable TZ if possible + if (config->has(GLUTEN_TIMEZONE_KEY)) + { + String timezone_name = config->getString(GLUTEN_TIMEZONE_KEY); + if (0 != setenv("TZ", timezone_name.data(), 1)) /// NOLINT + throw Poco::Exception("Cannot setenv TZ variable"); + + tzset(); + DateLUT::setDefaultTimezone(timezone_name); + } + + /// Set environment variable LIBHDFS3_CONF if possible + if (config->has(LIBHDFS3_CONF_KEY)) + { + std::string libhdfs3_conf = config->getString(LIBHDFS3_CONF_KEY, ""); + setenv("LIBHDFS3_CONF", libhdfs3_conf.c_str(), true); /// NOLINT + } +} + +void BackendInitializerUtil::initSettings() +{ + static const std::string settings_path("local_engine.settings"); + + settings = Settings(); + Poco::Util::AbstractConfiguration::Keys config_keys; + config->keys(settings_path, config_keys); + + for (const std::string & key : config_keys) + settings.set(key, config->getString(settings_path + "." + key)); + settings.set("join_use_nulls", true); + settings.set("input_format_orc_allow_missing_columns", true); + settings.set("input_format_orc_case_insensitive_column_matching", true); + settings.set("input_format_parquet_allow_missing_columns", true); + settings.set("input_format_parquet_case_insensitive_column_matching", true); + settings.set("function_json_value_return_type_allow_complex", true); + settings.set("function_json_value_return_type_allow_nullable", true); +} + +void BackendInitializerUtil::initContexts() +{ + /// Make sure global_context and shared_context are constructed only once. + auto & shared_context = SerializedPlanParser::shared_context; + if (!shared_context.get()) + { + shared_context = SharedContextHolder(Context::createShared()); + } + + auto & global_context = SerializedPlanParser::global_context; + if (!global_context) + { + global_context = Context::createGlobal(shared_context.get()); + global_context->makeGlobalContext(); + global_context->setTemporaryStoragePath("/tmp/libch", 0); + global_context->setPath(config->getString("path", "/")); + } +} + +void BackendInitializerUtil::applyConfigAndSettings() +{ + auto & global_context = SerializedPlanParser::global_context; + global_context->setConfig(config); + global_context->setSettings(settings); +} + +extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCombinatorFactory &); +extern void registerFunctions(FunctionFactory &); + +void registerAllFunctions() +{ + DB::registerFunctions(); + DB::registerAggregateFunctions(); + + { + /// register aggregate function combinators from local_engine + auto & factory = AggregateFunctionCombinatorFactory::instance(); + registerAggregateFunctionCombinatorPartialMerge(factory); + } + + { + /// register ordinary functions from local_engine + auto & factory = FunctionFactory::instance(); + registerFunctions(factory); + } +} + +extern void registerAllFunctions(); + +void BackendInitializerUtil::registerAllFactories() +{ + registerReadBufferBuilders(); + LOG_INFO(logger, "Register read buffer builders."); + + registerRelParsers(); + LOG_INFO(logger, "Register relation parsers."); + + registerAllFunctions(); + LOG_INFO(logger, "Register all functions."); +} + +void BackendInitializerUtil::initCompiledExpressionCache() +{ + #if USE_EMBEDDED_COMPILER + /// 128 MB + constexpr size_t compiled_expression_cache_size_default = 1024 * 1024 * 128; + size_t compiled_expression_cache_size = config->getUInt64("compiled_expression_cache_size", compiled_expression_cache_size_default); + + constexpr size_t compiled_expression_cache_elements_size_default = 10000; + size_t compiled_expression_cache_elements_size + = config->getUInt64("compiled_expression_cache_elements_size", compiled_expression_cache_elements_size_default); + + CompiledExpressionCacheFactory::instance().init(compiled_expression_cache_size, compiled_expression_cache_elements_size); +#endif +} + +void BackendInitializerUtil::init(const std::string & plan) +{ + initConfig(plan); + initLoggers(); + + initEnvs(); + LOG_INFO(logger, "Init environment variables."); + + initSettings(); + LOG_INFO(logger, "Init settings."); + + initContexts(); + LOG_INFO(logger, "Init shared context and global context."); + + applyConfigAndSettings(); + LOG_INFO(logger, "Apply configuration and setting for global context."); + + std::call_once( + init_flag, + [&] + { + registerAllFactories(); + LOG_INFO(logger, "Register all factories."); + + initCompiledExpressionCache(); + LOG_INFO(logger, "Init compiled expressions cache factory."); + }); +} + +void BackendFinalizerUtil::finalizeGlobally() +{ + local_engine::BroadCastJoinBuilder::clean(); + + auto & global_context = SerializedPlanParser::global_context; + auto & shared_context = SerializedPlanParser::shared_context; + auto * logger = BackendInitializerUtil::logger; + if (global_context) + { + global_context->shutdown(); + global_context.reset(); + shared_context.reset(); + } +} + +void BackendFinalizerUtil::finalizeSessionally() +{ +} + +} diff --git a/utils/local-engine/Common/CHUtil.h b/utils/local-engine/Common/CHUtil.h new file mode 100644 index 000000000000..665afade3e4b --- /dev/null +++ b/utils/local-engine/Common/CHUtil.h @@ -0,0 +1,138 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ + +class BlockUtil +{ +public: + // Build a header block with a virtual column which will be + // use to indicate the number of rows in a block. + // Commonly seen in the following quries: + // - select count(1) from t + // - select 1 from t + static DB::Block buildRowCountHeader(); + static DB::Chunk buildRowCountChunk(UInt64 rows); + static DB::Block buildRowCountBlock(UInt64 rows); + + static DB::Block buildHeader(const DB::NamesAndTypesList & names_types_list); + + static constexpr UInt64 FLAT_STRUCT = 1; + static constexpr UInt64 FLAT_NESTED_TABLE = 2; + // flatten the struct and array(struct) columns. + // It's different from Nested::flattend() + static DB::Block flattenBlock(const DB::Block & block, UInt64 flags = FLAT_STRUCT|FLAT_NESTED_TABLE, bool recursively = false); +}; + +/// Use this class to extract element columns from columns of nested type in a block, e.g. named Tuple. +/// It can extract a column from a multiple nested type column, e.g. named Tuple in named Tuple +/// Keeps some intermediate data to avoid rebuild them multi-times. +class NestedColumnExtractHelper +{ +public: + explicit NestedColumnExtractHelper(const DB::Block & block_, bool case_insentive_); + std::optional extractColumn(const String & column_name); +private: + std::optional + extractColumn(const String & original_column_name, const String & column_name_prefix, const String & column_name_suffix); + const DB::Block & block; + bool case_insentive; + std::map nested_tables; + + const DB::ColumnWithTypeAndName * findColumn(const DB::Block & block, const std::string & name) const; +}; + +class PlanUtil +{ +public: + static std::string explainPlan(DB::QueryPlan & plan); +}; + +class MergeTreeUtil +{ +public: + using Path = std::filesystem::path; + static std::vector getAllMergeTreeParts(const Path & storage_path); + static DB::NamesAndTypesList getSchemaFromMergeTreePart(const Path & part_path); +}; + +class ActionsDAGUtil +{ +public: + static const DB::ActionsDAG::Node * convertNodeType( + DB::ActionsDAGPtr & actions_dag, + const DB::ActionsDAG::Node * node, + const std::string & type_name, + const std::string & result_name = ""); +}; + +class QueryPipelineUtil +{ +public: + static String explainPipeline(DB::QueryPipeline & pipeline); +}; + + +class BackendFinalizerUtil; +class JNIUtils; +class BackendInitializerUtil +{ +public: + /// Initialize two kinds of resources + /// 1. global level resources like global_context/shared_context, notice that they can only be initialized once in process lifetime + /// 2. session level resources like settings/configs, they can be initialized multiple times following the lifetime of executor/driver + static void init(const std::string & plan); + +private: + friend class BackendFinalizerUtil; + friend class JNIUtils; + + static void initConfig(const std::string & plan); + static void initLoggers(); + static void initEnvs(); + static void initSettings(); + static void initContexts(); + static void applyConfigAndSettings(); + static void registerAllFactories(); + static void initCompiledExpressionCache(); + + static std::map getBackendConfMap(const std::string & plan); + + inline static const String CH_BACKEND_CONF_PREFIX = "spark.gluten.sql.columnar.backend.ch"; + inline static const String CH_RUNTIME_CONF = "runtime_conf"; + inline static const String CH_RUNTIME_CONF_PREFIX = CH_BACKEND_CONF_PREFIX + "." + CH_RUNTIME_CONF; + inline static const String CH_RUNTIME_CONF_FILE = CH_RUNTIME_CONF_PREFIX + ".conf_file"; + inline static const String GLUTEN_TIMEZONE_KEY = "spark.gluten.timezone"; + inline static const String LIBHDFS3_CONF_KEY = "hdfs.libhdfs3_conf"; + inline static const String SETTINGs_PATH = "local_engine.settings"; + + inline static std::once_flag init_flag; + inline static DB::Context::ConfigurationPtr config; + inline static Poco::Logger * logger; + inline static DB::Settings settings; +}; + +class BackendFinalizerUtil +{ +public: + /// Release global level resources like global_context/shared_context. Invoked only once in the lifetime of process when JVM is shuting down. + static void finalizeGlobally(); + + /// Release session level resources like StorageJoinBuilder. Invoked every time executor/driver shutdown. + static void finalizeSessionally(); +}; + +} diff --git a/utils/local-engine/Common/CMakeLists.txt b/utils/local-engine/Common/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/local-engine/Common/ChunkBuffer.cpp b/utils/local-engine/Common/ChunkBuffer.cpp new file mode 100644 index 000000000000..702428b37a91 --- /dev/null +++ b/utils/local-engine/Common/ChunkBuffer.cpp @@ -0,0 +1,34 @@ +#include "ChunkBuffer.h" + +namespace local_engine +{ +void ChunkBuffer::add(DB::Chunk & columns, int start, int end) +{ + if (accumulated_columns.empty()) + { + auto num_cols = columns.getNumColumns(); + accumulated_columns.reserve(num_cols); + for (size_t i = 0; i < num_cols; i++) + { + accumulated_columns.emplace_back(columns.getColumns()[i]->cloneEmpty()); + } + } + + for (size_t i = 0; i < columns.getNumColumns(); ++i) + accumulated_columns[i]->insertRangeFrom(*columns.getColumns()[i], start, end - start); +} +size_t ChunkBuffer::size() const +{ + if (accumulated_columns.empty()) + return 0; + return accumulated_columns.at(0)->size(); +} +DB::Chunk ChunkBuffer::releaseColumns() +{ + auto rows = size(); + DB::Columns res(std::make_move_iterator(accumulated_columns.begin()), std::make_move_iterator(accumulated_columns.end())); + accumulated_columns.clear(); + return DB::Chunk(res, rows); +} + +} diff --git a/utils/local-engine/Common/ChunkBuffer.h b/utils/local-engine/Common/ChunkBuffer.h new file mode 100644 index 000000000000..fcaf21cba4da --- /dev/null +++ b/utils/local-engine/Common/ChunkBuffer.h @@ -0,0 +1,17 @@ +#pragma once +#include + +namespace local_engine +{ +class ChunkBuffer +{ +public: + void add(DB::Chunk & columns, int start, int end); + size_t size() const; + DB::Chunk releaseColumns(); + +private: + DB::MutableColumns accumulated_columns; +}; + +} diff --git a/utils/local-engine/Common/ConcurrentMap.h b/utils/local-engine/Common/ConcurrentMap.h new file mode 100644 index 000000000000..c20729fdc97a --- /dev/null +++ b/utils/local-engine/Common/ConcurrentMap.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +namespace local_engine +{ +template +class ConcurrentMap { +public: + void insert(const K& key, const V& value) { + std::lock_guard lock{mutex}; + map.insert({key, value}); + } + + V get(const K& key) { + std::lock_guard lock{mutex}; + auto it = map.find(key); + if (it == map.end()) { + return nullptr; + } + return it->second; + } + + void erase(const K& key) { + std::lock_guard lock{mutex}; + map.erase(key); + } + + void clear() { + std::lock_guard lock{mutex}; + map.clear(); + } + + size_t size() const { + std::lock_guard lock{mutex}; + return map.size(); + } + +private: + std::unordered_map map; + mutable std::mutex mutex; +}; +} + diff --git a/utils/local-engine/Common/DebugUtils.cpp b/utils/local-engine/Common/DebugUtils.cpp new file mode 100644 index 000000000000..b104daadfe00 --- /dev/null +++ b/utils/local-engine/Common/DebugUtils.cpp @@ -0,0 +1,64 @@ +#include "DebugUtils.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace debug +{ + +void headBlock(const DB::Block & block, size_t count) +{ + std::cerr << "============Block============" << std::endl; + std::cerr << block.dumpStructure() << std::endl; + // print header + for (const auto& name : block.getNames()) + std::cerr << name << "\t"; + std::cerr << std::endl; + + // print rows + for (size_t row = 0; row < std::min(count, block.rows()); ++row) + { + for (size_t column = 0; column < block.columns(); ++column) + { + const auto type = block.getByPosition(column).type; + auto col = block.getByPosition(column).column; + + if (column > 0) + std::cerr << "\t"; + DB::WhichDataType which(type); + if (which.isAggregateFunction()) + { + std::cerr << "Nan"; + } + else if (col->isNullAt(row)) + { + std::cerr << "null"; + } + else + { + std::cerr << toString((*col)[row]); + } + } + std::cerr << std::endl; + } +} + +void headColumn(const DB::ColumnPtr & column, size_t count) +{ + std::cerr << "============Column============" << std::endl; + + // print header + std::cerr << column->getName() << "\t"; + std::cerr << std::endl; + + // print rows + for (size_t row = 0; row < std::min(count, column->size()); ++row) + std::cerr << toString((*column)[row]) << std::endl; +} + +} diff --git a/utils/local-engine/Common/DebugUtils.h b/utils/local-engine/Common/DebugUtils.h new file mode 100644 index 000000000000..03121f50bd82 --- /dev/null +++ b/utils/local-engine/Common/DebugUtils.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace debug +{ + +void headBlock(const DB::Block & block, size_t count=10); + +void headColumn(const DB::ColumnPtr & column, size_t count=10); +} diff --git a/utils/local-engine/Common/ExceptionUtils.cpp b/utils/local-engine/Common/ExceptionUtils.cpp new file mode 100644 index 000000000000..627687a2d4c8 --- /dev/null +++ b/utils/local-engine/Common/ExceptionUtils.cpp @@ -0,0 +1,13 @@ +#include "ExceptionUtils.h" + +using namespace DB; + +namespace local_engine +{ +void ExceptionUtils::handleException(const Exception & exception) +{ + LOG_ERROR(&Poco::Logger::get("ExceptionUtils"), "{}\n{}", exception.message(), exception.getStackTraceString()); + exception.rethrow(); +} + +} diff --git a/utils/local-engine/Common/ExceptionUtils.h b/utils/local-engine/Common/ExceptionUtils.h new file mode 100644 index 000000000000..8676db299dfd --- /dev/null +++ b/utils/local-engine/Common/ExceptionUtils.h @@ -0,0 +1,12 @@ +#pragma once +#include +#include + +namespace local_engine +{ +class ExceptionUtils +{ +public: + static void handleException(const DB::Exception & exception); +}; +} diff --git a/utils/local-engine/Common/JNIUtils.cpp b/utils/local-engine/Common/JNIUtils.cpp new file mode 100644 index 000000000000..4d8dd6e9333a --- /dev/null +++ b/utils/local-engine/Common/JNIUtils.cpp @@ -0,0 +1,36 @@ +#include "JNIUtils.h" + +namespace local_engine +{ + +JNIEnv * JNIUtils::getENV(int * attach) +{ + if (vm == nullptr) + return nullptr; + + *attach = 0; + JNIEnv * jni_env = nullptr; + + int status = vm->GetEnv(reinterpret_cast(&jni_env), JNI_VERSION_1_8); + + if (status == JNI_EDETACHED || jni_env == nullptr) + { + status = vm->AttachCurrentThread(reinterpret_cast(&jni_env), nullptr); + if (status < 0) + { + jni_env = nullptr; + } + else + { + *attach = 1; + } + } + return jni_env; +} + +void JNIUtils::detachCurrentThread() +{ + vm->DetachCurrentThread(); +} + +} diff --git a/utils/local-engine/Common/JNIUtils.h b/utils/local-engine/Common/JNIUtils.h new file mode 100644 index 000000000000..1a4e43015c60 --- /dev/null +++ b/utils/local-engine/Common/JNIUtils.h @@ -0,0 +1,26 @@ +#pragma once +#include + +namespace local_engine +{ +class JNIUtils +{ +public: + inline static JavaVM * vm = nullptr; + + static JNIEnv * getENV(int * attach); + + static void detachCurrentThread(); +}; + +#define GET_JNIENV(env) \ + int attached; \ + JNIEnv * (env) = JNIUtils::getENV(&attached); + +#define CLEAN_JNIENV \ + if (attached) [[unlikely]]\ + { \ + JNIUtils::detachCurrentThread(); \ + } + +} diff --git a/utils/local-engine/Common/JoinHelper.cpp b/utils/local-engine/Common/JoinHelper.cpp new file mode 100644 index 000000000000..cecdcf4a3128 --- /dev/null +++ b/utils/local-engine/Common/JoinHelper.cpp @@ -0,0 +1,29 @@ +#include "JoinHelper.h" +#include +#include + +using namespace DB; + +namespace local_engine +{ + +JoinOptimizationInfo parseJoinOptimizationInfo(const std::string & optimization) +{ + JoinOptimizationInfo info; + ReadBufferFromString in(optimization); + assertString("JoinParameters:", in); + assertString("isBHJ=", in); + readBoolText(info.is_broadcast, in); + assertChar('\n', in); + if (info.is_broadcast) + { + assertString("isNullAwareAntiJoin=", in); + readBoolText(info.is_null_aware_anti_join, in); + assertChar('\n', in); + assertString("buildHashTableId=", in); + readString(info.storage_join_key, in); + assertChar('\n', in); + } + return info; +} +} diff --git a/utils/local-engine/Common/JoinHelper.h b/utils/local-engine/Common/JoinHelper.h new file mode 100644 index 000000000000..ec6a4f778a2c --- /dev/null +++ b/utils/local-engine/Common/JoinHelper.h @@ -0,0 +1,19 @@ +#pragma once +#include + +namespace local_engine +{ +struct JoinOptimizationInfo +{ + bool is_broadcast; + bool is_null_aware_anti_join; + std::string storage_join_key; +}; + + +JoinOptimizationInfo parseJoinOptimizationInfo(const std::string & optimization); + + +} + + diff --git a/utils/local-engine/Common/Logger.cpp b/utils/local-engine/Common/Logger.cpp new file mode 100644 index 000000000000..37d9546cb5b8 --- /dev/null +++ b/utils/local-engine/Common/Logger.cpp @@ -0,0 +1,26 @@ +#include "Logger.h" + +#include +#include +#include +#include +#include + + +using Poco::ConsoleChannel; +using Poco::AutoPtr; +using Poco::AsyncChannel; + +void local_engine::Logger::initConsoleLogger(const std::string & level) +{ + AutoPtr chan(new ConsoleChannel); + AutoPtr async_chann(new AsyncChannel(chan)); + Poco::Logger::root().setChannel(async_chann); + Poco::Logger::root().setLevel(level); +} + +void local_engine::Logger::initFileLogger(Poco::Util::AbstractConfiguration & config, const std::string & cmd_name) +{ + static Loggers loggers; + loggers.buildLoggers(config, Poco::Logger::root(), cmd_name); +} diff --git a/utils/local-engine/Common/Logger.h b/utils/local-engine/Common/Logger.h new file mode 100644 index 000000000000..e5dace96a7d3 --- /dev/null +++ b/utils/local-engine/Common/Logger.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace local_engine +{ +class Logger +{ +public: + static void initConsoleLogger(const std::string & level = "error"); + static void initFileLogger(Poco::Util::AbstractConfiguration & config, const std::string & cmd_name); +}; +} + + diff --git a/utils/local-engine/Common/MergeTreeTool.cpp b/utils/local-engine/Common/MergeTreeTool.cpp new file mode 100644 index 000000000000..0ff8e030e221 --- /dev/null +++ b/utils/local-engine/Common/MergeTreeTool.cpp @@ -0,0 +1,81 @@ +#include "MergeTreeTool.h" +#include +#include +#include +#include + +using namespace DB; + +namespace local_engine +{ +std::shared_ptr buildMetaData(DB::NamesAndTypesList columns, ContextPtr context) +{ + std::shared_ptr metadata = std::make_shared(); + ColumnsDescription columns_description; + for (const auto &item : columns) + { + columns_description.add(ColumnDescription(item.name, item.type)); + } + metadata->setColumns(std::move(columns_description)); + metadata->partition_key.expression_list_ast = std::make_shared(); + metadata->sorting_key = KeyDescription::getSortingKeyFromAST(makeASTFunction("tuple"), metadata->getColumns(), context, {}); + metadata->primary_key.expression = std::make_shared(std::make_shared()); + return metadata; +} + +std::unique_ptr buildMergeTreeSettings() +{ + auto settings = std::make_unique(); + settings->set("min_bytes_for_wide_part", Field(0)); + settings->set("min_rows_for_wide_part", Field(0)); + return settings; +} + +std::unique_ptr buildQueryInfo(NamesAndTypesList& names_and_types_list) +{ + std::unique_ptr query_info = std::make_unique(); + query_info->query = std::make_shared(); + auto syntax_analyzer_result = std::make_shared(names_and_types_list); + syntax_analyzer_result->analyzed_join = std::make_shared(); + query_info->syntax_analyzer_result = syntax_analyzer_result; + return query_info; +} + + +MergeTreeTable parseMergeTreeTableString(const std::string & info) +{ + ReadBufferFromString in(info); + assertString("MergeTree;", in); + MergeTreeTable table; + readString(table.database, in); + assertChar('\n', in); + readString(table.table, in); + assertChar('\n', in); + readString(table.relative_path, in); + assertChar('\n', in); + readIntText(table.min_block, in); + assertChar('\n', in); + readIntText(table.max_block, in); + assertChar('\n', in); + assertEOF(in); + return table; +} + +std::string MergeTreeTable::toString() const +{ + WriteBufferFromOwnString out; + writeString("MergeTree;", out); + writeString(database, out); + writeChar('\n', out); + writeString(table, out); + writeChar('\n', out); + writeString(relative_path, out); + writeChar('\n', out); + writeIntText(min_block, out); + writeChar('\n', out); + writeIntText(max_block, out); + writeChar('\n', out); + return out.str(); +} + +} diff --git a/utils/local-engine/Common/MergeTreeTool.h b/utils/local-engine/Common/MergeTreeTool.h new file mode 100644 index 000000000000..df82133c887c --- /dev/null +++ b/utils/local-engine/Common/MergeTreeTool.h @@ -0,0 +1,39 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace local_engine +{ + +using namespace DB; +std::shared_ptr buildMetaData(DB::NamesAndTypesList columns, ContextPtr context); + +std::unique_ptr buildMergeTreeSettings(); + +std::unique_ptr buildQueryInfo(NamesAndTypesList & names_and_types_list); + +struct MergeTreeTable +{ + std::string database; + std::string table; + std::string relative_path; + int min_block; + int max_block; + + std::string toString() const; +}; + +MergeTreeTable parseMergeTreeTableString(const std::string & info); + +} diff --git a/utils/local-engine/Common/QueryContext.cpp b/utils/local-engine/Common/QueryContext.cpp new file mode 100644 index 000000000000..1722a737a7b8 --- /dev/null +++ b/utils/local-engine/Common/QueryContext.cpp @@ -0,0 +1,71 @@ +#include "QueryContext.h" +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +} + +namespace local_engine +{ +using namespace DB; +thread_local std::weak_ptr query_scope; +thread_local std::weak_ptr thread_status; +ConcurrentMap allocator_map; + +int64_t initializeQuery(ReservationListenerWrapperPtr listener) +{ + auto query_context = Context::createCopy(SerializedPlanParser::global_context); + query_context->makeQueryContext(); + auto allocator_context = std::make_shared(); + allocator_context->thread_status = std::make_shared(); + allocator_context->query_scope = std::make_shared(query_context); + allocator_context->query_context = query_context; + allocator_context->listener = listener; + thread_status = std::weak_ptr(allocator_context->thread_status); + query_scope = std::weak_ptr(allocator_context->query_scope); + auto allocator_id = reinterpret_cast(allocator_context.get()); + CurrentMemoryTracker::before_alloc = [listener](Int64 size, bool throw_if_memory_exceed) -> void + { + if (throw_if_memory_exceed) + listener->reserveOrThrow(size); + else + listener->reserve(size); + }; + CurrentMemoryTracker::before_free = [listener](Int64 size) -> void { listener->free(size); }; + allocator_map.insert(allocator_id, allocator_context); + return allocator_id; +} + +void releaseAllocator(int64_t allocator_id) +{ + if (!allocator_map.get(allocator_id)) + { + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "allocator {} not found", allocator_id); + } + auto status = allocator_map.get(allocator_id)->thread_status; + auto listener = allocator_map.get(allocator_id)->listener; + if (status->untracked_memory < 0) + listener->free(-status->untracked_memory); + allocator_map.erase(allocator_id); +} + +NativeAllocatorContextPtr getAllocator(int64_t allocator) +{ + return allocator_map.get(allocator); +} + +int64_t allocatorMemoryUsage(int64_t allocator_id) +{ + return allocator_map.get(allocator_id)->thread_status->memory_tracker.get(); +} + +} diff --git a/utils/local-engine/Common/QueryContext.h b/utils/local-engine/Common/QueryContext.h new file mode 100644 index 000000000000..50b52c5563e4 --- /dev/null +++ b/utils/local-engine/Common/QueryContext.h @@ -0,0 +1,27 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace local_engine +{ +int64_t initializeQuery(ReservationListenerWrapperPtr listener); + +void releaseAllocator(int64_t allocator_id); + +int64_t allocatorMemoryUsage(int64_t allocator_id); + +struct NativeAllocatorContext +{ + std::shared_ptr query_scope; + std::shared_ptr thread_status; + DB::ContextPtr query_context; + ReservationListenerWrapperPtr listener; +}; + +using NativeAllocatorContextPtr = std::shared_ptr; + +NativeAllocatorContextPtr getAllocator(int64_t allocator); +} diff --git a/utils/local-engine/Common/StringUtils.cpp b/utils/local-engine/Common/StringUtils.cpp new file mode 100644 index 000000000000..720bb32cbf9f --- /dev/null +++ b/utils/local-engine/Common/StringUtils.cpp @@ -0,0 +1,27 @@ +#include "StringUtils.h" +#include +#include + +namespace local_engine +{ +PartitionValues StringUtils::parsePartitionTablePath(const std::string & file) +{ + PartitionValues result; + Poco::StringTokenizer path(file, "/"); + for (const auto & item : path) + { + auto position = item.find('='); + if (position != std::string::npos) + { + result.emplace_back(PartitionValue(item.substr(0,position), item.substr(position+1))); + } + } + return result; +} +bool StringUtils::isNullPartitionValue(const std::string & value) +{ + return value == "__HIVE_DEFAULT_PARTITION__"; +} +} + + diff --git a/utils/local-engine/Common/StringUtils.h b/utils/local-engine/Common/StringUtils.h new file mode 100644 index 000000000000..40f33500513c --- /dev/null +++ b/utils/local-engine/Common/StringUtils.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include + +namespace local_engine +{ +using PartitionValue = std::pair; +using PartitionValues = std::vector; + +class StringUtils +{ +public: + static PartitionValues parsePartitionTablePath(const std::string & file); + static bool isNullPartitionValue(const std::string & value); +}; +} diff --git a/utils/local-engine/Common/common.cpp b/utils/local-engine/Common/common.cpp new file mode 100644 index 000000000000..08008129f35a --- /dev/null +++ b/utils/local-engine/Common/common.cpp @@ -0,0 +1,30 @@ +#include +#include + +#include + +using namespace DB; + +#ifdef __cplusplus +extern "C" { +#endif + +char * createExecutor(const std::string & plan_string) +{ + auto context = Context::createCopy(local_engine::SerializedPlanParser::global_context); + local_engine::SerializedPlanParser parser(context); + auto query_plan = parser.parse(plan_string); + local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(parser.query_context); + executor->execute(std::move(query_plan)); + return reinterpret_cast(executor); +} + +bool executorHasNext(char * executor_address) +{ + local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); + return executor->hasNext(); +} + +#ifdef __cplusplus +} +#endif diff --git a/utils/local-engine/Functions/CMakeLists.txt b/utils/local-engine/Functions/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/local-engine/Functions/FunctionsHashingExtended.cpp b/utils/local-engine/Functions/FunctionsHashingExtended.cpp new file mode 100644 index 000000000000..ef9700874638 --- /dev/null +++ b/utils/local-engine/Functions/FunctionsHashingExtended.cpp @@ -0,0 +1,13 @@ +#include "FunctionsHashingExtended.h" + +#include + +namespace local_engine +{ +void registerFunctionsHashingExtended(FunctionFactory & factory) +{ + factory.registerFunction(); + factory.registerFunction(); +} + +} diff --git a/utils/local-engine/Functions/FunctionsHashingExtended.h b/utils/local-engine/Functions/FunctionsHashingExtended.h new file mode 100644 index 000000000000..e557eeb88694 --- /dev/null +++ b/utils/local-engine/Functions/FunctionsHashingExtended.h @@ -0,0 +1,137 @@ +#pragma once + +#include +#include + +#ifdef __clang__ +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wused-but-marked-unused" +#endif + +#include + +namespace local_engine +{ +using namespace DB; + +/// For spark compatiability of ClickHouse. +/// The difference between spark xxhash64 and CH xxHash64 +/// In Spark, the seed is 42 +/// In CH, the seed is 0. So we need to add new impl ImplXxHash64Spark in CH with seed = 42. +struct ImplXxHashSpark64 +{ + static constexpr auto name = "xxHashSpark64"; + using ReturnType = UInt64; + using uint128_t = CityHash_v1_0_2::uint128; + + static auto apply(const char * s, const size_t len) { return XXH_INLINE_XXH64(s, len, 42); } + + /* + With current implementation with more than 1 arguments it will give the results + non-reproducible from outside of CH. (see comment on ImplXxHash32). + */ + static auto combineHashes(UInt64 h1, UInt64 h2) { return CityHash_v1_0_2::Hash128to64(uint128_t(h1, h2)); } + + static constexpr bool use_int_hash_for_pods = false; +}; + + +/// Block read - if your platform needs to do endian-swapping or can only +/// handle aligned reads, do the conversion here +static ALWAYS_INLINE uint32_t getblock32(const uint32_t * p, int i) +{ + uint32_t res; + memcpy(&res, p + i, sizeof(res)); + return res; +} + +static ALWAYS_INLINE uint32_t rotl32(uint32_t x, int8_t r) +{ + return (x << r) | (x >> (32 - r)); +} + +static void MurmurHashSpark3_x86_32(const void * key, size_t len, uint32_t seed, void * out) +{ + const uint8_t * data = static_cast(key); + const int nblocks = len / 4; + + uint32_t h1 = seed; + + const uint32_t c1 = 0xcc9e2d51; + const uint32_t c2 = 0x1b873593; + + /// body + const uint32_t * blocks = reinterpret_cast(data + nblocks * 4); + + for (int i = -nblocks; i; i++) + { + uint32_t k1 = getblock32(blocks, i); + + k1 *= c1; + k1 = rotl32(k1, 15); + k1 *= c2; + + h1 ^= k1; + h1 = rotl32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + } + + /// tail + const uint8_t * tail = (data + nblocks * 4); + uint32_t k1 = 0; + while (tail != data + len) + { + k1 = *tail; + + k1 *= c1; + k1 = rotl32(k1, 15); + k1 *= c2; + + h1 ^= k1; + h1 = rotl32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + + ++tail; + } + + /// finalization + h1 ^= len; + h1 ^= h1 >> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >> 16; + + /// output + *static_cast(out) = h1; +} + +/// For spark compatiability of ClickHouse. +/// The difference between spark hash and CH murmurHash3_32 +/// 1. They calculate hash functions with different seeds +/// 2. Spark current impl is not right, but it is not fixed for backward compatiability. See: https://issues.apache.org/jira/browse/SPARK-23381 +struct MurmurHashSpark3Impl32 +{ + static constexpr auto name = "murmurHashSpark3_32"; + using ReturnType = UInt32; + + static UInt32 apply(const char * data, const size_t size) + { + union + { + UInt32 h; + char bytes[sizeof(h)]; + }; + MurmurHashSpark3_x86_32(data, size, 42, bytes); + return h; + } + + static UInt32 combineHashes(UInt32 h1, UInt32 h2) { return IntHash32Impl::apply(h1) ^ h2; } + + static constexpr bool use_int_hash_for_pods = false; +}; + +using FunctionXxHashSpark64 = FunctionAnyHash; +using FunctionMurmurHashSpark3_32 = FunctionAnyHash; + +} diff --git a/utils/local-engine/Functions/positionUTF8Spark.cpp b/utils/local-engine/Functions/positionUTF8Spark.cpp new file mode 100644 index 000000000000..6a987aea567c --- /dev/null +++ b/utils/local-engine/Functions/positionUTF8Spark.cpp @@ -0,0 +1,296 @@ +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_COLUMN; +} + +} + +namespace local_engine +{ + +using namespace DB; + +// Spark-specific version of PositionImpl +template +struct PositionSparkImpl +{ + static constexpr bool use_default_implementation_for_constants = false; + static constexpr bool supports_start_pos = true; + static constexpr auto name = Name::name; + + static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {};} + + using ResultType = UInt64; + + /// Find one substring in many strings. + static void vectorConstant( + const ColumnString::Chars & data, + const ColumnString::Offsets & offsets, + const std::string & needle, + const ColumnPtr & start_pos, + PaddedPODArray & res, + [[maybe_unused]] ColumnUInt8 * res_null) + { + + /// `res_null` serves as an output parameter for implementing an XYZOrNull variant. + assert(!res_null); + + const UInt8 * begin = data.data(); + const UInt8 * pos = begin; + const UInt8 * end = pos + data.size(); + + /// Current index in the array of strings. + size_t i = 0; + + typename Impl::SearcherInBigHaystack searcher = Impl::createSearcherInBigHaystack(needle.data(), needle.size(), end - pos); + + /// We will search for the next occurrence in all strings at once. + while (pos < end && end != (pos = searcher.search(pos, end - pos))) + { + /// Determine which index it refers to. + while (begin + offsets[i] <= pos) + { + res[i] = 0; + ++i; + } + auto start = start_pos != nullptr ? start_pos->getUInt(i) : 0; + + /// We check that the entry does not pass through the boundaries of strings. + // The result is 0 if start_pos is 0, in compliance with Spark semantics + if (start != 0 && pos + needle.size() < begin + offsets[i]) + { + auto res_pos = 1 + Impl::countChars(reinterpret_cast(begin + offsets[i - 1]), reinterpret_cast(pos)); + if (res_pos < start) + { + pos = reinterpret_cast(Impl::advancePos( + reinterpret_cast(pos), + reinterpret_cast(begin + offsets[i]), + start - res_pos)); + continue; + } + // The result is 1 if needle is empty, in compliance with Spark semantics + res[i] = needle.empty() ? 1 : res_pos; + } + else + { + res[i] = 0; + } + pos = begin + offsets[i]; + ++i; + } + + if (i < res.size()) + memset(&res[i], 0, (res.size() - i) * sizeof(res[0])); + } + + /// Search for substring in string. + static void constantConstantScalar( + std::string data, + std::string needle, + UInt64 start_pos, + UInt64 & res) + { + size_t start_byte = Impl::advancePos(data.data(), data.data() + data.size(), start_pos - 1) - data.data(); + res = data.find(needle, start_byte); + if (res == std::string::npos) + res = 0; + else + res = 1 + Impl::countChars(data.data(), data.data() + res); + } + + /// Search for substring in string starting from different positions. + static void constantConstant( + std::string data, + std::string needle, + const ColumnPtr & start_pos, + PaddedPODArray & res, + [[maybe_unused]] ColumnUInt8 * res_null) + { + /// `res_null` serves as an output parameter for implementing an XYZOrNull variant. + assert(!res_null); + + Impl::toLowerIfNeed(data); + Impl::toLowerIfNeed(needle); + + if (start_pos == nullptr) + { + res[0] = 0; + return; + } + + size_t haystack_size = Impl::countChars(data.data(), data.data() + data.size()); + + size_t size = start_pos != nullptr ? start_pos->size() : 0; + for (size_t i = 0; i < size; ++i) + { + auto start = start_pos->getUInt(i); + + if (start == 0 || start > haystack_size + 1) + { + res[i] = 0; + continue; + } + if (needle.empty()) + { + res[0] = 1; + continue; + } + constantConstantScalar(data, needle, start, res[i]); + } + } + + /// Search each time for a different single substring inside each time different string. + static void vectorVector( + const ColumnString::Chars & haystack_data, + const ColumnString::Offsets & haystack_offsets, + const ColumnString::Chars & needle_data, + const ColumnString::Offsets & needle_offsets, + const ColumnPtr & start_pos, + PaddedPODArray & res, + [[maybe_unused]] ColumnUInt8 * res_null) + { + /// `res_null` serves as an output parameter for implementing an XYZOrNull variant. + assert(!res_null); + + ColumnString::Offset prev_haystack_offset = 0; + ColumnString::Offset prev_needle_offset = 0; + + size_t size = haystack_offsets.size(); + + for (size_t i = 0; i < size; ++i) + { + size_t needle_size = needle_offsets[i] - prev_needle_offset - 1; + size_t haystack_size = haystack_offsets[i] - prev_haystack_offset - 1; + + auto start = start_pos != nullptr ? start_pos->getUInt(i) : UInt64(0); + + if (start == 0 || start > haystack_size + 1) + { + res[i] = 0; + } + else if (0 == needle_size) + { + /// An empty string is always 1 in compliance with Spark semantics. + res[i] = 1; + } + else + { + /// It is assumed that the StringSearcher is not very difficult to initialize. + typename Impl::SearcherInSmallHaystack searcher = Impl::createSearcherInSmallHaystack( + reinterpret_cast(&needle_data[prev_needle_offset]), + needle_offsets[i] - prev_needle_offset - 1); /// zero byte at the end + + const char * beg = Impl::advancePos( + reinterpret_cast(&haystack_data[prev_haystack_offset]), + reinterpret_cast(&haystack_data[haystack_offsets[i] - 1]), + start - 1); + /// searcher returns a pointer to the found substring or to the end of `haystack`. + size_t pos = searcher.search(reinterpret_cast(beg), &haystack_data[haystack_offsets[i] - 1]) + - &haystack_data[prev_haystack_offset]; + + if (pos != haystack_size) + { + res[i] = 1 + + Impl::countChars( + reinterpret_cast(&haystack_data[prev_haystack_offset]), + reinterpret_cast(&haystack_data[prev_haystack_offset + pos])); + } + else + res[i] = 0; + } + + prev_haystack_offset = haystack_offsets[i]; + prev_needle_offset = needle_offsets[i]; + } + } + + /// Find many substrings in single string. + static void constantVector( + const String & haystack, + const ColumnString::Chars & needle_data, + const ColumnString::Offsets & needle_offsets, + const ColumnPtr & start_pos, + PaddedPODArray & res, + [[maybe_unused]] ColumnUInt8 * res_null) + { + /// `res_null` serves as an output parameter for implementing an XYZOrNull variant. + assert(!res_null); + + /// NOTE You could use haystack indexing. But this is a rare case. + ColumnString::Offset prev_needle_offset = 0; + + size_t size = needle_offsets.size(); + + for (size_t i = 0; i < size; ++i) + { + size_t needle_size = needle_offsets[i] - prev_needle_offset - 1; + + auto start = start_pos != nullptr ? start_pos->getUInt(i) : UInt64(0); + + if (start == 0 || start > haystack.size() + 1) + { + res[i] = 0; + } + else if (0 == needle_size) + { + res[i] = 1; + } + else + { + typename Impl::SearcherInSmallHaystack searcher = Impl::createSearcherInSmallHaystack( + reinterpret_cast(&needle_data[prev_needle_offset]), needle_offsets[i] - prev_needle_offset - 1); + + const char * beg = Impl::advancePos(haystack.data(), haystack.data() + haystack.size(), start - 1); + size_t pos = searcher.search( + reinterpret_cast(beg), + reinterpret_cast(haystack.data()) + haystack.size()) + - reinterpret_cast(haystack.data()); + + if (pos != haystack.size()) + { + res[i] = 1 + Impl::countChars(haystack.data(), haystack.data() + pos); + } + else + res[i] = 0; + } + + prev_needle_offset = needle_offsets[i]; + } + } + + template + static void vectorFixedConstant(Args &&...) + { + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name); + } + + template + static void vectorFixedVector(Args &&...) + { + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name); + } +}; + +struct NamePositionUTF8Spark +{ + static constexpr auto name = "positionUTF8Spark"; +}; + + +using FunctionPositionUTF8Spark = FunctionsStringSearch>; + + +void registerFunctionPositionUTF8Spark(FunctionFactory & factory) +{ + factory.registerFunction(); +} + +} diff --git a/utils/local-engine/Functions/registerFunctions.cpp b/utils/local-engine/Functions/registerFunctions.cpp new file mode 100644 index 000000000000..4904cb69ce66 --- /dev/null +++ b/utils/local-engine/Functions/registerFunctions.cpp @@ -0,0 +1,18 @@ +#include + +namespace local_engine +{ + +using namespace DB; +void registerFunctionSparkTrim(FunctionFactory &); +void registerFunctionsHashingExtended(FunctionFactory & factory); +void registerFunctionPositionUTF8Spark(FunctionFactory &); + +void registerFunctions(FunctionFactory & factory) +{ + registerFunctionSparkTrim(factory); + registerFunctionsHashingExtended(factory); + registerFunctionPositionUTF8Spark(factory); +} + +} diff --git a/utils/local-engine/Functions/sparkTrim.cpp b/utils/local-engine/Functions/sparkTrim.cpp new file mode 100644 index 000000000000..2c80f09a4a52 --- /dev/null +++ b/utils/local-engine/Functions/sparkTrim.cpp @@ -0,0 +1,173 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace DB; + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ILLEGAL_COLUMN; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + + +namespace local_engine +{ + +struct TrimModeLeft +{ + static constexpr auto name = "sparkTrimLeft"; + static constexpr bool trim_left = true; + static constexpr bool trim_right = false; +}; + +struct TrimModeRight +{ + static constexpr auto name = "sparkTrimRigth"; + static constexpr bool trim_left = false; + static constexpr bool trim_right = true; +}; + +struct TrimModeBoth +{ + static constexpr auto name = "sparkTrimBoth"; + static constexpr bool trim_left = true; + static constexpr bool trim_right = true; +}; + + +namespace +{ + + template + class SparkTrimFunction : public IFunction + { + public: + static constexpr auto name = Name::name; + + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + String getName() const override { return name; } + + bool isVariadic() const override { return true; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + size_t getNumberOfArguments() const override { return 2; } + + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (arguments.size() != 2) + throw Exception( + "Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) + + ", should be 2.", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + return std::make_shared(); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override + { + const ColumnPtr & trimStrPtr = arguments[1].column; + const ColumnConst * trimStrConst = typeid_cast(&*trimStrPtr); + + if (!trimStrConst) { + throw Exception("Second argument of function " + getName() + " must be constant string", ErrorCodes::ILLEGAL_COLUMN); + } + String trimStr = trimStrConst->getValue(); + auto col_res = ColumnString::create(); + const ColumnPtr column = arguments[0].column; + if (const ColumnString * col = checkAndGetColumn(column.get())) + { + executeVector(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), trimStr); + return col_res; + } + + return col_res; + } + + private: + void executeVector( + const ColumnString::Chars & data, + const ColumnString::Offsets & offsets, + ColumnString::Chars & res_data, + ColumnString::Offsets & res_offsets, + const DB::String & trimStr) const + { + size_t size = offsets.size(); + res_offsets.resize(size); + res_data.reserve(data.size()); + + size_t prev_offset = 0; + size_t res_offset = 0; + + const UInt8 * start; + size_t length; + + for (size_t i = 0; i < size; ++i) + { + trimInternal(reinterpret_cast(&data[prev_offset]), offsets[i] - prev_offset - 1, start, length, trimStr); + res_data.resize(res_data.size() + length + 1); + memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], start, length); + res_offset += length + 1; + res_data[res_offset - 1] = '\0'; + res_offsets[i] = res_offset; + prev_offset = offsets[i]; + } + } + + void trimInternal(const UInt8 * data, size_t size, const UInt8 *& res_data, size_t & res_size, const DB::String & trimStr) const + { + const char * char_data = reinterpret_cast(data); + const char * char_end = char_data + size; + if constexpr (Name::trim_left) + { + for (size_t i = 0; i < size; i++) + { + char c = * (char_data + i); + if (trimStr.find(c) == std::string::npos) { + char_data += i; + break; + } + } + res_data = reinterpret_cast(char_data); + + } + if constexpr (Name::trim_right) + { + while(char_end != char_data) { + char c = *(char_end -1); + if (trimStr.find(c) == std::string::npos) { + break; + } + char_end -= 1; + } + } + res_size = char_end - char_data; + } + }; + + using FunctionSparkTrimBoth = SparkTrimFunction; + using FunctionSparkTrimLeft = SparkTrimFunction; + using FunctionSparkTrimRight = SparkTrimFunction; +} + +void registerFunctionSparkTrim(FunctionFactory & factory) +{ + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); +} +} diff --git a/utils/local-engine/Operator/BlockCoalesceOperator.cpp b/utils/local-engine/Operator/BlockCoalesceOperator.cpp new file mode 100644 index 000000000000..69931235ce19 --- /dev/null +++ b/utils/local-engine/Operator/BlockCoalesceOperator.cpp @@ -0,0 +1,33 @@ +#include "BlockCoalesceOperator.h" +#include + +namespace local_engine +{ +void BlockCoalesceOperator::mergeBlock(DB::Block & block) +{ + block_buffer.add(block, 0, block.rows()); +} +bool BlockCoalesceOperator::isFull() +{ + return block_buffer.size() >= buf_size; +} +DB::Block* BlockCoalesceOperator::releaseBlock() +{ + clearCache(); + cached_block = new DB::Block(block_buffer.releaseColumns()); + return cached_block; +} +BlockCoalesceOperator::~BlockCoalesceOperator() +{ + clearCache(); +} +void BlockCoalesceOperator::clearCache() +{ + if (cached_block) + { + delete cached_block; + cached_block = nullptr; + } +} +} + diff --git a/utils/local-engine/Operator/BlockCoalesceOperator.h b/utils/local-engine/Operator/BlockCoalesceOperator.h new file mode 100644 index 000000000000..46bbbfd3276d --- /dev/null +++ b/utils/local-engine/Operator/BlockCoalesceOperator.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace DB +{ +class Block; +} + +namespace local_engine +{ +class BlockCoalesceOperator +{ +public: + BlockCoalesceOperator(size_t buf_size_):buf_size(buf_size_){} + virtual ~BlockCoalesceOperator(); + void mergeBlock(DB::Block & block); + bool isFull(); + DB::Block* releaseBlock(); + +private: + size_t buf_size; + ColumnsBuffer block_buffer; + DB::Block * cached_block = nullptr; + + void clearCache(); +}; +} + + diff --git a/utils/local-engine/Operator/ExpandStep.cpp b/utils/local-engine/Operator/ExpandStep.cpp new file mode 100644 index 000000000000..6bc4e659db71 --- /dev/null +++ b/utils/local-engine/Operator/ExpandStep.cpp @@ -0,0 +1,113 @@ +#include "ExpandStep.h" +#include "ExpandTransorm.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +static DB::ITransformingStep::Traits getTraits() +{ + return DB::ITransformingStep::Traits + { + { + .preserves_distinct_columns = false, /// Actually, we may check that distinct names are in aggregation keys + .returns_single_stream = true, + .preserves_number_of_streams = false, + .preserves_sorting = false, + }, + { + .preserves_number_of_rows = false, + } + }; +} + +ExpandStep::ExpandStep( + const DB::DataStream & input_stream_, + const std::vector & aggregating_expressions_columns_, + const std::vector> & grouping_sets_, + const std::string & grouping_id_name_) + : DB::ITransformingStep( + input_stream_, + buildOutputHeader(input_stream_.header, aggregating_expressions_columns_, grouping_id_name_), + getTraits()) + , aggregating_expressions_columns(aggregating_expressions_columns_) + , grouping_sets(grouping_sets_) + , grouping_id_name(grouping_id_name_) +{ + header = input_stream_.header; + output_header = getOutputStream().header; +} + +DB::Block ExpandStep::buildOutputHeader( + const DB::Block & input_header, + const std::vector & aggregating_expressions_columns_, + const std::string & grouping_id_name_) +{ + DB::ColumnsWithTypeAndName cols; + std::set agg_cols; + + for (size_t i = 0; i < input_header.columns(); ++i) + { + const auto & old_col = input_header.getByPosition(i); + if (i < aggregating_expressions_columns_.size()) + { + // do nothing with the aggregating columns. + cols.push_back(old_col); + continue; + } + if (old_col.type->isNullable()) + cols.push_back(old_col); + else + { + auto null_map = DB::ColumnUInt8::create(0, 0); + auto null_col = DB::ColumnNullable::create(old_col.column, std::move(null_map)); + auto null_type = std::make_shared(old_col.type); + cols.push_back(DB::ColumnWithTypeAndName(null_col, null_type, old_col.name)); + } + } + + // add group id column + auto grouping_id_col = DB::ColumnInt64::create(0, 0); + auto grouping_id_type = std::make_shared(); + cols.emplace_back(DB::ColumnWithTypeAndName(std::move(grouping_id_col), grouping_id_type, grouping_id_name_)); + return DB::Block(cols); +} + +void ExpandStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/) +{ + DB::QueryPipelineProcessorsCollector collector(pipeline, this); + auto build_transform = [&](DB::OutputPortRawPtrs outputs){ + DB::Processors new_processors; + for (auto & output : outputs) + { + auto expand_op = std::make_shared(header, output_header, aggregating_expressions_columns, grouping_sets); + new_processors.push_back(expand_op); + DB::connect(*output, expand_op->getInputs().front()); + } + return new_processors; + }; + pipeline.transform(build_transform); + processors = collector.detachProcessors(); +} + +void ExpandStep::describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const +{ + if (!processors.empty()) + DB::IQueryPlanStep::describePipeline(processors, settings); +} +void ExpandStep::updateOutputStream() +{ + createOutputStream(input_streams.front(), output_header, getDataStreamTraits()); +} + +} diff --git a/utils/local-engine/Operator/ExpandStep.h b/utils/local-engine/Operator/ExpandStep.h new file mode 100644 index 000000000000..3b12432df9cd --- /dev/null +++ b/utils/local-engine/Operator/ExpandStep.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +namespace local_engine +{ +class ExpandStep : public DB::ITransformingStep +{ +public: + // The input stream should only contain grouping columns. + explicit ExpandStep( + const DB::DataStream & input_stream_, + const std::vector & aggregating_expressions_columns_, + const std::vector> & grouping_sets_, + const std::string & grouping_id_name_); + ~ExpandStep() override = default; + + String getName() const override { return "ExpandStep"; } + + void transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & settings) override; + void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const override; +private: + std::vector aggregating_expressions_columns; + std::vector> grouping_sets; + std::string grouping_id_name; + DB::Block header; + DB::Block output_header; + + void updateOutputStream() override; + + static DB::Block buildOutputHeader( + const DB::Block & header, + const std::vector & aggregating_expressions_columns_, + const std::string & grouping_id_name_); +}; +} diff --git a/utils/local-engine/Operator/ExpandTransform.cpp b/utils/local-engine/Operator/ExpandTransform.cpp new file mode 100644 index 000000000000..c083c1afa552 --- /dev/null +++ b/utils/local-engine/Operator/ExpandTransform.cpp @@ -0,0 +1,122 @@ +#include "ExpandTransorm.h" +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace local_engine +{ +ExpandTransform::ExpandTransform( + const DB::Block & input_, + const DB::Block & output_, + const std::vector & aggregating_expressions_columns_, + const std::vector> & grouping_sets_) + : DB::IProcessor({input_}, {output_}) + , aggregating_expressions_columns(aggregating_expressions_columns_) + , grouping_sets(grouping_sets_) +{} + +ExpandTransform::Status ExpandTransform::prepare() +{ + auto & output = outputs.front(); + auto & input = inputs.front(); + + if (output.isFinished()) + { + input.close(); + return Status::Finished; + } + + if (!output.canPush()) + { + input.setNotNeeded(); + return Status::PortFull; + } + + if (has_output) + { + output.push(nextChunk()); + return Status::PortFull; + } + + if (has_input) + { + return Status::Ready; + } + + if (input.isFinished()) + { + output.finish(); + return Status::Finished; + } + + if (!input.hasData()) + { + input.setNeeded(); + return Status::NeedData; + } + input_chunk = input.pull(); + has_input = true; + return Status::Ready; +} + +void ExpandTransform::work() +{ + assert(expanded_chunks.empty()); + size_t agg_cols_size = aggregating_expressions_columns.size(); + for (int set_id = 0; static_cast(set_id) < grouping_sets.size(); ++set_id) + { + const auto & sets = grouping_sets[set_id]; + DB::Columns cols; + const auto & original_cols = input_chunk.getColumns(); + for (size_t i = 0; i < original_cols.size(); ++i) + { + const auto & original_col = original_cols[i]; + size_t rows = original_col->size(); + if (i < agg_cols_size) + { + cols.push_back(original_col); + continue; + } + // the output columns should all be nullable. + if (!sets.contains(i)) + { + auto null_map = DB::ColumnUInt8::create(rows, 1); + auto col = DB::ColumnNullable::create(original_col, std::move(null_map)); + cols.push_back(std::move(col)); + } + else + { + if (original_col->isNullable()) + cols.push_back(original_col); + else + { + auto null_map = DB::ColumnUInt8::create(rows, 0); + auto col = DB::ColumnNullable::create(original_col, std::move(null_map)); + cols.push_back(std::move(col)); + } + } + } + auto id_col = DB::DataTypeInt64().createColumnConst(input_chunk.getNumRows(), set_id); + cols.push_back(std::move(id_col)); + expanded_chunks.push_back(DB::Chunk(cols, input_chunk.getNumRows())); + } + has_output = true; + has_input = false; +} + +DB::Chunk ExpandTransform::nextChunk() +{ + assert(!expanded_chunks.empty()); + DB::Chunk ret; + ret.swap(expanded_chunks.front()); + expanded_chunks.pop_front(); + has_output = !expanded_chunks.empty(); + return ret; +} +} diff --git a/utils/local-engine/Operator/ExpandTransorm.h b/utils/local-engine/Operator/ExpandTransorm.h new file mode 100644 index 000000000000..b131c36cd809 --- /dev/null +++ b/utils/local-engine/Operator/ExpandTransorm.h @@ -0,0 +1,41 @@ +#pragma once +#include +#include +#include +#include +#include +#include +namespace local_engine +{ +// For handling substrait expand node. +// The implementation in spark for groupingsets/rollup/cube is different from Clickhouse. +// We have to ways to support groupingsets/rollup/cube +// - rewrite the substrait plan in local engine and reuse the implementation of clickhouse. This +// may be more complex. +// - implement new transform to do the expandation. It's more simple, but may suffer some performance +// issues. We try this first. +class ExpandTransform : public DB::IProcessor +{ +public: + using Status = DB::IProcessor::Status; + ExpandTransform( + const DB::Block & input_, + const DB::Block & output_, + const std::vector & aggregating_expressions_columns_, + const std::vector> & grouping_sets_); + + Status prepare() override; + void work() override; + + DB::String getName() const override { return "ExpandTransform"; } +private: + std::vector aggregating_expressions_columns; + std::vector> grouping_sets; + bool has_input = false; + bool has_output = false; + + DB::Chunk input_chunk; + std::list expanded_chunks; + DB::Chunk nextChunk(); +}; +} diff --git a/utils/local-engine/Operator/PartitionColumnFillingTransform.cpp b/utils/local-engine/Operator/PartitionColumnFillingTransform.cpp new file mode 100644 index 000000000000..4b7c5374ceed --- /dev/null +++ b/utils/local-engine/Operator/PartitionColumnFillingTransform.cpp @@ -0,0 +1,123 @@ +#include "PartitionColumnFillingTransform.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace DB; + +namespace DB +{ +namespace ErrorCodes +{ + extern const int UNKNOWN_TYPE; +} +} + +namespace local_engine +{ +template +requires( + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) + ColumnPtr createIntPartitionColumn(DataTypePtr column_type, std::string partition_value) +{ + Type value; + auto value_buffer = ReadBufferFromString(partition_value); + readIntText(value, value_buffer); + return column_type->createColumnConst(1, value); +} + +template +requires(std::is_same_v || std::is_same_v) ColumnPtr + createFloatPartitionColumn(DataTypePtr column_type, std::string partition_value) +{ + Type value; + auto value_buffer = ReadBufferFromString(partition_value); + readFloatText(value, value_buffer); + return column_type->createColumnConst(1, value); +} + +//template <> +//ColumnPtr createFloatPartitionColumn(DataTypePtr column_type, std::string partition_value); +//template <> +//ColumnPtr createFloatPartitionColumn(DataTypePtr column_type, std::string partition_value); + +PartitionColumnFillingTransform::PartitionColumnFillingTransform( + const DB::Block & input_, const DB::Block & output_, const String & partition_col_name_, const String & partition_col_value_) + : ISimpleTransform(input_, output_, true), partition_col_name(partition_col_name_), partition_col_value(partition_col_value_) +{ + partition_col_type = output_.getByName(partition_col_name_).type; + partition_column = createPartitionColumn(); +} + +ColumnPtr PartitionColumnFillingTransform::createPartitionColumn() +{ + ColumnPtr result; + DataTypePtr nested_type = partition_col_type; + if (const DataTypeNullable * nullable_type = checkAndGetDataType(partition_col_type.get())) + { + nested_type = nullable_type->getNestedType(); + if (StringUtils::isNullPartitionValue(partition_col_value)) + { + return nullable_type->createColumnConstWithDefaultValue(1); + } + } + WhichDataType which(nested_type); + if (which.isInt8()) + { + result = createIntPartitionColumn(partition_col_type, partition_col_value); + } + else if (which.isInt16()) + { + result = createIntPartitionColumn(partition_col_type, partition_col_value); + } + else if (which.isInt32()) + { + result = createIntPartitionColumn(partition_col_type, partition_col_value); + } + else if (which.isInt64()) + { + result = createIntPartitionColumn(partition_col_type, partition_col_value); + } + else if (which.isFloat32()) + { + result = createFloatPartitionColumn(partition_col_type, partition_col_value); + } + else if (which.isFloat64()) + { + result = createFloatPartitionColumn(partition_col_type, partition_col_value); + } + else if (which.isDate()) + { + DayNum value; + auto value_buffer = ReadBufferFromString(partition_col_value); + readDateText(value, value_buffer); + result = partition_col_type->createColumnConst(1, value); + } + else if (which.isString()) + { + result = partition_col_type->createColumnConst(1, partition_col_value); + } + else + { + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported datatype {}", partition_col_type->getFamilyName()); + } + return result; +} + +void PartitionColumnFillingTransform::transform(DB::Chunk & chunk) +{ + size_t partition_column_position = output.getHeader().getPositionByName(partition_col_name); + if (partition_column_position == input.getHeader().columns()) + { + chunk.addColumn(partition_column->cloneResized(chunk.getNumRows())); + } + else + { + chunk.addColumn(partition_column_position, partition_column->cloneResized(chunk.getNumRows())); + } +} +} diff --git a/utils/local-engine/Operator/PartitionColumnFillingTransform.h b/utils/local-engine/Operator/PartitionColumnFillingTransform.h new file mode 100644 index 000000000000..f3e0a606a506 --- /dev/null +++ b/utils/local-engine/Operator/PartitionColumnFillingTransform.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace local_engine +{ +class PartitionColumnFillingTransform : public DB::ISimpleTransform +{ +public: + PartitionColumnFillingTransform( + const DB::Block & input_, + const DB::Block & output_, + const String & partition_col_name_, + const String & partition_col_value_); + void transform(DB::Chunk & chunk) override; + String getName() const override + { + return "PartitionColumnFillingTransform"; + } + +private: + DB::ColumnPtr createPartitionColumn(); + + DB::DataTypePtr partition_col_type; + String partition_col_name; + String partition_col_value; + DB::ColumnPtr partition_column; +}; + +} + + diff --git a/utils/local-engine/Parser/CHColumnToSparkRow.cpp b/utils/local-engine/Parser/CHColumnToSparkRow.cpp new file mode 100644 index 000000000000..edde9f603a75 --- /dev/null +++ b/utils/local-engine/Parser/CHColumnToSparkRow.cpp @@ -0,0 +1,957 @@ +#include "CHColumnToSparkRow.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int UNKNOWN_TYPE; +} +} + + +namespace local_engine +{ +using namespace DB; + +int64_t calculateBitSetWidthInBytes(int32_t num_fields) +{ + return ((num_fields + 63) / 64) * 8; +} + +static int64_t calculatedFixeSizePerRow(int64_t num_cols) +{ + return calculateBitSetWidthInBytes(num_cols) + num_cols * 8; +} + +int64_t roundNumberOfBytesToNearestWord(int64_t num_bytes) +{ + auto remainder = num_bytes & 0x07; // This is equivalent to `numBytes % 8` + return num_bytes + ((8 - remainder) & 0x7); +} + + +void bitSet(char * bitmap, int32_t index) +{ + int64_t mask = 1L << (index & 0x3f); // mod 64 and shift + int64_t word_offset = (index >> 6) * 8; + int64_t word; + memcpy(&word, bitmap + word_offset, sizeof(int64_t)); + int64_t value = word | mask; + memcpy(bitmap + word_offset, &value, sizeof(int64_t)); +} + +ALWAYS_INLINE bool isBitSet(const char * bitmap, int32_t index) +{ + assert(index >= 0); + int64_t mask = 1 << (index & 63); + int64_t word_offset = static_cast(index >> 6) * 8L; + int64_t word = *reinterpret_cast(bitmap + word_offset); + return word & mask; +} + +static void writeFixedLengthNonNullableValue( + char * buffer_address, + int64_t field_offset, + const ColumnWithTypeAndName & col, + int64_t num_rows, + const std::vector & offsets) +{ + FixedLengthDataWriter writer(col.type); + for (size_t i = 0; i < static_cast(num_rows); i++) + writer.unsafeWrite(col.column->getDataAt(i), buffer_address + offsets[i] + field_offset); +} + +static void writeFixedLengthNullableValue( + char * buffer_address, + int64_t field_offset, + const ColumnWithTypeAndName & col, + int32_t col_index, + int64_t num_rows, + const std::vector & offsets) +{ + const auto * nullable_column = checkAndGetColumn(*col.column); + const auto & null_map = nullable_column->getNullMapData(); + const auto & nested_column = nullable_column->getNestedColumn(); + FixedLengthDataWriter writer(col.type); + for (size_t i = 0; i < static_cast(num_rows); i++) + { + if (null_map[i]) + bitSet(buffer_address + offsets[i], col_index); + else + writer.unsafeWrite(nested_column.getDataAt(i), buffer_address + offsets[i] + field_offset); + } +} + +static void writeVariableLengthNonNullableValue( + char * buffer_address, + int64_t field_offset, + const ColumnWithTypeAndName & col, + int64_t num_rows, + const std::vector & offsets, + std::vector & buffer_cursor) +{ + const auto type_without_nullable{std::move(removeNullable(col.type))}; + const bool use_raw_data = BackingDataLengthCalculator::isDataTypeSupportRawData(type_without_nullable); + const bool big_endian = BackingDataLengthCalculator::isBigEndianInSparkRow(type_without_nullable); + VariableLengthDataWriter writer(col.type, buffer_address, offsets, buffer_cursor); + if (use_raw_data) + { + if (!big_endian) + { + for (size_t i = 0; i < static_cast(num_rows); i++) + { + StringRef str = col.column->getDataAt(i); + int64_t offset_and_size = writer.writeUnalignedBytes(i, str.data, str.size, 0); + memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); + } + } + else + { + Field field; + for (size_t i = 0; i < static_cast(num_rows); i++) + { + StringRef str_view = col.column->getDataAt(i); + String buf(str_view.data, str_view.size); + BackingDataLengthCalculator::swapDecimalEndianBytes(buf); + int64_t offset_and_size = writer.writeUnalignedBytes(i, buf.data(), buf.size(), 0); + memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); + } + } + } + else + { + Field field; + for (size_t i = 0; i < static_cast(num_rows); i++) + { + field = std::move((*col.column)[i]); + int64_t offset_and_size = writer.write(i, field, 0); + memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); + } + } +} + +static void writeVariableLengthNullableValue( + char * buffer_address, + int64_t field_offset, + const ColumnWithTypeAndName & col, + int32_t col_index, + int64_t num_rows, + const std::vector & offsets, + std::vector & buffer_cursor) +{ + const auto * nullable_column = checkAndGetColumn(*col.column); + const auto & null_map = nullable_column->getNullMapData(); + const auto & nested_column = nullable_column->getNestedColumn(); + const auto type_without_nullable{std::move(removeNullable(col.type))}; + const bool use_raw_data = BackingDataLengthCalculator::isDataTypeSupportRawData(type_without_nullable); + const bool big_endian = BackingDataLengthCalculator::isBigEndianInSparkRow(type_without_nullable); + VariableLengthDataWriter writer(col.type, buffer_address, offsets, buffer_cursor); + if (use_raw_data) + { + for (size_t i = 0; i < static_cast(num_rows); i++) + { + if (null_map[i]) + bitSet(buffer_address + offsets[i], col_index); + else if (!big_endian) + { + StringRef str = nested_column.getDataAt(i); + int64_t offset_and_size = writer.writeUnalignedBytes(i, str.data, str.size, 0); + memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); + } + else + { + Field field; + nested_column.get(i, field); + StringRef str_view = nested_column.getDataAt(i); + String buf(str_view.data, str_view.size); + BackingDataLengthCalculator::swapDecimalEndianBytes(buf); + int64_t offset_and_size = writer.writeUnalignedBytes(i, buf.data(), buf.size(), 0); + memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); + } + } + } + else + { + Field field; + for (size_t i = 0; i < static_cast(num_rows); i++) + { + if (null_map[i]) + bitSet(buffer_address + offsets[i], col_index); + else + { + field = std::move(nested_column[i]); + int64_t offset_and_size = writer.write(i, field, 0); + memcpy(buffer_address + offsets[i] + field_offset, &offset_and_size, 8); + } + } + } +} + + +static void writeValue( + char * buffer_address, + int64_t field_offset, + const ColumnWithTypeAndName & col, + int32_t col_index, + int64_t num_rows, + const std::vector & offsets, + std::vector & buffer_cursor) +{ + const auto type_without_nullable{std::move(removeNullable(col.type))}; + const auto is_nullable = isColumnNullable(*col.column); + if (BackingDataLengthCalculator::isFixedLengthDataType(type_without_nullable)) + { + if (is_nullable) + writeFixedLengthNullableValue(buffer_address, field_offset, col, col_index, num_rows, offsets); + else + writeFixedLengthNonNullableValue(buffer_address, field_offset, col, num_rows, offsets); + } + else if (BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) + { + if (is_nullable) + writeVariableLengthNullableValue(buffer_address, field_offset, col, col_index, num_rows, offsets, buffer_cursor); + else + writeVariableLengthNonNullableValue(buffer_address, field_offset, col, num_rows, offsets, buffer_cursor); + } + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Doesn't support type {} for writeValue", col.type->getName()); +} + +SparkRowInfo::SparkRowInfo(const DB::ColumnsWithTypeAndName & cols, const DB::DataTypes & types, const size_t & col_size, const size_t & row_size) + : types(types) + , num_rows(row_size) + , num_cols(col_size) + , null_bitset_width_in_bytes(calculateBitSetWidthInBytes(num_cols)) + , total_bytes(0) + , offsets(num_rows, 0) + , lengths(num_rows, 0) + , buffer_cursor(num_rows, 0) + , buffer_address(nullptr) +{ + int64_t fixed_size_per_row = calculatedFixeSizePerRow(num_cols); + + /// Initialize lengths and buffer_cursor + for (int64_t i = 0; i < num_rows; i++) + { + lengths[i] = fixed_size_per_row; + buffer_cursor[i] = fixed_size_per_row; + } + + for (int64_t col_idx = 0; col_idx < num_cols; ++col_idx) + { + const auto & col = cols[col_idx]; + /// No need to calculate backing data length for fixed length types + const auto type_without_nullable = removeNullable(col.type); + if (BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) + { + if (BackingDataLengthCalculator::isDataTypeSupportRawData(type_without_nullable)) + { + auto column = col.column->convertToFullColumnIfConst(); + const auto * nullable_column = checkAndGetColumn(*column); + if (nullable_column) + { + const auto & nested_column = nullable_column->getNestedColumn(); + const auto & null_map = nullable_column->getNullMapData(); + for (auto row_idx = 0; row_idx < num_rows; ++row_idx) + if (!null_map[row_idx]) + lengths[row_idx] += roundNumberOfBytesToNearestWord(nested_column.getDataAt(row_idx).size); + } + else + { + for (auto row_idx = 0; row_idx < num_rows; ++row_idx) + lengths[row_idx] += roundNumberOfBytesToNearestWord(col.column->getDataAt(row_idx).size); + } + } + else + { + BackingDataLengthCalculator calculator(col.type); + for (auto row_idx = 0; row_idx < num_rows; ++row_idx) + { + const auto field = (*col.column)[row_idx]; + lengths[row_idx] += calculator.calculate(field); + } + } + } + } + + /// Initialize offsets + for (int64_t i = 1; i < num_rows; ++i) + offsets[i] = offsets[i - 1] + lengths[i - 1]; + + /// Initialize total_bytes + for (int64_t i = 0; i < num_rows; ++i) + total_bytes += lengths[i]; +} + +SparkRowInfo::SparkRowInfo(const Block & block) : SparkRowInfo(block.getColumnsWithTypeAndName(), block.getDataTypes(), block.columns(), block.rows()){} + +const DB::DataTypes & SparkRowInfo::getDataTypes() const +{ + return types; +} + +int64_t SparkRowInfo::getFieldOffset(int32_t col_idx) const +{ + return null_bitset_width_in_bytes + 8L * col_idx; +} + +int64_t SparkRowInfo::getNullBitsetWidthInBytes() const +{ + return null_bitset_width_in_bytes; +} + +void SparkRowInfo::setNullBitsetWidthInBytes(int64_t null_bitset_width_in_bytes_) +{ + null_bitset_width_in_bytes = null_bitset_width_in_bytes_; +} + +int64_t SparkRowInfo::getNumCols() const +{ + return num_cols; +} + +void SparkRowInfo::setNumCols(int64_t num_cols_) +{ + num_cols = num_cols_; +} + +int64_t SparkRowInfo::getNumRows() const +{ + return num_rows; +} + +void SparkRowInfo::setNumRows(int64_t num_rows_) +{ + num_rows = num_rows_; +} + +char * SparkRowInfo::getBufferAddress() const +{ + return buffer_address; +} + +void SparkRowInfo::setBufferAddress(char * buffer_address_) +{ + buffer_address = buffer_address_; +} + +const std::vector & SparkRowInfo::getOffsets() const +{ + return offsets; +} + +const std::vector & SparkRowInfo::getLengths() const +{ + return lengths; +} + +std::vector & SparkRowInfo::getBufferCursor() +{ + return buffer_cursor; +} + +int64_t SparkRowInfo::getTotalBytes() const +{ + return total_bytes; +} + +std::unique_ptr CHColumnToSparkRow::convertCHColumnToSparkRow(const Block & block) +{ + if (!block.columns()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A block with empty columns"); + + auto block_col = block.getByPosition(0); + DB::ColumnPtr nested_col = block_col.column; + if (const auto * const_col = checkAndGetColumn(nested_col.get())) + { + nested_col = const_col->getDataColumnPtr(); + } + if (const auto * nullable_col = checkAndGetColumn(nested_col.get())) + { + nested_col = nullable_col->getNestedColumnPtr(); + } + + auto checkAndGetTupleDataTypes = [] (const DB::ColumnPtr& column) -> DB::DataTypes + { + DB::DataTypes data_types; + if (column->getDataType() != DB::TypeIndex::Tuple) + { + return data_types; + } + const auto * tuple_col = checkAndGetColumn(column.get()); + const size_t col_size = tuple_col->tupleSize(); + for (size_t i = 0; i < col_size; i++) + { + DB::DataTypePtr field_type; + const auto & field_col = tuple_col->getColumn(i); + if (field_col.isNullable()) + { + const auto & field_nested_col= assert_cast(field_col).getNestedColumn(); + if (field_nested_col.getDataType() != DB::TypeIndex::String) + { + data_types.clear(); + return data_types; + } + else + { + DataTypePtr string_type = std::make_shared(); + field_type = std::make_shared(string_type); + } + } + else if (field_col.getDataType() == DB::TypeIndex::String) + { + field_type = std::make_shared(); + } + else + { + data_types.clear(); + return data_types; + } + data_types.emplace_back(field_type); + } + return data_types; + }; + + std::unique_ptr spark_row_info; + DB::ColumnsWithTypeAndName columns; + auto data_types = checkAndGetTupleDataTypes(nested_col); + if (data_types.size() > 0) + { + const auto * tuple_col = checkAndGetColumn(nested_col.get()); + for (size_t i = 0; i < tuple_col->tupleSize(); i++) + { + DB::ColumnWithTypeAndName col_type_name(tuple_col->getColumnPtr(i), data_types[i], "c" + std::to_string(i)); + columns.emplace_back(col_type_name); + } + spark_row_info = std::make_unique(columns, data_types, tuple_col->tupleSize(), block.rows()); + } + else + { + spark_row_info = std::make_unique(block); + columns = block.getColumnsWithTypeAndName(); + } + spark_row_info->setBufferAddress(reinterpret_cast(alloc(spark_row_info->getTotalBytes(), 64))); + // spark_row_info->setBufferAddress(alignedAlloc(spark_row_info->getTotalBytes(), 64)); + memset(spark_row_info->getBufferAddress(), 0, spark_row_info->getTotalBytes()); + for (auto col_idx = 0; col_idx < spark_row_info->getNumCols(); col_idx++) + { + const auto & col = columns[col_idx]; + int64_t field_offset = spark_row_info->getFieldOffset(col_idx); + + ColumnWithTypeAndName col_not_const{col.column->convertToFullColumnIfConst(), col.type, col.name}; + writeValue( + spark_row_info->getBufferAddress(), + field_offset, + col_not_const, + col_idx, + spark_row_info->getNumRows(), + spark_row_info->getOffsets(), + spark_row_info->getBufferCursor()); + } + return spark_row_info; +} + +void CHColumnToSparkRow::freeMem(char * address, size_t size) +{ + free(address, size); + // rollback(size); +} + +BackingDataLengthCalculator::BackingDataLengthCalculator(const DataTypePtr & type_) + : type_without_nullable(removeNullable(type_)), which(type_without_nullable) +{ + if (!isFixedLengthDataType(type_without_nullable) && !isVariableLengthDataType(type_without_nullable)) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Doesn't support type {} for BackingDataLengthCalculator", type_without_nullable->getName()); +} + +int64_t BackingDataLengthCalculator::calculate(const Field & field) const +{ + if (field.isNull()) + return 0; + + if (which.isNativeInt() || which.isNativeUInt() || which.isFloat() || which.isDateOrDate32() || which.isDateTime64() + || which.isDecimal32() || which.isDecimal64()) + return 0; + + if (which.isStringOrFixedString()) + { + const auto & str = field.get(); + return roundNumberOfBytesToNearestWord(str.size()); + } + + if (which.isDecimal128()) + return 16; + + if (which.isArray()) + { + /// 内存布局:numElements(8B) | null_bitmap(与numElements成正比) | values(每个值长度与类型有关) | backing buffer + const auto & array = field.get(); /// Array can not be wrapped with Nullable + const auto num_elems = array.size(); + int64_t res = 8 + calculateBitSetWidthInBytes(num_elems); + + const auto * array_type = typeid_cast(type_without_nullable.get()); + const auto & nested_type = array_type->getNestedType(); + res += roundNumberOfBytesToNearestWord(getArrayElementSize(nested_type) * num_elems); + + BackingDataLengthCalculator calculator(nested_type); + for (size_t i = 0; i < array.size(); ++i) + res += calculator.calculate(array[i]); + return res; + } + + if (which.isMap()) + { + /// 内存布局:Length of UnsafeArrayData of key(8B) | UnsafeArrayData of key | UnsafeArrayData of value + int64_t res = 8; + + /// Construct Array of keys and values from Map + const auto & map = field.get(); /// Map can not be wrapped with Nullable + const auto num_keys = map.size(); + auto array_key = Array(); + auto array_val = Array(); + array_key.reserve(num_keys); + array_val.reserve(num_keys); + for (size_t i = 0; i < num_keys; ++i) + { + const auto & pair = map[i].get(); + array_key.push_back(pair[0]); + array_val.push_back(pair[1]); + } + + const auto * map_type = typeid_cast(type_without_nullable.get()); + + const auto & key_type = map_type->getKeyType(); + const auto key_array_type = std::make_shared(key_type); + BackingDataLengthCalculator calculator_key(key_array_type); + res += calculator_key.calculate(array_key); + + const auto & val_type = map_type->getValueType(); + const auto type_array_val = std::make_shared(val_type); + BackingDataLengthCalculator calculator_val(type_array_val); + res += calculator_val.calculate(array_val); + return res; + } + + if (which.isTuple()) + { + /// 内存布局:null_bitmap(字节数与字段数成正比) | field1 value(8B) | field2 value(8B) | ... | fieldn value(8B) | backing buffer + const auto & tuple = field.get(); /// Tuple can not be wrapped with Nullable + const auto * type_tuple = typeid_cast(type_without_nullable.get()); + const auto & type_fields = type_tuple->getElements(); + const auto num_fields = type_fields.size(); + int64_t res = calculateBitSetWidthInBytes(num_fields) + 8 * num_fields; + for (size_t i = 0; i < num_fields; ++i) + { + BackingDataLengthCalculator calculator(type_fields[i]); + res += calculator.calculate(tuple[i]); + } + return res; + } + + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Doesn't support type {} for BackingBufferLengthCalculator", type_without_nullable->getName()); +} + +int64_t BackingDataLengthCalculator::getArrayElementSize(const DataTypePtr & nested_type) +{ + const WhichDataType nested_which(removeNullable(nested_type)); + if (nested_which.isUInt8() || nested_which.isInt8()) + return 1; + else if (nested_which.isUInt16() || nested_which.isInt16() || nested_which.isDate()) + return 2; + else if ( + nested_which.isUInt32() || nested_which.isInt32() || nested_which.isFloat32() || nested_which.isDate32() + || nested_which.isDecimal32()) + return 4; + else if ( + nested_which.isUInt64() || nested_which.isInt64() || nested_which.isFloat64() || nested_which.isDateTime64() + || nested_which.isDecimal64()) + return 8; + else + return 8; +} + +bool BackingDataLengthCalculator::isFixedLengthDataType(const DataTypePtr & type_without_nullable) +{ + const WhichDataType which(type_without_nullable); + return which.isUInt8() || which.isInt8() || which.isUInt16() || which.isInt16() || which.isDate() || which.isUInt32() || which.isInt32() + || which.isFloat32() || which.isDate32() || which.isDecimal32() || which.isUInt64() || which.isInt64() || which.isFloat64() + || which.isDateTime64() || which.isDecimal64() || which.isNothing(); +} + +bool BackingDataLengthCalculator::isVariableLengthDataType(const DataTypePtr & type_without_nullable) +{ + const WhichDataType which(type_without_nullable); + return which.isStringOrFixedString() || which.isDecimal128() || which.isArray() || which.isMap() || which.isTuple(); +} + +bool BackingDataLengthCalculator::isDataTypeSupportRawData(const DB::DataTypePtr & type_without_nullable) +{ + const WhichDataType which(type_without_nullable); + return isFixedLengthDataType(type_without_nullable) || which.isStringOrFixedString() || which.isDecimal128(); +} + +bool BackingDataLengthCalculator::isBigEndianInSparkRow(const DB::DataTypePtr & type_without_nullable) +{ + const WhichDataType which(type_without_nullable); + return which.isDecimal128(); +} + +void BackingDataLengthCalculator::swapDecimalEndianBytes(String & buf) +{ + assert(buf.size() == 16); + + using base_type = Decimal128::NativeType::base_type; + auto * decimal128 = reinterpret_cast(buf.data()); + for (size_t i = 0; i != std::size(decimal128->value.items); ++i) + decimal128->value.items[i] = __builtin_bswap64(decimal128->value.items[i]); + + base_type * high = reinterpret_cast(buf.data() + 8); + base_type * low = reinterpret_cast(buf.data()); + std::swap(*high, *low); +} + +VariableLengthDataWriter::VariableLengthDataWriter( + const DataTypePtr & type_, char * buffer_address_, const std::vector & offsets_, std::vector & buffer_cursor_) + : type_without_nullable(removeNullable(type_)) + , which(type_without_nullable) + , buffer_address(buffer_address_) + , offsets(offsets_) + , buffer_cursor(buffer_cursor_) +{ + assert(buffer_address); + assert(!offsets.empty()); + assert(!buffer_cursor.empty()); + assert(offsets.size() == buffer_cursor.size()); + + if (!BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataWriter doesn't support type {}", type_without_nullable->getName()); +} + +int64_t VariableLengthDataWriter::writeArray(size_t row_idx, const DB::Array & array, int64_t parent_offset) +{ + /// 内存布局:numElements(8B) | null_bitmap(与numElements成正比) | values(每个值长度与类型有关) | backing data + const auto & offset = offsets[row_idx]; + auto & cursor = buffer_cursor[row_idx]; + const auto num_elems = array.size(); + const auto * array_type = typeid_cast(type_without_nullable.get()); + const auto & nested_type = array_type->getNestedType(); + + /// Write numElements(8B) + const auto start = cursor; + memcpy(buffer_address + offset + cursor, &num_elems, 8); + cursor += 8; + if (num_elems == 0) + return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, 8); + + /// Skip null_bitmap(already reset to zero) + const auto len_null_bitmap = calculateBitSetWidthInBytes(num_elems); + cursor += len_null_bitmap; + + /// Skip values(already reset to zero) + const auto elem_size = BackingDataLengthCalculator::getArrayElementSize(nested_type); + const auto len_values = roundNumberOfBytesToNearestWord(elem_size * num_elems); + cursor += len_values; + + if (BackingDataLengthCalculator::isFixedLengthDataType(removeNullable(nested_type))) + { + /// If nested type is fixed-length data type, update null_bitmap and values in place + FixedLengthDataWriter writer(nested_type); + for (size_t i = 0; i < num_elems; ++i) + { + const auto & elem = array[i]; + if (elem.isNull()) + bitSet(buffer_address + offset + start + 8, i); + else +// writer.write(elem, buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); + writer.unsafeWrite(reinterpret_cast(&elem.get()), buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); + } + } + else + { + /// If nested type is not fixed-length data type, update null_bitmap in place + /// And append values in backing data recursively + VariableLengthDataWriter writer(nested_type, buffer_address, offsets, buffer_cursor); + for (size_t i = 0; i < num_elems; ++i) + { + const auto & elem = array[i]; + if (elem.isNull()) + bitSet(buffer_address + offset + start + 8, i); + else + { + const auto offset_and_size = writer.write(row_idx, elem, start); + memcpy(buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size, &offset_and_size, 8); + } + } + } + return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, cursor - start); +} + +int64_t VariableLengthDataWriter::writeMap(size_t row_idx, const DB::Map & map, int64_t parent_offset) +{ + /// 内存布局:Length of UnsafeArrayData of key(8B) | UnsafeArrayData of key | UnsafeArrayData of value + const auto & offset = offsets[row_idx]; + auto & cursor = buffer_cursor[row_idx]; + + /// Skip length of UnsafeArrayData of key(8B) + const auto start = cursor; + cursor += 8; + + /// If Map is empty, return in advance + const auto num_pairs = map.size(); + if (num_pairs == 0) + return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, 8); + + /// Construct array of keys and array of values from map + auto key_array = Array(); + auto val_array = Array(); + key_array.reserve(num_pairs); + val_array.reserve(num_pairs); + for (size_t i = 0; i < num_pairs; ++i) + { + const auto & pair = map[i].get(); + key_array.push_back(pair[0]); + val_array.push_back(pair[1]); + } + + const auto * map_type = typeid_cast(type_without_nullable.get()); + + /// Append UnsafeArrayData of key + const auto & key_type = map_type->getKeyType(); + const auto key_array_type = std::make_shared(key_type); + VariableLengthDataWriter key_writer(key_array_type, buffer_address, offsets, buffer_cursor); + const auto key_array_size = BackingDataLengthCalculator::extractSize(key_writer.write(row_idx, key_array, start + 8)); + + /// Fill length of UnsafeArrayData of key + memcpy(buffer_address + offset + start, &key_array_size, 8); + + /// Append UnsafeArrayData of value + const auto & val_type = map_type->getValueType(); + const auto val_array_type = std::make_shared(val_type); + VariableLengthDataWriter val_writer(val_array_type, buffer_address, offsets, buffer_cursor); + val_writer.write(row_idx, val_array, start + 8 + key_array_size); + return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, cursor - start); +} + +int64_t VariableLengthDataWriter::writeStruct(size_t row_idx, const DB::Tuple & tuple, int64_t parent_offset) +{ + /// 内存布局:null_bitmap(字节数与字段数成正比) | values(num_fields * 8B) | backing data + const auto & offset = offsets[row_idx]; + auto & cursor = buffer_cursor[row_idx]; + const auto start = cursor; + + /// Skip null_bitmap + const auto * tuple_type = typeid_cast(type_without_nullable.get()); + const auto & field_types = tuple_type->getElements(); + const auto num_fields = field_types.size(); + if (num_fields == 0) + return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, 0); + const auto len_null_bitmap = calculateBitSetWidthInBytes(num_fields); + cursor += len_null_bitmap; + + /// Skip values + cursor += num_fields * 8; + + /// If field type is fixed-length, fill field value in values region + /// else append it to backing data region, and update offset_and_size in values region + for (size_t i = 0; i < num_fields; ++i) + { + const auto & field_value = tuple[i]; + const auto & field_type = field_types[i]; + if (field_value.isNull()) + { + bitSet(buffer_address + offset + start, i); + continue; + } + + if (BackingDataLengthCalculator::isFixedLengthDataType(removeNullable(field_type))) + { + FixedLengthDataWriter writer(field_type); + // writer.write(field_value, buffer_address + offset + start + len_null_bitmap + i * 8); + writer.unsafeWrite( + reinterpret_cast(&field_value.get()), buffer_address + offset + start + len_null_bitmap + i * 8); + } + else + { + VariableLengthDataWriter writer(field_type, buffer_address, offsets, buffer_cursor); + const auto offset_and_size = writer.write(row_idx, field_value, start); + memcpy(buffer_address + offset + start + len_null_bitmap + 8 * i, &offset_and_size, 8); + } + } + return BackingDataLengthCalculator::getOffsetAndSize(start - parent_offset, cursor - start); +} + +int64_t VariableLengthDataWriter::write(size_t row_idx, const DB::Field & field, int64_t parent_offset) +{ + assert(row_idx < offsets.size()); + + if (field.isNull()) + return 0; + + if (which.isStringOrFixedString()) + { + const auto & str = field.get(); + return writeUnalignedBytes(row_idx, str.data(), str.size(), parent_offset); + } + + if (which.isDecimal128()) + { + const auto & decimal_field = field.safeGet>(); + auto decimal128 = decimal_field.getValue(); + String buf(reinterpret_cast(&decimal128), sizeof(decimal128)); + BackingDataLengthCalculator::swapDecimalEndianBytes(buf); + return writeUnalignedBytes(row_idx, buf.c_str(), sizeof(Decimal128), parent_offset); + } + + if (which.isArray()) + { + const auto & array = field.get(); + return writeArray(row_idx, array, parent_offset); + } + + if (which.isMap()) + { + const auto & map = field.get(); + return writeMap(row_idx, map, parent_offset); + } + + if (which.isTuple()) + { + const auto & tuple = field.get(); + return writeStruct(row_idx, tuple, parent_offset); + } + + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Doesn't support type {} for BackingDataWriter", type_without_nullable->getName()); +} + +int64_t BackingDataLengthCalculator::getOffsetAndSize(int64_t cursor, int64_t size) +{ + return (cursor << 32) | size; +} + +int64_t BackingDataLengthCalculator::extractOffset(int64_t offset_and_size) +{ + return offset_and_size >> 32; +} + +int64_t BackingDataLengthCalculator::extractSize(int64_t offset_and_size) +{ + return offset_and_size & 0xffffffff; +} + +int64_t VariableLengthDataWriter::writeUnalignedBytes(size_t row_idx, const char * src, size_t size, int64_t parent_offset) +{ + memcpy(buffer_address + offsets[row_idx] + buffer_cursor[row_idx], src, size); + auto res = BackingDataLengthCalculator::getOffsetAndSize(buffer_cursor[row_idx] - parent_offset, size); + buffer_cursor[row_idx] += roundNumberOfBytesToNearestWord(size); + return res; +} + + +FixedLengthDataWriter::FixedLengthDataWriter(const DB::DataTypePtr & type_) + : type_without_nullable(removeNullable(type_)), which(type_without_nullable) +{ + if (!BackingDataLengthCalculator::isFixedLengthDataType(type_without_nullable)) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "FixedLengthWriter doesn't support type {}", type_without_nullable->getName()); +} + +void FixedLengthDataWriter::write(const DB::Field & field, char * buffer) +{ + /// Skip null value + if (field.isNull()) + return; + + if (which.isUInt8()) + { + const auto value = UInt8(field.get()); + memcpy(buffer, &value, 1); + } + else if (which.isUInt16() || which.isDate()) + { + const auto value = UInt16(field.get()); + memcpy(buffer, &value, 2); + } + else if (which.isUInt32() || which.isDate32()) + { + const auto value = UInt32(field.get()); + memcpy(buffer, &value, 4); + } + else if (which.isUInt64()) + { + const auto & value = field.get(); + memcpy(buffer, &value, 8); + } + else if (which.isInt8()) + { + const auto value = Int8(field.get()); + memcpy(buffer, &value, 1); + } + else if (which.isInt16()) + { + const auto value = Int16(field.get()); + memcpy(buffer, &value, 2); + } + else if (which.isInt32()) + { + const auto value = Int32(field.get()); + memcpy(buffer, &value, 4); + } + else if (which.isInt64()) + { + const auto & value = field.get(); + memcpy(buffer, &value, 8); + } + else if (which.isFloat32()) + { + const auto value = Float32(field.get()); + memcpy(buffer, &value, 4); + } + else if (which.isFloat64()) + { + const auto & value = field.get(); + memcpy(buffer, &value, 8); + } + else if (which.isDecimal32()) + { + const auto & value = field.get(); + const auto decimal = value.getValue(); + memcpy(buffer, &decimal, 4); + } + else if (which.isDecimal64() || which.isDateTime64()) + { + const auto & value = field.get(); + auto decimal = value.getValue(); + memcpy(buffer, &decimal, 8); + } + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "FixedLengthDataWriter doesn't support type {}", type_without_nullable->getName()); +} + +void FixedLengthDataWriter::unsafeWrite(const StringRef & str, char * buffer) +{ + memcpy(buffer, str.data, str.size); +} + +void FixedLengthDataWriter::unsafeWrite(const char * __restrict src, char * __restrict buffer) +{ + memcpy(buffer, src, type_without_nullable->getSizeOfValueInMemory()); +} + +} diff --git a/utils/local-engine/Parser/CHColumnToSparkRow.h b/utils/local-engine/Parser/CHColumnToSparkRow.h new file mode 100644 index 000000000000..f32cf7a6f6d3 --- /dev/null +++ b/utils/local-engine/Parser/CHColumnToSparkRow.h @@ -0,0 +1,175 @@ +#pragma once +#include +#include +#include +#include +#include + + +namespace local_engine +{ +int64_t calculateBitSetWidthInBytes(int32_t num_fields); +int64_t roundNumberOfBytesToNearestWord(int64_t num_bytes); +void bitSet(char * bitmap, int32_t index); +bool isBitSet(const char * bitmap, int32_t index); + +class CHColumnToSparkRow; +class SparkRowToCHColumn; + +class SparkRowInfo : public boost::noncopyable +{ + friend CHColumnToSparkRow; + friend SparkRowToCHColumn; + +public: + explicit SparkRowInfo(const DB::Block & block); + explicit SparkRowInfo(const DB::ColumnsWithTypeAndName & cols, const DB::DataTypes & types, const size_t & col_size, const size_t & row_size); + + const DB::DataTypes & getDataTypes() const; + + int64_t getFieldOffset(int32_t col_idx) const; + + int64_t getNullBitsetWidthInBytes() const; + void setNullBitsetWidthInBytes(int64_t null_bitset_width_in_bytes_); + + int64_t getNumCols() const; + void setNumCols(int64_t num_cols_); + + int64_t getNumRows() const; + void setNumRows(int64_t num_rows_); + + char * getBufferAddress() const; + void setBufferAddress(char * buffer_address); + + const std::vector & getOffsets() const; + const std::vector & getLengths() const; + std::vector & getBufferCursor(); + int64_t getTotalBytes() const; + +private: + const DB::DataTypes types; + int64_t num_rows; + int64_t num_cols; + int64_t null_bitset_width_in_bytes; + int64_t total_bytes; + + std::vector offsets; + std::vector lengths; + std::vector buffer_cursor; + char * buffer_address; +}; + +using SparkRowInfoPtr = std::unique_ptr; + +class CHColumnToSparkRow : private Allocator +// class CHColumnToSparkRow : public DB::Arena +{ +public: + std::unique_ptr convertCHColumnToSparkRow(const DB::Block & block); + void freeMem(char * address, size_t size); +}; + +/// Return backing data length of values with variable-length type in bytes +class BackingDataLengthCalculator +{ +public: + static constexpr size_t DECIMAL_MAX_INT64_DIGITS = 18; + + explicit BackingDataLengthCalculator(const DB::DataTypePtr & type_); + virtual ~BackingDataLengthCalculator() = default; + + /// Return length is guranteed to round up to 8 + virtual int64_t calculate(const DB::Field & field) const; + + static int64_t getArrayElementSize(const DB::DataTypePtr & nested_type); + + /// Is CH DataType can be converted to fixed-length data type in Spark? + static bool isFixedLengthDataType(const DB::DataTypePtr & type_without_nullable); + + /// Is CH DataType can be converted to variable-length data type in Spark? + static bool isVariableLengthDataType(const DB::DataTypePtr & type_without_nullable); + + /// If Data Type can use raw data between CH Column and Spark Row if value is not null + static bool isDataTypeSupportRawData(const DB::DataTypePtr & type_without_nullable); + + /// If bytes in Spark Row is big-endian. If true, we have to transform them to little-endian afterwords + static bool isBigEndianInSparkRow(const DB::DataTypePtr & type_without_nullable); + + /// Convert endian. Big to little or little to big. + /// Note: Spark unsafeRow biginteger is big-endian. + /// CH Int128 is little-endian, is same as system(std::endian::native). + static void swapDecimalEndianBytes(String & buf); + + static int64_t getOffsetAndSize(int64_t cursor, int64_t size); + static int64_t extractOffset(int64_t offset_and_size); + static int64_t extractSize(int64_t offset_and_size); + +private: + // const DB::DataTypePtr type; + const DB::DataTypePtr type_without_nullable; + const DB::WhichDataType which; +}; + +/// Writing variable-length typed values to backing data region of Spark Row +/// User who calls VariableLengthDataWriter is responsible to write offset_and_size +/// returned by VariableLengthDataWriter::write to field value in Spark Row +class VariableLengthDataWriter +{ +public: + VariableLengthDataWriter( + const DB::DataTypePtr & type_, + char * buffer_address_, + const std::vector & offsets_, + std::vector & buffer_cursor_); + + virtual ~VariableLengthDataWriter() = default; + + /// Write value of variable-length to backing data region of structure(row or array) and return offset and size in backing data region + /// It's caller's duty to make sure that row fields or array elements are written in order + /// parent_offset: the starting offset of current structure in which we are updating it's backing data region + virtual int64_t write(size_t row_idx, const DB::Field & field, int64_t parent_offset); + + /// Only support String/FixedString/Decimal128 + int64_t writeUnalignedBytes(size_t row_idx, const char * src, size_t size, int64_t parent_offset); +private: + int64_t writeArray(size_t row_idx, const DB::Array & array, int64_t parent_offset); + int64_t writeMap(size_t row_idx, const DB::Map & map, int64_t parent_offset); + int64_t writeStruct(size_t row_idx, const DB::Tuple & tuple, int64_t parent_offset); + + // const DB::DataTypePtr type; + const DB::DataTypePtr type_without_nullable; + const DB::WhichDataType which; + + /// Global buffer of spark rows + char * const buffer_address; + /// Offsets of each spark row + const std::vector & offsets; + /// Cursors of backing data in each spark row, relative to offsets + std::vector & buffer_cursor; +}; + +class FixedLengthDataWriter +{ +public: + explicit FixedLengthDataWriter(const DB::DataTypePtr & type_); + virtual ~FixedLengthDataWriter() = default; + + /// Write value of fixed-length to values region of structure(struct or array) + /// It's caller's duty to make sure that struct fields or array elements are written in order + virtual void write(const DB::Field & field, char * buffer); + + /// Copy memory chunk of Fixed length typed CH Column directory to buffer for performance. + /// It is unsafe unless you know what you are doing. + virtual void unsafeWrite(const StringRef & str, char * buffer); + + /// Copy memory chunk of in fixed length typed Field directory to buffer for performance. + /// It is unsafe unless you know what you are doing. + virtual void unsafeWrite(const char * __restrict src, char * __restrict buffer); + +private: + // const DB::DataTypePtr type; + const DB::DataTypePtr type_without_nullable; + const DB::WhichDataType which; +}; + +} diff --git a/utils/local-engine/Parser/CMakeLists.txt b/utils/local-engine/Parser/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/local-engine/Parser/ExpandRelParser.cpp b/utils/local-engine/Parser/ExpandRelParser.cpp new file mode 100644 index 000000000000..21c81e9a3f02 --- /dev/null +++ b/utils/local-engine/Parser/ExpandRelParser.cpp @@ -0,0 +1,92 @@ +#include "ExpandRelParser.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +} +namespace local_engine +{ + +ExpandRelParser::ExpandRelParser(SerializedPlanParser * plan_parser_) + : RelParser(plan_parser_) +{} + +DB::QueryPlanPtr +ExpandRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list & rel_stack) +{ + const auto & expand_rel = rel.expand(); + std::vector aggregating_expressions_columns; + std::set agg_cols_ref; + const auto & header = query_plan->getCurrentDataStream().header; + for (int i = 0; i < expand_rel.aggregate_expressions_size(); ++i) + { + const auto & expr = expand_rel.aggregate_expressions(i); + if (expr.has_selection()) + { + aggregating_expressions_columns.push_back(expr.selection().direct_reference().struct_field().field()); + agg_cols_ref.insert(expr.selection().direct_reference().struct_field().field()); + } + else + { + // FIXEME. see https://github.com/oap-project/gluten/pull/794 + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, + "Unsupported aggregating expression in expand node. {}. input header:{}.", + expr.ShortDebugString(), + header.dumpNames()); + } + } + std::vector> grouping_sets; + buildGroupingSets(expand_rel, grouping_sets); + // The input header is : aggregating columns + grouping columns. + auto expand_step = std::make_unique( + query_plan->getCurrentDataStream(), aggregating_expressions_columns, grouping_sets, expand_rel.group_name()); + expand_step->setStepDescription("Expand step"); + query_plan->addStep(std::move(expand_step)); + return query_plan; +} + + +void ExpandRelParser::buildGroupingSets(const substrait::ExpandRel & expand_rel, std::vector> & grouping_sets) +{ + for (int i = 0; i < expand_rel.groupings_size(); ++i) + { + const auto grouping_set_pb = expand_rel.groupings(i); + std::set grouping_set; + for (int n = 0; n < grouping_set_pb.groupsets_expressions_size(); ++n) + { + const auto & expr = grouping_set_pb.groupsets_expressions(n); + if (expr.has_selection()) + { + grouping_set.insert(expr.selection().direct_reference().struct_field().field()); + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupported expression in grouping sets"); + } + } + grouping_sets.emplace_back(std::move(grouping_set)); + } +} + +void registerExpandRelParser(RelParserFactory & factory) +{ + auto builder = [](SerializedPlanParser * plan_parser) + { + return std::make_shared(plan_parser); + }; + factory.registerBuilder(substrait::Rel::RelTypeCase::kExpand, builder); +} +} diff --git a/utils/local-engine/Parser/ExpandRelParser.h b/utils/local-engine/Parser/ExpandRelParser.h new file mode 100644 index 000000000000..2d6e07b16002 --- /dev/null +++ b/utils/local-engine/Parser/ExpandRelParser.h @@ -0,0 +1,17 @@ +#pragma once +#include +#include + +namespace local_engine +{ +class ExpandRelParser : public RelParser +{ +public: + explicit ExpandRelParser(SerializedPlanParser * plan_parser_); + ~ExpandRelParser() override = default; + DB::QueryPlanPtr + parse(DB::QueryPlanPtr query_plan, const substrait::Rel & sort_rel, std::list & rel_stack_) override; +private: + static void buildGroupingSets(const substrait::ExpandRel & expand_rel, std::vector> & grouping_sets); +}; +} diff --git a/utils/local-engine/Parser/RelParser.cpp b/utils/local-engine/Parser/RelParser.cpp new file mode 100644 index 000000000000..261b5aa765d9 --- /dev/null +++ b/utils/local-engine/Parser/RelParser.cpp @@ -0,0 +1,140 @@ +#include "RelParser.h" +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int LOGICAL_ERROR; +} +} + +namespace local_engine +{ +AggregateFunctionPtr RelParser::getAggregateFunction( + DB::String & name, DB::DataTypes arg_types, DB::AggregateFunctionProperties & properties, const DB::Array & parameters) +{ + auto & factory = AggregateFunctionFactory::instance(); + return factory.get(name, arg_types, parameters, properties); +} + +std::optional RelParser::parseFunctionName(UInt32 function_ref) +{ + const auto & function_mapping = getFunctionMapping(); + auto it = function_mapping.find(std::to_string(function_ref)); + if (it == function_mapping.end()) + { + return {}; + } + auto function_signature = it->second; + auto function_name = function_signature.substr(0, function_signature.find(':')); + return function_name; +} + +DB::DataTypes RelParser::parseFunctionArgumentTypes( + const Block & header, const google::protobuf::RepeatedPtrField & func_args) +{ + DB::DataTypes res; + for (const auto & arg : func_args) + { + if (!arg.has_value()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect a FunctionArgument with value field"); + } + res.emplace_back(parseExpressionType(header, arg.value())); + } + return res; +} + +DB::DataTypePtr RelParser::parseExpressionType(const Block & header, const substrait::Expression & expr) +{ + DB::DataTypePtr res; + if (expr.has_selection()) + { + auto pos = expr.selection().direct_reference().struct_field().field(); + res = header.getByPosition(pos).type; + } + else if (expr.has_literal()) + { + auto [data_type, _] = SerializedPlanParser::parseLiteral(expr.literal()); + res = data_type; + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow FunctionArgument: {}", expr.DebugString()); + } + return res; +} + + +DB::Names RelParser::parseFunctionArgumentNames( + const Block & header, const google::protobuf::RepeatedPtrField & func_args) +{ + DB::Names res; + for (const auto & arg : func_args) + { + if (!arg.has_value()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect a FunctionArgument with value field"); + } + const auto & value = arg.value(); + if (value.has_selection()) + { + auto pos = value.selection().direct_reference().struct_field().field(); + res.push_back(header.getByPosition(pos).name); + } + else if (value.has_literal()) + { + auto [_, field] = SerializedPlanParser::parseLiteral(value.literal()); + res.push_back(field.dump()); + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow FunctionArgument: {}", arg.DebugString()); + } + } + return res; +} + +RelParserFactory & RelParserFactory::instance() +{ + static RelParserFactory factory; + return factory; +} + +void RelParserFactory::registerBuilder(UInt32 k, RelParserBuilder builder) +{ + auto it = builders.find(k); + if (it != builders.end()) + { + throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Duplicated builder key:{}", k); + } + builders[k] = builder; +} + +RelParserFactory::RelParserBuilder RelParserFactory::getBuilder(DB::UInt32 k) +{ + auto it = builders.find(k); + if (it == builders.end()) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Not found builder for key:{}", k); + } + return it->second; +} + +void registerWindowRelParser(RelParserFactory & factory); +void registerSortRelParser(RelParserFactory & factory); +void registerExpandRelParser(RelParserFactory & factory); + +void registerRelParsers() +{ + auto & factory = RelParserFactory::instance(); + registerWindowRelParser(factory); + registerSortRelParser(factory); + registerExpandRelParser(factory); +} +} diff --git a/utils/local-engine/Parser/RelParser.h b/utils/local-engine/Parser/RelParser.h new file mode 100644 index 000000000000..381777571df2 --- /dev/null +++ b/utils/local-engine/Parser/RelParser.h @@ -0,0 +1,70 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace local_engine +{ +/// parse a single substrait relation +class RelParser +{ +public: + explicit RelParser(SerializedPlanParser * plan_parser_) + :plan_parser(plan_parser_) + {} + + virtual ~RelParser() = default; + virtual DB::QueryPlanPtr parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & rel_stack_) = 0; + + static AggregateFunctionPtr getAggregateFunction( + DB::String & name, + DB::DataTypes arg_types, + DB::AggregateFunctionProperties & properties, + const DB::Array & parameters = {}); + +public: + static DB::DataTypePtr parseType(const substrait::Type & type) { return SerializedPlanParser::parseType(type); } +protected: + inline ContextPtr getContext() { return plan_parser->context; } + inline String getUniqueName(const std::string & name) { return plan_parser->getUniqueName(name); } + inline const std::unordered_map & getFunctionMapping() { return plan_parser->function_mapping; } + std::optional parseFunctionName(UInt32 function_ref); + static DB::DataTypes parseFunctionArgumentTypes(const Block & header, const google::protobuf::RepeatedPtrField & func_args); + static DB::DataTypePtr parseExpressionType(const Block & header, const substrait::Expression & expr); + static DB::Names parseFunctionArgumentNames(const Block & header, const google::protobuf::RepeatedPtrField & func_args); + const DB::ActionsDAG::Node * parseArgument(ActionsDAGPtr action_dag, const substrait::Expression & rel) + { + return plan_parser->parseExpression(action_dag, rel); + } + std::pair parseLiteral(const substrait::Expression_Literal & literal) + { + return plan_parser->parseLiteral(literal); + } + +private: + SerializedPlanParser * plan_parser; +}; + +class RelParserFactory +{ +protected: + RelParserFactory() = default; +public: + using RelParserBuilder = std::function(SerializedPlanParser *)>; + static RelParserFactory & instance(); + void registerBuilder(UInt32 k, RelParserBuilder builder); + RelParserBuilder getBuilder(DB::UInt32 k); +private: + std::map builders; +}; + +void registerRelParsers(); +} diff --git a/utils/local-engine/Parser/SerializedPlanParser.cpp b/utils/local-engine/Parser/SerializedPlanParser.cpp new file mode 100644 index 000000000000..9d8c0f2890e7 --- /dev/null +++ b/utils/local-engine/Parser/SerializedPlanParser.cpp @@ -0,0 +1,2809 @@ +#include "SerializedPlanParser.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "Common/Exception.h" +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "DataTypes/IDataType.h" +#include "Parsers/ExpressionListParsers.h" +#include "SerializedPlanParser.h" +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int UNKNOWN_TYPE; + extern const int BAD_ARGUMENTS; + extern const int NO_SUCH_DATA_PART; + extern const int UNKNOWN_FUNCTION; + extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int INVALID_JOIN_ON_EXPRESSION; +} +} + +namespace local_engine +{ +using namespace DB; + +void join(ActionsDAG::NodeRawConstPtrs v, char c, std::string & s) +{ + s.clear(); + for (auto p = v.begin(); p != v.end(); ++p) + { + s += (*p)->result_name; + if (p != v.end() - 1) + s += c; + } +} + +bool isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & ch_type) +{ + const auto parsed_ch_type = SerializedPlanParser::parseType(substrait_type); + return parsed_ch_type->equals(*ch_type); +} + +void SerializedPlanParser::parseExtensions( + const ::google::protobuf::RepeatedPtrField & extensions) +{ + for (const auto & extension : extensions) + { + if (extension.has_extension_function()) + { + function_mapping.emplace( + std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name()); + } + } +} + +std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( + const std::vector & expressions, + const DB::Block & header, + const DB::Block & read_schema) +{ + auto actions_dag = std::make_shared(blockToNameAndTypeList(header)); + NamesWithAliases required_columns; + std::set distinct_columns; + + for (const auto & expr : expressions) + { + if (expr.has_selection()) + { + auto position = expr.selection().direct_reference().struct_field().field(); + auto col_name = read_schema.getByPosition(position).name; + const ActionsDAG::Node * field = actions_dag->tryFindInOutputs(col_name); + if (distinct_columns.contains(field->result_name)) + { + auto unique_name = getUniqueName(field->result_name); + required_columns.emplace_back(NameWithAlias(field->result_name, unique_name)); + distinct_columns.emplace(unique_name); + } + else + { + required_columns.emplace_back(NameWithAlias(field->result_name, field->result_name)); + distinct_columns.emplace(field->result_name); + } + } + else if (expr.has_scalar_function()) + { + const auto & scalar_function = expr.scalar_function(); + auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference())); + auto function_name = getFunctionName(function_signature, scalar_function); + + std::vector result_names; + std::vector useless; + if (function_name == "arrayJoin") + { + /// Whether the function from spark is explode or posexplode + bool position = startsWith(function_signature, "posexplode"); + actions_dag = parseArrayJoin(header, expr, result_names, useless, actions_dag, true, position); + } + else + { + result_names.resize(1); + actions_dag = parseFunction(header, expr, result_names[0], useless, actions_dag, true); + } + + for (const auto & result_name : result_names) + { + if (result_name.empty()) + continue; + + if (distinct_columns.contains(result_name)) + { + auto unique_name = getUniqueName(result_name); + required_columns.emplace_back(NameWithAlias(result_name, unique_name)); + distinct_columns.emplace(unique_name); + } + else + { + required_columns.emplace_back(NameWithAlias(result_name, result_name)); + distinct_columns.emplace(result_name); + } + } + } + else if (expr.has_cast() || expr.has_if_then() || expr.has_literal()) + { + const auto * node = parseExpression(actions_dag, expr); + actions_dag->addOrReplaceInOutputs(*node); + if (distinct_columns.contains(node->result_name)) + { + auto unique_name = getUniqueName(node->result_name); + required_columns.emplace_back(NameWithAlias(node->result_name, unique_name)); + distinct_columns.emplace(unique_name); + } + else + { + required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name)); + distinct_columns.emplace(node->result_name); + } + } + else + throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case())); + } + + actions_dag->project(required_columns); + return actions_dag; +} + +std::string getDecimalFunction(const substrait::Type_Decimal & decimal, const bool null_on_overflow) { + std::string ch_function_name; + UInt32 precision = decimal.precision(); + UInt32 scale = decimal.scale(); + + if (precision <= DataTypeDecimal32::maxPrecision()) + { + ch_function_name = "toDecimal32"; + } + else if (precision <= DataTypeDecimal64::maxPrecision()) + { + ch_function_name = "toDecimal64"; + } + else if (precision <= DataTypeDecimal128::maxPrecision()) + { + ch_function_name = "toDecimal128"; + } + else + { + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision); + } + + if (null_on_overflow) { + ch_function_name = ch_function_name + "OrNull"; + } + + return ch_function_name; +} +/// TODO: This function needs to be improved for Decimal/Array/Map/Tuple types. +std::string getCastFunction(const substrait::Type & type) +{ + std::string ch_function_name; + if (type.has_fp64()) + { + ch_function_name = "toFloat64"; + } + else if (type.has_fp32()) + { + ch_function_name = "toFloat32"; + } + else if (type.has_string() || type.has_binary()) + { + ch_function_name = "toString"; + } + else if (type.has_i64()) + { + ch_function_name = "toInt64"; + } + else if (type.has_i32()) + { + ch_function_name = "toInt32"; + } + else if (type.has_i16()) + { + ch_function_name = "toInt16"; + } + else if (type.has_i8()) + { + ch_function_name = "toInt8"; + } + else if (type.has_date()) + { + ch_function_name = "toDate32"; + } + // TODO need complete param: scale + else if (type.has_timestamp()) + { + ch_function_name = "toDateTime64"; + } + else if (type.has_bool_()) + { + ch_function_name = "toUInt8"; + } + else if (type.has_decimal()) + { + ch_function_name = getDecimalFunction(type.decimal(), false); + } + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support cast type {}", type.DebugString()); + + /// TODO(taiyang-li): implement cast functions of other types + + return ch_function_name; +} + +bool SerializedPlanParser::isReadRelFromJava(const substrait::ReadRel & rel) +{ + assert(rel.has_local_files()); + assert(rel.has_base_schema()); + return rel.local_files().items().size() == 1 && rel.local_files().items().at(0).uri_file().starts_with("iterator"); +} + +QueryPlanPtr SerializedPlanParser::parseReadRealWithLocalFile(const substrait::ReadRel & rel) +{ + assert(rel.has_local_files()); + assert(rel.has_base_schema()); + auto header = parseNameStruct(rel.base_schema()); + auto source = std::make_shared(context, header, rel.local_files()); + auto source_pipe = Pipe(source); + auto source_step = std::make_unique(std::move(source_pipe), "substrait local files", nullptr); + source_step->setStepDescription("read local files"); + auto query_plan = std::make_unique(); + query_plan->addStep(std::move(source_step)); + return query_plan; +} + +QueryPlanPtr SerializedPlanParser::parseReadRealWithJavaIter(const substrait::ReadRel & rel) +{ + assert(rel.has_local_files()); + assert(rel.local_files().items().size() == 1); + assert(rel.has_base_schema()); + auto iter = rel.local_files().items().at(0).uri_file(); + auto pos = iter.find(':'); + auto iter_index = std::stoi(iter.substr(pos + 1, iter.size())); + auto plan = std::make_unique(); + + auto source = std::make_shared(parseNameStruct(rel.base_schema()), input_iters[iter_index]); + QueryPlanStepPtr source_step = std::make_unique(Pipe(source)); + source_step->setStepDescription("Read From Java Iter"); + plan->addStep(std::move(source_step)); + + return plan; +} + +void SerializedPlanParser::addRemoveNullableStep(QueryPlan & plan, std::vector columns) +{ + if (columns.empty()) return; + auto remove_nullable_actions_dag + = std::make_shared(blockToNameAndTypeList(plan.getCurrentDataStream().header)); + removeNullable(columns, remove_nullable_actions_dag); + auto expression_step = std::make_unique(plan.getCurrentDataStream(), remove_nullable_actions_dag); + expression_step->setStepDescription("Remove nullable properties"); + plan.addStep(std::move(expression_step)); +} + +QueryPlanPtr SerializedPlanParser::parseMergeTreeTable(const substrait::ReadRel & rel) +{ + assert(rel.has_extension_table()); + google::protobuf::StringValue table; + table.ParseFromString(rel.extension_table().detail().value()); + auto merge_tree_table = local_engine::parseMergeTreeTableString(table.value()); + DB::Block header; + if (rel.has_base_schema() && rel.base_schema().names_size()) + { + header = parseNameStruct(rel.base_schema()); + } + else + { + // For count(*) case, there will be an empty base_schema, so we try to read at least once column + auto all_parts_dir = MergeTreeUtil::getAllMergeTreeParts( std::filesystem::path("/") / merge_tree_table.relative_path); + if (all_parts_dir.empty()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Empty mergetree directory: {}", merge_tree_table.relative_path); + } + auto part_names_types_list = MergeTreeUtil::getSchemaFromMergeTreePart(all_parts_dir[0]); + NamesAndTypesList one_column_name_type; + one_column_name_type.push_back(part_names_types_list.front()); + header = BlockUtil::buildHeader(one_column_name_type); + LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "Try to read ({}) instead of empty header", header.dumpNames()); + } + auto names_and_types_list = header.getNamesAndTypesList(); + auto storage_factory = StorageMergeTreeFactory::instance(); + auto metadata = buildMetaData(names_and_types_list, context); + query_context.metadata = metadata; + auto storage = storage_factory.getStorage( + StorageID(merge_tree_table.database, merge_tree_table.table), + metadata->getColumns(), + [merge_tree_table, metadata]() -> CustomStorageMergeTreePtr + { + auto custom_storage_merge_tree = std::make_shared( + StorageID(merge_tree_table.database, merge_tree_table.table), + merge_tree_table.relative_path, + *metadata, + false, + global_context, + "", + MergeTreeData::MergingParams(), + buildMergeTreeSettings()); + custom_storage_merge_tree->loadDataParts(false); + return custom_storage_merge_tree; + }); + query_context.storage_snapshot = std::make_shared(*storage, metadata); + query_context.custom_storage_merge_tree = storage; + auto query_info = buildQueryInfo(names_and_types_list); + std::vector not_null_columns; + if (rel.has_filter()) + { + query_info->prewhere_info = parsePreWhereInfo(rel.filter(), header, not_null_columns); + } + auto data_parts = query_context.custom_storage_merge_tree->getAllDataPartsVector(); + int min_block = merge_tree_table.min_block; + int max_block = merge_tree_table.max_block; + MergeTreeData::DataPartsVector selected_parts; + std::copy_if( + std::begin(data_parts), + std::end(data_parts), + std::inserter(selected_parts, std::begin(selected_parts)), + [min_block, max_block](MergeTreeData::DataPartPtr part) + { return part->info.min_block >= min_block && part->info.max_block < max_block; }); + if (selected_parts.empty()) + { + throw Exception(ErrorCodes::NO_SUCH_DATA_PART, "part {} to {} not found.", min_block, max_block); + } + auto query = query_context.custom_storage_merge_tree->reader.readFromParts( + selected_parts, names_and_types_list.getNames(), query_context.storage_snapshot, *query_info, context, 4096 * 2, 1); + if (!not_null_columns.empty()) + { + auto input_header = query->getCurrentDataStream().header; + std::erase_if(not_null_columns, [input_header](auto item) -> bool {return !input_header.has(item);}); + addRemoveNullableStep(*query, not_null_columns); + } + return query; +} + +PrewhereInfoPtr SerializedPlanParser::parsePreWhereInfo(const substrait::Expression & rel, Block & input, std::vector& not_nullable_columns) +{ + auto prewhere_info = std::make_shared(); + prewhere_info->prewhere_actions = std::make_shared(input.getNamesAndTypesList()); + std::string filter_name; + // for in function + if (rel.has_singular_or_list()) + { + const auto *in_node = parseExpression(prewhere_info->prewhere_actions, rel); + prewhere_info->prewhere_actions->addOrReplaceInOutputs(*in_node); + filter_name = in_node->result_name; + } + else + { + parseFunctionWithDAG(rel, filter_name, not_nullable_columns, prewhere_info->prewhere_actions, true); + } + prewhere_info->prewhere_column_name = filter_name; + prewhere_info->need_filter = true; + prewhere_info->remove_prewhere_column = true; + auto cols = prewhere_info->prewhere_actions->getRequiredColumnsNames(); + if (last_project) + { + prewhere_info->prewhere_actions->removeUnusedActions(Names{filter_name}, true, true); + prewhere_info->prewhere_actions->projectInput(false); + for (const auto & expr : last_project->expressions()) + { + if (expr.has_selection()) + { + auto position = expr.selection().direct_reference().struct_field().field(); + auto name = input.getByPosition(position).name; + prewhere_info->prewhere_actions->tryRestoreColumn(name); + } + } + } + else + { + prewhere_info->prewhere_actions->removeUnusedActions(Names{filter_name}, false, true); + prewhere_info->prewhere_actions->projectInput(false); + for (const auto& name : input.getNames()) + { + prewhere_info->prewhere_actions->tryRestoreColumn(name); + } + } + return prewhere_info; +} + +Block SerializedPlanParser::parseNameStruct(const substrait::NamedStruct & struct_) +{ + ColumnsWithTypeAndName internal_cols; + internal_cols.reserve(struct_.names_size()); + std::list field_names; + for (int i = 0; i < struct_.names_size(); ++i) + { + field_names.emplace_back(struct_.names(i)); + } + + for (int i = 0; i < struct_.struct_().types_size(); ++i) + { + auto name = field_names.front(); + const auto & type = struct_.struct_().types(i); + auto data_type = parseType(type, &field_names); + Poco::StringTokenizer name_parts(name, "#"); + if (name_parts.count() == 4) + { + auto agg_function_name = getFunctionName(name_parts[3], {}); + AggregateFunctionProperties properties; + auto tmp = AggregateFunctionFactory::instance().get(agg_function_name, {data_type}, {}, properties); + data_type = tmp->getStateType(); + } + internal_cols.push_back(ColumnWithTypeAndName(data_type, name)); + } + Block res(std::move(internal_cols)); + return std::move(res); +} + +DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type) +{ + return wrapNullableType(nullable == substrait::Type_Nullability_NULLABILITY_NULLABLE, nested_type); +} + +DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type) +{ + if (nullable && !nested_type->isNullable()) + return std::make_shared(nested_type); + else + return nested_type; +} + +/** + * names is used to name struct type fields. + * + */ +DataTypePtr SerializedPlanParser::parseType(const substrait::Type & substrait_type, std::list * names) +{ + DataTypePtr ch_type; + std::string_view current_name; + if (names) + { + current_name = names->front(); + names->pop_front(); + } + + if (substrait_type.has_bool_()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.bool_().nullability(), ch_type); + } + else if (substrait_type.has_i8()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.i8().nullability(), ch_type); + } + else if (substrait_type.has_i16()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.i16().nullability(), ch_type); + } + else if (substrait_type.has_i32()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.i32().nullability(), ch_type); + } + else if (substrait_type.has_i64()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.i64().nullability(), ch_type); + } + else if (substrait_type.has_string()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.string().nullability(), ch_type); + } + else if (substrait_type.has_binary()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.binary().nullability(), ch_type); + } + else if (substrait_type.has_fixed_char()) + { + const auto & fixed_char = substrait_type.fixed_char(); + ch_type = std::make_shared(fixed_char.length()); + ch_type = wrapNullableType(fixed_char.nullability(), ch_type); + } + else if (substrait_type.has_fixed_binary()) + { + const auto & fixed_binary = substrait_type.fixed_binary(); + ch_type = std::make_shared(fixed_binary.length()); + ch_type = wrapNullableType(fixed_binary.nullability(), ch_type); + } + else if (substrait_type.has_fp32()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.fp32().nullability(), ch_type); + } + else if (substrait_type.has_fp64()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.fp64().nullability(), ch_type); + } + else if (substrait_type.has_timestamp()) + { + ch_type = std::make_shared(6); + ch_type = wrapNullableType(substrait_type.timestamp().nullability(), ch_type); + } + else if (substrait_type.has_date()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(substrait_type.date().nullability(), ch_type); + } + else if (substrait_type.has_decimal()) + { + UInt32 precision = substrait_type.decimal().precision(); + UInt32 scale = substrait_type.decimal().scale(); + if (precision > DataTypeDecimal128::maxPrecision()) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision); + ch_type = createDecimal(precision, scale); + ch_type = wrapNullableType(substrait_type.decimal().nullability(), ch_type); + } + else if (substrait_type.has_struct_()) + { + DataTypes ch_field_types(substrait_type.struct_().types().size()); + Strings field_names; + for (size_t i = 0; i < ch_field_types.size(); ++i) + { + if (names) + field_names.push_back(names->front()); + + ch_field_types[i] = std::move(parseType(substrait_type.struct_().types()[i], names)); + } + if (!field_names.empty()) + ch_type = std::make_shared(ch_field_types, field_names); + else + ch_type = std::make_shared(ch_field_types); + ch_type = wrapNullableType(substrait_type.struct_().nullability(), ch_type); + } + else if (substrait_type.has_list()) + { + auto ch_nested_type = parseType(substrait_type.list().type()); + ch_type = std::make_shared(ch_nested_type); + ch_type = wrapNullableType(substrait_type.list().nullability(), ch_type); + } + else if (substrait_type.has_map()) + { + auto ch_key_type = parseType(substrait_type.map().key()); + auto ch_val_type = parseType(substrait_type.map().value()); + ch_type = std::make_shared(ch_key_type, ch_val_type); + ch_type = wrapNullableType(substrait_type.map().nullability(), ch_type); + } + else if (substrait_type.has_nothing()) + { + ch_type = std::make_shared(); + ch_type = wrapNullableType(true, ch_type); + } + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support type {}", substrait_type.DebugString()); + + /// TODO(taiyang-li): consider Time/IntervalYear/IntervalDay/TimestampTZ/UUID/VarChar/FixedBinary/UserDefined + return std::move(ch_type); +} + +DB::DataTypePtr SerializedPlanParser::parseType(const std::string & type) +{ + static std::map type2type = { + {"BooleanType", "UInt8"}, + {"ByteType", "Int8"}, + {"ShortType", "Int16"}, + {"IntegerType", "Int32"}, + {"LongType", "Int64"}, + {"FloatType", "Float32"}, + {"DoubleType", "Float64"}, + {"StringType", "String"}, + {"DateType", "Date"} + }; + + auto it = type2type.find(type); + if (it == type2type.end()) + { + throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Unknow spark type: {}", type); + } + return DB::DataTypeFactory::instance().get(it->second); +} + +QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr plan) +{ + auto * logger = &Poco::Logger::get("SerializedPlanParser"); + if (logger->debug()) + { + namespace pb_util = google::protobuf::util; + pb_util::JsonOptions options; + std::string json; + pb_util::MessageToJsonString(*plan, &json, options); + LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "substrait plan:{}", json); + } + parseExtensions(plan->extensions()); + if (plan->relations_size() == 1) + { + auto root_rel = plan->relations().at(0); + if (!root_rel.has_root()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "must have root rel!"); + } + std::list rel_stack; + auto query_plan = parseOp(root_rel.root().input(), rel_stack); + if (root_rel.root().names_size()) + { + ActionsDAGPtr actions_dag = std::make_shared(blockToNameAndTypeList(query_plan->getCurrentDataStream().header)); + NamesWithAliases aliases; + auto cols = query_plan->getCurrentDataStream().header.getNamesAndTypesList(); + for (size_t i = 0; i < cols.getNames().size(); i++) + { + aliases.emplace_back(NameWithAlias(cols.getNames()[i], root_rel.root().names(i))); + } + actions_dag->project(aliases); + auto expression_step = std::make_unique(query_plan->getCurrentDataStream(), actions_dag); + expression_step->setStepDescription("Rename Output"); + query_plan->addStep(std::move(expression_step)); + } + return query_plan; + } + else + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "too many relations found"); + } +} + +QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list & rel_stack) +{ + QueryPlanPtr query_plan; + switch (rel.rel_type_case()) + { + case substrait::Rel::RelTypeCase::kFetch: { + rel_stack.push_back(&rel); + const auto & limit = rel.fetch(); + query_plan = parseOp(limit.input(), rel_stack); + rel_stack.pop_back(); + auto limit_step = std::make_unique(query_plan->getCurrentDataStream(), limit.count(), limit.offset()); + query_plan->addStep(std::move(limit_step)); + break; + } + case substrait::Rel::RelTypeCase::kFilter: { + rel_stack.push_back(&rel); + const auto & filter = rel.filter(); + query_plan = parseOp(filter.input(), rel_stack); + rel_stack.pop_back(); + std::string filter_name; + std::vector required_columns; + + ActionsDAGPtr actions_dag = nullptr; + if (filter.condition().has_scalar_function()) + { + actions_dag = parseFunction( + query_plan->getCurrentDataStream().header, filter.condition(), filter_name, required_columns, nullptr, true); + } + else + { + actions_dag = std::make_shared(blockToNameAndTypeList(query_plan->getCurrentDataStream().header)); + const auto * node = parseExpression(actions_dag, filter.condition()); + filter_name = node->result_name; + } + + auto input = query_plan->getCurrentDataStream().header.getNames(); + Names input_with_condition(input); + input_with_condition.emplace_back(filter_name); + actions_dag->removeUnusedActions(input_with_condition); + auto filter_step = std::make_unique(query_plan->getCurrentDataStream(), actions_dag, filter_name, true); + query_plan->addStep(std::move(filter_step)); + + // remove nullable + addRemoveNullableStep(*query_plan, required_columns); + break; + } + case substrait::Rel::RelTypeCase::kGenerate: + case substrait::Rel::RelTypeCase::kProject: { + const substrait::Rel * input = nullptr; + bool is_generate = false; + std::vector expressions; + + if (rel.has_project()) + { + const auto & project = rel.project(); + last_project = &project; + input = &project.input(); + + expressions.reserve(project.expressions_size()); + for (int i=0; ihas_read() && !input->read().has_local_files(); + if (is_mergetree_input) + read_schema = parseNameStruct(input->read().base_schema()); + else + read_schema = query_plan->getCurrentDataStream().header; + + auto actions_dag = expressionsToActionsDAG(expressions, query_plan->getCurrentDataStream().header, read_schema); + auto expression_step = std::make_unique(query_plan->getCurrentDataStream(), actions_dag); + expression_step->setStepDescription(is_generate ? "Generate" : "Project"); + query_plan->addStep(std::move(expression_step)); + break; + } + case substrait::Rel::RelTypeCase::kAggregate: { + rel_stack.push_back(&rel); + const auto & aggregate = rel.aggregate(); + query_plan = parseOp(aggregate.input(), rel_stack); + rel_stack.pop_back(); + + bool is_final; + auto aggregate_step = parseAggregate(*query_plan, aggregate, is_final); + + query_plan->addStep(std::move(aggregate_step)); + + if (is_final) + { + std::vector measure_positions; + std::vector measure_types; + for (int i = 0; i < aggregate.measures_size(); i++) + { + auto position + = aggregate.measures(i).measure().arguments(0).value().selection().direct_reference().struct_field().field(); + measure_positions.emplace_back(position); + measure_types.emplace_back(aggregate.measures(i).measure().output_type()); + } + auto source = query_plan->getCurrentDataStream().header.getColumnsWithTypeAndName(); + auto target = source; + + bool need_convert = false; + for (size_t i = 0; i < measure_positions.size(); i++) + { + if (!isTypeMatched(measure_types[i], source[measure_positions[i]].type)) + { + auto target_type = parseType(measure_types[i]); + target[measure_positions[i]].type = target_type; + target[measure_positions[i]].column = target_type->createColumn(); + need_convert = true; + } + } + + if (need_convert) + { + ActionsDAGPtr convert_action + = ActionsDAG::makeConvertingActions(source, target, DB::ActionsDAG::MatchColumnsMode::Position); + if (convert_action) + { + QueryPlanStepPtr convert_step = std::make_unique(query_plan->getCurrentDataStream(), convert_action); + convert_step->setStepDescription("Convert Aggregate Output"); + query_plan->addStep(std::move(convert_step)); + } + } + } + break; + } + case substrait::Rel::RelTypeCase::kRead: { + const auto & read = rel.read(); + assert(read.has_local_files() || read.has_extension_table() && "Only support local parquet files or merge tree read rel"); + if (read.has_local_files()) + { + if (isReadRelFromJava(read)) + { + query_plan = parseReadRealWithJavaIter(read); + } + else + { + query_plan = parseReadRealWithLocalFile(read); + } + } + else + { + query_plan = parseMergeTreeTable(read); + } + last_project = nullptr; + break; + } + case substrait::Rel::RelTypeCase::kJoin: { + const auto & join = rel.join(); + if (!join.has_left() || !join.has_right()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "left table or right table is missing."); + } + last_project = nullptr; + rel_stack.push_back(&rel); + auto left_plan = parseOp(join.left(), rel_stack); + last_project = nullptr; + auto right_plan = parseOp(join.right(), rel_stack); + rel_stack.pop_back(); + + query_plan = parseJoin(join, std::move(left_plan), std::move(right_plan)); + break; + } + case substrait::Rel::RelTypeCase::kSort: { + rel_stack.push_back(&rel); + const auto & sort_rel = rel.sort(); + query_plan = parseOp(sort_rel.input(), rel_stack); + rel_stack.pop_back(); + auto sort_parser = RelParserFactory::instance().getBuilder(substrait::Rel::RelTypeCase::kSort)(this); + query_plan = sort_parser->parse(std::move(query_plan), rel, rel_stack); + break; + } + case substrait::Rel::RelTypeCase::kWindow: { + rel_stack.push_back(&rel); + const auto win_rel = rel.window(); + query_plan = parseOp(win_rel.input(), rel_stack); + rel_stack.pop_back(); + auto win_parser = RelParserFactory::instance().getBuilder(substrait::Rel::RelTypeCase::kWindow)(this); + query_plan = win_parser->parse(std::move(query_plan), rel, rel_stack); + break; + } + case substrait::Rel::RelTypeCase::kExpand: { + rel_stack.push_back(&rel); + const auto & expand_rel = rel.expand(); + query_plan = parseOp(expand_rel.input(), rel_stack); + rel_stack.pop_back(); + auto epand_parser = RelParserFactory::instance().getBuilder(substrait::Rel::RelTypeCase::kExpand)(this); + query_plan = epand_parser->parse(std::move(query_plan), rel, rel_stack); + break; + } + default: + throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support relation type: {}.\n{}", rel.rel_type_case(), rel.DebugString()); + } + return query_plan; +} + +AggregateFunctionPtr getAggregateFunction(const std::string & name, DataTypes arg_types) +{ + auto & factory = AggregateFunctionFactory::instance(); + AggregateFunctionProperties properties; + return factory.get(name, arg_types, Array{}, properties); +} + + +NamesAndTypesList SerializedPlanParser::blockToNameAndTypeList(const Block & header) +{ + NamesAndTypesList types; + for (const auto & name : header.getNames()) + { + const auto * column = header.findByName(name); + types.push_back(NameAndTypePair(column->name, column->type)); + } + return types; +} + +void SerializedPlanParser::addPreProjectStepIfNeeded( + QueryPlan & plan, + const substrait::AggregateRel & rel, + std::vector & measure_names, + std::map & nullable_measure_names) +{ + auto input = plan.getCurrentDataStream(); + ActionsDAGPtr expression = std::make_shared(blockToNameAndTypeList(input.header)); + std::vector required_columns; + std::vector to_wrap_nullable; + String measure_name; + bool need_pre_project = false; + for (const auto & measure : rel.measures()) + { + auto which_measure_type = WhichDataType(parseType(measure.measure().output_type())); + if (measure.measure().arguments_size() != 1) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "only support one argument aggregate function"); + } + auto arg = measure.measure().arguments(0).value(); + + if (arg.has_selection()) + { + measure_name = input.header.getByPosition(arg.selection().direct_reference().struct_field().field()).name; + measure_names.emplace_back(measure_name); + } + else if (arg.has_literal()) + { + const auto * node = parseExpression(expression, arg); + expression->addOrReplaceInOutputs(*node); + measure_name = node->result_name; + measure_names.emplace_back(measure_name); + need_pre_project = true; + } + else + { + // this includes the arg.has_scalar_function() case + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported aggregate argument type {}.", arg.DebugString()); + } + + if (which_measure_type.isNullable() && + measure.measure().phase() == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE && + !expression->findInOutputs(measure_name).result_type->isNullable() + ) + { + to_wrap_nullable.emplace_back(measure_name); + need_pre_project = true; + } + } + wrapNullable(to_wrap_nullable, expression, nullable_measure_names); + + if (need_pre_project) + { + auto expression_before_aggregate = std::make_unique(input, expression); + expression_before_aggregate->setStepDescription("Before Aggregate"); + plan.addStep(std::move(expression_before_aggregate)); + } +} + + +/** + * Gluten will use a pre projection step (search needsPreProjection in HashAggregateExecBaseTransformer) + * so this function can assume all group and agg args are direct references or literals + */ +QueryPlanStepPtr SerializedPlanParser::parseAggregate(QueryPlan & plan, const substrait::AggregateRel & rel, bool & is_final) +{ + std::set phase_set; + for (int i = 0; i < rel.measures_size(); ++i) + { + const auto & measure = rel.measures(i); + phase_set.emplace(measure.measure().phase()); + } + + bool has_first_stage = phase_set.contains(substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE); + bool has_inter_stage = phase_set.contains(substrait::AggregationPhase::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE); + bool has_final_stage = phase_set.contains(substrait::AggregationPhase::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT); + + if (phase_set.size() > 1) + { + if (phase_set.size() == 2 && has_first_stage && has_inter_stage) + { + // this will happen in a sql like: + // select sum(a), count(distinct b) from T + } + else + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "too many aggregate phase!"); + } + } + + is_final = has_final_stage; + + std::vector measure_names; + std::map nullable_measure_names; + addPreProjectStepIfNeeded(plan, rel, measure_names, nullable_measure_names); + + Names keys = {}; + if (rel.groupings_size() == 1) + { + for (const auto & group : rel.groupings(0).grouping_expressions()) + { + if (group.has_selection() && group.selection().has_direct_reference()) + { + keys.emplace_back(plan.getCurrentDataStream().header.getNames().at(group.selection().direct_reference().struct_field().field())); + } + else + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported group expression: {}", group.DebugString()); + } + } + } + // only support one grouping or no grouping + else if (rel.groupings_size() != 0) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "too many groupings"); + } + + auto aggregates = AggregateDescriptions(); + for (int i = 0; i < rel.measures_size(); ++i) + { + const auto & measure = rel.measures(i); + AggregateDescription agg; + auto function_signature = function_mapping.at(std::to_string(measure.measure().function_reference())); + auto function_name_idx = function_signature.find(':'); + // assert(function_name_idx != function_signature.npos && ("invalid function signature: " + function_signature).c_str()); + auto function_name = getFunctionName(function_signature.substr(0, function_name_idx), {}); + if (measure.measure().phase() != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE) + { + agg.column_name = measure_names.at(i); + } + else + { + agg.column_name = function_name + "(" + measure_names.at(i) + ")"; + } + + // if measure arg has nullable version, use it + auto input_column = measure_names.at(i); + auto entry = nullable_measure_names.find(input_column); + if (entry != nullable_measure_names.end()) + { + input_column = entry->second; + } + agg.argument_names = {input_column}; + auto arg_type = plan.getCurrentDataStream().header.getByName(input_column).type; + if (const auto * function_type = checkAndGetDataType(arg_type.get())) + { + const auto * suffix = "PartialMerge"; + agg.function = getAggregateFunction(function_name + suffix, {arg_type}); + } + else + { + auto arg = arg_type; + if (measure.measure().phase() != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE) + { + auto first = getAggregateFunction(function_name, {arg_type}); + arg = first->getStateType(); + const auto * suffix = "PartialMerge"; + function_name = function_name + suffix; + } + + agg.function = getAggregateFunction(function_name, {arg}); + } + aggregates.push_back(agg); + } + + if (has_final_stage) + { + return std::make_unique(plan.getCurrentDataStream(), + getMergedAggregateParam( keys, aggregates), + true, + false, + 1, + 1, + false, + context->getSettingsRef().max_block_size, + context->getSettingsRef().aggregation_in_order_max_block_bytes, + SortDescription(), + context->getSettingsRef().enable_memory_bound_merging_of_aggregation_results); + } + else + { + auto aggregating_step = std::make_unique( + plan.getCurrentDataStream(), + getAggregateParam(keys, aggregates), + GroupingSetsParamsList(), + false, + context->getSettingsRef().max_block_size, + context->getSettingsRef().aggregation_in_order_max_block_bytes, + 1, + 1, + false, + false, + SortDescription(), + SortDescription(), + false, + false); + return std::move(aggregating_step); + } +} + + +std::string +SerializedPlanParser::getFunctionName(const std::string & function_signature, const substrait::Expression_ScalarFunction & function) +{ + const auto & output_type = function.output_type(); + auto args = function.arguments(); + auto pos = function_signature.find(':'); + auto function_name = function_signature.substr(0, pos); + if (!SCALAR_FUNCTIONS.contains(function_name)) + throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unsupported function {}", function_name); + + std::string ch_function_name; + if (function_name == "cast") + { + ch_function_name = getCastFunction(output_type); + } + else if (function_name == "trim") + { + if (args.size() == 1) + { + ch_function_name = "trimBoth"; + } + if (args.size() == 2) + { + ch_function_name = "sparkTrimBoth"; + } + } + else if (function_name == "ltrim") + { + if (args.size() == 1) + { + ch_function_name = "trimLeft"; + } + if (args.size() == 2) + { + ch_function_name = "sparkTrimLeft"; + } + } + else if (function_name == "rtrim") + { + if (args.size() == 1) + { + ch_function_name = "trimRight"; + } + if (args.size() == 2) + { + ch_function_name = "sparkTrimRigth"; + } + } + else if (function_name == "extract") + { + if (args.size() != 2) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function extract requires two args, function:{}", function.ShortDebugString()); + + // Get the first arg: field + const auto & extract_field = args.at(0); + + if (extract_field.value().has_literal()) + { + const auto & field_value = extract_field.value().literal().string(); + if (field_value == "YEAR") + ch_function_name = "toYear"; // spark: extract(YEAR FROM) or year + else if (field_value == "YEAR_OF_WEEK") + ch_function_name = "toISOYear"; // spark: extract(YEAROFWEEK FROM) + else if (field_value == "QUARTER") + ch_function_name = "toQuarter"; // spark: extract(QUARTER FROM) or quarter + else if (field_value == "MONTH") + ch_function_name = "toMonth"; // spark: extract(MONTH FROM) or month + else if (field_value == "WEEK_OF_YEAR") + ch_function_name = "toISOWeek"; // spark: extract(WEEK FROM) or weekofyear + /* + else if (field_value == "WEEK_DAY") + { + /// spark: weekday(t) -> substrait: extract(WEEK_DAY FROM t) -> ch: WEEKDAY(t) + /// spark: extract(DAYOFWEEK_ISO FROM t) -> substrait: 1 + extract(WEEK_DAY FROM t) -> ch: 1 + WEEKDAY(t) + ch_function_name = "?"; + } + else if (field_value == "DAY_OF_WEEK") + ch_function_name = "?"; // spark: extract(DAYOFWEEK FROM) or dayofweek + */ + else if (field_value == "DAY") + ch_function_name = "toDayOfMonth"; // spark: extract(DAY FROM) or dayofmonth + else if (field_value == "DAY_OF_YEAR") + ch_function_name = "toDayOfYear"; // spark: extract(DOY FROM) or dayofyear + else if (field_value == "HOUR") + ch_function_name = "toHour"; // spark: extract(HOUR FROM) or hour + else if (field_value == "MINUTE") + ch_function_name = "toMinute"; // spark: extract(MINUTE FROM) or minute + else if (field_value == "SECOND") + ch_function_name = "toSecond"; // spark: extract(SECOND FROM) or secondwithfraction + else + throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of extract function is wrong."); + } + else + throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of extract function is wrong."); + } + else if (function_name == "trunc") + { + if (args.size() != 2) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function trunc requires two args, function:{}", function.ShortDebugString()); + + const auto & trunc_field = args.at(0); + if (!trunc_field.value().has_literal() || !trunc_field.value().literal().has_string()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "The second arg of trunc function is wrong."); + + const auto & field_value = trunc_field.value().literal().string(); + if (field_value == "YEAR" || field_value == "YYYY" || field_value == "YY") + ch_function_name = "toStartOfYear"; + else if (field_value == "QUARTER") + ch_function_name = "toStartOfQuarter"; + else if (field_value == "MONTH" || field_value == "MM" || field_value == "MON") + ch_function_name = "toStartOfMonth"; + else if (field_value == "WEEK") + ch_function_name = "toStartOfWeek"; + else + throw Exception(ErrorCodes::BAD_ARGUMENTS, "The second arg of trunc function is wrong, value:{}", field_value); + } + else if (function_name == "check_overflow") + { + if (args.size() < 2) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires at least two args."); + ch_function_name = getDecimalFunction(output_type.decimal(), args.at(1).value().literal().boolean()); + } + else + ch_function_name = SCALAR_FUNCTIONS.at(function_name); + + return ch_function_name; +} + +ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( + const substrait::Expression & rel, + std::vector & result_names, + std::vector & required_columns, + DB::ActionsDAGPtr actions_dag, + bool keep_result, bool position) +{ + if (!rel.has_scalar_function()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString()); + + const auto & scalar_function = rel.scalar_function(); + + auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); + auto function_name = getFunctionName(function_signature, scalar_function); + if (function_name != "arrayJoin") + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Function parseArrayJoinWithDAG should only process arrayJoin function, but input is {}", + rel.ShortDebugString()); + + /// The argument number of arrayJoin(converted from Spark explode/posexplode) should be 1 + if (scalar_function.arguments_size() != 1) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 but is {}", scalar_function.arguments_size()); + + ActionsDAG::NodeRawConstPtrs args; + parseFunctionArguments(actions_dag, args, required_columns, function_name, scalar_function); + + /// Remove Nullable from Nullable(Array(xx)) or Nullable(Map(xx, xx)) if needed + const auto * arg_not_null = args[0]; + if (arg_not_null->result_type->isNullable()) + { + auto assume_not_null_builder = FunctionFactory::instance().get("assumeNotNull", context); + arg_not_null = &actions_dag->addFunction(assume_not_null_builder, {args[0]}, "assumeNotNull(" + args[0]->result_name + ")"); + } + + /// arrayJoin(arg_not_null) + auto array_join_name = "arrayJoin(" + arg_not_null->result_name + ")"; + const auto * array_join_node = &actions_dag->addArrayJoin(*arg_not_null, array_join_name); + + auto arg_type = arg_not_null->result_type; + WhichDataType which(arg_type.get()); + auto tuple_element_builder = FunctionFactory::instance().get("tupleElement", context); + auto tuple_index_type = std::make_shared(); + + auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * + { + ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); + const auto * index_node = &actions_dag->addColumn(std::move(index_col)); + auto result_name = "tupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; + return &actions_dag->addFunction(tuple_element_builder, {tuple_node, index_node}, result_name); + }; + + /// Special process to keep compatiable with Spark + if (!position) + { + /// Spark: explode(array_or_map) -> CH: arrayJoin(array_or_map) + if (which.isMap()) + { + /// In Spark: explode(map(k, v)) output 2 columns with default names "key" and "value" + /// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type. + /// So we must wrap arrayJoin with tupleElement function for compatiability. + + /// arrayJoin(arg_not_null).1 + const auto * key_node = add_tuple_element(array_join_node, 1); + + /// arrayJoin(arg_not_null).2 + const auto * val_node = add_tuple_element(array_join_node, 2); + + result_names.push_back(key_node->result_name); + result_names.push_back(val_node->result_name); + if (keep_result) + { + actions_dag->addOrReplaceInOutputs(*key_node); + actions_dag->addOrReplaceInOutputs(*val_node); + } + return {key_node, val_node}; + } + else if (which.isArray()) + { + result_names.push_back(array_join_name); + if (keep_result) + actions_dag->addOrReplaceInOutputs(*array_join_node); + return {array_join_node}; + } + else + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument type of arrayJoin converted from explode should be Array or Map but is {}", arg_type->getName()); + } + else + { + /// Spark: posexplode(array_or_map) -> CH: arrayJoin(map), in which map = mapFromArrays(range(length(array_or_map)), array_or_map) + if (which.isMap()) + { + /// In Spark: posexplode(array_of_map) output 2 or 3 columns: (pos, col) or (pos, key, value) + /// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type. + /// So we must wrap arrayJoin with tupleElement function for compatiability. + + /// pos = arrayJoin(arg_not_null).1 + const auto * pos_node = add_tuple_element(array_join_node, 1); + + /// col = arrayJoin(arg_not_null).2 or (key, value) = arrayJoin(arg_not_null).2 + const auto * item_node = add_tuple_element(array_join_node, 2); + + /// Get type of y from node: cast(mapFromArrays(x, y), 'Map(K, V)') + DataTypePtr raw_child_type; + if (args[0]->type == ActionsDAG::ActionType::FUNCTION && args[0]->function_base->getName() == "mapFromArrays") + { + /// Get Type of y from node: mapFromArrays(x, y) + raw_child_type = DB::removeNullable(args[0]->children[1]->result_type); + } + else if (args[0]->type == ActionsDAG::ActionType::FUNCTION && args[0]->function_base->getName() == "_CAST" && + args[0]->children[0]->type == ActionsDAG::ActionType::FUNCTION && args[0]->children[0]->function_base->getName() == "mapFromArrays") + { + /// Get Type of y from node: cast(mapFromArrays(x, y), 'Map(K, V)') + raw_child_type = DB::removeNullable(args[0]->children[0]->children[1]->result_type); + } + else + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Invalid argument type of arrayJoin: {}", actions_dag->dumpDAG()); + + + if (isMap(raw_child_type)) + { + /// key = arrayJoin(arg_not_null).2.1 + const auto * item_key_node = add_tuple_element(item_node, 1); + + /// value = arrayJoin(arg_not_null).2.2 + const auto * item_value_node = add_tuple_element(item_node, 2); + + result_names.push_back(pos_node->result_name); + result_names.push_back(item_key_node->result_name); + result_names.push_back(item_value_node->result_name); + if (keep_result) + { + actions_dag->addOrReplaceInOutputs(*pos_node); + actions_dag->addOrReplaceInOutputs(*item_key_node); + actions_dag->addOrReplaceInOutputs(*item_value_node); + } + + return {pos_node, item_key_node, item_value_node}; + } + else if (isArray(raw_child_type)) + { + /// col = arrayJoin(arg_not_null).2 + result_names.push_back(pos_node->result_name); + result_names.push_back(item_node->result_name); + if (keep_result) + { + actions_dag->addOrReplaceInOutputs(*pos_node); + actions_dag->addOrReplaceInOutputs(*item_node); + } + return {pos_node, item_node}; + } + else + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "The raw input of arrayJoin converted from posexplode should be Array or Map type but is {}", + raw_child_type->getName()); + } + else + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Argument type of arrayJoin converted from posexplode should be Map but is {}", + arg_type->getName()); + } +} + +const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( + const substrait::Expression & rel, + std::string & result_name, + std::vector & required_columns, + DB::ActionsDAGPtr actions_dag, + bool keep_result) +{ + if (!rel.has_scalar_function()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); + + const auto & scalar_function = rel.scalar_function(); + + auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); + auto function_name = getFunctionName(function_signature, scalar_function); + ActionsDAG::NodeRawConstPtrs args; + parseFunctionArguments(actions_dag, args, required_columns, function_name, scalar_function); + + /// If the first argument of function formatDateTimeInJodaSyntax is integer, replace formatDateTimeInJodaSyntax with fromUnixTimestampInJodaSyntax + /// to avoid exception + if (function_name == "formatDateTimeInJodaSyntax") + { + if (args.size() > 1 && isInteger(DB::removeNullable(args[0]->result_type))) + function_name = "fromUnixTimestampInJodaSyntax"; + } + + const ActionsDAG::Node * result_node; + if (function_name == "alias") + { + result_name = args[0]->result_name; + actions_dag->addOrReplaceInOutputs(*args[0]); + result_node = &actions_dag->addAlias(actions_dag->findInOutputs(result_name), result_name); + } + else + { + if (function_name == "isNotNull") + { + required_columns.emplace_back(args[0]->result_name); + } + else if (function_name == "splitByRegexp") + { + if (args.size() >= 2) + { + /// In Spark: split(str, regex [, limit] ) + /// In CH: splitByRegexp(regexp, s) + std::swap(args[0], args[1]); + } + } + + if (startsWith(function_signature, "extract:")) + { + // delete the first arg of extract + args.erase(args.begin()); + } + else if (startsWith(function_signature, "trunc:")) + { + // delete the second arg of trunc + args.pop_back(); + } + + if (function_signature.find("check_overflow:", 0) != function_signature.npos) + { + if (scalar_function.arguments().size() < 2) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires at least two args."); + + ActionsDAG::NodeRawConstPtrs new_args; + new_args.reserve(2); + + // if toDecimalxxOrNull, first arg need string type + if (scalar_function.arguments().at(1).value().literal().boolean()) + { + std::string check_overflow_args_trans_function = "toString"; + DB::ActionsDAG::NodeRawConstPtrs to_string_args({args[0]}); + + auto to_string_cast = FunctionFactory::instance().get(check_overflow_args_trans_function, context); + std::string to_string_cast_args_name; + join(to_string_args, ',', to_string_cast_args_name); + result_name = check_overflow_args_trans_function + "(" + to_string_cast_args_name + ")"; + const auto * to_string_cast_node = &actions_dag->addFunction(to_string_cast, to_string_args, result_name); + new_args.emplace_back(to_string_cast_node); + } + else + { + new_args.emplace_back(args[0]); + } + + auto type = std::make_shared(); + UInt32 scale = rel.scalar_function().output_type().decimal().scale(); + new_args.emplace_back( + &actions_dag->addColumn(ColumnWithTypeAndName(type->createColumnConst(1, scale), type, getUniqueName(toString(scale))))); + + args = std::move(new_args); + } + + auto function_builder = FunctionFactory::instance().get(function_name, context); + std::string args_name; + join(args, ',', args_name); + result_name = function_name + "(" + args_name + ")"; + const auto * function_node = &actions_dag->addFunction(function_builder, args, result_name); + result_node = function_node; + + if (!isTypeMatched(rel.scalar_function().output_type(), function_node->result_type)) + { + result_node = ActionsDAGUtil::convertNodeType( + actions_dag, + function_node, + SerializedPlanParser::parseType(rel.scalar_function().output_type())->getName(), + function_node->result_name); + } + if (function_name == "JSON_VALUE") + { + result_node->function->setResolver(function_builder); + } + if (keep_result) + actions_dag->addOrReplaceInOutputs(*result_node); + } + return result_node; +} + +void SerializedPlanParser::parseFunctionArguments( + DB::ActionsDAGPtr & actions_dag, + ActionsDAG::NodeRawConstPtrs & parsed_args, + std::vector & required_columns, + std::string & function_name, + const substrait::Expression_ScalarFunction & scalar_function) +{ + auto add_column = [&](const DataTypePtr & type, const Field & field) -> auto + { + return &actions_dag->addColumn(ColumnWithTypeAndName(type->createColumnConst(1, field), type, getUniqueName(toString(field)))); + }; + + const auto & args = scalar_function.arguments(); + + // Some functions need to be handled specially. + if (function_name == "JSONExtract") + { + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[0]); + auto data_type = parseType(scalar_function.output_type()); + parsed_args.emplace_back(add_column(std::make_shared(), data_type->getName())); + } + else if (function_name == "tupleElement") + { + // tupleElement. the field index must be unsigned integer in CH, cast the signed integer in substrait + // which must be a positive value into unsigned integer here. + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[0]); + + // tuple indecies start from 1, in spark, start from 0 + if (!args[1].value().has_literal()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be a literal"); + } + auto [data_type, field] = parseLiteral(args[1].value().literal()); + if (data_type->getTypeId() != DB::TypeIndex::Int32) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be i32"); + } + UInt32 field_index = field.get() + 1; + const auto * index_node = add_column(std::make_shared(), field_index); + parsed_args.emplace_back(index_node); + } + else if (function_name == "tuple") + { + // Arguments in the format, (, [, , ...]) + // We don't need to care the field names here. + for (int index = 1; index < args.size(); index += 2) + { + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[index]); + } + } + else if (function_name == "has") + { + // since FunctionArrayIndex::useDefaultImplementationForNulls = false, we need to unwrap the + // nullable + const ActionsDAG::Node * arg_node = parseFunctionArgument(actions_dag, required_columns, function_name, args[0]); + if (arg_node->result_type->isNullable()) + { + auto nested_type = typeid_cast(arg_node->result_type.get())->getNestedType(); + arg_node = ActionsDAGUtil::convertNodeType(actions_dag, arg_node, nested_type->getName()); + } + parsed_args.emplace_back(arg_node); + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[1]); + } + else if (function_name == "arrayElement") + { + // arrayElement. in spark, the array element index must a be positive value. But in CH, a array element index + // could be positive or negative and have different effects. So we make a cast here. + // In clickhosue, map element are also accessed by arrayElement, not make the cast. + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[0]); + auto element_type = actions_dag->getNodes().back().result_type; + const auto * nested_type = element_type.get(); + if (nested_type->isNullable()) + { + nested_type = typeid_cast(nested_type)->getNestedType().get(); + } + const auto * index_node = parseFunctionArgument(actions_dag, required_columns, function_name, args[1]); + if (nested_type->getTypeId() == DB::TypeIndex::Array) + { + DB::DataTypeNullable target_type(std::make_shared()); + index_node = ActionsDAGUtil::convertNodeType(actions_dag, index_node, target_type.getName()); + parsed_args.emplace_back(index_node); + } + else + parsed_args.push_back(index_node); + + } + else if (function_name == "repeat") + { + // repeat. the field index must be unsigned integer in CH, cast the signed integer in substrait + // which must be a positive value into unsigned integer here. + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[0]); + const DB::ActionsDAG::Node * repeat_times_node = + parseFunctionArgument(actions_dag, required_columns, function_name, args[1]); + DB::DataTypeNullable target_type(std::make_shared()); + repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, repeat_times_node, target_type.getName()); + parsed_args.emplace_back(repeat_times_node); + } + else if (function_name == "leftPadUTF8" || function_name == "rightPadUTF8") + { + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[0]); + + /// Make sure the second function arguemnt's type is unsigned integer + /// TODO: delete this branch after Kyligence/Clickhouse upgraged to 23.2 + const DB::ActionsDAG::Node * pad_length_node = + parseFunctionArgument(actions_dag, required_columns, function_name, args[1]); + DB::DataTypeNullable target_type(std::make_shared()); + pad_length_node = ActionsDAGUtil::convertNodeType(actions_dag, pad_length_node, target_type.getName()); + parsed_args.emplace_back(pad_length_node); + + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[2]); + } + else if (function_name == "isNaN") + { + // the result of isNaN(NULL) is NULL in CH, but false in Spark + const DB::ActionsDAG::Node * arg_node = nullptr; + if (args[0].value().has_cast()) + { + arg_node = parseExpression(actions_dag, args[0].value().cast().input()); + const auto * res_type = arg_node->result_type.get(); + if (res_type->isNullable()) + { + res_type = typeid_cast(res_type)->getNestedType().get(); + } + if (isString(*res_type)) + { + DB::ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node}; + arg_node = toFunctionNode(actions_dag, "toFloat64OrZero", cast_func_args); + } + else + { + arg_node = parseFunctionArgument(actions_dag, required_columns, function_name, args[0]); + } + } + else + { + arg_node = parseFunctionArgument(actions_dag, required_columns, function_name, args[0]); + } + + DB::ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node, add_column(std::make_shared(), 0)}; + parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull", ifnull_func_args)); + } + else if (function_name == "positionUTF8Spark") + { + if (args.size() >= 2) + { + // In Spark: position(substr, str, Int32) + // In CH: position(str, subtr, UInt32) + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[1]); + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[0]); + } + if (args.size() >= 3) + { + // add cast: cast(start_pos as UInt32) + const auto * start_pos_node = parseFunctionArgument(actions_dag, required_columns, function_name, args[2]); + DB::DataTypeNullable target_type(std::make_shared()); + start_pos_node = ActionsDAGUtil::convertNodeType(actions_dag, start_pos_node, target_type.getName()); + parsed_args.emplace_back(start_pos_node); + } + } + else if (function_name == "space") + { + // convert space function to repeat + const DB::ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, required_columns, "repeat", args[0]); + const DB::ActionsDAG::Node * space_str_node = add_column(std::make_shared(), " "); + function_name = "repeat"; + parsed_args.emplace_back(space_str_node); + parsed_args.emplace_back(repeat_times_node); + } + else if (function_name == "json_tuple") + { + function_name = "JSONExtract"; + const DB::ActionsDAG::Node * json_expr_node = parseFunctionArgument(actions_dag, required_columns, "JSONExtract", args[0]); + std::string extract_expr = "Tuple("; + for (int i = 1; i < args.size(); i++) + { + auto arg_value = args[i].value(); + if (!arg_value.has_literal()) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The arguments of function {} must be string literal", function_name); + } + DB::Field f = arg_value.literal().string(); + std::string s; + if (f.tryGet(s)) + { + extract_expr.append(s).append(" Nullable(String)"); + if (i != args.size() - 1) + { + extract_expr.append(","); + } + } + } + extract_expr.append(")"); + const DB::ActionsDAG::Node * extract_expr_node = add_column(std::make_shared(), extract_expr); + parsed_args.emplace_back(json_expr_node); + parsed_args.emplace_back(extract_expr_node); + } + else + { + // Default handle + for (const auto & arg : args) + parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, arg); + } +} + +void SerializedPlanParser::parseFunctionArgument( + DB::ActionsDAGPtr & actions_dag, + ActionsDAG::NodeRawConstPtrs & parsed_args, + std::vector & required_columns, + const std::string & function_name, + const substrait::FunctionArgument & arg) +{ + parsed_args.emplace_back(parseFunctionArgument(actions_dag, required_columns, function_name, arg)); +} + +const DB::ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument( + DB::ActionsDAGPtr & actions_dag, + std::vector & required_columns, + const std::string & function_name, + const substrait::FunctionArgument & arg) +{ + const DB::ActionsDAG::Node * res; + if (arg.value().has_scalar_function()) + { + std::string arg_name; + bool keep_arg = FUNCTION_NEED_KEEP_ARGUMENTS.contains(function_name); + parseFunctionWithDAG(arg.value(), arg_name, required_columns, actions_dag, keep_arg); + res = &actions_dag->getNodes().back(); + } + else + { + res = parseExpression(actions_dag, arg.value()); + } + return res; +} + +// Convert signed integer index into unsigned integer index +std::pair +SerializedPlanParser::convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field) +{ + // For tupelElement, field index starts from 1, but int substrait plan, it starts from 0. + #define UINT_CONVERT(type_ptr, field, type_name) \ + if ((type_ptr)->getTypeId() == DB::TypeIndex::type_name) \ + {\ + return {std::make_shared(), static_cast((field).get()) + 1};\ + } + + auto type_id = type->getTypeId(); + if (type_id == DB::TypeIndex::UInt8 || type_id == DB::TypeIndex::UInt16 || type_id == DB::TypeIndex::UInt32 + || type_id == DB::TypeIndex::UInt64) + { + return {type, field}; + } + UINT_CONVERT(type, field, Int8) + UINT_CONVERT(type, field, Int16) + UINT_CONVERT(type, field, Int32) + UINT_CONVERT(type, field, Int64) + throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Not valid interger type: {}", type->getName()); + #undef UINT_CONVERT +} + +ActionsDAGPtr SerializedPlanParser::parseFunction( + const Block & input, + const substrait::Expression & rel, + std::string & result_name, + std::vector & required_columns, + ActionsDAGPtr actions_dag, + bool keep_result) +{ + if (!actions_dag) + actions_dag = std::make_shared(blockToNameAndTypeList(input)); + + parseFunctionWithDAG(rel, result_name, required_columns, actions_dag, keep_result); + return actions_dag; +} + +ActionsDAGPtr SerializedPlanParser::parseArrayJoin( + const Block & input, + const substrait::Expression & rel, + std::vector & result_names, + std::vector & required_columns, + ActionsDAGPtr actions_dag, + bool keep_result, bool position) +{ + if (!actions_dag) + actions_dag = std::make_shared(blockToNameAndTypeList(input)); + + parseArrayJoinWithDAG(rel, result_names, required_columns, actions_dag, keep_result, position); + return actions_dag; +} + +const ActionsDAG::Node * +SerializedPlanParser::toFunctionNode(ActionsDAGPtr action_dag, const String & function, const DB::ActionsDAG::NodeRawConstPtrs & args) +{ + auto function_builder = DB::FunctionFactory::instance().get(function, context); + std::string args_name; + join(args, ',', args_name); + auto result_name = function + "(" + args_name + ")"; + const auto * function_node = &action_dag->addFunction(function_builder, args, result_name); + return function_node; +} + +std::pair SerializedPlanParser::parseLiteral(const substrait::Expression_Literal & literal) +{ + DataTypePtr type; + Field field; + + switch (literal.literal_type_case()) + { + case substrait::Expression_Literal::kFp64: { + type = std::make_shared(); + field = literal.fp64(); + break; + } + case substrait::Expression_Literal::kFp32: { + type = std::make_shared(); + field = literal.fp32(); + break; + } + case substrait::Expression_Literal::kString: { + type = std::make_shared(); + field = literal.string(); + break; + } + case substrait::Expression_Literal::kBinary: { + type = std::make_shared(); + field = literal.binary(); + break; + } + case substrait::Expression_Literal::kI64: { + type = std::make_shared(); + field = literal.i64(); + break; + } + case substrait::Expression_Literal::kI32: { + type = std::make_shared(); + field = literal.i32(); + break; + } + case substrait::Expression_Literal::kBoolean: { + type = std::make_shared(); + field = literal.boolean() ? UInt8(1) : UInt8(0); + break; + } + case substrait::Expression_Literal::kI16: { + type = std::make_shared(); + field = literal.i16(); + break; + } + case substrait::Expression_Literal::kI8: { + type = std::make_shared(); + field = literal.i8(); + break; + } + case substrait::Expression_Literal::kDate: { + type = std::make_shared(); + field = literal.date(); + break; + } + case substrait::Expression_Literal::kTimestamp: { + type = std::make_shared(6); + field = DecimalField(literal.timestamp(), 6); + break; + } + case substrait::Expression_Literal::kDecimal: { + UInt32 precision = literal.decimal().precision(); + UInt32 scale = literal.decimal().scale(); + const auto & bytes = literal.decimal().value(); + + if (precision <= DataTypeDecimal32::maxPrecision()) + { + type = std::make_shared(precision, scale); + auto value = *reinterpret_cast(bytes.data()); + field = DecimalField(value, scale); + } + else if (precision <= DataTypeDecimal64::maxPrecision()) + { + type = std::make_shared(precision, scale); + auto value = *reinterpret_cast(bytes.data()); + field = DecimalField(value, scale); + } + else if (precision <= DataTypeDecimal128::maxPrecision()) + { + type = std::make_shared(precision, scale); + String bytes_copy(bytes); + auto value = *reinterpret_cast(bytes_copy.data()); + field = DecimalField(value, scale); + } + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision); + break; + } + /// TODO(taiyang-li) Other type: Struct/Map/List + case substrait::Expression_Literal::kList: { + /// TODO(taiyang-li) Implement empty list + if (literal.has_empty_list()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Empty list not support!"); + + DataTypePtr first_type; + std::tie(first_type, std::ignore) = parseLiteral(literal.list().values(0)); + + size_t list_len = literal.list().values_size(); + Array array(list_len); + for (size_t i = 0; i < list_len; ++i) + { + auto type_and_field = std::move(parseLiteral(literal.list().values(i))); + if (!first_type->equals(*type_and_field.first)) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Literal list type mismatch:{} and {}", + first_type->getName(), + type_and_field.first->getName()); + array[i] = std::move(type_and_field.second); + } + + type = std::make_shared(first_type); + field = std::move(array); + break; + } + case substrait::Expression_Literal::kNull: { + type = parseType(literal.null()); + field = std::move(Field{}); + break; + } + default: { + throw Exception( + ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}", magic_enum::enum_name(literal.literal_type_case())); + } + } + return std::make_pair(std::move(type), std::move(field)); +} + +const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr action_dag, const substrait::Expression & rel) +{ + auto add_column = [&](const DataTypePtr & type, const Field & field) -> auto + { + return &action_dag->addColumn(ColumnWithTypeAndName(type->createColumnConst(1, field), type, getUniqueName(toString(field)))); + }; + + switch (rel.rex_type_case()) + { + case substrait::Expression::RexTypeCase::kLiteral: { + DataTypePtr type; + Field field; + std::tie(type, field) = parseLiteral(rel.literal()); + return add_column(type, field); + } + + case substrait::Expression::RexTypeCase::kSelection: { + if (!rel.selection().has_direct_reference() || !rel.selection().direct_reference().has_struct_field()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections"); + + const auto * field = action_dag->getInputs()[rel.selection().direct_reference().struct_field().field()]; + return action_dag->tryFindInOutputs(field->result_name); + } + + case substrait::Expression::RexTypeCase::kCast: { + if (!rel.cast().has_type() || !rel.cast().has_input()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node."); + + std::string ch_function_name = getCastFunction(rel.cast().type()); + DB::ActionsDAG::NodeRawConstPtrs args; + const auto & cast_input = rel.cast().input(); + args.emplace_back(parseExpression(action_dag, cast_input)); + + if (ch_function_name.starts_with("toDecimal")) + { + UInt32 scale = rel.cast().type().decimal().scale(); + args.emplace_back(add_column(std::make_shared(), scale)); + } + else if (ch_function_name.starts_with("toDateTime64")) + { + /// In Spark: cast(xx as TIMESTAMP) + /// In CH: toDateTime(xx, 6) + /// So we must add extra argument: 6 + args.emplace_back(add_column(std::make_shared(), 6)); + } + + const auto * function_node = toFunctionNode(action_dag, ch_function_name, args); + action_dag->addOrReplaceInOutputs(*function_node); + return function_node; + } + + case substrait::Expression::RexTypeCase::kIfThen: { + const auto & if_then = rel.if_then(); + auto function_multi_if = DB::FunctionFactory::instance().get("multiIf", context); + DB::ActionsDAG::NodeRawConstPtrs args; + + auto condition_nums = if_then.ifs_size(); + for (int i = 0; i < condition_nums; ++i) + { + const auto & ifs = if_then.ifs(i); + const auto * if_node = parseExpression(action_dag, ifs.if_()); + args.emplace_back(if_node); + + const auto * then_node = parseExpression(action_dag, ifs.then()); + args.emplace_back(then_node); + } + + const auto * else_node = parseExpression(action_dag, if_then.else_()); + args.emplace_back(else_node); + std::string args_name; + join(args, ',', args_name); + auto result_name = "multiIf(" + args_name + ")"; + const auto * function_node = &action_dag->addFunction(function_multi_if, args, result_name); + action_dag->addOrReplaceInOutputs(*function_node); + return function_node; + } + + case substrait::Expression::RexTypeCase::kScalarFunction: { + std::string result; + std::vector useless; + return parseFunctionWithDAG(rel, result, useless, action_dag, false); + } + + case substrait::Expression::RexTypeCase::kSingularOrList: { + const auto & options = rel.singular_or_list().options(); + /// options is empty always return false + if (options.empty()) + return add_column(std::make_shared(), 0); + /// options should be literals + if (!options[0].has_literal()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type"); + + DB::ActionsDAG::NodeRawConstPtrs args; + args.emplace_back(parseExpression(action_dag, rel.singular_or_list().value())); + + bool nullable = false; + size_t options_len = options.size(); + for (size_t i = 0; i < options_len; ++i) + { + if (!options[i].has_literal()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!"); + if (!nullable) + nullable = options[i].literal().has_null(); + } + + DataTypePtr elem_type; + std::tie(elem_type, std::ignore) = parseLiteral(options[0].literal()); + elem_type = wrapNullableType(nullable, elem_type); + + MutableColumnPtr elem_column = elem_type->createColumn(); + elem_column->reserve(options_len); + for (size_t i = 0; i < options_len; ++i) + { + auto type_and_field = std::move(parseLiteral(options[i].literal())); + auto option_type = wrapNullableType(nullable, type_and_field.first); + if (!elem_type->equals(*option_type)) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "SingularOrList options type mismatch:{} and {}", + elem_type->getName(), + option_type->getName()); + + elem_column->insert(type_and_field.second); + } + + MutableColumns elem_columns; + elem_columns.emplace_back(std::move(elem_column)); + + auto name = getUniqueName("__set"); + Block elem_block; + elem_block.insert(ColumnWithTypeAndName(nullptr, elem_type, name)); + elem_block.setColumns(std::move(elem_columns)); + + SizeLimits limit; + auto elem_set = std::make_shared(limit, true, false); + elem_set->setHeader(elem_block.getColumnsWithTypeAndName()); + elem_set->insertFromBlock(elem_block.getColumnsWithTypeAndName()); + elem_set->finishInsert(); + + auto arg = ColumnSet::create(elem_set->getTotalRowCount(), elem_set); + args.emplace_back(&action_dag->addColumn(ColumnWithTypeAndName(std::move(arg), std::make_shared(), name))); + + const auto * function_node = toFunctionNode(action_dag, "in", args); + action_dag->addOrReplaceInOutputs(*function_node); + if (nullable) + { + /// if sets has `null` and value not in sets + /// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920) + /// In CH: return `false` + /// So we used if(a, b, c) cast `false` to `null` if sets has `null` + auto type = wrapNullableType(true, function_node->result_type); + DB::ActionsDAG::NodeRawConstPtrs cast_args({function_node, add_column(type, true), add_column(type, Field())}); + auto cast = FunctionFactory::instance().get("if", context); + function_node = toFunctionNode(action_dag, "if", cast_args); + } + return function_node; + } + + default: + throw Exception( + ErrorCodes::UNKNOWN_TYPE, + "Unsupported spark expression type {} : {}", + magic_enum::enum_name(rel.rex_type_case()), + rel.DebugString()); + } +} + +QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) +{ + auto plan_ptr = std::make_unique(); + auto ok = plan_ptr->ParseFromString(plan); + if (!ok) + throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from string failed"); + + auto res = std::move(parse(std::move(plan_ptr))); + + auto * logger = &Poco::Logger::get("SerializedPlanParser"); + if (logger->debug()) + { + auto out = PlanUtil::explainPlan(*res); + LOG_DEBUG(logger, "clickhouse plan:{}", out); + } + return std::move(res); +} + +QueryPlanPtr SerializedPlanParser::parseJson(const std::string & json_plan) +{ + auto plan_ptr = std::make_unique(); + google::protobuf::util::JsonStringToMessage( + google::protobuf::stringpiece_internal::StringPiece(json_plan.c_str()), + plan_ptr.get()); + return parse(std::move(plan_ptr)); +} + +void SerializedPlanParser::initFunctionEnv() +{ + registerFunctions(); + registerAggregateFunctions(); +} +SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) : context(context_) +{ +} +ContextMutablePtr SerializedPlanParser::global_context = nullptr; + +Context::ConfigurationPtr SerializedPlanParser::config = nullptr; + +void SerializedPlanParser::collectJoinKeys( + const substrait::Expression & condition, std::vector> & join_keys, int32_t right_key_start) +{ + auto condition_name = getFunctionName( + function_mapping.at(std::to_string(condition.scalar_function().function_reference())), condition.scalar_function()); + if (condition_name == "and") + { + collectJoinKeys(condition.scalar_function().arguments(0).value(), join_keys, right_key_start); + collectJoinKeys(condition.scalar_function().arguments(1).value(), join_keys, right_key_start); + } + else if (condition_name == "equals") + { + const auto & function = condition.scalar_function(); + auto left_key_idx = function.arguments(0).value().selection().direct_reference().struct_field().field(); + auto right_key_idx = function.arguments(1).value().selection().direct_reference().struct_field().field() - right_key_start; + join_keys.emplace_back(std::pair(left_key_idx, right_key_idx)); + } + else + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "doesn't support condition {}", condition_name); + } +} + +DB::QueryPlanPtr SerializedPlanParser::parseJoin(substrait::JoinRel join, DB::QueryPlanPtr left, DB::QueryPlanPtr right) +{ + google::protobuf::StringValue optimization; + optimization.ParseFromString(join.advanced_extension().optimization().value()); + auto join_opt_info = parseJoinOptimizationInfo(optimization.value()); + auto table_join = std::make_shared(global_context->getSettings(), global_context->getTemporaryVolume()); + if (join.type() == substrait::JoinRel_JoinType_JOIN_TYPE_INNER) + { + table_join->setKind(DB::JoinKind::Inner); + table_join->setStrictness(DB::JoinStrictness::All); + } + else if (join.type() == substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI) + { + table_join->setKind(DB::JoinKind::Left); + table_join->setStrictness(DB::JoinStrictness::Semi); + } + else if (join.type() == substrait::JoinRel_JoinType_JOIN_TYPE_ANTI) + { + table_join->setKind(DB::JoinKind::Left); + table_join->setStrictness(DB::JoinStrictness::Anti); + } + else if (join.type() == substrait::JoinRel_JoinType_JOIN_TYPE_LEFT) + { + table_join->setKind(DB::JoinKind::Left); + table_join->setStrictness(DB::JoinStrictness::All); + } + else if (join.type() == substrait::JoinRel_JoinType_JOIN_TYPE_OUTER) + { + table_join->setKind(DB::JoinKind::Full); + table_join->setStrictness(DB::JoinStrictness::All); + } + else + { + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join.type())); + } + + if (join_opt_info.is_broadcast) + { + auto storage_join = BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key); + ActionsDAGPtr project = ActionsDAG::makeConvertingActions( + right->getCurrentDataStream().header.getColumnsWithTypeAndName(), + storage_join->getRightSampleBlock().getColumnsWithTypeAndName(), + ActionsDAG::MatchColumnsMode::Position); + if (project) + { + QueryPlanStepPtr project_step = std::make_unique(right->getCurrentDataStream(), project); + project_step->setStepDescription("Rename Broadcast Table Name"); + right->addStep(std::move(project_step)); + } + } + + NameSet left_columns_set; + for (const auto & col : left->getCurrentDataStream().header.getNames()) + { + left_columns_set.emplace(col); + } + table_join->setColumnsFromJoinedTable(right->getCurrentDataStream().header.getNamesAndTypesList(), + left_columns_set, + getUniqueName("right") + "."); + // fix right table key duplicate + NamesWithAliases right_table_alias; + for (size_t idx = 0; idx < table_join->columnsFromJoinedTable().size(); idx++) + { + auto origin_name = right->getCurrentDataStream().header.getByPosition(idx).name; + auto dedup_name = table_join->columnsFromJoinedTable().getNames().at(idx); + if (origin_name != dedup_name) + { + right_table_alias.emplace_back(NameWithAlias(origin_name, dedup_name)); + } + } + if (!right_table_alias.empty()) + { + ActionsDAGPtr rename_dag = std::make_shared(right->getCurrentDataStream().header.getNamesAndTypesList()); + auto original_right_columns = right->getCurrentDataStream().header; + for (const auto & column_alias : right_table_alias) + { + if (original_right_columns.has(column_alias.first)) + { + auto pos = original_right_columns.getPositionByName(column_alias.first); + const auto & alias = rename_dag->addAlias(*rename_dag->getInputs()[pos], column_alias.second); + rename_dag->getOutputs()[pos] = &alias; + } + } + rename_dag->projectInput(); + QueryPlanStepPtr project_step = std::make_unique(right->getCurrentDataStream(), rename_dag); + project_step->setStepDescription("Right Table Rename"); + right->addStep(std::move(project_step)); + } + + for (const auto & column : table_join->columnsFromJoinedTable()) + { + table_join->addJoinedColumn(column); + } + ActionsDAGPtr left_convert_actions = nullptr; + ActionsDAGPtr right_convert_actions = nullptr; + std::tie(left_convert_actions, right_convert_actions) = table_join->createConvertingActions( + left->getCurrentDataStream().header.getColumnsWithTypeAndName(), right->getCurrentDataStream().header.getColumnsWithTypeAndName()); + + if (right_convert_actions) + { + auto converting_step = std::make_unique(right->getCurrentDataStream(), right_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + right->addStep(std::move(converting_step)); + } + + if (left_convert_actions) + { + auto converting_step = std::make_unique(left->getCurrentDataStream(), left_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + left->addStep(std::move(converting_step)); + } + QueryPlanPtr query_plan; + Names after_join_names; + auto left_names = left->getCurrentDataStream().header.getNames(); + after_join_names.insert(after_join_names.end(), left_names.begin(), left_names.end()); + auto right_name = table_join->columnsFromJoinedTable().getNames(); + after_join_names.insert(after_join_names.end(), right_name.begin(), right_name.end()); + + bool add_filter_step = false; + try + { + parseJoinKeysAndCondition(table_join, join, left, right, table_join->columnsFromJoinedTable(), after_join_names); + } + // if ch not support the join type or join conditions, it will throw an exception like 'not support'. + catch (Poco::Exception & e) + { + // CH not support join condition has 'or' and has different table in each side. + // But in inner join, we could execute join condition after join. so we have add filter step + if (e.code() == ErrorCodes::INVALID_JOIN_ON_EXPRESSION && table_join->kind() == DB::JoinKind::Inner) + { + add_filter_step = true; + } + else + { + throw; + } + } + + if (join_opt_info.is_broadcast) + { + auto storage_join = BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key); + if (!storage_join) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "broad cast table {} not found.", join_opt_info.storage_join_key); + } + auto hash_join = storage_join->getJoinLocked(table_join, context); + QueryPlanStepPtr join_step = std::make_unique(left->getCurrentDataStream(), hash_join, 8192); + + join_step->setStepDescription("JOIN"); + left->addStep(std::move(join_step)); + query_plan = std::move(left); + } + else + { + auto hash_join = std::make_shared(table_join, right->getCurrentDataStream().header.cloneEmpty()); + QueryPlanStepPtr join_step + = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), hash_join, 8192, 1, false); + + join_step->setStepDescription("JOIN"); + + std::vector plans; + plans.emplace_back(std::move(left)); + plans.emplace_back(std::move(right)); + + query_plan = std::make_unique(); + query_plan->unitePlans(std::move(join_step), {std::move(plans)}); + } + + reorderJoinOutput(*query_plan, after_join_names); + if (add_filter_step) + { + std::string filter_name; + std::vector useless; + auto actions_dag + = parseFunction(query_plan->getCurrentDataStream().header, join.post_join_filter(), filter_name, useless, nullptr, true); + auto filter_step = std::make_unique(query_plan->getCurrentDataStream(), actions_dag, filter_name, true); + filter_step->setStepDescription("Post Join Filter"); + query_plan->addStep(std::move(filter_step)); + } + return query_plan; +} + +void SerializedPlanParser::parseJoinKeysAndCondition( + std::shared_ptr table_join, + substrait::JoinRel & join, + DB::QueryPlanPtr & left, + DB::QueryPlanPtr & right, + const NamesAndTypesList & alias_right, + Names & names) +{ + ASTs args; + ASTParser astParser(context, function_mapping); + + if (join.has_expression()) + { + args.emplace_back(astParser.parseToAST(names, join.expression())); + } + + if (join.has_post_join_filter()) + { + args.emplace_back(astParser.parseToAST(names, join.post_join_filter())); + } + + if (args.empty()) + return; + + ASTPtr ast = args.size() == 1 ? args.back() : makeASTFunction("and", args); + + bool is_asof = (table_join->strictness() == JoinStrictness::Asof); + + Aliases aliases; + DatabaseAndTableWithAlias left_table_name; + DatabaseAndTableWithAlias right_table_name; + TableWithColumnNamesAndTypes left_table(left_table_name, left->getCurrentDataStream().header.getNamesAndTypesList()); + TableWithColumnNamesAndTypes right_table(right_table_name, alias_right); + + CollectJoinOnKeysVisitor::Data data{*table_join, left_table, right_table, aliases, is_asof}; + if (auto * or_func = ast->as(); or_func && or_func->name == "or") + { + for (auto & disjunct : or_func->arguments->children) + { + table_join->addDisjunct(); + CollectJoinOnKeysVisitor(data).visit(disjunct); + } + assert(table_join->getClauses().size() == or_func->arguments->children.size()); + } + else + { + table_join->addDisjunct(); + CollectJoinOnKeysVisitor(data).visit(ast); + assert(table_join->oneDisjunct()); + } + + if (join.has_post_join_filter()) + { + auto left_keys = table_join->leftKeysList(); + auto right_keys = table_join->rightKeysList(); + if (!left_keys->children.empty()) + { + auto actions = astParser.convertToActions(left->getCurrentDataStream().header.getNamesAndTypesList(), left_keys); + QueryPlanStepPtr before_join_step = std::make_unique(left->getCurrentDataStream(), actions); + before_join_step->setStepDescription("Before JOIN LEFT"); + left->addStep(std::move(before_join_step)); + } + + if (!right_keys->children.empty()) + { + auto actions = astParser.convertToActions(right->getCurrentDataStream().header.getNamesAndTypesList(), right_keys); + QueryPlanStepPtr before_join_step = std::make_unique(right->getCurrentDataStream(), actions); + before_join_step->setStepDescription("Before JOIN RIGHT"); + right->addStep(std::move(before_join_step)); + } + } +} + +ActionsDAGPtr ASTParser::convertToActions(const NamesAndTypesList & name_and_types, const ASTPtr & ast) +{ + NamesAndTypesList aggregation_keys; + ColumnNumbersList aggregation_keys_indexes_list; + AggregationKeysInfo info(aggregation_keys, aggregation_keys_indexes_list, GroupByKind::NONE); + SizeLimits size_limits_for_set; + ActionsVisitor::Data visitor_data( + context, + size_limits_for_set, + size_t(0), + name_and_types, + std::make_shared(name_and_types), + nullptr /* prepared_sets */, + false /* no_subqueries */, + false /* no_makeset */, + false /* only_consts */, + false /* create_source_for_in */, + info); + ActionsVisitor(visitor_data).visit(ast); + return visitor_data.getActions(); +} + +ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & rel) +{ + LOG_DEBUG(&Poco::Logger::get("ASTParser"), "substrait plan:{}", rel.DebugString()); + if (!rel.has_scalar_function()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); + + const auto & scalar_function = rel.scalar_function(); + auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); + auto function_name = SerializedPlanParser::getFunctionName(function_signature, scalar_function); + ASTs ast_args; + parseFunctionArgumentsToAST(names, scalar_function, ast_args); + + return makeASTFunction(function_name, ast_args); +} + +void ASTParser::parseFunctionArgumentsToAST( + const Names & names, const substrait::Expression_ScalarFunction & scalar_function, ASTs & ast_args) +{ + const auto & args = scalar_function.arguments(); + + for (const auto & arg : args) + { + if (arg.value().has_scalar_function()) + { + ast_args.emplace_back(parseToAST(names, arg.value())); + } + else + { + ast_args.emplace_back(parseArgumentToAST(names, arg.value())); + } + } +} + +ASTPtr ASTParser::parseArgumentToAST(const Names & names, const substrait::Expression & rel) +{ + switch (rel.rex_type_case()) + { + case substrait::Expression::RexTypeCase::kLiteral: { + DataTypePtr type; + Field field; + std::tie(std::ignore, field) = SerializedPlanParser::parseLiteral(rel.literal()); + return std::make_shared(field); + } + case substrait::Expression::RexTypeCase::kSelection: { + if (!rel.selection().has_direct_reference() || !rel.selection().direct_reference().has_struct_field()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections"); + + const auto field = rel.selection().direct_reference().struct_field().field(); + return std::make_shared(names[field]); + } + case substrait::Expression::RexTypeCase::kCast: { + if (!rel.cast().has_type() || !rel.cast().has_input()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node."); + + std::string ch_function_name = getCastFunction(rel.cast().type()); + + ASTs args; + args.emplace_back(parseArgumentToAST(names, rel.cast().input())); + + if (ch_function_name.starts_with("toDecimal")) + { + UInt32 scale = rel.cast().type().decimal().scale(); + args.emplace_back(std::make_shared(scale)); + } + else if (ch_function_name.starts_with("toDateTime64")) + { + /// In Spark: cast(xx as TIMESTAMP) + /// In CH: toDateTime(xx, 6) + /// So we must add extra argument: 6 + args.emplace_back(std::make_shared(6)); + } + + return makeASTFunction(ch_function_name, args); + } + case substrait::Expression::RexTypeCase::kIfThen: { + const auto & if_then = rel.if_then(); + const auto * ch_function_name = "multiIf"; + auto function_multi_if = DB::FunctionFactory::instance().get(ch_function_name, context); + ASTs args; + + auto condition_nums = if_then.ifs_size(); + for (int i = 0; i < condition_nums; ++i) + { + const auto & ifs = if_then.ifs(i); + auto if_node = parseArgumentToAST(names, ifs.if_()); + args.emplace_back(if_node); + + auto then_node = parseArgumentToAST(names, ifs.then()); + args.emplace_back(then_node); + } + + auto else_node = parseArgumentToAST(names, if_then.else_()); + return makeASTFunction(ch_function_name, args); + } + case substrait::Expression::RexTypeCase::kScalarFunction: { + return parseToAST(names, rel); + } + case substrait::Expression::RexTypeCase::kSingularOrList: { + const auto & options = rel.singular_or_list().options(); + /// options is empty always return false + if (options.empty()) + return std::make_shared(0); + /// options should be literals + if (!options[0].has_literal()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type"); + + ASTs args; + args.emplace_back(parseArgumentToAST(names, rel.singular_or_list().value())); + + bool nullable = false; + size_t options_len = options.size(); + args.reserve(options_len); + + for (size_t i = 0; i < options_len; ++i) + { + if (!options[i].has_literal()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!"); + if (!nullable) + nullable = options[i].literal().has_null(); + } + + auto elem_type_and_field = SerializedPlanParser::parseLiteral(options[0].literal()); + DataTypePtr elem_type = wrapNullableType(nullable, elem_type_and_field.first); + for (size_t i = 0; i < options_len; ++i) + { + auto type_and_field = std::move(SerializedPlanParser::parseLiteral(options[i].literal())); + auto option_type = wrapNullableType(nullable, type_and_field.first); + if (!elem_type->equals(*option_type)) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "SingularOrList options type mismatch:{} and {}", + elem_type->getName(), + option_type->getName()); + + args.emplace_back(std::make_shared(type_and_field.second)); + } + + auto ast = makeASTFunction("in", args); + if (nullable) + { + /// if sets has `null` and value not in sets + /// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920) + /// In CH: return `false` + /// So we used if(a, b, c) cast `false` to `null` if sets has `null` + ast = makeASTFunction("if", ast, std::make_shared(true), std::make_shared(Field())); + } + + return ast; + } + default: + throw Exception( + ErrorCodes::UNKNOWN_TYPE, + "Join on condition error. Unsupported spark expression type {} : {}", + magic_enum::enum_name(rel.rex_type_case()), + rel.DebugString()); + } +} + +void SerializedPlanParser::reorderJoinOutput(QueryPlan & plan, DB::Names cols) +{ + ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); + NamesWithAliases project_cols; + for (const auto & col : cols) + { + project_cols.emplace_back(NameWithAlias(col, col)); + } + project->project(project_cols); + QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); + project_step->setStepDescription("Reorder Join Output"); + plan.addStep(std::move(project_step)); +} + +void SerializedPlanParser::removeNullable(std::vector require_columns, ActionsDAGPtr actionsDag) +{ + for (const auto & item : require_columns) + { + const auto * require_node = actionsDag->tryFindInOutputs(item); + if (require_node) + { + auto function_builder = FunctionFactory::instance().get("assumeNotNull", context); + ActionsDAG::NodeRawConstPtrs args = {require_node}; + const auto & node = actionsDag->addFunction(function_builder, args, item); + actionsDag->addOrReplaceInOutputs(node); + } + } +} + +void SerializedPlanParser::wrapNullable(std::vector columns, ActionsDAGPtr actionsDag, + std::map& nullable_measure_names) +{ + for (const auto & item : columns) + { + ActionsDAG::NodeRawConstPtrs args; + args.emplace_back(&actionsDag->findInOutputs(item)); + const auto * node = toFunctionNode(actionsDag, "toNullable", args); + actionsDag->addOrReplaceInOutputs(*node); + nullable_measure_names[item] = node->result_name; + } +} + +SharedContextHolder SerializedPlanParser::shared_context; + +LocalExecutor::~LocalExecutor() +{ + if (spark_buffer) + { + ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size); + spark_buffer.reset(); + } +} + + +void LocalExecutor::execute(QueryPlanPtr query_plan) +{ + current_query_plan = std::move(query_plan); + Stopwatch stopwatch; + stopwatch.start(); + QueryPlanOptimizationSettings optimization_settings{.optimize_plan = true}; + auto pipeline_builder = current_query_plan->buildQueryPipeline( + optimization_settings, + BuildQueryPipelineSettings{ + .actions_settings = ExpressionActionsSettings{ + .can_compile_expressions = true, .min_count_to_compile_expression = 3, .compile_expressions = CompileExpressions::yes}}); + query_pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); + LOG_DEBUG(&Poco::Logger::get("LocalExecutor"), "clickhouse pipeline:{}", QueryPipelineUtil::explainPipeline(query_pipeline)); + auto t_pipeline = stopwatch.elapsedMicroseconds(); + executor = std::make_unique(query_pipeline); + auto t_executor = stopwatch.elapsedMicroseconds() - t_pipeline; + stopwatch.stop(); + LOG_INFO( + &Poco::Logger::get("SerializedPlanParser"), + "build pipeline {} ms; create executor {} ms;", + t_pipeline / 1000.0, + t_executor / 1000.0); + header = current_query_plan->getCurrentDataStream().header.cloneEmpty(); + ch_column_to_spark_row = std::make_unique(); + +} +std::unique_ptr LocalExecutor::writeBlockToSparkRow(Block & block) +{ + return ch_column_to_spark_row->convertCHColumnToSparkRow(block); +} +bool LocalExecutor::hasNext() +{ + bool has_next; + try + { + if (currentBlock().columns() == 0 || isConsumed()) + { + auto empty_block = header.cloneEmpty(); + setCurrentBlock(empty_block); + has_next = executor->pull(currentBlock()); + produce(); + } + else + { + has_next = true; + } + } + catch (DB::Exception & e) + { + LOG_ERROR( + &Poco::Logger::get("LocalExecutor"), "run query plan failed. {}\n{}", e.message(), PlanUtil::explainPlan(*current_query_plan)); + throw; + } + return has_next; +} +SparkRowInfoPtr LocalExecutor::next() +{ + checkNextValid(); + SparkRowInfoPtr row_info = writeBlockToSparkRow(currentBlock()); + consume(); + if (spark_buffer) + { + ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size); + spark_buffer.reset(); + } + spark_buffer = std::make_unique(); + spark_buffer->address = row_info->getBufferAddress(); + spark_buffer->size = row_info->getTotalBytes(); + return row_info; +} + +Block * LocalExecutor::nextColumnar() +{ + checkNextValid(); + Block * columnar_batch; + if (currentBlock().columns() > 0) + { + columnar_batch = ¤tBlock(); + } + else + { + auto empty_block = header.cloneEmpty(); + setCurrentBlock(empty_block); + columnar_batch = ¤tBlock(); + } + consume(); + return columnar_batch; +} + +Block & LocalExecutor::getHeader() +{ + return header; +} +LocalExecutor::LocalExecutor(QueryContext & _query_context) + : query_context(_query_context) +{ +} +} diff --git a/utils/local-engine/Parser/SerializedPlanParser.h b/utils/local-engine/Parser/SerializedPlanParser.h new file mode 100644 index 000000000000..650e3fbac0b5 --- /dev/null +++ b/utils/local-engine/Parser/SerializedPlanParser.h @@ -0,0 +1,413 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ + +static const std::map SCALAR_FUNCTIONS = { + {"is_not_null","isNotNull"}, + {"is_null","isNull"}, + {"gte","greaterOrEquals"}, + {"gt", "greater"}, + {"lte", "lessOrEquals"}, + {"lt", "less"}, + {"equal", "equals"}, + + {"and", "and"}, + {"or", "or"}, + {"not", "not"}, + {"xor", "xor"}, + + {"extract", ""}, + {"cast", ""}, + {"alias", "alias"}, + + /// datetime functions + {"to_date", "toDate"}, + {"quarter", "toQuarter"}, + {"to_unix_timestamp", "toUnixTimestamp"}, + {"unix_timestamp", "toUnixTimestamp"}, + {"date_format", "formatDateTimeInJodaSyntax"}, + + /// arithmetic functions + {"subtract", "minus"}, + {"multiply", "multiply"}, + {"add", "plus"}, + {"divide", "divide"}, + {"modulus", "modulo"}, + {"pmod", "pmod"}, + {"abs", "abs"}, + {"ceil", "ceil"}, + {"floor", "floor"}, + {"round", "round"}, + {"bround", "roundBankers"}, + {"exp", "exp"}, + {"power", "power"}, + {"cos", "cos"}, + {"cosh", "cosh"}, + {"sin", "sin"}, + {"sinh", "sinh"}, + {"tan", "tan"}, + {"tanh", "tanh"}, + {"acos", "acos"}, + {"asin", "asin"}, + {"atan", "atan"}, + {"atan2", "atan2"}, + {"bitwise_not", "bitNot"}, + {"bitwise_and", "bitAnd"}, + {"bitwise_or", "bitOr"}, + {"bitwise_xor", "bitXor"}, + {"sqrt", "sqrt"}, + {"cbrt", "cbrt"}, + {"degrees", "degrees"}, + {"e", "e"}, + {"pi", "pi"}, + {"hex", "hex"}, + {"unhex", "unhex"}, + {"hypot", "hypot"}, + {"sign", "sign"}, + {"log10", "log10"}, + {"log1p", "log1p"}, + {"log2", "log2"}, + {"log", "log"}, + {"radians", "radians"}, + {"greatest", "greatest"}, + {"least", "least"}, + {"shiftleft", "bitShiftLeft"}, + {"shiftright", "bitShiftRight"}, + {"check_overflow", "check_overflow"}, + {"factorial", "factorial"}, + {"rand", "randCanonical"}, + {"isnan", "isNaN"}, + + /// string functions + {"like", "like"}, + {"not_like", "notLike"}, + {"starts_with", "startsWith"}, + {"ends_with", "endsWith"}, + {"contains", "countSubstrings"}, + {"substring", "substring"}, + {"lower", "lower"}, + {"upper", "upper"}, + {"trim", ""}, + {"ltrim", ""}, + {"rtrim", ""}, + {"concat", "concat"}, + {"strpos", "position"}, + {"char_length", "char_length"}, + {"replace", "replaceAll"}, + {"regexp_replace", "replaceRegexpAll"}, + {"chr", "char"}, + {"rlike", "match"}, + {"ascii", "ascii"}, + {"split", "splitByRegexp"}, + {"concat_ws", "concat_ws"}, + {"base64", "base64Encode"}, + {"unbase64","base64Decode"}, + {"lpad","leftPadUTF8"}, + {"rpad","rightPadUTF8"}, + {"reverse","reverseUTF8"}, + // {"hash","murmurHash3_32"}, + {"md5","MD5"}, + {"translate", "translateUTF8"}, + {"repeat","repeat"}, + {"position", "positionUTF8Spark"}, + {"locate", "positionUTF8Spark"}, + {"space","space"}, + + /// hash functions + {"hash", "murmurHashSpark3_32"}, + {"xxhash64", "xxHashSpark64"}, + + // in functions + {"in", "in"}, + + // null related functions + {"coalesce", "coalesce"}, + + // aggregate functions + {"count", "count"}, + {"avg", "avg"}, + {"sum", "sum"}, + {"min", "min"}, + {"max", "max"}, + {"collect_list", "groupArray"}, + {"stddev_samp", "stddev_samp"}, + {"stddev_pop", "stddev_pop"}, + + // date or datetime functions + {"from_unixtime", "fromUnixTimestampInJodaSyntax"}, + {"date_add", "addDays"}, + {"date_sub", "subtractDays"}, + {"datediff", "dateDiff"}, + {"second", "toSecond"}, + {"add_months", "addMonths"}, + {"trunc", ""}, /// dummy mapping + + // array functions + {"array", "array"}, + {"size", "length"}, + {"get_array_item", "arrayElement"}, + {"element_at", "arrayElement"}, + {"array_contains", "has"}, + {"range", "range"}, /// dummy mapping + + // map functions + {"map", "map"}, + {"get_map_value", "arrayElement"}, + {"map_keys", "mapKeys"}, + {"map_values", "mapValues"}, + {"map_from_arrays", "mapFromArrays"}, + + // tuple functions + {"get_struct_field", "tupleElement"}, + {"named_struct", "tuple"}, + + // table-valued generator function + {"explode", "arrayJoin"}, + {"posexplode", "arrayJoin"}, + + // json functions + {"get_json_object", "JSON_VALUE"}, + {"to_json", "toJSONString"}, + {"from_json", "JSONExtract"}, + {"json_tuple", "json_tuple"} +}; + +static const std::set FUNCTION_NEED_KEEP_ARGUMENTS = {"alias"}; + +struct QueryContext +{ + StorageSnapshotPtr storage_snapshot; + std::shared_ptr metadata; + std::shared_ptr custom_storage_merge_tree; +}; + +DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type); +DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type); + +class SerializedPlanParser +{ + friend class RelParser; + friend class ASTParser; +public: + explicit SerializedPlanParser(const ContextPtr & context); + static void initFunctionEnv(); + DB::QueryPlanPtr parse(const std::string & plan); + DB::QueryPlanPtr parseJson(const std::string & json_plan); + DB::QueryPlanPtr parse(std::unique_ptr plan); + + DB::QueryPlanPtr parseReadRealWithLocalFile(const substrait::ReadRel & rel); + DB::QueryPlanPtr parseReadRealWithJavaIter(const substrait::ReadRel & rel); + DB::QueryPlanPtr parseMergeTreeTable(const substrait::ReadRel & rel); + PrewhereInfoPtr parsePreWhereInfo(const substrait::Expression & rel, Block & input, std::vector& not_nullable_columns); + + static bool isReadRelFromJava(const substrait::ReadRel & rel); + static DB::Block parseNameStruct(const substrait::NamedStruct & struct_); + static DB::DataTypePtr parseType(const substrait::Type & type, std::list * names = nullptr); + // This is used for construct a data type from spark type name; + static DB::DataTypePtr parseType(const std::string & type); + + void addInputIter(jobject iter) { input_iters.emplace_back(iter); } + + void parseExtensions(const ::google::protobuf::RepeatedPtrField & extensions); + std::shared_ptr expressionsToActionsDAG( + const std::vector & expressions, + const DB::Block & header, + const DB::Block & read_schema); + + static ContextMutablePtr global_context; + static Context::ConfigurationPtr config; + static SharedContextHolder shared_context; + QueryContext query_context; + +private: + static DB::NamesAndTypesList blockToNameAndTypeList(const DB::Block & header); + DB::QueryPlanPtr parseOp(const substrait::Rel & rel, std::list & rel_stack); + void + collectJoinKeys(const substrait::Expression & condition, std::vector> & join_keys, int32_t right_key_start); + DB::QueryPlanPtr parseJoin(substrait::JoinRel join, DB::QueryPlanPtr left, DB::QueryPlanPtr right); + void parseJoinKeysAndCondition( + std::shared_ptr table_join, + substrait::JoinRel & join, + DB::QueryPlanPtr & left, + DB::QueryPlanPtr & right, + const NamesAndTypesList & alias_right, + Names & names); + + static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols); + static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); + DB::ActionsDAGPtr parseFunction( + const Block & input, + const substrait::Expression & rel, + std::string & result_name, + std::vector & required_columns, + DB::ActionsDAGPtr actions_dag = nullptr, + bool keep_result = false); + DB::ActionsDAGPtr parseArrayJoin( + const Block & input, + const substrait::Expression & rel, + std::vector & result_names, + std::vector & required_columns, + DB::ActionsDAGPtr actions_dag = nullptr, + bool keep_result = false, + bool position = false); + const ActionsDAG::Node * parseFunctionWithDAG( + const substrait::Expression & rel, + std::string & result_name, + std::vector & required_columns, + DB::ActionsDAGPtr actions_dag = nullptr, + bool keep_result = false); + ActionsDAG::NodeRawConstPtrs parseArrayJoinWithDAG( + const substrait::Expression & rel, + std::vector & result_name, + std::vector & required_columns, + DB::ActionsDAGPtr actions_dag = nullptr, + bool keep_result = false, + bool position = false); + void parseFunctionArguments( + DB::ActionsDAGPtr & actions_dag, + ActionsDAG::NodeRawConstPtrs & parsed_args, + std::vector & required_columns, + std::string & function_name, + const substrait::Expression_ScalarFunction & scalar_function); + void parseFunctionArgument( + DB::ActionsDAGPtr & actions_dag, + ActionsDAG::NodeRawConstPtrs & parsed_args, + std::vector & required_columns, + const std::string & function_name, + const substrait::FunctionArgument & arg); + const DB::ActionsDAG::Node * parseFunctionArgument( + DB::ActionsDAGPtr & actions_dag, + std::vector & required_columns, + const std::string & function_name, + const substrait::FunctionArgument & arg); + void addPreProjectStepIfNeeded( + QueryPlan & plan, + const substrait::AggregateRel & rel, + std::vector & measure_names, + std::map & nullable_measure_names); + DB::QueryPlanStepPtr parseAggregate(DB::QueryPlan & plan, const substrait::AggregateRel & rel, bool & is_final); + const DB::ActionsDAG::Node * parseExpression(DB::ActionsDAGPtr action_dag, const substrait::Expression & rel); + const ActionsDAG::Node * + toFunctionNode(ActionsDAGPtr action_dag, const String & function, const DB::ActionsDAG::NodeRawConstPtrs & args); + // remove nullable after isNotNull + void removeNullable(std::vector require_columns, ActionsDAGPtr actionsDag); + std::string getUniqueName(const std::string & name) { return name + "_" + std::to_string(name_no++); } + + static std::pair parseLiteral(const substrait::Expression_Literal & literal); + void wrapNullable(std::vector columns, ActionsDAGPtr actionsDag, + std::map& nullable_measure_names); + + static Aggregator::Params getAggregateParam(const Names & keys, + const AggregateDescriptions & aggregates) + { + Settings settings; + return Aggregator::Params( + keys, + aggregates, + false, + settings.max_rows_to_group_by, + settings.group_by_overflow_mode, + settings.group_by_two_level_threshold, + settings.group_by_two_level_threshold_bytes, + settings.max_bytes_before_external_group_by, + settings.empty_result_for_aggregation_by_empty_set, + nullptr, + settings.max_threads, + settings.min_free_disk_space_for_temporary_data, + true, + 3, + settings.max_block_size, + false, + false); + } + + static Aggregator::Params + getMergedAggregateParam(const Names & keys, const AggregateDescriptions & aggregates) + { + Settings settings; + return Aggregator::Params(keys, aggregates, false, settings.max_threads, settings.max_block_size); + } + + void addRemoveNullableStep(QueryPlan & plan, std::vector columns); + + static std::pair convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field); + + int name_no = 0; + std::unordered_map function_mapping; + std::vector input_iters; + const substrait::ProjectRel * last_project = nullptr; + ContextPtr context; + +}; + +struct SparkBuffer +{ + char * address; + size_t size; +}; + +class LocalExecutor : public BlockIterator +{ +public: + LocalExecutor() = default; + explicit LocalExecutor(QueryContext & _query_context); + void execute(QueryPlanPtr query_plan); + SparkRowInfoPtr next(); + Block * nextColumnar(); + bool hasNext(); + ~LocalExecutor(); + + Block & getHeader(); + +private: + QueryContext query_context; + std::unique_ptr writeBlockToSparkRow(DB::Block & block); + QueryPipeline query_pipeline; + std::unique_ptr executor; + Block header; + std::unique_ptr ch_column_to_spark_row; + std::unique_ptr spark_buffer; + DB::QueryPlanPtr current_query_plan; +}; + + +class ASTParser +{ +public: + explicit ASTParser(const ContextPtr & _context, std::unordered_map & _function_mapping) + : context(_context), function_mapping(_function_mapping){}; + ~ASTParser() = default; + + ASTPtr parseToAST(const Names & names, const substrait::Expression & rel); + ActionsDAGPtr convertToActions(const NamesAndTypesList & name_and_types, const ASTPtr & ast); + +private: + ContextPtr context; + std::unordered_map function_mapping; + + void parseFunctionArgumentsToAST(const Names & names, const substrait::Expression_ScalarFunction & scalar_function, ASTs & ast_args); + ASTPtr parseArgumentToAST(const Names & names, const substrait::Expression & rel); +}; +} diff --git a/utils/local-engine/Parser/SortRelParser.cpp b/utils/local-engine/Parser/SortRelParser.cpp new file mode 100644 index 000000000000..09669b44dfe3 --- /dev/null +++ b/utils/local-engine/Parser/SortRelParser.cpp @@ -0,0 +1,98 @@ +#include "SortRelParser.h" +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +} + +namespace local_engine +{ + +SortRelParser::SortRelParser(SerializedPlanParser * plan_paser_) + : RelParser(plan_paser_) +{} + +DB::QueryPlanPtr +SortRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list & rel_stack_) +{ + size_t limit = parseLimit(rel_stack_); + const auto & sort_rel = rel.sort(); + auto sort_descr = parseSortDescription(sort_rel.sorts(), query_plan->getCurrentDataStream().header); + const auto & settings = getContext()->getSettingsRef(); + auto sorting_step = std::make_unique( + query_plan->getCurrentDataStream(), + sort_descr, + limit, + SortingStep::Settings(*getContext()), + false); + sorting_step->setStepDescription("Sorting step"); + query_plan->addStep(std::move(sorting_step)); + return query_plan; +} + +DB::SortDescription +SortRelParser::parseSortDescription(const google::protobuf::RepeatedPtrField & sort_fields, const DB::Block & header) +{ + static std::map> direction_map = {{1, {1, -1}}, {2, {1, 1}}, {3, {-1, 1}}, {4, {-1, -1}}}; + + DB::SortDescription sort_descr; + for (int i = 0, sz = sort_fields.size(); i < sz; ++i) + { + const auto & sort_field = sort_fields[i]; + + if (!sort_field.expr().has_selection() || !sort_field.expr().selection().has_direct_reference() + || !sort_field.expr().selection().direct_reference().has_struct_field()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport sort field"); + } + auto field_pos = sort_field.expr().selection().direct_reference().struct_field().field(); + + auto direction_iter = direction_map.find(sort_field.direction()); + if (direction_iter == direction_map.end()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsuppor sort direction: {}", sort_field.direction()); + } + if (header.columns()) + { + const auto & col_name = header.getByPosition(field_pos).name; + sort_descr.emplace_back(col_name, direction_iter->second.first, direction_iter->second.second); + sort_descr.back().column_name = col_name; + } + else + { + const auto & col_name = header.getByPosition(field_pos).name; + sort_descr.emplace_back(col_name, direction_iter->second.first, direction_iter->second.second); + } + } + return sort_descr; +} + +size_t SortRelParser::parseLimit(std::list & rel_stack_) +{ + if (rel_stack_.empty()) + return 0; + const auto & last_rel = *rel_stack_.back(); + if (last_rel.has_fetch()) + { + const auto & fetch_rel = last_rel.fetch(); + return fetch_rel.count(); + } + return 0; +} + +void registerSortRelParser(RelParserFactory & factory) +{ + auto builder = [](SerializedPlanParser * plan_parser) + { + return std::make_shared(plan_parser); + }; + factory.registerBuilder(substrait::Rel::RelTypeCase::kSort, builder); +} +} diff --git a/utils/local-engine/Parser/SortRelParser.h b/utils/local-engine/Parser/SortRelParser.h new file mode 100644 index 000000000000..4661cd89e369 --- /dev/null +++ b/utils/local-engine/Parser/SortRelParser.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include +#include +#include +namespace local_engine +{ +class SortRelParser : public RelParser +{ +public: + explicit SortRelParser(SerializedPlanParser * plan_paser_); + ~SortRelParser() override = default; + + DB::QueryPlanPtr + parse(DB::QueryPlanPtr query_plan, const substrait::Rel & sort_rel, std::list & rel_stack_) override; + static DB::SortDescription parseSortDescription(const google::protobuf::RepeatedPtrField & sort_fields, const DB::Block & header); +private: + size_t parseLimit(std::list & rel_stack_); +}; +} diff --git a/utils/local-engine/Parser/SparkRowToCHColumn.cpp b/utils/local-engine/Parser/SparkRowToCHColumn.cpp new file mode 100644 index 000000000000..6a0ea6a5a4dc --- /dev/null +++ b/utils/local-engine/Parser/SparkRowToCHColumn.cpp @@ -0,0 +1,417 @@ +#include "SparkRowToCHColumn.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int UNKNOWN_TYPE; + extern const int LOGICAL_ERROR; +} +} + +using namespace DB; + +namespace local_engine +{ +jclass SparkRowToCHColumn::spark_row_interator_class = nullptr; +jmethodID SparkRowToCHColumn::spark_row_interator_hasNext = nullptr; +jmethodID SparkRowToCHColumn::spark_row_interator_next = nullptr; +jmethodID SparkRowToCHColumn::spark_row_iterator_nextBatch = nullptr; + +ALWAYS_INLINE static void writeRowToColumns(std::vector & columns, const SparkRowReader & spark_row_reader) +{ + auto num_fields = columns.size(); + const auto & field_types = spark_row_reader.getFieldTypes(); + for (size_t i = 0; i < num_fields; i++) + { + if (spark_row_reader.supportRawData(i)) + { + const StringRef str_ref{std::move(spark_row_reader.getStringRef(i))}; + if (str_ref.data == nullptr) + columns[i]->insertData(nullptr, str_ref.size); + else if (!spark_row_reader.isBigEndianInSparkRow(i)) + columns[i]->insertData(str_ref.data, str_ref.size); + else + columns[i]->insert(spark_row_reader.getField(i)); // read decimal128 + } + else + columns[i]->insert(spark_row_reader.getField(i)); + } +} + +std::unique_ptr +SparkRowToCHColumn::convertSparkRowInfoToCHColumn(const SparkRowInfo & spark_row_info, const Block & header) +{ + auto block = std::make_unique(); + const auto num_rows = spark_row_info.getNumRows(); + if (header.columns()) + { + *block = std::move(header.cloneEmpty()); + MutableColumns mutable_columns{std::move(block->mutateColumns())}; + for (size_t col_i = 0; col_i < header.columns(); ++col_i) + mutable_columns[col_i]->reserve(num_rows); + + DataTypes types{std::move(header.getDataTypes())}; + SparkRowReader row_reader(types); + for (int64_t i = 0; i < num_rows; i++) + { + row_reader.pointTo(spark_row_info.getBufferAddress() + spark_row_info.getOffsets()[i], spark_row_info.getLengths()[i]); + writeRowToColumns(mutable_columns, row_reader); + } + block->setColumns(std::move(mutable_columns)); + } + else + { + // This is a special case for count(1)/count(*) + *block = BlockUtil::buildRowCountBlock(num_rows); + } + return std::move(block); +} + +void SparkRowToCHColumn::appendSparkRowToCHColumn(SparkRowToCHColumnHelper & helper, char * buffer, int32_t length) +{ + SparkRowReader row_reader(helper.data_types); + row_reader.pointTo(buffer, length); + writeRowToColumns(helper.mutable_columns, row_reader); + ++helper.rows; +} + +Block * SparkRowToCHColumn::getBlock(SparkRowToCHColumnHelper & helper) +{ + auto * block = new Block(); + if (helper.header.columns()) + { + *block = std::move(helper.header.cloneEmpty()); + block->setColumns(std::move(helper.mutable_columns)); + } + else + { + // In some cases, there is no required columns in spark plan, E.g. count(*). + // In these cases, the rows is the only needed information, so we try to create + // a block with a const column which will not be really used any where. + auto uint8_ty = std::make_shared(); + auto col = uint8_ty->createColumnConst(helper.rows, 0); + ColumnWithTypeAndName named_col(col, uint8_ty, "__anonymous_col__"); + block->insert(named_col); + } + return block; +} + +VariableLengthDataReader::VariableLengthDataReader(const DataTypePtr & type_) + : type(type_), type_without_nullable(removeNullable(type)), which(type_without_nullable) +{ + if (!BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataReader doesn't support type {}", type->getName()); +} + +Field VariableLengthDataReader::read(const char *buffer, size_t length) const +{ + if (which.isStringOrFixedString()) + return std::move(readString(buffer, length)); + + if (which.isDecimal128()) + return std::move(readDecimal(buffer, length)); + + if (which.isArray()) + return std::move(readArray(buffer, length)); + + if (which.isMap()) + return std::move(readMap(buffer, length)); + + if (which.isTuple()) + return std::move(readStruct(buffer, length)); + + throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataReader doesn't support type {}", type->getName()); +} + +StringRef VariableLengthDataReader::readUnalignedBytes(const char * buffer, size_t length) const +{ + return {buffer, length}; +} + +Field VariableLengthDataReader::readDecimal(const char * buffer, size_t length) const +{ + assert(sizeof(Decimal128) <= length); + + char decimal128_fix_data[sizeof(Decimal128)] = {}; + memcpy(decimal128_fix_data + sizeof(Decimal128) - length, buffer, length); // padding + String buf(decimal128_fix_data, sizeof(Decimal128)); + BackingDataLengthCalculator::swapDecimalEndianBytes(buf); // Big-endian to Little-endian + + auto * decimal128 = reinterpret_cast(buf.data()); + const auto * decimal128_type = typeid_cast(type_without_nullable.get()); + return std::move(DecimalField(std::move(*decimal128), decimal128_type->getScale())); +} + +Field VariableLengthDataReader::readString(const char * buffer, size_t length) const +{ + String str(buffer, length); + return std::move(Field(std::move(str))); +} + +Field VariableLengthDataReader::readArray(const char * buffer, [[maybe_unused]] size_t length) const +{ + /// 内存布局:numElements(8B) | null_bitmap(与numElements成正比) | values(每个值长度与类型有关) | backing data + /// Read numElements + int64_t num_elems = 0; + memcpy(&num_elems, buffer, 8); + if (num_elems == 0 || length == 0) + return Array(); + + /// Skip null_bitmap + const auto len_null_bitmap = calculateBitSetWidthInBytes(num_elems); + + /// Read values + const auto * array_type = typeid_cast(type_without_nullable.get()); + const auto & nested_type = array_type->getNestedType(); + const auto elem_size = BackingDataLengthCalculator::getArrayElementSize(nested_type); + const auto len_values = roundNumberOfBytesToNearestWord(elem_size * num_elems); + Array array; + array.reserve(num_elems); + + if (BackingDataLengthCalculator::isFixedLengthDataType(removeNullable(nested_type))) + { + FixedLengthDataReader reader(nested_type); + for (int64_t i = 0; i < num_elems; ++i) + { + if (isBitSet(buffer + 8, i)) + { + array.emplace_back(std::move(Null{})); + } + else + { + const auto elem = reader.read(buffer + 8 + len_null_bitmap + i * elem_size); + array.emplace_back(elem); + } + } + } + else if (BackingDataLengthCalculator::isVariableLengthDataType(removeNullable(nested_type))) + { + VariableLengthDataReader reader(nested_type); + for (int64_t i = 0; i < num_elems; ++i) + { + if (isBitSet(buffer + 8, i)) + { + array.emplace_back(std::move(Null{})); + } + else + { + int64_t offset_and_size = 0; + memcpy(&offset_and_size, buffer + 8 + len_null_bitmap + i * 8, 8); + const int64_t offset = BackingDataLengthCalculator::extractOffset(offset_and_size); + const int64_t size = BackingDataLengthCalculator::extractSize(offset_and_size); + + const auto elem = reader.read(buffer + offset, size); + array.emplace_back(elem); + } + } + } + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataReader doesn't support type {}", nested_type->getName()); + + return std::move(array); +} + +Field VariableLengthDataReader::readMap(const char * buffer, size_t length) const +{ + /// 内存布局:Length of UnsafeArrayData of key(8B) | UnsafeArrayData of key | UnsafeArrayData of value + /// Read Length of UnsafeArrayData of key + int64_t key_array_size = 0; + memcpy(&key_array_size, buffer, 8); + if (key_array_size == 0 || length == 0) + return std::move(Map()); + + /// Read UnsafeArrayData of keys + const auto * map_type = typeid_cast(type_without_nullable.get()); + const auto & key_type = map_type->getKeyType(); + const auto key_array_type = std::make_shared(key_type); + VariableLengthDataReader key_reader(key_array_type); + auto key_field = key_reader.read(buffer + 8, key_array_size); + auto & key_array = key_field.safeGet(); + + /// Read UnsafeArrayData of values + const auto & val_type = map_type->getValueType(); + const auto val_array_type = std::make_shared(val_type); + VariableLengthDataReader val_reader(val_array_type); + auto val_field = val_reader.read(buffer + 8 + key_array_size, length - 8 - key_array_size); + auto & val_array = val_field.safeGet(); + + /// Construct map in CH way [(k1, v1), (k2, v2), ...] + if (key_array.size() != val_array.size()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Key size {} not equal to value size {} in map", key_array.size(), val_array.size()); + Map map(key_array.size()); + for (size_t i = 0; i < key_array.size(); ++i) + { + Tuple tuple(2); + tuple[0] = std::move(key_array[i]); + tuple[1] = std::move(val_array[i]); + + map[i] = std::move(tuple); + } + return std::move(map); +} + +Field VariableLengthDataReader::readStruct(const char * buffer, size_t /*length*/) const +{ + /// 内存布局:null_bitmap(字节数与字段数成正比) | values(num_fields * 8B) | backing data + const auto * tuple_type = typeid_cast(type_without_nullable.get()); + const auto & field_types = tuple_type->getElements(); + const auto num_fields = field_types.size(); + if (num_fields == 0) + return std::move(Tuple()); + + const auto len_null_bitmap = calculateBitSetWidthInBytes(num_fields); + + Tuple tuple(num_fields); + for (size_t i=0; igetName()); + } + return std::move(tuple); +} + +FixedLengthDataReader::FixedLengthDataReader(const DataTypePtr & type_) + : type(type_), type_without_nullable(removeNullable(type)), which(type_without_nullable) +{ + if (!BackingDataLengthCalculator::isFixedLengthDataType(type_without_nullable) || !type_without_nullable->isValueRepresentedByNumber()) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "VariableLengthDataReader doesn't support type {}", type->getName()); + + value_size = type_without_nullable->getSizeOfValueInMemory(); +} + +StringRef FixedLengthDataReader::unsafeRead(const char * buffer) const +{ + return {buffer, value_size}; +} + +Field FixedLengthDataReader::read(const char * buffer) const +{ + if (which.isUInt8()) + { + UInt8 value = 0; + memcpy(&value, buffer, 1); + return value; + } + + if (which.isUInt16() || which.isDate()) + { + UInt16 value = 0; + memcpy(&value, buffer, 2); + return value; + } + + if (which.isUInt32()) + { + UInt32 value = 0; + memcpy(&value, buffer, 4); + return value; + } + + if (which.isUInt64()) + { + UInt64 value = 0; + memcpy(&value, buffer, 8); + return value; + } + + if (which.isInt8()) + { + Int8 value = 0; + memcpy(&value, buffer, 1); + return value; + } + + if (which.isInt16()) + { + Int16 value = 0; + memcpy(&value, buffer, 2); + return value; + } + + if (which.isInt32() || which.isDate32()) + { + Int32 value = 0; + memcpy(&value, buffer, 4); + return value; + } + + if (which.isInt64()) + { + Int64 value = 0; + memcpy(&value, buffer, 8); + return value; + } + + if (which.isFloat32()) + { + Float32 value = 0.0; + memcpy(&value, buffer, 4); + return value; + } + + if (which.isFloat64()) + { + Float64 value = 0.0; + memcpy(&value, buffer, 8); + return value; + } + + if (which.isDecimal32()) + { + Decimal32 value = 0; + memcpy(&value, buffer, 4); + + const auto * decimal32_type = typeid_cast(type_without_nullable.get()); + return std::move(DecimalField{value, decimal32_type->getScale()}); + } + + if (which.isDecimal64() || which.isDateTime64()) + { + Decimal64 value = 0; + memcpy(&value, buffer, 8); + + UInt32 scale = which.isDecimal64() ? typeid_cast(type_without_nullable.get())->getScale() + : typeid_cast(type_without_nullable.get())->getScale(); + return std::move(DecimalField{value, scale}); + } + throw Exception(ErrorCodes::UNKNOWN_TYPE, "FixedLengthDataReader doesn't support type {}", type->getName()); +} + +} diff --git a/utils/local-engine/Parser/SparkRowToCHColumn.h b/utils/local-engine/Parser/SparkRowToCHColumn.h new file mode 100644 index 000000000000..a8de24e6749f --- /dev/null +++ b/utils/local-engine/Parser/SparkRowToCHColumn.h @@ -0,0 +1,374 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "base/types.h" + +namespace DB +{ +namespace ErrorCodes +{ + extern const int UNKNOWN_TYPE; + extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; +} +} +namespace local_engine +{ +using namespace DB; +using namespace std; +struct SparkRowToCHColumnHelper +{ + DataTypes data_types; + Block header; + MutableColumns mutable_columns; + UInt64 rows; + + SparkRowToCHColumnHelper(vector & names, vector & types) + : data_types(names.size()) + { + assert(names.size() == types.size()); + + ColumnsWithTypeAndName columns(names.size()); + for (size_t i = 0; i < names.size(); ++i) + { + data_types[i] = parseType(types[i]); + columns[i] = std::move(ColumnWithTypeAndName(data_types[i], names[i])); + } + + header = std::move(Block(columns)); + resetMutableColumns(); + } + + ~SparkRowToCHColumnHelper() = default; + + void resetMutableColumns() + { + rows = 0; + mutable_columns = std::move(header.mutateColumns()); + } + + static DataTypePtr parseType(const string & type) + { + auto substrait_type = std::make_unique(); + auto ok = substrait_type->ParseFromString(type); + if (!ok) + throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Type from string failed"); + return std::move(SerializedPlanParser::parseType(*substrait_type)); + } +}; + +class SparkRowToCHColumn +{ +public: + static jclass spark_row_interator_class; + static jmethodID spark_row_interator_hasNext; + static jmethodID spark_row_interator_next; + static jmethodID spark_row_iterator_nextBatch; + + // case 1: rows are batched (this is often directly converted from Block) + static std::unique_ptr convertSparkRowInfoToCHColumn(const SparkRowInfo & spark_row_info, const Block & header); + + // case 2: provided with a sequence of spark UnsafeRow, convert them to a Block + static Block * + convertSparkRowItrToCHColumn(jobject java_iter, vector & names, vector & types) + { + SparkRowToCHColumnHelper helper(names, types); + + GET_JNIENV(env) + while (safeCallBooleanMethod(env, java_iter, spark_row_interator_hasNext)) + { + jobject rows_buf = safeCallObjectMethod(env, java_iter, spark_row_iterator_nextBatch); + auto * rows_buf_ptr = static_cast(env->GetDirectBufferAddress(rows_buf)); + int len = *(reinterpret_cast(rows_buf_ptr)); + + // len = -1 means reaching the buf's end. + // len = 0 indicates no columns in the this row. e.g. count(1)/count(*) + while (len >= 0) + { + rows_buf_ptr += 4; + appendSparkRowToCHColumn(helper, rows_buf_ptr, len); + + rows_buf_ptr += len; + len = *(reinterpret_cast(rows_buf_ptr)); + } + + // Try to release reference. + env->DeleteLocalRef(rows_buf); + } + return getBlock(helper); + } + + static void freeBlock(Block * block) + { + delete block; + block = nullptr; + } + +private: + static void appendSparkRowToCHColumn(SparkRowToCHColumnHelper & helper, char * buffer, int32_t length); + static Block * getBlock(SparkRowToCHColumnHelper & helper); +}; + +class VariableLengthDataReader +{ +public: + explicit VariableLengthDataReader(const DataTypePtr& type_); + virtual ~VariableLengthDataReader() = default; + + virtual Field read(const char * buffer, size_t length) const; + virtual StringRef readUnalignedBytes(const char * buffer, size_t length) const; + +private: + virtual Field readDecimal(const char * buffer, size_t length) const; + virtual Field readString(const char * buffer, size_t length) const; + virtual Field readArray(const char * buffer, size_t length) const; + virtual Field readMap(const char * buffer, size_t length) const; + virtual Field readStruct(const char * buffer, size_t length) const; + + const DataTypePtr type; + const DataTypePtr type_without_nullable; + const WhichDataType which; +}; + +class FixedLengthDataReader +{ +public: + explicit FixedLengthDataReader(const DB::DataTypePtr & type_); + virtual ~FixedLengthDataReader() = default; + + virtual Field read(const char * buffer) const; + virtual StringRef unsafeRead(const char * buffer) const; + +private: + const DB::DataTypePtr type; + const DB::DataTypePtr type_without_nullable; + const DB::WhichDataType which; + size_t value_size; +}; +class SparkRowReader +{ +public: + explicit SparkRowReader(const DataTypes & field_types_) + : field_types(field_types_) + , num_fields(field_types.size()) + , bit_set_width_in_bytes(calculateBitSetWidthInBytes(num_fields)) + , field_offsets(num_fields) + , support_raw_datas(num_fields) + , is_big_endians_in_spark_row(num_fields) + , fixed_length_data_readers(num_fields) + , variable_length_data_readers(num_fields) + { + for (auto ordinal = 0; ordinal < num_fields; ++ordinal) + { + const auto type_without_nullable = removeNullable(field_types[ordinal]); + field_offsets[ordinal] = bit_set_width_in_bytes + ordinal * 8L; + support_raw_datas[ordinal] = BackingDataLengthCalculator::isDataTypeSupportRawData(type_without_nullable); + is_big_endians_in_spark_row[ordinal] = BackingDataLengthCalculator::isBigEndianInSparkRow(type_without_nullable); + if (BackingDataLengthCalculator::isFixedLengthDataType(type_without_nullable)) + fixed_length_data_readers[ordinal] = std::make_shared(field_types[ordinal]); + else if (BackingDataLengthCalculator::isVariableLengthDataType(type_without_nullable)) + variable_length_data_readers[ordinal] = std::make_shared(field_types[ordinal]); + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "SparkRowReader doesn't support type {}", field_types[ordinal]->getName()); + } + } + + const DataTypes & getFieldTypes() const + { + return field_types; + } + + bool supportRawData(int ordinal) const + { + assertIndexIsValid(ordinal); + return support_raw_datas[ordinal]; + } + + bool isBigEndianInSparkRow(int ordinal) const + { + assertIndexIsValid(ordinal); + return is_big_endians_in_spark_row[ordinal]; + } + + std::shared_ptr getFixedLengthDataReader(int ordinal) const + { + assertIndexIsValid(ordinal); + return fixed_length_data_readers[ordinal]; + } + + std::shared_ptr getVariableLengthDataReader(int ordinal) const + { + assertIndexIsValid(ordinal); + return variable_length_data_readers[ordinal]; + } + + void assertIndexIsValid([[maybe_unused]] int index) const + { + assert(index >= 0); + assert(index < num_fields); + } + + bool isNullAt(int ordinal) const + { + assertIndexIsValid(ordinal); + return isBitSet(buffer, ordinal); + } + + const char* getRawDataForFixedNumber(int ordinal) const + { + assertIndexIsValid(ordinal); + return reinterpret_cast(getFieldOffset(ordinal)); + } + + int8_t getByte(int ordinal) const + { + assertIndexIsValid(ordinal); + return *reinterpret_cast(getFieldOffset(ordinal)); + } + + uint8_t getUnsignedByte(int ordinal) const + { + assertIndexIsValid(ordinal); + return *reinterpret_cast(getFieldOffset(ordinal)); + } + + int16_t getShort(int ordinal) const + { + assertIndexIsValid(ordinal); + return *reinterpret_cast(getFieldOffset(ordinal)); + } + + uint16_t getUnsignedShort(int ordinal) const + { + assertIndexIsValid(ordinal); + return *reinterpret_cast(getFieldOffset(ordinal)); + } + + int32_t getInt(int ordinal) const + { + assertIndexIsValid(ordinal); + return *reinterpret_cast(getFieldOffset(ordinal)); + } + + uint32_t getUnsignedInt(int ordinal) const + { + assertIndexIsValid(ordinal); + return *reinterpret_cast(getFieldOffset(ordinal)); + } + + int64_t getLong(int ordinal) const + { + assertIndexIsValid(ordinal); + return *reinterpret_cast(getFieldOffset(ordinal)); + } + + float_t getFloat(int ordinal) const + { + assertIndexIsValid(ordinal); + return *reinterpret_cast(getFieldOffset(ordinal)); + } + + double_t getDouble(int ordinal) const + { + assertIndexIsValid(ordinal); + return *reinterpret_cast(getFieldOffset(ordinal)); + } + + StringRef getString(int ordinal) const + { + assertIndexIsValid(ordinal); + int64_t offset_and_size = getLong(ordinal); + int32_t offset = static_cast(offset_and_size >> 32); + int32_t size = static_cast(offset_and_size); + return StringRef(reinterpret_cast(this->buffer + offset), size); + } + + int32_t getStringSize(int ordinal) const + { + assertIndexIsValid(ordinal); + return static_cast(getLong(ordinal)); + } + + void pointTo(const char * buffer_, int32_t length_) + { + buffer = buffer_; + length = length_; + } + + StringRef getStringRef(int ordinal) const + { + assertIndexIsValid(ordinal); + if (!support_raw_datas[ordinal]) + throw Exception( + ErrorCodes::UNKNOWN_TYPE, "SparkRowReader::getStringRef doesn't support type {}", field_types[ordinal]->getName()); + + if (isNullAt(ordinal)) + return StringRef(); + + const auto & fixed_length_data_reader = fixed_length_data_readers[ordinal]; + const auto & variable_length_data_reader = variable_length_data_readers[ordinal]; + if (fixed_length_data_reader) + return std::move(fixed_length_data_reader->unsafeRead(getFieldOffset(ordinal))); + else if (variable_length_data_reader) + { + int64_t offset_and_size = 0; + memcpy(&offset_and_size, buffer + bit_set_width_in_bytes + ordinal * 8, 8); + const int64_t offset = BackingDataLengthCalculator::extractOffset(offset_and_size); + const int64_t size = BackingDataLengthCalculator::extractSize(offset_and_size); + return std::move(variable_length_data_reader->readUnalignedBytes(buffer + offset, size)); + } + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "SparkRowReader::getStringRef doesn't support type {}", field_types[ordinal]->getName()); + } + + Field getField(int ordinal) const + { + assertIndexIsValid(ordinal); + + if (isNullAt(ordinal)) + return std::move(Null{}); + + const auto & fixed_length_data_reader = fixed_length_data_readers[ordinal]; + const auto & variable_length_data_reader = variable_length_data_readers[ordinal]; + + if (fixed_length_data_reader) + return std::move(fixed_length_data_reader->read(getFieldOffset(ordinal))); + else if (variable_length_data_reader) + { + int64_t offset_and_size = 0; + memcpy(&offset_and_size, buffer + bit_set_width_in_bytes + ordinal * 8, 8); + const int64_t offset = BackingDataLengthCalculator::extractOffset(offset_and_size); + const int64_t size = BackingDataLengthCalculator::extractSize(offset_and_size); + return std::move(variable_length_data_reader->read(buffer + offset, size)); + } + else + throw Exception(ErrorCodes::UNKNOWN_TYPE, "SparkRowReader::getField doesn't support type {}", field_types[ordinal]->getName()); + } + +private: + const char * getFieldOffset(int ordinal) const { return buffer + field_offsets[ordinal]; } + + const DataTypes field_types; + const int32_t num_fields; + const int32_t bit_set_width_in_bytes; + std::vector field_offsets; + std::vector support_raw_datas; + std::vector is_big_endians_in_spark_row; + std::vector> fixed_length_data_readers; + std::vector> variable_length_data_readers; + + const char * buffer; + int32_t length; +}; + +} diff --git a/utils/local-engine/Parser/WindowRelParser.cpp b/utils/local-engine/Parser/WindowRelParser.cpp new file mode 100644 index 000000000000..2e0b7a2c7765 --- /dev/null +++ b/utils/local-engine/Parser/WindowRelParser.cpp @@ -0,0 +1,398 @@ +#include "WindowRelParser.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int UNKNOWN_TYPE; + extern const int LOGICAL_ERROR; + extern const int BAD_ARGUMENTS; +} +} +namespace local_engine +{ + +WindowRelParser::WindowRelParser(SerializedPlanParser * plan_paser_) : RelParser(plan_paser_) +{ +} + +DB::QueryPlanPtr +WindowRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & /*rel_stack_*/) +{ + // rel_stack = rel_stack_; + const auto & win_rel_pb = rel.window(); + current_plan = std::move(current_plan_); + auto expected_header = current_plan->getCurrentDataStream().header; + for (const auto & measure : win_rel_pb.measures()) + { + const auto & win_function = measure.measure(); + ColumnWithTypeAndName named_col; + named_col.name = win_function.column_name(); + named_col.type = parseType(win_function.output_type()); + named_col.column = named_col.type->createColumn(); + expected_header.insert(named_col); + } + tryAddProjectionBeforeWindow(*current_plan, win_rel_pb); + + auto window_descriptions = parseWindowDescriptions(win_rel_pb); + + /// In spark plan, there is already a sort step before each window, so we don't need to add sort steps here. + for (auto & it : window_descriptions) + { + auto & win = it.second; + ; + auto window_step = std::make_unique(current_plan->getCurrentDataStream(), win, win.window_functions); + window_step->setStepDescription("Window step for window '" + win.window_name + "'"); + current_plan->addStep(std::move(window_step)); + } + + + auto current_header = current_plan->getCurrentDataStream().header; + if (!DB::blocksHaveEqualStructure(expected_header, current_header)) + { + ActionsDAGPtr convert_action = ActionsDAG::makeConvertingActions( + current_header.getColumnsWithTypeAndName(), + expected_header.getColumnsWithTypeAndName(), + DB::ActionsDAG::MatchColumnsMode::Name); + QueryPlanStepPtr convert_step = std::make_unique(current_plan->getCurrentDataStream(), convert_action); + convert_step->setStepDescription("Convert window Output"); + current_plan->addStep(std::move(convert_step)); + } + + return std::move(current_plan); +} +DB::WindowDescription +WindowRelParser::parseWindowDescrption(const substrait::WindowRel & win_rel, const substrait::Expression::WindowFunction & win_function) +{ + DB::WindowDescription win_descr; + win_descr.frame = parseWindowFrame(win_function); + win_descr.partition_by = parsePartitionBy(win_rel.partition_expressions()); + win_descr.order_by = SortRelParser::parseSortDescription(win_rel.sorts(), current_plan->getCurrentDataStream().header); + win_descr.full_sort_description = win_descr.partition_by; + win_descr.full_sort_description.insert(win_descr.full_sort_description.end(), win_descr.order_by.begin(), win_descr.order_by.end()); + + DB::WriteBufferFromOwnString ss; + ss << "partition by " << DB::dumpSortDescription(win_descr.partition_by); + ss << " order by " << DB::dumpSortDescription(win_descr.order_by); + ss << " " << win_descr.frame.toString(); + win_descr.window_name = ss.str(); + return win_descr; +} + +std::unordered_map WindowRelParser::parseWindowDescriptions(const substrait::WindowRel & win_rel) +{ + std::unordered_map window_descriptions; + for (int i = 0; i < win_rel.measures_size(); ++i) + { + const auto & measure = win_rel.measures(i); + const auto & win_function = measure.measure(); + auto win_descr = parseWindowDescrption(win_rel, win_function); + WindowDescription * description = nullptr; + const auto win_it = window_descriptions.find(win_descr.window_name); + if (win_it != window_descriptions.end()) + description = &win_it->second; + else + { + window_descriptions[win_descr.window_name] = win_descr; + description = &window_descriptions[win_descr.window_name]; + } + auto win_func = parseWindowFunctionDescription(win_rel, win_function, measures_arg_names[i], measures_arg_types[i]); + description->window_functions.emplace_back(win_func); + } + return window_descriptions; +} + +DB::WindowFrame WindowRelParser::parseWindowFrame(const substrait::Expression::WindowFunction & window_function) +{ + auto function_name = parseFunctionName(window_function.function_reference()); + if (!function_name) + function_name = ""; + DB::WindowFrame win_frame; + win_frame.type = parseWindowFrameType(*function_name, window_function); + parseBoundType( + *function_name, window_function.lower_bound(), true, win_frame.begin_type, win_frame.begin_offset, win_frame.begin_preceding); + parseBoundType(*function_name, window_function.upper_bound(), false, win_frame.end_type, win_frame.end_offset, win_frame.end_preceding); + + // special cases + if (*function_name == "lead" || *function_name == "lag") + { + win_frame.begin_preceding = true; + win_frame.end_preceding = false; + } + return win_frame; +} + +DB::WindowFrame::FrameType WindowRelParser::parseWindowFrameType(const std::string & function_name, const substrait::Expression::WindowFunction & window_function) +{ + // It's weird! The frame type only could be rows in spark for rank(). But in clickhouse + // it's should be range. If run rank() over rows frame, the result is different. The rank number + // is different for the same values. + static const std::unordered_map special_function_frame_type + = { + {"rank", substrait::RANGE}, + {"dense_rank", substrait::RANGE}, + }; + + substrait::WindowType frame_type; + auto iter = special_function_frame_type.find(function_name); + if (iter != special_function_frame_type.end()) + frame_type = iter->second; + else + frame_type = window_function.window_type(); + + if (frame_type == substrait::ROWS) + { + return DB::WindowFrame::FrameType::ROWS; + } + else if (frame_type == substrait::RANGE) + { + return DB::WindowFrame::FrameType::RANGE; + } + else + { + throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Unknow window frame type:{}", frame_type); + } +} + +void WindowRelParser::parseBoundType( + const std::string & function_name, + const substrait::Expression::WindowFunction::Bound & bound, + bool is_begin_or_end, + DB::WindowFrame::BoundaryType & bound_type, + Field & offset, + bool & preceding_direction) +{ + /// some default settings. + offset = 0; + + if (bound.has_preceding()) + { + const auto & preceding = bound.preceding(); + bound_type = DB::WindowFrame::BoundaryType::Offset; + preceding_direction = preceding.offset() >= 0; + if (preceding.offset() < 0) + { + offset = 0 - preceding.offset(); + } + else + { + offset = preceding.offset(); + } + } + else if (bound.has_following()) + { + const auto & following = bound.following(); + bound_type = DB::WindowFrame::BoundaryType::Offset; + preceding_direction = following.offset() < 0; + if (following.offset() < 0) + { + offset = 0 - following.offset(); + } + else + { + offset = following.offset(); + } + } + else if (bound.has_current_row()) + { + const auto & current_row = bound.current_row(); + bound_type = DB::WindowFrame::BoundaryType::Current; + offset = 0; + preceding_direction = is_begin_or_end; + } + else if (bound.has_unbounded_preceding()) + { + bound_type = DB::WindowFrame::BoundaryType::Unbounded; + offset = 0; + preceding_direction = true; + } + else if (bound.has_unbounded_following()) + { + bound_type = DB::WindowFrame::BoundaryType::Unbounded; + offset = 0; + preceding_direction = false; + } + else + { + throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Unknown bound type:{}", bound.DebugString()); + } +} + + +DB::SortDescription WindowRelParser::parsePartitionBy(const google::protobuf::RepeatedPtrField & expressions) +{ + DB::Block header = current_plan->getCurrentDataStream().header; + DB::SortDescription sort_descr; + for (const auto & expr : expressions) + { + if (!expr.has_selection()) + { + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Column reference is expected."); + } + auto pos = expr.selection().direct_reference().struct_field().field(); + auto col_name = header.getByPosition(pos).name; + sort_descr.push_back(DB::SortColumnDescription(col_name, 1, 1)); + } + return sort_descr; +} + +WindowFunctionDescription WindowRelParser::parseWindowFunctionDescription( + const substrait::WindowRel & win_rel, + const substrait::Expression::WindowFunction & window_function, + const DB::Names & arg_names, + const DB::DataTypes & arg_types) +{ + auto header = current_plan->getCurrentDataStream().header; + WindowFunctionDescription description; + description.column_name = window_function.column_name(); + description.function_node = nullptr; + + auto function_name = parseFunctionName(window_function.function_reference()); + if (!function_name) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Not found function for reference: {}", window_function.function_reference()); + + DB::AggregateFunctionProperties agg_function_props; + // Special transform for lead/lag + if (*function_name == "lead" || *function_name == "lag") + { + if (*function_name == "lead") + function_name = "leadInFrame"; + else + function_name = "lagInFrame"; + auto agg_function_ptr = getAggregateFunction(*function_name, arg_types, agg_function_props); + + description.argument_names = arg_names; + description.argument_types = arg_types; + description.aggregate_function = agg_function_ptr; + } + else + { + auto agg_function_ptr = getAggregateFunction(*function_name, arg_types, agg_function_props); + + description.argument_names = arg_names; + description.argument_types = arg_types; + description.aggregate_function = agg_function_ptr; + } + + return description; +} + +void WindowRelParser::tryAddProjectionBeforeWindow( + QueryPlan & plan, const substrait::WindowRel & win_rel) +{ + auto header = plan.getCurrentDataStream().header; + ActionsDAGPtr actions_dag = std::make_shared(header.getColumnsWithTypeAndName()); + bool need_project = false; + for (const auto & measure : win_rel.measures()) + { + DB::Names names; + DB::DataTypes types; + auto function_name = parseFunctionName(measure.measure().function_reference()); + if (function_name && (*function_name == "lead" || *function_name == "lag")) + { + const auto & arg0 = measure.measure().arguments(0).value(); + const auto & col = header.getByPosition(arg0.selection().direct_reference().struct_field().field()); + names.emplace_back(col.name); + types.emplace_back(col.type); + + auto arg1 = measure.measure().arguments(1).value(); + const DB::ActionsDAG::Node * node = nullptr; + // lag's offset is negative + if (*function_name == "lag") + { + auto literal_result = parseLiteral(arg1.literal()); + assert(literal_result.second.safeGet() < 0); + auto real_field = 0 - literal_result.second.safeGet(); + node = &actions_dag->addColumn(ColumnWithTypeAndName( + literal_result.first->createColumnConst(1, real_field), literal_result.first, getUniqueName(toString(real_field)))); + } + else + { + node = parseArgument(actions_dag, arg1); + } + node = ActionsDAGUtil::convertNodeType(actions_dag, node, DataTypeInt64().getName()); + actions_dag->addOrReplaceInOutputs(*node); + names.emplace_back(node->result_name); + types.emplace_back(node->result_type); + + const auto & arg2 = measure.measure().arguments(2).value(); + if (arg2.has_literal() && !arg2.literal().has_null()) + { + node = parseArgument(actions_dag, arg2); + actions_dag->addOrReplaceInOutputs(*node); + names.emplace_back(node->result_name); + types.emplace_back(node->result_type); + } + need_project = true; + } + else + { + for (int i = 0, n = measure.measure().arguments().size(); i < n; ++i) + { + const auto & arg = measure.measure().arguments(i).value(); + if (arg.has_selection()) + { + const auto & col = header.getByPosition(arg.selection().direct_reference().struct_field().field()); + names.push_back(col.name); + types.emplace_back(col.type); + } + else if (arg.has_literal()) + { + // for example, sum(2) over(...), we need to add new const column for 2, otherwise + // an exception of not found column(2) will throw. + const auto * node = parseArgument(actions_dag, arg); + names.push_back(node->result_name); + types.emplace_back(node->result_type); + actions_dag->addOrReplaceInOutputs(*node); + need_project = true; + } + else + { + // There should be a projections ahead to eliminate complex expressions. + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported aggregate argument type {}.", arg.DebugString()); + } + } + } + measures_arg_names.emplace_back(std::move(names)); + measures_arg_types.emplace_back(std::move(types)); + } + if (need_project) + { + auto project_step = std::make_unique(plan.getCurrentDataStream(), actions_dag); + project_step->setStepDescription("Add projections before aggregation"); + plan.addStep(std::move(project_step)); + } +} + + +void registerWindowRelParser(RelParserFactory & factory) +{ + auto builder = [](SerializedPlanParser * plan_paser) { return std::make_shared(plan_paser); }; + factory.registerBuilder(substrait::Rel::RelTypeCase::kWindow, builder); + +} +} diff --git a/utils/local-engine/Parser/WindowRelParser.h b/utils/local-engine/Parser/WindowRelParser.h new file mode 100644 index 000000000000..e7e400507acc --- /dev/null +++ b/utils/local-engine/Parser/WindowRelParser.h @@ -0,0 +1,57 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace local_engine +{ +class WindowRelParser : public RelParser +{ +public: + explicit WindowRelParser(SerializedPlanParser * plan_paser_); + ~WindowRelParser() override = default; + DB::QueryPlanPtr parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & rel_stack_) override; + +private: + DB::QueryPlanPtr current_plan; + // std::list * rel_stack; + Poco::Logger * logger = &Poco::Logger::get("WindowRelParser"); + // for constructing aggregate function argument names + std::vector measures_arg_names; + std::vector measures_arg_types; + + /// There will be window descrptions generated for different window frame type; + std::unordered_map parseWindowDescriptions(const substrait::WindowRel & win_rel); + + // Build a window description in CH with respect to a window function, since the same + // function may have different window frame in CH and spark. + DB::WindowDescription + parseWindowDescrption(const substrait::WindowRel & win_rel, const substrait::Expression::WindowFunction & win_function); + DB::WindowFrame parseWindowFrame(const substrait::Expression::WindowFunction & window_function); + DB::WindowFrame::FrameType + parseWindowFrameType(const std::string & function_name, const substrait::Expression::WindowFunction & window_function); + static void parseBoundType( + const std::string & function_name, + const substrait::Expression::WindowFunction::Bound & bound, + bool is_begin_or_end, + DB::WindowFrame::BoundaryType & bound_type, + Field & offset, + bool & preceding); + DB::SortDescription parsePartitionBy(const google::protobuf::RepeatedPtrField & expressions); + DB::WindowFunctionDescription parseWindowFunctionDescription( + const substrait::WindowRel & win_rel, + const substrait::Expression::WindowFunction & window_function, + const DB::Names & arg_names, + const DB::DataTypes & arg_types); + + void tryAddProjectionBeforeWindow(QueryPlan & plan, const substrait::WindowRel & win_rel); + +}; + + +} diff --git a/utils/local-engine/Shuffle/CMakeLists.txt b/utils/local-engine/Shuffle/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/local-engine/Shuffle/NativeSplitter.cpp b/utils/local-engine/Shuffle/NativeSplitter.cpp new file mode 100644 index 000000000000..07f11d9b2eaf --- /dev/null +++ b/utils/local-engine/Shuffle/NativeSplitter.cpp @@ -0,0 +1,238 @@ +#include "NativeSplitter.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +jclass NativeSplitter::iterator_class = nullptr; +jmethodID NativeSplitter::iterator_has_next = nullptr; +jmethodID NativeSplitter::iterator_next = nullptr; + +void NativeSplitter::split(DB::Block & block) +{ + if (block.rows() == 0) + { + return; + } + if (!output_header.columns()) [[unlikely]] + { + if (output_columns_indicies.empty()) + { + output_header = block.cloneEmpty(); + for (size_t i = 0; i < block.columns(); ++i) + { + output_columns_indicies.push_back(i); + } + } + else + { + DB::ColumnsWithTypeAndName cols; + for (const auto & index : output_columns_indicies) + { + cols.push_back(block.getByPosition(index)); + } + output_header = DB::Block(cols); + } + } + computePartitionId(block); + DB::Block out_block; + for (size_t col = 0; col < output_header.columns(); ++col) + { + out_block.insert(block.getByPosition(output_columns_indicies[col])); + } + for (size_t col = 0; col < output_header.columns(); ++col) + { + for (size_t j = 0; j < partition_info.partition_num; ++j) + { + size_t from = partition_info.partition_start_points[j]; + size_t length = partition_info.partition_start_points[j + 1] - from; + if (length == 0) + continue; // no data for this partition continue; + partition_buffer[j]->appendSelective(col, out_block, partition_info.partition_selector, from, length); + } + } + + bool has_active_sender = false; + for (size_t i = 0; i < options.partition_nums; ++i) + { + if (partition_buffer[i]->size() >= options.buffer_size) + { + output_buffer.emplace(std::pair(i, std::make_unique(partition_buffer[i]->releaseColumns()))); + } + } +} + +NativeSplitter::NativeSplitter(Options options_, jobject input_) : options(options_) +{ + GET_JNIENV(env) + input = env->NewGlobalRef(input_); + partition_buffer.reserve(options.partition_nums); + for (size_t i = 0; i < options.partition_nums; ++i) + { + partition_buffer.emplace_back(std::make_shared(options.buffer_size)); + } + CLEAN_JNIENV +} + +NativeSplitter::~NativeSplitter() +{ + GET_JNIENV(env) + env->DeleteGlobalRef(input); + CLEAN_JNIENV +} + +bool NativeSplitter::hasNext() +{ + while (output_buffer.empty()) + { + if (inputHasNext()) + { + split(*reinterpret_cast(inputNext())); + } + else + { + for (size_t i = 0; i < options.partition_nums; ++i) + { + auto buffer = partition_buffer.at(i); + if (buffer->size() > 0) + { + output_buffer.emplace(std::pair(i, new Block(buffer->releaseColumns()))); + } + } + break; + } + } + if (!output_buffer.empty()) + { + next_partition_id = output_buffer.top().first; + setCurrentBlock(*output_buffer.top().second); + produce(); + } + return !output_buffer.empty(); +} + +DB::Block * NativeSplitter::next() +{ + if (!output_buffer.empty()) + { + output_buffer.pop(); + } + consume(); + return ¤tBlock(); +} + +int32_t NativeSplitter::nextPartitionId() +{ + return next_partition_id; +} + +bool NativeSplitter::inputHasNext() +{ + GET_JNIENV(env) + bool next = safeCallBooleanMethod(env, input, iterator_has_next); + CLEAN_JNIENV + return next; +} + +int64_t NativeSplitter::inputNext() +{ + GET_JNIENV(env) + int64_t result = safeCallLongMethod(env, input, iterator_next); + CLEAN_JNIENV + return result; +} +std::unique_ptr NativeSplitter::create(const std::string & short_name, Options options_, jobject input) +{ + if (short_name == "rr") + { + return std::make_unique(options_, input); + } + else if (short_name == "hash") + { + return std::make_unique(options_, input); + } + else if (short_name == "single") + { + options_.partition_nums = 1; + return std::make_unique(options_, input); + } + else if (short_name == "range") + { + return std::make_unique(options_, input); + } + else + { + throw std::runtime_error("unsupported splitter " + short_name); + } +} + +HashNativeSplitter::HashNativeSplitter(NativeSplitter::Options options_, jobject input) + : NativeSplitter(options_, input) +{ + Poco::StringTokenizer exprs_list(options_.exprs_buffer, ","); + std::vector hash_fields; + for (auto iter = exprs_list.begin(); iter != exprs_list.end(); ++iter) + { + hash_fields.push_back(std::stoi(*iter)); + } + + Poco::StringTokenizer output_column_tokenizer(options_.schema_buffer, ","); + for (auto iter = output_column_tokenizer.begin(); iter != output_column_tokenizer.end(); ++iter) + { + output_columns_indicies.push_back(std::stoi(*iter)); + } + + selector_builder = std::make_unique(options.partition_nums, hash_fields, "cityHash64"); +} + +void HashNativeSplitter::computePartitionId(Block & block) +{ + partition_info = selector_builder->build(block); +} + +RoundRobinNativeSplitter::RoundRobinNativeSplitter(NativeSplitter::Options options_, jobject input) : NativeSplitter(options_, input) +{ + Poco::StringTokenizer output_column_tokenizer(options_.schema_buffer, ","); + for (auto iter = output_column_tokenizer.begin(); iter != output_column_tokenizer.end(); ++iter) + { + output_columns_indicies.push_back(std::stoi(*iter)); + } + selector_builder = std::make_unique(options_.partition_nums); +} + +void RoundRobinNativeSplitter::computePartitionId(Block & block) +{ + partition_info = selector_builder->build(block); +} + +RangePartitionNativeSplitter::RangePartitionNativeSplitter(NativeSplitter::Options options_, jobject input) + : NativeSplitter(options_, input) +{ + Poco::StringTokenizer output_column_tokenizer(options_.schema_buffer, ","); + for (auto iter = output_column_tokenizer.begin(); iter != output_column_tokenizer.end(); ++iter) + { + output_columns_indicies.push_back(std::stoi(*iter)); + } + selector_builder = std::make_unique(options_.exprs_buffer, options_.partition_nums); +} + +void RangePartitionNativeSplitter::computePartitionId(DB::Block & block) +{ + partition_info = selector_builder->build(block); +} + +} diff --git a/utils/local-engine/Shuffle/NativeSplitter.h b/utils/local-engine/Shuffle/NativeSplitter.h new file mode 100644 index 000000000000..30c647b2fc93 --- /dev/null +++ b/utils/local-engine/Shuffle/NativeSplitter.h @@ -0,0 +1,99 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +class NativeSplitter : BlockIterator +{ +public: + struct Options + { + size_t buffer_size = DEFAULT_BLOCK_SIZE; + size_t partition_nums; + std::string exprs_buffer; + std::string schema_buffer; + }; + + struct Holder + { + std::unique_ptr splitter = nullptr; + }; + + static jclass iterator_class; + static jmethodID iterator_has_next; + static jmethodID iterator_next; + static std::unique_ptr create(const std::string & short_name, Options options, jobject input); + + NativeSplitter(Options options, jobject input); + bool hasNext(); + DB::Block * next(); + int32_t nextPartitionId(); + + + virtual ~NativeSplitter(); + +protected: + virtual void computePartitionId(DB::Block &) { } + Options options; + PartitionInfo partition_info; + std::vector output_columns_indicies; + DB::Block output_header; + +private: + void split(DB::Block & block); + int64_t inputNext(); + bool inputHasNext(); + + + std::vector> partition_buffer; + std::stack>> output_buffer; + int32_t next_partition_id = -1; + jobject input; +}; + +class HashNativeSplitter : public NativeSplitter +{ + void computePartitionId(DB::Block & block) override; + +public: + HashNativeSplitter(NativeSplitter::Options options_, jobject input); + +private: + std::unique_ptr selector_builder; +}; + +class RoundRobinNativeSplitter : public NativeSplitter +{ + void computePartitionId(DB::Block & block) override; + +public: + RoundRobinNativeSplitter(NativeSplitter::Options options_, jobject input); + +private: + std::unique_ptr selector_builder; +}; + +class RangePartitionNativeSplitter : public NativeSplitter +{ + void computePartitionId(DB::Block & block) override; +public: + RangePartitionNativeSplitter(NativeSplitter::Options options_, jobject input); + ~RangePartitionNativeSplitter() override = default; +private: + std::unique_ptr selector_builder; +}; + +} diff --git a/utils/local-engine/Shuffle/NativeWriterInMemory.cpp b/utils/local-engine/Shuffle/NativeWriterInMemory.cpp new file mode 100644 index 000000000000..84135cdfc7e1 --- /dev/null +++ b/utils/local-engine/Shuffle/NativeWriterInMemory.cpp @@ -0,0 +1,25 @@ +#include "NativeWriterInMemory.h" + +using namespace DB; + +namespace local_engine +{ + +NativeWriterInMemory::NativeWriterInMemory() +{ + write_buffer = std::make_unique(); +} +void NativeWriterInMemory::write(Block & block) +{ + if (block.columns() == 0 || block.rows() == 0) return; + if (!writer) + { + writer = std::make_unique(*write_buffer, 0, block.cloneEmpty()); + } + writer->write(block); +} +std::string & NativeWriterInMemory::collect() +{ + return write_buffer->str(); +} +} diff --git a/utils/local-engine/Shuffle/NativeWriterInMemory.h b/utils/local-engine/Shuffle/NativeWriterInMemory.h new file mode 100644 index 000000000000..a92c14353a1f --- /dev/null +++ b/utils/local-engine/Shuffle/NativeWriterInMemory.h @@ -0,0 +1,19 @@ +#pragma once +#include +#include + +namespace local_engine +{ +class NativeWriterInMemory +{ +public: + NativeWriterInMemory(); + void write(DB::Block & block); + std::string & collect(); + +private: + std::unique_ptr write_buffer; + //lazy init + std::unique_ptr writer; +}; +} diff --git a/utils/local-engine/Shuffle/SelectorBuilder.cpp b/utils/local-engine/Shuffle/SelectorBuilder.cpp new file mode 100644 index 000000000000..02f20bcdbc0d --- /dev/null +++ b/utils/local-engine/Shuffle/SelectorBuilder.cpp @@ -0,0 +1,342 @@ +#include "SelectorBuilder.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +} + +namespace local_engine +{ + +PartitionInfo PartitionInfo::fromSelector(DB::IColumn::Selector selector, size_t partition_num) +{ + auto rows = selector.size(); + std::vector partition_row_idx_start_points(partition_num + 1, 0); + IColumn::Selector partition_selector(rows, 0); + for (size_t i = 0; i < rows; ++i) + { + partition_row_idx_start_points[selector[i]]++; + } + + for (size_t i = 1; i <= partition_num; ++i) + { + partition_row_idx_start_points[i] += partition_row_idx_start_points[i - 1]; + } + for (size_t i = rows; i-- > 0;) + { + partition_selector[partition_row_idx_start_points[selector[i]] - 1] = i; + partition_row_idx_start_points[selector[i]]--; + } + return PartitionInfo{.partition_selector = std::move(partition_selector), .partition_start_points = partition_row_idx_start_points, + .partition_num = partition_num}; +} + +PartitionInfo RoundRobinSelectorBuilder::build(DB::Block & block) +{ + DB::IColumn::Selector result; + result.resize_fill(block.rows(), 0); + for (auto & pid : result) + { + pid = pid_selection; + pid_selection = (pid_selection + 1) % parts_num; + } + return PartitionInfo::fromSelector(std::move(result), parts_num); +} + +HashSelectorBuilder::HashSelectorBuilder( + UInt32 parts_num_, + const std::vector & exprs_index_, + const std::string & hash_function_name_) + : parts_num(parts_num_), exprs_index(exprs_index_), hash_function_name(hash_function_name_) +{ +} + +PartitionInfo HashSelectorBuilder::build(DB::Block & block) +{ + ColumnsWithTypeAndName args; + auto rows = block.rows(); + for (size_t i = 0; i < exprs_index.size(); i++) + { + args.emplace_back(block.getByPosition(exprs_index.at(i))); + } + + if (!hash_function) [[unlikely]] + { + auto & factory = DB::FunctionFactory::instance(); + auto function = factory.get(hash_function_name, local_engine::SerializedPlanParser::global_context); + + hash_function = function->build(args); + } + DB::IColumn::Selector partition_ids; + partition_ids.reserve(rows); + auto result_type = hash_function->getResultType(); + auto hash_column = hash_function->execute(args, result_type, rows, false); + + for (size_t i = 0; i < block.rows(); i++) + { + partition_ids.emplace_back(static_cast(hash_column->get64(i) % parts_num)); + } + return PartitionInfo::fromSelector(std::move(partition_ids), parts_num); +} + + +static std::map> direction_map = { + {1, {1, -1}}, + {2, {1, 1}}, + {3, {-1, 1}}, + {4, {-1, -1}} +}; + +RangeSelectorBuilder::RangeSelectorBuilder(const std::string & option, const size_t partition_num_) +{ + Poco::JSON::Parser parser; + auto info = parser.parse(option).extract(); + auto ordering_infos = info->get("ordering").extract(); + initSortInformation(ordering_infos); + initRangeBlock(info->get("range_bounds").extract()); + partition_num = partition_num_; +} + +PartitionInfo RangeSelectorBuilder::build(DB::Block & block) +{ + DB::IColumn::Selector result; + computePartitionIdByBinarySearch(block, result); + return PartitionInfo::fromSelector(std::move(result), partition_num); +} + +void RangeSelectorBuilder::initSortInformation(Poco::JSON::Array::Ptr orderings) +{ + for (size_t i = 0; i < orderings->size(); ++i) + { + auto ordering = orderings->get(i).extract(); + auto col_pos = ordering->get("column_ref").convert(); + auto col_name = ordering->get("column_name").convert(); + + auto sort_direction = ordering->get("direction").convert(); + auto d_iter = direction_map.find(sort_direction); + if (d_iter == direction_map.end()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupported sorting direction:{}", sort_direction); + } + DB::SortColumnDescription ch_col_sort_descr(col_name, d_iter->second.first, d_iter->second.second); + sort_descriptions.emplace_back(ch_col_sort_descr); + + auto type_name = ordering->get("data_type").convert(); + auto type = SerializedPlanParser::parseType(type_name); + SortFieldTypeInfo info; + info.inner_type = type; + info.is_nullable = ordering->get("is_nullable").convert(); + sort_field_types.emplace_back(info); + sorting_key_columns.emplace_back(col_pos); + } +} + +void RangeSelectorBuilder::initRangeBlock(Poco::JSON::Array::Ptr range_bounds) +{ + DB::ColumnsWithTypeAndName columns; + for (size_t i = 0; i < sort_field_types.size(); ++i) + { + auto & type_info = sort_field_types[i]; + auto inner_col = type_info.inner_type->createColumn(); + auto data_type = type_info.inner_type; + DB::MutableColumnPtr col = std::move(inner_col); + if (type_info.is_nullable) + { + col = ColumnNullable::create(std::move(col), DB::ColumnUInt8::create(0, 0)); + data_type = std::make_shared(data_type); + } + for (size_t r = 0; r < range_bounds->size(); ++r) + { + auto row = range_bounds->get(r).extract(); + auto field_info = row->get(i).extract(); + if (field_info->get("is_null").convert()) + { + col->insertData(nullptr, 0); + } + else + { + const auto & type_name = type_info.inner_type->getName(); + const auto & field_value = field_info->get("value"); + if (type_name == "UInt8") + { + col->insert(static_cast(field_value.convert())); + } + else if (type_name == "Int8") + { + col->insert(field_value.convert()); + } + else if (type_name == "Int16") + { + col->insert(field_value.convert()); + } + else if (type_name == "Int32") + { + col->insert(field_value.convert()); + } + else if(type_name == "Int64") + { + col->insert(field_value.convert()); + } + else if (type_name == "Float32") + { + col->insert(field_value.convert()); + } + else if (type_name == "Float64") + { + col->insert(field_value.convert()); + } + else if (type_name == "String") + { + col->insert(field_value.convert()); + } + else if (type_name == "Date") + { + int val = field_value.convert(); + col->insert(val); + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupported data type: {}", type_info.inner_type->getName()); + } + } + } + auto col_name = "sort_col_" + std::to_string(i); + columns.emplace_back(std::move(col), data_type, col_name); + } + range_bounds_block = DB::Block(columns); +} + +void RangeSelectorBuilder::initActionsDAG(const DB::Block & block) +{ + std::lock_guard lock(actions_dag_mutex); + if (has_init_actions_dag) + return; + SerializedPlanParser plan_parser(local_engine::SerializedPlanParser::global_context); + plan_parser.parseExtensions(projection_plan_pb->extensions()); + + const auto & expressions = projection_plan_pb->relations().at(0).root().input().project().expressions(); + std::vector exprs; + exprs.reserve(expressions.size()); + for (const auto & expression: expressions) + exprs.emplace_back(expression); + + auto projection_actions_dag + = plan_parser.expressionsToActionsDAG(exprs, block, block); + projection_expression_actions = std::make_unique(projection_actions_dag); + has_init_actions_dag = true; +} + +void RangeSelectorBuilder::computePartitionIdByBinarySearch(DB::Block & block, DB::IColumn::Selector & selector) +{ + Chunks chunks; + Chunk chunk(block.getColumns(), block.rows()); + chunks.emplace_back(std::move(chunk)); + selector.clear(); + selector.reserve(block.rows()); + auto input_columns = block.getColumns(); + auto total_rows = block.rows(); + const auto & bounds_columns = range_bounds_block.getColumns(); + auto max_part = bounds_columns[0]->size(); + for (size_t i = 0; i < bounds_columns.size(); i++) + { + if (bounds_columns[i]->isNullable() && !input_columns[sorting_key_columns[i]]->isNullable()) + { + input_columns[sorting_key_columns[i]] = makeNullable(input_columns[sorting_key_columns[i]]); + } + } + for (size_t r = 0; r < total_rows; ++r) + { + size_t selected_partition = 0; + auto ret = binarySearchBound(bounds_columns, 0, max_part - 1, input_columns, sorting_key_columns, r); + if (ret >= 0) + selected_partition = ret; + else + selected_partition = max_part; + selector.emplace_back(selected_partition); + } +} + +int RangeSelectorBuilder::compareRow( + const DB::Columns & columns, + const std::vector & required_columns, + size_t row, + const DB::Columns & bound_columns, + size_t bound_row) +{ + for(size_t i = 0, n = required_columns.size(); i < n; ++i) + { + auto lpos = required_columns[i]; + auto rpos = i; + auto res = columns[lpos]->compareAt(row, bound_row, *bound_columns[rpos], sort_descriptions[i].nulls_direction) + * sort_descriptions[i].direction; + if (res != 0) + { + return res; + } + } + return 0; +} + +// If there were elements in range[l,r] that are larger then the row +// the return the min element's index. otherwise return -1 +int RangeSelectorBuilder::binarySearchBound( + const DB::Columns & bound_columns, + Int64 l, + Int64 r, + const DB::Columns & columns, + const std::vector & used_cols, + size_t row) +{ + if (l > r) + { + return -1; + } + auto m = (l + r) >> 1; + auto cmp_ret = compareRow(columns, used_cols, row, bound_columns, m); + if (l == r) + { + if (cmp_ret <= 0) + return static_cast(m); + else + return -1; + } + + if (cmp_ret == 0) + return static_cast(m); + if (cmp_ret < 0) + { + cmp_ret = binarySearchBound(bound_columns, l, m - 1, columns, used_cols, row); + if (cmp_ret < 0) + { + // m is the upper bound + return static_cast(m); + } + return cmp_ret; + + } + else + { + cmp_ret = binarySearchBound(bound_columns, m + 1, r, columns, used_cols, row); + if (cmp_ret < 0) + return -1; + else + return cmp_ret; + } + __builtin_unreachable(); +} +} diff --git a/utils/local-engine/Shuffle/SelectorBuilder.h b/utils/local-engine/Shuffle/SelectorBuilder.h new file mode 100644 index 000000000000..58e708df5baf --- /dev/null +++ b/utils/local-engine/Shuffle/SelectorBuilder.h @@ -0,0 +1,95 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace local_engine +{ + +struct PartitionInfo +{ + DB::IColumn::Selector partition_selector; + std::vector partition_start_points; + size_t partition_num; + + static PartitionInfo fromSelector(DB::IColumn::Selector selector, size_t partition_num); +}; + +class RoundRobinSelectorBuilder +{ +public: + explicit RoundRobinSelectorBuilder(size_t parts_num_) : parts_num(parts_num_) {} + PartitionInfo build(DB::Block & block); +private: + size_t parts_num; + Int32 pid_selection = 0; +}; + +class HashSelectorBuilder +{ +public: + explicit HashSelectorBuilder( + UInt32 parts_num_, + const std::vector & exprs_index_, + const std::string & hash_function_name_); + PartitionInfo build(DB::Block & block); +private: + UInt32 parts_num; + std::vector exprs_index; + std::string hash_function_name; + DB::FunctionBasePtr hash_function; +}; + +class RangeSelectorBuilder +{ +public: + explicit RangeSelectorBuilder(const std::string & options_, const size_t partition_num_); + PartitionInfo build(DB::Block & block); +private: + DB::SortDescription sort_descriptions; + std::vector sorting_key_columns; + struct SortFieldTypeInfo + { + DB::DataTypePtr inner_type; + bool is_nullable = false; + }; + std::vector sort_field_types; + DB::Block range_bounds_block; + + // If the ordering keys have expressions, we caculate the expressions here. + std::mutex actions_dag_mutex; + std::unique_ptr projection_plan_pb; + std::atomic has_init_actions_dag; + std::unique_ptr projection_expression_actions; + size_t partition_num; + + void initSortInformation(Poco::JSON::Array::Ptr orderings); + void initRangeBlock(Poco::JSON::Array::Ptr range_bounds); + void initActionsDAG(const DB::Block & block); + + void computePartitionIdByBinarySearch(DB::Block & block, DB::IColumn::Selector & selector); + int compareRow( + const DB::Columns & columns, + const std::vector & required_columns, + size_t row, + const DB::Columns & bound_columns, + size_t bound_row); + + int binarySearchBound( + const DB::Columns & bound_columns, + Int64 l, + Int64 r, + const DB::Columns & columns, + const std::vector & used_cols, + size_t row); +}; + +} diff --git a/utils/local-engine/Shuffle/ShuffleReader.cpp b/utils/local-engine/Shuffle/ShuffleReader.cpp new file mode 100644 index 000000000000..7b5ae1599ac0 --- /dev/null +++ b/utils/local-engine/Shuffle/ShuffleReader.cpp @@ -0,0 +1,69 @@ +#include "ShuffleReader.h" +#include +#include +#include +#include + +using namespace DB; + +namespace local_engine +{ + +local_engine::ShuffleReader::ShuffleReader(std::unique_ptr in_, bool compressed) : in(std::move(in_)) +{ + if (compressed) + { + compressed_in = std::make_unique(*in); + input_stream = std::make_unique(*compressed_in, 0); + } + else + { + input_stream = std::make_unique(*in, 0); + } +} +Block * local_engine::ShuffleReader::read() +{ + auto block = input_stream->read(); + setCurrentBlock(block); + if (unlikely(header.columns() == 0)) + header = currentBlock().cloneEmpty(); + return ¤tBlock(); +} +ShuffleReader::~ShuffleReader() +{ + in.reset(); + compressed_in.reset(); + input_stream.reset(); +} + +jclass ShuffleReader::input_stream_class = nullptr; +jmethodID ShuffleReader::input_stream_read = nullptr; + +bool ReadBufferFromJavaInputStream::nextImpl() +{ + int count = readFromJava(); + if (count > 0) + { + working_buffer.resize(count); + } + return count > 0; +} +int ReadBufferFromJavaInputStream::readFromJava() +{ + GET_JNIENV(env) + jint count = safeCallIntMethod(env, java_in, ShuffleReader::input_stream_read, reinterpret_cast(working_buffer.begin()), buffer_size); + CLEAN_JNIENV + return count; +} +ReadBufferFromJavaInputStream::ReadBufferFromJavaInputStream(jobject input_stream, size_t customize_buffer_size) : java_in(input_stream), buffer_size(customize_buffer_size) +{ +} +ReadBufferFromJavaInputStream::~ReadBufferFromJavaInputStream() +{ + GET_JNIENV(env) + env->DeleteGlobalRef(java_in); + CLEAN_JNIENV + +} + +} diff --git a/utils/local-engine/Shuffle/ShuffleReader.h b/utils/local-engine/Shuffle/ShuffleReader.h new file mode 100644 index 000000000000..aa52a2a83208 --- /dev/null +++ b/utils/local-engine/Shuffle/ShuffleReader.h @@ -0,0 +1,43 @@ +#pragma once +#include +#include +#include +#include +#include + + +namespace local_engine +{ +class ReadBufferFromJavaInputStream; +class ShuffleReader : BlockIterator +{ +public: + explicit ShuffleReader(std::unique_ptr in_, bool compressed); + DB::Block* read(); + ~ShuffleReader(); + static jclass input_stream_class; + static jmethodID input_stream_read; + std::unique_ptr in; + +private: + std::unique_ptr compressed_in; + std::unique_ptr input_stream; + DB::Block header; +}; + + +class ReadBufferFromJavaInputStream : public DB::BufferWithOwnMemory +{ +public: + explicit ReadBufferFromJavaInputStream(jobject input_stream, size_t customize_buffer_size); + ~ReadBufferFromJavaInputStream() override; + +private: + jobject java_in; + size_t buffer_size; + int readFromJava(); + bool nextImpl() override; + +}; + +} diff --git a/utils/local-engine/Shuffle/ShuffleSplitter.cpp b/utils/local-engine/Shuffle/ShuffleSplitter.cpp new file mode 100644 index 000000000000..7a464bc2f348 --- /dev/null +++ b/utils/local-engine/Shuffle/ShuffleSplitter.cpp @@ -0,0 +1,375 @@ +#include "ShuffleSplitter.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +void ShuffleSplitter::split(DB::Block & block) +{ + if (block.rows() == 0) + { + return; + } + Stopwatch watch; + watch.start(); + computeAndCountPartitionId(block); + splitBlockByPartition(block); + split_result.total_write_time += watch.elapsedNanoseconds(); +} +SplitResult ShuffleSplitter::stop() +{ + // spill all buffers + Stopwatch watch; + watch.start(); + for (size_t i = 0; i < options.partition_nums; i++) + { + spillPartition(i); + partition_outputs[i]->flush(); + partition_write_buffers[i].reset(); + } + partition_outputs.clear(); + partition_cached_write_buffers.clear(); + partition_write_buffers.clear(); + mergePartitionFiles(); + split_result.total_write_time += watch.elapsedNanoseconds(); + stopped = true; + return split_result; +} +void ShuffleSplitter::splitBlockByPartition(DB::Block & block) +{ + if (!output_header.columns()) [[unlikely]] + { + if (output_columns_indicies.empty()) + { + output_header = block.cloneEmpty(); + for (size_t i = 0; i < block.columns(); ++i) + { + output_columns_indicies.push_back(i); + } + } + else + { + DB::ColumnsWithTypeAndName cols; + for (const auto & index : output_columns_indicies) + { + cols.push_back(block.getByPosition(index)); + } + output_header = DB::Block(cols); + } + } + DB::Block out_block; + for (size_t col = 0; col < output_header.columns(); ++col) + { + out_block.insert(block.getByPosition(output_columns_indicies[col])); + } + for (size_t col = 0; col < output_header.columns(); ++col) + { + for (size_t j = 0; j < partition_info.partition_num; ++j) + { + size_t from = partition_info.partition_start_points[j]; + size_t length = partition_info.partition_start_points[j + 1] - from; + if (length == 0) + continue; // no data for this partition continue; + partition_buffer[j].appendSelective(col, out_block, partition_info.partition_selector, from, length); + } + } + + for (size_t i = 0; i < options.partition_nums; ++i) + { + ColumnsBuffer & buffer = partition_buffer[i]; + if (buffer.size() >= options.split_size) + { + spillPartition(i); + } + } + +} +void ShuffleSplitter::init() +{ + partition_buffer.reserve(options.partition_nums); + partition_outputs.reserve(options.partition_nums); + partition_write_buffers.reserve(options.partition_nums); + partition_cached_write_buffers.reserve(options.partition_nums); + split_result.partition_length.reserve(options.partition_nums); + split_result.raw_partition_length.reserve(options.partition_nums); + for (size_t i = 0; i < options.partition_nums; ++i) + { + partition_buffer.emplace_back(ColumnsBuffer()); + split_result.partition_length.emplace_back(0); + split_result.raw_partition_length.emplace_back(0); + partition_outputs.emplace_back(nullptr); + partition_write_buffers.emplace_back(nullptr); + partition_cached_write_buffers.emplace_back(nullptr); + } +} + +void ShuffleSplitter::spillPartition(size_t partition_id) +{ + Stopwatch watch; + watch.start(); + if (!partition_outputs[partition_id]) + { + partition_write_buffers[partition_id] = getPartitionWriteBuffer(partition_id); + partition_outputs[partition_id] + = std::make_unique(*partition_write_buffers[partition_id], 0, partition_buffer[partition_id].getHeader()); + } + DB::Block result = partition_buffer[partition_id].releaseColumns(); + if (result.rows() > 0) + { + partition_outputs[partition_id]->write(result); + } + split_result.total_spill_time += watch.elapsedNanoseconds(); + split_result.total_bytes_spilled += result.bytes(); +} + +void ShuffleSplitter::mergePartitionFiles() +{ + DB::WriteBufferFromFile data_write_buffer = DB::WriteBufferFromFile(options.data_file); + std::string buffer; + int buffer_size = options.io_buffer_size; + buffer.reserve(buffer_size); + for (size_t i = 0; i < options.partition_nums; ++i) + { + auto file = getPartitionTempFile(i); + DB::ReadBufferFromFile reader = DB::ReadBufferFromFile(file, options.io_buffer_size); + while (reader.next()) + { + auto bytes = reader.readBig(buffer.data(), buffer_size); + data_write_buffer.write(buffer.data(), bytes); + split_result.partition_length[i] += bytes; + split_result.total_bytes_written += bytes; + } + reader.close(); + std::filesystem::remove(file); + } + data_write_buffer.close(); +} + +ShuffleSplitter::ShuffleSplitter(SplitOptions && options_) : options(options_) +{ + init(); +} + +ShuffleSplitter::Ptr ShuffleSplitter::create(const std::string & short_name, SplitOptions options_) +{ + if (short_name == "rr") + { + return RoundRobinSplitter::create(std::move(options_)); + } + else if (short_name == "hash") + { + return HashSplitter::create(std::move(options_)); + } + else if (short_name == "single") + { + options_.partition_nums = 1; + return RoundRobinSplitter::create(std::move(options_)); + } + else if (short_name == "range") + { + return RangeSplitter::create(std::move(options_)); + } + else + { + throw std::runtime_error("unsupported splitter " + short_name); + } +} + +std::string ShuffleSplitter::getPartitionTempFile(size_t partition_id) +{ + auto file_name = std::to_string(options.shuffle_id) + "_" + std::to_string(options.map_id) + "_" + std::to_string(partition_id); + std::hash hasher; + auto hash = hasher(file_name); + auto dir_id = hash % options.local_dirs_list.size(); + auto sub_dir_id = (hash / options.local_dirs_list.size()) % options.num_sub_dirs; + + std::string dir = std::filesystem::path(options.local_dirs_list[dir_id]) / std::format("{:02x}", sub_dir_id); + if (!std::filesystem::exists(dir)) + std::filesystem::create_directories(dir); + return std::filesystem::path(dir) / file_name; +} + +std::unique_ptr ShuffleSplitter::getPartitionWriteBuffer(size_t partition_id) +{ + auto file = getPartitionTempFile(partition_id); + if (partition_cached_write_buffers[partition_id] == nullptr) + partition_cached_write_buffers[partition_id] + = std::make_unique(file, options.io_buffer_size, O_CREAT | O_WRONLY | O_APPEND); + if (!options.compress_method.empty() + && std::find(compress_methods.begin(), compress_methods.end(), options.compress_method) != compress_methods.end()) + { + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), {}); + return std::make_unique(*partition_cached_write_buffers[partition_id], codec); + } + else + { + return std::move(partition_cached_write_buffers[partition_id]); + } +} + +const std::vector ShuffleSplitter::compress_methods = {"", "ZSTD", "LZ4"}; + +void ShuffleSplitter::writeIndexFile() +{ + auto index_file = options.data_file + ".index"; + auto writer = std::make_unique(index_file, options.io_buffer_size, O_CREAT | O_WRONLY | O_TRUNC); + for (auto len : split_result.partition_length) + { + DB::writeIntText(len, *writer); + DB::writeChar('\n', *writer); + } +} + +void ColumnsBuffer::add(DB::Block & block, int start, int end) +{ + if (header.columns() == 0) + header = block.cloneEmpty(); + if (accumulated_columns.empty()) [[unlikely]] + { + accumulated_columns.reserve(block.columns()); + for (size_t i = 0; i < block.columns(); i++) + { + auto column = block.getColumns()[i]->cloneEmpty(); + column->reserve(prefer_buffer_size); + accumulated_columns.emplace_back(std::move(column)); + } + } + assert(!accumulated_columns.empty()); + for (size_t i = 0; i < block.columns(); ++i) + accumulated_columns[i]->insertRangeFrom(*block.getByPosition(i).column, start, end - start); +} + +void ColumnsBuffer::appendSelective(size_t column_idx, const DB::Block & source, const DB::IColumn::Selector & selector, size_t from, size_t length) +{ + if (header.columns() == 0) + header = source.cloneEmpty(); + if (accumulated_columns.empty()) [[unlikely]] + { + accumulated_columns.reserve(source.columns()); + for (size_t i = 0; i < source.columns(); i++) + { + auto column = source.getColumns()[i]->convertToFullColumnIfConst()->cloneEmpty(); + column->reserve(prefer_buffer_size); + accumulated_columns.emplace_back(std::move(column)); + } + } + accumulated_columns[column_idx]->insertRangeSelective(*source.getByPosition(column_idx).column->convertToFullColumnIfConst(), selector, from, length); +} + +size_t ColumnsBuffer::size() const +{ + if (accumulated_columns.empty()) + return 0; + return accumulated_columns.at(0)->size(); +} + +DB::Block ColumnsBuffer::releaseColumns() +{ + DB::Columns res(std::make_move_iterator(accumulated_columns.begin()), std::make_move_iterator(accumulated_columns.end())); + accumulated_columns.clear(); + if (res.empty()) + { + return header.cloneEmpty(); + } + else + { + return header.cloneWithColumns(res); + } +} + +DB::Block ColumnsBuffer::getHeader() +{ + return header; +} +ColumnsBuffer::ColumnsBuffer(size_t prefer_buffer_size_) : prefer_buffer_size(prefer_buffer_size_) +{ +} + +RoundRobinSplitter::RoundRobinSplitter(SplitOptions options_) : ShuffleSplitter(std::move(options_)) +{ + Poco::StringTokenizer output_column_tokenizer(options_.out_exprs, ","); + for (auto iter = output_column_tokenizer.begin(); iter != output_column_tokenizer.end(); ++iter) + { + output_columns_indicies.push_back(std::stoi(*iter)); + } + selector_builder = std::make_unique(options.partition_nums); +} + +void RoundRobinSplitter::computeAndCountPartitionId(DB::Block & block) +{ + Stopwatch watch; + watch.start(); + partition_info = selector_builder->build(block); + split_result.total_compute_pid_time += watch.elapsedNanoseconds(); +} + +std::unique_ptr RoundRobinSplitter::create(SplitOptions && options_) +{ + return std::make_unique(std::move(options_)); +} + +HashSplitter::HashSplitter(SplitOptions options_) : ShuffleSplitter(std::move(options_)) +{ + Poco::StringTokenizer exprs_list(options_.hash_exprs, ","); + std::vector hash_fields; + for (auto iter = exprs_list.begin(); iter != exprs_list.end(); ++iter) + { + hash_fields.push_back(std::stoi(*iter)); + } + + Poco::StringTokenizer output_column_tokenizer(options_.out_exprs, ","); + for (auto iter = output_column_tokenizer.begin(); iter != output_column_tokenizer.end(); ++iter) + { + output_columns_indicies.push_back(std::stoi(*iter)); + } + + selector_builder = std::make_unique(options.partition_nums, hash_fields, "cityHash64"); +} +std::unique_ptr HashSplitter::create(SplitOptions && options_) +{ + return std::make_unique(std::move(options_)); +} + +void HashSplitter::computeAndCountPartitionId(DB::Block & block) +{ + Stopwatch watch; + watch.start(); + partition_info = selector_builder->build(block); + split_result.total_compute_pid_time += watch.elapsedNanoseconds(); +} + +std::unique_ptr RangeSplitter::create(SplitOptions && options_) +{ + return std::make_unique(std::move(options_)); +} + +RangeSplitter::RangeSplitter(SplitOptions options_) : ShuffleSplitter(std::move(options_)) +{ + Poco::StringTokenizer output_column_tokenizer(options_.out_exprs, ","); + for (auto iter = output_column_tokenizer.begin(); iter != output_column_tokenizer.end(); ++iter) + { + output_columns_indicies.push_back(std::stoi(*iter)); + } + selector_builder = std::make_unique(options.hash_exprs, options.partition_nums); +} +void RangeSplitter::computeAndCountPartitionId(DB::Block & block) +{ + Stopwatch watch; + watch.start(); + partition_info = selector_builder->build(block); + split_result.total_compute_pid_time += watch.elapsedNanoseconds(); +} +} diff --git a/utils/local-engine/Shuffle/ShuffleSplitter.h b/utils/local-engine/Shuffle/ShuffleSplitter.h new file mode 100644 index 000000000000..ece3433447a7 --- /dev/null +++ b/utils/local-engine/Shuffle/ShuffleSplitter.h @@ -0,0 +1,140 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace local_engine +{ +struct SplitOptions +{ + size_t split_size = DEFAULT_BLOCK_SIZE; + size_t io_buffer_size = DBMS_DEFAULT_BUFFER_SIZE; + std::string data_file; + std::vector local_dirs_list; + int num_sub_dirs; + int shuffle_id; + int map_id; + size_t partition_nums; + std::string hash_exprs; + std::string out_exprs; + // std::vector exprs; + std::string compress_method = "zstd"; + int compress_level; +}; + +class ColumnsBuffer +{ +public: + explicit ColumnsBuffer(size_t prefer_buffer_size = DEFAULT_BLOCK_SIZE); + void add(DB::Block & columns, int start, int end); + void appendSelective(size_t column_idx, const DB::Block & source, const DB::IColumn::Selector & selector, size_t from, size_t length); + size_t size() const; + DB::Block releaseColumns(); + DB::Block getHeader(); + +private: + DB::MutableColumns accumulated_columns; + DB::Block header; + size_t prefer_buffer_size; +}; + +struct SplitResult +{ + Int64 total_compute_pid_time = 0; + Int64 total_write_time = 0; + Int64 total_spill_time = 0; + Int64 total_bytes_written = 0; + Int64 total_bytes_spilled = 0; + std::vector partition_length; + std::vector raw_partition_length; +}; + +class ShuffleSplitter +{ +public: + static const std::vector compress_methods; + using Ptr = std::unique_ptr; + static Ptr create(const std::string & short_name, SplitOptions options_); + explicit ShuffleSplitter(SplitOptions && options); + virtual ~ShuffleSplitter() + { + if (!stopped) + stop(); + } + void split(DB::Block & block); + virtual void computeAndCountPartitionId(DB::Block &) { } + std::vector getPartitionLength() const { return split_result.partition_length; } + void writeIndexFile(); + SplitResult stop(); + +private: + void init(); + void splitBlockByPartition(DB::Block & block); + void spillPartition(size_t partition_id); + std::string getPartitionTempFile(size_t partition_id); + void mergePartitionFiles(); + std::unique_ptr getPartitionWriteBuffer(size_t partition_id); + +protected: + bool stopped = false; + PartitionInfo partition_info; + std::vector partition_buffer; + std::vector> partition_outputs; + std::vector> partition_write_buffers; + std::vector> partition_cached_write_buffers; + std::vector output_columns_indicies; + DB::Block output_header; + SplitOptions options; + SplitResult split_result; +}; + +class RoundRobinSplitter : public ShuffleSplitter +{ +public: + static std::unique_ptr create(SplitOptions && options); + + explicit RoundRobinSplitter(SplitOptions options_); + + ~RoundRobinSplitter() override = default; + void computeAndCountPartitionId(DB::Block & block) override; + +private: + std::unique_ptr selector_builder; +}; + +class HashSplitter : public ShuffleSplitter +{ +public: + static std::unique_ptr create(SplitOptions && options); + + explicit HashSplitter(SplitOptions options_); + + ~HashSplitter() override = default; + void computeAndCountPartitionId(DB::Block & block) override; + +private: + std::unique_ptr selector_builder; +}; + +class RangeSplitter : public ShuffleSplitter +{ +public: + static std::unique_ptr create(SplitOptions && options); + explicit RangeSplitter(SplitOptions options_); + void computeAndCountPartitionId(DB::Block & block) override; +private: + std::unique_ptr selector_builder; +}; +struct SplitterHolder +{ + ShuffleSplitter::Ptr splitter; +}; + + +} diff --git a/utils/local-engine/Shuffle/ShuffleWriter.cpp b/utils/local-engine/Shuffle/ShuffleWriter.cpp new file mode 100644 index 000000000000..fa52599a6af6 --- /dev/null +++ b/utils/local-engine/Shuffle/ShuffleWriter.cpp @@ -0,0 +1,58 @@ +#include "ShuffleWriter.h" +#include +#include +#include + +using namespace DB; + +namespace local_engine +{ + +ShuffleWriter::ShuffleWriter(jobject output_stream, jbyteArray buffer, const std::string & codecStr, bool enable_compression, size_t customize_buffer_size) +{ + compression_enable = enable_compression; + write_buffer = std::make_unique(output_stream, buffer, customize_buffer_size); + if (compression_enable) + { + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(codecStr), {}); + compressed_out = std::make_unique(*write_buffer, codec); + } +} +void ShuffleWriter::write(const Block & block) +{ + if (!native_writer) + { + if (compression_enable) + { + native_writer = std::make_unique(*compressed_out, 0, block.cloneEmpty()); + } + else + { + native_writer = std::make_unique(*write_buffer, 0, block.cloneEmpty()); + } + } + if (block.rows() > 0) + { + native_writer->write(block); + } +} +void ShuffleWriter::flush() +{ + if (native_writer) + { + native_writer->flush(); + } +} +ShuffleWriter::~ShuffleWriter() +{ + if (native_writer) + { + native_writer->flush(); + if (compression_enable) + { + compressed_out->finalize(); + } + write_buffer->finalize(); + } +} +} diff --git a/utils/local-engine/Shuffle/ShuffleWriter.h b/utils/local-engine/Shuffle/ShuffleWriter.h new file mode 100644 index 000000000000..d26134bc1c65 --- /dev/null +++ b/utils/local-engine/Shuffle/ShuffleWriter.h @@ -0,0 +1,21 @@ +#pragma once +#include +#include + +namespace local_engine +{ +class ShuffleWriter +{ +public: + ShuffleWriter(jobject output_stream, jbyteArray buffer, const std::string & codecStr, bool enable_compression, size_t customize_buffer_size); + virtual ~ShuffleWriter(); + void write(const DB::Block & block); + void flush(); + +private: + std::unique_ptr compressed_out; + std::unique_ptr write_buffer; + std::unique_ptr native_writer; + bool compression_enable; +}; +} diff --git a/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.cpp b/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.cpp new file mode 100644 index 000000000000..aa6023aaec5c --- /dev/null +++ b/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.cpp @@ -0,0 +1,46 @@ +#include "WriteBufferFromJavaOutputStream.h" +#include +#include + +namespace local_engine +{ +jclass WriteBufferFromJavaOutputStream::output_stream_class = nullptr; +jmethodID WriteBufferFromJavaOutputStream::output_stream_write = nullptr; +jmethodID WriteBufferFromJavaOutputStream::output_stream_flush = nullptr; + +void WriteBufferFromJavaOutputStream::nextImpl() +{ + GET_JNIENV(env) + size_t bytes_write = 0; + while (offset() - bytes_write > 0) + { + size_t copy_num = std::min(offset() - bytes_write, buffer_size); + env->SetByteArrayRegion(buffer, 0 , copy_num, reinterpret_cast(this->working_buffer.begin() + bytes_write)); + safeCallVoidMethod(env, output_stream, output_stream_write, buffer, 0, copy_num); + bytes_write += copy_num; + } + CLEAN_JNIENV +} +WriteBufferFromJavaOutputStream::WriteBufferFromJavaOutputStream(jobject output_stream_, jbyteArray buffer_, size_t customize_buffer_size) +{ + GET_JNIENV(env) + buffer = static_cast(env->NewWeakGlobalRef(buffer_)); + output_stream = env->NewWeakGlobalRef(output_stream_); + buffer_size = customize_buffer_size; + CLEAN_JNIENV +} +void WriteBufferFromJavaOutputStream::finalizeImpl() +{ + next(); + GET_JNIENV(env) + safeCallVoidMethod(env, output_stream, output_stream_flush); + CLEAN_JNIENV +} +WriteBufferFromJavaOutputStream::~WriteBufferFromJavaOutputStream() +{ + GET_JNIENV(env) + env->DeleteWeakGlobalRef(output_stream); + env->DeleteWeakGlobalRef(buffer); + CLEAN_JNIENV +} +} diff --git a/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.h b/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.h new file mode 100644 index 000000000000..2579edc2e75d --- /dev/null +++ b/utils/local-engine/Shuffle/WriteBufferFromJavaOutputStream.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include +#include + +namespace local_engine +{ +class WriteBufferFromJavaOutputStream : public DB::BufferWithOwnMemory +{ +public: + static jclass output_stream_class; + static jmethodID output_stream_write; + static jmethodID output_stream_flush; + + WriteBufferFromJavaOutputStream(jobject output_stream, jbyteArray buffer, size_t customize_buffer_size); + ~WriteBufferFromJavaOutputStream() override; + +private: + void nextImpl() override; + +protected: + void finalizeImpl() override; + +private: + jobject output_stream; + jbyteArray buffer; + size_t buffer_size; +}; +} diff --git a/utils/local-engine/Storages/ArrowParquetBlockInputFormat.cpp b/utils/local-engine/Storages/ArrowParquetBlockInputFormat.cpp new file mode 100644 index 000000000000..21a390f3546c --- /dev/null +++ b/utils/local-engine/Storages/ArrowParquetBlockInputFormat.cpp @@ -0,0 +1,103 @@ +#include "ArrowParquetBlockInputFormat.h" + +#include +#include +#include +#include +#include + +#include "ch_parquet/OptimizedArrowColumnToCHColumn.h" + +using namespace DB; + +namespace local_engine +{ +ArrowParquetBlockInputFormat::ArrowParquetBlockInputFormat( + DB::ReadBuffer & in_, const DB::Block & header, const DB::FormatSettings & formatSettings, const std::vector & row_group_indices_) + : OptimizedParquetBlockInputFormat(in_, header, formatSettings) + , row_group_indices(row_group_indices_) +{ +} + +static size_t countIndicesForType(std::shared_ptr type) +{ + if (type->id() == arrow::Type::LIST) + return countIndicesForType(static_cast(type.get())->value_type()); + + if (type->id() == arrow::Type::STRUCT) + { + int indices = 0; + auto * struct_type = static_cast(type.get()); + for (int i = 0; i != struct_type->num_fields(); ++i) + indices += countIndicesForType(struct_type->field(i)->type()); + return indices; + } + + if (type->id() == arrow::Type::MAP) + { + auto * map_type = static_cast(type.get()); + return countIndicesForType(map_type->key_type()) + countIndicesForType(map_type->item_type()); + } + + return 1; +} + +DB::Chunk ArrowParquetBlockInputFormat::generate() +{ + DB::Chunk res; + block_missing_values.clear(); + + if (!file_reader) + { + prepareReader(); + file_reader->set_batch_size(8192); + if (row_group_indices.empty() && file_reader->num_row_groups() == 0) + { + return {}; + } + else if (row_group_indices.empty()) + { + auto row_group_range = boost::irange(0, file_reader->num_row_groups()); + row_group_indices = std::vector(row_group_range.begin(), row_group_range.end()); + } + auto read_status = file_reader->GetRecordBatchReader(row_group_indices, column_indices, ¤t_record_batch_reader); + if (!read_status.ok()) + throw std::runtime_error{"Error while reading Parquet data: " + read_status.ToString()}; + } + + if (is_stopped) + return {}; + + + Stopwatch watch; + watch.start(); + auto batch = current_record_batch_reader->Next(); + if (*batch) + { + auto tmp_table = arrow::Table::FromRecordBatches({*batch}); + if (format_settings.use_lowercase_column_name) + { + tmp_table = (*tmp_table)->RenameColumns(column_names); + } + non_convert_time += watch.elapsedNanoseconds(); + watch.restart(); + arrow_column_to_ch_column->arrowTableToCHChunk(res, *tmp_table); + convert_time += watch.elapsedNanoseconds(); + } + else + { + current_record_batch_reader.reset(); + file_reader.reset(); + return {}; + } + + /// If defaults_for_omitted_fields is true, calculate the default values from default expression for omitted fields. + /// Otherwise fill the missing columns with zero values of its type. + if (format_settings.defaults_for_omitted_fields) + for (size_t row_idx = 0; row_idx < res.getNumRows(); ++row_idx) + for (const auto & column_idx : missing_columns) + block_missing_values.setBit(column_idx, row_idx); + return res; +} + +} diff --git a/utils/local-engine/Storages/ArrowParquetBlockInputFormat.h b/utils/local-engine/Storages/ArrowParquetBlockInputFormat.h new file mode 100644 index 000000000000..e01741f5be8b --- /dev/null +++ b/utils/local-engine/Storages/ArrowParquetBlockInputFormat.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include "ch_parquet/OptimizedParquetBlockInputFormat.h" +#include "ch_parquet/OptimizedArrowColumnToCHColumn.h" +#include "ch_parquet/arrow/reader.h" + +namespace arrow +{ +class RecordBatchReader; +class Table; +} + +namespace local_engine +{ +class ArrowParquetBlockInputFormat : public DB::OptimizedParquetBlockInputFormat +{ +public: + ArrowParquetBlockInputFormat(DB::ReadBuffer & in, const DB::Block & header, const DB::FormatSettings & formatSettings, const std::vector & row_group_indices_ = {}); + //virtual ~ArrowParquetBlockInputFormat(); + +private: + DB::Chunk generate() override; + + int64_t convert_time = 0; + int64_t non_convert_time = 0; + std::shared_ptr current_record_batch_reader; + std::vector row_group_indices; +}; + +} diff --git a/utils/local-engine/Storages/CMakeLists.txt b/utils/local-engine/Storages/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/local-engine/Storages/CustomMergeTreeSink.cpp b/utils/local-engine/Storages/CustomMergeTreeSink.cpp new file mode 100644 index 000000000000..2a4dd95f4642 --- /dev/null +++ b/utils/local-engine/Storages/CustomMergeTreeSink.cpp @@ -0,0 +1,18 @@ +#include "CustomMergeTreeSink.h" + +void local_engine::CustomMergeTreeSink::consume(Chunk chunk) +{ + auto block = metadata_snapshot->getSampleBlock().cloneWithColumns(chunk.detachColumns()); + DB::BlockWithPartition block_with_partition(Block(block), DB::Row{}); + auto part = storage.writer.writeTempPart(block_with_partition, metadata_snapshot, context); + MergeTreeData::Transaction transaction(storage, NO_TRANSACTION_RAW); + { + auto lock = storage.lockParts(); + storage.renameTempPartAndAdd(part.part, transaction, lock); + transaction.commit(&lock); + } +} +//std::list local_engine::CustomMergeTreeSink::getOutputs() +//{ +// return {}; +//} diff --git a/utils/local-engine/Storages/CustomMergeTreeSink.h b/utils/local-engine/Storages/CustomMergeTreeSink.h new file mode 100644 index 000000000000..44e84c488c3e --- /dev/null +++ b/utils/local-engine/Storages/CustomMergeTreeSink.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include "CustomStorageMergeTree.h" + +namespace local_engine +{ + +class CustomMergeTreeSink : public ISink +{ +public: + CustomMergeTreeSink( + CustomStorageMergeTree & storage_, + const StorageMetadataPtr metadata_snapshot_, + ContextPtr context_) + : ISink(metadata_snapshot_->getSampleBlock()) + , storage(storage_) + , metadata_snapshot(metadata_snapshot_) + , context(context_) + { + } + + String getName() const override { return "CustomMergeTreeSink"; } + void consume(Chunk chunk) override; +// std::list getOutputs(); +private: + CustomStorageMergeTree & storage; + StorageMetadataPtr metadata_snapshot; + ContextPtr context; +}; + +} + diff --git a/utils/local-engine/Storages/CustomStorageMergeTree.cpp b/utils/local-engine/Storages/CustomStorageMergeTree.cpp new file mode 100644 index 000000000000..d510dfd3acb9 --- /dev/null +++ b/utils/local-engine/Storages/CustomStorageMergeTree.cpp @@ -0,0 +1,84 @@ +#include "CustomStorageMergeTree.h" + +namespace local_engine +{ + +CustomStorageMergeTree::CustomStorageMergeTree( + const StorageID & table_id_, + const String & relative_data_path_, + const StorageInMemoryMetadata & metadata_, + bool attach, + ContextMutablePtr context_, + const String & date_column_name, + const MergingParams & merging_params_, + std::unique_ptr storage_settings_, + bool /*has_force_restore_data_flag*/) + : MergeTreeData( + table_id_, + metadata_, + context_, + date_column_name, + merging_params_, + std::move(storage_settings_), + false, /// require_part_metadata + attach) + , writer(*this) + , reader(*this) +{ + initializeDirectoriesAndFormatVersion(relative_data_path_, attach, date_column_name); +} +void CustomStorageMergeTree::dropPartNoWaitNoThrow(const String & /*part_name*/) +{ + throw std::runtime_error("not implement"); +} +void CustomStorageMergeTree::dropPart(const String & /*part_name*/, bool /*detach*/, ContextPtr /*context*/) +{ + throw std::runtime_error("not implement"); +} +void CustomStorageMergeTree::dropPartition(const ASTPtr & /*partition*/, bool /*detach*/, ContextPtr /*context*/) +{ +} +PartitionCommandsResultInfo CustomStorageMergeTree::attachPartition( + const ASTPtr & /*partition*/, const StorageMetadataPtr & /*metadata_snapshot*/, bool /*part*/, ContextPtr /*context*/) +{ + throw std::runtime_error("not implement"); +} +void CustomStorageMergeTree::replacePartitionFrom( + const StoragePtr & /*source_table*/, const ASTPtr & /*partition*/, bool /*replace*/, ContextPtr /*context*/) +{ + throw std::runtime_error("not implement"); +} +void CustomStorageMergeTree::movePartitionToTable(const StoragePtr & /*dest_table*/, const ASTPtr & /*partition*/, ContextPtr /*context*/) +{ + throw std::runtime_error("not implement"); +} +bool CustomStorageMergeTree::partIsAssignedToBackgroundOperation(const MergeTreeData::DataPartPtr & /*part*/) const +{ + throw std::runtime_error("not implement"); +} +MutationCommands CustomStorageMergeTree::getFirstAlterMutationCommandsForPart(const MergeTreeData::DataPartPtr & /*part*/) const +{ + return {}; +} +std::string CustomStorageMergeTree::getName() const +{ + throw std::runtime_error("not implement"); +} +std::vector CustomStorageMergeTree::getMutationsStatus() const +{ + throw std::runtime_error("not implement"); +} +bool CustomStorageMergeTree::scheduleDataProcessingJob(BackgroundJobsAssignee & /*executor*/) +{ + throw std::runtime_error("not implement"); +} +void CustomStorageMergeTree::startBackgroundMovesIfNeeded() +{ + throw std::runtime_error("not implement"); +} +std::unique_ptr CustomStorageMergeTree::getDefaultSettings() const +{ + throw std::runtime_error("not implement"); +} + +} diff --git a/utils/local-engine/Storages/CustomStorageMergeTree.h b/utils/local-engine/Storages/CustomStorageMergeTree.h new file mode 100644 index 000000000000..753b4932495a --- /dev/null +++ b/utils/local-engine/Storages/CustomStorageMergeTree.h @@ -0,0 +1,58 @@ +#pragma once + +#include +#include +#include +#include + +namespace local_engine +{ +using namespace DB; + +class CustomMergeTreeSink; + +class CustomStorageMergeTree final : public MergeTreeData +{ + friend class CustomMergeTreeSink; + +public: + CustomStorageMergeTree( + const StorageID & table_id_, + const String & relative_data_path_, + const StorageInMemoryMetadata & metadata, + bool attach, + ContextMutablePtr context_, + const String & date_column_name, + const MergingParams & merging_params_, + std::unique_ptr settings_, + bool has_force_restore_data_flag = false); + std::string getName() const override; + std::vector getMutationsStatus() const override; + bool scheduleDataProcessingJob(BackgroundJobsAssignee & executor) override; + + MergeTreeDataWriter writer; + MergeTreeDataSelectExecutor reader; + +private: + SimpleIncrement increment; + + void startBackgroundMovesIfNeeded() override; + std::unique_ptr getDefaultSettings() const override; + +protected: + void dropPartNoWaitNoThrow(const String & part_name) override; + void dropPart(const String & part_name, bool detach, ContextPtr context) override; + void dropPartition(const ASTPtr & partition, bool detach, ContextPtr context) override; + PartitionCommandsResultInfo + attachPartition(const ASTPtr & partition, const StorageMetadataPtr & metadata_snapshot, bool part, ContextPtr context) override; + void replacePartitionFrom(const StoragePtr & source_table, const ASTPtr & partition, bool replace, ContextPtr context) override; + void movePartitionToTable(const StoragePtr & dest_table, const ASTPtr & partition, ContextPtr context) override; + bool partIsAssignedToBackgroundOperation(const DataPartPtr & part) const override; + MutationCommands getFirstAlterMutationCommandsForPart(const DataPartPtr & part) const override; + void attachRestoredParts(MutableDataPartsVector && parts) override + { + throw std::runtime_error("not implement"); + }; +}; + +} diff --git a/utils/local-engine/Storages/SourceFromJavaIter.cpp b/utils/local-engine/Storages/SourceFromJavaIter.cpp new file mode 100644 index 000000000000..9764d8267454 --- /dev/null +++ b/utils/local-engine/Storages/SourceFromJavaIter.cpp @@ -0,0 +1,94 @@ +#include "SourceFromJavaIter.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +jclass SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr; +jmethodID SourceFromJavaIter::serialized_record_batch_iterator_hasNext = nullptr; +jmethodID SourceFromJavaIter::serialized_record_batch_iterator_next = nullptr; + + +static DB::Block getRealHeader(const DB::Block & header) +{ + if (header.columns()) + return header; + return BlockUtil::buildRowCountHeader(); +} +SourceFromJavaIter::SourceFromJavaIter(DB::Block header, jobject java_iter_) + : DB::ISource(getRealHeader(header)) + , java_iter(java_iter_) + , original_header(header) +{ +} +DB::Chunk SourceFromJavaIter::generate() +{ + GET_JNIENV(env) + jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext); + DB::Chunk result; + if (has_next) + { + jbyteArray block = static_cast(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next)); + DB::Block * data = reinterpret_cast(byteArrayToLong(env, block)); + if (data->rows() > 0) + { + size_t rows = data->rows(); + if (original_header.columns()) + { + result.setColumns(data->mutateColumns(), rows); + convertNullable(result); + auto info = std::make_shared(); + info->is_overflows = data->info.is_overflows; + info->bucket_num = data->info.bucket_num; + result.setChunkInfo(info); + } + else + { + result = BlockUtil::buildRowCountChunk(rows); + } + } + } + CLEAN_JNIENV + return result; +} +SourceFromJavaIter::~SourceFromJavaIter() +{ + GET_JNIENV(env) + env->DeleteGlobalRef(java_iter); + CLEAN_JNIENV +} +Int64 SourceFromJavaIter::byteArrayToLong(JNIEnv * env, jbyteArray arr) +{ + jsize len = env->GetArrayLength(arr); + assert(len == sizeof(Int64)); + char * c_arr = new char[len]; + env->GetByteArrayRegion(arr, 0, len, reinterpret_cast(c_arr)); + std::reverse(c_arr, c_arr + 8); + Int64 result = reinterpret_cast(c_arr)[0]; + delete[] c_arr; + return result; +} +void SourceFromJavaIter::convertNullable(DB::Chunk & chunk) +{ + auto output = this->getOutputs().front().getHeader(); + auto rows = chunk.getNumRows(); + auto columns = chunk.detachColumns(); + for (size_t i = 0; i < columns.size(); ++i) + { + DB::WhichDataType which(columns.at(i)->getDataType()); + if (output.getByPosition(i).type->isNullable() && !which.isNullable() && !which.isAggregateFunction()) + { + columns[i] = DB::makeNullable(columns.at(i)); + } + } + chunk.setColumns(columns, rows); +} +} diff --git a/utils/local-engine/Storages/SourceFromJavaIter.h b/utils/local-engine/Storages/SourceFromJavaIter.h new file mode 100644 index 000000000000..cf326734ba9b --- /dev/null +++ b/utils/local-engine/Storages/SourceFromJavaIter.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include + +namespace local_engine +{ +class SourceFromJavaIter : public DB::ISource +{ +public: + static jclass serialized_record_batch_iterator_class; + static jmethodID serialized_record_batch_iterator_hasNext; + static jmethodID serialized_record_batch_iterator_next; + + static Int64 byteArrayToLong(JNIEnv * env, jbyteArray arr); + + SourceFromJavaIter(DB::Block header, jobject java_iter_); + ~SourceFromJavaIter() override; + + String getName() const override { return "SourceFromJavaIter"; } + +private: + DB::Chunk generate() override; + void convertNullable(DB::Chunk & chunk); + + jobject java_iter; + DB::Block original_header; +}; + +} diff --git a/utils/local-engine/Storages/StorageJoinFromReadBuffer.cpp b/utils/local-engine/Storages/StorageJoinFromReadBuffer.cpp new file mode 100644 index 000000000000..829817374d75 --- /dev/null +++ b/utils/local-engine/Storages/StorageJoinFromReadBuffer.cpp @@ -0,0 +1,386 @@ +#include "StorageJoinFromReadBuffer.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NOT_IMPLEMENTED; + extern const int LOGICAL_ERROR; + extern const int UNSUPPORTED_JOIN_KEYS; + extern const int NO_SUCH_COLUMN_IN_TABLE; + extern const int INCOMPATIBLE_TYPE_OF_JOIN; +} + +template +static const char * rawData(T & t) +{ + return reinterpret_cast(&t); +} +template +static size_t rawSize(T &) +{ + return sizeof(T); +} +template <> +const char * rawData(const StringRef & t) +{ + return t.data; +} +template <> +size_t rawSize(const StringRef & t) +{ + return t.size; +} + +class JoinSource : public ISource +{ +public: + JoinSource(HashJoinPtr join_, TableLockHolder lock_holder_, UInt64 max_block_size_, Block sample_block_) + : ISource(sample_block_) + , join(join_) + , lock_holder(lock_holder_) + , max_block_size(max_block_size_) + , sample_block(std::move(sample_block_)) + { + if (!join->getTableJoin().oneDisjunct()) + throw DB::Exception(ErrorCodes::NOT_IMPLEMENTED, "StorageJoin does not support OR for keys in JOIN ON section"); + + column_indices.resize(sample_block.columns()); + + auto & saved_block = join->getJoinedData()->sample_block; + + for (size_t i = 0; i < sample_block.columns(); ++i) + { + auto & [_, type, name] = sample_block.getByPosition(i); + if (join->right_table_keys.has(name)) + { + key_pos = i; + const auto & column = join->right_table_keys.getByName(name); + restored_block.insert(column); + } + else + { + size_t pos = saved_block.getPositionByName(name); + column_indices[i] = pos; + + const auto & column = saved_block.getByPosition(pos); + restored_block.insert(column); + } + } + } + + String getName() const override { return "Join"; } + +protected: + Chunk generate() override + { + if (join->data->blocks.empty()) + return {}; + + Chunk chunk; + if (!joinDispatch(join->kind, join->strictness, join->data->maps.front(), + [&](auto kind, auto strictness, auto & map) { chunk = createChunk(map); })) + throw Exception("Logical error: unknown JOIN strictness", ErrorCodes::LOGICAL_ERROR); + return chunk; + } + +private: + HashJoinPtr join; + TableLockHolder lock_holder; + + UInt64 max_block_size; + Block sample_block; + Block restored_block; /// sample_block with parent column types + + ColumnNumbers column_indices; + std::optional key_pos; + + std::unique_ptr> position; /// type erasure + + template + Chunk createChunk(const Maps & maps) + { + MutableColumns mut_columns = restored_block.cloneEmpty().mutateColumns(); + + size_t rows_added = 0; + + switch (join->data->type) + { +#define M(TYPE) \ + case HashJoin::Type::TYPE: \ + rows_added = fillColumns(*maps.TYPE, mut_columns); \ + break; + APPLY_FOR_JOIN_VARIANTS_LIMITED(M) +#undef M + + default: + throw Exception("Unsupported JOIN keys in StorageJoin. Type: " + toString(static_cast(join->data->type)), + ErrorCodes::UNSUPPORTED_JOIN_KEYS); + } + + if (!rows_added) + return {}; + + Columns columns; + columns.reserve(mut_columns.size()); + for (auto & col : mut_columns) + columns.emplace_back(std::move(col)); + + /// Correct nullability and LowCardinality types + for (size_t i = 0; i < columns.size(); ++i) + { + const auto & src = restored_block.getByPosition(i); + const auto & dst = sample_block.getByPosition(i); + + if (!src.type->equals(*dst.type)) + { + auto arg = src; + arg.column = std::move(columns[i]); + columns[i] = castColumn(arg, dst.type); + } + } + + UInt64 num_rows = columns.at(0)->size(); + return Chunk(std::move(columns), num_rows); + } + + template + size_t fillColumns(const Map & map, MutableColumns & columns) + { + size_t rows_added = 0; + + if (!position) + position = decltype(position)( + static_cast(new typename Map::const_iterator(map.begin())), //-V572 + [](void * ptr) { delete reinterpret_cast(ptr); }); + + auto & it = *reinterpret_cast(position.get()); + auto end = map.end(); + + for (; it != end; ++it) + { + if constexpr (STRICTNESS == JoinStrictness::RightAny) + { + fillOne(columns, column_indices, it, key_pos, rows_added); + } + else if constexpr (STRICTNESS == JoinStrictness::All) + { + fillAll(columns, column_indices, it, key_pos, rows_added); + } + else if constexpr (STRICTNESS == JoinStrictness::Any) + { + if constexpr (KIND == JoinKind::Left || KIND == JoinKind::Inner) + fillOne(columns, column_indices, it, key_pos, rows_added); + else if constexpr (KIND == JoinKind::Right) + fillAll(columns, column_indices, it, key_pos, rows_added); + } + else if constexpr (STRICTNESS == JoinStrictness::Semi) + { + if constexpr (KIND == JoinKind::Left) + fillOne(columns, column_indices, it, key_pos, rows_added); + else if constexpr (KIND == JoinKind::Right) + fillAll(columns, column_indices, it, key_pos, rows_added); + } + else if constexpr (STRICTNESS == JoinStrictness::Anti) + { + if constexpr (KIND == JoinKind::Left) + fillOne(columns, column_indices, it, key_pos, rows_added); + else if constexpr (KIND == JoinKind::Right) + fillAll(columns, column_indices, it, key_pos, rows_added); + } + else + throw Exception("This JOIN is not implemented yet", ErrorCodes::NOT_IMPLEMENTED); + + if (rows_added >= max_block_size) + { + ++it; + break; + } + } + + return rows_added; + } + + template + static void fillOne(MutableColumns & columns, const ColumnNumbers & column_indices, typename Map::const_iterator & it, + const std::optional & key_pos, size_t & rows_added) + { + for (size_t j = 0; j < columns.size(); ++j) + if (j == key_pos) + columns[j]->insertData(rawData(it->getKey()), rawSize(it->getKey())); + else + columns[j]->insertFrom(*it->getMapped().block->getByPosition(column_indices[j]).column.get(), it->getMapped().row_num); + ++rows_added; + } + + template + static void fillAll(MutableColumns & columns, const ColumnNumbers & column_indices, typename Map::const_iterator & it, + const std::optional & key_pos, size_t & rows_added) + { + for (auto ref_it = it->getMapped().begin(); ref_it.ok(); ++ref_it) + { + for (size_t j = 0; j < columns.size(); ++j) + if (j == key_pos) + columns[j]->insertData(rawData(it->getKey()), rawSize(it->getKey())); + else + columns[j]->insertFrom(*ref_it->block->getByPosition(column_indices[j]).column.get(), ref_it->row_num); + ++rows_added; + } + } +}; + +} + +using namespace DB; + +namespace local_engine +{ + +void StorageJoinFromReadBuffer::rename(const String & /*new_path_to_table_data*/, const DB::StorageID & /*new_table_id*/) +{ + throw std::runtime_error("unsupported operation"); +} +DB::SinkToStoragePtr +StorageJoinFromReadBuffer::write(const DB::ASTPtr & /*query*/, const DB::StorageMetadataPtr & /*ptr*/, DB::ContextPtr /*context*/) +{ + throw std::runtime_error("unsupported operation"); +} +bool StorageJoinFromReadBuffer::storesDataOnDisk() const +{ + return false; +} +DB::Strings StorageJoinFromReadBuffer::getDataPaths() const +{ + throw std::runtime_error("unsupported operation"); +} + +void StorageJoinFromReadBuffer::finishInsert() +{ + in.reset(); +} + +DB::Pipe StorageJoinFromReadBuffer::read( + const DB::Names & column_names, + const DB::StorageSnapshotPtr & storage_snapshot, + DB::SelectQueryInfo & /*query_info*/, + DB::ContextPtr context, + DB::QueryProcessingStage::Enum /*processed_stage*/, + size_t max_block_size, + size_t /*num_streams*/) +{ + storage_snapshot->check(column_names); + + Block source_sample_block = storage_snapshot->getSampleBlockForColumns(column_names); + RWLockImpl::LockHolder holder = tryLockTimedWithContext(rwlock, RWLockImpl::Read, context); + return Pipe(std::make_shared(join, std::move(holder), max_block_size, source_sample_block)); +} + +void StorageJoinFromReadBuffer::restore() +{ + if (!in) + { + throw std::runtime_error("input reader buffer is not available"); + } + ContextPtr ctx = nullptr; + NativeReader block_stream(*in, 0); + + ProfileInfo info; + while (Block block = block_stream.read()) + { + auto final_block = sample_block.cloneWithColumns(block.mutateColumns()); + info.update(final_block); + insertBlock(final_block, ctx); + } + + finishInsert(); +} +void StorageJoinFromReadBuffer::insertBlock(const Block & block, DB::ContextPtr context) +{ + TableLockHolder holder = tryLockTimedWithContext(rwlock, RWLockImpl::Write, context); + join->addJoinedBlock(block, true); +} +size_t StorageJoinFromReadBuffer::getSize(DB::ContextPtr context) const +{ + TableLockHolder holder = tryLockTimedWithContext(rwlock, RWLockImpl::Read, context); + return join->getTotalRowCount(); +} +DB::RWLockImpl::LockHolder +StorageJoinFromReadBuffer::tryLockTimedWithContext(const RWLock & lock, DB::RWLockImpl::Type type, DB::ContextPtr context) const +{ + const String query_id = context ? context->getInitialQueryId() : RWLockImpl::NO_QUERY; + const std::chrono::milliseconds acquire_timeout + = context ? context->getSettingsRef().lock_acquire_timeout : std::chrono::seconds(DBMS_DEFAULT_LOCK_ACQUIRE_TIMEOUT_SEC); + return tryLockTimed(lock, type, query_id, acquire_timeout); +} +StorageJoinFromReadBuffer::StorageJoinFromReadBuffer( + std::unique_ptr in_, + const StorageID & table_id_, + const Names & key_names_, + bool use_nulls_, + DB::SizeLimits limits_, + DB::JoinKind kind_, + DB::JoinStrictness strictness_, + const ColumnsDescription & columns_, + const ConstraintsDescription & constraints_, + const String & comment, + const bool overwrite_, + const String & relative_path_) : StorageSetOrJoinBase{nullptr, relative_path_, table_id_, columns_, constraints_, comment, false} + , key_names(key_names_) + , use_nulls(use_nulls_) + , limits(limits_) + , kind(kind_) + , strictness(strictness_) + , overwrite(overwrite_) + , in(std::move(in_)) +{ + auto metadata_snapshot = getInMemoryMetadataPtr(); + sample_block = metadata_snapshot->getSampleBlock(); + for (const auto & key : key_names) + if (!metadata_snapshot->getColumns().hasPhysical(key)) + throw Exception{"Key column (" + key + ") does not exist in table declaration.", ErrorCodes::NO_SUCH_COLUMN_IN_TABLE}; + + table_join = std::make_shared(limits, use_nulls, kind, strictness, key_names); + join = std::make_shared(table_join, getRightSampleBlock(), overwrite); + restore(); +} +DB::HashJoinPtr StorageJoinFromReadBuffer::getJoinLocked(std::shared_ptr analyzed_join, DB::ContextPtr context) const +{ + auto metadata_snapshot = getInMemoryMetadataPtr(); + if (!analyzed_join->sameStrictnessAndKind(strictness, kind)) + throw Exception("Table " + getStorageID().getNameForLogs() + " has incompatible type of JOIN.", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN); + + if ((analyzed_join->forceNullableRight() && !use_nulls) || + (!analyzed_join->forceNullableRight() && isLeftOrFull(analyzed_join->kind()) && use_nulls)) + throw Exception( + ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN, + "Table {} needs the same join_use_nulls setting as present in LEFT or FULL JOIN", + getStorageID().getNameForLogs()); + + /// TODO: check key columns + + /// Set names qualifiers: table.column -> column + /// It's required because storage join stores non-qualified names + /// Qualifies will be added by join implementation (HashJoin) + analyzed_join->setRightKeys(key_names); + + HashJoinPtr join_clone = std::make_shared(analyzed_join, getRightSampleBlock()); + + RWLockImpl::LockHolder holder = tryLockTimedWithContext(rwlock, RWLockImpl::Read, context); + join_clone->setLock(holder); + join_clone->reuseJoinedData(*join); + + return join_clone; +} +} diff --git a/utils/local-engine/Storages/StorageJoinFromReadBuffer.h b/utils/local-engine/Storages/StorageJoinFromReadBuffer.h new file mode 100644 index 000000000000..ba114f519da6 --- /dev/null +++ b/utils/local-engine/Storages/StorageJoinFromReadBuffer.h @@ -0,0 +1,91 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace DB +{ +class QueryPlan; +} + +namespace local_engine +{ +class StorageJoinFromReadBuffer : public DB::StorageSetOrJoinBase +{ + +public: + StorageJoinFromReadBuffer( + std::unique_ptr in_, + const DB::StorageID & table_id_, + const DB::Names & key_names_, + bool use_nulls_, + DB::SizeLimits limits_, + DB::JoinKind kind_, + DB::JoinStrictness strictness_, + const DB::ColumnsDescription & columns_, + const DB::ConstraintsDescription & constraints_, + const String & comment, + bool overwrite_, + const String & relative_path_ = "/tmp" /* useless variable */); + + String getName() const override { return "Join"; } + + void rename(const String & new_path_to_table_data, const DB::StorageID & new_table_id) override; + DB::HashJoinPtr getJoinLocked(std::shared_ptr analyzed_join, DB::ContextPtr context) const; + DB::SinkToStoragePtr write(const DB::ASTPtr & query, const DB::StorageMetadataPtr & ptr, DB::ContextPtr context) override; + bool storesDataOnDisk() const override; + DB::Strings getDataPaths() const override; + DB::Pipe read( + const DB::Names & column_names, + const DB::StorageSnapshotPtr & storage_snapshot, + DB::SelectQueryInfo & query_info, + DB::ContextPtr context, + DB::QueryProcessingStage::Enum processed_stage, + size_t max_block_size, + size_t num_streams) override; + DB::Block getRightSampleBlock() const + { + auto metadata_snapshot = getInMemoryMetadataPtr(); + DB::Block block = metadata_snapshot->getSampleBlock(); + if (use_nulls && isLeftOrFull(kind)) + { + for (auto & col : block) + { + DB::JoinCommon::convertColumnToNullable(col); + } + } + return block; + } +protected: + void restore(); + +private: + void insertBlock(const DB::Block & block, DB::ContextPtr context) override; + void finishInsert() override; + size_t getSize(DB::ContextPtr context) const override; + DB::RWLockImpl::LockHolder tryLockTimedWithContext(const DB::RWLock & lock, DB::RWLockImpl::Type type, DB::ContextPtr context) const; + + DB::Block sample_block; + const DB::Names key_names; + bool use_nulls; + DB::SizeLimits limits; + DB::JoinKind kind; /// LEFT | INNER ... + DB::JoinStrictness strictness; /// ANY | ALL + bool overwrite; + + std::shared_ptr table_join; + DB::HashJoinPtr join; + + std::unique_ptr in; + + /// Protect state for concurrent use in insertFromBlock and joinBlock. + /// Lock is stored in HashJoin instance during query and blocks concurrent insertions. + mutable DB::RWLock rwlock = DB::RWLockImpl::create(); + mutable std::mutex mutate_mutex; +}; +} + + diff --git a/utils/local-engine/Storages/StorageMergeTreeFactory.cpp b/utils/local-engine/Storages/StorageMergeTreeFactory.cpp new file mode 100644 index 000000000000..c47488e23ebb --- /dev/null +++ b/utils/local-engine/Storages/StorageMergeTreeFactory.cpp @@ -0,0 +1,67 @@ +#include "StorageMergeTreeFactory.h" + +namespace local_engine +{ + +StorageMergeTreeFactory & StorageMergeTreeFactory::instance() +{ + static StorageMergeTreeFactory ret; + return ret; +} + +CustomStorageMergeTreePtr +StorageMergeTreeFactory::getStorage(StorageID id, ColumnsDescription columns, std::function creator) +{ + auto table_name = id.database_name + "." + id.table_name; + std::lock_guard lock(storage_map_mutex); + if (!storage_map.contains(table_name)) + { + if (storage_map.contains(table_name)) + { + std::set existed_columns = storage_columns_map.at(table_name); + for (const auto & column : columns) + { + if (!existed_columns.contains(column.name)) + { + storage_map.erase(table_name); + storage_columns_map.erase(table_name); + } + } + } + if (!storage_map.contains(table_name)) + { + storage_map.emplace(table_name, creator()); + storage_columns_map.emplace(table_name, std::set()); + for (const auto & column : storage_map.at(table_name)->getInMemoryMetadataPtr()->columns) + { + storage_columns_map.at(table_name).emplace(column.name); + } + } + } + return storage_map.at(table_name); +} + +StorageInMemoryMetadataPtr StorageMergeTreeFactory::getMetadata(StorageID id, std::function creator) +{ + auto table_name = id.database_name + "." + id.table_name; + + std::lock_guard lock(metadata_map_mutex); + if (!metadata_map.contains(table_name)) + { + if (!metadata_map.contains(table_name)) + { + metadata_map.emplace(table_name, creator()); + } + } + return metadata_map.at(table_name); +} + + +std::unordered_map StorageMergeTreeFactory::storage_map; +std::unordered_map> StorageMergeTreeFactory::storage_columns_map; +std::mutex StorageMergeTreeFactory::storage_map_mutex; + +std::unordered_map StorageMergeTreeFactory::metadata_map; +std::mutex StorageMergeTreeFactory::metadata_map_mutex; + +} diff --git a/utils/local-engine/Storages/StorageMergeTreeFactory.h b/utils/local-engine/Storages/StorageMergeTreeFactory.h new file mode 100644 index 000000000000..f811b0bee180 --- /dev/null +++ b/utils/local-engine/Storages/StorageMergeTreeFactory.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace local_engine +{ +using CustomStorageMergeTreePtr = std::shared_ptr; +using StorageInMemoryMetadataPtr = std::shared_ptr; + +class StorageMergeTreeFactory +{ +public: + static StorageMergeTreeFactory & instance(); + static CustomStorageMergeTreePtr getStorage(StorageID id, ColumnsDescription columns, std::function creator); + static StorageInMemoryMetadataPtr getMetadata(StorageID id, std::function creator); + +private: + static std::unordered_map storage_map; + static std::unordered_map> storage_columns_map; + static std::mutex storage_map_mutex; + + static std::unordered_map metadata_map; + static std::mutex metadata_map_mutex; +}; +} diff --git a/utils/local-engine/Storages/SubstraitSource/CMakeLists.txt b/utils/local-engine/Storages/SubstraitSource/CMakeLists.txt new file mode 100644 index 000000000000..9f0b55aee81f --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/CMakeLists.txt @@ -0,0 +1,46 @@ + +set(ARROW_INCLUDE_DIR "${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/src") + + +macro(add_headers_and_sources_including_cc prefix common_path) + add_glob(${prefix}_headers ${CMAKE_CURRENT_SOURCE_DIR} ${common_path}/*.h) + add_glob(${prefix}_sources ${common_path}/*.cpp ${common_path}/*.c ${common_path}/*.cc ${common_path}/*.h) +endmacro() + +add_headers_and_sources(substait_source .) +add_headers_and_sources_including_cc(ch_parquet arrow) +add_library(substait_source ${substait_source_sources}) +target_compile_options(substait_source PUBLIC -fPIC + -Wno-shorten-64-to-32 + -Wno-shadow-field-in-constructor + -Wno-return-type + -Wno-reserved-identifier + -Wno-extra-semi-stmt + -Wno-extra-semi + -Wno-unused-result + -Wno-unreachable-code-return + -Wno-unused-parameter + -Wno-unreachable-code + -Wno-pessimizing-move + -Wno-unreachable-code-break + -Wno-unused-variable + -Wno-inconsistent-missing-override + -Wno-shadow-uncaptured-local + -Wno-suggest-override + -Wno-unused-member-function + -Wno-deprecated-this-capture +) + +target_link_libraries(substait_source PUBLIC + boost::headers_only + ch_contrib::protobuf + clickhouse_common_io + ch_contrib::hdfs + substrait +) + +target_include_directories(substait_source SYSTEM BEFORE PUBLIC + ${ARROW_INCLUDE_DIR} + ${CMAKE_BINARY_DIR}/contrib/arrow-cmake/cpp/src + ${ClickHouse_SOURCE_DIR}/contrib/arrow-cmake/cpp/src +) \ No newline at end of file diff --git a/utils/local-engine/Storages/SubstraitSource/FormatFile.cpp b/utils/local-engine/Storages/SubstraitSource/FormatFile.cpp new file mode 100644 index 000000000000..fc804c793ec1 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/FormatFile.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NOT_IMPLEMENTED; +} +} +namespace local_engine +{ +FormatFile::FormatFile( + DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_) + : context(context_), file_info(file_info_), read_buffer_builder(read_buffer_builder_) +{ + PartitionValues part_vals = StringUtils::parsePartitionTablePath(file_info.uri_file()); + for (const auto & part : part_vals) + { + partition_keys.push_back(part.first); + partition_values[part.first] = part.second; + } +} + +FormatFilePtr FormatFileUtil::createFile(DB::ContextPtr context, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file) +{ + if (file.has_parquet()) + { + return std::make_shared(context, file, read_buffer_builder); + } + else if (file.has_orc()) + { + return std::make_shared(context, file, read_buffer_builder); + } + else + { + throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Format not suupported:{}", file.DebugString()); + } + + __builtin_unreachable(); +} +} diff --git a/utils/local-engine/Storages/SubstraitSource/FormatFile.h b/utils/local-engine/Storages/SubstraitSource/FormatFile.h new file mode 100644 index 000000000000..351ba8773287 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/FormatFile.h @@ -0,0 +1,66 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +class FormatFile +{ +public: + struct InputFormat + { + public: + DB::InputFormatPtr input; + std::unique_ptr read_buffer; + }; + using InputFormatPtr = std::shared_ptr; + + FormatFile( + DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_); + virtual ~FormatFile() = default; + + /// create a new input format for reading this file + virtual InputFormatPtr createInputFormat(const DB::Block & header) = 0; + + /// Spark would split a large file into small segements and read in different tasks + /// If this file doesn't support the split feacture, only the task with offset 0 will generate data. + virtual bool supportSplit() { return false; } + + /// try to get rows from file metadata + virtual std::optional getTotalRows() { return {}; } + + /// get partition keys from file path + inline const std::vector & getFilePartitionKeys() const { return partition_keys; } + + inline const std::map & getFilePartitionValues() const { return partition_values; } + + virtual String getURIPath() const { return file_info.uri_file(); } + + virtual size_t getStartOffset() const { return file_info.start(); } + virtual size_t getLength() const { return file_info.length(); } + +protected: + DB::ContextPtr context; + substrait::ReadRel::LocalFiles::FileOrFiles file_info; + ReadBufferBuilderPtr read_buffer_builder; + std::vector partition_keys; + std::map partition_values; + +}; +using FormatFilePtr = std::shared_ptr; +using FormatFiles = std::vector; + +class FormatFileUtil +{ +public: + static FormatFilePtr createFile(DB::ContextPtr context, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file); +}; +} diff --git a/utils/local-engine/Storages/SubstraitSource/OrcFormatFile.cpp b/utils/local-engine/Storages/SubstraitSource/OrcFormatFile.cpp new file mode 100644 index 000000000000..ee04f7e85ce3 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/OrcFormatFile.cpp @@ -0,0 +1,239 @@ +#include "OrcFormatFile.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int CANNOT_READ_ALL_DATA; +} +} +namespace local_engine +{ + +ORCBlockInputFormat::ORCBlockInputFormat( + DB::ReadBuffer & in_, DB::Block header_, const DB::FormatSettings & format_settings_, const std::vector & stripes_) + : IInputFormat(std::move(header_), in_), format_settings(format_settings_), stripes(stripes_) +{ +} + + +void ORCBlockInputFormat::resetParser() +{ + IInputFormat::resetParser(); + + file_reader.reset(); + include_indices.clear(); + include_column_names.clear(); + block_missing_values.clear(); + current_stripe = 0; +} + + +DB::Chunk ORCBlockInputFormat::generate() +{ + DB::Chunk res; + block_missing_values.clear(); + + if (!file_reader) + prepareReader(); + + if (is_stopped) + return {}; + + std::shared_ptr batch_reader; + batch_reader = fetchNextStripe(); + if (!batch_reader) + { + return res; + } + + std::shared_ptr table; + arrow::Status table_status = batch_reader->ReadAll(&table); + if (!table_status.ok()) + { + throw DB::ParsingException( + DB::ErrorCodes::CANNOT_READ_ALL_DATA, "Error while reading batch of ORC data: {}", table_status.ToString()); + } + + if (!table || !table->num_rows()) + { + return res; + } + + if (format_settings.use_lowercase_column_name) + table = *table->RenameColumns(include_column_names); + + arrow_column_to_ch_column->arrowTableToCHChunk(res, table); + /// If defaults_for_omitted_fields is true, calculate the default values from default expression for omitted fields. + /// Otherwise fill the missing columns with zero values of its type. + if (format_settings.defaults_for_omitted_fields) + for (size_t row_idx = 0; row_idx < res.getNumRows(); ++row_idx) + for (const auto & column_idx : missing_columns) + block_missing_values.setBit(column_idx, row_idx); + return res; +} + + +void ORCBlockInputFormat::prepareReader() +{ + std::shared_ptr schema; + OrcUtil::getFileReaderAndSchema(*in, file_reader, schema, format_settings, is_stopped); + if (is_stopped) + return; + + arrow_column_to_ch_column = std::make_unique( + getPort().getHeader(), "ORC", format_settings.orc.import_nested, format_settings.orc.allow_missing_columns); + missing_columns = arrow_column_to_ch_column->getMissingColumns(*schema); + + std::unordered_set nested_table_names; + if (format_settings.orc.import_nested) + nested_table_names = DB::Nested::getAllTableNames(getPort().getHeader()); + + + /// In ReadStripe column indices should be started from 1, + /// because 0 indicates to select all columns. + int index = 1; + for (int i = 0; i < schema->num_fields(); ++i) + { + /// LIST type require 2 indices, STRUCT - the number of elements + 1, + /// so we should recursively count the number of indices we need for this type. + int indexes_count = OrcUtil::countIndicesForType(schema->field(i)->type()); + const auto & name = schema->field(i)->name(); + if (getPort().getHeader().has(name) || nested_table_names.contains(name)) + { + for (int j = 0; j != indexes_count; ++j) + { + include_indices.push_back(index + j); + include_column_names.push_back(name); + } + } + index += indexes_count; + } +} + +std::shared_ptr ORCBlockInputFormat::stepOneStripe() +{ + auto result = file_reader->NextStripeReader(format_settings.orc.row_batch_size, include_indices); + current_stripe += 1; + if (!result.ok()) + { + throw DB::ParsingException(DB::ErrorCodes::CANNOT_READ_ALL_DATA, "Failed to create batch reader: {}", result.status().ToString()); + } + std::shared_ptr batch_reader; + batch_reader = std::move(result).ValueOrDie(); + return batch_reader; +} + +std::shared_ptr ORCBlockInputFormat::fetchNextStripe() +{ + if (current_stripe >= stripes.size()) + return nullptr; + auto & strip = stripes[current_stripe]; + file_reader->Seek(strip.start_row); + return stepOneStripe(); +} + +OrcFormatFile::OrcFormatFile( + DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_) + : FormatFile(context_, file_info_, read_buffer_builder_) +{ +} + +FormatFile::InputFormatPtr OrcFormatFile::createInputFormat(const DB::Block & header) +{ + auto read_buffer = read_buffer_builder->build(file_info); + auto format_settings = DB::getFormatSettings(context); + format_settings.orc.import_nested = true; + auto file_format = std::make_shared(); + file_format->read_buffer = std::move(read_buffer); + std::vector stripes; + if (auto * seekable_in = dynamic_cast(file_format->read_buffer.get())) + { + stripes = collectRequiredStripes(seekable_in); + seekable_in->seek(0, SEEK_SET); + } + else + { + stripes = collectRequiredStripes(); + } + auto input_format = std::make_shared(*file_format->read_buffer, header, format_settings, stripes); + file_format->input = input_format; + return file_format; +} + +std::optional OrcFormatFile::getTotalRows() +{ + { + std::lock_guard lock(mutex); + if (total_rows) + return total_rows; + } + auto required_stripes = collectRequiredStripes(); + { + std::lock_guard lock(mutex); + if (total_rows) + return total_rows; + size_t num_rows = 0; + for (const auto stipe_info : required_stripes) + { + num_rows += stipe_info.num_rows; + } + total_rows = num_rows; + return total_rows; + } +} + +std::vector OrcFormatFile::collectRequiredStripes() +{ + auto in = read_buffer_builder->build(file_info); + return collectRequiredStripes(in.get()); +} + +std::vector OrcFormatFile::collectRequiredStripes(DB::ReadBuffer* read_buffer) +{ + std::vector stripes; + DB::FormatSettings format_settings; + format_settings.seekable_read = true; + std::atomic is_stopped{0}; + auto arrow_file = DB::asArrowFile(*read_buffer, format_settings, is_stopped, "ORC", ORC_MAGIC_BYTES); + auto orc_reader = OrcUtil::createOrcReader(arrow_file); + auto num_stripes = orc_reader->getNumberOfStripes(); + + size_t total_num_rows = 0; + for (size_t i = 0; i < num_stripes; ++i) + { + auto stripe_metadata = orc_reader->getStripe(i); + auto offset = stripe_metadata->getOffset(); + if (file_info.start() <= offset && offset < file_info.start() + file_info.length()) + { + StripeInformation stripe_info; + stripe_info.index = i; + stripe_info.offset = stripe_metadata->getLength(); + stripe_info.length = stripe_metadata->getLength(); + stripe_info.num_rows = stripe_metadata->getNumberOfRows(); + stripe_info.start_row = total_num_rows; + stripes.emplace_back(stripe_info); + } + total_num_rows += stripe_metadata->getNumberOfRows(); + } + return stripes; +} +} diff --git a/utils/local-engine/Storages/SubstraitSource/OrcFormatFile.h b/utils/local-engine/Storages/SubstraitSource/OrcFormatFile.h new file mode 100644 index 000000000000..1539b3821299 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/OrcFormatFile.h @@ -0,0 +1,92 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace local_engine +{ + +struct StripeInformation +{ + UInt64 index; + UInt64 offset; + UInt64 length; + UInt64 num_rows; + UInt64 start_row; +}; + +// local engine's orc block input formatter +// the behavior of generate is different from DB::ORCBlockInputFormat +class ORCBlockInputFormat : public DB::IInputFormat +{ +public: + explicit ORCBlockInputFormat( + DB::ReadBuffer & in_, + DB::Block header_, + const DB::FormatSettings & format_settings_, + const std::vector & stripes_); + + String getName() const override { return "LocalEngineORCBlockInputFormat"; } + + void resetParser() override; + + const DB::BlockMissingValues & getMissingValues() const override { return block_missing_values; } + +protected: + DB::Chunk generate() override; + + void onCancel() override { is_stopped = 1; } +private: + + // TODO: check that this class implements every part of its parent + + std::unique_ptr file_reader; + + std::unique_ptr arrow_column_to_ch_column; + + // indices of columns to read from ORC file + std::vector include_indices; + std::vector include_column_names; + + std::vector missing_columns; + DB::BlockMissingValues block_missing_values; + + const DB::FormatSettings format_settings; + + std::vector stripes; + UInt64 current_stripe = 0; + + std::atomic is_stopped{0}; + + void prepareReader(); + + std::shared_ptr stepOneStripe(); + + std::shared_ptr fetchNextStripe(); +}; + +class OrcFormatFile : public FormatFile +{ +public: + + explicit OrcFormatFile( + DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_); + ~OrcFormatFile() override = default; + FormatFile::InputFormatPtr createInputFormat(const DB::Block & header) override; + std::optional getTotalRows() override; + + bool supportSplit() override { return true; } + +private: + std::mutex mutex; + std::optional total_rows; + + std::vector collectRequiredStripes(); + std::vector collectRequiredStripes(DB::ReadBuffer * read_buffer); +}; +} diff --git a/utils/local-engine/Storages/SubstraitSource/OrcUtil.cpp b/utils/local-engine/Storages/SubstraitSource/OrcUtil.cpp new file mode 100644 index 000000000000..4e618916c59e --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/OrcUtil.cpp @@ -0,0 +1,164 @@ +#include "OrcUtil.h" +#include +#include +#include +#include +#include +#include +#include + + +#define ORC_THROW_NOT_OK(s) \ + do { \ + arrow::Status _s = (s); \ + if (!_s.ok()) { \ + DB::WriteBufferFromOwnString ss; \ + ss << "Arrow error: " << _s.ToString(); \ + throw orc::ParseError(ss.str()); \ + } \ + } while (0) + +#define ORC_ASSIGN_OR_THROW_IMPL(status_name, lhs, rexpr) \ + auto status_name = (rexpr); \ + ORC_THROW_NOT_OK(status_name.status()); \ + lhs = std::move(status_name).ValueOrDie(); + +#define ORC_ASSIGN_OR_THROW(lhs, rexpr) \ + ORC_ASSIGN_OR_THROW_IMPL(ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), \ + lhs, rexpr); + +#define ORC_BEGIN_CATCH_NOT_OK try { +#define ORC_END_CATCH_NOT_OK \ + } \ + catch (const orc::ParseError& e) { \ + return arrow::Status::IOError(e.what()); \ + } \ + catch (const orc::InvalidArgument& e) { \ + return arrow::Status::Invalid(e.what()); \ + } \ + catch (const orc::NotImplementedYet& e) { \ + return arrow::Status::NotImplemented(e.what()); \ + } + +#define ORC_CATCH_NOT_OK(_s) \ + ORC_BEGIN_CATCH_NOT_OK(_s); \ + ORC_END_CATCH_NOT_OK + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int BAD_ARGUMENTS; +} +} +namespace local_engine +{ +uint64_t ArrowInputFile::getLength() const +{ + ORC_ASSIGN_OR_THROW(int64_t size, file->GetSize()); + return static_cast(size); +} + +uint64_t ArrowInputFile::getNaturalReadSize() const +{ + return 128 * 1024; +} + +void ArrowInputFile::read(void * buf, uint64_t length, uint64_t offset) +{ + ORC_ASSIGN_OR_THROW(int64_t bytes_read, file->ReadAt(offset, length, buf)); + + if (static_cast(bytes_read) != length) + { + throw orc::ParseError("Short read from arrow input file"); + } +} + +const std::string & ArrowInputFile::getName() const +{ + static const std::string filename("ArrowInputFile"); + return filename; +} + +arrow::Status innerCreateOrcReader(std::shared_ptr file_, std::unique_ptr * orc_reader) +{ + std::unique_ptr io_wrapper(new ArrowInputFile(file_)); + orc::ReaderOptions options; + ORC_CATCH_NOT_OK(*orc_reader = std::move(orc::createReader(std::move(io_wrapper), options))); + + return arrow::Status::OK(); + +} + +std::unique_ptr OrcUtil::createOrcReader(std::shared_ptr file_) +{ + std::unique_ptr orc_reader; + auto status = innerCreateOrcReader(file_, &orc_reader); + if (!status.ok()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Create orc reader failed. {}", status.message()); + } + return orc_reader; +} + +size_t OrcUtil::countIndicesForType(std::shared_ptr type) +{ + if (type->id() == arrow::Type::LIST) + return countIndicesForType(static_cast(type.get())->value_type()) + 1; + + if (type->id() == arrow::Type::STRUCT) + { + int indices = 1; + auto * struct_type = static_cast(type.get()); + for (int i = 0; i != struct_type->num_fields(); ++i) + indices += countIndicesForType(struct_type->field(i)->type()); + return indices; + } + + if (type->id() == arrow::Type::MAP) + { + auto * map_type = static_cast(type.get()); + return countIndicesForType(map_type->key_type()) + countIndicesForType(map_type->item_type()) + 1; + } + + return 1; +} + +void OrcUtil::getFileReaderAndSchema( + DB::ReadBuffer & in, + std::unique_ptr & file_reader, + std::shared_ptr & schema, + const DB::FormatSettings & format_settings, + std::atomic & is_stopped) +{ + auto arrow_file = DB::asArrowFile(in, format_settings, is_stopped, "ORC", ORC_MAGIC_BYTES); + if (is_stopped) + return; + + auto result = arrow::adapters::orc::ORCFileReader::Open(arrow_file, arrow::default_memory_pool()); + if (!result.ok()) + throw DB::Exception(result.status().ToString(), DB::ErrorCodes::BAD_ARGUMENTS); + file_reader = std::move(result).ValueOrDie(); + + auto read_schema_result = file_reader->ReadSchema(); + if (!read_schema_result.ok()) + throw DB::Exception(read_schema_result.status().ToString(), DB::ErrorCodes::BAD_ARGUMENTS); + schema = std::move(read_schema_result).ValueOrDie(); + + if (format_settings.use_lowercase_column_name) + { + std::vector> fields; + fields.reserve(schema->num_fields()); + for (int i = 0; i < schema->num_fields(); ++i) + { + const auto& field = schema->field(i); + auto name = field->name(); + boost::to_lower(name); + fields.push_back(field->WithName(name)); + } + schema = arrow::schema(fields, schema->metadata()); + } +} +} diff --git a/utils/local-engine/Storages/SubstraitSource/OrcUtil.h b/utils/local-engine/Storages/SubstraitSource/OrcUtil.h new file mode 100644 index 000000000000..983e72a8a173 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/OrcUtil.h @@ -0,0 +1,46 @@ +#pragma once +#include +#include +/// there are destructor not be overrided warnings in orc lib, ignore them +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsuggest-destructor-override" +#include +#include +#pragma GCC diagnostic pop +#include +#include + +namespace local_engine{ + +class ArrowInputFile : public orc::InputStream { + public: + explicit ArrowInputFile(const std::shared_ptr& file_) + : file(file_) {} + + uint64_t getLength() const override; + + uint64_t getNaturalReadSize() const override; + + void read(void* buf, uint64_t length, uint64_t offset) override; + + const std::string& getName() const override; + + private: + std::shared_ptr file; +}; + +class OrcUtil +{ +public: + static std::unique_ptr createOrcReader(std::shared_ptr file_); + + static size_t countIndicesForType(std::shared_ptr type); + static void getFileReaderAndSchema( + DB::ReadBuffer & in, + std::unique_ptr & file_reader, + std::shared_ptr & schema, + const DB::FormatSettings & format_settings, + std::atomic & is_stopped); +}; + +} diff --git a/utils/local-engine/Storages/SubstraitSource/ParquetFormatFile.cpp b/utils/local-engine/Storages/SubstraitSource/ParquetFormatFile.cpp new file mode 100644 index 000000000000..8f35bad1b934 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/ParquetFormatFile.cpp @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} +} + +namespace local_engine +{ +ParquetFormatFile::ParquetFormatFile(DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_) + : FormatFile(context_, file_info_, read_buffer_builder_) +{ +} + +FormatFile::InputFormatPtr ParquetFormatFile::createInputFormat(const DB::Block & header) +{ + auto res = std::make_shared(); + res->read_buffer = std::move(read_buffer_builder->build(file_info)); + std::vector row_group_indices; + std::vector required_row_groups; + if (auto * seekable_in = dynamic_cast(res->read_buffer.get())) + { + // reuse the read_buffer to avoid opening the file twice. + // especially,the cost of opening a hdfs file is large. + required_row_groups = collectRequiredRowGroups(seekable_in); + seekable_in->seek(0, SEEK_SET); + } + else + { + required_row_groups = collectRequiredRowGroups(); + } + for (const auto & row_group : required_row_groups) + { + row_group_indices.emplace_back(row_group.index); + } + auto format_settings = DB::getFormatSettings(context); + format_settings.parquet.import_nested = true; + auto input_format = std::make_shared(*(res->read_buffer), header, format_settings, row_group_indices); + res->input = input_format; + return res; +} + +std::optional ParquetFormatFile::getTotalRows() +{ + { + std::lock_guard lock(mutex); + if (total_rows) + return total_rows; + } + auto rowgroups = collectRequiredRowGroups(); + size_t rows = 0; + for (const auto & rowgroup : rowgroups) + { + rows += rowgroup.num_rows; + } + { + std::lock_guard lock(mutex); + total_rows = rows; + return total_rows; + } +} + +std::vector ParquetFormatFile::collectRequiredRowGroups() +{ + auto in = read_buffer_builder->build(file_info); + return collectRequiredRowGroups(in.get()); +} + +std::vector ParquetFormatFile::collectRequiredRowGroups(DB::ReadBuffer * read_buffer) +{ + std::unique_ptr reader; + DB::FormatSettings format_settings; + format_settings.seekable_read = true; + std::atomic is_stopped{0}; + auto status = parquet::arrow::OpenFile( + asArrowFile(*read_buffer, format_settings, is_stopped, "Parquet", PARQUET_MAGIC_BYTES), arrow::default_memory_pool(), &reader); + if (!status.ok()) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Open file({}) failed. {}", file_info.uri_file(), status.ToString()); + } + + auto file_meta = reader->parquet_reader()->metadata(); + std::vector row_group_metadatas; + for (int i = 0, n = file_meta->num_row_groups(); i < n; ++i) + { + auto row_group_meta = file_meta->RowGroup(i); + auto offset = static_cast(row_group_meta->file_offset()); + if (!offset) + { + offset = static_cast(row_group_meta->ColumnChunk(0)->file_offset()); + } + if (file_info.start() <= offset && offset < file_info.start() + file_info.length()) + { + RowGroupInfomation info; + info.index = i; + info.num_rows = row_group_meta->num_rows(); + info.start = row_group_meta->file_offset(); + info.total_compressed_size = row_group_meta->total_compressed_size(); + info.total_size = row_group_meta->total_byte_size(); + row_group_metadatas.emplace_back(info); + } + } + return row_group_metadatas; + +} +} diff --git a/utils/local-engine/Storages/SubstraitSource/ParquetFormatFile.h b/utils/local-engine/Storages/SubstraitSource/ParquetFormatFile.h new file mode 100644 index 000000000000..d26e5a28ac39 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/ParquetFormatFile.h @@ -0,0 +1,34 @@ +#pragma once +#include +#include +#include +#include +#include +namespace local_engine +{ +struct RowGroupInfomation +{ + UInt32 index = 0; + UInt64 start = 0; + UInt64 total_compressed_size = 0; + UInt64 total_size = 0; + UInt64 num_rows = 0; +}; +class ParquetFormatFile : public FormatFile +{ +public: + explicit ParquetFormatFile(DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_); + ~ParquetFormatFile() override = default; + FormatFile::InputFormatPtr createInputFormat(const DB::Block & header) override; + std::optional getTotalRows() override; + bool supportSplit() override { return true; } + +private: + std::mutex mutex; + std::optional total_rows; + + std::vector collectRequiredRowGroups(); + std::vector collectRequiredRowGroups(DB::ReadBuffer * read_buffer); +}; + +} diff --git a/utils/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp b/utils/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp new file mode 100644 index 000000000000..12390b0ffd10 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp @@ -0,0 +1,204 @@ +#include +#include +#include "IO/ReadSettings.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} +} + +namespace local_engine +{ + +class LocalFileReadBufferBuilder : public ReadBufferBuilder +{ +public: + explicit LocalFileReadBufferBuilder(DB::ContextPtr context_) : ReadBufferBuilder(context_) {} + ~LocalFileReadBufferBuilder() override = default; + + std::unique_ptr build(const substrait::ReadRel::LocalFiles::FileOrFiles & file_info) override + { + Poco::URI file_uri(file_info.uri_file()); + std::unique_ptr read_buffer; + const String & file_path = file_uri.getPath(); + struct stat file_stat; + if (stat(file_path.c_str(), &file_stat)) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "file stat failed for {}", file_path); + + if (S_ISREG(file_stat.st_mode)) + read_buffer = std::make_unique(file_path); + else + read_buffer = std::make_unique(file_path); + return read_buffer; + } +}; + +#if USE_HDFS +class HDFSFileReadBufferBuilder : public ReadBufferBuilder +{ +public: + explicit HDFSFileReadBufferBuilder(DB::ContextPtr context_) : ReadBufferBuilder(context_) {} + ~HDFSFileReadBufferBuilder() override = default; + + std::unique_ptr build(const substrait::ReadRel::LocalFiles::FileOrFiles & file_info) override + { + Poco::URI file_uri(file_info.uri_file()); + std::unique_ptr read_buffer; + + std::string uri_path = "hdfs://" + file_uri.getHost(); + if (file_uri.getPort()) + uri_path += ":" + std::to_string(file_uri.getPort()); + DB::ReadSettings read_settings; + read_buffer = std::make_unique( + uri_path, file_uri.getPath(), context->getGlobalContext()->getConfigRef(), + read_settings); + return read_buffer; + } +}; +#endif + +#if USE_AWS_S3 +class S3FileReadBufferBuilder : public ReadBufferBuilder +{ +public: + explicit S3FileReadBufferBuilder(DB::ContextPtr context_) : ReadBufferBuilder(context_) {} + ~S3FileReadBufferBuilder() override = default; + + std::unique_ptr build(const substrait::ReadRel::LocalFiles::FileOrFiles & file_info) override + { + Poco::URI file_uri(file_info.uri_file()); + auto client = getClient(); + std::unique_ptr readbuffer; + readbuffer + = std::make_unique(client, file_uri.getHost(), file_uri.getPath().substr(1), "", DB::S3Settings::RequestSettings(),DB::ReadSettings()); + return readbuffer; + } +private: + std::shared_ptr shared_client; + + std::shared_ptr getClient() + { + if (shared_client) + return shared_client; + const auto & config = context->getConfigRef(); + String config_prefix = "s3"; + DB::S3::PocoHTTPClientConfiguration client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration( + config.getString(config_prefix + ".region", ""), + context->getRemoteHostFilter(), + context->getGlobalContext()->getSettingsRef().s3_max_redirects, + false, + false, + nullptr, + nullptr); + + DB::S3::URI uri(config.getString(config_prefix + ".endpoint")); + + client_configuration.connectTimeoutMs = config.getUInt(config_prefix + ".connect_timeout_ms", 10000); + client_configuration.requestTimeoutMs = config.getUInt(config_prefix + ".request_timeout_ms", 5000); + client_configuration.maxConnections = config.getUInt(config_prefix + ".max_connections", 100); + client_configuration.endpointOverride = uri.endpoint; + + client_configuration.retryStrategy + = std::make_shared(config.getUInt(config_prefix + ".retry_attempts", 10)); + + shared_client = DB::S3::ClientFactory::instance().create( + client_configuration, + uri.is_virtual_hosted_style, + config.getString(config_prefix + ".access_key_id", ""), + config.getString(config_prefix + ".secret_access_key", ""), + config.getString(config_prefix + ".server_side_encryption_customer_key_base64", ""), + {}, + config.getBool(config_prefix + ".use_environment_credentials", config.getBool("s3.use_environment_credentials", false)), + config.getBool(config_prefix + ".use_insecure_imds_request", config.getBool("s3.use_insecure_imds_request", false))); + return shared_client; + } +}; +#endif + +#if USE_AZURE_BLOB_STORAGE +class AzureBlobReadBuffer : public ReadBufferBuilder +{ +public: + explicit AzureBlobReadBuffer(DB::ContextPtr context_) : ReadBufferBuilder(context_) {} + ~AzureBlobReadBuffer() override = default; + + std::unique_ptr build(const substrait::ReadRel::LocalFiles::FileOrFiles & file_info) + { + Poco::URI file_uri(file_info.uri_file()); + std::unique_ptr read_buffer; + read_buffer = std::make_unique(getClient(), file_uri.getPath(), DB::ReadSettings(), 5, 5); + return read_buffer; + } +private: + std::shared_ptr shared_client; + + std::shared_ptr getClient() + { + if (shared_client) + return shared_client; + shared_client = DB::getAzureBlobContainerClient(context->getConfigRef(), "blob"); + return shared_client; + } +}; +#endif + +void registerReadBufferBuilders() +{ + auto & factory = ReadBufferBuilderFactory::instance(); + factory.registerBuilder("file", [](DB::ContextPtr context_) { return std::make_shared(context_); }); + +#if USE_HDFS + factory.registerBuilder("hdfs", [](DB::ContextPtr context_) { return std::make_shared(context_); }); +#endif + +#if USE_AWS_S3 + factory.registerBuilder("s3", [](DB::ContextPtr context_) { return std::make_shared(context_); }); +#endif + +#if USE_AZURE_BLOB_STORAGE + factory.registerBuilder("wasb", [](DB::ContextPtr context_) { return std::make_shared(context_); }); + factory.registerBuilder("wasbs", [](DB::ContextPtr context_) { return std::make_shared(context_); }); +#endif +} + +ReadBufferBuilderFactory & ReadBufferBuilderFactory::instance() +{ + static ReadBufferBuilderFactory instance; + return instance; +} + +ReadBufferBuilderPtr ReadBufferBuilderFactory::createBuilder(const String & schema, DB::ContextPtr context) +{ + auto it = builders.find(schema); + if (it == builders.end()) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Not found read buffer builder for {}", schema); + return it->second(context); +} + +void ReadBufferBuilderFactory::registerBuilder(const String & schema, NewBuilder newer) +{ + auto it = builders.find(schema); + if (it != builders.end()) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "readbuffer builder for {} has been registered", schema); + builders[schema] = newer; +} + +} diff --git a/utils/local-engine/Storages/SubstraitSource/ReadBufferBuilder.h b/utils/local-engine/Storages/SubstraitSource/ReadBufferBuilder.h new file mode 100644 index 000000000000..4632fb4c0a5b --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/ReadBufferBuilder.h @@ -0,0 +1,38 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +namespace local_engine +{ +class ReadBufferBuilder +{ +public: + explicit ReadBufferBuilder(DB::ContextPtr context_) : context(context_) {} + virtual ~ReadBufferBuilder() = default; + /// build a new read buffer + virtual std::unique_ptr build(const substrait::ReadRel::LocalFiles::FileOrFiles & file_info) = 0; +protected: + DB::ContextPtr context; +}; + +using ReadBufferBuilderPtr = std::shared_ptr; + +class ReadBufferBuilderFactory : public boost::noncopyable +{ +public: + using NewBuilder = std::function; + static ReadBufferBuilderFactory & instance(); + ReadBufferBuilderPtr createBuilder(const String & schema, DB::ContextPtr context); + + void registerBuilder(const String & schema, NewBuilder newer); + +private: + std::map builders; +}; + +void registerReadBufferBuilders(); +} diff --git a/utils/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp b/utils/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp new file mode 100644 index 000000000000..e38943964522 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp @@ -0,0 +1,417 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +namespace DB +{ +namespace ErrorCodes +{ + extern const int UNKNOWN_TYPE; + extern const int LOGICAL_ERROR; +} +} +namespace local_engine +{ + +// When run query "select count(*) from t", there is no any column to be read. +// The number of rows is the only needed information. To handle these cases, we +// build blocks with a const virtual column to indicate how many rows is in it. +static DB::Block getRealHeader(const DB::Block & header) +{ + if (header.columns()) + return header; + return BlockUtil::buildRowCountHeader(); +} + +SubstraitFileSource::SubstraitFileSource(DB::ContextPtr context_, const DB::Block & header_, const substrait::ReadRel::LocalFiles & file_infos) + : DB::ISource(getRealHeader(header_), false) + , context(context_) + , output_header(header_) +{ + /** + * We may query part fields of a struct column. For example, we have a column c in type + * struct{x:int, y:int, z:int}, and just want fields c.x and c.y. In the substraint plan, we get + * a column c described in type struct{x:int, y:int} which is not matched with the original + * struct type and cause some exceptions. To solve this, we flatten all struct columns into + * independent field columns recursively, and fold the field columns back into struct columns + * at the end. + */ + to_read_header = BlockUtil::flattenBlock(output_header, BlockUtil::FLAT_STRUCT, true); + flatten_output_header = to_read_header; + if (file_infos.items_size()) + { + Poco::URI file_uri(file_infos.items().Get(0).uri_file()); + read_buffer_builder = ReadBufferBuilderFactory::instance().createBuilder(file_uri.getScheme(), context); + for (const auto & item : file_infos.items()) + { + files.emplace_back(FormatFileUtil::createFile(context, read_buffer_builder, item)); + } + + auto partition_keys = files[0]->getFilePartitionKeys(); + /// file partition keys are read from the file path + for (const auto & key : partition_keys) + { + to_read_header.erase(key); + } + } +} + + +DB::Chunk SubstraitFileSource::generate() +{ + while(true) + { + if (!tryPrepareReader()) + { + /// all files finished + return {}; + } + + DB::Chunk chunk; + if (file_reader->pull(chunk)) + { + if (output_header.columns()) + { + auto result_block = foldFlattenColumns(chunk.detachColumns(), output_header); + auto cols = result_block.getColumns(); + return DB::Chunk(cols, result_block.rows()); + } + else + { + // The count(*)/count(1) case + return chunk; + } + } + + /// try to read from next file + file_reader.reset(); + } + return {}; +} + +bool SubstraitFileSource::tryPrepareReader() +{ + if (file_reader) [[likely]] + return true; + + if (current_file_index >= files.size()) + return false; + + auto current_file = files[current_file_index]; + current_file_index += 1; + + if (!current_file->supportSplit() && current_file->getStartOffset()) + { + /// For the files do not support split strategy, the task with not 0 offset will generate empty data + file_reader = std::make_unique(current_file); + return true; + } + if (!to_read_header.columns()) + { + auto total_rows = current_file->getTotalRows(); + if (total_rows) + { + file_reader = std::make_unique(current_file, context, output_header, *total_rows); + } + else + { + /// TODO: It may be a text format file that we do not have the stat metadata, e.g. total rows. + /// If we can get the file's schema, we can try to read all columns out. maybe consider make this + /// scan action fallback into spark + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, + "All columns to read is partition columns, but this file({}) doesn't support this case.", + current_file->getURIPath()); + } + } + else + { + file_reader = std::make_unique(current_file, context, to_read_header, flatten_output_header); + } + + return true; +} + +DB::Block SubstraitFileSource::foldFlattenColumns(const DB::Columns & cols, const DB::Block & header) +{ + DB::ColumnsWithTypeAndName result_cols; + size_t pos = 0; + for (size_t i = 0; i < header.columns(); ++i) + { + const auto & named_col = header.getByPosition(i); + auto new_col = foldFlattenColumn(named_col.type, named_col.name, pos, cols); + result_cols.push_back(new_col); + } + return DB::Block(std::move(result_cols)); +} + +DB::ColumnWithTypeAndName +SubstraitFileSource::foldFlattenColumn(DB::DataTypePtr col_type, const std::string & col_name, size_t & pos, const DB::Columns & cols) +{ + DB::DataTypePtr nested_type = nullptr; + if (col_type->isNullable()) + { + nested_type = typeid_cast(col_type.get())->getNestedType(); + } + else + { + nested_type = col_type; + } + + const DB::DataTypeTuple * type_tuple = typeid_cast(nested_type.get()); + if (type_tuple) + { + if (type_tuple->haveExplicitNames()) + { + const auto & field_types = type_tuple->getElements(); + const auto & field_names = type_tuple->getElementNames(); + size_t fields_num = field_names.size(); + DB::Columns tuple_cols; + for (size_t i = 0; i < fields_num; ++i) + { + auto named_col = foldFlattenColumn(field_types[i], field_names[i], pos, cols); + tuple_cols.push_back(named_col.column); + } + auto tuple_col = DB::ColumnTuple::create(std::move(tuple_cols)); + + // The original type col_type may be wrapped by nullable, so add a cast here. + DB::ColumnWithTypeAndName ret_col(std::move(tuple_col), nested_type, col_name); + ret_col.column = DB::castColumn(ret_col, col_type); + ret_col.type = col_type; + return ret_col; + } + else + { + size_t current_pos = pos; + pos += 1; + return DB::ColumnWithTypeAndName(cols[current_pos], col_type, col_name); + } + } + + size_t current_pos = pos; + pos += 1; + return DB::ColumnWithTypeAndName(cols[current_pos], col_type, col_name); +} + +DB::ColumnPtr FileReaderWrapper::createConstColumn(DB::DataTypePtr data_type, const DB::Field & field, size_t rows) +{ + auto nested_type = DB::removeNullable(data_type); + auto column = nested_type->createColumnConst(rows, field); + + if (data_type->isNullable()) + column = DB::ColumnNullable::create(column, DB::ColumnUInt8::create(rows, 0)); + return column; +} + +DB::ColumnPtr FileReaderWrapper::createColumn(DB::DataTypePtr data_type, size_t rows, const String & value) +{ + if (StringUtils::isNullPartitionValue(value)) + { + auto nested_type = DB::removeNullable(data_type); + auto column = nested_type->createColumnConstWithDefaultValue(rows); + return DB::ColumnNullable::create(column, DB::ColumnUInt8::create(rows, 1)); + } + else + { + auto field = buildFieldFromString(value, data_type); + return createConstColumn(data_type, field, rows); + } +} + +#define BUILD_INT_FIELD(type) \ + [](DB::ReadBuffer & in, const String &) \ + {\ + type value = 0;\ + DB::readIntText(value, in);\ + return DB::Field(value);\ + } + +#define BUILD_FP_FIELD(type) \ + [](DB::ReadBuffer & in, const String &) \ + {\ + type value = 0.0;\ + DB::readFloatText(value, in);\ + return DB::Field(value);\ + } + +DB::Field FileReaderWrapper::buildFieldFromString(const String & str_value, DB::DataTypePtr type) +{ + using FieldBuilder = std::function; + static std::map field_builders + = {{magic_enum::enum_integer(DB::TypeIndex::Int8), BUILD_INT_FIELD(Int8) }, + {magic_enum::enum_integer(DB::TypeIndex::Int16), BUILD_INT_FIELD(Int16) }, + {magic_enum::enum_integer(DB::TypeIndex::Int32), BUILD_INT_FIELD(Int32) }, + {magic_enum::enum_integer(DB::TypeIndex::Int64), BUILD_INT_FIELD(Int64) }, + {magic_enum::enum_integer(DB::TypeIndex::Float32), BUILD_FP_FIELD(DB::Float32) }, + {magic_enum::enum_integer(DB::TypeIndex::Float64), BUILD_FP_FIELD(DB::Float64)}, + {magic_enum::enum_integer(DB::TypeIndex::String), [](DB::ReadBuffer &, const String & val) { return DB::Field(val); }}, + {magic_enum::enum_integer(DB::TypeIndex::Date), + [](DB::ReadBuffer & in, const String &) + { + DayNum value; + readDateText(value, in); + return DB::Field(value); + }}, + {magic_enum::enum_integer(DB::TypeIndex::Date32), + [](DB::ReadBuffer & in, const String &) + { + ExtendedDayNum value; + readDateText(value, in); + return DB::Field(value.toUnderType()); + }}}; + + auto nested_type = DB::removeNullable(type); + DB::ReadBufferFromString read_buffer(str_value); + auto it = field_builders.find(magic_enum::enum_integer(nested_type->getTypeId())); + if (it == field_builders.end()) + throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Unsupported data type {}", type->getFamilyName()); + return it->second(read_buffer, str_value); +} + +ConstColumnsFileReader::ConstColumnsFileReader(FormatFilePtr file_, DB::ContextPtr context_, const DB::Block & header_, size_t block_size_) + : FileReaderWrapper(file_) + , context(context_) + , header(header_) + , remained_rows(0) + , block_size(block_size_) +{ + auto rows = file->getTotalRows(); + if (!rows) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Cannot get total rows number from file : {}", file->getURIPath()); + remained_rows = *rows; +} + +bool ConstColumnsFileReader::pull(DB::Chunk & chunk) +{ + if (!remained_rows) [[unlikely]] + return false; + size_t to_read_rows = 0; + if (remained_rows < block_size) + { + to_read_rows = remained_rows; + remained_rows = 0; + } + else + { + to_read_rows = block_size; + remained_rows -= block_size; + } + DB::Columns res_columns; + size_t columns_num = header.columns(); + if (columns_num) + { + res_columns.reserve(columns_num); + const auto & partition_values = file->getFilePartitionValues(); + for (size_t pos = 0; pos < columns_num; ++pos) + { + auto col_with_name_and_type = header.getByPosition(pos); + auto type = col_with_name_and_type.type; + const auto & name = col_with_name_and_type.name; + auto it = partition_values.find(name); + if (it == partition_values.end()) [[unlikely]] + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow partition column : {}", name); + } + res_columns.emplace_back(createColumn(type, to_read_rows, it->second)); + } + } + else + { + // the original header is empty, build a block to represent the row count. + res_columns = BlockUtil::buildRowCountChunk(to_read_rows).detachColumns(); + } + + chunk = DB::Chunk(std::move(res_columns), to_read_rows); + return true; +} + +NormalFileReader::NormalFileReader(FormatFilePtr file_, DB::ContextPtr context_, const DB::Block & to_read_header_, const DB::Block & output_header_) + : FileReaderWrapper(file_) + , context(context_) + , to_read_header(to_read_header_) + , output_header(output_header_) +{ + input_format = file->createInputFormat(to_read_header); + DB::Pipe pipe(input_format->input); + pipeline = std::make_unique(std::move(pipe)); + reader = std::make_unique(*pipeline); +} + + +bool NormalFileReader::pull(DB::Chunk & chunk) +{ + DB::Chunk tmp_chunk; + auto status = reader->pull(tmp_chunk); + if (!status) + { + return false; + } + + size_t rows = tmp_chunk.getNumRows(); + if (!rows) + return false; + + auto read_columns = tmp_chunk.detachColumns(); + DB::Columns res_columns; + auto columns_with_name_and_type = output_header.getColumnsWithTypeAndName(); + auto partition_values = file->getFilePartitionValues(); + + for (auto & column : columns_with_name_and_type) + { + if (to_read_header.has(column.name)) + { + auto pos = to_read_header.getPositionByName(column.name); + res_columns.push_back(read_columns[pos]); + } + else + { + auto it = partition_values.find(column.name); + if (it == partition_values.end()) + { + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, "Not found column({}) from file({}) partition keys.", column.name, file->getURIPath()); + } + res_columns.push_back(createColumn(column.type, rows, it->second)); + } + } + chunk = DB::Chunk(std::move(res_columns), rows); + return true; +} +} diff --git a/utils/local-engine/Storages/SubstraitSource/SubstraitFileSource.h b/utils/local-engine/Storages/SubstraitSource/SubstraitFileSource.h new file mode 100644 index 000000000000..80caed8d9134 --- /dev/null +++ b/utils/local-engine/Storages/SubstraitSource/SubstraitFileSource.h @@ -0,0 +1,108 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace local_engine +{ + +class FileReaderWrapper +{ +public: + explicit FileReaderWrapper(FormatFilePtr file_) : file(file_) {} + virtual ~FileReaderWrapper() = default; + virtual bool pull(DB::Chunk & chunk) = 0; + +protected: + FormatFilePtr file; + + static DB::ColumnPtr createConstColumn(DB::DataTypePtr type, const DB::Field & field, size_t rows); + static DB::ColumnPtr createColumn(DB::DataTypePtr data_type, size_t rows, const String & value); + static DB::Field buildFieldFromString(const String & value, DB::DataTypePtr type); +}; + +class NormalFileReader : public FileReaderWrapper +{ +public: + NormalFileReader(FormatFilePtr file_, DB::ContextPtr context_, const DB::Block & to_read_header_, const DB::Block & output_header_); + ~NormalFileReader() override = default; + bool pull(DB::Chunk & chunk) override; + +private: + DB::ContextPtr context; + DB::Block to_read_header; + DB::Block output_header; + + FormatFile::InputFormatPtr input_format; + std::unique_ptr pipeline; + std::unique_ptr reader; +}; + +class EmptyFileReader : public FileReaderWrapper +{ +public: + explicit EmptyFileReader(FormatFilePtr file_) : FileReaderWrapper(file_) {} + ~EmptyFileReader() override = default; + bool pull(DB::Chunk &) override { return false; } +}; + +class ConstColumnsFileReader : public FileReaderWrapper +{ +public: + ConstColumnsFileReader(FormatFilePtr file_, DB::ContextPtr context_, const DB::Block & header_, size_t block_size_ = DEFAULT_BLOCK_SIZE); + ~ConstColumnsFileReader() override = default; + bool pull(DB::Chunk & chunk); +private: + DB::ContextPtr context; + DB::Block header; + size_t remained_rows; + size_t block_size; +}; + +class SubstraitFileSource : public DB::ISource +{ +public: + SubstraitFileSource(DB::ContextPtr context_, const DB::Block & header_, const substrait::ReadRel::LocalFiles & file_infos); + ~SubstraitFileSource() override = default; + + String getName() const override + { + return "SubstraitFileSource"; + } +protected: + DB::Chunk generate() override; +private: + DB::ContextPtr context; + DB::Block output_header; + DB::Block flatten_output_header; // flatten a struct column into independent field columns recursively + DB::Block to_read_header; // Not include partition keys + FormatFiles files; + + UInt32 current_file_index = 0; + std::unique_ptr file_reader; + ReadBufferBuilderPtr read_buffer_builder; + + bool tryPrepareReader(); + + // E.g we have flatten columns correspond to header {a:int, b.x.i: int, b.x.j: string, b.y: string} + // but we want to fold all the flatten struct columns into one struct column, + // {a:int, b: {x: {i: int, j: string}, y: string}} + // Notice, don't support list with named struct. ClickHouse may take advantage of this to support + // nested table, but not the case in spark. + static DB::Block foldFlattenColumns(const DB::Columns & cols, const DB::Block & header); + static DB::ColumnWithTypeAndName + foldFlattenColumn(DB::DataTypePtr col_type, const std::string & col_name, size_t & pos, const DB::Columns & cols); +}; +} diff --git a/utils/local-engine/Storages/ch_parquet/CMakeLists.txt b/utils/local-engine/Storages/ch_parquet/CMakeLists.txt new file mode 100644 index 000000000000..5fc365c983d7 --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/CMakeLists.txt @@ -0,0 +1,44 @@ + +set(ARROW_INCLUDE_DIR "${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/src") + +macro(add_headers_and_sources_including_cc prefix common_path) + add_glob(${prefix}_headers ${CMAKE_CURRENT_SOURCE_DIR} ${common_path}/*.h) + add_glob(${prefix}_sources ${common_path}/*.cpp ${common_path}/*.c ${common_path}/*.cc ${common_path}/*.h) +endmacro() + +add_headers_and_sources(ch_parquet .) +add_headers_and_sources_including_cc(ch_parquet arrow) + +add_library(ch_parquet ${ch_parquet_sources}) + +target_compile_options(ch_parquet PUBLIC -fPIC + -Wno-shorten-64-to-32 + -Wno-shadow-field-in-constructor + -Wno-return-type + -Wno-reserved-identifier + -Wno-extra-semi-stmt + -Wno-extra-semi + -Wno-unused-result + -Wno-unreachable-code-return + -Wno-unused-parameter + -Wno-unreachable-code + -Wno-pessimizing-move + -Wno-unreachable-code-break + -Wno-unused-variable + -Wno-inconsistent-missing-override + -Wno-shadow-uncaptured-local + -Wno-suggest-override + -Wno-unused-member-function + -Wno-deprecated-this-capture +) + +target_link_libraries(ch_parquet PUBLIC + boost::headers_only + clickhouse_common_io +) + +target_include_directories(ch_parquet SYSTEM BEFORE PUBLIC + ${ARROW_INCLUDE_DIR} + ${CMAKE_BINARY_DIR}/contrib/arrow-cmake/cpp/src + ${ClickHouse_SOURCE_DIR}/contrib/arrow-cmake/cpp/src +) \ No newline at end of file diff --git a/utils/local-engine/Storages/ch_parquet/OptimizedArrowColumnToCHColumn.cpp b/utils/local-engine/Storages/ch_parquet/OptimizedArrowColumnToCHColumn.cpp new file mode 100644 index 000000000000..cdc71c5f30ef --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/OptimizedArrowColumnToCHColumn.cpp @@ -0,0 +1,682 @@ +#include "OptimizedArrowColumnToCHColumn.h" + +#if USE_ARROW || USE_ORC || USE_PARQUET + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/column_reader.h" + +#include +#include + +/// UINT16 and UINT32 are processed separately, see comments in readColumnFromArrowColumn. +#define FOR_ARROW_NUMERIC_TYPES(M) \ + M(arrow::Type::UINT8, DB::UInt8) \ + M(arrow::Type::INT8, DB::Int8) \ + M(arrow::Type::INT16, DB::Int16) \ + M(arrow::Type::INT32, DB::Int32) \ + M(arrow::Type::UINT64, DB::UInt64) \ + M(arrow::Type::INT64, DB::Int64) \ + M(arrow::Type::HALF_FLOAT, DB::Float32) \ + M(arrow::Type::FLOAT, DB::Float32) \ + M(arrow::Type::DOUBLE, DB::Float64) + +#define FOR_ARROW_INDEXES_TYPES(M) \ + M(arrow::Type::UINT8, DB::UInt8) \ + M(arrow::Type::INT8, DB::UInt8) \ + M(arrow::Type::UINT16, DB::UInt16) \ + M(arrow::Type::INT16, DB::UInt16) \ + M(arrow::Type::UINT32, DB::UInt32) \ + M(arrow::Type::INT32, DB::UInt32) \ + M(arrow::Type::UINT64, DB::UInt64) \ + M(arrow::Type::INT64, DB::UInt64) + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int UNKNOWN_TYPE; + extern const int VALUE_IS_OUT_OF_RANGE_OF_DATA_TYPE; + extern const int BAD_ARGUMENTS; + extern const int DUPLICATE_COLUMN; + extern const int THERE_IS_NO_COLUMN; + extern const int UNKNOWN_EXCEPTION; + extern const int INCORRECT_NUMBER_OF_COLUMNS; +} + +/// Inserts numeric data right into internal column data to reduce an overhead +template > +static ColumnWithTypeAndName readColumnWithNumericData(std::shared_ptr & arrow_column, const String & column_name) +{ + auto internal_type = std::make_shared>(); + auto internal_column = internal_type->createColumn(); + auto & column_data = static_cast(*internal_column).getData(); + column_data.reserve(arrow_column->length()); + + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + std::shared_ptr chunk = arrow_column->chunk(chunk_i); + if (chunk->length() == 0) + continue; + + /// buffers[0] is a null bitmap and buffers[1] are actual values + std::shared_ptr buffer = chunk->data()->buffers[1]; + const auto * raw_data = reinterpret_cast(buffer->data()); + column_data.insert_assume_reserved(raw_data, raw_data + chunk->length()); + } + return {std::move(internal_column), std::move(internal_type), column_name}; +} + +/// Inserts chars and offsets right into internal column data to reduce an overhead. +/// Internal offsets are shifted by one to the right in comparison with Arrow ones. So the last offset should map to the end of all chars. +/// Also internal strings are null terminated. +static ColumnWithTypeAndName readColumnWithStringData(std::shared_ptr & arrow_column, const String & column_name) +{ + auto internal_type = std::make_shared(); + auto internal_column = internal_type->createColumn(); + PaddedPODArray & column_chars_t = assert_cast(*internal_column).getChars(); + PaddedPODArray & column_offsets = assert_cast(*internal_column).getOffsets(); + + size_t chars_t_size = 0; + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + arrow::BinaryArray & chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + const size_t chunk_length = chunk.length(); + + if (chunk_length > 0) + { + chars_t_size += chunk.value_offset(chunk_length - 1) + chunk.value_length(chunk_length - 1); + chars_t_size += chunk_length; /// additional space for null bytes + } + } + + column_chars_t.reserve(chars_t_size); + column_offsets.reserve(arrow_column->length()); + + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + arrow::BinaryArray & chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + std::shared_ptr buffer = chunk.value_data(); + const size_t chunk_length = chunk.length(); + + for (size_t offset_i = 0; offset_i != chunk_length; ++offset_i) + { + if (!chunk.IsNull(offset_i) && buffer) + { + const auto * raw_data = buffer->data() + chunk.value_offset(offset_i); + column_chars_t.insert_assume_reserved(raw_data, raw_data + chunk.value_length(offset_i)); + } + column_chars_t.emplace_back('\0'); + + column_offsets.emplace_back(column_chars_t.size()); + } + } + return {std::move(internal_column), std::move(internal_type), column_name}; +} + +static ColumnWithTypeAndName readColumnWithBooleanData(std::shared_ptr & arrow_column, const String & column_name) +{ + auto internal_type = std::make_shared(); + auto internal_column = internal_type->createColumn(); + auto & column_data = assert_cast &>(*internal_column).getData(); + column_data.reserve(arrow_column->length()); + + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + arrow::BooleanArray & chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + if (chunk.length() == 0) + continue; + + /// buffers[0] is a null bitmap and buffers[1] are actual values + std::shared_ptr buffer = chunk.data()->buffers[1]; + + for (size_t bool_i = 0; bool_i != static_cast(chunk.length()); ++bool_i) + column_data.emplace_back(chunk.Value(bool_i)); + } + return {std::move(internal_column), std::move(internal_type), column_name}; +} + +static ColumnWithTypeAndName readColumnWithDate32Data(std::shared_ptr & arrow_column, const String & column_name) +{ + auto internal_type = std::make_shared(); + auto internal_column = internal_type->createColumn(); + PaddedPODArray & column_data = assert_cast &>(*internal_column).getData(); + column_data.reserve(arrow_column->length()); + + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + arrow::Date32Array & chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + + /// buffers[0] is a null bitmap and buffers[1] are actual values + std::shared_ptr buffer = chunk.data()->buffers[1]; + const auto * raw_data = reinterpret_cast(buffer->data()); + column_data.insert_assume_reserved(raw_data, raw_data + chunk.length()); + + const Int32* p_end = raw_data + chunk.length(); + for (Int32* p = const_cast(raw_data); p < p_end; ++p) + { + if (unlikely(*p > DATE_LUT_MAX_EXTEND_DAY_NUM)) + throw Exception{ErrorCodes::VALUE_IS_OUT_OF_RANGE_OF_DATA_TYPE, + "Input value {} of a column \"{}\" is greater than max allowed Date value, which is {}", *p, column_name, DATE_LUT_MAX_DAY_NUM}; + } + } + return {std::move(internal_column), std::move(internal_type), column_name}; +} + +/// Arrow stores Parquet::DATETIME in Int64, while ClickHouse stores DateTime in UInt32. Therefore, it should be checked before saving +static ColumnWithTypeAndName readColumnWithDate64Data(std::shared_ptr & arrow_column, const String & column_name) +{ + auto internal_type = std::make_shared(); + auto internal_column = internal_type->createColumn(); + auto & column_data = assert_cast &>(*internal_column).getData(); + column_data.reserve(arrow_column->length()); + + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + auto & chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + for (size_t value_i = 0, length = static_cast(chunk.length()); value_i < length; ++value_i) + { + auto timestamp = static_cast(chunk.Value(value_i) / 1000); // Always? in ms + column_data.emplace_back(timestamp); + } + } + return {std::move(internal_column), std::move(internal_type), column_name}; +} + +static ColumnWithTypeAndName readColumnWithTimestampData(std::shared_ptr & arrow_column, const String & column_name) +{ + const auto & arrow_type = static_cast(*(arrow_column->type())); + const UInt8 scale = arrow_type.unit() * 3; + auto internal_type = std::make_shared(scale, arrow_type.timezone()); + auto internal_column = internal_type->createColumn(); + auto & column_data = assert_cast &>(*internal_column).getData(); + column_data.reserve(arrow_column->length()); + + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + const auto & chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + for (size_t value_i = 0, length = static_cast(chunk.length()); value_i < length; ++value_i) + { + column_data.emplace_back(chunk.Value(value_i)); + } + } + return {std::move(internal_column), std::move(internal_type), column_name}; +} + +template +static ColumnWithTypeAndName readColumnWithDecimalDataImpl(std::shared_ptr & arrow_column, const String & column_name, DataTypePtr internal_type) +{ + auto internal_column = internal_type->createColumn(); + auto & column = assert_cast &>(*internal_column); + auto & column_data = column.getData(); + column_data.reserve(arrow_column->length()); + + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + auto & chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + for (size_t value_i = 0, length = static_cast(chunk.length()); value_i < length; ++value_i) + { + column_data.emplace_back(chunk.IsNull(value_i) ? DecimalType(0) : *reinterpret_cast(chunk.Value(value_i))); // TODO: copy column + } + } + return {std::move(internal_column), internal_type, column_name}; +} + +template +static ColumnWithTypeAndName readColumnWithDecimalData(std::shared_ptr & arrow_column, const String & column_name) +{ + const auto * arrow_decimal_type = static_cast(arrow_column->type().get()); + size_t precision = arrow_decimal_type->precision(); + auto internal_type = createDecimal(precision, arrow_decimal_type->scale()); + if (precision <= DecimalUtils::max_precision) + return readColumnWithDecimalDataImpl(arrow_column, column_name, internal_type); + else if (precision <= DecimalUtils::max_precision) + return readColumnWithDecimalDataImpl(arrow_column, column_name, internal_type); + else if (precision <= DecimalUtils::max_precision) + return readColumnWithDecimalDataImpl(arrow_column, column_name, internal_type); + return readColumnWithDecimalDataImpl(arrow_column, column_name, internal_type); +} + +/// Creates a null bytemap from arrow's null bitmap +static ColumnPtr readByteMapFromArrowColumn(std::shared_ptr & arrow_column) +{ + auto nullmap_column = ColumnUInt8::create(); + PaddedPODArray & bytemap_data = assert_cast &>(*nullmap_column).getData(); + bytemap_data.reserve(arrow_column->length()); + + for (size_t chunk_i = 0; chunk_i != static_cast(arrow_column->num_chunks()); ++chunk_i) + { + std::shared_ptr chunk = arrow_column->chunk(chunk_i); + + for (size_t value_i = 0; value_i != static_cast(chunk->length()); ++value_i) + bytemap_data.emplace_back(chunk->IsNull(value_i)); + } + return std::move(nullmap_column); +} + +static ColumnPtr readOffsetsFromArrowListColumn(std::shared_ptr & arrow_column) +{ + auto offsets_column = ColumnUInt64::create(); + ColumnArray::Offsets & offsets_data = assert_cast &>(*offsets_column).getData(); + offsets_data.reserve(arrow_column->length()); + + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + arrow::ListArray & list_chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + auto arrow_offsets_array = list_chunk.offsets(); + auto & arrow_offsets = dynamic_cast(*arrow_offsets_array); + auto start = offsets_data.back(); + for (int64_t i = 1; i < arrow_offsets.length(); ++i) + offsets_data.emplace_back(start + arrow_offsets.Value(i)); + } + return std::move(offsets_column); +} + +static ColumnPtr readColumnWithIndexesData(std::shared_ptr & arrow_column) +{ + switch (arrow_column->type()->id()) + { +# define DISPATCH(ARROW_NUMERIC_TYPE, CPP_NUMERIC_TYPE) \ + case ARROW_NUMERIC_TYPE: \ + { \ + return readColumnWithNumericData(arrow_column, "").column; \ + } + FOR_ARROW_INDEXES_TYPES(DISPATCH) +# undef DISPATCH + default: + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unsupported type for indexes in LowCardinality: {}.", arrow_column->type()->name()); + } +} + +static std::shared_ptr getNestedArrowColumn(std::shared_ptr & arrow_column) +{ + arrow::ArrayVector array_vector; + array_vector.reserve(arrow_column->num_chunks()); + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + arrow::ListArray & list_chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + std::shared_ptr chunk = list_chunk.values(); + array_vector.emplace_back(std::move(chunk)); + } + return std::make_shared(array_vector); +} + +static ColumnWithTypeAndName readColumnFromArrowColumn( + const std::shared_ptr & arrow_field, + std::shared_ptr & arrow_column, + const std::string & format_name, + std::unordered_map> & dictionary_values, + bool read_ints_as_dates) +{ + const auto is_nullable = arrow_field->nullable(); + const auto column_name = arrow_field->name(); + if (is_nullable) + { + auto nested_column = readColumnFromArrowColumn(arrow_field->WithNullable(false), arrow_column, format_name, dictionary_values, read_ints_as_dates); + auto nullmap_column = readByteMapFromArrowColumn(arrow_column); + auto nullable_type = std::make_shared(std::move(nested_column.type)); + auto nullable_column = ColumnNullable::create(nested_column.column, nullmap_column); + return {std::move(nullable_column), std::move(nullable_type), column_name}; + } + + auto * ch_chunk_array_p = dynamic_cast(arrow_column->chunk(0).get()); + if (ch_chunk_array_p != nullptr) + { + //the values are already written into CH Column, not arrow array + ch_chunk_array_p->column.name = column_name; + return ch_chunk_array_p->column; + } + + switch (arrow_column->type()->id()) + { + case arrow::Type::STRING: + case arrow::Type::BINARY: + //case arrow::Type::FIXED_SIZE_BINARY: + return readColumnWithStringData(arrow_column, column_name); + case arrow::Type::BOOL: + return readColumnWithBooleanData(arrow_column, column_name); + case arrow::Type::DATE32: + return readColumnWithDate32Data(arrow_column, column_name); + case arrow::Type::DATE64: + return readColumnWithDate64Data(arrow_column, column_name); + // ClickHouse writes Date as arrow UINT16 and DateTime as arrow UINT32, + // so, read UINT16 as Date and UINT32 as DateTime to perform correct conversion + // between Date and DateTime further. + case arrow::Type::UINT16: + { + auto column = readColumnWithNumericData(arrow_column, column_name); + if (read_ints_as_dates) + column.type = std::make_shared(); + return column; + } + case arrow::Type::UINT32: + { + auto column = readColumnWithNumericData(arrow_column, column_name); + if (read_ints_as_dates) + column.type = std::make_shared(); + return column; + } + case arrow::Type::TIMESTAMP: + return readColumnWithTimestampData(arrow_column, column_name); + case arrow::Type::DECIMAL128: + return readColumnWithDecimalData(arrow_column, column_name); + case arrow::Type::DECIMAL256: + return readColumnWithDecimalData(arrow_column, column_name); + case arrow::Type::MAP: + { + const auto arrow_nested_field = arrow_field->type()->field(0); + auto arrow_nested_column = getNestedArrowColumn(arrow_column); + auto nested_column + = readColumnFromArrowColumn(arrow_nested_field, arrow_nested_column, format_name, dictionary_values, read_ints_as_dates); + auto offsets_column = readOffsetsFromArrowListColumn(arrow_column); + + const auto * tuple_column = assert_cast(nested_column.column.get()); + const auto * tuple_type = assert_cast(nested_column.type.get()); + auto map_column = ColumnMap::create(tuple_column->getColumnPtr(0), tuple_column->getColumnPtr(1), offsets_column); + auto map_type = std::make_shared(tuple_type->getElements()[0], tuple_type->getElements()[1]); + return {std::move(map_column), std::move(map_type), column_name}; + } + case arrow::Type::LIST: + { + const auto arrow_nested_field = arrow_field->type()->field(0); + auto arrow_nested_column = getNestedArrowColumn(arrow_column); + auto nested_column + = readColumnFromArrowColumn(arrow_nested_field, arrow_nested_column, format_name, dictionary_values, read_ints_as_dates); + auto offsets_column = readOffsetsFromArrowListColumn(arrow_column); + auto array_column = ColumnArray::create(nested_column.column, offsets_column); + auto array_type = std::make_shared(nested_column.type); + return {std::move(array_column), std::move(array_type), column_name}; + } + case arrow::Type::STRUCT: + { + auto arrow_type = arrow_field->type(); + auto * arrow_struct_type = assert_cast(arrow_type.get()); + std::vector nested_arrow_columns(arrow_struct_type->num_fields()); + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + arrow::StructArray & struct_chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + for (int i = 0; i < arrow_struct_type->num_fields(); ++i) + nested_arrow_columns[i].emplace_back(struct_chunk.field(i)); + } + + std::vector tuple_names; + DataTypes tuple_types; + Columns tuple_elements; + + for (int i = 0; i != arrow_struct_type->num_fields(); ++i) + { + const auto & nested_arrow_field = arrow_struct_type->field(i); + auto nested_arrow_column = std::make_shared(nested_arrow_columns[i]); + auto element = readColumnFromArrowColumn( + nested_arrow_field, nested_arrow_column, format_name, dictionary_values, read_ints_as_dates); + tuple_names.emplace_back(std::move(element.name)); + tuple_types.emplace_back(std::move(element.type)); + tuple_elements.emplace_back(std::move(element.column)); + } + + auto tuple_column = ColumnTuple::create(std::move(tuple_elements)); + auto tuple_type = std::make_shared(std::move(tuple_types), std::move(tuple_names)); + return {std::move(tuple_column), std::move(tuple_type), column_name}; + } + case arrow::Type::DICTIONARY: + { + auto & dict_values = dictionary_values[column_name]; + /// Load dictionary values only once and reuse it. + if (!dict_values) + { + arrow::ArrayVector dict_array; + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + arrow::DictionaryArray & dict_chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + dict_array.emplace_back(dict_chunk.dictionary()); + } + + auto * arrow_dict_type = assert_cast(arrow_field->type().get()); + auto arrow_dict_field = arrow::field("dict", arrow_dict_type->value_type()); + auto arrow_dict_column = std::make_shared(dict_array); + auto dict_column = readColumnFromArrowColumn(arrow_dict_field, arrow_dict_column, format_name, dictionary_values, read_ints_as_dates); + + /// We should convert read column to ColumnUnique. + auto tmp_lc_column = DataTypeLowCardinality(dict_column.type).createColumn(); + auto tmp_dict_column = IColumn::mutate(assert_cast(tmp_lc_column.get())->getDictionaryPtr()); + static_cast(tmp_dict_column.get())->uniqueInsertRangeFrom(*dict_column.column, 0, dict_column.column->size()); + dict_column.column = std::move(tmp_dict_column); + dict_values = std::make_shared(std::move(dict_column)); + } + + arrow::ArrayVector indexes_array; + for (size_t chunk_i = 0, num_chunks = static_cast(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i) + { + arrow::DictionaryArray & dict_chunk = dynamic_cast(*(arrow_column->chunk(chunk_i))); + indexes_array.emplace_back(dict_chunk.indices()); + } + + auto arrow_indexes_column = std::make_shared(indexes_array); + auto indexes_column = readColumnWithIndexesData(arrow_indexes_column); + auto lc_column = ColumnLowCardinality::create(dict_values->column, indexes_column); + auto lc_type = std::make_shared(dict_values->type); + return {std::move(lc_column), std::move(lc_type), column_name}; + } +# define DISPATCH(ARROW_NUMERIC_TYPE, CPP_NUMERIC_TYPE) \ + case ARROW_NUMERIC_TYPE: \ + return readColumnWithNumericData(arrow_column, column_name); + FOR_ARROW_NUMERIC_TYPES(DISPATCH) +# undef DISPATCH + // TODO: read JSON as a string? + // TODO: read UUID as a string? + default: + throw Exception(ErrorCodes::UNKNOWN_TYPE, + "Unsupported {} type '{}' of an input column '{}'.", format_name, arrow_column->type()->name(), column_name); + } +} + + +// Creating CH header by arrow schema. Will be useful in task about inserting +// data from file without knowing table structure. + +static void checkStatus(const arrow::Status & status, const String & column_name, const String & format_name) +{ + if (!status.ok()) + throw Exception{ErrorCodes::UNKNOWN_EXCEPTION, "Error with a {} column '{}': {}.", format_name, column_name, status.ToString()}; +} + +Block OptimizedArrowColumnToCHColumn::arrowSchemaToCHHeader(const arrow::Schema & schema, const std::string & format_name) +{ + ColumnsWithTypeAndName sample_columns; + for (const auto & field : schema.fields()) + { + /// Create empty arrow column by it's type and convert it to ClickHouse column. + arrow::MemoryPool* pool = arrow::default_memory_pool(); + std::unique_ptr array_builder; + arrow::Status status = MakeBuilder(pool, field->type(), &array_builder); + checkStatus(status, field->name(), format_name); + + std::shared_ptr arrow_array; + status = array_builder->Finish(&arrow_array); + checkStatus(status, field->name(), format_name); + + arrow::ArrayVector array_vector = {arrow_array}; + auto arrow_column = std::make_shared(array_vector); + std::unordered_map> dict_values; + ColumnWithTypeAndName sample_column = readColumnFromArrowColumn(field, arrow_column, format_name, dict_values, false); + // std::cerr << "field:" << field->ToString() << ", datatype:" << sample_column.type->getName() << std::endl; + + sample_columns.emplace_back(std::move(sample_column)); + } + return Block(std::move(sample_columns)); +} + +OptimizedArrowColumnToCHColumn::OptimizedArrowColumnToCHColumn( + const Block & header_, const std::string & format_name_, bool import_nested_, bool allow_missing_columns_) + : header(header_), format_name(format_name_), import_nested(import_nested_), allow_missing_columns(allow_missing_columns_) +{ +} + +void OptimizedArrowColumnToCHColumn::arrowTableToCHChunk(Chunk & res, std::shared_ptr & table) +{ + NameToColumnPtr name_to_column_ptr; + for (const auto & column_name : table->ColumnNames()) + { + std::shared_ptr arrow_column = table->GetColumnByName(column_name); + if (!arrow_column) + throw Exception(ErrorCodes::DUPLICATE_COLUMN, "Column '{}' is duplicated", column_name); + name_to_column_ptr[column_name] = arrow_column; + } + + Stopwatch sw; + sw.start(); + arrowColumnsToCHChunk(res, name_to_column_ptr, table->schema()); + real_convert += sw.elapsedNanoseconds(); +} + +void OptimizedArrowColumnToCHColumn::arrowColumnsToCHChunk( + Chunk & res, NameToColumnPtr & name_to_column_ptr, const std::shared_ptr & schema) +{ + if (unlikely(name_to_column_ptr.empty())) + throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Columns is empty"); + + Columns columns_list; + UInt64 num_rows = name_to_column_ptr.begin()->second->length(); + columns_list.reserve(header.columns()); + std::unordered_map>> nested_tables; + for (size_t column_i = 0, columns = header.columns(); column_i < columns; ++column_i) + { + const ColumnWithTypeAndName & header_column = header.getByPosition(column_i); + auto search_column_name = header_column.name; + ColumnWithTypeAndName column; + + if (!name_to_column_ptr.contains(search_column_name)) + { + bool read_from_nested = false; + /// Check if it's a column from nested table. + if (import_nested) + { + String search_nested_table_name = Nested::extractTableName(header_column.name); + if (name_to_column_ptr.contains(search_nested_table_name)) + { + if (!nested_tables.contains(search_nested_table_name)) + { + const auto & arrow_field = schema->field(schema->GetFieldIndex(search_nested_table_name)); + std::shared_ptr arrow_column = name_to_column_ptr[search_nested_table_name]; + ColumnsWithTypeAndName cols + = {readColumnFromArrowColumn(arrow_field, arrow_column, format_name, dictionary_values, true)}; + BlockPtr block_ptr = std::make_shared(cols); + auto column_extractor = std::make_shared(*block_ptr, true); + nested_tables[search_nested_table_name] = {block_ptr, column_extractor}; + } + auto nested_column = nested_tables[search_nested_table_name].second->extractColumn(search_column_name); + if (nested_column) + { + column = *nested_column; + read_from_nested = true; + } + } + } + if (!read_from_nested) + { + if (!allow_missing_columns) + { + throw Exception{ErrorCodes::THERE_IS_NO_COLUMN, "Column '{}' is not presented in input data.", header_column.name}; + } + else + { + column.name = header_column.name; + column.type = header_column.type; + column.column = header_column.column->cloneResized(num_rows); + columns_list.push_back(std::move(column.column)); + continue; + } + } + } + else + { + auto arrow_column = name_to_column_ptr[search_column_name]; + const auto & arrow_field = schema->field(schema->GetFieldIndex(search_column_name)); + column = readColumnFromArrowColumn(arrow_field, arrow_column, format_name, dictionary_values, true); + } + try + { + column.column = castColumn(column, header_column.type); + } + catch (Exception & e) + { + e.addMessage(fmt::format( + "while converting column {} from type {} to type {}", + backQuote(header_column.name), + column.type->getName(), + header_column.type->getName())); + throw; + } + column.type = header_column.type; + columns_list.push_back(std::move(column.column)); + } + res.setColumns(columns_list, num_rows); +} + +std::vector OptimizedArrowColumnToCHColumn::getMissingColumns(const arrow::Schema & schema) const +{ + std::vector missing_columns; + auto block_from_arrow = arrowSchemaToCHHeader(schema, format_name); + local_engine::NestedColumnExtractHelper nested_table(block_from_arrow, true); + for (size_t i = 0, columns = header.columns(); i < columns; ++i) + { + const auto & column = header.getByPosition(i); + bool read_from_nested = false; + if (!block_from_arrow.has(column.name)) + { + String nested_table_name = Nested::extractTableName(column.name); + if (import_nested && block_from_arrow.has(nested_table_name)) + { + if (nested_table.extractColumn(column.name)) + read_from_nested = true; + } + + if (!read_from_nested) + { + if (!allow_missing_columns) + { + throw Exception{ErrorCodes::THERE_IS_NO_COLUMN, "Column '{}' is not presented in input data.", column.name}; + } + + missing_columns.push_back(i); + } + } + } + return missing_columns; +} + +} + +#endif diff --git a/utils/local-engine/Storages/ch_parquet/OptimizedArrowColumnToCHColumn.h b/utils/local-engine/Storages/ch_parquet/OptimizedArrowColumnToCHColumn.h new file mode 100644 index 000000000000..1b12e411031e --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/OptimizedArrowColumnToCHColumn.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +#if USE_ARROW || USE_ORC || USE_PARQUET + +#include +#include +#include +#include +#include + + +namespace DB +{ + +class Block; +class Chunk; + +class OptimizedArrowColumnToCHColumn +{ +public: + using NameToColumnPtr = std::unordered_map>; + + OptimizedArrowColumnToCHColumn( + const Block & header_, + const std::string & format_name_, + bool import_nested_, + bool allow_missing_columns_); + + void arrowTableToCHChunk(Chunk & res, std::shared_ptr & table); + + void arrowColumnsToCHChunk(Chunk & res, NameToColumnPtr & name_to_column_ptr, const std::shared_ptr & schema); + + /// Get missing columns that exists in header but not in arrow::Schema + std::vector getMissingColumns(const arrow::Schema & schema) const; + + static Block arrowSchemaToCHHeader(const arrow::Schema & schema, const std::string & format_name); + + int64_t real_convert = 0; + int64_t cast_time = 0; + +private: + const Block & header; + const std::string format_name; + bool import_nested; + /// If false, throw exception if some columns in header not exists in arrow table. + bool allow_missing_columns; + + /// Map {column name : dictionary column}. + /// To avoid converting dictionary from Arrow Dictionary + /// to LowCardinality every chunk we save it and reuse. + std::unordered_map> dictionary_values; +}; + +} + +#endif diff --git a/utils/local-engine/Storages/ch_parquet/OptimizedParquetBlockInputFormat.cpp b/utils/local-engine/Storages/ch_parquet/OptimizedParquetBlockInputFormat.cpp new file mode 100644 index 000000000000..732968c3abe9 --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/OptimizedParquetBlockInputFormat.cpp @@ -0,0 +1,237 @@ +#include "OptimizedParquetBlockInputFormat.h" +#include + +#if USE_PARQUET + +#include +#include +#include +#include +#include +#include +#include "Storages/ch_parquet/arrow/reader.h" +#include +#include +#include "OptimizedArrowColumnToCHColumn.h" +#include + +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int CANNOT_READ_ALL_DATA; +} + +#define THROW_ARROW_NOT_OK(status) \ + do \ + { \ + if (::arrow::Status _s = (status); !_s.ok()) \ + throw Exception(_s.ToString(), ErrorCodes::BAD_ARGUMENTS); \ + } while (false) + +OptimizedParquetBlockInputFormat::OptimizedParquetBlockInputFormat(ReadBuffer & in_, Block header_, const FormatSettings & format_settings_) + : IInputFormat(std::move(header_), in_), format_settings(format_settings_) +{ +} + +Chunk OptimizedParquetBlockInputFormat::generate() +{ + Chunk res; + block_missing_values.clear(); + + if (!file_reader) + prepareReader(); + + if (is_stopped) + return {}; + + if (row_group_current >= row_group_total) + return res; + + std::shared_ptr table; + arrow::Status read_status = file_reader->ReadRowGroup(row_group_current, column_indices, &table); + if (!read_status.ok()) + throw ParsingException{"Error while reading Parquet data: " + read_status.ToString(), + ErrorCodes::CANNOT_READ_ALL_DATA}; + + if (format_settings.use_lowercase_column_name) + table = *table->RenameColumns(column_names); + + ++row_group_current; + + arrow_column_to_ch_column->arrowTableToCHChunk(res, table); + + /// If defaults_for_omitted_fields is true, calculate the default values from default expression for omitted fields. + /// Otherwise fill the missing columns with zero values of its type. + if (format_settings.defaults_for_omitted_fields) + for (size_t row_idx = 0; row_idx < res.getNumRows(); ++row_idx) + for (const auto & column_idx : missing_columns) + block_missing_values.setBit(column_idx, row_idx); + return res; +} + +void OptimizedParquetBlockInputFormat::resetParser() +{ + IInputFormat::resetParser(); + + file_reader.reset(); + column_indices.clear(); + column_names.clear(); + row_group_current = 0; + block_missing_values.clear(); +} + +const BlockMissingValues & OptimizedParquetBlockInputFormat::getMissingValues() const +{ + return block_missing_values; +} + +static size_t countIndicesForType(std::shared_ptr type) +{ + if (type->id() == arrow::Type::LIST) + return countIndicesForType(static_cast(type.get())->value_type()); + + if (type->id() == arrow::Type::STRUCT) + { + int indices = 0; + auto * struct_type = static_cast(type.get()); + for (int i = 0; i != struct_type->num_fields(); ++i) + indices += countIndicesForType(struct_type->field(i)->type()); + return indices; + } + + if (type->id() == arrow::Type::MAP) + { + auto * map_type = static_cast(type.get()); + return countIndicesForType(map_type->key_type()) + countIndicesForType(map_type->item_type()); + } + + return 1; +} + +static void getFileReaderAndSchema( + ReadBuffer & in, + std::unique_ptr & file_reader, + std::shared_ptr & schema, + const FormatSettings & format_settings, + std::atomic & is_stopped) +{ + auto arrow_file = asArrowFile(in, format_settings, is_stopped, "Parquet", PARQUET_MAGIC_BYTES); + if (is_stopped) + return; + THROW_ARROW_NOT_OK(ch_parquet::arrow::OpenFile(std::move(arrow_file), arrow::default_memory_pool(), &file_reader)); + THROW_ARROW_NOT_OK(file_reader->GetSchema(&schema)); + + if (format_settings.use_lowercase_column_name) + { + std::vector> fields; + fields.reserve(schema->num_fields()); + for (int i = 0; i < schema->num_fields(); ++i) + { + const auto& field = schema->field(i); + auto name = field->name(); + boost::to_lower(name); + fields.push_back(field->WithName(name)); + } + schema = arrow::schema(fields, schema->metadata()); + } +} + +void OptimizedParquetBlockInputFormat::prepareReader() +{ + std::shared_ptr schema; + getFileReaderAndSchema(*in, file_reader, schema, format_settings, is_stopped); + if (is_stopped) + return; + + row_group_total = file_reader->num_row_groups(); + row_group_current = 0; + + arrow_column_to_ch_column = std::make_unique(getPort().getHeader(), "Parquet", format_settings.parquet.import_nested, format_settings.parquet.allow_missing_columns); + missing_columns = arrow_column_to_ch_column->getMissingColumns(*schema); + + std::unordered_set nested_table_names; + if (format_settings.parquet.import_nested) + nested_table_names = Nested::getAllTableNames(getPort().getHeader()); + + int index = 0; + for (int i = 0; i < schema->num_fields(); ++i) + { + /// STRUCT type require the number of indexes equal to the number of + /// nested elements, so we should recursively + /// count the number of indices we need for this type. + int indexes_count = countIndicesForType(schema->field(i)->type()); + const auto & name = schema->field(i)->name(); + if (getPort().getHeader().has(name) || nested_table_names.contains(name)) + { + for (int j = 0; j != indexes_count; ++j) + { + column_indices.push_back(index + j); + column_names.push_back(name); + } + } + index += indexes_count; + } +} + +OptimizedParquetSchemaReader::OptimizedParquetSchemaReader(ReadBuffer & in_, const FormatSettings & format_settings_) : ISchemaReader(in_), format_settings(format_settings_) +{ +} + +NamesAndTypesList OptimizedParquetSchemaReader::readSchema() +{ + std::unique_ptr file_reader; + std::shared_ptr schema; + std::atomic is_stopped = 0; + getFileReaderAndSchema(in, file_reader, schema, format_settings, is_stopped); + auto header = OptimizedArrowColumnToCHColumn::arrowSchemaToCHHeader(*schema, "Parquet"); + return header.getNamesAndTypesList(); +} + +void registerInputFormatParquet(FormatFactory & factory) +{ + factory.registerInputFormat( + "Parquet", + [](ReadBuffer &buf, + const Block &sample, + const RowInputFormatParams &, + const FormatSettings & settings) + { + return std::make_shared(buf, sample, settings); + }); + factory.markFormatSupportsSubcolumns("Parquet"); + factory.markFormatSupportsSubsetOfColumns("Parquet"); +} + +void registerOptimizedParquetSchemaReader(FormatFactory & factory) +{ + factory.registerSchemaReader( + "Parquet", + [](ReadBuffer & buf, const FormatSettings & settings) { return std::make_shared(buf, settings); }); + + factory.registerAdditionalInfoForSchemaCacheGetter( + "Parquet", + [](const FormatSettings & settings) + { return fmt::format("schema_inference_make_columns_nullable={}", settings.schema_inference_make_columns_nullable); }); +} + +} + +#else + +namespace DB +{ +class FormatFactory; +void registerInputFormatParquet(FormatFactory &) +{ +} + +void registerOptimizedParquetSchemaReader(FormatFactory &) {} +} + +#endif diff --git a/utils/local-engine/Storages/ch_parquet/OptimizedParquetBlockInputFormat.h b/utils/local-engine/Storages/ch_parquet/OptimizedParquetBlockInputFormat.h new file mode 100644 index 000000000000..f997085e5828 --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/OptimizedParquetBlockInputFormat.h @@ -0,0 +1,67 @@ +#pragma once +#include +#if USE_PARQUET + +#include +#include +#include + +namespace ch_parquet::arrow { class FileReader; } + +namespace arrow { class Buffer; } + +namespace DB +{ + +class OptimizedArrowColumnToCHColumn; + +class OptimizedParquetBlockInputFormat : public IInputFormat +{ +public: + OptimizedParquetBlockInputFormat(ReadBuffer & in_, Block header_, const FormatSettings & format_settings_); + + void resetParser() override; + + String getName() const override { return "OptimizedParquetBlockInputFormat"; } + + const BlockMissingValues & getMissingValues() const override; + +private: + Chunk generate() override; + +protected: + void prepareReader(); + + void onCancel() override + { + is_stopped = 1; + } + + std::unique_ptr file_reader; + int row_group_total = 0; + // indices of columns to read from Parquet file + std::vector column_indices; + std::vector column_names; + std::unique_ptr arrow_column_to_ch_column; + int row_group_current = 0; + std::vector missing_columns; + BlockMissingValues block_missing_values; + const FormatSettings format_settings; + + std::atomic is_stopped{0}; +}; + +class OptimizedParquetSchemaReader : public ISchemaReader +{ +public: + OptimizedParquetSchemaReader(ReadBuffer & in_, const FormatSettings & format_settings_); + + NamesAndTypesList readSchema() override; + +private: + const FormatSettings format_settings; +}; + +} + +#endif diff --git a/utils/local-engine/Storages/ch_parquet/arrow/column_reader.cc b/utils/local-engine/Storages/ch_parquet/arrow/column_reader.cc new file mode 100644 index 000000000000..ca092e634e9a --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/arrow/column_reader.cc @@ -0,0 +1,1974 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "column_reader.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/buffer_builder.h" +#include "arrow/array.h" +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_dict.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/chunked_array.h" +#include "arrow/type.h" +#include "arrow/util/bit_stream_utils.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_writer.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/compression.h" +#include "arrow/util/int_util_internal.h" +#include "arrow/util/logging.h" +#include "arrow/util/rle_encoding.h" +#include "parquet/column_page.h" +#include "Storages/ch_parquet/arrow/encoding.h" +#include "parquet/encryption/encryption_internal.h" +#include "parquet/encryption/internal_file_decryptor.h" +#include "parquet/level_comparison.h" +#include "parquet/level_conversion.h" +#include "parquet/properties.h" +#include "parquet/statistics.h" +#include "parquet/thrift_internal.h" // IWYU pragma: keep +// Required after "arrow/util/int_util_internal.h" (for OPTIONAL) +#include "parquet/windows_compatibility.h" + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +using arrow::MemoryPool; +using arrow::internal::AddWithOverflow; +using arrow::internal::checked_cast; +using arrow::internal::MultiplyWithOverflow; + +namespace BitUtil = arrow::BitUtil; + +namespace ch_parquet { +using namespace parquet; +using namespace DB; +namespace { + inline bool HasSpacedValues(const ColumnDescriptor* descr) { + if (descr->max_repetition_level() > 0) { + // repeated+flat case + return !descr->schema_node()->is_required(); + } else { + // non-repeated+nested case + // Find if a node forces nulls in the lowest level along the hierarchy + const schema::Node* node = descr->schema_node().get(); + while (node) { + if (node->is_optional()) { + return true; + } + node = node->parent(); + } + return false; + } + } +} // namespace + +LevelDecoder::LevelDecoder() : num_values_remaining_(0) {} + +LevelDecoder::~LevelDecoder() {} + +int LevelDecoder::SetData(Encoding::type encoding, int16_t max_level, + int num_buffered_values, const uint8_t* data, + int32_t data_size) { + max_level_ = max_level; + int32_t num_bytes = 0; + encoding_ = encoding; + num_values_remaining_ = num_buffered_values; + bit_width_ = BitUtil::Log2(max_level + 1); + switch (encoding) { + case Encoding::RLE: { + if (data_size < 4) { + throw ParquetException("Received invalid levels (corrupt data page?)"); + } + num_bytes = ::arrow::util::SafeLoadAs(data); + if (num_bytes < 0 || num_bytes > data_size - 4) { + throw ParquetException("Received invalid number of bytes (corrupt data page?)"); + } + const uint8_t* decoder_data = data + 4; + if (!rle_decoder_) { + rle_decoder_.reset( + new ::arrow::util::RleDecoder(decoder_data, num_bytes, bit_width_)); + } else { + rle_decoder_->Reset(decoder_data, num_bytes, bit_width_); + } + return 4 + num_bytes; + } + case Encoding::BIT_PACKED: { + int num_bits = 0; + if (MultiplyWithOverflow(num_buffered_values, bit_width_, &num_bits)) { + throw ParquetException( + "Number of buffered values too large (corrupt data page?)"); + } + num_bytes = static_cast(BitUtil::BytesForBits(num_bits)); + if (num_bytes < 0 || num_bytes > data_size - 4) { + throw ParquetException("Received invalid number of bytes (corrupt data page?)"); + } + if (!bit_packed_decoder_) { + bit_packed_decoder_.reset(new ::arrow::BitUtil::BitReader(data, num_bytes)); + } else { + bit_packed_decoder_->Reset(data, num_bytes); + } + return num_bytes; + } + default: + throw ParquetException("Unknown encoding type for levels."); + } + return -1; +} + +void LevelDecoder::SetDataV2(int32_t num_bytes, int16_t max_level, + int num_buffered_values, const uint8_t* data) { + max_level_ = max_level; + // Repetition and definition levels always uses RLE encoding + // in the DataPageV2 format. + if (num_bytes < 0) { + throw ParquetException("Invalid page header (corrupt data page?)"); + } + encoding_ = Encoding::RLE; + num_values_remaining_ = num_buffered_values; + bit_width_ = BitUtil::Log2(max_level + 1); + + if (!rle_decoder_) { + rle_decoder_.reset(new ::arrow::util::RleDecoder(data, num_bytes, bit_width_)); + } else { + rle_decoder_->Reset(data, num_bytes, bit_width_); + } +} + +int LevelDecoder::Decode(int batch_size, int16_t* levels) { + int num_decoded = 0; + + int num_values = std::min(num_values_remaining_, batch_size); + if (encoding_ == Encoding::RLE) { + num_decoded = rle_decoder_->GetBatch(levels, num_values); + } else { + num_decoded = bit_packed_decoder_->GetBatch(bit_width_, levels, num_values); + } + if (num_decoded > 0) { + internal::MinMax min_max = internal::FindMinMax(levels, num_decoded); + if (ARROW_PREDICT_FALSE(min_max.min < 0 || min_max.max > max_level_)) { + std::stringstream ss; + ss << "Malformed levels. min: " << min_max.min << " max: " << min_max.max + << " out of range. Max Level: " << max_level_; + throw ParquetException(ss.str()); + } + } + num_values_remaining_ -= num_decoded; + return num_decoded; +} + +ReaderProperties default_reader_properties() { + static ReaderProperties default_reader_properties; + return default_reader_properties; +} + +namespace { + + // Extracts encoded statistics from V1 and V2 data page headers + template + EncodedStatistics ExtractStatsFromHeader(const H& header) { + EncodedStatistics page_statistics; + if (!header.__isset.statistics) { + return page_statistics; + } + const format::Statistics& stats = header.statistics; + if (stats.__isset.max) { + page_statistics.set_max(stats.max); + } + if (stats.__isset.min) { + page_statistics.set_min(stats.min); + } + if (stats.__isset.null_count) { + page_statistics.set_null_count(stats.null_count); + } + if (stats.__isset.distinct_count) { + page_statistics.set_distinct_count(stats.distinct_count); + } + return page_statistics; + } + + // ---------------------------------------------------------------------- + // SerializedPageReader deserializes Thrift metadata and pages that have been + // assembled in a serialized stream for storing in a Parquet files + + // This subclass delimits pages appearing in a serialized stream, each preceded + // by a serialized Thrift format::PageHeader indicating the type of each page + // and the page metadata. + class SerializedPageReader : public PageReader { + public: + SerializedPageReader(std::shared_ptr stream, int64_t total_num_rows, + Compression::type codec, ::arrow::MemoryPool* pool, + const CryptoContext* crypto_ctx) + : stream_(std::move(stream)), + decompression_buffer_(AllocateBuffer(pool, 0)), + page_ordinal_(0), + seen_num_rows_(0), + total_num_rows_(total_num_rows), + decryption_buffer_(AllocateBuffer(pool, 0)) { + if (crypto_ctx != nullptr) { + crypto_ctx_ = *crypto_ctx; + InitDecryption(); + } + max_page_header_size_ = kDefaultMaxPageHeaderSize; + decompressor_ = GetCodec(codec); + } + + // Implement the PageReader interface + std::shared_ptr NextPage() override; + + void set_max_page_header_size(uint32_t size) override { max_page_header_size_ = size; } + + private: + void UpdateDecryption(const std::shared_ptr& decryptor, int8_t module_type, + const std::string& page_aad); + + void InitDecryption(); + + std::shared_ptr DecompressIfNeeded(std::shared_ptr page_buffer, + int compressed_len, int uncompressed_len, + int levels_byte_len = 0); + + std::shared_ptr stream_; + + format::PageHeader current_page_header_; + std::shared_ptr current_page_; + + // Compression codec to use. + std::unique_ptr<::arrow::util::Codec> decompressor_; + std::shared_ptr decompression_buffer_; + + // The fields below are used for calculation of AAD (additional authenticated data) + // suffix which is part of the Parquet Modular Encryption. + // The AAD suffix for a parquet module is built internally by + // concatenating different parts some of which include + // the row group ordinal, column ordinal and page ordinal. + // Please refer to the encryption specification for more details: + // https://github.com/apache/parquet-format/blob/encryption/Encryption.md#44-additional-authenticated-data + + // The ordinal fields in the context below are used for AAD suffix calculation. + CryptoContext crypto_ctx_; + int16_t page_ordinal_; // page ordinal does not count the dictionary page + + // Maximum allowed page size + uint32_t max_page_header_size_; + + // Number of rows read in data pages so far + int64_t seen_num_rows_; + + // Number of rows in all the data pages + int64_t total_num_rows_; + + // data_page_aad_ and data_page_header_aad_ contain the AAD for data page and data page + // header in a single column respectively. + // While calculating AAD for different pages in a single column the pages AAD is + // updated by only the page ordinal. + std::string data_page_aad_; + std::string data_page_header_aad_; + // Encryption + std::shared_ptr decryption_buffer_; + }; + + void SerializedPageReader::InitDecryption() { + // Prepare the AAD for quick update later. + if (crypto_ctx_.data_decryptor != nullptr) { + DCHECK(!crypto_ctx_.data_decryptor->file_aad().empty()); + data_page_aad_ = encryption::CreateModuleAad( + crypto_ctx_.data_decryptor->file_aad(), encryption::kDataPage, + crypto_ctx_.row_group_ordinal, crypto_ctx_.column_ordinal, kNonPageOrdinal); + } + if (crypto_ctx_.meta_decryptor != nullptr) { + DCHECK(!crypto_ctx_.meta_decryptor->file_aad().empty()); + data_page_header_aad_ = encryption::CreateModuleAad( + crypto_ctx_.meta_decryptor->file_aad(), encryption::kDataPageHeader, + crypto_ctx_.row_group_ordinal, crypto_ctx_.column_ordinal, kNonPageOrdinal); + } + } + + void SerializedPageReader::UpdateDecryption(const std::shared_ptr& decryptor, + int8_t module_type, + const std::string& page_aad) { + DCHECK(decryptor != nullptr); + if (crypto_ctx_.start_decrypt_with_dictionary_page) { + std::string aad = encryption::CreateModuleAad( + decryptor->file_aad(), module_type, crypto_ctx_.row_group_ordinal, + crypto_ctx_.column_ordinal, kNonPageOrdinal); + decryptor->UpdateAad(aad); + } else { + encryption::QuickUpdatePageAad(page_aad, page_ordinal_); + decryptor->UpdateAad(page_aad); + } + } + + std::shared_ptr SerializedPageReader::NextPage() { + // Loop here because there may be unhandled page types that we skip until + // finding a page that we do know what to do with + + while (seen_num_rows_ < total_num_rows_) { + uint32_t header_size = 0; + uint32_t allowed_page_size = kDefaultPageHeaderSize; + + // Page headers can be very large because of page statistics + // We try to deserialize a larger buffer progressively + // until a maximum allowed header limit + while (true) { + PARQUET_ASSIGN_OR_THROW(auto view, stream_->Peek(allowed_page_size)); + if (view.size() == 0) { + return std::shared_ptr(nullptr); + } + + // This gets used, then set by DeserializeThriftMsg + header_size = static_cast(view.size()); + try { + if (crypto_ctx_.meta_decryptor != nullptr) { + UpdateDecryption(crypto_ctx_.meta_decryptor, encryption::kDictionaryPageHeader, + data_page_header_aad_); + } + DeserializeThriftMsg(reinterpret_cast(view.data()), &header_size, + ¤t_page_header_, crypto_ctx_.meta_decryptor); + break; + } catch (std::exception& e) { + // Failed to deserialize. Double the allowed page header size and try again + std::stringstream ss; + ss << e.what(); + allowed_page_size *= 2; + if (allowed_page_size > max_page_header_size_) { + ss << "Deserializing page header failed.\n"; + throw ParquetException(ss.str()); + } + } + } + // Advance the stream offset + PARQUET_THROW_NOT_OK(stream_->Advance(header_size)); + + int compressed_len = current_page_header_.compressed_page_size; + int uncompressed_len = current_page_header_.uncompressed_page_size; + if (compressed_len < 0 || uncompressed_len < 0) { + throw ParquetException("Invalid page header"); + } + + if (crypto_ctx_.data_decryptor != nullptr) { + UpdateDecryption(crypto_ctx_.data_decryptor, encryption::kDictionaryPage, + data_page_aad_); + } + + // Read the compressed data page. + PARQUET_ASSIGN_OR_THROW(auto page_buffer, stream_->Read(compressed_len)); + if (page_buffer->size() != compressed_len) { + std::stringstream ss; + ss << "Page was smaller (" << page_buffer->size() << ") than expected (" + << compressed_len << ")"; + ParquetException::EofException(ss.str()); + } + + // Decrypt it if we need to + if (crypto_ctx_.data_decryptor != nullptr) { + PARQUET_THROW_NOT_OK(decryption_buffer_->Resize( + compressed_len - crypto_ctx_.data_decryptor->CiphertextSizeDelta(), false)); + compressed_len = crypto_ctx_.data_decryptor->Decrypt( + page_buffer->data(), compressed_len, decryption_buffer_->mutable_data()); + + page_buffer = decryption_buffer_; + } + + const PageType::type page_type = LoadEnumSafe(¤t_page_header_.type); + + if (page_type == PageType::DICTIONARY_PAGE) { + crypto_ctx_.start_decrypt_with_dictionary_page = false; + const format::DictionaryPageHeader& dict_header = + current_page_header_.dictionary_page_header; + + bool is_sorted = dict_header.__isset.is_sorted ? dict_header.is_sorted : false; + if (dict_header.num_values < 0) { + throw ParquetException("Invalid page header (negative number of values)"); + } + + // Uncompress if needed + page_buffer = + DecompressIfNeeded(std::move(page_buffer), compressed_len, uncompressed_len); + + return std::make_shared(page_buffer, dict_header.num_values, + LoadEnumSafe(&dict_header.encoding), + is_sorted); + } else if (page_type == PageType::DATA_PAGE) { + ++page_ordinal_; + const format::DataPageHeader& header = current_page_header_.data_page_header; + + if (header.num_values < 0) { + throw ParquetException("Invalid page header (negative number of values)"); + } + EncodedStatistics page_statistics = ExtractStatsFromHeader(header); + seen_num_rows_ += header.num_values; + + // Uncompress if needed + page_buffer = + DecompressIfNeeded(std::move(page_buffer), compressed_len, uncompressed_len); + + return std::make_shared(page_buffer, header.num_values, + LoadEnumSafe(&header.encoding), + LoadEnumSafe(&header.definition_level_encoding), + LoadEnumSafe(&header.repetition_level_encoding), + uncompressed_len, page_statistics); + } else if (page_type == PageType::DATA_PAGE_V2) { + ++page_ordinal_; + const format::DataPageHeaderV2& header = current_page_header_.data_page_header_v2; + + if (header.num_values < 0) { + throw ParquetException("Invalid page header (negative number of values)"); + } + if (header.definition_levels_byte_length < 0 || + header.repetition_levels_byte_length < 0) { + throw ParquetException("Invalid page header (negative levels byte length)"); + } + bool is_compressed = header.__isset.is_compressed ? header.is_compressed : false; + EncodedStatistics page_statistics = ExtractStatsFromHeader(header); + seen_num_rows_ += header.num_values; + + // Uncompress if needed + int levels_byte_len; + if (AddWithOverflow(header.definition_levels_byte_length, + header.repetition_levels_byte_length, &levels_byte_len)) { + throw ParquetException("Levels size too large (corrupt file?)"); + } + // DecompressIfNeeded doesn't take `is_compressed` into account as + // it's page type-agnostic. + if (is_compressed) { + page_buffer = DecompressIfNeeded(std::move(page_buffer), compressed_len, + uncompressed_len, levels_byte_len); + } + + return std::make_shared( + page_buffer, header.num_values, header.num_nulls, header.num_rows, + LoadEnumSafe(&header.encoding), header.definition_levels_byte_length, + header.repetition_levels_byte_length, uncompressed_len, is_compressed, + page_statistics); + } else { + // We don't know what this page type is. We're allowed to skip non-data + // pages. + continue; + } + } + return std::shared_ptr(nullptr); + } + + std::shared_ptr SerializedPageReader::DecompressIfNeeded( + std::shared_ptr page_buffer, int compressed_len, int uncompressed_len, + int levels_byte_len) { + if (decompressor_ == nullptr) { + return page_buffer; + } + if (compressed_len < levels_byte_len || uncompressed_len < levels_byte_len) { + throw ParquetException("Invalid page header"); + } + + // Grow the uncompressed buffer if we need to. + if (uncompressed_len > static_cast(decompression_buffer_->size())) { + PARQUET_THROW_NOT_OK(decompression_buffer_->Resize(uncompressed_len, false)); + } + + if (levels_byte_len > 0) { + // First copy the levels as-is + uint8_t* decompressed = decompression_buffer_->mutable_data(); + memcpy(decompressed, page_buffer->data(), levels_byte_len); + } + + // Decompress the values + PARQUET_THROW_NOT_OK(decompressor_->Decompress( + compressed_len - levels_byte_len, page_buffer->data() + levels_byte_len, + uncompressed_len - levels_byte_len, + decompression_buffer_->mutable_data() + levels_byte_len)); + + return decompression_buffer_; + } + +} // namespace +} + +namespace parquet +{ + std::unique_ptr PageReader::Open( + std::shared_ptr stream, + int64_t total_num_rows, + Compression::type codec, + ::arrow::MemoryPool * pool, + const CryptoContext * ctx) + { + return std::unique_ptr(new SerializedPageReader(std::move(stream), total_num_rows, codec, pool, ctx)); + } +} + +namespace ch_parquet +{ +using namespace parquet; +using namespace parquet::internal; + +namespace { + + // ---------------------------------------------------------------------- + // Impl base class for TypedColumnReader and RecordReader + + // PLAIN_DICTIONARY is deprecated but used to be used as a dictionary index + // encoding. + static bool IsDictionaryIndexEncoding(const Encoding::type& e) { + return e == Encoding::RLE_DICTIONARY || e == Encoding::PLAIN_DICTIONARY; + } + + template + class ColumnReaderImplBase { + public: + using T = typename DType::c_type; + + ColumnReaderImplBase(const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) + : descr_(descr), + max_def_level_(descr->max_definition_level()), + max_rep_level_(descr->max_repetition_level()), + num_buffered_values_(0), + num_decoded_values_(0), + pool_(pool), + current_decoder_(nullptr), + current_encoding_(Encoding::UNKNOWN) {} + + virtual ~ColumnReaderImplBase() = default; + + protected: + // Read up to batch_size values from the current data page into the + // pre-allocated memory T* + // + // @returns: the number of values read into the out buffer + int64_t ReadValues(int64_t batch_size, T* out) { + int64_t num_decoded = current_decoder_->Decode(out, static_cast(batch_size)); + return num_decoded; + } + + // Read up to batch_size values from the current data page into the + // pre-allocated memory T*, leaving spaces for null entries according + // to the def_levels. + // + // @returns: the number of values read into the out buffer + int64_t ReadValuesSpaced(int64_t batch_size, T* out, int64_t null_count, + uint8_t* valid_bits, int64_t valid_bits_offset) { + return current_decoder_->DecodeSpaced(out, static_cast(batch_size), + static_cast(null_count), valid_bits, + valid_bits_offset); + } + + // Read multiple definition levels into preallocated memory + // + // Returns the number of decoded definition levels + int64_t ReadDefinitionLevels(int64_t batch_size, int16_t* levels) { + if (max_def_level_ == 0) { + return 0; + } + return definition_level_decoder_.Decode(static_cast(batch_size), levels); + } + + bool HasNextInternal() { + // Either there is no data page available yet, or the data page has been + // exhausted + if (num_buffered_values_ == 0 || num_decoded_values_ == num_buffered_values_) { + if (!ReadNewPage() || num_buffered_values_ == 0) { + return false; + } + } + return true; + } + + // Read multiple repetition levels into preallocated memory + // Returns the number of decoded repetition levels + int64_t ReadRepetitionLevels(int64_t batch_size, int16_t* levels) { + if (max_rep_level_ == 0) { + return 0; + } + return repetition_level_decoder_.Decode(static_cast(batch_size), levels); + } + + // Advance to the next data page + bool ReadNewPage() { + // Loop until we find the next data page. + while (true) { + current_page_ = pager_->NextPage(); + if (!current_page_) { + // EOS + return false; + } + + if (current_page_->type() == PageType::DICTIONARY_PAGE) { + ConfigureDictionary(static_cast(current_page_.get())); + continue; + } else if (current_page_->type() == PageType::DATA_PAGE) { + const auto page = std::static_pointer_cast(current_page_); + const int64_t levels_byte_size = InitializeLevelDecoders( + *page, page->repetition_level_encoding(), page->definition_level_encoding()); + InitializeDataDecoder(*page, levels_byte_size); + return true; + } else if (current_page_->type() == PageType::DATA_PAGE_V2) { + const auto page = std::static_pointer_cast(current_page_); + int64_t levels_byte_size = InitializeLevelDecodersV2(*page); + InitializeDataDecoder(*page, levels_byte_size); + return true; + } else { + // We don't know what this page type is. We're allowed to skip non-data + // pages. + continue; + } + } + return true; + } + + void ConfigureDictionary(const DictionaryPage* page) { + int encoding = static_cast(page->encoding()); + if (page->encoding() == Encoding::PLAIN_DICTIONARY || + page->encoding() == Encoding::PLAIN) { + encoding = static_cast(Encoding::RLE_DICTIONARY); + } + + auto it = decoders_.find(encoding); + if (it != decoders_.end()) { + throw ParquetException("Column cannot have more than one dictionary."); + } + + if (page->encoding() == Encoding::PLAIN_DICTIONARY || + page->encoding() == Encoding::PLAIN) { + auto dictionary = MakeTypedDecoder(Encoding::PLAIN, descr_); + dictionary->SetData(page->num_values(), page->data(), page->size()); + + // The dictionary is fully decoded during DictionaryDecoder::Init, so the + // DictionaryPage buffer is no longer required after this step + // + // TODO(wesm): investigate whether this all-or-nothing decoding of the + // dictionary makes sense and whether performance can be improved + + std::unique_ptr> decoder = MakeDictDecoder(descr_, pool_); + decoder->SetDict(dictionary.get()); + decoders_[encoding] = + std::unique_ptr(dynamic_cast(decoder.release())); + } else { + ParquetException::NYI("only plain dictionary encoding has been implemented"); + } + + new_dictionary_ = true; + current_decoder_ = decoders_[encoding].get(); + DCHECK(current_decoder_); + } + + // Initialize repetition and definition level decoders on the next data page. + + // If the data page includes repetition and definition levels, we + // initialize the level decoders and return the number of encoded level bytes. + // The return value helps determine the number of bytes in the encoded data. + int64_t InitializeLevelDecoders(const DataPage& page, + Encoding::type repetition_level_encoding, + Encoding::type definition_level_encoding) { + // Read a data page. + num_buffered_values_ = page.num_values(); + + // Have not decoded any values from the data page yet + num_decoded_values_ = 0; + + const uint8_t* buffer = page.data(); + int32_t levels_byte_size = 0; + int32_t max_size = page.size(); + + // Data page Layout: Repetition Levels - Definition Levels - encoded values. + // Levels are encoded as rle or bit-packed. + // Init repetition levels + if (max_rep_level_ > 0) { + int32_t rep_levels_bytes = repetition_level_decoder_.SetData( + repetition_level_encoding, max_rep_level_, + static_cast(num_buffered_values_), buffer, max_size); + buffer += rep_levels_bytes; + levels_byte_size += rep_levels_bytes; + max_size -= rep_levels_bytes; + } + // TODO figure a way to set max_def_level_ to 0 + // if the initial value is invalid + + // Init definition levels + if (max_def_level_ > 0) { + int32_t def_levels_bytes = definition_level_decoder_.SetData( + definition_level_encoding, max_def_level_, + static_cast(num_buffered_values_), buffer, max_size); + levels_byte_size += def_levels_bytes; + max_size -= def_levels_bytes; + } + + return levels_byte_size; + } + + int64_t InitializeLevelDecodersV2(const DataPageV2& page) { + // Read a data page. + num_buffered_values_ = page.num_values(); + + // Have not decoded any values from the data page yet + num_decoded_values_ = 0; + const uint8_t* buffer = page.data(); + + const int64_t total_levels_length = + static_cast(page.repetition_levels_byte_length()) + + page.definition_levels_byte_length(); + + if (total_levels_length > page.size()) { + throw ParquetException("Data page too small for levels (corrupt header?)"); + } + + if (max_rep_level_ > 0) { + repetition_level_decoder_.SetDataV2(page.repetition_levels_byte_length(), + max_rep_level_, + static_cast(num_buffered_values_), buffer); + buffer += page.repetition_levels_byte_length(); + } + + if (max_def_level_ > 0) { + definition_level_decoder_.SetDataV2(page.definition_levels_byte_length(), + max_def_level_, + static_cast(num_buffered_values_), buffer); + } + + return total_levels_length; + } + + // Get a decoder object for this page or create a new decoder if this is the + // first page with this encoding. + void InitializeDataDecoder(const DataPage& page, int64_t levels_byte_size) { + const uint8_t* buffer = page.data() + levels_byte_size; + const int64_t data_size = page.size() - levels_byte_size; + + if (data_size < 0) { + throw ParquetException("Page smaller than size of encoded levels"); + } + + Encoding::type encoding = page.encoding(); + + if (IsDictionaryIndexEncoding(encoding)) { + encoding = Encoding::RLE_DICTIONARY; + } + + auto it = decoders_.find(static_cast(encoding)); + if (it != decoders_.end()) { + DCHECK(it->second.get() != nullptr); + if (encoding == Encoding::RLE_DICTIONARY) { + DCHECK(current_decoder_->encoding() == Encoding::RLE_DICTIONARY); + } + current_decoder_ = it->second.get(); + } else { + switch (encoding) { + case Encoding::PLAIN: { + auto decoder = MakeTypedDecoder(Encoding::PLAIN, descr_); + current_decoder_ = decoder.get(); + decoders_[static_cast(encoding)] = std::move(decoder); + break; + } + case Encoding::BYTE_STREAM_SPLIT: { + auto decoder = MakeTypedDecoder(Encoding::BYTE_STREAM_SPLIT, descr_); + current_decoder_ = decoder.get(); + decoders_[static_cast(encoding)] = std::move(decoder); + break; + } + case Encoding::RLE_DICTIONARY: + throw ParquetException("Dictionary page must be before data page."); + + case Encoding::DELTA_BINARY_PACKED: { + auto decoder = MakeTypedDecoder(Encoding::DELTA_BINARY_PACKED, descr_); + current_decoder_ = decoder.get(); + decoders_[static_cast(encoding)] = std::move(decoder); + break; + } + case Encoding::DELTA_LENGTH_BYTE_ARRAY: + case Encoding::DELTA_BYTE_ARRAY: + ParquetException::NYI("Unsupported encoding"); + + default: + throw ParquetException("Unknown encoding type."); + } + } + current_encoding_ = encoding; + current_decoder_->SetData(static_cast(num_buffered_values_), buffer, + static_cast(data_size)); + } + + const ColumnDescriptor* descr_; + const int16_t max_def_level_; + const int16_t max_rep_level_; + + std::unique_ptr pager_; + std::shared_ptr current_page_; + + // Not set if full schema for this field has no optional or repeated elements + LevelDecoder definition_level_decoder_; + + // Not set for flat schemas. + LevelDecoder repetition_level_decoder_; + + // The total number of values stored in the data page. This is the maximum of + // the number of encoded definition levels or encoded values. For + // non-repeated, required columns, this is equal to the number of encoded + // values. For repeated or optional values, there may be fewer data values + // than levels, and this tells you how many encoded levels there are in that + // case. + int64_t num_buffered_values_; + + // The number of values from the current data page that have been decoded + // into memory + int64_t num_decoded_values_; + + ::arrow::MemoryPool* pool_; + + using DecoderType = TypedDecoder; + DecoderType* current_decoder_; + Encoding::type current_encoding_; + + /// Flag to signal when a new dictionary has been set, for the benefit of + /// DictionaryRecordReader + bool new_dictionary_; + + // The exposed encoding + ExposedEncoding exposed_encoding_ = ExposedEncoding::NO_ENCODING; + + // Map of encoding type to the respective decoder object. For example, a + // column chunk's data pages may include both dictionary-encoded and + // plain-encoded data. + std::unordered_map> decoders_; + + void ConsumeBufferedValues(int64_t num_values) { num_decoded_values_ += num_values; } + }; + + // ---------------------------------------------------------------------- + // TypedColumnReader implementations + + template + class TypedColumnReaderImpl : public TypedColumnReader, + public ColumnReaderImplBase { + public: + using T = typename DType::c_type; + + TypedColumnReaderImpl(const ColumnDescriptor* descr, std::unique_ptr pager, + ::arrow::MemoryPool* pool) + : ColumnReaderImplBase(descr, pool) { + this->pager_ = std::move(pager); + } + + bool HasNext() override { return this->HasNextInternal(); } + + int64_t ReadBatch(int64_t batch_size, int16_t* def_levels, int16_t* rep_levels, + T* values, int64_t* values_read) override; + + int64_t ReadBatchSpaced(int64_t batch_size, int16_t* def_levels, int16_t* rep_levels, + T* values, uint8_t* valid_bits, int64_t valid_bits_offset, + int64_t* levels_read, int64_t* values_read, + int64_t* null_count) override; + + int64_t Skip(int64_t num_rows_to_skip) override; + + Type::type type() const override { return this->descr_->physical_type(); } + + const ColumnDescriptor* descr() const override { return this->descr_; } + + ExposedEncoding GetExposedEncoding() override { return this->exposed_encoding_; }; + + int64_t ReadBatchWithDictionary(int64_t batch_size, int16_t* def_levels, + int16_t* rep_levels, int32_t* indices, + int64_t* indices_read, const T** dict, + int32_t* dict_len) override; + + protected: + void SetExposedEncoding(ExposedEncoding encoding) override { + this->exposed_encoding_ = encoding; + } + + private: + // Read dictionary indices. Similar to ReadValues but decode data to dictionary indices. + // This function is called only by ReadBatchWithDictionary(). + int64_t ReadDictionaryIndices(int64_t indices_to_read, int32_t* indices) { + auto decoder = dynamic_cast*>(this->current_decoder_); + return decoder->DecodeIndices(static_cast(indices_to_read), indices); + } + + // Get dictionary. The dictionary should have been set by SetDict(). The dictionary is + // owned by the internal decoder and is destroyed when the reader is destroyed. This + // function is called only by ReadBatchWithDictionary() after dictionary is configured. + void GetDictionary(const T** dictionary, int32_t* dictionary_length) { + auto decoder = dynamic_cast*>(this->current_decoder_); + decoder->GetDictionary(dictionary, dictionary_length); + } + + // Read definition and repetition levels. Also return the number of definition levels + // and number of values to read. This function is called before reading values. + void ReadLevels(int64_t batch_size, int16_t* def_levels, int16_t* rep_levels, + int64_t* num_def_levels, int64_t* values_to_read) { + batch_size = + std::min(batch_size, this->num_buffered_values_ - this->num_decoded_values_); + + // If the field is required and non-repeated, there are no definition levels + if (this->max_def_level_ > 0 && def_levels != nullptr) { + *num_def_levels = this->ReadDefinitionLevels(batch_size, def_levels); + // TODO(wesm): this tallying of values-to-decode can be performed with better + // cache-efficiency if fused with the level decoding. + for (int64_t i = 0; i < *num_def_levels; ++i) { + if (def_levels[i] == this->max_def_level_) { + ++(*values_to_read); + } + } + } else { + // Required field, read all values + *values_to_read = batch_size; + } + + // Not present for non-repeated fields + if (this->max_rep_level_ > 0 && rep_levels != nullptr) { + int64_t num_rep_levels = this->ReadRepetitionLevels(batch_size, rep_levels); + if (def_levels != nullptr && *num_def_levels != num_rep_levels) { + throw ParquetException("Number of decoded rep / def levels did not match"); + } + } + } + }; + + template + int64_t TypedColumnReaderImpl::ReadBatchWithDictionary( + int64_t batch_size, int16_t* def_levels, int16_t* rep_levels, int32_t* indices, + int64_t* indices_read, const T** dict, int32_t* dict_len) { + bool has_dict_output = dict != nullptr && dict_len != nullptr; + // Similar logic as ReadValues to get pages. + if (!HasNext()) { + *indices_read = 0; + if (has_dict_output) { + *dict = nullptr; + *dict_len = 0; + } + return 0; + } + + // Verify the current data page is dictionary encoded. + if (this->current_encoding_ != Encoding::RLE_DICTIONARY) { + std::stringstream ss; + ss << "Data page is not dictionary encoded. Encoding: " + << EncodingToString(this->current_encoding_); + throw ParquetException(ss.str()); + } + + // Get dictionary pointer and length. + if (has_dict_output) { + GetDictionary(dict, dict_len); + } + + // Similar logic as ReadValues to get def levels and rep levels. + int64_t num_def_levels = 0; + int64_t indices_to_read = 0; + ReadLevels(batch_size, def_levels, rep_levels, &num_def_levels, &indices_to_read); + + // Read dictionary indices. + *indices_read = ReadDictionaryIndices(indices_to_read, indices); + int64_t total_indices = std::max(num_def_levels, *indices_read); + this->ConsumeBufferedValues(total_indices); + + return total_indices; + } + + template + int64_t TypedColumnReaderImpl::ReadBatch(int64_t batch_size, int16_t* def_levels, + int16_t* rep_levels, T* values, + int64_t* values_read) { + // HasNext invokes ReadNewPage + if (!HasNext()) { + *values_read = 0; + return 0; + } + + // TODO(wesm): keep reading data pages until batch_size is reached, or the + // row group is finished + int64_t num_def_levels = 0; + int64_t values_to_read = 0; + ReadLevels(batch_size, def_levels, rep_levels, &num_def_levels, &values_to_read); + + *values_read = this->ReadValues(values_to_read, values); + int64_t total_values = std::max(num_def_levels, *values_read); + this->ConsumeBufferedValues(total_values); + + return total_values; + } + + template + int64_t TypedColumnReaderImpl::ReadBatchSpaced( + int64_t batch_size, int16_t* def_levels, int16_t* rep_levels, T* values, + uint8_t* valid_bits, int64_t valid_bits_offset, int64_t* levels_read, + int64_t* values_read, int64_t* null_count_out) { + // HasNext invokes ReadNewPage + if (!HasNext()) { + *levels_read = 0; + *values_read = 0; + *null_count_out = 0; + return 0; + } + + int64_t total_values; + // TODO(wesm): keep reading data pages until batch_size is reached, or the + // row group is finished + batch_size = + std::min(batch_size, this->num_buffered_values_ - this->num_decoded_values_); + + // If the field is required and non-repeated, there are no definition levels + if (this->max_def_level_ > 0) { + int64_t num_def_levels = this->ReadDefinitionLevels(batch_size, def_levels); + + // Not present for non-repeated fields + if (this->max_rep_level_ > 0) { + int64_t num_rep_levels = this->ReadRepetitionLevels(batch_size, rep_levels); + if (num_def_levels != num_rep_levels) { + throw ParquetException("Number of decoded rep / def levels did not match"); + } + } + + const bool has_spaced_values = HasSpacedValues(this->descr_); + int64_t null_count = 0; + if (!has_spaced_values) { + int values_to_read = 0; + for (int64_t i = 0; i < num_def_levels; ++i) { + if (def_levels[i] == this->max_def_level_) { + ++values_to_read; + } + } + total_values = this->ReadValues(values_to_read, values); + ::arrow::BitUtil::SetBitsTo(valid_bits, valid_bits_offset, + /*length=*/total_values, + /*bits_are_set=*/true); + *values_read = total_values; + } else { + internal::LevelInfo info; + info.repeated_ancestor_def_level = this->max_def_level_ - 1; + info.def_level = this->max_def_level_; + info.rep_level = this->max_rep_level_; + internal::ValidityBitmapInputOutput validity_io; + validity_io.values_read_upper_bound = num_def_levels; + validity_io.valid_bits = valid_bits; + validity_io.valid_bits_offset = valid_bits_offset; + validity_io.null_count = null_count; + validity_io.values_read = *values_read; + + internal::DefLevelsToBitmap(def_levels, num_def_levels, info, &validity_io); + null_count = validity_io.null_count; + *values_read = validity_io.values_read; + + total_values = + this->ReadValuesSpaced(*values_read, values, static_cast(null_count), + valid_bits, valid_bits_offset); + } + *levels_read = num_def_levels; + *null_count_out = null_count; + + } else { + // Required field, read all values + total_values = this->ReadValues(batch_size, values); + ::arrow::BitUtil::SetBitsTo(valid_bits, valid_bits_offset, + /*length=*/total_values, + /*bits_are_set=*/true); + *null_count_out = 0; + *values_read = total_values; + *levels_read = total_values; + } + + this->ConsumeBufferedValues(*levels_read); + return total_values; + } + + template + int64_t TypedColumnReaderImpl::Skip(int64_t num_rows_to_skip) { + int64_t rows_to_skip = num_rows_to_skip; + while (HasNext() && rows_to_skip > 0) { + // If the number of rows to skip is more than the number of undecoded values, skip the + // Page. + if (rows_to_skip > (this->num_buffered_values_ - this->num_decoded_values_)) { + rows_to_skip -= this->num_buffered_values_ - this->num_decoded_values_; + this->num_decoded_values_ = this->num_buffered_values_; + } else { + // We need to read this Page + // Jump to the right offset in the Page + int64_t batch_size = 1024; // ReadBatch with a smaller memory footprint + int64_t values_read = 0; + + // This will be enough scratch space to accommodate 16-bit levels or any + // value type + int value_size = type_traits::value_byte_size; + std::shared_ptr scratch = AllocateBuffer( + this->pool_, batch_size * std::max(sizeof(int16_t), value_size)); + + do { + batch_size = std::min(batch_size, rows_to_skip); + values_read = + ReadBatch(static_cast(batch_size), + reinterpret_cast(scratch->mutable_data()), + reinterpret_cast(scratch->mutable_data()), + reinterpret_cast(scratch->mutable_data()), &values_read); + rows_to_skip -= values_read; + } while (values_read > 0 && rows_to_skip > 0); + } + } + return num_rows_to_skip - rows_to_skip; + } + +} // namespace + +// ---------------------------------------------------------------------- +// Dynamic column reader constructor + +std::shared_ptr ColumnReader::Make(const ColumnDescriptor* descr, + std::unique_ptr pager, + MemoryPool* pool) { + switch (descr->physical_type()) { + case Type::BOOLEAN: + return std::make_shared>(descr, std::move(pager), + pool); + case Type::INT32: + return std::make_shared>(descr, std::move(pager), + pool); + case Type::INT64: + return std::make_shared>(descr, std::move(pager), + pool); + case Type::INT96: + return std::make_shared>(descr, std::move(pager), + pool); + case Type::FLOAT: + return std::make_shared>(descr, std::move(pager), + pool); + case Type::DOUBLE: + return std::make_shared>(descr, std::move(pager), + pool); + case Type::BYTE_ARRAY: + return std::make_shared>( + descr, std::move(pager), pool); + case Type::FIXED_LEN_BYTE_ARRAY: + return std::make_shared>(descr, std::move(pager), + pool); + default: + ParquetException::NYI("type reader not implemented"); + } + // Unreachable code, but suppress compiler warning + return std::shared_ptr(nullptr); +} + +// ---------------------------------------------------------------------- +// RecordReader + +namespace internal { + namespace { + + // The minimum number of repetition/definition levels to decode at a time, for + // better vectorized performance when doing many smaller record reads + constexpr int64_t kMinLevelBatchSize = 1024; + + template + class TypedRecordReader : public ColumnReaderImplBase, + virtual public RecordReader { + public: + using T = typename DType::c_type; + using BASE = ColumnReaderImplBase; + TypedRecordReader(const ColumnDescriptor* descr, LevelInfo leaf_info, MemoryPool* pool) + : BASE(descr, pool) { + leaf_info_ = leaf_info; + nullable_values_ = leaf_info.HasNullableValues(); + at_record_start_ = true; + records_read_ = 0; + values_written_ = 0; + values_capacity_ = 0; + null_count_ = 0; + levels_written_ = 0; + levels_position_ = 0; + levels_capacity_ = 0; + uses_values_ = !(descr->physical_type() == Type::BYTE_ARRAY); + + if (uses_values_) { + values_ = AllocateBuffer(pool); + } + valid_bits_ = AllocateBuffer(pool); + def_levels_ = AllocateBuffer(pool); + rep_levels_ = AllocateBuffer(pool); + Reset(); + } + + int64_t available_values_current_page() const { + return this->num_buffered_values_ - this->num_decoded_values_; + } + + // Compute the values capacity in bytes for the given number of elements + int64_t bytes_for_values(int64_t nitems) const { + int64_t type_size = GetTypeByteSize(this->descr_->physical_type()); + int64_t bytes_for_values = -1; + if (MultiplyWithOverflow(nitems, type_size, &bytes_for_values)) { + throw ParquetException("Total size of items too large"); + } + return bytes_for_values; + } + + int64_t ReadRecords(int64_t num_records) override { + // Delimit records, then read values at the end + int64_t records_read = 0; + + if (levels_position_ < levels_written_) { + records_read += ReadRecordData(num_records); + } + + int64_t level_batch_size = std::max(kMinLevelBatchSize, num_records); + + // If we are in the middle of a record, we continue until reaching the + // desired number of records or the end of the current record if we've found + // enough records + while (!at_record_start_ || records_read < num_records) { + // Is there more data to read in this row group? + if (!this->HasNextInternal()) { + if (!at_record_start_) { + // We ended the row group while inside a record that we haven't seen + // the end of yet. So increment the record count for the last record in + // the row group + ++records_read; + at_record_start_ = true; + } + break; + } + + /// We perform multiple batch reads until we either exhaust the row group + /// or observe the desired number of records + int64_t batch_size = std::min(level_batch_size, available_values_current_page()); + + // No more data in column + if (batch_size == 0) { + break; + } + + if (this->max_def_level_ > 0) { + ReserveLevels(batch_size); + + int16_t* def_levels = this->def_levels() + levels_written_; + int16_t* rep_levels = this->rep_levels() + levels_written_; + + // Not present for non-repeated fields + int64_t levels_read = 0; + if (this->max_rep_level_ > 0) { + levels_read = this->ReadDefinitionLevels(batch_size, def_levels); + if (this->ReadRepetitionLevels(batch_size, rep_levels) != levels_read) { + throw ParquetException("Number of decoded rep / def levels did not match"); + } + } else if (this->max_def_level_ > 0) { + levels_read = this->ReadDefinitionLevels(batch_size, def_levels); + } + + // Exhausted column chunk + if (levels_read == 0) { + break; + } + + levels_written_ += levels_read; + records_read += ReadRecordData(num_records - records_read); + } else { + // No repetition or definition levels + batch_size = std::min(num_records - records_read, batch_size); + records_read += ReadRecordData(batch_size); + } + } + + return records_read; + } + + // We may outwardly have the appearance of having exhausted a column chunk + // when in fact we are in the middle of processing the last batch + bool has_values_to_process() const { return levels_position_ < levels_written_; } + + std::shared_ptr ReleaseValues() override { + if (uses_values_) { + auto result = values_; + PARQUET_THROW_NOT_OK(result->Resize(bytes_for_values(values_written_), true)); + values_ = AllocateBuffer(this->pool_); + values_capacity_ = 0; + return result; + } else { + return nullptr; + } + } + + std::shared_ptr ReleaseIsValid() override { + if (leaf_info_.HasNullableValues()) { + auto result = valid_bits_; + PARQUET_THROW_NOT_OK(result->Resize(BitUtil::BytesForBits(values_written_), true)); + valid_bits_ = AllocateBuffer(this->pool_); + return result; + } else { + return nullptr; + } + } + + // Process written repetition/definition levels to reach the end of + // records. Process no more levels than necessary to delimit the indicated + // number of logical records. Updates internal state of RecordReader + // + // \return Number of records delimited + int64_t DelimitRecords(int64_t num_records, int64_t* values_seen) { + int64_t values_to_read = 0; + int64_t records_read = 0; + + const int16_t* def_levels = this->def_levels() + levels_position_; + const int16_t* rep_levels = this->rep_levels() + levels_position_; + + DCHECK_GT(this->max_rep_level_, 0); + + // Count logical records and number of values to read + while (levels_position_ < levels_written_) { + const int16_t rep_level = *rep_levels++; + if (rep_level == 0) { + // If at_record_start_ is true, we are seeing the start of a record + // for the second time, such as after repeated calls to + // DelimitRecords. In this case we must continue until we find + // another record start or exhausting the ColumnChunk + if (!at_record_start_) { + // We've reached the end of a record; increment the record count. + ++records_read; + if (records_read == num_records) { + // We've found the number of records we were looking for. Set + // at_record_start_ to true and break + at_record_start_ = true; + break; + } + } + } + // We have decided to consume the level at this position; therefore we + // must advance until we find another record boundary + at_record_start_ = false; + + const int16_t def_level = *def_levels++; + if (def_level == this->max_def_level_) { + ++values_to_read; + } + ++levels_position_; + } + *values_seen = values_to_read; + return records_read; + } + + void Reserve(int64_t capacity) override { + ReserveLevels(capacity); + ReserveValues(capacity); + } + + int64_t UpdateCapacity(int64_t capacity, int64_t size, int64_t extra_size) { + if (extra_size < 0) { + throw ParquetException("Negative size (corrupt file?)"); + } + int64_t target_size = -1; + if (AddWithOverflow(size, extra_size, &target_size)) { + throw ParquetException("Allocation size too large (corrupt file?)"); + } + if (target_size >= (1LL << 62)) { + throw ParquetException("Allocation size too large (corrupt file?)"); + } + if (capacity >= target_size) { + return capacity; + } + return BitUtil::NextPower2(target_size); + } + + void ReserveLevels(int64_t extra_levels) { + if (this->max_def_level_ > 0) { + const int64_t new_levels_capacity = + UpdateCapacity(levels_capacity_, levels_written_, extra_levels); + if (new_levels_capacity > levels_capacity_) { + constexpr auto kItemSize = static_cast(sizeof(int16_t)); + int64_t capacity_in_bytes = -1; + if (MultiplyWithOverflow(new_levels_capacity, kItemSize, &capacity_in_bytes)) { + throw ParquetException("Allocation size too large (corrupt file?)"); + } + PARQUET_THROW_NOT_OK(def_levels_->Resize(capacity_in_bytes, false)); + if (this->max_rep_level_ > 0) { + PARQUET_THROW_NOT_OK(rep_levels_->Resize(capacity_in_bytes, false)); + } + levels_capacity_ = new_levels_capacity; + } + } + } + + void ReserveValues(int64_t extra_values) { + const int64_t new_values_capacity = + UpdateCapacity(values_capacity_, values_written_, extra_values); + if (new_values_capacity > values_capacity_) { + // XXX(wesm): A hack to avoid memory allocation when reading directly + // into builder classes + if (uses_values_) { + PARQUET_THROW_NOT_OK( + values_->Resize(bytes_for_values(new_values_capacity), false)); + } + values_capacity_ = new_values_capacity; + } + if (leaf_info_.HasNullableValues()) { + int64_t valid_bytes_new = BitUtil::BytesForBits(values_capacity_); + if (valid_bits_->size() < valid_bytes_new) { + int64_t valid_bytes_old = BitUtil::BytesForBits(values_written_); + PARQUET_THROW_NOT_OK(valid_bits_->Resize(valid_bytes_new, false)); + + // Avoid valgrind warnings + memset(valid_bits_->mutable_data() + valid_bytes_old, 0, + valid_bytes_new - valid_bytes_old); + } + } + } + + void Reset() override { + ResetValues(); + + if (levels_written_ > 0) { + const int64_t levels_remaining = levels_written_ - levels_position_; + // Shift remaining levels to beginning of buffer and trim to only the number + // of decoded levels remaining + int16_t* def_data = def_levels(); + int16_t* rep_data = rep_levels(); + + std::copy(def_data + levels_position_, def_data + levels_written_, def_data); + PARQUET_THROW_NOT_OK( + def_levels_->Resize(levels_remaining * sizeof(int16_t), false)); + + if (this->max_rep_level_ > 0) { + std::copy(rep_data + levels_position_, rep_data + levels_written_, rep_data); + PARQUET_THROW_NOT_OK( + rep_levels_->Resize(levels_remaining * sizeof(int16_t), false)); + } + + levels_written_ -= levels_position_; + levels_position_ = 0; + levels_capacity_ = levels_remaining; + } + + records_read_ = 0; + + // Call Finish on the binary builders to reset them + } + + void SetPageReader(std::unique_ptr reader) override { + at_record_start_ = true; + this->pager_ = std::move(reader); + ResetDecoders(); + } + + bool HasMoreData() const override { return this->pager_ != nullptr; } + + // Dictionary decoders must be reset when advancing row groups + void ResetDecoders() { this->decoders_.clear(); } + + virtual void ReadValuesSpaced(int64_t values_with_nulls, int64_t null_count) { + uint8_t* valid_bits = valid_bits_->mutable_data(); + const int64_t valid_bits_offset = values_written_; + + int64_t num_decoded = this->current_decoder_->DecodeSpaced( + ValuesHead(), static_cast(values_with_nulls), + static_cast(null_count), valid_bits, valid_bits_offset); + DCHECK_EQ(num_decoded, values_with_nulls); + } + + virtual void ReadValuesDense(int64_t values_to_read) { + int64_t num_decoded = + this->current_decoder_->Decode(ValuesHead(), static_cast(values_to_read)); + DCHECK_EQ(num_decoded, values_to_read); + } + + // Return number of logical records read + int64_t ReadRecordData(int64_t num_records) { + // Conservative upper bound + const int64_t possible_num_values = + std::max(num_records, levels_written_ - levels_position_); + ReserveValues(possible_num_values); + + const int64_t start_levels_position = levels_position_; + + int64_t values_to_read = 0; + int64_t records_read = 0; + if (this->max_rep_level_ > 0) { + records_read = DelimitRecords(num_records, &values_to_read); + } else if (this->max_def_level_ > 0) { + // No repetition levels, skip delimiting logic. Each level represents a + // null or not null entry + records_read = std::min(levels_written_ - levels_position_, num_records); + + // This is advanced by DelimitRecords, which we skipped + levels_position_ += records_read; + } else { + records_read = values_to_read = num_records; + } + + int64_t null_count = 0; + if (leaf_info_.HasNullableValues()) { + ValidityBitmapInputOutput validity_io; + validity_io.values_read_upper_bound = levels_position_ - start_levels_position; + validity_io.valid_bits = valid_bits_->mutable_data(); + validity_io.valid_bits_offset = values_written_; + + DefLevelsToBitmap(def_levels() + start_levels_position, + levels_position_ - start_levels_position, leaf_info_, + &validity_io); + values_to_read = validity_io.values_read - validity_io.null_count; + null_count = validity_io.null_count; + DCHECK_GE(values_to_read, 0); + ReadValuesSpaced(validity_io.values_read, null_count); + } else { + DCHECK_GE(values_to_read, 0); + ReadValuesDense(values_to_read); + } + if (this->leaf_info_.def_level > 0) { + // Optional, repeated, or some mix thereof + this->ConsumeBufferedValues(levels_position_ - start_levels_position); + } else { + // Flat, non-repeated + this->ConsumeBufferedValues(values_to_read); + } + // Total values, including null spaces, if any + values_written_ += values_to_read + null_count; + null_count_ += null_count; + + return records_read; + } + + void DebugPrintState() override { + const int16_t* def_levels = this->def_levels(); + const int16_t* rep_levels = this->rep_levels(); + const int64_t total_levels_read = levels_position_; + + const T* vals = reinterpret_cast(this->values()); + + std::cout << "def levels: "; + for (int64_t i = 0; i < total_levels_read; ++i) { + std::cout << def_levels[i] << " "; + } + std::cout << std::endl; + + std::cout << "rep levels: "; + for (int64_t i = 0; i < total_levels_read; ++i) { + std::cout << rep_levels[i] << " "; + } + std::cout << std::endl; + + std::cout << "values: "; + for (int64_t i = 0; i < this->values_written(); ++i) { + std::cout << vals[i] << " "; + } + std::cout << std::endl; + } + + void ResetValues() { + if (values_written_ > 0) { + // Resize to 0, but do not shrink to fit + if (uses_values_) { + PARQUET_THROW_NOT_OK(values_->Resize(0, false)); + } + PARQUET_THROW_NOT_OK(valid_bits_->Resize(0, false)); + values_written_ = 0; + values_capacity_ = 0; + null_count_ = 0; + } + } + + protected: + template + T* ValuesHead() { + return reinterpret_cast(values_->mutable_data()) + values_written_; + } + LevelInfo leaf_info_; + }; + + class FLBARecordReader : public TypedRecordReader, + virtual public BinaryRecordReader { + public: + FLBARecordReader(const ColumnDescriptor* descr, LevelInfo leaf_info, + ::arrow::MemoryPool* pool) + : TypedRecordReader(descr, leaf_info, pool), builder_(nullptr) { + DCHECK_EQ(descr_->physical_type(), Type::FIXED_LEN_BYTE_ARRAY); + int byte_width = descr_->type_length(); + std::shared_ptr<::arrow::DataType> type = ::arrow::fixed_size_binary(byte_width); + builder_.reset(new ::arrow::FixedSizeBinaryBuilder(type, this->pool_)); + } + + ::arrow::ArrayVector GetBuilderChunks() override { + std::shared_ptr<::arrow::Array> chunk; + PARQUET_THROW_NOT_OK(builder_->Finish(&chunk)); + return ::arrow::ArrayVector({chunk}); + } + + void ReadValuesDense(int64_t values_to_read) override { + auto values = ValuesHead(); + int64_t num_decoded = + this->current_decoder_->Decode(values, static_cast(values_to_read)); + DCHECK_EQ(num_decoded, values_to_read); + + for (int64_t i = 0; i < num_decoded; i++) { + PARQUET_THROW_NOT_OK(builder_->Append(values[i].ptr)); + } + ResetValues(); + } + + void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override { + uint8_t* valid_bits = valid_bits_->mutable_data(); + const int64_t valid_bits_offset = values_written_; + auto values = ValuesHead(); + + int64_t num_decoded = this->current_decoder_->DecodeSpaced( + values, static_cast(values_to_read), static_cast(null_count), + valid_bits, valid_bits_offset); + DCHECK_EQ(num_decoded, values_to_read); + + for (int64_t i = 0; i < num_decoded; i++) { + if (::arrow::BitUtil::GetBit(valid_bits, valid_bits_offset + i)) { + PARQUET_THROW_NOT_OK(builder_->Append(values[i].ptr)); + } else { + PARQUET_THROW_NOT_OK(builder_->AppendNull()); + } + } + ResetValues(); + } + + private: + std::unique_ptr<::arrow::FixedSizeBinaryBuilder> builder_; + }; + + + class ByteArrayChunkedRecordReader : public TypedRecordReader, + virtual public BinaryRecordReader { + public: + ByteArrayChunkedRecordReader(const ColumnDescriptor* descr, LevelInfo leaf_info, + ::arrow::MemoryPool* pool) + : TypedRecordReader(descr, leaf_info, pool) { + DCHECK_EQ(descr_->physical_type(), Type::BYTE_ARRAY); + accumulator_.builder.reset(new ::arrow::BinaryBuilder(pool)); + } + + ::arrow::ArrayVector GetBuilderChunks() override { + ::arrow::ArrayVector result = accumulator_.chunks; + if (result.size() == 0 || accumulator_.builder->length() > 0) { + std::shared_ptr<::arrow::Array> last_chunk; + PARQUET_THROW_NOT_OK(accumulator_.builder->Finish(&last_chunk)); + result.push_back(std::move(last_chunk)); + } + accumulator_.chunks = {}; + return result; + } + + void ReadValuesDense(int64_t values_to_read) override { + int64_t num_decoded = this->current_decoder_->DecodeArrowNonNull( + static_cast(values_to_read), &accumulator_); + DCHECK_EQ(num_decoded, values_to_read); + ResetValues(); + } + + void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override { + int64_t num_decoded = this->current_decoder_->DecodeArrow( + static_cast(values_to_read), static_cast(null_count), + valid_bits_->mutable_data(), values_written_, &accumulator_); + DCHECK_EQ(num_decoded, values_to_read - null_count); + ResetValues(); + } + + private: + // Helper data structure for accumulating builder chunks + typename EncodingTraits::Accumulator accumulator_; + }; + + + + class CHByteArrayChunkedRecordReader : public TypedRecordReader, virtual public BinaryRecordReader + { + public: + CHByteArrayChunkedRecordReader(const ColumnDescriptor * descr, LevelInfo leaf_info, ::arrow::MemoryPool * pool) + : TypedRecordReader(descr, leaf_info, pool) + { + DCHECK_EQ(descr_->physical_type(), Type::BYTE_ARRAY); + this -> pool = pool; + //accumulator_.builder.reset(new ::arrow::BinaryBuilder(pool)); + } + + bool inited = false; + PaddedPODArray * column_chars_t_p; + PaddedPODArray * column_offsets_p; + std::unique_ptr internal_column; + ::arrow::MemoryPool * pool; + + std::shared_ptr<::arrow::Array> fake_array; + int64_t null_counter = 0; + int64_t value_counter = 0; + + void initialize() { + accumulator_.builder = std::make_unique<::arrow::BinaryBuilder>(pool); + if (!fake_array) { + accumulator_.builder->AppendNulls(8192); + accumulator_.builder->Finish(&fake_array); + } + inited = true; + } + + void createColumnIfNeeded() { + if (!internal_column) { + auto internal_type = std::make_shared(); + internal_column = std::make_unique(std::move(internal_type->createColumn())); + column_chars_t_p = &assert_cast(**internal_column).getChars(); + column_offsets_p = &assert_cast(**internal_column).getOffsets(); + } + } + + ::arrow::ArrayVector GetBuilderChunks() override + { + if (!internal_column) { // !internal_column happens at the last empty chunk + ::arrow::ArrayVector result = accumulator_.chunks; + if (accumulator_.builder->length() > 0) { + throw ::parquet::ParquetException("unexpected data existing"); + } + accumulator_.chunks = {}; + return result; + } else { + MutableColumnPtr temp = std::move(*internal_column); + internal_column.reset(); + fake_array->data()->length = temp->size();//the last batch's size may < 8192 + + fake_array->data()->SetNullCount(null_counter); + null_counter = 0; + value_counter = 0; + return {std::make_shared( + ColumnWithTypeAndName(std::move(temp), std::make_shared(), ""),fake_array)}; + } + } + + void ReadValuesDense(int64_t values_to_read) override + { + if (unlikely(!inited)) {initialize();} + + ::arrow::internal::BitmapWriter bitmap_writer( + const_cast(fake_array->data()->buffers[0]->data()), + value_counter, values_to_read); + + createColumnIfNeeded(); + int64_t num_decoded + = this->current_decoder_->DecodeCHNonNull(static_cast(values_to_read), column_chars_t_p, column_offsets_p, bitmap_writer); + DCHECK_EQ(num_decoded, values_to_read); + ResetValues(); + + value_counter += values_to_read; + bitmap_writer.Finish(); + } + + void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override + { + if (unlikely(!inited)) {initialize();} + + ::arrow::internal::BitmapWriter bitmap_writer( + const_cast(fake_array->data()->buffers[0]->data()), + value_counter, values_to_read); + + createColumnIfNeeded(); + int64_t num_decoded = this->current_decoder_->DecodeCH( + static_cast(values_to_read), + static_cast(null_count), + valid_bits_->mutable_data(), + values_written_, + column_chars_t_p, + column_offsets_p, + bitmap_writer); + + null_counter += null_count; + value_counter += values_to_read; + DCHECK_EQ(num_decoded, values_to_read - null_count); + ResetValues(); + + bitmap_writer.Finish(); + } + + private: + // Helper data structure for accumulating builder chunks + typename EncodingTraits::Accumulator accumulator_; + }; + + + + class ByteArrayDictionaryRecordReader : public TypedRecordReader, + virtual public DictionaryRecordReader { + public: + ByteArrayDictionaryRecordReader(const ColumnDescriptor* descr, LevelInfo leaf_info, + ::arrow::MemoryPool* pool) + : TypedRecordReader(descr, leaf_info, pool), builder_(pool) { + this->read_dictionary_ = true; + } + + std::shared_ptr<::arrow::ChunkedArray> GetResult() override { + FlushBuilder(); + std::vector> result; + std::swap(result, result_chunks_); + return std::make_shared<::arrow::ChunkedArray>(std::move(result), builder_.type()); + } + + void FlushBuilder() { + if (builder_.length() > 0) { + std::shared_ptr<::arrow::Array> chunk; + PARQUET_THROW_NOT_OK(builder_.Finish(&chunk)); + result_chunks_.emplace_back(std::move(chunk)); + + // Also clears the dictionary memo table + builder_.Reset(); + } + } + + void MaybeWriteNewDictionary() { + if (this->new_dictionary_) { + /// If there is a new dictionary, we may need to flush the builder, then + /// insert the new dictionary values + FlushBuilder(); + builder_.ResetFull(); + auto decoder = dynamic_cast(this->current_decoder_); + decoder->InsertDictionary(&builder_); + this->new_dictionary_ = false; + } + } + + void ReadValuesDense(int64_t values_to_read) override { + int64_t num_decoded = 0; + if (current_encoding_ == Encoding::RLE_DICTIONARY) { + MaybeWriteNewDictionary(); + auto decoder = dynamic_cast(this->current_decoder_); + num_decoded = decoder->DecodeIndices(static_cast(values_to_read), &builder_); + } else { + num_decoded = this->current_decoder_->DecodeArrowNonNull( + static_cast(values_to_read), &builder_); + + /// Flush values since they have been copied into the builder + ResetValues(); + } + DCHECK_EQ(num_decoded, values_to_read); + } + + void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override { + int64_t num_decoded = 0; + if (current_encoding_ == Encoding::RLE_DICTIONARY) { + MaybeWriteNewDictionary(); + auto decoder = dynamic_cast(this->current_decoder_); + num_decoded = decoder->DecodeIndicesSpaced( + static_cast(values_to_read), static_cast(null_count), + valid_bits_->mutable_data(), values_written_, &builder_); + } else { + num_decoded = this->current_decoder_->DecodeArrow( + static_cast(values_to_read), static_cast(null_count), + valid_bits_->mutable_data(), values_written_, &builder_); + + /// Flush values since they have been copied into the builder + ResetValues(); + } + DCHECK_EQ(num_decoded, values_to_read - null_count); + } + + private: + using BinaryDictDecoder = DictDecoder; + + ::arrow::BinaryDictionary32Builder builder_; + std::vector> result_chunks_; + }; + + // TODO(wesm): Implement these to some satisfaction + template <> + void TypedRecordReader::DebugPrintState() {} + + template <> + void TypedRecordReader::DebugPrintState() {} + + template <> + void TypedRecordReader::DebugPrintState() {} + + std::shared_ptr MakeByteArrayRecordReader(const ColumnDescriptor* descr, + LevelInfo leaf_info, + ::arrow::MemoryPool* pool, + bool read_dictionary) { + if (read_dictionary) { + return std::make_shared(descr, leaf_info, pool); + } else if (descr->path()->ToDotVector().size() == 1 && descr->logical_type()->type() == LogicalType::Type::type::STRING) { + /// CHByteArrayChunkedRecordReader is only for reading columns with type String and is not nested in complex type + /// This fixes issue: https://github.com/Kyligence/ClickHouse/issues/166 + return std::make_shared(descr, leaf_info, pool); + } else { + return std::make_shared(descr, leaf_info, pool); + } + } + + } // namespace + + std::shared_ptr RecordReader::Make(const ColumnDescriptor* descr, + LevelInfo leaf_info, MemoryPool* pool, + const bool read_dictionary) { + switch (descr->physical_type()) { + case Type::BOOLEAN: + return std::make_shared>(descr, leaf_info, pool); + case Type::INT32: + return std::make_shared>(descr, leaf_info, pool); + case Type::INT64: + return std::make_shared>(descr, leaf_info, pool); + case Type::INT96: + return std::make_shared>(descr, leaf_info, pool); + case Type::FLOAT: + return std::make_shared>(descr, leaf_info, pool); + case Type::DOUBLE: + return std::make_shared>(descr, leaf_info, pool); + case Type::BYTE_ARRAY: + return MakeByteArrayRecordReader(descr, leaf_info, pool, read_dictionary); + case Type::FIXED_LEN_BYTE_ARRAY: + return std::make_shared(descr, leaf_info, pool); + default: { + // PARQUET-1481: This can occur if the file is corrupt + std::stringstream ss; + ss << "Invalid physical column type: " << static_cast(descr->physical_type()); + throw ParquetException(ss.str()); + } + } + // Unreachable code, but suppress compiler warning + return nullptr; + } + +} // namespace internal +} // namespace parquet diff --git a/utils/local-engine/Storages/ch_parquet/arrow/column_reader.h b/utils/local-engine/Storages/ch_parquet/arrow/column_reader.h new file mode 100644 index 000000000000..64e15de764a6 --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/arrow/column_reader.h @@ -0,0 +1,403 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "parquet/exception.h" +#include "parquet/level_conversion.h" +#include "parquet/platform.h" +#include "parquet/schema.h" +#include "parquet/types.h" + +#include "arrow/array.h" +#include "arrow/chunked_array.h" +#include "arrow/array/builder_binary.h" +#include "arrow/type.h" +#include + +namespace arrow { + +class Array; +class ChunkedArray; + +namespace BitUtil { + class BitReader; +} // namespace BitUtil + +namespace util { + class RleDecoder; +} // namespace util + +} // namespace arrow + +namespace parquet{ +class Decryptor; +class Page; +} + +namespace ch_parquet +{ +using namespace parquet; + + +// 16 MB is the default maximum page header size +static constexpr uint32_t kDefaultMaxPageHeaderSize = 16 * 1024 * 1024; + +// 16 KB is the default expected page header size +static constexpr uint32_t kDefaultPageHeaderSize = 16 * 1024; + +class PARQUET_EXPORT LevelDecoder +{ +public: + LevelDecoder(); + ~LevelDecoder(); + + // Initialize the LevelDecoder state with new data + // and return the number of bytes consumed + int SetData(Encoding::type encoding, int16_t max_level, int num_buffered_values, const uint8_t * data, int32_t data_size); + + void SetDataV2(int32_t num_bytes, int16_t max_level, int num_buffered_values, const uint8_t * data); + + // Decodes a batch of levels into an array and returns the number of levels decoded + int Decode(int batch_size, int16_t * levels); + +private: + int bit_width_; + int num_values_remaining_; + Encoding::type encoding_; + std::unique_ptr<::arrow::util::RleDecoder> rle_decoder_; + std::unique_ptr<::arrow::BitUtil::BitReader> bit_packed_decoder_; + int16_t max_level_; +}; + +struct CryptoContext { + CryptoContext(bool start_with_dictionary_page, int16_t rg_ordinal, int16_t col_ordinal, + std::shared_ptr meta, std::shared_ptr data) + : start_decrypt_with_dictionary_page(start_with_dictionary_page), + row_group_ordinal(rg_ordinal), + column_ordinal(col_ordinal), + meta_decryptor(std::move(meta)), + data_decryptor(std::move(data)) {} + CryptoContext() {} + + bool start_decrypt_with_dictionary_page = false; + int16_t row_group_ordinal = -1; + int16_t column_ordinal = -1; + std::shared_ptr meta_decryptor; + std::shared_ptr data_decryptor; +}; + +} +namespace parquet{ +using namespace ch_parquet; +// Abstract page iterator interface. This way, we can feed column pages to the +// ColumnReader through whatever mechanism we choose +class PARQUET_EXPORT PageReader { +public: + virtual ~PageReader() = default; + + static std::unique_ptr Open( + std::shared_ptr stream, int64_t total_num_rows, + Compression::type codec, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(), + const CryptoContext* ctx = NULLPTR); + + // @returns: shared_ptr(nullptr) on EOS, std::shared_ptr + // containing new Page otherwise + virtual std::shared_ptr NextPage() = 0; + + virtual void set_max_page_header_size(uint32_t size) = 0; +}; +} +namespace ch_parquet{ + +class PARQUET_EXPORT ColumnReader { +public: + virtual ~ColumnReader() = default; + + static std::shared_ptr Make( + const ColumnDescriptor* descr, std::unique_ptr pager, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + + // Returns true if there are still values in this column. + virtual bool HasNext() = 0; + + virtual Type::type type() const = 0; + + virtual const ColumnDescriptor* descr() const = 0; + + // Get the encoding that can be exposed by this reader. If it returns + // dictionary encoding, then ReadBatchWithDictionary can be used to read data. + // + // \note API EXPERIMENTAL + virtual ExposedEncoding GetExposedEncoding() = 0; + +protected: + friend class RowGroupReader; + // Set the encoding that can be exposed by this reader. + // + // \note API EXPERIMENTAL + virtual void SetExposedEncoding(ExposedEncoding encoding) = 0; +}; + +// API to read values from a single column. This is a main client facing API. +template +class TypedColumnReader : public ColumnReader { +public: + typedef typename DType::c_type T; + + // Read a batch of repetition levels, definition levels, and values from the + // column. + // + // Since null values are not stored in the values, the number of values read + // may be less than the number of repetition and definition levels. With + // nested data this is almost certainly true. + // + // Set def_levels or rep_levels to nullptr if you want to skip reading them. + // This is only safe if you know through some other source that there are no + // undefined values. + // + // To fully exhaust a row group, you must read batches until the number of + // values read reaches the number of stored values according to the metadata. + // + // This API is the same for both V1 and V2 of the DataPage + // + // @returns: actual number of levels read (see values_read for number of values read) + virtual int64_t ReadBatch(int64_t batch_size, int16_t* def_levels, int16_t* rep_levels, + T* values, int64_t* values_read) = 0; + + /// Read a batch of repetition levels, definition levels, and values from the + /// column and leave spaces for null entries on the lowest level in the values + /// buffer. + /// + /// In comparison to ReadBatch the length of repetition and definition levels + /// is the same as of the number of values read for max_definition_level == 1. + /// In the case of max_definition_level > 1, the repetition and definition + /// levels are larger than the values but the values include the null entries + /// with definition_level == (max_definition_level - 1). + /// + /// To fully exhaust a row group, you must read batches until the number of + /// values read reaches the number of stored values according to the metadata. + /// + /// @param batch_size the number of levels to read + /// @param[out] def_levels The Parquet definition levels, output has + /// the length levels_read. + /// @param[out] rep_levels The Parquet repetition levels, output has + /// the length levels_read. + /// @param[out] values The values in the lowest nested level including + /// spacing for nulls on the lowest levels; output has the length + /// values_read. + /// @param[out] valid_bits Memory allocated for a bitmap that indicates if + /// the row is null or on the maximum definition level. For performance + /// reasons the underlying buffer should be able to store 1 bit more than + /// required. If this requires an additional byte, this byte is only read + /// but never written to. + /// @param valid_bits_offset The offset in bits of the valid_bits where the + /// first relevant bit resides. + /// @param[out] levels_read The number of repetition/definition levels that were read. + /// @param[out] values_read The number of values read, this includes all + /// non-null entries as well as all null-entries on the lowest level + /// (i.e. definition_level == max_definition_level - 1) + /// @param[out] null_count The number of nulls on the lowest levels. + /// (i.e. (values_read - null_count) is total number of non-null entries) + /// + /// \deprecated Since 4.0.0 + ARROW_DEPRECATED("Doesn't handle nesting correctly and unused outside of unit tests.") + virtual int64_t ReadBatchSpaced(int64_t batch_size, int16_t* def_levels, + int16_t* rep_levels, T* values, uint8_t* valid_bits, + int64_t valid_bits_offset, int64_t* levels_read, + int64_t* values_read, int64_t* null_count) = 0; + + // Skip reading levels + // Returns the number of levels skipped + virtual int64_t Skip(int64_t num_rows_to_skip) = 0; + + // Read a batch of repetition levels, definition levels, and indices from the + // column. And read the dictionary if a dictionary page is encountered during + // reading pages. This API is similar to ReadBatch(), with ability to read + // dictionary and indices. It is only valid to call this method when the reader can + // expose dictionary encoding. (i.e., the reader's GetExposedEncoding() returns + // DICTIONARY). + // + // The dictionary is read along with the data page. When there's no data page, + // the dictionary won't be returned. + // + // @param batch_size The batch size to read + // @param[out] def_levels The Parquet definition levels. + // @param[out] rep_levels The Parquet repetition levels. + // @param[out] indices The dictionary indices. + // @param[out] indices_read The number of indices read. + // @param[out] dict The pointer to dictionary values. It will return nullptr if + // there's no data page. Each column chunk only has one dictionary page. The dictionary + // is owned by the reader, so the caller is responsible for copying the dictionary + // values before the reader gets destroyed. + // @param[out] dict_len The dictionary length. It will return 0 if there's no data + // page. + // @returns: actual number of levels read (see indices_read for number of + // indices read + // + // \note API EXPERIMENTAL + virtual int64_t ReadBatchWithDictionary(int64_t batch_size, int16_t* def_levels, + int16_t* rep_levels, int32_t* indices, + int64_t* indices_read, const T** dict, + int32_t* dict_len) = 0; +}; + +namespace internal { +using namespace parquet::internal; + + /// \brief Stateful column reader that delimits semantic records for both flat + /// and nested columns + /// + /// \note API EXPERIMENTAL + /// \since 1.3.0 + class RecordReader { + public: + static std::shared_ptr Make( + const ColumnDescriptor* descr, LevelInfo leaf_info, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(), + const bool read_dictionary = false); + + virtual ~RecordReader() = default; + + /// \brief Attempt to read indicated number of records from column chunk + /// \return number of records read + virtual int64_t ReadRecords(int64_t num_records) = 0; + + /// \brief Pre-allocate space for data. Results in better flat read performance + virtual void Reserve(int64_t num_values) = 0; + + /// \brief Clear consumed values and repetition/definition levels as the + /// result of calling ReadRecords + virtual void Reset() = 0; + + /// \brief Transfer filled values buffer to caller. A new one will be + /// allocated in subsequent ReadRecords calls + virtual std::shared_ptr ReleaseValues() = 0; + + /// \brief Transfer filled validity bitmap buffer to caller. A new one will + /// be allocated in subsequent ReadRecords calls + virtual std::shared_ptr ReleaseIsValid() = 0; + + /// \brief Return true if the record reader has more internal data yet to + /// process + virtual bool HasMoreData() const = 0; + + /// \brief Advance record reader to the next row group + /// \param[in] reader obtained from RowGroupReader::GetColumnPageReader + virtual void SetPageReader(std::unique_ptr reader) = 0; + + virtual void DebugPrintState() = 0; + + /// \brief Decoded definition levels + int16_t* def_levels() const { + return reinterpret_cast(def_levels_->mutable_data()); + } + + /// \brief Decoded repetition levels + int16_t* rep_levels() const { + return reinterpret_cast(rep_levels_->mutable_data()); + } + + /// \brief Decoded values, including nulls, if any + uint8_t* values() const { return values_->mutable_data(); } + + /// \brief Number of values written including nulls (if any) + int64_t values_written() const { return values_written_; } + + /// \brief Number of definition / repetition levels (from those that have + /// been decoded) that have been consumed inside the reader. + int64_t levels_position() const { return levels_position_; } + + /// \brief Number of definition / repetition levels that have been written + /// internally in the reader + int64_t levels_written() const { return levels_written_; } + + /// \brief Number of nulls in the leaf + int64_t null_count() const { return null_count_; } + + /// \brief True if the leaf values are nullable + bool nullable_values() const { return nullable_values_; } + + /// \brief True if reading directly as Arrow dictionary-encoded + bool read_dictionary() const { return read_dictionary_; } + + protected: + bool nullable_values_; + + bool at_record_start_; + int64_t records_read_; + + int64_t values_written_; + int64_t values_capacity_; + int64_t null_count_; + + int64_t levels_written_; + int64_t levels_position_; + int64_t levels_capacity_; + + std::shared_ptr<::arrow::ResizableBuffer> values_; + // In the case of false, don't allocate the values buffer (when we directly read into + // builder classes). + bool uses_values_; + + std::shared_ptr<::arrow::ResizableBuffer> valid_bits_; + std::shared_ptr<::arrow::ResizableBuffer> def_levels_; + std::shared_ptr<::arrow::ResizableBuffer> rep_levels_; + + bool read_dictionary_ = false; + }; + + class CHStringArray : public ::arrow::BinaryArray + { + public: + CHStringArray(DB::ColumnWithTypeAndName column, std::shared_ptr<::arrow::Array> fake_array) : BinaryArray(fake_array -> data()) + { + this->column = column; + } + + DB::ColumnWithTypeAndName column; + }; + + class BinaryRecordReader : virtual public RecordReader { + public: + virtual std::vector> GetBuilderChunks() = 0; + }; + + /// \brief Read records directly to dictionary-encoded Arrow form (int32 + /// indices). Only valid for BYTE_ARRAY columns + class DictionaryRecordReader : virtual public RecordReader { + public: + virtual std::shared_ptr<::arrow::ChunkedArray> GetResult() = 0; + }; + +} // namespace internal + +using BoolReader = TypedColumnReader; +using Int32Reader = TypedColumnReader; +using Int64Reader = TypedColumnReader; +using Int96Reader = TypedColumnReader; +using FloatReader = TypedColumnReader; +using DoubleReader = TypedColumnReader; +using ByteArrayReader = TypedColumnReader; +using FixedLenByteArrayReader = TypedColumnReader; + +} // namespace parquet diff --git a/utils/local-engine/Storages/ch_parquet/arrow/encoding.cc b/utils/local-engine/Storages/ch_parquet/arrow/encoding.cc new file mode 100644 index 000000000000..3da6cc0abf18 --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/arrow/encoding.cc @@ -0,0 +1,2835 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "encoding.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/array/builder_dict.h" +#include "arrow/stl_allocator.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_stream_utils.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/bitmap_writer.h" +#include "arrow/util/byte_stream_split.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/hashing.h" +#include "arrow/util/logging.h" +#include "arrow/util/rle_encoding.h" +#include "arrow/util/ubsan.h" +#include "arrow/visitor_inline.h" +#include "parquet/exception.h" +#include "parquet/platform.h" +#include "parquet/schema.h" +#include "parquet/types.h" + +namespace BitUtil = arrow::BitUtil; + +using arrow::Status; +using arrow::VisitNullBitmapInline; +using arrow::internal::checked_cast; + +template +using ArrowPoolVector = std::vector>; + +namespace ch_parquet { +using namespace parquet; +namespace { + +constexpr int64_t kInMemoryDefaultCapacity = 1024; +// The Parquet spec isn't very clear whether ByteArray lengths are signed or +// unsigned, but the Java implementation uses signed ints. +constexpr size_t kMaxByteArraySize = std::numeric_limits::max(); + +class EncoderImpl : virtual public Encoder { + public: + EncoderImpl(const ColumnDescriptor* descr, Encoding::type encoding, MemoryPool* pool) + : descr_(descr), + encoding_(encoding), + pool_(pool), + type_length_(descr ? descr->type_length() : -1) {} + + Encoding::type encoding() const override { return encoding_; } + + MemoryPool* memory_pool() const override { return pool_; } + + protected: + // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY + const ColumnDescriptor* descr_; + const Encoding::type encoding_; + MemoryPool* pool_; + + /// Type length from descr + int type_length_; +}; + +// ---------------------------------------------------------------------- +// Plain encoder implementation + +template +class PlainEncoder : public EncoderImpl, virtual public TypedEncoder { + public: + using T = typename DType::c_type; + + explicit PlainEncoder(const ColumnDescriptor* descr, MemoryPool* pool) + : EncoderImpl(descr, Encoding::PLAIN, pool), sink_(pool) {} + + int64_t EstimatedDataEncodedSize() override { return sink_.length(); } + + std::shared_ptr FlushValues() override { + std::shared_ptr buffer; + PARQUET_THROW_NOT_OK(sink_.Finish(&buffer)); + return buffer; + } + + using TypedEncoder::Put; + + void Put(const T* buffer, int num_values) override; + + void Put(const ::arrow::Array& values) override; + + void PutSpaced(const T* src, int num_values, const uint8_t* valid_bits, + int64_t valid_bits_offset) override { + if (valid_bits != NULLPTR) { + PARQUET_ASSIGN_OR_THROW(auto buffer, ::arrow::AllocateBuffer(num_values * sizeof(T), + this->memory_pool())); + T* data = reinterpret_cast(buffer->mutable_data()); + int num_valid_values = ::arrow::util::internal::SpacedCompress( + src, num_values, valid_bits, valid_bits_offset, data); + Put(data, num_valid_values); + } else { + Put(src, num_values); + } + } + + void UnsafePutByteArray(const void* data, uint32_t length) { + DCHECK(length == 0 || data != nullptr) << "Value ptr cannot be NULL"; + sink_.UnsafeAppend(&length, sizeof(uint32_t)); + sink_.UnsafeAppend(data, static_cast(length)); + } + + void Put(const ByteArray& val) { + // Write the result to the output stream + const int64_t increment = static_cast(val.len + sizeof(uint32_t)); + if (ARROW_PREDICT_FALSE(sink_.length() + increment > sink_.capacity())) { + PARQUET_THROW_NOT_OK(sink_.Reserve(increment)); + } + UnsafePutByteArray(val.ptr, val.len); + } + + protected: + template + void PutBinaryArray(const ArrayType& array) { + const int64_t total_bytes = + array.value_offset(array.length()) - array.value_offset(0); + PARQUET_THROW_NOT_OK(sink_.Reserve(total_bytes + array.length() * sizeof(uint32_t))); + + PARQUET_THROW_NOT_OK(::arrow::VisitArrayDataInline( + *array.data(), + [&](::arrow::util::string_view view) { + if (ARROW_PREDICT_FALSE(view.size() > kMaxByteArraySize)) { + return Status::Invalid("Parquet cannot store strings with size 2GB or more"); + } + UnsafePutByteArray(view.data(), static_cast(view.size())); + return Status::OK(); + }, + []() { return Status::OK(); })); + } + + ::arrow::BufferBuilder sink_; +}; + +template +void PlainEncoder::Put(const T* buffer, int num_values) { + if (num_values > 0) { + PARQUET_THROW_NOT_OK(sink_.Append(buffer, num_values * sizeof(T))); + } +} + +template <> +inline void PlainEncoder::Put(const ByteArray* src, int num_values) { + for (int i = 0; i < num_values; ++i) { + Put(src[i]); + } +} + +template +void DirectPutImpl(const ::arrow::Array& values, ::arrow::BufferBuilder* sink) { + if (values.type_id() != ArrayType::TypeClass::type_id) { + std::string type_name = ArrayType::TypeClass::type_name(); + throw ParquetException("direct put to " + type_name + " from " + + values.type()->ToString() + " not supported"); + } + + using value_type = typename ArrayType::value_type; + constexpr auto value_size = sizeof(value_type); + auto raw_values = checked_cast(values).raw_values(); + + if (values.null_count() == 0) { + // no nulls, just dump the data + PARQUET_THROW_NOT_OK(sink->Append(raw_values, values.length() * value_size)); + } else { + PARQUET_THROW_NOT_OK( + sink->Reserve((values.length() - values.null_count()) * value_size)); + + for (int64_t i = 0; i < values.length(); i++) { + if (values.IsValid(i)) { + sink->UnsafeAppend(&raw_values[i], value_size); + } + } + } +} + +template <> +void PlainEncoder::Put(const ::arrow::Array& values) { + DirectPutImpl<::arrow::Int32Array>(values, &sink_); +} + +template <> +void PlainEncoder::Put(const ::arrow::Array& values) { + DirectPutImpl<::arrow::Int64Array>(values, &sink_); +} + +template <> +void PlainEncoder::Put(const ::arrow::Array& values) { + ParquetException::NYI("direct put to Int96"); +} + +template <> +void PlainEncoder::Put(const ::arrow::Array& values) { + DirectPutImpl<::arrow::FloatArray>(values, &sink_); +} + +template <> +void PlainEncoder::Put(const ::arrow::Array& values) { + DirectPutImpl<::arrow::DoubleArray>(values, &sink_); +} + +template +void PlainEncoder::Put(const ::arrow::Array& values) { + ParquetException::NYI("direct put of " + values.type()->ToString()); +} + +void AssertBaseBinary(const ::arrow::Array& values) { + if (!::arrow::is_base_binary_like(values.type_id())) { + throw ParquetException("Only BaseBinaryArray and subclasses supported"); + } +} + +template <> +inline void PlainEncoder::Put(const ::arrow::Array& values) { + AssertBaseBinary(values); + + if (::arrow::is_binary_like(values.type_id())) { + PutBinaryArray(checked_cast(values)); + } else { + DCHECK(::arrow::is_large_binary_like(values.type_id())); + PutBinaryArray(checked_cast(values)); + } +} + +void AssertFixedSizeBinary(const ::arrow::Array& values, int type_length) { + if (values.type_id() != ::arrow::Type::FIXED_SIZE_BINARY && + values.type_id() != ::arrow::Type::DECIMAL) { + throw ParquetException("Only FixedSizeBinaryArray and subclasses supported"); + } + if (checked_cast(*values.type()).byte_width() != + type_length) { + throw ParquetException("Size mismatch: " + values.type()->ToString() + + " should have been " + std::to_string(type_length) + " wide"); + } +} + +template <> +inline void PlainEncoder::Put(const ::arrow::Array& values) { + AssertFixedSizeBinary(values, descr_->type_length()); + const auto& data = checked_cast(values); + + if (data.null_count() == 0) { + // no nulls, just dump the data + PARQUET_THROW_NOT_OK( + sink_.Append(data.raw_values(), data.length() * data.byte_width())); + } else { + const int64_t total_bytes = + data.length() * data.byte_width() - data.null_count() * data.byte_width(); + PARQUET_THROW_NOT_OK(sink_.Reserve(total_bytes)); + for (int64_t i = 0; i < data.length(); i++) { + if (data.IsValid(i)) { + sink_.UnsafeAppend(data.Value(i), data.byte_width()); + } + } + } +} + +template <> +inline void PlainEncoder::Put(const FixedLenByteArray* src, int num_values) { + if (descr_->type_length() == 0) { + return; + } + for (int i = 0; i < num_values; ++i) { + // Write the result to the output stream + DCHECK(src[i].ptr != nullptr) << "Value ptr cannot be NULL"; + PARQUET_THROW_NOT_OK(sink_.Append(src[i].ptr, descr_->type_length())); + } +} + +template <> +class PlainEncoder : public EncoderImpl, virtual public BooleanEncoder { + public: + explicit PlainEncoder(const ColumnDescriptor* descr, MemoryPool* pool) + : EncoderImpl(descr, Encoding::PLAIN, pool), + bits_available_(kInMemoryDefaultCapacity * 8), + bits_buffer_(AllocateBuffer(pool, kInMemoryDefaultCapacity)), + sink_(pool), + bit_writer_(bits_buffer_->mutable_data(), + static_cast(bits_buffer_->size())) {} + + int64_t EstimatedDataEncodedSize() override; + std::shared_ptr FlushValues() override; + + void Put(const bool* src, int num_values) override; + + void Put(const std::vector& src, int num_values) override; + + void PutSpaced(const bool* src, int num_values, const uint8_t* valid_bits, + int64_t valid_bits_offset) override { + if (valid_bits != NULLPTR) { + PARQUET_ASSIGN_OR_THROW(auto buffer, ::arrow::AllocateBuffer(num_values * sizeof(T), + this->memory_pool())); + T* data = reinterpret_cast(buffer->mutable_data()); + int num_valid_values = ::arrow::util::internal::SpacedCompress( + src, num_values, valid_bits, valid_bits_offset, data); + Put(data, num_valid_values); + } else { + Put(src, num_values); + } + } + + void Put(const ::arrow::Array& values) override { + if (values.type_id() != ::arrow::Type::BOOL) { + throw ParquetException("direct put to boolean from " + values.type()->ToString() + + " not supported"); + } + + const auto& data = checked_cast(values); + if (data.null_count() == 0) { + PARQUET_THROW_NOT_OK(sink_.Reserve(BitUtil::BytesForBits(data.length()))); + // no nulls, just dump the data + ::arrow::internal::CopyBitmap(data.data()->GetValues(1), data.offset(), + data.length(), sink_.mutable_data(), sink_.length()); + } else { + auto n_valid = BitUtil::BytesForBits(data.length() - data.null_count()); + PARQUET_THROW_NOT_OK(sink_.Reserve(n_valid)); + ::arrow::internal::FirstTimeBitmapWriter writer(sink_.mutable_data(), + sink_.length(), n_valid); + + for (int64_t i = 0; i < data.length(); i++) { + if (data.IsValid(i)) { + if (data.Value(i)) { + writer.Set(); + } else { + writer.Clear(); + } + writer.Next(); + } + } + writer.Finish(); + } + sink_.UnsafeAdvance(data.length()); + } + + private: + int bits_available_; + std::shared_ptr bits_buffer_; + ::arrow::BufferBuilder sink_; + ::arrow::BitUtil::BitWriter bit_writer_; + + template + void PutImpl(const SequenceType& src, int num_values); +}; + +template +void PlainEncoder::PutImpl(const SequenceType& src, int num_values) { + int bit_offset = 0; + if (bits_available_ > 0) { + int bits_to_write = std::min(bits_available_, num_values); + for (int i = 0; i < bits_to_write; i++) { + bit_writer_.PutValue(src[i], 1); + } + bits_available_ -= bits_to_write; + bit_offset = bits_to_write; + + if (bits_available_ == 0) { + bit_writer_.Flush(); + PARQUET_THROW_NOT_OK( + sink_.Append(bit_writer_.buffer(), bit_writer_.bytes_written())); + bit_writer_.Clear(); + } + } + + int bits_remaining = num_values - bit_offset; + while (bit_offset < num_values) { + bits_available_ = static_cast(bits_buffer_->size()) * 8; + + int bits_to_write = std::min(bits_available_, bits_remaining); + for (int i = bit_offset; i < bit_offset + bits_to_write; i++) { + bit_writer_.PutValue(src[i], 1); + } + bit_offset += bits_to_write; + bits_available_ -= bits_to_write; + bits_remaining -= bits_to_write; + + if (bits_available_ == 0) { + bit_writer_.Flush(); + PARQUET_THROW_NOT_OK( + sink_.Append(bit_writer_.buffer(), bit_writer_.bytes_written())); + bit_writer_.Clear(); + } + } +} + +int64_t PlainEncoder::EstimatedDataEncodedSize() { + int64_t position = sink_.length(); + return position + bit_writer_.bytes_written(); +} + +std::shared_ptr PlainEncoder::FlushValues() { + if (bits_available_ > 0) { + bit_writer_.Flush(); + PARQUET_THROW_NOT_OK(sink_.Append(bit_writer_.buffer(), bit_writer_.bytes_written())); + bit_writer_.Clear(); + bits_available_ = static_cast(bits_buffer_->size()) * 8; + } + + std::shared_ptr buffer; + PARQUET_THROW_NOT_OK(sink_.Finish(&buffer)); + return buffer; +} + +void PlainEncoder::Put(const bool* src, int num_values) { + PutImpl(src, num_values); +} + +void PlainEncoder::Put(const std::vector& src, int num_values) { + PutImpl(src, num_values); +} + +// ---------------------------------------------------------------------- +// DictEncoder implementations + +template +struct DictEncoderTraits { + using c_type = typename DType::c_type; + using MemoTableType = ::arrow::internal::ScalarMemoTable; +}; + +template <> +struct DictEncoderTraits { + using MemoTableType = ::arrow::internal::BinaryMemoTable<::arrow::BinaryBuilder>; +}; + +template <> +struct DictEncoderTraits { + using MemoTableType = ::arrow::internal::BinaryMemoTable<::arrow::BinaryBuilder>; +}; + +// Initially 1024 elements +static constexpr int32_t kInitialHashTableSize = 1 << 10; + +/// See the dictionary encoding section of +/// https://github.com/Parquet/parquet-format. The encoding supports +/// streaming encoding. Values are encoded as they are added while the +/// dictionary is being constructed. At any time, the buffered values +/// can be written out with the current dictionary size. More values +/// can then be added to the encoder, including new dictionary +/// entries. +template +class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { + using MemoTableType = typename DictEncoderTraits::MemoTableType; + + public: + typedef typename DType::c_type T; + + explicit DictEncoderImpl(const ColumnDescriptor* desc, MemoryPool* pool) + : EncoderImpl(desc, Encoding::PLAIN_DICTIONARY, pool), + buffered_indices_(::arrow::stl::allocator(pool)), + dict_encoded_size_(0), + memo_table_(pool, kInitialHashTableSize) {} + + ~DictEncoderImpl() override { DCHECK(buffered_indices_.empty()); } + + int dict_encoded_size() override { return dict_encoded_size_; } + + int WriteIndices(uint8_t* buffer, int buffer_len) override { + // Write bit width in first byte + *buffer = static_cast(bit_width()); + ++buffer; + --buffer_len; + + ::arrow::util::RleEncoder encoder(buffer, buffer_len, bit_width()); + + for (int32_t index : buffered_indices_) { + if (!encoder.Put(index)) return -1; + } + encoder.Flush(); + + ClearIndices(); + return 1 + encoder.len(); + } + + void set_type_length(int type_length) { this->type_length_ = type_length; } + + /// Returns a conservative estimate of the number of bytes needed to encode the buffered + /// indices. Used to size the buffer passed to WriteIndices(). + int64_t EstimatedDataEncodedSize() override { + // Note: because of the way RleEncoder::CheckBufferFull() is called, we have to + // reserve + // an extra "RleEncoder::MinBufferSize" bytes. These extra bytes won't be used + // but not reserving them would cause the encoder to fail. + return 1 + + ::arrow::util::RleEncoder::MaxBufferSize( + bit_width(), static_cast(buffered_indices_.size())) + + ::arrow::util::RleEncoder::MinBufferSize(bit_width()); + } + + /// The minimum bit width required to encode the currently buffered indices. + int bit_width() const override { + if (ARROW_PREDICT_FALSE(num_entries() == 0)) return 0; + if (ARROW_PREDICT_FALSE(num_entries() == 1)) return 1; + return BitUtil::Log2(num_entries()); + } + + /// Encode value. Note that this does not actually write any data, just + /// buffers the value's index to be written later. + inline void Put(const T& value); + + // Not implemented for other data types + inline void PutByteArray(const void* ptr, int32_t length); + + void Put(const T* src, int num_values) override { + for (int32_t i = 0; i < num_values; i++) { + Put(src[i]); + } + } + + void PutSpaced(const T* src, int num_values, const uint8_t* valid_bits, + int64_t valid_bits_offset) override { + ::arrow::internal::VisitSetBitRunsVoid(valid_bits, valid_bits_offset, num_values, + [&](int64_t position, int64_t length) { + for (int64_t i = 0; i < length; i++) { + Put(src[i + position]); + } + }); + } + + using TypedEncoder::Put; + + void Put(const ::arrow::Array& values) override; + void PutDictionary(const ::arrow::Array& values) override; + + template + void PutIndicesTyped(const ::arrow::Array& data) { + auto values = data.data()->GetValues(1); + size_t buffer_position = buffered_indices_.size(); + buffered_indices_.resize(buffer_position + + static_cast(data.length() - data.null_count())); + ::arrow::internal::VisitSetBitRunsVoid( + data.null_bitmap_data(), data.offset(), data.length(), + [&](int64_t position, int64_t length) { + for (int64_t i = 0; i < length; ++i) { + buffered_indices_[buffer_position++] = + static_cast(values[i + position]); + } + }); + } + + void PutIndices(const ::arrow::Array& data) override { + switch (data.type()->id()) { + case ::arrow::Type::UINT8: + case ::arrow::Type::INT8: + return PutIndicesTyped<::arrow::UInt8Type>(data); + case ::arrow::Type::UINT16: + case ::arrow::Type::INT16: + return PutIndicesTyped<::arrow::UInt16Type>(data); + case ::arrow::Type::UINT32: + case ::arrow::Type::INT32: + return PutIndicesTyped<::arrow::UInt32Type>(data); + case ::arrow::Type::UINT64: + case ::arrow::Type::INT64: + return PutIndicesTyped<::arrow::UInt64Type>(data); + default: + throw ParquetException("Passed non-integer array to PutIndices"); + } + } + + std::shared_ptr FlushValues() override { + std::shared_ptr buffer = + AllocateBuffer(this->pool_, EstimatedDataEncodedSize()); + int result_size = WriteIndices(buffer->mutable_data(), + static_cast(EstimatedDataEncodedSize())); + PARQUET_THROW_NOT_OK(buffer->Resize(result_size, false)); + return std::move(buffer); + } + + /// Writes out the encoded dictionary to buffer. buffer must be preallocated to + /// dict_encoded_size() bytes. + void WriteDict(uint8_t* buffer) override; + + /// The number of entries in the dictionary. + int num_entries() const override { return memo_table_.size(); } + + private: + /// Clears all the indices (but leaves the dictionary). + void ClearIndices() { buffered_indices_.clear(); } + + /// Indices that have not yet be written out by WriteIndices(). + ArrowPoolVector buffered_indices_; + + template + void PutBinaryArray(const ArrayType& array) { + PARQUET_THROW_NOT_OK(::arrow::VisitArrayDataInline( + *array.data(), + [&](::arrow::util::string_view view) { + if (ARROW_PREDICT_FALSE(view.size() > kMaxByteArraySize)) { + return Status::Invalid("Parquet cannot store strings with size 2GB or more"); + } + PutByteArray(view.data(), static_cast(view.size())); + return Status::OK(); + }, + []() { return Status::OK(); })); + } + + template + void PutBinaryDictionaryArray(const ArrayType& array) { + DCHECK_EQ(array.null_count(), 0); + for (int64_t i = 0; i < array.length(); i++) { + auto v = array.GetView(i); + if (ARROW_PREDICT_FALSE(v.size() > kMaxByteArraySize)) { + throw ParquetException("Parquet cannot store strings with size 2GB or more"); + } + dict_encoded_size_ += static_cast(v.size() + sizeof(uint32_t)); + int32_t unused_memo_index; + PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert( + v.data(), static_cast(v.size()), &unused_memo_index)); + } + } + + /// The number of bytes needed to encode the dictionary. + int dict_encoded_size_; + + MemoTableType memo_table_; +}; + +template +void DictEncoderImpl::WriteDict(uint8_t* buffer) { + // For primitive types, only a memcpy + DCHECK_EQ(static_cast(dict_encoded_size_), sizeof(T) * memo_table_.size()); + memo_table_.CopyValues(0 /* start_pos */, reinterpret_cast(buffer)); +} + +// ByteArray and FLBA already have the dictionary encoded in their data heaps +template <> +void DictEncoderImpl::WriteDict(uint8_t* buffer) { + memo_table_.VisitValues(0, [&buffer](const ::arrow::util::string_view& v) { + uint32_t len = static_cast(v.length()); + memcpy(buffer, &len, sizeof(len)); + buffer += sizeof(len); + memcpy(buffer, v.data(), len); + buffer += len; + }); +} + +template <> +void DictEncoderImpl::WriteDict(uint8_t* buffer) { + memo_table_.VisitValues(0, [&](const ::arrow::util::string_view& v) { + DCHECK_EQ(v.length(), static_cast(type_length_)); + memcpy(buffer, v.data(), type_length_); + buffer += type_length_; + }); +} + +template +inline void DictEncoderImpl::Put(const T& v) { + // Put() implementation for primitive types + auto on_found = [](int32_t memo_index) {}; + auto on_not_found = [this](int32_t memo_index) { + dict_encoded_size_ += static_cast(sizeof(T)); + }; + + int32_t memo_index; + PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert(v, on_found, on_not_found, &memo_index)); + buffered_indices_.push_back(memo_index); +} + +template +inline void DictEncoderImpl::PutByteArray(const void* ptr, int32_t length) { + DCHECK(false); +} + +template <> +inline void DictEncoderImpl::PutByteArray(const void* ptr, + int32_t length) { + static const uint8_t empty[] = {0}; + + auto on_found = [](int32_t memo_index) {}; + auto on_not_found = [&](int32_t memo_index) { + dict_encoded_size_ += static_cast(length + sizeof(uint32_t)); + }; + + DCHECK(ptr != nullptr || length == 0); + ptr = (ptr != nullptr) ? ptr : empty; + int32_t memo_index; + PARQUET_THROW_NOT_OK( + memo_table_.GetOrInsert(ptr, length, on_found, on_not_found, &memo_index)); + buffered_indices_.push_back(memo_index); +} + +template <> +inline void DictEncoderImpl::Put(const ByteArray& val) { + return PutByteArray(val.ptr, static_cast(val.len)); +} + +template <> +inline void DictEncoderImpl::Put(const FixedLenByteArray& v) { + static const uint8_t empty[] = {0}; + + auto on_found = [](int32_t memo_index) {}; + auto on_not_found = [this](int32_t memo_index) { dict_encoded_size_ += type_length_; }; + + DCHECK(v.ptr != nullptr || type_length_ == 0); + const void* ptr = (v.ptr != nullptr) ? v.ptr : empty; + int32_t memo_index; + PARQUET_THROW_NOT_OK( + memo_table_.GetOrInsert(ptr, type_length_, on_found, on_not_found, &memo_index)); + buffered_indices_.push_back(memo_index); +} + +template <> +void DictEncoderImpl::Put(const ::arrow::Array& values) { + ParquetException::NYI("Direct put to Int96"); +} + +template <> +void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { + ParquetException::NYI("Direct put to Int96"); +} + +template +void DictEncoderImpl::Put(const ::arrow::Array& values) { + using ArrayType = typename ::arrow::CTypeTraits::ArrayType; + const auto& data = checked_cast(values); + if (data.null_count() == 0) { + // no nulls, just dump the data + for (int64_t i = 0; i < data.length(); i++) { + Put(data.Value(i)); + } + } else { + for (int64_t i = 0; i < data.length(); i++) { + if (data.IsValid(i)) { + Put(data.Value(i)); + } + } + } +} + +template <> +void DictEncoderImpl::Put(const ::arrow::Array& values) { + AssertFixedSizeBinary(values, type_length_); + const auto& data = checked_cast(values); + if (data.null_count() == 0) { + // no nulls, just dump the data + for (int64_t i = 0; i < data.length(); i++) { + Put(FixedLenByteArray(data.Value(i))); + } + } else { + std::vector empty(type_length_, 0); + for (int64_t i = 0; i < data.length(); i++) { + if (data.IsValid(i)) { + Put(FixedLenByteArray(data.Value(i))); + } + } + } +} + +template <> +void DictEncoderImpl::Put(const ::arrow::Array& values) { + AssertBaseBinary(values); + if (::arrow::is_binary_like(values.type_id())) { + PutBinaryArray(checked_cast(values)); + } else { + DCHECK(::arrow::is_large_binary_like(values.type_id())); + PutBinaryArray(checked_cast(values)); + } +} + +template +void AssertCanPutDictionary(DictEncoderImpl* encoder, const ::arrow::Array& dict) { + if (dict.null_count() > 0) { + throw ParquetException("Inserted dictionary cannot cannot contain nulls"); + } + + if (encoder->num_entries() > 0) { + throw ParquetException("Can only call PutDictionary on an empty DictEncoder"); + } +} + +template +void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { + AssertCanPutDictionary(this, values); + + using ArrayType = typename ::arrow::CTypeTraits::ArrayType; + const auto& data = checked_cast(values); + + dict_encoded_size_ += static_cast(sizeof(typename DType::c_type) * data.length()); + for (int64_t i = 0; i < data.length(); i++) { + int32_t unused_memo_index; + PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert(data.Value(i), &unused_memo_index)); + } +} + +template <> +void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { + AssertFixedSizeBinary(values, type_length_); + AssertCanPutDictionary(this, values); + + const auto& data = checked_cast(values); + + dict_encoded_size_ += static_cast(type_length_ * data.length()); + for (int64_t i = 0; i < data.length(); i++) { + int32_t unused_memo_index; + PARQUET_THROW_NOT_OK( + memo_table_.GetOrInsert(data.Value(i), type_length_, &unused_memo_index)); + } +} + +template <> +void DictEncoderImpl::PutDictionary(const ::arrow::Array& values) { + AssertBaseBinary(values); + AssertCanPutDictionary(this, values); + + if (::arrow::is_binary_like(values.type_id())) { + PutBinaryDictionaryArray(checked_cast(values)); + } else { + DCHECK(::arrow::is_large_binary_like(values.type_id())); + PutBinaryDictionaryArray(checked_cast(values)); + } +} + +// ---------------------------------------------------------------------- +// ByteStreamSplitEncoder implementations + +template +class ByteStreamSplitEncoder : public EncoderImpl, virtual public TypedEncoder { + public: + using T = typename DType::c_type; + using TypedEncoder::Put; + + explicit ByteStreamSplitEncoder( + const ColumnDescriptor* descr, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + + int64_t EstimatedDataEncodedSize() override; + std::shared_ptr FlushValues() override; + + void Put(const T* buffer, int num_values) override; + void Put(const ::arrow::Array& values) override; + void PutSpaced(const T* src, int num_values, const uint8_t* valid_bits, + int64_t valid_bits_offset) override; + + protected: + template + void PutImpl(const ::arrow::Array& values) { + if (values.type_id() != ArrowType::type_id) { + throw ParquetException(std::string() + "direct put to " + ArrowType::type_name() + + " from " + values.type()->ToString() + " not supported"); + } + const auto& data = *values.data(); + PutSpaced(data.GetValues(1), + static_cast(data.length), data.GetValues(0, 0), data.offset); + } + + ::arrow::BufferBuilder sink_; + int64_t num_values_in_buffer_; +}; + +template +ByteStreamSplitEncoder::ByteStreamSplitEncoder(const ColumnDescriptor* descr, + ::arrow::MemoryPool* pool) + : EncoderImpl(descr, Encoding::BYTE_STREAM_SPLIT, pool), + sink_{pool}, + num_values_in_buffer_{0} {} + +template +int64_t ByteStreamSplitEncoder::EstimatedDataEncodedSize() { + return sink_.length(); +} + +template +std::shared_ptr ByteStreamSplitEncoder::FlushValues() { + std::shared_ptr output_buffer = + AllocateBuffer(this->memory_pool(), EstimatedDataEncodedSize()); + uint8_t* output_buffer_raw = output_buffer->mutable_data(); + const uint8_t* raw_values = sink_.data(); + ::arrow::util::internal::ByteStreamSplitEncode(raw_values, num_values_in_buffer_, + output_buffer_raw); + sink_.Reset(); + num_values_in_buffer_ = 0; + return std::move(output_buffer); +} + +template +void ByteStreamSplitEncoder::Put(const T* buffer, int num_values) { + if (num_values > 0) { + PARQUET_THROW_NOT_OK(sink_.Append(buffer, num_values * sizeof(T))); + num_values_in_buffer_ += num_values; + } +} + +template <> +void ByteStreamSplitEncoder::Put(const ::arrow::Array& values) { + PutImpl<::arrow::FloatType>(values); +} + +template <> +void ByteStreamSplitEncoder::Put(const ::arrow::Array& values) { + PutImpl<::arrow::DoubleType>(values); +} + +template +void ByteStreamSplitEncoder::PutSpaced(const T* src, int num_values, + const uint8_t* valid_bits, + int64_t valid_bits_offset) { + if (valid_bits != NULLPTR) { + PARQUET_ASSIGN_OR_THROW(auto buffer, ::arrow::AllocateBuffer(num_values * sizeof(T), + this->memory_pool())); + T* data = reinterpret_cast(buffer->mutable_data()); + int num_valid_values = ::arrow::util::internal::SpacedCompress( + src, num_values, valid_bits, valid_bits_offset, data); + Put(data, num_valid_values); + } else { + Put(src, num_values); + } +} + +class DecoderImpl : virtual public Decoder { + public: + void SetData(int num_values, const uint8_t* data, int len) override { + num_values_ = num_values; + data_ = data; + len_ = len; + } + + int values_left() const override { return num_values_; } + Encoding::type encoding() const override { return encoding_; } + + protected: + explicit DecoderImpl(const ColumnDescriptor* descr, Encoding::type encoding) + : descr_(descr), encoding_(encoding), num_values_(0), data_(NULLPTR), len_(0) {} + + // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY + const ColumnDescriptor* descr_; + + const Encoding::type encoding_; + int num_values_; + const uint8_t* data_; + int len_; + int type_length_; +}; + +template +class PlainDecoder : public DecoderImpl, virtual public TypedDecoder { + public: + using T = typename DType::c_type; + explicit PlainDecoder(const ColumnDescriptor* descr); + + int Decode(T* buffer, int max_values) override; + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) override; + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) override; +}; + +template <> +inline int PlainDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + ParquetException::NYI("DecodeArrow not supported for Int96"); +} + +template <> +inline int PlainDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + ParquetException::NYI("DecodeArrow not supported for Int96"); +} + +template <> +inline int PlainDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + ParquetException::NYI("dictionaries of BooleanType"); +} + +template +int PlainDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + using value_type = typename DType::c_type; + + constexpr int value_size = static_cast(sizeof(value_type)); + int values_decoded = num_values - null_count; + if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { + ParquetException::EofException(); + } + + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + builder->UnsafeAppend(::arrow::util::SafeLoadAs(data_)); + data_ += sizeof(value_type); + }, + [&]() { builder->UnsafeAppendNull(); }); + + num_values_ -= values_decoded; + len_ -= sizeof(value_type) * values_decoded; + return values_decoded; +} + +template +int PlainDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + using value_type = typename DType::c_type; + + constexpr int value_size = static_cast(sizeof(value_type)); + int values_decoded = num_values - null_count; + if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { + ParquetException::EofException(); + } + + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + PARQUET_THROW_NOT_OK( + builder->Append(::arrow::util::SafeLoadAs(data_))); + data_ += sizeof(value_type); + }, + [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + + num_values_ -= values_decoded; + len_ -= sizeof(value_type) * values_decoded; + return values_decoded; +} + +// Decode routine templated on C++ type rather than type enum +template +inline int DecodePlain(const uint8_t* data, int64_t data_size, int num_values, + int type_length, T* out) { + int64_t bytes_to_decode = num_values * static_cast(sizeof(T)); + if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) { + ParquetException::EofException(); + } + // If bytes_to_decode == 0, data could be null + if (bytes_to_decode > 0) { + memcpy(out, data, bytes_to_decode); + } + return static_cast(bytes_to_decode); +} + +template +PlainDecoder::PlainDecoder(const ColumnDescriptor* descr) + : DecoderImpl(descr, Encoding::PLAIN) { + if (descr_ && descr_->physical_type() == Type::FIXED_LEN_BYTE_ARRAY) { + type_length_ = descr_->type_length(); + } else { + type_length_ = -1; + } +} + +// Template specialization for BYTE_ARRAY. The written values do not own their +// own data. + +static inline int64_t ReadByteArray(const uint8_t* data, int64_t data_size, + ByteArray* out) { + if (ARROW_PREDICT_FALSE(data_size < 4)) { + ParquetException::EofException(); + } + const int32_t len = ::arrow::util::SafeLoadAs(data); + if (len < 0) { + throw ParquetException("Invalid BYTE_ARRAY value"); + } + const int64_t consumed_length = static_cast(len) + 4; + if (ARROW_PREDICT_FALSE(data_size < consumed_length)) { + ParquetException::EofException(); + } + *out = ByteArray{static_cast(len), data + 4}; + return consumed_length; +} + +template <> +inline int DecodePlain(const uint8_t* data, int64_t data_size, int num_values, + int type_length, ByteArray* out) { + int bytes_decoded = 0; + for (int i = 0; i < num_values; ++i) { + const auto increment = ReadByteArray(data, data_size, out + i); + if (ARROW_PREDICT_FALSE(increment > INT_MAX - bytes_decoded)) { + throw ParquetException("BYTE_ARRAY chunk too large"); + } + data += increment; + data_size -= increment; + bytes_decoded += static_cast(increment); + } + return bytes_decoded; +} + +// Template specialization for FIXED_LEN_BYTE_ARRAY. The written values do not +// own their own data. +template <> +inline int DecodePlain(const uint8_t* data, int64_t data_size, + int num_values, int type_length, + FixedLenByteArray* out) { + int64_t bytes_to_decode = static_cast(type_length) * num_values; + if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) { + ParquetException::EofException(); + } + for (int i = 0; i < num_values; ++i) { + out[i].ptr = data; + data += type_length; + data_size -= type_length; + } + return static_cast(bytes_to_decode); +} + +template +int PlainDecoder::Decode(T* buffer, int max_values) { + max_values = std::min(max_values, num_values_); + int bytes_consumed = DecodePlain(data_, len_, max_values, type_length_, buffer); + data_ += bytes_consumed; + len_ -= bytes_consumed; + num_values_ -= max_values; + return max_values; +} + +class PlainBooleanDecoder : public DecoderImpl, + virtual public TypedDecoder, + virtual public BooleanDecoder { + public: + explicit PlainBooleanDecoder(const ColumnDescriptor* descr); + void SetData(int num_values, const uint8_t* data, int len) override; + + // Two flavors of bool decoding + int Decode(uint8_t* buffer, int max_values) override; + int Decode(bool* buffer, int max_values) override; + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* out) override; + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* out) override; + + private: + std::unique_ptr<::arrow::BitUtil::BitReader> bit_reader_; +}; + +PlainBooleanDecoder::PlainBooleanDecoder(const ColumnDescriptor* descr) + : DecoderImpl(descr, Encoding::PLAIN) {} + +void PlainBooleanDecoder::SetData(int num_values, const uint8_t* data, int len) { + num_values_ = num_values; + bit_reader_.reset(new BitUtil::BitReader(data, len)); +} + +int PlainBooleanDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + int values_decoded = num_values - null_count; + if (ARROW_PREDICT_FALSE(num_values_ < values_decoded)) { + ParquetException::EofException(); + } + + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + bool value; + ARROW_IGNORE_EXPR(bit_reader_->GetValue(1, &value)); + builder->UnsafeAppend(value); + }, + [&]() { builder->UnsafeAppendNull(); }); + + num_values_ -= values_decoded; + return values_decoded; +} + +inline int PlainBooleanDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + ParquetException::NYI("dictionaries of BooleanType"); +} + +int PlainBooleanDecoder::Decode(uint8_t* buffer, int max_values) { + max_values = std::min(max_values, num_values_); + bool val; + ::arrow::internal::BitmapWriter bit_writer(buffer, 0, max_values); + for (int i = 0; i < max_values; ++i) { + if (!bit_reader_->GetValue(1, &val)) { + ParquetException::EofException(); + } + if (val) { + bit_writer.Set(); + } + bit_writer.Next(); + } + bit_writer.Finish(); + num_values_ -= max_values; + return max_values; +} + +int PlainBooleanDecoder::Decode(bool* buffer, int max_values) { + max_values = std::min(max_values, num_values_); + if (bit_reader_->GetBatch(1, buffer, max_values) != max_values) { + ParquetException::EofException(); + } + num_values_ -= max_values; + return max_values; +} + +struct ArrowBinaryHelper { + explicit ArrowBinaryHelper(typename EncodingTraits::Accumulator* out) { + this->out = out; + this->builder = out->builder.get(); + this->chunk_space_remaining = + ::arrow::kBinaryMemoryLimit - this->builder->value_data_length(); + } + + Status PushChunk() { + std::shared_ptr<::arrow::Array> result; + RETURN_NOT_OK(builder->Finish(&result)); + out->chunks.push_back(result); + chunk_space_remaining = ::arrow::kBinaryMemoryLimit; + return Status::OK(); + } + + bool CanFit(int64_t length) const { return length <= chunk_space_remaining; } + + void UnsafeAppend(const uint8_t* data, int32_t length) { + chunk_space_remaining -= length; + builder->UnsafeAppend(data, length); + } + + void UnsafeAppendNull() { builder->UnsafeAppendNull(); } + + Status Append(const uint8_t* data, int32_t length) { + chunk_space_remaining -= length; + return builder->Append(data, length); + } + + Status AppendNull() { return builder->AppendNull(); } + + typename EncodingTraits::Accumulator* out; + ::arrow::BinaryBuilder* builder; + int64_t chunk_space_remaining; +}; + +template <> +inline int PlainDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + ParquetException::NYI(); +} + +template <> +inline int PlainDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + ParquetException::NYI(); +} + +template <> +inline int PlainDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + int values_decoded = num_values - null_count; + if (ARROW_PREDICT_FALSE(len_ < descr_->type_length() * values_decoded)) { + ParquetException::EofException(); + } + + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + builder->UnsafeAppend(data_); + data_ += descr_->type_length(); + }, + [&]() { builder->UnsafeAppendNull(); }); + + num_values_ -= values_decoded; + len_ -= descr_->type_length() * values_decoded; + return values_decoded; +} + +template <> +inline int PlainDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + int values_decoded = num_values - null_count; + if (ARROW_PREDICT_FALSE(len_ < descr_->type_length() * values_decoded)) { + ParquetException::EofException(); + } + + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + PARQUET_THROW_NOT_OK(builder->Append(data_)); + data_ += descr_->type_length(); + }, + [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + + num_values_ -= values_decoded; + len_ -= descr_->type_length() * values_decoded; + return values_decoded; +} + +class PlainByteArrayDecoder : public PlainDecoder, + virtual public ByteArrayDecoder { + public: + using Base = PlainDecoder; + using Base::DecodeSpaced; + using Base::PlainDecoder; + + // ---------------------------------------------------------------------- + // Dictionary read paths + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + ::arrow::BinaryDictionary32Builder* builder) override { + int result = 0; + PARQUET_THROW_NOT_OK(DecodeArrow(num_values, null_count, valid_bits, + valid_bits_offset, builder, &result)); + return result; + } + + // ---------------------------------------------------------------------- + // Optimized dense binary read paths + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* out) override { + int result = 0; + PARQUET_THROW_NOT_OK(DecodeArrowDense(num_values, null_count, valid_bits, + valid_bits_offset, out, &result)); + return result; + } + + int DecodeCH(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + PaddedPODArray* column_chars_t_p, + PaddedPODArray* column_offsets_p, + ::arrow::internal::BitmapWriter& bitmap_writer + ) { + int result = 0; + PARQUET_THROW_NOT_OK(DecodeCHDense(num_values, null_count, valid_bits, + valid_bits_offset, column_chars_t_p, + column_offsets_p, bitmap_writer, &result)); + return result; + } + + + private: + + Status DecodeCHDense(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + PaddedPODArray* column_chars_t_p, + PaddedPODArray* column_offsets_p, + ::arrow::internal::BitmapWriter& bitmap_writer, + int* out_values_decoded) { + //ArrowBinaryHelper helper(out); + int values_decoded = 0; + +// RETURN_NOT_OK(helper.builder->Reserve(num_values)); +// RETURN_NOT_OK(helper.builder->ReserveData( +// std::min(len_, helper.chunk_space_remaining))); + column_offsets_p->reserve(num_values); + column_chars_t_p->reserve(num_values + len_); + + if (null_count == 0) { + for (int i = 0 ; i < num_values; i++) { + if (ARROW_PREDICT_FALSE(len_ < 4)) + { + ParquetException::EofException(); + } + auto value_len = ::arrow::util::SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) + { + return Status::Invalid("Invalid or corrupted value_len '", value_len, "'"); + } + auto increment = value_len + 4; + if (ARROW_PREDICT_FALSE(len_ < increment)) + { + ParquetException::EofException(); + } + + column_chars_t_p->insert_assume_reserved(data_ + 4, data_ + 4 + value_len); + column_chars_t_p->emplace_back('\0'); + column_offsets_p->emplace_back(column_chars_t_p->size()); + + bitmap_writer.Set(); + bitmap_writer.Next(); + + data_ += increment; + len_ -= increment; + ++values_decoded; + } + } else { + RETURN_NOT_OK(VisitNullBitmapInline( + valid_bits, + valid_bits_offset, + num_values, + null_count, + [&]() + { + if (ARROW_PREDICT_FALSE(len_ < 4)) + { + ParquetException::EofException(); + } + auto value_len = ::arrow::util::SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) + { + return Status::Invalid("Invalid or corrupted value_len '", value_len, "'"); + } + auto increment = value_len + 4; + if (ARROW_PREDICT_FALSE(len_ < increment)) + { + ParquetException::EofException(); + } + + column_chars_t_p->insert_assume_reserved(data_ + 4, data_ + 4 + value_len); + column_chars_t_p->emplace_back('\0'); + column_offsets_p->emplace_back(column_chars_t_p->size()); + + bitmap_writer.Set(); + bitmap_writer.Next(); + + data_ += increment; + len_ -= increment; + ++values_decoded; + return Status::OK(); + }, + [&]() + { + //helper.UnsafeAppendNull(); + column_chars_t_p->emplace_back('\0'); + column_offsets_p->emplace_back(column_chars_t_p->size()); + + bitmap_writer.Clear(); + bitmap_writer.Next(); + + return Status::OK(); + })); + } + + num_values_ -= values_decoded; + *out_values_decoded = values_decoded; + return Status::OK(); + } + + Status DecodeArrowDense(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* out, + int* out_values_decoded) { + ArrowBinaryHelper helper(out); + int values_decoded = 0; + + RETURN_NOT_OK(helper.builder->Reserve(num_values)); + RETURN_NOT_OK(helper.builder->ReserveData( + std::min(len_, helper.chunk_space_remaining))); + + int i = 0; + RETURN_NOT_OK(VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + if (ARROW_PREDICT_FALSE(len_ < 4)) { + ParquetException::EofException(); + } + auto value_len = ::arrow::util::SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) { + return Status::Invalid("Invalid or corrupted value_len '", value_len, "'"); + } + auto increment = value_len + 4; + if (ARROW_PREDICT_FALSE(len_ < increment)) { + ParquetException::EofException(); + } + if (ARROW_PREDICT_FALSE(!helper.CanFit(value_len))) { + // This element would exceed the capacity of a chunk + RETURN_NOT_OK(helper.PushChunk()); + RETURN_NOT_OK(helper.builder->Reserve(num_values - i)); + RETURN_NOT_OK(helper.builder->ReserveData( + std::min(len_, helper.chunk_space_remaining))); + } + helper.UnsafeAppend(data_ + 4, value_len); + data_ += increment; + len_ -= increment; + ++values_decoded; + ++i; + return Status::OK(); + }, + [&]() { + helper.UnsafeAppendNull(); + ++i; + return Status::OK(); + })); + + num_values_ -= values_decoded; + *out_values_decoded = values_decoded; + return Status::OK(); + } + + template + Status DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, BuilderType* builder, + int* out_values_decoded) { + RETURN_NOT_OK(builder->Reserve(num_values)); + int values_decoded = 0; + + RETURN_NOT_OK(VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + if (ARROW_PREDICT_FALSE(len_ < 4)) { + ParquetException::EofException(); + } + auto value_len = ::arrow::util::SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) { + return Status::Invalid("Invalid or corrupted value_len '", value_len, "'"); + } + auto increment = value_len + 4; + if (ARROW_PREDICT_FALSE(len_ < increment)) { + ParquetException::EofException(); + } + RETURN_NOT_OK(builder->Append(data_ + 4, value_len)); + data_ += increment; + len_ -= increment; + ++values_decoded; + return Status::OK(); + }, + [&]() { return builder->AppendNull(); })); + + num_values_ -= values_decoded; + *out_values_decoded = values_decoded; + return Status::OK(); + } +}; + +class PlainFLBADecoder : public PlainDecoder, virtual public FLBADecoder { + public: + using Base = PlainDecoder; + using Base::PlainDecoder; +}; + +// ---------------------------------------------------------------------- +// Dictionary encoding and decoding + +template +class DictDecoderImpl : public DecoderImpl, virtual public DictDecoder { + public: + typedef typename Type::c_type T; + + // Initializes the dictionary with values from 'dictionary'. The data in + // dictionary is not guaranteed to persist in memory after this call so the + // dictionary decoder needs to copy the data out if necessary. + explicit DictDecoderImpl(const ColumnDescriptor* descr, + MemoryPool* pool = ::arrow::default_memory_pool()) + : DecoderImpl(descr, Encoding::RLE_DICTIONARY), + dictionary_(AllocateBuffer(pool, 0)), + dictionary_length_(0), + byte_array_data_(AllocateBuffer(pool, 0)), + byte_array_offsets_(AllocateBuffer(pool, 0)), + indices_scratch_space_(AllocateBuffer(pool, 0)) {} + + // Perform type-specific initiatialization + void SetDict(TypedDecoder* dictionary) override; + + void SetData(int num_values, const uint8_t* data, int len) override { + num_values_ = num_values; + if (len == 0) { + // Initialize dummy decoder to avoid crashes later on + idx_decoder_ = ::arrow::util::RleDecoder(data, len, /*bit_width=*/1); + return; + } + uint8_t bit_width = *data; + if (ARROW_PREDICT_FALSE(bit_width >= 64)) { + throw ParquetException("Invalid or corrupted bit_width"); + } + idx_decoder_ = ::arrow::util::RleDecoder(++data, --len, bit_width); + } + + int Decode(T* buffer, int num_values) override { + num_values = std::min(num_values, num_values_); + int decoded_values = + idx_decoder_.GetBatchWithDict(reinterpret_cast(dictionary_->data()), + dictionary_length_, buffer, num_values); + if (decoded_values != num_values) { + ParquetException::EofException(); + } + num_values_ -= num_values; + return num_values; + } + + int DecodeSpaced(T* buffer, int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset) override { + num_values = std::min(num_values, num_values_); + if (num_values != idx_decoder_.GetBatchWithDictSpaced( + reinterpret_cast(dictionary_->data()), + dictionary_length_, buffer, num_values, null_count, valid_bits, + valid_bits_offset)) { + ParquetException::EofException(); + } + num_values_ -= num_values; + return num_values; + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* out) override; + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* out) override; + + void InsertDictionary(::arrow::ArrayBuilder* builder) override; + + int DecodeIndicesSpaced(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + ::arrow::ArrayBuilder* builder) override { + if (num_values > 0) { + // TODO(wesm): Refactor to batch reads for improved memory use. It is not + // trivial because the null_count is relative to the entire bitmap + PARQUET_THROW_NOT_OK(indices_scratch_space_->TypedResize( + num_values, /*shrink_to_fit=*/false)); + } + + auto indices_buffer = + reinterpret_cast(indices_scratch_space_->mutable_data()); + + if (num_values != idx_decoder_.GetBatchSpaced(num_values, null_count, valid_bits, + valid_bits_offset, indices_buffer)) { + ParquetException::EofException(); + } + + /// XXX(wesm): Cannot append "valid bits" directly to the builder + std::vector valid_bytes(num_values); + ::arrow::internal::BitmapReader bit_reader(valid_bits, valid_bits_offset, num_values); + for (int64_t i = 0; i < num_values; ++i) { + valid_bytes[i] = static_cast(bit_reader.IsSet()); + bit_reader.Next(); + } + + auto binary_builder = checked_cast<::arrow::BinaryDictionary32Builder*>(builder); + PARQUET_THROW_NOT_OK( + binary_builder->AppendIndices(indices_buffer, num_values, valid_bytes.data())); + num_values_ -= num_values - null_count; + return num_values - null_count; + } + + int DecodeIndices(int num_values, ::arrow::ArrayBuilder* builder) override { + num_values = std::min(num_values, num_values_); + if (num_values > 0) { + // TODO(wesm): Refactor to batch reads for improved memory use. This is + // relatively simple here because we don't have to do any bookkeeping of + // nulls + PARQUET_THROW_NOT_OK(indices_scratch_space_->TypedResize( + num_values, /*shrink_to_fit=*/false)); + } + auto indices_buffer = + reinterpret_cast(indices_scratch_space_->mutable_data()); + if (num_values != idx_decoder_.GetBatch(indices_buffer, num_values)) { + ParquetException::EofException(); + } + auto binary_builder = checked_cast<::arrow::BinaryDictionary32Builder*>(builder); + PARQUET_THROW_NOT_OK(binary_builder->AppendIndices(indices_buffer, num_values)); + num_values_ -= num_values; + return num_values; + } + + int DecodeIndices(int num_values, int32_t* indices) override { + if (num_values != idx_decoder_.GetBatch(indices, num_values)) { + ParquetException::EofException(); + } + num_values_ -= num_values; + return num_values; + } + + void GetDictionary(const T** dictionary, int32_t* dictionary_length) override { + *dictionary_length = dictionary_length_; + *dictionary = reinterpret_cast(dictionary_->mutable_data()); + } + + protected: + Status IndexInBounds(int32_t index) { + if (ARROW_PREDICT_TRUE(0 <= index && index < dictionary_length_)) { + return Status::OK(); + } + return Status::Invalid("Index not in dictionary bounds"); + } + + inline void DecodeDict(TypedDecoder* dictionary) { + dictionary_length_ = static_cast(dictionary->values_left()); + PARQUET_THROW_NOT_OK(dictionary_->Resize(dictionary_length_ * sizeof(T), + /*shrink_to_fit=*/false)); + dictionary->Decode(reinterpret_cast(dictionary_->mutable_data()), + dictionary_length_); + } + + // Only one is set. + std::shared_ptr dictionary_; + + int32_t dictionary_length_; + + // Data that contains the byte array data (byte_array_dictionary_ just has the + // pointers). + std::shared_ptr byte_array_data_; + + // Arrow-style byte offsets for each dictionary value. We maintain two + // representations of the dictionary, one as ByteArray* for non-Arrow + // consumers and this one for Arrow consumers. Since dictionaries are + // generally pretty small to begin with this doesn't mean too much extra + // memory use in most cases + std::shared_ptr byte_array_offsets_; + + // Reusable buffer for decoding dictionary indices to be appended to a + // BinaryDictionary32Builder + std::shared_ptr indices_scratch_space_; + + ::arrow::util::RleDecoder idx_decoder_; +}; + +template +void DictDecoderImpl::SetDict(TypedDecoder* dictionary) { + DecodeDict(dictionary); +} + +template <> +void DictDecoderImpl::SetDict(TypedDecoder* dictionary) { + ParquetException::NYI("Dictionary encoding is not implemented for boolean values"); +} + +template <> +void DictDecoderImpl::SetDict(TypedDecoder* dictionary) { + DecodeDict(dictionary); + + auto dict_values = reinterpret_cast(dictionary_->mutable_data()); + + int total_size = 0; + for (int i = 0; i < dictionary_length_; ++i) { + total_size += dict_values[i].len; + } + PARQUET_THROW_NOT_OK(byte_array_data_->Resize(total_size, + /*shrink_to_fit=*/false)); + PARQUET_THROW_NOT_OK( + byte_array_offsets_->Resize((dictionary_length_ + 1) * sizeof(int32_t), + /*shrink_to_fit=*/false)); + + int32_t offset = 0; + uint8_t* bytes_data = byte_array_data_->mutable_data(); + int32_t* bytes_offsets = + reinterpret_cast(byte_array_offsets_->mutable_data()); + for (int i = 0; i < dictionary_length_; ++i) { + memcpy(bytes_data + offset, dict_values[i].ptr, dict_values[i].len); + bytes_offsets[i] = offset; + dict_values[i].ptr = bytes_data + offset; + offset += dict_values[i].len; + } + bytes_offsets[dictionary_length_] = offset; +} + +template <> +inline void DictDecoderImpl::SetDict(TypedDecoder* dictionary) { + DecodeDict(dictionary); + + auto dict_values = reinterpret_cast(dictionary_->mutable_data()); + + int fixed_len = descr_->type_length(); + int total_size = dictionary_length_ * fixed_len; + + PARQUET_THROW_NOT_OK(byte_array_data_->Resize(total_size, + /*shrink_to_fit=*/false)); + uint8_t* bytes_data = byte_array_data_->mutable_data(); + for (int32_t i = 0, offset = 0; i < dictionary_length_; ++i, offset += fixed_len) { + memcpy(bytes_data + offset, dict_values[i].ptr, fixed_len); + dict_values[i].ptr = bytes_data + offset; + } +} + +template <> +inline int DictDecoderImpl::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + ParquetException::NYI("DecodeArrow to Int96Type"); +} + +template <> +inline int DictDecoderImpl::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + ParquetException::NYI("DecodeArrow to Int96Type"); +} + +template <> +inline int DictDecoderImpl::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + ParquetException::NYI("DecodeArrow implemented elsewhere"); +} + +template <> +inline int DictDecoderImpl::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + ParquetException::NYI("DecodeArrow implemented elsewhere"); +} + +template +int DictDecoderImpl::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + auto dict_values = reinterpret_cast(dictionary_->data()); + + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + int32_t index; + if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + throw ParquetException(""); + } + PARQUET_THROW_NOT_OK(IndexInBounds(index)); + PARQUET_THROW_NOT_OK(builder->Append(dict_values[index])); + }, + [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + + return num_values - null_count; +} + +template <> +int DictDecoderImpl::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + ParquetException::NYI("No dictionary encoding for BooleanType"); +} + +template <> +inline int DictDecoderImpl::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + if (builder->byte_width() != descr_->type_length()) { + throw ParquetException("Byte width mismatch: builder was " + + std::to_string(builder->byte_width()) + " but decoder was " + + std::to_string(descr_->type_length())); + } + + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + auto dict_values = reinterpret_cast(dictionary_->data()); + + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + int32_t index; + if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + throw ParquetException(""); + } + PARQUET_THROW_NOT_OK(IndexInBounds(index)); + builder->UnsafeAppend(dict_values[index].ptr); + }, + [&]() { builder->UnsafeAppendNull(); }); + + return num_values - null_count; +} + +template <> +int DictDecoderImpl::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + auto value_type = + checked_cast(*builder->type()).value_type(); + auto byte_width = + checked_cast(*value_type).byte_width(); + if (byte_width != descr_->type_length()) { + throw ParquetException("Byte width mismatch: builder was " + + std::to_string(byte_width) + " but decoder was " + + std::to_string(descr_->type_length())); + } + + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + auto dict_values = reinterpret_cast(dictionary_->data()); + + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + int32_t index; + if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + throw ParquetException(""); + } + PARQUET_THROW_NOT_OK(IndexInBounds(index)); + PARQUET_THROW_NOT_OK(builder->Append(dict_values[index].ptr)); + }, + [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); }); + + return num_values - null_count; +} + +template +int DictDecoderImpl::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + using value_type = typename Type::c_type; + auto dict_values = reinterpret_cast(dictionary_->data()); + + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + int32_t index; + if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) { + throw ParquetException(""); + } + PARQUET_THROW_NOT_OK(IndexInBounds(index)); + builder->UnsafeAppend(dict_values[index]); + }, + [&]() { builder->UnsafeAppendNull(); }); + + return num_values - null_count; +} + +template +void DictDecoderImpl::InsertDictionary(::arrow::ArrayBuilder* builder) { + ParquetException::NYI("InsertDictionary only implemented for BYTE_ARRAY types"); +} + +template <> +void DictDecoderImpl::InsertDictionary(::arrow::ArrayBuilder* builder) { + auto binary_builder = checked_cast<::arrow::BinaryDictionary32Builder*>(builder); + + // Make a BinaryArray referencing the internal dictionary data + auto arr = std::make_shared<::arrow::BinaryArray>( + dictionary_length_, byte_array_offsets_, byte_array_data_); + PARQUET_THROW_NOT_OK(binary_builder->InsertMemoValues(*arr)); +} + +class DictByteArrayDecoderImpl : public DictDecoderImpl, + virtual public ByteArrayDecoder { + public: + using BASE = DictDecoderImpl; + using BASE::DictDecoderImpl; + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + ::arrow::BinaryDictionary32Builder* builder) override { + int result = 0; + if (null_count == 0) { + PARQUET_THROW_NOT_OK(DecodeArrowNonNull(num_values, builder, &result)); + } else { + PARQUET_THROW_NOT_OK(DecodeArrow(num_values, null_count, valid_bits, + valid_bits_offset, builder, &result)); + } + return result; + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* out) override { + int result = 0; + if (null_count == 0) { + PARQUET_THROW_NOT_OK(DecodeArrowDenseNonNull(num_values, out, &result)); + } else { + PARQUET_THROW_NOT_OK(DecodeArrowDense(num_values, null_count, valid_bits, + valid_bits_offset, out, &result)); + } + return result; + } + + + int DecodeCH(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + PaddedPODArray* column_chars_t_p, + PaddedPODArray* column_offsets_p, + ::arrow::internal::BitmapWriter& bitmap_writer + ) override { + int result = 0; + if (null_count == 0) { + PARQUET_THROW_NOT_OK(DecodeCHDenseNonNull(num_values, column_chars_t_p, column_offsets_p, bitmap_writer, &result)); + } else { + PARQUET_THROW_NOT_OK(DecodeCHDense(num_values, null_count, valid_bits, + valid_bits_offset, column_chars_t_p, column_offsets_p, bitmap_writer, &result)); + } + return result; + } + + private: + Status DecodeCHDense(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + PaddedPODArray* column_chars_t_p, + PaddedPODArray* column_offsets_p, + ::arrow::internal::BitmapWriter& bitmap_writer, + int* out_num_values) { + constexpr int32_t kBufferSize = 1024; + int32_t indices[kBufferSize]; + + column_offsets_p->reserve(num_values); + column_chars_t_p->reserve(num_values * 20); // approx + + ::arrow::internal::BitmapReader bit_reader(valid_bits, valid_bits_offset, num_values); + + auto dict_values = reinterpret_cast(dictionary_->data()); + int values_decoded = 0; + int num_appended = 0; + while (num_appended < num_values) { + bool is_valid = bit_reader.IsSet(); + bit_reader.Next(); + + if (is_valid) { + int32_t batch_size = + std::min(kBufferSize, num_values - num_appended - null_count); + int num_indices = idx_decoder_.GetBatch(indices, batch_size); + + if (ARROW_PREDICT_FALSE(num_indices < 1)) { + return Status::Invalid("Invalid number of indices '", num_indices, "'"); + } + + int i = 0; + while (true) { + // Consume all indices + if (is_valid) { + auto idx = indices[i]; + RETURN_NOT_OK(IndexInBounds(idx)); + const auto& val = dict_values[idx]; + column_chars_t_p -> insert(val.ptr, val.ptr + static_cast(val.len)); + column_chars_t_p -> emplace_back('\0'); + column_offsets_p -> emplace_back(column_chars_t_p -> size()); + ++i; + ++values_decoded; + + bitmap_writer.Set(); + bitmap_writer.Next(); + } else { + column_chars_t_p -> emplace_back('\0'); + column_offsets_p -> emplace_back(column_chars_t_p -> size()); + --null_count; + + bitmap_writer.Clear(); + bitmap_writer.Next(); + } + ++num_appended; + if (i == num_indices) { + // Do not advance the bit_reader if we have fulfilled the decode + // request + break; + } + is_valid = bit_reader.IsSet(); + bit_reader.Next(); + } + } else { + column_chars_t_p -> emplace_back('\0'); + column_offsets_p -> emplace_back(column_chars_t_p -> size()); + --null_count; + ++num_appended; + + bitmap_writer.Clear(); + bitmap_writer.Next(); + } + } + *out_num_values = values_decoded; + return Status::OK(); + } + + Status DecodeCHDenseNonNull(int num_values, + PaddedPODArray* column_chars_t_p, + PaddedPODArray* column_offsets_p, + ::arrow::internal::BitmapWriter& bitmap_writer, + int* out_num_values) { + constexpr int32_t kBufferSize = 2048; + int32_t indices[kBufferSize]; + int values_decoded = 0; + + auto dict_values = reinterpret_cast(dictionary_->data()); + + while (values_decoded < num_values) { + int32_t batch_size = std::min(kBufferSize, num_values - values_decoded); + int num_indices = idx_decoder_.GetBatch(indices, batch_size); + if (num_indices == 0) ParquetException::EofException(); + for (int i = 0; i < num_indices; ++i) { + auto idx = indices[i]; + RETURN_NOT_OK(IndexInBounds(idx)); + const auto& val = dict_values[idx]; + column_chars_t_p -> insert(val.ptr, val.ptr + static_cast(val.len)); + column_chars_t_p -> emplace_back('\0'); + column_offsets_p -> emplace_back(column_chars_t_p -> size()); + + bitmap_writer.Set(); + bitmap_writer.Next(); + } + values_decoded += num_indices; + } + *out_num_values = values_decoded; + return Status::OK(); + } + + + Status DecodeArrowDense(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* out, + int* out_num_values) { + constexpr int32_t kBufferSize = 1024; + int32_t indices[kBufferSize]; + + ArrowBinaryHelper helper(out); + + ::arrow::internal::BitmapReader bit_reader(valid_bits, valid_bits_offset, num_values); + + auto dict_values = reinterpret_cast(dictionary_->data()); + int values_decoded = 0; + int num_appended = 0; + while (num_appended < num_values) { + bool is_valid = bit_reader.IsSet(); + bit_reader.Next(); + + if (is_valid) { + int32_t batch_size = + std::min(kBufferSize, num_values - num_appended - null_count); + int num_indices = idx_decoder_.GetBatch(indices, batch_size); + + if (ARROW_PREDICT_FALSE(num_indices < 1)) { + return Status::Invalid("Invalid number of indices '", num_indices, "'"); + } + + int i = 0; + while (true) { + // Consume all indices + if (is_valid) { + auto idx = indices[i]; + RETURN_NOT_OK(IndexInBounds(idx)); + const auto& val = dict_values[idx]; + if (ARROW_PREDICT_FALSE(!helper.CanFit(val.len))) { + RETURN_NOT_OK(helper.PushChunk()); + } + RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); + ++i; + ++values_decoded; + } else { + RETURN_NOT_OK(helper.AppendNull()); + --null_count; + } + ++num_appended; + if (i == num_indices) { + // Do not advance the bit_reader if we have fulfilled the decode + // request + break; + } + is_valid = bit_reader.IsSet(); + bit_reader.Next(); + } + } else { + RETURN_NOT_OK(helper.AppendNull()); + --null_count; + ++num_appended; + } + } + *out_num_values = values_decoded; + return Status::OK(); + } + + Status DecodeArrowDenseNonNull(int num_values, + typename EncodingTraits::Accumulator* out, + int* out_num_values) { + constexpr int32_t kBufferSize = 2048; + int32_t indices[kBufferSize]; + int values_decoded = 0; + + ArrowBinaryHelper helper(out); + auto dict_values = reinterpret_cast(dictionary_->data()); + + while (values_decoded < num_values) { + int32_t batch_size = std::min(kBufferSize, num_values - values_decoded); + int num_indices = idx_decoder_.GetBatch(indices, batch_size); + if (num_indices == 0) ParquetException::EofException(); + for (int i = 0; i < num_indices; ++i) { + auto idx = indices[i]; + RETURN_NOT_OK(IndexInBounds(idx)); + const auto& val = dict_values[idx]; + if (ARROW_PREDICT_FALSE(!helper.CanFit(val.len))) { + RETURN_NOT_OK(helper.PushChunk()); + } + RETURN_NOT_OK(helper.Append(val.ptr, static_cast(val.len))); + } + values_decoded += num_indices; + } + *out_num_values = values_decoded; + return Status::OK(); + } + + template + Status DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, BuilderType* builder, + int* out_num_values) { + constexpr int32_t kBufferSize = 1024; + int32_t indices[kBufferSize]; + + RETURN_NOT_OK(builder->Reserve(num_values)); + ::arrow::internal::BitmapReader bit_reader(valid_bits, valid_bits_offset, num_values); + + auto dict_values = reinterpret_cast(dictionary_->data()); + + int values_decoded = 0; + int num_appended = 0; + while (num_appended < num_values) { + bool is_valid = bit_reader.IsSet(); + bit_reader.Next(); + + if (is_valid) { + int32_t batch_size = + std::min(kBufferSize, num_values - num_appended - null_count); + int num_indices = idx_decoder_.GetBatch(indices, batch_size); + + int i = 0; + while (true) { + // Consume all indices + if (is_valid) { + auto idx = indices[i]; + RETURN_NOT_OK(IndexInBounds(idx)); + const auto& val = dict_values[idx]; + RETURN_NOT_OK(builder->Append(val.ptr, val.len)); + ++i; + ++values_decoded; + } else { + RETURN_NOT_OK(builder->AppendNull()); + --null_count; + } + ++num_appended; + if (i == num_indices) { + // Do not advance the bit_reader if we have fulfilled the decode + // request + break; + } + is_valid = bit_reader.IsSet(); + bit_reader.Next(); + } + } else { + RETURN_NOT_OK(builder->AppendNull()); + --null_count; + ++num_appended; + } + } + *out_num_values = values_decoded; + return Status::OK(); + } + + template + Status DecodeArrowNonNull(int num_values, BuilderType* builder, int* out_num_values) { + constexpr int32_t kBufferSize = 2048; + int32_t indices[kBufferSize]; + + RETURN_NOT_OK(builder->Reserve(num_values)); + + auto dict_values = reinterpret_cast(dictionary_->data()); + + int values_decoded = 0; + while (values_decoded < num_values) { + int32_t batch_size = std::min(kBufferSize, num_values - values_decoded); + int num_indices = idx_decoder_.GetBatch(indices, batch_size); + if (num_indices == 0) ParquetException::EofException(); + for (int i = 0; i < num_indices; ++i) { + auto idx = indices[i]; + RETURN_NOT_OK(IndexInBounds(idx)); + const auto& val = dict_values[idx]; + RETURN_NOT_OK(builder->Append(val.ptr, val.len)); + } + values_decoded += num_indices; + } + *out_num_values = values_decoded; + return Status::OK(); + } +}; + +// ---------------------------------------------------------------------- +// DeltaBitPackDecoder + +template +class DeltaBitPackDecoder : public DecoderImpl, virtual public TypedDecoder { + public: + typedef typename DType::c_type T; + + explicit DeltaBitPackDecoder(const ColumnDescriptor* descr, + MemoryPool* pool = ::arrow::default_memory_pool()) + : DecoderImpl(descr, Encoding::DELTA_BINARY_PACKED), pool_(pool) { + if (DType::type_num != Type::INT32 && DType::type_num != Type::INT64) { + throw ParquetException("Delta bit pack encoding should only be for integer data."); + } + } + + void SetData(int num_values, const uint8_t* data, int len) override { + this->num_values_ = num_values; + decoder_ = ::arrow::BitUtil::BitReader(data, len); + InitHeader(); + } + + int Decode(T* buffer, int max_values) override { + return GetInternal(buffer, max_values); + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* out) override { + if (null_count != 0) { + ParquetException::NYI("Delta bit pack DecodeArrow with null slots"); + } + std::vector values(num_values); + GetInternal(values.data(), num_values); + PARQUET_THROW_NOT_OK(out->AppendValues(values)); + return num_values; + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* out) override { + if (null_count != 0) { + ParquetException::NYI("Delta bit pack DecodeArrow with null slots"); + } + std::vector values(num_values); + GetInternal(values.data(), num_values); + PARQUET_THROW_NOT_OK(out->Reserve(num_values)); + for (T value : values) { + PARQUET_THROW_NOT_OK(out->Append(value)); + } + return num_values; + } + + private: + static constexpr int kMaxDeltaBitWidth = static_cast(sizeof(T) * 8); + + void InitHeader() { + if (!decoder_.GetVlqInt(&values_per_block_) || + !decoder_.GetVlqInt(&mini_blocks_per_block_) || + !decoder_.GetVlqInt(&total_value_count_) || + !decoder_.GetZigZagVlqInt(&last_value_)) { + ParquetException::EofException(); + } + + if (values_per_block_ == 0) { + throw ParquetException("cannot have zero value per block"); + } + if (mini_blocks_per_block_ == 0) { + throw ParquetException("cannot have zero miniblock per block"); + } + values_per_mini_block_ = values_per_block_ / mini_blocks_per_block_; + if (values_per_mini_block_ == 0) { + throw ParquetException("cannot have zero value per miniblock"); + } + if (values_per_mini_block_ % 32 != 0) { + throw ParquetException( + "the number of values in a miniblock must be multiple of 32, but it's " + + std::to_string(values_per_mini_block_)); + } + + delta_bit_widths_ = AllocateBuffer(pool_, mini_blocks_per_block_); + block_initialized_ = false; + values_current_mini_block_ = 0; + } + + void InitBlock() { + if (!decoder_.GetZigZagVlqInt(&min_delta_)) ParquetException::EofException(); + + // read the bitwidth of each miniblock + uint8_t* bit_width_data = delta_bit_widths_->mutable_data(); + for (uint32_t i = 0; i < mini_blocks_per_block_; ++i) { + if (!decoder_.GetAligned(1, bit_width_data + i)) { + ParquetException::EofException(); + } + if (bit_width_data[i] > kMaxDeltaBitWidth) { + throw ParquetException("delta bit width larger than integer bit width"); + } + } + mini_block_idx_ = 0; + delta_bit_width_ = bit_width_data[0]; + values_current_mini_block_ = values_per_mini_block_; + block_initialized_ = true; + } + + int GetInternal(T* buffer, int max_values) { + max_values = std::min(max_values, this->num_values_); + DCHECK_LE(static_cast(max_values), total_value_count_); + int i = 0; + while (i < max_values) { + if (ARROW_PREDICT_FALSE(values_current_mini_block_ == 0)) { + if (ARROW_PREDICT_FALSE(!block_initialized_)) { + buffer[i++] = last_value_; + --total_value_count_; + if (ARROW_PREDICT_FALSE(i == max_values)) break; + InitBlock(); + } else { + ++mini_block_idx_; + if (mini_block_idx_ < mini_blocks_per_block_) { + delta_bit_width_ = delta_bit_widths_->data()[mini_block_idx_]; + values_current_mini_block_ = values_per_mini_block_; + } else { + InitBlock(); + } + } + } + + int values_decode = + std::min(values_current_mini_block_, static_cast(max_values - i)); + if (decoder_.GetBatch(delta_bit_width_, buffer + i, values_decode) != + values_decode) { + ParquetException::EofException(); + } + for (int j = 0; j < values_decode; ++j) { + // Addition between min_delta, packed int and last_value should be treated as + // unsigned addtion. Overflow is as expected. + uint64_t delta = + static_cast(min_delta_) + static_cast(buffer[i + j]); + buffer[i + j] = static_cast(delta + static_cast(last_value_)); + last_value_ = buffer[i + j]; + } + values_current_mini_block_ -= values_decode; + total_value_count_ -= values_decode; + i += values_decode; + } + this->num_values_ -= max_values; + return max_values; + } + + MemoryPool* pool_; + ::arrow::BitUtil::BitReader decoder_; + uint32_t values_per_block_; + uint32_t mini_blocks_per_block_; + uint32_t values_per_mini_block_; + uint32_t values_current_mini_block_; + uint32_t total_value_count_; + + bool block_initialized_; + T min_delta_; + uint32_t mini_block_idx_; + std::shared_ptr delta_bit_widths_; + int delta_bit_width_; + + T last_value_; +}; + +// ---------------------------------------------------------------------- +// DELTA_LENGTH_BYTE_ARRAY + +class DeltaLengthByteArrayDecoder : public DecoderImpl, + virtual public TypedDecoder { + public: + explicit DeltaLengthByteArrayDecoder(const ColumnDescriptor* descr, + MemoryPool* pool = ::arrow::default_memory_pool()) + : DecoderImpl(descr, Encoding::DELTA_LENGTH_BYTE_ARRAY), + len_decoder_(nullptr, pool), + pool_(pool) {} + + void SetData(int num_values, const uint8_t* data, int len) override { + num_values_ = num_values; + if (len == 0) return; + int total_lengths_len = ::arrow::util::SafeLoadAs(data); + data += 4; + this->len_decoder_.SetData(num_values, data, total_lengths_len); + data_ = data + total_lengths_len; + this->len_ = len - 4 - total_lengths_len; + } + + int Decode(ByteArray* buffer, int max_values) override { + using VectorT = ArrowPoolVector; + max_values = std::min(max_values, num_values_); + VectorT lengths(max_values, 0, ::arrow::stl::allocator(pool_)); + len_decoder_.Decode(lengths.data(), max_values); + for (int i = 0; i < max_values; ++i) { + buffer[i].len = lengths[i]; + buffer[i].ptr = data_; + this->data_ += lengths[i]; + this->len_ -= lengths[i]; + } + this->num_values_ -= max_values; + return max_values; + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* out) override { + ParquetException::NYI("DecodeArrow for DeltaLengthByteArrayDecoder"); + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* out) override { + ParquetException::NYI("DecodeArrow for DeltaLengthByteArrayDecoder"); + } + + private: + DeltaBitPackDecoder len_decoder_; + ::arrow::MemoryPool* pool_; +}; + +// ---------------------------------------------------------------------- +// DELTA_BYTE_ARRAY + +class DeltaByteArrayDecoder : public DecoderImpl, + virtual public TypedDecoder { + public: + explicit DeltaByteArrayDecoder(const ColumnDescriptor* descr, + MemoryPool* pool = ::arrow::default_memory_pool()) + : DecoderImpl(descr, Encoding::DELTA_BYTE_ARRAY), + prefix_len_decoder_(nullptr, pool), + suffix_decoder_(nullptr, pool), + last_value_(0, nullptr) {} + + virtual void SetData(int num_values, const uint8_t* data, int len) { + num_values_ = num_values; + if (len == 0) return; + int prefix_len_length = ::arrow::util::SafeLoadAs(data); + data += 4; + len -= 4; + prefix_len_decoder_.SetData(num_values, data, prefix_len_length); + data += prefix_len_length; + len -= prefix_len_length; + suffix_decoder_.SetData(num_values, data, len); + } + + // TODO: this doesn't work and requires memory management. We need to allocate + // new strings to store the results. + virtual int Decode(ByteArray* buffer, int max_values) { + max_values = std::min(max_values, this->num_values_); + for (int i = 0; i < max_values; ++i) { + int prefix_len = 0; + prefix_len_decoder_.Decode(&prefix_len, 1); + ByteArray suffix = {0, nullptr}; + suffix_decoder_.Decode(&suffix, 1); + buffer[i].len = prefix_len + suffix.len; + + uint8_t* result = reinterpret_cast(malloc(buffer[i].len)); + memcpy(result, last_value_.ptr, prefix_len); + memcpy(result + prefix_len, suffix.ptr, suffix.len); + + buffer[i].ptr = result; + last_value_ = buffer[i]; + } + this->num_values_ -= max_values; + return max_values; + } + + private: + DeltaBitPackDecoder prefix_len_decoder_; + DeltaLengthByteArrayDecoder suffix_decoder_; + ByteArray last_value_; +}; + +// ---------------------------------------------------------------------- +// BYTE_STREAM_SPLIT + +template +class ByteStreamSplitDecoder : public DecoderImpl, virtual public TypedDecoder { + public: + using T = typename DType::c_type; + explicit ByteStreamSplitDecoder(const ColumnDescriptor* descr); + + int Decode(T* buffer, int max_values) override; + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) override; + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) override; + + void SetData(int num_values, const uint8_t* data, int len) override; + + T* EnsureDecodeBuffer(int64_t min_values) { + const int64_t size = sizeof(T) * min_values; + if (!decode_buffer_ || decode_buffer_->size() < size) { + PARQUET_ASSIGN_OR_THROW(decode_buffer_, ::arrow::AllocateBuffer(size)); + } + return reinterpret_cast(decode_buffer_->mutable_data()); + } + + private: + int num_values_in_buffer_{0}; + std::shared_ptr decode_buffer_; + + static constexpr size_t kNumStreams = sizeof(T); +}; + +template +ByteStreamSplitDecoder::ByteStreamSplitDecoder(const ColumnDescriptor* descr) + : DecoderImpl(descr, Encoding::BYTE_STREAM_SPLIT) {} + +template +void ByteStreamSplitDecoder::SetData(int num_values, const uint8_t* data, + int len) { + DecoderImpl::SetData(num_values, data, len); + if (num_values * static_cast(sizeof(T)) > len) { + throw ParquetException("Data size too small for number of values (corrupted file?)"); + } + num_values_in_buffer_ = num_values; +} + +template +int ByteStreamSplitDecoder::Decode(T* buffer, int max_values) { + const int values_to_decode = std::min(num_values_, max_values); + const int num_decoded_previously = num_values_in_buffer_ - num_values_; + const uint8_t* data = data_ + num_decoded_previously; + + ::arrow::util::internal::ByteStreamSplitDecode(data, values_to_decode, + num_values_in_buffer_, buffer); + num_values_ -= values_to_decode; + len_ -= sizeof(T) * values_to_decode; + return values_to_decode; +} + +template +int ByteStreamSplitDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* builder) { + constexpr int value_size = static_cast(kNumStreams); + int values_decoded = num_values - null_count; + if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) { + ParquetException::EofException(); + } + + PARQUET_THROW_NOT_OK(builder->Reserve(num_values)); + + const int num_decoded_previously = num_values_in_buffer_ - num_values_; + const uint8_t* data = data_ + num_decoded_previously; + int offset = 0; + +#if defined(ARROW_HAVE_SIMD_SPLIT) + // Use fast decoding into intermediate buffer. This will also decode + // some null values, but it's fast enough that we don't care. + T* decode_out = EnsureDecodeBuffer(values_decoded); + ::arrow::util::internal::ByteStreamSplitDecode(data, values_decoded, + num_values_in_buffer_, decode_out); + + // XXX If null_count is 0, we could even append in bulk or decode directly into + // builder + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + builder->UnsafeAppend(decode_out[offset]); + ++offset; + }, + [&]() { builder->UnsafeAppendNull(); }); + +#else + VisitNullBitmapInline( + valid_bits, valid_bits_offset, num_values, null_count, + [&]() { + uint8_t gathered_byte_data[kNumStreams]; + for (size_t b = 0; b < kNumStreams; ++b) { + const size_t byte_index = b * num_values_in_buffer_ + offset; + gathered_byte_data[b] = data[byte_index]; + } + builder->UnsafeAppend(::arrow::util::SafeLoadAs(&gathered_byte_data[0])); + ++offset; + }, + [&]() { builder->UnsafeAppendNull(); }); +#endif + + num_values_ -= values_decoded; + len_ -= sizeof(T) * values_decoded; + return values_decoded; +} + +template +int ByteStreamSplitDecoder::DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) { + ParquetException::NYI("DecodeArrow for ByteStreamSplitDecoder"); +} + +} // namespace + +// ---------------------------------------------------------------------- +// Encoder and decoder factory functions + +std::unique_ptr MakeEncoder(Type::type type_num, Encoding::type encoding, + bool use_dictionary, const ColumnDescriptor* descr, + MemoryPool* pool) { + if (use_dictionary) { + switch (type_num) { + case Type::INT32: + return std::unique_ptr(new DictEncoderImpl(descr, pool)); + case Type::INT64: + return std::unique_ptr(new DictEncoderImpl(descr, pool)); + case Type::INT96: + return std::unique_ptr(new DictEncoderImpl(descr, pool)); + case Type::FLOAT: + return std::unique_ptr(new DictEncoderImpl(descr, pool)); + case Type::DOUBLE: + return std::unique_ptr(new DictEncoderImpl(descr, pool)); + case Type::BYTE_ARRAY: + return std::unique_ptr(new DictEncoderImpl(descr, pool)); + case Type::FIXED_LEN_BYTE_ARRAY: + return std::unique_ptr(new DictEncoderImpl(descr, pool)); + default: + DCHECK(false) << "Encoder not implemented"; + break; + } + } else if (encoding == Encoding::PLAIN) { + switch (type_num) { + case Type::BOOLEAN: + return std::unique_ptr(new PlainEncoder(descr, pool)); + case Type::INT32: + return std::unique_ptr(new PlainEncoder(descr, pool)); + case Type::INT64: + return std::unique_ptr(new PlainEncoder(descr, pool)); + case Type::INT96: + return std::unique_ptr(new PlainEncoder(descr, pool)); + case Type::FLOAT: + return std::unique_ptr(new PlainEncoder(descr, pool)); + case Type::DOUBLE: + return std::unique_ptr(new PlainEncoder(descr, pool)); + case Type::BYTE_ARRAY: + return std::unique_ptr(new PlainEncoder(descr, pool)); + case Type::FIXED_LEN_BYTE_ARRAY: + return std::unique_ptr(new PlainEncoder(descr, pool)); + default: + DCHECK(false) << "Encoder not implemented"; + break; + } + } else if (encoding == Encoding::BYTE_STREAM_SPLIT) { + switch (type_num) { + case Type::FLOAT: + return std::unique_ptr( + new ByteStreamSplitEncoder(descr, pool)); + case Type::DOUBLE: + return std::unique_ptr( + new ByteStreamSplitEncoder(descr, pool)); + default: + throw ParquetException("BYTE_STREAM_SPLIT only supports FLOAT and DOUBLE"); + break; + } + } else { + ParquetException::NYI("Selected encoding is not supported"); + } + DCHECK(false) << "Should not be able to reach this code"; + return nullptr; +} + +std::unique_ptr MakeDecoder(Type::type type_num, Encoding::type encoding, + const ColumnDescriptor* descr) { + if (encoding == Encoding::PLAIN) { + switch (type_num) { + case Type::BOOLEAN: + return std::unique_ptr(new PlainBooleanDecoder(descr)); + case Type::INT32: + return std::unique_ptr(new PlainDecoder(descr)); + case Type::INT64: + return std::unique_ptr(new PlainDecoder(descr)); + case Type::INT96: + return std::unique_ptr(new PlainDecoder(descr)); + case Type::FLOAT: + return std::unique_ptr(new PlainDecoder(descr)); + case Type::DOUBLE: + return std::unique_ptr(new PlainDecoder(descr)); + case Type::BYTE_ARRAY: + return std::unique_ptr(new PlainByteArrayDecoder(descr)); + case Type::FIXED_LEN_BYTE_ARRAY: + return std::unique_ptr(new PlainFLBADecoder(descr)); + default: + break; + } + } else if (encoding == Encoding::BYTE_STREAM_SPLIT) { + switch (type_num) { + case Type::FLOAT: + return std::unique_ptr(new ByteStreamSplitDecoder(descr)); + case Type::DOUBLE: + return std::unique_ptr(new ByteStreamSplitDecoder(descr)); + default: + throw ParquetException("BYTE_STREAM_SPLIT only supports FLOAT and DOUBLE"); + break; + } + } else if (encoding == Encoding::DELTA_BINARY_PACKED) { + switch (type_num) { + case Type::INT32: + return std::unique_ptr(new DeltaBitPackDecoder(descr)); + case Type::INT64: + return std::unique_ptr(new DeltaBitPackDecoder(descr)); + default: + throw ParquetException("DELTA_BINARY_PACKED only supports INT32 and INT64"); + break; + } + } else { + ParquetException::NYI("Selected encoding is not supported"); + } + DCHECK(false) << "Should not be able to reach this code"; + return nullptr; +} + +namespace detail { +std::unique_ptr MakeDictDecoder(Type::type type_num, + const ColumnDescriptor* descr, + MemoryPool* pool) { + switch (type_num) { + case Type::BOOLEAN: + ParquetException::NYI("Dictionary encoding not implemented for boolean type"); + case Type::INT32: + return std::unique_ptr(new DictDecoderImpl(descr, pool)); + case Type::INT64: + return std::unique_ptr(new DictDecoderImpl(descr, pool)); + case Type::INT96: + return std::unique_ptr(new DictDecoderImpl(descr, pool)); + case Type::FLOAT: + return std::unique_ptr(new DictDecoderImpl(descr, pool)); + case Type::DOUBLE: + return std::unique_ptr(new DictDecoderImpl(descr, pool)); + case Type::BYTE_ARRAY: + return std::unique_ptr(new DictByteArrayDecoderImpl(descr, pool)); + case Type::FIXED_LEN_BYTE_ARRAY: + return std::unique_ptr(new DictDecoderImpl(descr, pool)); + default: + break; + } + DCHECK(false) << "Should not be able to reach this code"; + return nullptr; +} + +} // namespace detail +} // namespace parquet diff --git a/utils/local-engine/Storages/ch_parquet/arrow/encoding.h b/utils/local-engine/Storages/ch_parquet/arrow/encoding.h new file mode 100644 index 000000000000..c780eb8d39ec --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/arrow/encoding.h @@ -0,0 +1,486 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/util/spaced.h" +#include "arrow/util/bitmap_writer.h" + + +#include "parquet/exception.h" +#include "parquet/platform.h" +#include "parquet/types.h" + +#include +#include + + +namespace arrow { + +class Array; +class ArrayBuilder; +class BinaryArray; +class BinaryBuilder; +class BooleanBuilder; +class Int32Type; +class Int64Type; +class FloatType; +class DoubleType; +class FixedSizeBinaryType; +template +class NumericBuilder; +class FixedSizeBinaryBuilder; +template +class Dictionary32Builder; + +} // namespace arrow + +namespace parquet{ +class ColumnDescriptor; +} + +namespace ch_parquet { +using namespace parquet; +using namespace DB; + +template +class TypedEncoder; + +using BooleanEncoder = TypedEncoder; +using Int32Encoder = TypedEncoder; +using Int64Encoder = TypedEncoder; +using Int96Encoder = TypedEncoder; +using FloatEncoder = TypedEncoder; +using DoubleEncoder = TypedEncoder; +using ByteArrayEncoder = TypedEncoder; +using FLBAEncoder = TypedEncoder; + +template +class TypedDecoder; + +class BooleanDecoder; +using Int32Decoder = TypedDecoder; +using Int64Decoder = TypedDecoder; +using Int96Decoder = TypedDecoder; +using FloatDecoder = TypedDecoder; +using DoubleDecoder = TypedDecoder; +using ByteArrayDecoder = TypedDecoder; +class FLBADecoder; + +template +struct EncodingTraits; + +template <> +struct EncodingTraits { + using Encoder = BooleanEncoder; + using Decoder = BooleanDecoder; + + using ArrowType = ::arrow::BooleanType; + using Accumulator = ::arrow::BooleanBuilder; + struct DictAccumulator {}; +}; + +template <> +struct EncodingTraits { + using Encoder = Int32Encoder; + using Decoder = Int32Decoder; + + using ArrowType = ::arrow::Int32Type; + using Accumulator = ::arrow::NumericBuilder<::arrow::Int32Type>; + using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::Int32Type>; +}; + +template <> +struct EncodingTraits { + using Encoder = Int64Encoder; + using Decoder = Int64Decoder; + + using ArrowType = ::arrow::Int64Type; + using Accumulator = ::arrow::NumericBuilder<::arrow::Int64Type>; + using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::Int64Type>; +}; + +template <> +struct EncodingTraits { + using Encoder = Int96Encoder; + using Decoder = Int96Decoder; + + struct Accumulator {}; + struct DictAccumulator {}; +}; + +template <> +struct EncodingTraits { + using Encoder = FloatEncoder; + using Decoder = FloatDecoder; + + using ArrowType = ::arrow::FloatType; + using Accumulator = ::arrow::NumericBuilder<::arrow::FloatType>; + using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::FloatType>; +}; + +template <> +struct EncodingTraits { + using Encoder = DoubleEncoder; + using Decoder = DoubleDecoder; + + using ArrowType = ::arrow::DoubleType; + using Accumulator = ::arrow::NumericBuilder<::arrow::DoubleType>; + using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::DoubleType>; +}; + +template <> +struct EncodingTraits { + using Encoder = ByteArrayEncoder; + using Decoder = ByteArrayDecoder; + + /// \brief Internal helper class for decoding BYTE_ARRAY data where we can + /// overflow the capacity of a single arrow::BinaryArray + struct Accumulator { + std::unique_ptr<::arrow::BinaryBuilder> builder; + std::vector> chunks; + }; + using ArrowType = ::arrow::BinaryType; + using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::BinaryType>; +}; + +template <> +struct EncodingTraits { + using Encoder = FLBAEncoder; + using Decoder = FLBADecoder; + + using ArrowType = ::arrow::FixedSizeBinaryType; + using Accumulator = ::arrow::FixedSizeBinaryBuilder; + using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::FixedSizeBinaryType>; +}; + + +// Untyped base for all encoders +class Encoder { + public: + virtual ~Encoder() = default; + + virtual int64_t EstimatedDataEncodedSize() = 0; + virtual std::shared_ptr FlushValues() = 0; + virtual Encoding::type encoding() const = 0; + + virtual void Put(const ::arrow::Array& values) = 0; + + virtual MemoryPool* memory_pool() const = 0; +}; + +// Base class for value encoders. Since encoders may or not have state (e.g., +// dictionary encoding) we use a class instance to maintain any state. +// +// Encode interfaces are internal, subject to change without deprecation. +template +class TypedEncoder : virtual public Encoder { + public: + typedef typename DType::c_type T; + + using Encoder::Put; + + virtual void Put(const T* src, int num_values) = 0; + + virtual void Put(const std::vector& src, int num_values = -1); + + virtual void PutSpaced(const T* src, int num_values, const uint8_t* valid_bits, + int64_t valid_bits_offset) = 0; +}; + +template +void TypedEncoder::Put(const std::vector& src, int num_values) { + if (num_values == -1) { + num_values = static_cast(src.size()); + } + Put(src.data(), num_values); +} + +template <> +inline void TypedEncoder::Put(const std::vector& src, int num_values) { + // NOTE(wesm): This stub is here only to satisfy the compiler; it is + // overridden later with the actual implementation +} + +// Base class for dictionary encoders +template +class DictEncoder : virtual public TypedEncoder { + public: + /// Writes out any buffered indices to buffer preceded by the bit width of this data. + /// Returns the number of bytes written. + /// If the supplied buffer is not big enough, returns -1. + /// buffer must be preallocated with buffer_len bytes. Use EstimatedDataEncodedSize() + /// to size buffer. + virtual int WriteIndices(uint8_t* buffer, int buffer_len) = 0; + + virtual int dict_encoded_size() = 0; + // virtual int dict_encoded_size() { return dict_encoded_size_; } + + virtual int bit_width() const = 0; + + /// Writes out the encoded dictionary to buffer. buffer must be preallocated to + /// dict_encoded_size() bytes. + virtual void WriteDict(uint8_t* buffer) = 0; + + virtual int num_entries() const = 0; + + /// \brief EXPERIMENTAL: Append dictionary indices into the encoder. It is + /// assumed (without any boundschecking) that the indices reference + /// pre-existing dictionary values + /// \param[in] indices the dictionary index values. Only Int32Array currently + /// supported + virtual void PutIndices(const ::arrow::Array& indices) = 0; + + /// \brief EXPERIMENTAL: Append dictionary into encoder, inserting indices + /// separately. Currently throws exception if the current dictionary memo is + /// non-empty + /// \param[in] values the dictionary values. Only valid for certain + /// Parquet/Arrow type combinations, like BYTE_ARRAY/BinaryArray + virtual void PutDictionary(const ::arrow::Array& values) = 0; +}; + +// ---------------------------------------------------------------------- +// Value decoding + +class Decoder { + public: + virtual ~Decoder() = default; + + // Sets the data for a new page. This will be called multiple times on the same + // decoder and should reset all internal state. + virtual void SetData(int num_values, const uint8_t* data, int len) = 0; + + // Returns the number of values left (for the last call to SetData()). This is + // the number of values left in this page. + virtual int values_left() const = 0; + virtual Encoding::type encoding() const = 0; +}; + +template +class TypedDecoder : virtual public Decoder { + public: + using T = typename DType::c_type; + + /// \brief Decode values into a buffer + /// + /// Subclasses may override the more specialized Decode methods below. + /// + /// \param[in] buffer destination for decoded values + /// \param[in] max_values maximum number of values to decode + /// \return The number of values decoded. Should be identical to max_values except + /// at the end of the current data page. + virtual int Decode(T* buffer, int max_values) = 0; + + /// \brief Decode the values in this data page but leave spaces for null entries. + /// + /// \param[in] buffer destination for decoded values + /// \param[in] num_values size of the def_levels and buffer arrays including the number + /// of null slots + /// \param[in] null_count number of null slots + /// \param[in] valid_bits bitmap data indicating position of valid slots + /// \param[in] valid_bits_offset offset into valid_bits + /// \return The number of values decoded, including nulls. + virtual int DecodeSpaced(T* buffer, int num_values, int null_count, + const uint8_t* valid_bits, int64_t valid_bits_offset) { + if (null_count > 0) { + int values_to_read = num_values - null_count; + int values_read = Decode(buffer, values_to_read); + if (values_read != values_to_read) { + throw ParquetException("Number of values / definition_levels read did not match"); + } + + return ::arrow::util::internal::SpacedExpand(buffer, num_values, null_count, + valid_bits, valid_bits_offset); + } else { + return Decode(buffer, num_values); + } + } + + virtual int DecodeCH(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + PaddedPODArray* column_chars_t_p, + PaddedPODArray* column_offsets_p, + ::arrow::internal::BitmapWriter& bitmap_writer + ) {} + + int DecodeCHNonNull(int num_values, + PaddedPODArray* column_chars_t_p, + PaddedPODArray* column_offsets_p, + ::arrow::internal::BitmapWriter& bitmap_writer) { + return DecodeCH(num_values, 0, /*valid_bits=*/NULLPTR, 0, column_chars_t_p, column_offsets_p, bitmap_writer); + } + + /// \brief Decode into an ArrayBuilder or other accumulator + /// + /// This function assumes the definition levels were already decoded + /// as a validity bitmap in the given `valid_bits`. `null_count` + /// is the number of 0s in `valid_bits`. + /// As a space optimization, it is allowed for `valid_bits` to be null + /// if `null_count` is zero. + /// + /// \return number of values decoded + virtual int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::Accumulator* out + ) = 0; + + /// \brief Decode into an ArrayBuilder or other accumulator ignoring nulls + /// + /// \return number of values decoded + int DecodeArrowNonNull(int num_values, + typename EncodingTraits::Accumulator* out) { + return DecodeArrow(num_values, 0, /*valid_bits=*/NULLPTR, 0, out); + } + + /// \brief Decode into a DictionaryBuilder + /// + /// This function assumes the definition levels were already decoded + /// as a validity bitmap in the given `valid_bits`. `null_count` + /// is the number of 0s in `valid_bits`. + /// As a space optimization, it is allowed for `valid_bits` to be null + /// if `null_count` is zero. + /// + /// \return number of values decoded + virtual int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename EncodingTraits::DictAccumulator* builder) = 0; + + /// \brief Decode into a DictionaryBuilder ignoring nulls + /// + /// \return number of values decoded + int DecodeArrowNonNull(int num_values, + typename EncodingTraits::DictAccumulator* builder) { + return DecodeArrow(num_values, 0, /*valid_bits=*/NULLPTR, 0, builder); + } +}; + +template +class DictDecoder : virtual public TypedDecoder { + public: + using T = typename DType::c_type; + + virtual void SetDict(TypedDecoder* dictionary) = 0; + + /// \brief Insert dictionary values into the Arrow dictionary builder's memo, + /// but do not append any indices + virtual void InsertDictionary(::arrow::ArrayBuilder* builder) = 0; + + /// \brief Decode only dictionary indices and append to dictionary + /// builder. The builder must have had the dictionary from this decoder + /// inserted already. + /// + /// \warning Remember to reset the builder each time the dict decoder is initialized + /// with a new dictionary page + virtual int DecodeIndicesSpaced(int num_values, int null_count, + const uint8_t* valid_bits, int64_t valid_bits_offset, + ::arrow::ArrayBuilder* builder) = 0; + + /// \brief Decode only dictionary indices (no nulls) + /// + /// \warning Remember to reset the builder each time the dict decoder is initialized + /// with a new dictionary page + virtual int DecodeIndices(int num_values, ::arrow::ArrayBuilder* builder) = 0; + + /// \brief Decode only dictionary indices (no nulls). Same as above + /// DecodeIndices but target is an array instead of a builder. + /// + /// \note API EXPERIMENTAL + virtual int DecodeIndices(int num_values, int32_t* indices) = 0; + + /// \brief Get dictionary. The reader will call this API when it encounters a + /// new dictionary. + /// + /// @param[out] dictionary The pointer to dictionary values. Dictionary is owned by + /// the decoder and is destroyed when the decoder is destroyed. + /// @param[out] dictionary_length The dictionary length. + /// + /// \note API EXPERIMENTAL + virtual void GetDictionary(const T** dictionary, int32_t* dictionary_length) = 0; +}; + +// ---------------------------------------------------------------------- +// TypedEncoder specializations, traits, and factory functions + +class BooleanDecoder : virtual public TypedDecoder { + public: + using TypedDecoder::Decode; + virtual int Decode(uint8_t* buffer, int max_values) = 0; +}; + +class FLBADecoder : virtual public TypedDecoder { + public: + using TypedDecoder::DecodeSpaced; + + // TODO(wesm): As possible follow-up to PARQUET-1508, we should examine if + // there is value in adding specialized read methods for + // FIXED_LEN_BYTE_ARRAY. If only Decimal data can occur with this data type + // then perhaps not +}; + +PARQUET_EXPORT +std::unique_ptr MakeEncoder( + Type::type type_num, Encoding::type encoding, bool use_dictionary = false, + const ColumnDescriptor* descr = NULLPTR, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + +template +std::unique_ptr::Encoder> MakeTypedEncoder( + Encoding::type encoding, bool use_dictionary = false, + const ColumnDescriptor* descr = NULLPTR, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) { + using OutType = typename EncodingTraits::Encoder; + std::unique_ptr base = + MakeEncoder(DType::type_num, encoding, use_dictionary, descr, pool); + return std::unique_ptr(dynamic_cast(base.release())); +} + +PARQUET_EXPORT +std::unique_ptr MakeDecoder(Type::type type_num, Encoding::type encoding, + const ColumnDescriptor* descr = NULLPTR); + +namespace detail { + +PARQUET_EXPORT +std::unique_ptr MakeDictDecoder(Type::type type_num, + const ColumnDescriptor* descr, + ::arrow::MemoryPool* pool); + +} // namespace detail + +template +std::unique_ptr> MakeDictDecoder( + const ColumnDescriptor* descr = NULLPTR, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) { + using OutType = DictDecoder; + auto decoder = detail::MakeDictDecoder(DType::type_num, descr, pool); + return std::unique_ptr(dynamic_cast(decoder.release())); +} + +template +std::unique_ptr::Decoder> MakeTypedDecoder( + Encoding::type encoding, const ColumnDescriptor* descr = NULLPTR) { + using OutType = typename EncodingTraits::Decoder; + std::unique_ptr base = MakeDecoder(DType::type_num, encoding, descr); + return std::unique_ptr(dynamic_cast(base.release())); +} + +} // namespace parquet diff --git a/utils/local-engine/Storages/ch_parquet/arrow/reader.cc b/utils/local-engine/Storages/ch_parquet/arrow/reader.cc new file mode 100644 index 000000000000..31e6286f0e0d --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/arrow/reader.cc @@ -0,0 +1,1324 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "reader.h" + +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/extension_type.h" +#include "arrow/io/memory.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/future.h" +#include "arrow/util/iterator.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/parallel.h" +#include "arrow/util/range.h" +#include "Storages/ch_parquet/arrow/reader_internal.h" +#include "Storages/ch_parquet/arrow/column_reader.h" +#include "parquet/exception.h" +#include "parquet/file_reader.h" +#include "parquet/metadata.h" +#include "parquet/properties.h" +#include "parquet/schema.h" + +using arrow::Array; +using arrow::ArrayData; +using arrow::BooleanArray; +using arrow::ChunkedArray; +using arrow::DataType; +using arrow::ExtensionType; +using arrow::Field; +using arrow::Future; +using arrow::Int32Array; +using arrow::ListArray; +using arrow::MemoryPool; +using arrow::RecordBatchReader; +using arrow::ResizableBuffer; +using arrow::Status; +using arrow::StructArray; +using arrow::Table; +using arrow::TimestampArray; + +using arrow::internal::checked_cast; +using arrow::internal::Iota; + +// Help reduce verbosity +using ParquetReader = ch_parquet::ParquetFileReader; + +using ch_parquet::internal::RecordReader; + +namespace BitUtil = arrow::BitUtil; + + +using parquet::ParquetFileReader; +using parquet::ArrowReaderProperties; +using parquet::PageReader; +using parquet::ColumnDescriptor; +using parquet::Buffer; +using parquet::arrow::SchemaManifest; + +namespace ch_parquet { + +namespace arrow { + using namespace parquet::arrow; +namespace { + +::arrow::Result> ChunksToSingle(const ChunkedArray& chunked) { + switch (chunked.num_chunks()) { + case 0: { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr array, + ::arrow::MakeArrayOfNull(chunked.type(), 0)); + return array->data(); + } + case 1: + return chunked.chunk(0)->data(); + default: + // ARROW-3762(wesm): If item reader yields a chunked array, we reject as + // this is not yet implemented + return Status::NotImplemented( + "Nested data conversions not implemented for chunked array outputs"); + } +} + +} // namespace + +class ColumnReaderImpl : public ColumnReader { + public: + virtual Status GetDefLevels(const int16_t** data, int64_t* length) = 0; + virtual Status GetRepLevels(const int16_t** data, int64_t* length) = 0; + virtual const std::shared_ptr field() = 0; + + ::arrow::Status NextBatch(int64_t batch_size, + std::shared_ptr<::arrow::ChunkedArray>* out) final { + RETURN_NOT_OK(LoadBatch(batch_size)); + RETURN_NOT_OK(BuildArray(batch_size, out)); + for (int x = 0; x < (*out)->num_chunks(); x++) { + RETURN_NOT_OK((*out)->chunk(x)->Validate()); + } + return Status::OK(); + } + + virtual ::arrow::Status LoadBatch(int64_t num_records) = 0; + + virtual ::arrow::Status BuildArray(int64_t length_upper_bound, + std::shared_ptr<::arrow::ChunkedArray>* out) = 0; + virtual bool IsOrHasRepeatedChild() const = 0; +}; + +namespace { + +std::shared_ptr> VectorToSharedSet( + const std::vector& values) { + std::shared_ptr> result(new std::unordered_set()); + result->insert(values.begin(), values.end()); + return result; +} + +// Forward declaration +Status GetReader(const SchemaField& field, const std::shared_ptr& context, + std::unique_ptr* out); + +// ---------------------------------------------------------------------- +// FileReaderImpl forward declaration + +class FileReaderImpl : public FileReader { + public: + FileReaderImpl(MemoryPool* pool, std::unique_ptr reader, + ArrowReaderProperties properties) + : pool_(pool), + reader_(std::move(reader)), + reader_properties_(std::move(properties)) {} + + Status Init() { + return SchemaManifest::Make(reader_->metadata()->schema(), + reader_->metadata()->key_value_metadata(), + reader_properties_, &manifest_); + } + + FileColumnIteratorFactory SomeRowGroupsFactory(std::vector row_groups) { + return [row_groups](int i, ParquetFileReader* reader) { + return new FileColumnIterator(i, reader, row_groups); + }; + } + + FileColumnIteratorFactory AllRowGroupsFactory() { + return SomeRowGroupsFactory(Iota(reader_->metadata()->num_row_groups())); + } + + Status BoundsCheckColumn(int column) { + if (column < 0 || column >= this->num_columns()) { + return Status::Invalid("Column index out of bounds (got ", column, + ", should be " + "between 0 and ", + this->num_columns() - 1, ")"); + } + return Status::OK(); + } + + Status BoundsCheckRowGroup(int row_group) { + // row group indices check + if (row_group < 0 || row_group >= num_row_groups()) { + return Status::Invalid("Some index in row_group_indices is ", row_group, + ", which is either < 0 or >= num_row_groups(", + num_row_groups(), ")"); + } + return Status::OK(); + } + + Status BoundsCheck(const std::vector& row_groups, + const std::vector& column_indices) { + for (int i : row_groups) { + RETURN_NOT_OK(BoundsCheckRowGroup(i)); + } + for (int i : column_indices) { + RETURN_NOT_OK(BoundsCheckColumn(i)); + } + return Status::OK(); + } + + std::shared_ptr RowGroup(int row_group_index) override; + + Status ReadTable(const std::vector& indices, + std::shared_ptr* out) override { + return ReadRowGroups(Iota(reader_->metadata()->num_row_groups()), indices, out); + } + + Status GetFieldReader(int i, + const std::shared_ptr>& included_leaves, + const std::vector& row_groups, + std::unique_ptr* out) { + auto ctx = std::make_shared(); + ctx->reader = reader_.get(); + ctx->pool = pool_; + ctx->iterator_factory = SomeRowGroupsFactory(row_groups); + ctx->filter_leaves = true; + ctx->included_leaves = included_leaves; + return GetReader(manifest_.schema_fields[i], ctx, out); + } + + Status GetFieldReaders(const std::vector& column_indices, + const std::vector& row_groups, + std::vector>* out, + std::shared_ptr<::arrow::Schema>* out_schema) { + // We only need to read schema fields which have columns indicated + // in the indices vector + ARROW_ASSIGN_OR_RAISE(std::vector field_indices, + manifest_.GetFieldIndices(column_indices)); + + auto included_leaves = VectorToSharedSet(column_indices); + + out->resize(field_indices.size()); + ::arrow::FieldVector out_fields(field_indices.size()); + for (size_t i = 0; i < out->size(); ++i) { + std::unique_ptr reader; + RETURN_NOT_OK( + GetFieldReader(field_indices[i], included_leaves, row_groups, &reader)); + + out_fields[i] = reader->field(); + out->at(i) = std::move(reader); + } + + *out_schema = ::arrow::schema(std::move(out_fields), manifest_.schema_metadata); + return Status::OK(); + } + + Status GetColumn(int i, FileColumnIteratorFactory iterator_factory, + std::unique_ptr* out); + + Status GetColumn(int i, std::unique_ptr* out) override { + return GetColumn(i, AllRowGroupsFactory(), out); + } + + Status GetSchema(std::shared_ptr<::arrow::Schema>* out) override { + return FromParquetSchema(reader_->metadata()->schema(), reader_properties_, + reader_->metadata()->key_value_metadata(), out); + } + + Status ReadSchemaField(int i, std::shared_ptr* out) override { + auto included_leaves = VectorToSharedSet(Iota(reader_->metadata()->num_columns())); + std::vector row_groups = Iota(reader_->metadata()->num_row_groups()); + + std::unique_ptr reader; + RETURN_NOT_OK(GetFieldReader(i, included_leaves, row_groups, &reader)); + + return ReadColumn(i, row_groups, reader.get(), out); + } + + Status ReadColumn(int i, const std::vector& row_groups, ColumnReader* reader, + std::shared_ptr* out) { + BEGIN_PARQUET_CATCH_EXCEPTIONS + // TODO(wesm): This calculation doesn't make much sense when we have repeated + // schema nodes + int64_t records_to_read = 0; + for (auto row_group : row_groups) { + // Can throw exception + records_to_read += + reader_->metadata()->RowGroup(row_group)->ColumnChunk(i)->num_values(); + } + return reader->NextBatch(records_to_read, out); + END_PARQUET_CATCH_EXCEPTIONS + } + + Status ReadColumn(int i, const std::vector& row_groups, + std::shared_ptr* out) { + std::unique_ptr flat_column_reader; + RETURN_NOT_OK(GetColumn(i, SomeRowGroupsFactory(row_groups), &flat_column_reader)); + return ReadColumn(i, row_groups, flat_column_reader.get(), out); + } + + Status ReadColumn(int i, std::shared_ptr* out) override { + return ReadColumn(i, Iota(reader_->metadata()->num_row_groups()), out); + } + + Status ReadTable(std::shared_ptr
* table) override { + return ReadTable(Iota(reader_->metadata()->num_columns()), table); + } + + Status ReadRowGroups(const std::vector& row_groups, + const std::vector& indices, + std::shared_ptr
* table) override; + + // Helper method used by ReadRowGroups - read the given row groups/columns, skipping + // bounds checks and pre-buffering. Takes a shared_ptr to self to keep the reader + // alive in async contexts. + Future> DecodeRowGroups( + std::shared_ptr self, const std::vector& row_groups, + const std::vector& column_indices, ::arrow::internal::Executor* cpu_executor); + + Status ReadRowGroups(const std::vector& row_groups, + std::shared_ptr
* table) override { + return ReadRowGroups(row_groups, Iota(reader_->metadata()->num_columns()), table); + } + + Status ReadRowGroup(int row_group_index, const std::vector& column_indices, + std::shared_ptr
* out) override { + return ReadRowGroups({row_group_index}, column_indices, out); + } + + Status ReadRowGroup(int i, std::shared_ptr
* table) override { + return ReadRowGroup(i, Iota(reader_->metadata()->num_columns()), table); + } + + Status GetRecordBatchReader(const std::vector& row_group_indices, + const std::vector& column_indices, + std::unique_ptr* out) override; + + Status GetRecordBatchReader(const std::vector& row_group_indices, + std::unique_ptr* out) override { + return GetRecordBatchReader(row_group_indices, + Iota(reader_->metadata()->num_columns()), out); + } + + ::arrow::Result<::arrow::AsyncGenerator>> + GetRecordBatchGenerator(std::shared_ptr reader, + const std::vector row_group_indices, + const std::vector column_indices, + ::arrow::internal::Executor* cpu_executor, + int row_group_readahead) override; + + int num_columns() const { return reader_->metadata()->num_columns(); } + + ParquetFileReader* parquet_reader() const override { return reader_.get(); } + + int num_row_groups() const override { return reader_->metadata()->num_row_groups(); } + + void set_use_threads(bool use_threads) override { + reader_properties_.set_use_threads(use_threads); + } + + void set_batch_size(int64_t batch_size) override { + reader_properties_.set_batch_size(batch_size); + } + + const ArrowReaderProperties& properties() const override { return reader_properties_; } + + const SchemaManifest& manifest() const override { return manifest_; } + + Status ScanContents(std::vector columns, const int32_t column_batch_size, + int64_t* num_rows) override { + BEGIN_PARQUET_CATCH_EXCEPTIONS + *num_rows = ScanFileContents(columns, column_batch_size, reader_.get()); + return Status::OK(); + END_PARQUET_CATCH_EXCEPTIONS + } + + MemoryPool* pool_; + std::unique_ptr reader_; + ArrowReaderProperties reader_properties_; + + SchemaManifest manifest_; +}; + +class RowGroupRecordBatchReader : public ::arrow::RecordBatchReader { + public: + RowGroupRecordBatchReader(::arrow::RecordBatchIterator batches, + std::shared_ptr<::arrow::Schema> schema) + : batches_(std::move(batches)), schema_(std::move(schema)) {} + + ~RowGroupRecordBatchReader() override {} + + Status ReadNext(std::shared_ptr<::arrow::RecordBatch>* out) override { + return batches_.Next().Value(out); + } + + std::shared_ptr<::arrow::Schema> schema() const override { return schema_; } + + private: + ::arrow::Iterator> batches_; + std::shared_ptr<::arrow::Schema> schema_; +}; + +class ColumnChunkReaderImpl : public ColumnChunkReader { + public: + ColumnChunkReaderImpl(FileReaderImpl* impl, int row_group_index, int column_index) + : impl_(impl), column_index_(column_index), row_group_index_(row_group_index) {} + + Status Read(std::shared_ptr<::arrow::ChunkedArray>* out) override { + return impl_->ReadColumn(column_index_, {row_group_index_}, out); + } + + private: + FileReaderImpl* impl_; + int column_index_; + int row_group_index_; +}; + +class RowGroupReaderImpl : public RowGroupReader { + public: + RowGroupReaderImpl(FileReaderImpl* impl, int row_group_index) + : impl_(impl), row_group_index_(row_group_index) {} + + std::shared_ptr Column(int column_index) override { + return std::shared_ptr( + new ColumnChunkReaderImpl(impl_, row_group_index_, column_index)); + } + + Status ReadTable(const std::vector& column_indices, + std::shared_ptr<::arrow::Table>* out) override { + return impl_->ReadRowGroup(row_group_index_, column_indices, out); + } + + Status ReadTable(std::shared_ptr<::arrow::Table>* out) override { + return impl_->ReadRowGroup(row_group_index_, out); + } + + private: + FileReaderImpl* impl_; + int row_group_index_; +}; + +// ---------------------------------------------------------------------- +// Column reader implementations + +// Leaf reader is for primitive arrays and primitive children of nested arrays +class LeafReader : public ColumnReaderImpl { + public: + LeafReader(std::shared_ptr ctx, std::shared_ptr field, + std::unique_ptr input, + ::parquet::internal::LevelInfo leaf_info) + : ctx_(std::move(ctx)), + field_(std::move(field)), + input_(std::move(input)), + descr_(input_->descr()) { + record_reader_ = RecordReader::Make( + descr_, leaf_info, ctx_->pool, field_->type()->id() == ::arrow::Type::DICTIONARY); + NextRowGroup(); + } + + Status GetDefLevels(const int16_t** data, int64_t* length) final { + *data = record_reader_->def_levels(); + *length = record_reader_->levels_position(); + return Status::OK(); + } + + Status GetRepLevels(const int16_t** data, int64_t* length) final { + *data = record_reader_->rep_levels(); + *length = record_reader_->levels_position(); + return Status::OK(); + } + + bool IsOrHasRepeatedChild() const final { return false; } + + Status LoadBatch(int64_t records_to_read) final { + BEGIN_PARQUET_CATCH_EXCEPTIONS + out_ = nullptr; + record_reader_->Reset(); + // Pre-allocation gives much better performance for flat columns + record_reader_->Reserve(records_to_read); + while (records_to_read > 0) { + if (!record_reader_->HasMoreData()) { + break; + } + int64_t records_read = record_reader_->ReadRecords(records_to_read); + records_to_read -= records_read; + if (records_read == 0) { + NextRowGroup(); + } + } + RETURN_NOT_OK(TransferColumnData(record_reader_.get(), field_->type(), descr_, + ctx_->pool, &out_)); + return Status::OK(); + END_PARQUET_CATCH_EXCEPTIONS + } + + ::arrow::Status BuildArray(int64_t length_upper_bound, + std::shared_ptr<::arrow::ChunkedArray>* out) final { + *out = out_; + return Status::OK(); + } + + const std::shared_ptr field() override { return field_; } + + private: + std::shared_ptr out_; + void NextRowGroup() { + std::unique_ptr page_reader = input_->NextChunk(); + record_reader_->SetPageReader(std::move(page_reader)); + } + + std::shared_ptr ctx_; + std::shared_ptr field_; + std::unique_ptr input_; + const ColumnDescriptor* descr_; + std::shared_ptr record_reader_; +}; + +// Column reader for extension arrays +class ExtensionReader : public ColumnReaderImpl { + public: + ExtensionReader(std::shared_ptr field, + std::unique_ptr storage_reader) + : field_(std::move(field)), storage_reader_(std::move(storage_reader)) {} + + Status GetDefLevels(const int16_t** data, int64_t* length) override { + return storage_reader_->GetDefLevels(data, length); + } + + Status GetRepLevels(const int16_t** data, int64_t* length) override { + return storage_reader_->GetRepLevels(data, length); + } + + Status LoadBatch(int64_t number_of_records) final { + return storage_reader_->LoadBatch(number_of_records); + } + + Status BuildArray(int64_t length_upper_bound, + std::shared_ptr* out) override { + std::shared_ptr storage; + RETURN_NOT_OK(storage_reader_->BuildArray(length_upper_bound, &storage)); + *out = ExtensionType::WrapArray(field_->type(), storage); + return Status::OK(); + } + + bool IsOrHasRepeatedChild() const final { + return storage_reader_->IsOrHasRepeatedChild(); + } + + const std::shared_ptr field() override { return field_; } + + private: + std::shared_ptr field_; + std::unique_ptr storage_reader_; +}; + +template +class ListReader : public ColumnReaderImpl { + public: + ListReader(std::shared_ptr ctx, std::shared_ptr field, + ::parquet::internal::LevelInfo level_info, + std::unique_ptr child_reader) + : ctx_(std::move(ctx)), + field_(std::move(field)), + level_info_(level_info), + item_reader_(std::move(child_reader)) {} + + Status GetDefLevels(const int16_t** data, int64_t* length) override { + return item_reader_->GetDefLevels(data, length); + } + + Status GetRepLevels(const int16_t** data, int64_t* length) override { + return item_reader_->GetRepLevels(data, length); + } + + bool IsOrHasRepeatedChild() const final { return true; } + + Status LoadBatch(int64_t number_of_records) final { + return item_reader_->LoadBatch(number_of_records); + } + + virtual ::arrow::Result> AssembleArray( + std::shared_ptr data) { + if (field_->type()->id() == ::arrow::Type::MAP) { + // Error out if data is not map-compliant instead of aborting in MakeArray below + RETURN_NOT_OK(::arrow::MapArray::ValidateChildData(data->child_data)); + } + std::shared_ptr result = ::arrow::MakeArray(data); + return std::make_shared(result); + } + + Status BuildArray(int64_t length_upper_bound, + std::shared_ptr* out) override { + const int16_t* def_levels; + const int16_t* rep_levels; + int64_t num_levels; + RETURN_NOT_OK(item_reader_->GetDefLevels(&def_levels, &num_levels)); + RETURN_NOT_OK(item_reader_->GetRepLevels(&rep_levels, &num_levels)); + + std::shared_ptr validity_buffer; + ::parquet::internal::ValidityBitmapInputOutput validity_io; + validity_io.values_read_upper_bound = length_upper_bound; + if (field_->nullable()) { + ARROW_ASSIGN_OR_RAISE( + validity_buffer, + AllocateResizableBuffer(BitUtil::BytesForBits(length_upper_bound), ctx_->pool)); + validity_io.valid_bits = validity_buffer->mutable_data(); + } + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr offsets_buffer, + AllocateResizableBuffer( + sizeof(IndexType) * std::max(int64_t{1}, length_upper_bound + 1), + ctx_->pool)); + // Ensure zero initialization in case we have reached a zero length list (and + // because first entry is always zero). + IndexType* offset_data = reinterpret_cast(offsets_buffer->mutable_data()); + offset_data[0] = 0; + BEGIN_PARQUET_CATCH_EXCEPTIONS + ::parquet::internal::DefRepLevelsToList(def_levels, rep_levels, num_levels, + level_info_, &validity_io, offset_data); + END_PARQUET_CATCH_EXCEPTIONS + + RETURN_NOT_OK(item_reader_->BuildArray(offset_data[validity_io.values_read], out)); + + // Resize to actual number of elements returned. + RETURN_NOT_OK( + offsets_buffer->Resize((validity_io.values_read + 1) * sizeof(IndexType))); + if (validity_buffer != nullptr) { + RETURN_NOT_OK( + validity_buffer->Resize(BitUtil::BytesForBits(validity_io.values_read))); + validity_buffer->ZeroPadding(); + } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr item_chunk, ChunksToSingle(**out)); + + std::vector> buffers{ + validity_io.null_count > 0 ? validity_buffer : nullptr, offsets_buffer}; + auto data = std::make_shared( + field_->type(), + /*length=*/validity_io.values_read, std::move(buffers), + std::vector>{item_chunk}, validity_io.null_count); + + ARROW_ASSIGN_OR_RAISE(*out, AssembleArray(std::move(data))); + return Status::OK(); + } + + const std::shared_ptr field() override { return field_; } + + private: + std::shared_ptr ctx_; + std::shared_ptr field_; + ::parquet::internal::LevelInfo level_info_; + std::unique_ptr item_reader_; +}; + +class PARQUET_NO_EXPORT FixedSizeListReader : public ListReader { + public: + FixedSizeListReader(std::shared_ptr ctx, std::shared_ptr field, + ::parquet::internal::LevelInfo level_info, + std::unique_ptr child_reader) + : ListReader(std::move(ctx), std::move(field), level_info, + std::move(child_reader)) {} + ::arrow::Result> AssembleArray( + std::shared_ptr data) final { + DCHECK_EQ(data->buffers.size(), 2); + DCHECK_EQ(field()->type()->id(), ::arrow::Type::FIXED_SIZE_LIST); + const auto& type = checked_cast<::arrow::FixedSizeListType&>(*field()->type()); + const int32_t* offsets = reinterpret_cast(data->buffers[1]->data()); + for (int x = 1; x <= data->length; x++) { + int32_t size = offsets[x] - offsets[x - 1]; + if (size != type.list_size()) { + return Status::Invalid("Expected all lists to be of size=", type.list_size(), + " but index ", x, " had size=", size); + } + } + data->buffers.resize(1); + std::shared_ptr result = ::arrow::MakeArray(data); + return std::make_shared(result); + } +}; + +class PARQUET_NO_EXPORT StructReader : public ColumnReaderImpl { + public: + explicit StructReader(std::shared_ptr ctx, + std::shared_ptr filtered_field, + ::parquet::internal::LevelInfo level_info, + std::vector> children) + : ctx_(std::move(ctx)), + filtered_field_(std::move(filtered_field)), + level_info_(level_info), + children_(std::move(children)) { + // There could be a mix of children some might be repeated some might not be. + // If possible use one that isn't since that will be guaranteed to have the least + // number of levels to reconstruct a nullable bitmap. + auto result = std::find_if(children_.begin(), children_.end(), + [](const std::unique_ptr& child) { + return !child->IsOrHasRepeatedChild(); + }); + if (result != children_.end()) { + def_rep_level_child_ = result->get(); + has_repeated_child_ = false; + } else if (!children_.empty()) { + def_rep_level_child_ = children_.front().get(); + has_repeated_child_ = true; + } + } + + bool IsOrHasRepeatedChild() const final { return has_repeated_child_; } + + Status LoadBatch(int64_t records_to_read) override { + for (const std::unique_ptr& reader : children_) { + RETURN_NOT_OK(reader->LoadBatch(records_to_read)); + } + return Status::OK(); + } + Status BuildArray(int64_t length_upper_bound, + std::shared_ptr* out) override; + Status GetDefLevels(const int16_t** data, int64_t* length) override; + Status GetRepLevels(const int16_t** data, int64_t* length) override; + const std::shared_ptr field() override { return filtered_field_; } + + private: + const std::shared_ptr ctx_; + const std::shared_ptr filtered_field_; + const ::parquet::internal::LevelInfo level_info_; + const std::vector> children_; + ColumnReaderImpl* def_rep_level_child_ = nullptr; + bool has_repeated_child_; +}; + +Status StructReader::GetDefLevels(const int16_t** data, int64_t* length) { + *data = nullptr; + if (children_.size() == 0) { + *length = 0; + return Status::Invalid("StructReader had no children"); + } + + // This method should only be called when this struct or one of its parents + // are optional/repeated or it has a repeated child. + // Meaning all children must have rep/def levels associated + // with them. + RETURN_NOT_OK(def_rep_level_child_->GetDefLevels(data, length)); + return Status::OK(); +} + +Status StructReader::GetRepLevels(const int16_t** data, int64_t* length) { + *data = nullptr; + if (children_.size() == 0) { + *length = 0; + return Status::Invalid("StructReader had no childre"); + } + + // This method should only be called when this struct or one of its parents + // are optional/repeated or it has repeated child. + // Meaning all children must have rep/def levels associated + // with them. + RETURN_NOT_OK(def_rep_level_child_->GetRepLevels(data, length)); + return Status::OK(); +} + +Status StructReader::BuildArray(int64_t length_upper_bound, + std::shared_ptr* out) { + std::vector> children_array_data; + std::shared_ptr null_bitmap; + + ::parquet::internal::ValidityBitmapInputOutput validity_io; + validity_io.values_read_upper_bound = length_upper_bound; + // This simplifies accounting below. + validity_io.values_read = length_upper_bound; + + BEGIN_PARQUET_CATCH_EXCEPTIONS + const int16_t* def_levels; + const int16_t* rep_levels; + int64_t num_levels; + + if (has_repeated_child_) { + ARROW_ASSIGN_OR_RAISE( + null_bitmap, + AllocateResizableBuffer(BitUtil::BytesForBits(length_upper_bound), ctx_->pool)); + validity_io.valid_bits = null_bitmap->mutable_data(); + RETURN_NOT_OK(GetDefLevels(&def_levels, &num_levels)); + RETURN_NOT_OK(GetRepLevels(&rep_levels, &num_levels)); + DefRepLevelsToBitmap(def_levels, rep_levels, num_levels, level_info_, &validity_io); + } else if (filtered_field_->nullable()) { + ARROW_ASSIGN_OR_RAISE( + null_bitmap, + AllocateResizableBuffer(BitUtil::BytesForBits(length_upper_bound), ctx_->pool)); + validity_io.valid_bits = null_bitmap->mutable_data(); + RETURN_NOT_OK(GetDefLevels(&def_levels, &num_levels)); + DefLevelsToBitmap(def_levels, num_levels, level_info_, &validity_io); + } + + // Ensure all values are initialized. + if (null_bitmap) { + RETURN_NOT_OK(null_bitmap->Resize(BitUtil::BytesForBits(validity_io.values_read))); + null_bitmap->ZeroPadding(); + } + + END_PARQUET_CATCH_EXCEPTIONS + // Gather children arrays and def levels + for (auto& child : children_) { + std::shared_ptr field; + RETURN_NOT_OK(child->BuildArray(validity_io.values_read, &field)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr array_data, ChunksToSingle(*field)); + children_array_data.push_back(std::move(array_data)); + } + + if (!filtered_field_->nullable() && !has_repeated_child_) { + validity_io.values_read = children_array_data.front()->length; + } + + std::vector> buffers{validity_io.null_count > 0 ? null_bitmap + : nullptr}; + auto data = + std::make_shared(filtered_field_->type(), + /*length=*/validity_io.values_read, std::move(buffers), + std::move(children_array_data)); + std::shared_ptr result = ::arrow::MakeArray(data); + + *out = std::make_shared(result); + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// File reader implementation + +Status GetReader(const SchemaField& field, const std::shared_ptr& arrow_field, + const std::shared_ptr& ctx, + std::unique_ptr* out) { + BEGIN_PARQUET_CATCH_EXCEPTIONS + + auto type_id = arrow_field->type()->id(); + + if (type_id == ::arrow::Type::EXTENSION) { + auto storage_field = arrow_field->WithType( + checked_cast(*arrow_field->type()).storage_type()); + RETURN_NOT_OK(GetReader(field, storage_field, ctx, out)); + out->reset(new ExtensionReader(arrow_field, std::move(*out))); + return Status::OK(); + } + + if (field.children.size() == 0) { + if (!field.is_leaf()) { + return Status::Invalid("Parquet non-leaf node has no children"); + } + if (!ctx->IncludesLeaf(field.column_index)) { + *out = nullptr; + return Status::OK(); + } + std::unique_ptr input( + ctx->iterator_factory(field.column_index, ctx->reader)); + out->reset(new LeafReader(ctx, arrow_field, std::move(input), field.level_info)); + } else if (type_id == ::arrow::Type::LIST || type_id == ::arrow::Type::MAP || + type_id == ::arrow::Type::FIXED_SIZE_LIST || + type_id == ::arrow::Type::LARGE_LIST) { + auto list_field = arrow_field; + auto child = &field.children[0]; + std::unique_ptr child_reader; + RETURN_NOT_OK(GetReader(*child, ctx, &child_reader)); + if (child_reader == nullptr) { + *out = nullptr; + return Status::OK(); + } + + // These two types might not be equal if there column pruning occurred. + // further down the stack. + const std::shared_ptr reader_child_type = child_reader->field()->type(); + // This should really never happen but was raised as a question on the code + // review, this should be pretty cheap check so leave it in. + if (ARROW_PREDICT_FALSE(list_field->type()->num_fields() != 1)) { + return Status::Invalid("expected exactly one child field for: ", + list_field->ToString()); + } + const DataType& schema_child_type = *(list_field->type()->field(0)->type()); + if (type_id == ::arrow::Type::MAP) { + if (reader_child_type->num_fields() != 2 || + !reader_child_type->field(0)->type()->Equals( + *schema_child_type.field(0)->type())) { + // This case applies if either key or value are completed filtered + // out so we can take the type as is or the key was partially + // so keeping it as a map no longer makes sence. + list_field = list_field->WithType(::arrow::list(child_reader->field())); + } else if (!reader_child_type->field(1)->type()->Equals( + *schema_child_type.field(1)->type())) { + list_field = list_field->WithType(std::make_shared<::arrow::MapType>( + reader_child_type->field( + 0), // field 0 is unchanged baed on previous if statement + reader_child_type->field(1))); + } + // Map types are list> so use ListReader + // for reconstruction. + out->reset(new ListReader(ctx, list_field, field.level_info, + std::move(child_reader))); + } else if (type_id == ::arrow::Type::LIST) { + if (!reader_child_type->Equals(schema_child_type)) { + list_field = list_field->WithType(::arrow::list(reader_child_type)); + } + + out->reset(new ListReader(ctx, list_field, field.level_info, + std::move(child_reader))); + } else if (type_id == ::arrow::Type::LARGE_LIST) { + if (!reader_child_type->Equals(schema_child_type)) { + list_field = list_field->WithType(::arrow::large_list(reader_child_type)); + } + + out->reset(new ListReader(ctx, list_field, field.level_info, + std::move(child_reader))); + } else if (type_id == ::arrow::Type::FIXED_SIZE_LIST) { + if (!reader_child_type->Equals(schema_child_type)) { + auto& fixed_list_type = + checked_cast(*list_field->type()); + int32_t list_size = fixed_list_type.list_size(); + list_field = + list_field->WithType(::arrow::fixed_size_list(reader_child_type, list_size)); + } + + out->reset(new FixedSizeListReader(ctx, list_field, field.level_info, + std::move(child_reader))); + } else { + return Status::UnknownError("Unknown list type: ", field.field->ToString()); + } + } else if (type_id == ::arrow::Type::STRUCT) { + std::vector> child_fields; + int arrow_field_idx = 0; + std::vector> child_readers; + for (const auto& child : field.children) { + std::unique_ptr child_reader; + RETURN_NOT_OK(GetReader(child, ctx, &child_reader)); + if (!child_reader) { + arrow_field_idx++; + // If all children were pruned, then we do not try to read this field + continue; + } + std::shared_ptr<::arrow::Field> child_field = child.field; + const DataType& reader_child_type = *child_reader->field()->type(); + const DataType& schema_child_type = + *arrow_field->type()->field(arrow_field_idx++)->type(); + // These might not be equal if column pruning occurred. + if (!schema_child_type.Equals(reader_child_type)) { + child_field = child_field->WithType(child_reader->field()->type()); + } + child_fields.push_back(child_field); + child_readers.emplace_back(std::move(child_reader)); + } + if (child_fields.size() == 0) { + *out = nullptr; + return Status::OK(); + } + auto filtered_field = + ::arrow::field(arrow_field->name(), ::arrow::struct_(child_fields), + arrow_field->nullable(), arrow_field->metadata()); + out->reset(new StructReader(ctx, filtered_field, field.level_info, + std::move(child_readers))); + } else { + return Status::Invalid("Unsupported nested type: ", arrow_field->ToString()); + } + return Status::OK(); + + END_PARQUET_CATCH_EXCEPTIONS +} + +Status GetReader(const SchemaField& field, const std::shared_ptr& ctx, + std::unique_ptr* out) { + return GetReader(field, field.field, ctx, out); +} + +} // namespace + +Status FileReaderImpl::GetRecordBatchReader(const std::vector& row_groups, + const std::vector& column_indices, + std::unique_ptr* out) { + RETURN_NOT_OK(BoundsCheck(row_groups, column_indices)); + + if (reader_properties_.pre_buffer()) { + // PARQUET-1698/PARQUET-1820: pre-buffer row groups/column chunks if enabled + BEGIN_PARQUET_CATCH_EXCEPTIONS + reader_->PreBuffer(row_groups, column_indices, reader_properties_.io_context(), + reader_properties_.cache_options()); + END_PARQUET_CATCH_EXCEPTIONS + } + + std::vector> readers; + std::shared_ptr<::arrow::Schema> batch_schema; + RETURN_NOT_OK(GetFieldReaders(column_indices, row_groups, &readers, &batch_schema)); + + if (readers.empty()) { + // Just generate all batches right now; they're cheap since they have no columns. + int64_t batch_size = properties().batch_size(); + auto max_sized_batch = + ::arrow::RecordBatch::Make(batch_schema, batch_size, ::arrow::ArrayVector{}); + + ::arrow::RecordBatchVector batches; + + for (int row_group : row_groups) { + int64_t num_rows = parquet_reader()->metadata()->RowGroup(row_group)->num_rows(); + + batches.insert(batches.end(), num_rows / batch_size, max_sized_batch); + + if (int64_t trailing_rows = num_rows % batch_size) { + batches.push_back(max_sized_batch->Slice(0, trailing_rows)); + } + } + + *out = ::arrow::internal::make_unique( + ::arrow::MakeVectorIterator(std::move(batches)), std::move(batch_schema)); + + return Status::OK(); + } + + int64_t num_rows = 0; + for (int row_group : row_groups) { + num_rows += parquet_reader()->metadata()->RowGroup(row_group)->num_rows(); + } + + using ::arrow::RecordBatchIterator; + + // NB: This lambda will be invoked outside the scope of this call to + // `GetRecordBatchReader()`, so it must capture `readers` and `batch_schema` by value. + // `this` is a non-owning pointer so we are relying on the parent FileReader outliving + // this RecordBatchReader. + ::arrow::Iterator batches = ::arrow::MakeFunctionIterator( + [readers, batch_schema, num_rows, + this]() mutable -> ::arrow::Result { + ::arrow::ChunkedArrayVector columns(readers.size()); + + // don't reserve more rows than necessary + int64_t batch_size = std::min(properties().batch_size(), num_rows); + num_rows -= batch_size; + + RETURN_NOT_OK(::arrow::internal::OptionalParallelFor( + reader_properties_.use_threads(), static_cast(readers.size()), + [&](int i) { return readers[i]->NextBatch(batch_size, &columns[i]); })); + + for (const auto& column : columns) { + if (column == nullptr || column->length() == 0) { + return ::arrow::IterationTraits::End(); + } + } + + //Table reader will slice the batch, we don't want it happen +// auto table = ::arrow::Table::Make(batch_schema, std::move(columns)); +// auto table_reader = std::make_shared<::arrow::TableBatchReader>(*table); +// +// // NB: explicitly preserve table so that table_reader doesn't outlive it +// return ::arrow::MakeFunctionIterator( +// [table, table_reader] { return table_reader->Next(); }); + + std::vector> arrays; + for (const auto& column : columns) { + arrays.emplace_back(column->chunk(0)); + } + return ::arrow::MakeVectorIterator>( + {(::arrow::RecordBatch::Make(batch_schema, batch_size, std::move(arrays)))}); + }); + + + *out = ::arrow::internal::make_unique( + ::arrow::MakeFlattenIterator(std::move(batches)), std::move(batch_schema)); + + return Status::OK(); +} + +/// Given a file reader and a list of row groups, this is a generator of record +/// batch generators (where each sub-generator is the contents of a single row group). +class RowGroupGenerator { + public: + using RecordBatchGenerator = + ::arrow::AsyncGenerator>; + + explicit RowGroupGenerator(std::shared_ptr arrow_reader, + ::arrow::internal::Executor* cpu_executor, + std::vector row_groups, std::vector column_indices) + : arrow_reader_(std::move(arrow_reader)), + cpu_executor_(cpu_executor), + row_groups_(std::move(row_groups)), + column_indices_(std::move(column_indices)), + index_(0) {} + + ::arrow::Future operator()() { + if (index_ >= row_groups_.size()) { + return ::arrow::AsyncGeneratorEnd(); + } + int row_group = row_groups_[index_++]; + std::vector column_indices = column_indices_; + auto reader = arrow_reader_; + if (!reader->properties().pre_buffer()) { + return SubmitRead(cpu_executor_, reader, row_group, column_indices); + } + auto ready = reader->parquet_reader()->WhenBuffered({row_group}, column_indices); + if (cpu_executor_) ready = cpu_executor_->TransferAlways(ready); + return ready.Then([=]() -> ::arrow::Future { + return ReadOneRowGroup(cpu_executor_, reader, row_group, column_indices); + }); + } + + private: + // Synchronous fallback for when pre-buffer isn't enabled. + // + // Making the Parquet reader truly asynchronous requires heavy refactoring, so the + // generator piggybacks on ReadRangeCache. The lazy ReadRangeCache can be used for + // async I/O without forcing readahead. + static ::arrow::Future SubmitRead( + ::arrow::internal::Executor* cpu_executor, std::shared_ptr self, + const int row_group, const std::vector& column_indices) { + if (!cpu_executor) { + return ReadOneRowGroup(cpu_executor, self, row_group, column_indices); + } + // If we have an executor, then force transfer (even if I/O was complete) + return ::arrow::DeferNotOk(cpu_executor->Submit(ReadOneRowGroup, cpu_executor, self, + row_group, column_indices)); + } + + static ::arrow::Future ReadOneRowGroup( + ::arrow::internal::Executor* cpu_executor, std::shared_ptr self, + const int row_group, const std::vector& column_indices) { + // Skips bound checks/pre-buffering, since we've done that already + const int64_t batch_size = self->properties().batch_size(); + return self->DecodeRowGroups(self, {row_group}, column_indices, cpu_executor) + .Then([batch_size](const std::shared_ptr
& table) + -> ::arrow::Result { + ::arrow::TableBatchReader table_reader(*table); + table_reader.set_chunksize(batch_size); + ::arrow::RecordBatchVector batches; + RETURN_NOT_OK(table_reader.ReadAll(&batches)); + return ::arrow::MakeVectorGenerator(std::move(batches)); + }); + } + + std::shared_ptr arrow_reader_; + ::arrow::internal::Executor* cpu_executor_; + std::vector row_groups_; + std::vector column_indices_; + size_t index_; +}; + +::arrow::Result<::arrow::AsyncGenerator>> +FileReaderImpl::GetRecordBatchGenerator(std::shared_ptr reader, + const std::vector row_group_indices, + const std::vector column_indices, + ::arrow::internal::Executor* cpu_executor, + int row_group_readahead) { + RETURN_NOT_OK(BoundsCheck(row_group_indices, column_indices)); + if (reader_properties_.pre_buffer()) { + BEGIN_PARQUET_CATCH_EXCEPTIONS + reader_->PreBuffer(row_group_indices, column_indices, reader_properties_.io_context(), + reader_properties_.cache_options()); + END_PARQUET_CATCH_EXCEPTIONS + } + ::arrow::AsyncGenerator row_group_generator = + RowGroupGenerator(::arrow::internal::checked_pointer_cast(reader), + cpu_executor, row_group_indices, column_indices); + if (row_group_readahead > 0) { + row_group_generator = ::arrow::MakeReadaheadGenerator(std::move(row_group_generator), + row_group_readahead); + } + return ::arrow::MakeConcatenatedGenerator(std::move(row_group_generator)); +} + +Status FileReaderImpl::GetColumn(int i, FileColumnIteratorFactory iterator_factory, + std::unique_ptr* out) { + RETURN_NOT_OK(BoundsCheckColumn(i)); + auto ctx = std::make_shared(); + ctx->reader = reader_.get(); + ctx->pool = pool_; + ctx->iterator_factory = iterator_factory; + ctx->filter_leaves = false; + std::unique_ptr result; + RETURN_NOT_OK(GetReader(manifest_.schema_fields[i], ctx, &result)); + out->reset(result.release()); + return Status::OK(); +} + +Status FileReaderImpl::ReadRowGroups(const std::vector& row_groups, + const std::vector& column_indices, + std::shared_ptr
* out) { + RETURN_NOT_OK(BoundsCheck(row_groups, column_indices)); + + // PARQUET-1698/PARQUET-1820: pre-buffer row groups/column chunks if enabled + if (reader_properties_.pre_buffer()) { + BEGIN_PARQUET_CATCH_EXCEPTIONS + parquet_reader()->PreBuffer(row_groups, column_indices, + reader_properties_.io_context(), + reader_properties_.cache_options()); + END_PARQUET_CATCH_EXCEPTIONS + } + + auto fut = DecodeRowGroups(/*self=*/nullptr, row_groups, column_indices, + /*cpu_executor=*/nullptr); + ARROW_ASSIGN_OR_RAISE(*out, fut.MoveResult()); + return Status::OK(); +} + +Future> FileReaderImpl::DecodeRowGroups( + std::shared_ptr self, const std::vector& row_groups, + const std::vector& column_indices, ::arrow::internal::Executor* cpu_executor) { + // `self` is used solely to keep `this` alive in an async context - but we use this + // in a sync context too so use `this` over `self` + std::vector> readers; + std::shared_ptr<::arrow::Schema> result_schema; + RETURN_NOT_OK(GetFieldReaders(column_indices, row_groups, &readers, &result_schema)); + // OptionalParallelForAsync requires an executor + if (!cpu_executor) cpu_executor = ::arrow::internal::GetCpuThreadPool(); + + auto read_column = [row_groups, self, this](size_t i, + std::shared_ptr reader) + -> ::arrow::Result> { + std::shared_ptr<::arrow::ChunkedArray> column; + RETURN_NOT_OK(ReadColumn(static_cast(i), row_groups, reader.get(), &column)); + return column; + }; + auto make_table = [result_schema, row_groups, self, + this](const ::arrow::ChunkedArrayVector& columns) + -> ::arrow::Result> { + int64_t num_rows = 0; + if (!columns.empty()) { + num_rows = columns[0]->length(); + } else { + for (int i : row_groups) { + num_rows += parquet_reader()->metadata()->RowGroup(i)->num_rows(); + } + } + auto table = Table::Make(std::move(result_schema), columns, num_rows); + RETURN_NOT_OK(table->Validate()); + return table; + }; + return ::arrow::internal::OptionalParallelForAsync(reader_properties_.use_threads(), + std::move(readers), read_column, + cpu_executor) + .Then(std::move(make_table)); +} + +std::shared_ptr FileReaderImpl::RowGroup(int row_group_index) { + return std::make_shared(this, row_group_index); +} + +// ---------------------------------------------------------------------- +// Public factory functions + +Status FileReader::GetRecordBatchReader(const std::vector& row_group_indices, + std::shared_ptr* out) { + std::unique_ptr tmp; + ARROW_RETURN_NOT_OK(GetRecordBatchReader(row_group_indices, &tmp)); + out->reset(tmp.release()); + return Status::OK(); +} + +Status FileReader::GetRecordBatchReader(const std::vector& row_group_indices, + const std::vector& column_indices, + std::shared_ptr* out) { + std::unique_ptr tmp; + ARROW_RETURN_NOT_OK(GetRecordBatchReader(row_group_indices, column_indices, &tmp)); + out->reset(tmp.release()); + return Status::OK(); +} + +Status FileReader::Make(::arrow::MemoryPool* pool, + std::unique_ptr reader, + const ArrowReaderProperties& properties, + std::unique_ptr* out) { + out->reset(new FileReaderImpl(pool, std::move(reader), properties)); + return static_cast(out->get())->Init(); +} + +Status FileReader::Make(::arrow::MemoryPool* pool, + std::unique_ptr reader, + std::unique_ptr* out) { + return Make(pool, std::move(reader), default_arrow_reader_properties(), out); +} + +FileReaderBuilder::FileReaderBuilder() + : pool_(::arrow::default_memory_pool()), + properties_(default_arrow_reader_properties()) {} + +Status FileReaderBuilder::Open(std::shared_ptr<::arrow::io::RandomAccessFile> file, + const ReaderProperties& properties, + std::shared_ptr metadata) { + PARQUET_CATCH_NOT_OK(raw_reader_ = ParquetReader::Open(std::move(file), properties, + std::move(metadata))); + return Status::OK(); +} + +FileReaderBuilder* FileReaderBuilder::memory_pool(::arrow::MemoryPool* pool) { + pool_ = pool; + return this; +} + +FileReaderBuilder* FileReaderBuilder::properties( + const ArrowReaderProperties& arg_properties) { + properties_ = arg_properties; + return this; +} + +Status FileReaderBuilder::Build(std::unique_ptr* out) { + return FileReader::Make(pool_, std::move(raw_reader_), properties_, out); +} + +Status OpenFile(std::shared_ptr<::arrow::io::RandomAccessFile> file, MemoryPool* pool, + std::unique_ptr* reader) { + FileReaderBuilder builder; + RETURN_NOT_OK(builder.Open(std::move(file))); + return builder.memory_pool(pool)->Build(reader); +} + +namespace internal { + +Status FuzzReader(std::unique_ptr reader) { + auto st = Status::OK(); + for (int i = 0; i < reader->num_row_groups(); ++i) { + std::shared_ptr
table; + auto row_group_status = reader->ReadRowGroup(i, &table); + if (row_group_status.ok()) { + row_group_status &= table->ValidateFull(); + } + st &= row_group_status; + } + return st; +} + +Status FuzzReader(const uint8_t* data, int64_t size) { + auto buffer = std::make_shared<::arrow::Buffer>(data, size); + auto file = std::make_shared<::arrow::io::BufferReader>(buffer); + FileReaderBuilder builder; + RETURN_NOT_OK(builder.Open(std::move(file))); + + std::unique_ptr reader; + RETURN_NOT_OK(builder.Build(&reader)); + return FuzzReader(std::move(reader)); +} + +} // namespace internal + +} // namespace arrow +} // namespace parquet diff --git a/utils/local-engine/Storages/ch_parquet/arrow/reader.h b/utils/local-engine/Storages/ch_parquet/arrow/reader.h new file mode 100644 index 000000000000..4f46bd60763b --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/arrow/reader.h @@ -0,0 +1,357 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +// N.B. we don't include async_generator.h as it's relatively heavy +#include +#include +#include + +#include "parquet/file_reader.h" +#include "parquet/platform.h" +#include "parquet/properties.h" + +namespace arrow { + +class ChunkedArray; +class KeyValueMetadata; +class RecordBatchReader; +struct Scalar; +class Schema; +class Table; +class RecordBatch; + +} // namespace arrow + +namespace parquet +{ + +class FileMetaData; +class SchemaDescriptor; +namespace arrow +{ + //class ColumnChunkReader; + //class ColumnReader; + struct SchemaManifest; + //class RowGroupReader; +} +} + +namespace ch_parquet { + +using namespace parquet; +using namespace parquet::arrow; + +namespace arrow { + +class ColumnChunkReader; +class ColumnReader; +class RowGroupReader; + +/// \brief Arrow read adapter class for deserializing Parquet files as Arrow row batches. +/// +/// This interfaces caters for different use cases and thus provides different +/// interfaces. In its most simplistic form, we cater for a user that wants to +/// read the whole Parquet at once with the `FileReader::ReadTable` method. +/// +/// More advanced users that also want to implement parallelism on top of each +/// single Parquet files should do this on the RowGroup level. For this, they can +/// call `FileReader::RowGroup(i)->ReadTable` to receive only the specified +/// RowGroup as a table. +/// +/// In the most advanced situation, where a consumer wants to independently read +/// RowGroups in parallel and consume each column individually, they can call +/// `FileReader::RowGroup(i)->Column(j)->Read` and receive an `arrow::Column` +/// instance. +/// +/// The parquet format supports an optional integer field_id which can be assigned +/// to a field. Arrow will convert these field IDs to a metadata key named +/// PARQUET:field_id on the appropriate field. +// TODO(wesm): nested data does not always make sense with this user +// interface unless you are only reading a single leaf node from a branch of +// a table. For example: +// +// repeated group data { +// optional group record { +// optional int32 val1; +// optional byte_array val2; +// optional bool val3; +// } +// optional int32 val4; +// } +// +// In the Parquet file, there are 3 leaf nodes: +// +// * data.record.val1 +// * data.record.val2 +// * data.record.val3 +// * data.val4 +// +// When materializing this data in an Arrow array, we would have: +// +// data: list), +// val3: bool, +// >, +// val4: int32 +// >> +// +// However, in the Parquet format, each leaf node has its own repetition and +// definition levels describing the structure of the intermediate nodes in +// this array structure. Thus, we will need to scan the leaf data for a group +// of leaf nodes part of the same type tree to create a single result Arrow +// nested array structure. +// +// This is additionally complicated "chunky" repeated fields or very large byte +// arrays +class PARQUET_EXPORT FileReader { + public: + /// Factory function to create a FileReader from a ParquetFileReader and properties + static ::arrow::Status Make(::arrow::MemoryPool* pool, + std::unique_ptr reader, + const ArrowReaderProperties& properties, + std::unique_ptr* out); + + /// Factory function to create a FileReader from a ParquetFileReader + static ::arrow::Status Make(::arrow::MemoryPool* pool, + std::unique_ptr reader, + std::unique_ptr* out); + + // Since the distribution of columns amongst a Parquet file's row groups may + // be uneven (the number of values in each column chunk can be different), we + // provide a column-oriented read interface. The ColumnReader hides the + // details of paging through the file's row groups and yielding + // fully-materialized arrow::Array instances + // + // Returns error status if the column of interest is not flat. + virtual ::arrow::Status GetColumn(int i, std::unique_ptr* out) = 0; + + /// \brief Return arrow schema for all the columns. + virtual ::arrow::Status GetSchema(std::shared_ptr<::arrow::Schema>* out) = 0; + + /// \brief Read column as a whole into a chunked array. + /// + /// The indicated column index is relative to the schema + virtual ::arrow::Status ReadColumn(int i, + std::shared_ptr<::arrow::ChunkedArray>* out) = 0; + + // NOTE: Experimental API + // Reads a specific top level schema field into an Array + // The index i refers the index of the top level schema field, which may + // be nested or flat - e.g. + // + // 0 foo.bar + // foo.bar.baz + // foo.qux + // 1 foo2 + // 2 foo3 + // + // i=0 will read the entire foo struct, i=1 the foo2 primitive column etc + virtual ::arrow::Status ReadSchemaField( + int i, std::shared_ptr<::arrow::ChunkedArray>* out) = 0; + + /// \brief Return a RecordBatchReader of row groups selected from row_group_indices. + /// + /// Note that the ordering in row_group_indices matters. FileReaders must outlive + /// their RecordBatchReaders. + /// + /// \returns error Status if row_group_indices contains an invalid index + virtual ::arrow::Status GetRecordBatchReader( + const std::vector& row_group_indices, + std::unique_ptr<::arrow::RecordBatchReader>* out) = 0; + + ::arrow::Status GetRecordBatchReader(const std::vector& row_group_indices, + std::shared_ptr<::arrow::RecordBatchReader>* out); + + /// \brief Return a RecordBatchReader of row groups selected from + /// row_group_indices, whose columns are selected by column_indices. + /// + /// Note that the ordering in row_group_indices and column_indices + /// matter. FileReaders must outlive their RecordBatchReaders. + /// + /// \returns error Status if either row_group_indices or column_indices + /// contains an invalid index + virtual ::arrow::Status GetRecordBatchReader( + const std::vector& row_group_indices, const std::vector& column_indices, + std::unique_ptr<::arrow::RecordBatchReader>* out) = 0; + + /// \brief Return a generator of record batches. + /// + /// The FileReader must outlive the generator, so this requires that you pass in a + /// shared_ptr. + /// + /// \returns error Result if either row_group_indices or column_indices contains an + /// invalid index + virtual ::arrow::Result< + std::function<::arrow::Future>()>> + GetRecordBatchGenerator(std::shared_ptr reader, + const std::vector row_group_indices, + const std::vector column_indices, + ::arrow::internal::Executor* cpu_executor = NULLPTR, + int row_group_readahead = 0) = 0; + + ::arrow::Status GetRecordBatchReader(const std::vector& row_group_indices, + const std::vector& column_indices, + std::shared_ptr<::arrow::RecordBatchReader>* out); + + /// Read all columns into a Table + virtual ::arrow::Status ReadTable(std::shared_ptr<::arrow::Table>* out) = 0; + + /// \brief Read the given columns into a Table + /// + /// The indicated column indices are relative to the schema + virtual ::arrow::Status ReadTable(const std::vector& column_indices, + std::shared_ptr<::arrow::Table>* out) = 0; + + virtual ::arrow::Status ReadRowGroup(int i, const std::vector& column_indices, + std::shared_ptr<::arrow::Table>* out) = 0; + + virtual ::arrow::Status ReadRowGroup(int i, std::shared_ptr<::arrow::Table>* out) = 0; + + virtual ::arrow::Status ReadRowGroups(const std::vector& row_groups, + const std::vector& column_indices, + std::shared_ptr<::arrow::Table>* out) = 0; + + virtual ::arrow::Status ReadRowGroups(const std::vector& row_groups, + std::shared_ptr<::arrow::Table>* out) = 0; + + /// \brief Scan file contents with one thread, return number of rows + virtual ::arrow::Status ScanContents(std::vector columns, + const int32_t column_batch_size, + int64_t* num_rows) = 0; + + /// \brief Return a reader for the RowGroup, this object must not outlive the + /// FileReader. + virtual std::shared_ptr RowGroup(int row_group_index) = 0; + + /// \brief The number of row groups in the file + virtual int num_row_groups() const = 0; + + virtual ParquetFileReader* parquet_reader() const = 0; + + /// Set whether to use multiple threads during reads of multiple columns. + /// By default only one thread is used. + virtual void set_use_threads(bool use_threads) = 0; + + /// Set number of records to read per batch for the RecordBatchReader. + virtual void set_batch_size(int64_t batch_size) = 0; + + virtual const ArrowReaderProperties& properties() const = 0; + + virtual const SchemaManifest& manifest() const = 0; + + virtual ~FileReader() = default; +}; + +class RowGroupReader { + public: + virtual ~RowGroupReader() = default; + virtual std::shared_ptr Column(int column_index) = 0; + virtual ::arrow::Status ReadTable(const std::vector& column_indices, + std::shared_ptr<::arrow::Table>* out) = 0; + virtual ::arrow::Status ReadTable(std::shared_ptr<::arrow::Table>* out) = 0; + + private: + struct Iterator; +}; + +class ColumnChunkReader { + public: + virtual ~ColumnChunkReader() = default; + virtual ::arrow::Status Read(std::shared_ptr<::arrow::ChunkedArray>* out) = 0; +}; + +// At this point, the column reader is a stream iterator. It only knows how to +// read the next batch of values for a particular column from the file until it +// runs out. +// +// We also do not expose any internal Parquet details, such as row groups. This +// might change in the future. +class PARQUET_EXPORT ColumnReader { + public: + virtual ~ColumnReader() = default; + + // Scan the next array of the indicated size. The actual size of the + // returned array may be less than the passed size depending how much data is + // available in the file. + // + // When all the data in the file has been exhausted, the result is set to + // nullptr. + // + // Returns Status::OK on a successful read, including if you have exhausted + // the data available in the file. + virtual ::arrow::Status NextBatch(int64_t batch_size, + std::shared_ptr<::arrow::ChunkedArray>* out) = 0; +}; + +/// \brief Experimental helper class for bindings (like Python) that struggle +/// either with std::move or C++ exceptions +class PARQUET_EXPORT FileReaderBuilder { + public: + FileReaderBuilder(); + + /// Create FileReaderBuilder from Arrow file and optional properties / metadata + ::arrow::Status Open(std::shared_ptr<::arrow::io::RandomAccessFile> file, + const ReaderProperties& properties = default_reader_properties(), + std::shared_ptr metadata = NULLPTR); + + ParquetFileReader* raw_reader() { return raw_reader_.get(); } + + /// Set Arrow MemoryPool for memory allocation + FileReaderBuilder* memory_pool(::arrow::MemoryPool* pool); + /// Set Arrow reader properties + FileReaderBuilder* properties(const ArrowReaderProperties& arg_properties); + /// Build FileReader instance + ::arrow::Status Build(std::unique_ptr* out); + + private: + ::arrow::MemoryPool* pool_; + ArrowReaderProperties properties_; + std::unique_ptr raw_reader_; +}; + +/// \defgroup parquet-arrow-reader-factories Factory functions for Parquet Arrow readers +/// +/// @{ + +/// \brief Build FileReader from Arrow file and MemoryPool +/// +/// Advanced settings are supported through the FileReaderBuilder class. +PARQUET_EXPORT +::arrow::Status OpenFile(std::shared_ptr<::arrow::io::RandomAccessFile>, + ::arrow::MemoryPool* allocator, + std::unique_ptr* reader); + +/// @} + +PARQUET_EXPORT +::arrow::Status StatisticsAsScalars(const Statistics& Statistics, + std::shared_ptr<::arrow::Scalar>* min, + std::shared_ptr<::arrow::Scalar>* max); + +namespace internal { + +PARQUET_EXPORT +::arrow::Status FuzzReader(const uint8_t* data, int64_t size); + +} // namespace internal +} // namespace arrow +} // namespace parquet diff --git a/utils/local-engine/Storages/ch_parquet/arrow/reader_internal.cc b/utils/local-engine/Storages/ch_parquet/arrow/reader_internal.cc new file mode 100644 index 000000000000..f67945647b69 --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/arrow/reader_internal.cc @@ -0,0 +1,800 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "reader_internal.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/compute/api.h" +#include "arrow/datum.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/scalar.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/base64.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/endian.h" +#include "arrow/util/int_util_internal.h" +#include "arrow/util/logging.h" +#include "arrow/util/string_view.h" +#include "arrow/util/ubsan.h" +#include "arrow/visitor_inline.h" +#include "Storages/ch_parquet/arrow/reader.h" +#include "parquet/arrow/schema.h" +#include "parquet/arrow/schema_internal.h" +#include "Storages/ch_parquet/arrow/column_reader.h" +#include "parquet/platform.h" +#include "parquet/properties.h" +#include "parquet/schema.h" +#include "parquet/statistics.h" +#include "parquet/types.h" +// Required after "arrow/util/int_util_internal.h" (for OPTIONAL) +#include "parquet/windows_compatibility.h" + +using arrow::Array; +using arrow::BooleanArray; +using arrow::ChunkedArray; +using arrow::DataType; +using arrow::Datum; +using arrow::Decimal128; +using arrow::Decimal128Array; +using arrow::Decimal128Type; +using arrow::Decimal256; +using arrow::Decimal256Array; +using arrow::Decimal256Type; +using arrow::Field; +using arrow::Int32Array; +using arrow::ListArray; +using arrow::MemoryPool; +using arrow::ResizableBuffer; +using arrow::Status; +using arrow::StructArray; +using arrow::Table; +using arrow::TimestampArray; + +using ::arrow::BitUtil::FromBigEndian; +using ::arrow::internal::checked_cast; +using ::arrow::internal::checked_pointer_cast; +using ::arrow::internal::SafeLeftShift; +using ::arrow::util::SafeLoadAs; + +using ch_parquet::internal::BinaryRecordReader; +using ch_parquet::internal::DictionaryRecordReader; +using ch_parquet::internal::RecordReader; +using parquet::schema::GroupNode; +using parquet::schema::Node; +using parquet::schema::PrimitiveNode; +using ParquetType = parquet::Type; + +namespace BitUtil = arrow::BitUtil; + +namespace ch_parquet { +using namespace parquet; +namespace arrow { +using namespace parquet::arrow; +namespace { + +template +using ArrayType = typename ::arrow::TypeTraits::ArrayType; + +template +Status MakeMinMaxScalar(const StatisticsType& statistics, + std::shared_ptr<::arrow::Scalar>* min, + std::shared_ptr<::arrow::Scalar>* max) { + *min = ::arrow::MakeScalar(static_cast(statistics.min())); + *max = ::arrow::MakeScalar(static_cast(statistics.max())); + return Status::OK(); +} + +template +Status MakeMinMaxTypedScalar(const StatisticsType& statistics, + std::shared_ptr type, + std::shared_ptr<::arrow::Scalar>* min, + std::shared_ptr<::arrow::Scalar>* max) { + ARROW_ASSIGN_OR_RAISE(*min, ::arrow::MakeScalar(type, statistics.min())); + ARROW_ASSIGN_OR_RAISE(*max, ::arrow::MakeScalar(type, statistics.max())); + return Status::OK(); +} + +template +Status MakeMinMaxIntegralScalar(const StatisticsType& statistics, + const ::arrow::DataType& arrow_type, + std::shared_ptr<::arrow::Scalar>* min, + std::shared_ptr<::arrow::Scalar>* max) { + const auto column_desc = statistics.descr(); + const auto& logical_type = column_desc->logical_type(); + const auto& integer = checked_pointer_cast(logical_type); + const bool is_signed = integer->is_signed(); + + switch (integer->bit_width()) { + case 8: + return is_signed ? MakeMinMaxScalar(statistics, min, max) + : MakeMinMaxScalar(statistics, min, max); + case 16: + return is_signed ? MakeMinMaxScalar(statistics, min, max) + : MakeMinMaxScalar(statistics, min, max); + case 32: + return is_signed ? MakeMinMaxScalar(statistics, min, max) + : MakeMinMaxScalar(statistics, min, max); + case 64: + return is_signed ? MakeMinMaxScalar(statistics, min, max) + : MakeMinMaxScalar(statistics, min, max); + } + + return Status::OK(); +} + +static Status FromInt32Statistics(const Int32Statistics& statistics, + const LogicalType& logical_type, + std::shared_ptr<::arrow::Scalar>* min, + std::shared_ptr<::arrow::Scalar>* max) { + ARROW_ASSIGN_OR_RAISE(auto type, FromInt32(logical_type)); + + switch (logical_type.type()) { + case LogicalType::Type::INT: + return MakeMinMaxIntegralScalar(statistics, *type, min, max); + break; + case LogicalType::Type::DATE: + case LogicalType::Type::TIME: + case LogicalType::Type::NONE: + return MakeMinMaxTypedScalar(statistics, type, min, max); + break; + default: + break; + } + + return Status::NotImplemented("Cannot extract statistics for type "); +} + +static Status FromInt64Statistics(const Int64Statistics& statistics, + const LogicalType& logical_type, + std::shared_ptr<::arrow::Scalar>* min, + std::shared_ptr<::arrow::Scalar>* max) { + ARROW_ASSIGN_OR_RAISE(auto type, FromInt64(logical_type)); + + switch (logical_type.type()) { + case LogicalType::Type::INT: + return MakeMinMaxIntegralScalar(statistics, *type, min, max); + break; + case LogicalType::Type::TIME: + case LogicalType::Type::TIMESTAMP: + case LogicalType::Type::NONE: + return MakeMinMaxTypedScalar(statistics, type, min, max); + break; + default: + break; + } + + return Status::NotImplemented("Cannot extract statistics for type "); +} + +template +Result> FromBigEndianString( + const std::string& data, std::shared_ptr arrow_type) { + ARROW_ASSIGN_OR_RAISE( + DecimalType decimal, + DecimalType::FromBigEndian(reinterpret_cast(data.data()), + static_cast(data.size()))); + return ::arrow::MakeScalar(std::move(arrow_type), decimal); +} + +// Extracts Min and Max scalar from bytes like types (i.e. types where +// decimal is encoded as little endian. +Status ExtractDecimalMinMaxFromBytesType(const Statistics& statistics, + const LogicalType& logical_type, + std::shared_ptr<::arrow::Scalar>* min, + std::shared_ptr<::arrow::Scalar>* max) { + const DecimalLogicalType& decimal_type = + checked_cast(logical_type); + + Result> maybe_type = + Decimal128Type::Make(decimal_type.precision(), decimal_type.scale()); + std::shared_ptr arrow_type; + if (maybe_type.ok()) { + arrow_type = maybe_type.ValueOrDie(); + ARROW_ASSIGN_OR_RAISE( + *min, FromBigEndianString(statistics.EncodeMin(), arrow_type)); + ARROW_ASSIGN_OR_RAISE(*max, FromBigEndianString(statistics.EncodeMax(), + std::move(arrow_type))); + return Status::OK(); + } + // Fallback to see if Decimal256 can represent the type. + ARROW_ASSIGN_OR_RAISE( + arrow_type, Decimal256Type::Make(decimal_type.precision(), decimal_type.scale())); + ARROW_ASSIGN_OR_RAISE( + *min, FromBigEndianString(statistics.EncodeMin(), arrow_type)); + ARROW_ASSIGN_OR_RAISE(*max, FromBigEndianString(statistics.EncodeMax(), + std::move(arrow_type))); + + return Status::OK(); +} + +Status ByteArrayStatisticsAsScalars(const Statistics& statistics, + std::shared_ptr<::arrow::Scalar>* min, + std::shared_ptr<::arrow::Scalar>* max) { + auto logical_type = statistics.descr()->logical_type(); + if (logical_type->type() == LogicalType::Type::DECIMAL) { + return ExtractDecimalMinMaxFromBytesType(statistics, *logical_type, min, max); + } + std::shared_ptr<::arrow::DataType> type; + if (statistics.descr()->physical_type() == Type::FIXED_LEN_BYTE_ARRAY) { + type = ::arrow::fixed_size_binary(statistics.descr()->type_length()); + } else { + type = logical_type->type() == LogicalType::Type::STRING ? ::arrow::utf8() + : ::arrow::binary(); + } + ARROW_ASSIGN_OR_RAISE( + *min, ::arrow::MakeScalar(type, Buffer::FromString(statistics.EncodeMin()))); + ARROW_ASSIGN_OR_RAISE( + *max, ::arrow::MakeScalar(type, Buffer::FromString(statistics.EncodeMax()))); + + return Status::OK(); +} + +} // namespace + +Status StatisticsAsScalars(const Statistics& statistics, + std::shared_ptr<::arrow::Scalar>* min, + std::shared_ptr<::arrow::Scalar>* max) { + if (!statistics.HasMinMax()) { + return Status::Invalid("Statistics has no min max."); + } + + auto column_desc = statistics.descr(); + if (column_desc == nullptr) { + return Status::Invalid("Statistics carries no descriptor, can't infer arrow type."); + } + + auto physical_type = column_desc->physical_type(); + auto logical_type = column_desc->logical_type(); + switch (physical_type) { + case Type::BOOLEAN: + return MakeMinMaxScalar( + checked_cast(statistics), min, max); + case Type::FLOAT: + return MakeMinMaxScalar( + checked_cast(statistics), min, max); + case Type::DOUBLE: + return MakeMinMaxScalar( + checked_cast(statistics), min, max); + case Type::INT32: + return FromInt32Statistics(checked_cast(statistics), + *logical_type, min, max); + case Type::INT64: + return FromInt64Statistics(checked_cast(statistics), + *logical_type, min, max); + case Type::BYTE_ARRAY: + case Type::FIXED_LEN_BYTE_ARRAY: + return ByteArrayStatisticsAsScalars(statistics, min, max); + default: + return Status::NotImplemented("Extract statistics unsupported for physical_type ", + physical_type, " unsupported."); + } + + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// Primitive types + +namespace { + +template +Status TransferInt(RecordReader* reader, MemoryPool* pool, + const std::shared_ptr& type, Datum* out) { + using ArrowCType = typename ArrowType::c_type; + using ParquetCType = typename ParquetType::c_type; + int64_t length = reader->values_written(); + ARROW_ASSIGN_OR_RAISE(auto data, + ::arrow::AllocateBuffer(length * sizeof(ArrowCType), pool)); + + auto values = reinterpret_cast(reader->values()); + auto out_ptr = reinterpret_cast(data->mutable_data()); + std::copy(values, values + length, out_ptr); + *out = std::make_shared>( + type, length, std::move(data), reader->ReleaseIsValid(), reader->null_count()); + return Status::OK(); +} + +std::shared_ptr TransferZeroCopy(RecordReader* reader, + const std::shared_ptr& type) { + std::vector> buffers = {reader->ReleaseIsValid(), + reader->ReleaseValues()}; + auto data = std::make_shared<::arrow::ArrayData>(type, reader->values_written(), + buffers, reader->null_count()); + return ::arrow::MakeArray(data); +} + +Status TransferBool(RecordReader* reader, MemoryPool* pool, Datum* out) { + int64_t length = reader->values_written(); + + const int64_t buffer_size = BitUtil::BytesForBits(length); + ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(buffer_size, pool)); + + // Transfer boolean values to packed bitmap + auto values = reinterpret_cast(reader->values()); + uint8_t* data_ptr = data->mutable_data(); + memset(data_ptr, 0, buffer_size); + + for (int64_t i = 0; i < length; i++) { + if (values[i]) { + ::arrow::BitUtil::SetBit(data_ptr, i); + } + } + + *out = std::make_shared(length, std::move(data), reader->ReleaseIsValid(), + reader->null_count()); + return Status::OK(); +} + +Status TransferInt96(RecordReader* reader, MemoryPool* pool, + const std::shared_ptr& type, Datum* out, + const ::arrow::TimeUnit::type int96_arrow_time_unit) { + int64_t length = reader->values_written(); + auto values = reinterpret_cast(reader->values()); + ARROW_ASSIGN_OR_RAISE(auto data, + ::arrow::AllocateBuffer(length * sizeof(int64_t), pool)); + auto data_ptr = reinterpret_cast(data->mutable_data()); + for (int64_t i = 0; i < length; i++) { + if (values[i].value[2] == 0) { + // Happens for null entries: avoid triggering UBSAN as that Int96 timestamp + // isn't representable as a 64-bit Unix timestamp. + *data_ptr++ = 0; + } else { + switch (int96_arrow_time_unit) { + case ::arrow::TimeUnit::NANO: + *data_ptr++ = Int96GetNanoSeconds(values[i]); + break; + case ::arrow::TimeUnit::MICRO: + *data_ptr++ = Int96GetMicroSeconds(values[i]); + break; + case ::arrow::TimeUnit::MILLI: + *data_ptr++ = Int96GetMilliSeconds(values[i]); + break; + case ::arrow::TimeUnit::SECOND: + *data_ptr++ = Int96GetSeconds(values[i]); + break; + } + } + } + *out = std::make_shared(type, length, std::move(data), + reader->ReleaseIsValid(), reader->null_count()); + return Status::OK(); +} + +Status TransferDate64(RecordReader* reader, MemoryPool* pool, + const std::shared_ptr& type, Datum* out) { + int64_t length = reader->values_written(); + auto values = reinterpret_cast(reader->values()); + + ARROW_ASSIGN_OR_RAISE(auto data, + ::arrow::AllocateBuffer(length * sizeof(int64_t), pool)); + auto out_ptr = reinterpret_cast(data->mutable_data()); + + for (int64_t i = 0; i < length; i++) { + *out_ptr++ = static_cast(values[i]) * kMillisecondsPerDay; + } + + *out = std::make_shared<::arrow::Date64Array>( + type, length, std::move(data), reader->ReleaseIsValid(), reader->null_count()); + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// Binary, direct to dictionary-encoded + +Status TransferDictionary(RecordReader* reader, + const std::shared_ptr& logical_value_type, + std::shared_ptr* out) { + auto dict_reader = dynamic_cast(reader); + DCHECK(dict_reader); + *out = dict_reader->GetResult(); + if (!logical_value_type->Equals(*(*out)->type())) { + ARROW_ASSIGN_OR_RAISE(*out, (*out)->View(logical_value_type)); + } + return Status::OK(); +} + +Status TransferBinary(RecordReader* reader, MemoryPool* pool, + const std::shared_ptr& logical_value_type, + std::shared_ptr* out) { + if (reader->read_dictionary()) { + return TransferDictionary( + reader, ::arrow::dictionary(::arrow::int32(), logical_value_type), out); + } + ::arrow::compute::ExecContext ctx(pool); + ::arrow::compute::CastOptions cast_options; + cast_options.allow_invalid_utf8 = true; // avoid spending time validating UTF8 data + + auto binary_reader = dynamic_cast(reader); + DCHECK(binary_reader); + + auto chunks = binary_reader->GetBuilderChunks(); + if (chunks.size() > 0 && dynamic_cast<::ch_parquet::internal::CHStringArray*>(chunks.at(0).get()) != nullptr) { + //bypass any cast if it's already CHStringArray + *out = std::make_shared(chunks, logical_value_type); + return Status::OK(); + } + + for (auto& chunk : chunks) { + if (!chunk->type()->Equals(*logical_value_type)) { + // XXX: if a LargeBinary chunk is larger than 2GB, the MSBs of offsets + // will be lost because they are first created as int32 and then cast to int64. + ARROW_ASSIGN_OR_RAISE( + chunk, ::arrow::compute::Cast(*chunk, logical_value_type, cast_options, &ctx)); + } + } + *out = std::make_shared(chunks, logical_value_type); + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// INT32 / INT64 / BYTE_ARRAY / FIXED_LEN_BYTE_ARRAY -> Decimal128 || Decimal256 + +template +Status RawBytesToDecimalBytes(const uint8_t* value, int32_t byte_width, + uint8_t* out_buf) { + ARROW_ASSIGN_OR_RAISE(DecimalType t, DecimalType::FromBigEndian(value, byte_width)); + t.ToBytes(out_buf); + return ::arrow::Status::OK(); +} + +template +struct DecimalTypeTrait; + +template <> +struct DecimalTypeTrait<::arrow::Decimal128Array> { + using value = ::arrow::Decimal128; +}; + +template <> +struct DecimalTypeTrait<::arrow::Decimal256Array> { + using value = ::arrow::Decimal256; +}; + +template +struct DecimalConverter { + static inline Status ConvertToDecimal(const Array& array, + const std::shared_ptr&, + MemoryPool* pool, std::shared_ptr*) { + return Status::NotImplemented("not implemented"); + } +}; + +template +struct DecimalConverter { + static inline Status ConvertToDecimal(const Array& array, + const std::shared_ptr& type, + MemoryPool* pool, std::shared_ptr* out) { + const auto& fixed_size_binary_array = + checked_cast(array); + + // The byte width of each decimal value + const int32_t type_length = + checked_cast(*type).byte_width(); + + // number of elements in the entire array + const int64_t length = fixed_size_binary_array.length(); + + // Get the byte width of the values in the FixedSizeBinaryArray. Most of the time + // this will be different from the decimal array width because we write the minimum + // number of bytes necessary to represent a given precision + const int32_t byte_width = + checked_cast(*fixed_size_binary_array.type()) + .byte_width(); + // allocate memory for the decimal array + ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length * type_length, pool)); + + // raw bytes that we can write to + uint8_t* out_ptr = data->mutable_data(); + + // convert each FixedSizeBinary value to valid decimal bytes + const int64_t null_count = fixed_size_binary_array.null_count(); + + using DecimalType = typename DecimalTypeTrait::value; + if (null_count > 0) { + for (int64_t i = 0; i < length; ++i, out_ptr += type_length) { + if (!fixed_size_binary_array.IsNull(i)) { + RETURN_NOT_OK(RawBytesToDecimalBytes( + fixed_size_binary_array.GetValue(i), byte_width, out_ptr)); + } else { + std::memset(out_ptr, 0, type_length); + } + } + } else { + for (int64_t i = 0; i < length; ++i, out_ptr += type_length) { + RETURN_NOT_OK(RawBytesToDecimalBytes( + fixed_size_binary_array.GetValue(i), byte_width, out_ptr)); + } + } + + *out = std::make_shared( + type, length, std::move(data), fixed_size_binary_array.null_bitmap(), null_count); + + return Status::OK(); + } +}; + +template +struct DecimalConverter { + static inline Status ConvertToDecimal(const Array& array, + const std::shared_ptr& type, + MemoryPool* pool, std::shared_ptr* out) { + const auto& binary_array = checked_cast(array); + const int64_t length = binary_array.length(); + + const auto& decimal_type = checked_cast(*type); + const int64_t type_length = decimal_type.byte_width(); + + ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length * type_length, pool)); + + // raw bytes that we can write to + uint8_t* out_ptr = data->mutable_data(); + + const int64_t null_count = binary_array.null_count(); + + // convert each BinaryArray value to valid decimal bytes + for (int64_t i = 0; i < length; i++, out_ptr += type_length) { + int32_t record_len = 0; + const uint8_t* record_loc = binary_array.GetValue(i, &record_len); + + if (record_len < 0 || record_len > type_length) { + return Status::Invalid("Invalid BYTE_ARRAY length for ", type->ToString()); + } + + auto out_ptr_view = reinterpret_cast(out_ptr); + out_ptr_view[0] = 0; + out_ptr_view[1] = 0; + + // only convert rows that are not null if there are nulls, or + // all rows, if there are not + if ((null_count > 0 && !binary_array.IsNull(i)) || null_count <= 0) { + using DecimalType = typename DecimalTypeTrait::value; + RETURN_NOT_OK( + RawBytesToDecimalBytes(record_loc, record_len, out_ptr)); + } + } + *out = std::make_shared(type, length, std::move(data), + binary_array.null_bitmap(), null_count); + return Status::OK(); + } +}; + +/// \brief Convert an Int32 or Int64 array into a Decimal128Array +/// The parquet spec allows systems to write decimals in int32, int64 if the values are +/// small enough to fit in less 4 bytes or less than 8 bytes, respectively. +/// This function implements the conversion from int32 and int64 arrays to decimal arrays. +template < + typename ParquetIntegerType, + typename = ::arrow::enable_if_t::value || + std::is_same::value>> +static Status DecimalIntegerTransfer(RecordReader* reader, MemoryPool* pool, + const std::shared_ptr& type, Datum* out) { + // Decimal128 and Decimal256 are only Arrow constructs. Parquet does not + // specifically distinguish between decimal byte widths. + // Decimal256 isn't relevant here because the Arrow-Parquet C++ bindings never + // write Decimal values as integers and if the decimal value can fit in an + // integer it is wasteful to use Decimal256. Put another way, the only + // way an integer column could be construed as Decimal256 is if an arrow + // schema was stored as metadata in the file indicating the column was + // Decimal256. The current Arrow-Parquet C++ bindings will never do this. + DCHECK(type->id() == ::arrow::Type::DECIMAL128); + + const int64_t length = reader->values_written(); + + using ElementType = typename ParquetIntegerType::c_type; + static_assert(std::is_same::value || + std::is_same::value, + "ElementType must be int32_t or int64_t"); + + const auto values = reinterpret_cast(reader->values()); + + const auto& decimal_type = checked_cast(*type); + const int64_t type_length = decimal_type.byte_width(); + + ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length * type_length, pool)); + uint8_t* out_ptr = data->mutable_data(); + + using ::arrow::BitUtil::FromLittleEndian; + + for (int64_t i = 0; i < length; ++i, out_ptr += type_length) { + // sign/zero extend int32_t values, otherwise a no-op + const auto value = static_cast(values[i]); + + ::arrow::Decimal128 decimal(value); + decimal.ToBytes(out_ptr); + } + + if (reader->nullable_values()) { + std::shared_ptr is_valid = reader->ReleaseIsValid(); + *out = std::make_shared(type, length, std::move(data), is_valid, + reader->null_count()); + } else { + *out = std::make_shared(type, length, std::move(data)); + } + return Status::OK(); +} + +/// \brief Convert an arrow::BinaryArray to an arrow::Decimal{128,256}Array +/// We do this by: +/// 1. Creating an arrow::BinaryArray from the RecordReader's builder +/// 2. Allocating a buffer for the arrow::Decimal{128,256}Array +/// 3. Converting the big-endian bytes in each BinaryArray entry to two integers +/// representing the high and low bits of each decimal value. +template +Status TransferDecimal(RecordReader* reader, MemoryPool* pool, + const std::shared_ptr& type, Datum* out) { + auto binary_reader = dynamic_cast(reader); + DCHECK(binary_reader); + ::arrow::ArrayVector chunks = binary_reader->GetBuilderChunks(); + for (size_t i = 0; i < chunks.size(); ++i) { + std::shared_ptr chunk_as_decimal; + auto fn = &DecimalConverter::ConvertToDecimal; + RETURN_NOT_OK(fn(*chunks[i], type, pool, &chunk_as_decimal)); + // Replace the chunk, which will hopefully also free memory as we go + chunks[i] = chunk_as_decimal; + } + *out = std::make_shared(chunks, type); + return Status::OK(); +} + +} // namespace + +#define TRANSFER_INT32(ENUM, ArrowType) \ + case ::arrow::Type::ENUM: { \ + Status s = TransferInt(reader, pool, value_type, &result); \ + RETURN_NOT_OK(s); \ + } break; + +#define TRANSFER_INT64(ENUM, ArrowType) \ + case ::arrow::Type::ENUM: { \ + Status s = TransferInt(reader, pool, value_type, &result); \ + RETURN_NOT_OK(s); \ + } break; + +Status TransferColumnData(RecordReader* reader, std::shared_ptr value_type, + const ColumnDescriptor* descr, MemoryPool* pool, + std::shared_ptr* out) { + Datum result; + std::shared_ptr chunked_result; + switch (value_type->id()) { + case ::arrow::Type::DICTIONARY: { + RETURN_NOT_OK(TransferDictionary(reader, value_type, &chunked_result)); + result = chunked_result; + } break; + case ::arrow::Type::NA: { + result = std::make_shared<::arrow::NullArray>(reader->values_written()); + break; + } + case ::arrow::Type::INT32: + case ::arrow::Type::INT64: + case ::arrow::Type::FLOAT: + case ::arrow::Type::DOUBLE: + case ::arrow::Type::DATE32: + result = TransferZeroCopy(reader, value_type); + break; + case ::arrow::Type::BOOL: + RETURN_NOT_OK(TransferBool(reader, pool, &result)); + break; + TRANSFER_INT32(UINT8, ::arrow::UInt8Type); + TRANSFER_INT32(INT8, ::arrow::Int8Type); + TRANSFER_INT32(UINT16, ::arrow::UInt16Type); + TRANSFER_INT32(INT16, ::arrow::Int16Type); + TRANSFER_INT32(UINT32, ::arrow::UInt32Type); + TRANSFER_INT64(UINT64, ::arrow::UInt64Type); + TRANSFER_INT32(TIME32, ::arrow::Time32Type); + TRANSFER_INT64(TIME64, ::arrow::Time64Type); + case ::arrow::Type::DATE64: + RETURN_NOT_OK(TransferDate64(reader, pool, value_type, &result)); + break; + case ::arrow::Type::FIXED_SIZE_BINARY: + case ::arrow::Type::BINARY: + case ::arrow::Type::STRING: + case ::arrow::Type::LARGE_BINARY: + case ::arrow::Type::LARGE_STRING: { + RETURN_NOT_OK(TransferBinary(reader, pool, value_type, &chunked_result)); + result = chunked_result; + } break; + case ::arrow::Type::DECIMAL128: { + switch (descr->physical_type()) { + case ::parquet::Type::INT32: { + auto fn = DecimalIntegerTransfer; + RETURN_NOT_OK(fn(reader, pool, value_type, &result)); + } break; + case ::parquet::Type::INT64: { + auto fn = &DecimalIntegerTransfer; + RETURN_NOT_OK(fn(reader, pool, value_type, &result)); + } break; + case ::parquet::Type::BYTE_ARRAY: { + auto fn = &TransferDecimal; + RETURN_NOT_OK(fn(reader, pool, value_type, &result)); + } break; + case ::parquet::Type::FIXED_LEN_BYTE_ARRAY: { + auto fn = &TransferDecimal; + RETURN_NOT_OK(fn(reader, pool, value_type, &result)); + } break; + default: + return Status::Invalid( + "Physical type for decimal128 must be int32, int64, byte array, or fixed " + "length binary"); + } + } break; + case ::arrow::Type::DECIMAL256: + switch (descr->physical_type()) { + case ::parquet::Type::BYTE_ARRAY: { + auto fn = &TransferDecimal; + RETURN_NOT_OK(fn(reader, pool, value_type, &result)); + } break; + case ::parquet::Type::FIXED_LEN_BYTE_ARRAY: { + auto fn = &TransferDecimal; + RETURN_NOT_OK(fn(reader, pool, value_type, &result)); + } break; + default: + return Status::Invalid( + "Physical type for decimal256 must be fixed length binary"); + } + break; + + case ::arrow::Type::TIMESTAMP: { + const ::arrow::TimestampType& timestamp_type = + checked_cast<::arrow::TimestampType&>(*value_type); + if (descr->physical_type() == ::parquet::Type::INT96) { + RETURN_NOT_OK( + TransferInt96(reader, pool, value_type, &result, timestamp_type.unit())); + } else { + switch (timestamp_type.unit()) { + case ::arrow::TimeUnit::MILLI: + case ::arrow::TimeUnit::MICRO: + case ::arrow::TimeUnit::NANO: + result = TransferZeroCopy(reader, value_type); + break; + default: + return Status::NotImplemented("TimeUnit not supported"); + } + } + } break; + default: + return Status::NotImplemented("No support for reading columns of type ", + value_type->ToString()); + } + + if (result.kind() == Datum::ARRAY) { + *out = std::make_shared(result.make_array()); + } else if (result.kind() == Datum::CHUNKED_ARRAY) { + *out = result.chunked_array(); + } else { + DCHECK(false) << "Should be impossible, result was " << result.ToString(); + } + + return Status::OK(); +} + +} // namespace arrow +} // namespace parquet diff --git a/utils/local-engine/Storages/ch_parquet/arrow/reader_internal.h b/utils/local-engine/Storages/ch_parquet/arrow/reader_internal.h new file mode 100644 index 000000000000..1060d90ba84f --- /dev/null +++ b/utils/local-engine/Storages/ch_parquet/arrow/reader_internal.h @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "parquet/arrow/schema.h" +#include "Storages/ch_parquet/arrow/column_reader.h" +#include "parquet/file_reader.h" +#include "parquet/metadata.h" +#include "parquet/platform.h" +#include "parquet/schema.h" + +namespace arrow { + +class Array; +class ChunkedArray; +class DataType; +class Field; +class KeyValueMetadata; +class Schema; + +} // namespace arrow + +using arrow::Status; + +namespace parquet +{ + +class ArrowReaderProperties; +} + +namespace ch_parquet +{ +using namespace parquet; + +namespace arrow { + +class ColumnReaderImpl; + +// ---------------------------------------------------------------------- +// Iteration utilities + +// Abstraction to decouple row group iteration details from the ColumnReader, +// so we can read only a single row group if we want +class FileColumnIterator { + public: + explicit FileColumnIterator(int column_index, ParquetFileReader* reader, + std::vector row_groups) + : column_index_(column_index), + reader_(reader), + schema_(reader->metadata()->schema()), + row_groups_(row_groups.begin(), row_groups.end()) {} + + virtual ~FileColumnIterator() {} + + std::unique_ptr<::parquet::PageReader> NextChunk() { + if (row_groups_.empty()) { + return nullptr; + } + + auto row_group_reader = reader_->RowGroup(row_groups_.front()); + row_groups_.pop_front(); + return row_group_reader->GetColumnPageReader(column_index_); + } + + const SchemaDescriptor* schema() const { return schema_; } + + const ColumnDescriptor* descr() const { return schema_->Column(column_index_); } + + std::shared_ptr metadata() const { return reader_->metadata(); } + + int column_index() const { return column_index_; } + + protected: + int column_index_; + ParquetFileReader* reader_; + const SchemaDescriptor* schema_; + std::deque row_groups_; +}; + +using FileColumnIteratorFactory = + std::function; + +Status TransferColumnData(::ch_parquet::internal::RecordReader* reader, + std::shared_ptr<::arrow::DataType> value_type, + const ColumnDescriptor* descr, ::arrow::MemoryPool* pool, + std::shared_ptr<::arrow::ChunkedArray>* out); + +struct ReaderContext { + ParquetFileReader* reader; + ::arrow::MemoryPool* pool; + FileColumnIteratorFactory iterator_factory; + bool filter_leaves; + std::shared_ptr> included_leaves; + + bool IncludesLeaf(int leaf_index) const { + if (this->filter_leaves) { + return this->included_leaves->find(leaf_index) != this->included_leaves->end(); + } + return true; + } +}; + +} // namespace arrow +} // namespace parquet diff --git a/utils/local-engine/build/build.sh b/utils/local-engine/build/build.sh new file mode 100755 index 000000000000..107c33aceeab --- /dev/null +++ b/utils/local-engine/build/build.sh @@ -0,0 +1 @@ +sudo docker run --rm --volume="$2":/output --volume="$1":/clickhouse --volume=/tmp/.cache:/ccache -e ENABLE_EMBEDDED_COMPILER=ON libchbuilder:1.0 \ No newline at end of file diff --git a/utils/local-engine/build/image/Dockerfile b/utils/local-engine/build/image/Dockerfile new file mode 100644 index 000000000000..0f8d6f5cdda8 --- /dev/null +++ b/utils/local-engine/build/image/Dockerfile @@ -0,0 +1,81 @@ +# rebuild in #33610 +# docker build -t clickhouse/binary-builder . +FROM ubuntu:20.04 + +# ARG for quick switch to a given ubuntu mirror +ARG apt_archive="http://mirrors.aliyun.com" +RUN sed -i "s|http://archive.ubuntu.com|$apt_archive|g" /etc/apt/sources.list + +ENV DEBIAN_FRONTEND=noninteractive LLVM_VERSION=14 + +RUN apt-get update \ + && apt-get install \ + apt-transport-https \ + apt-utils \ + ca-certificates \ + curl \ + dnsutils \ + gnupg \ + iputils-ping \ + lsb-release \ + wget \ + --yes --no-install-recommends --verbose-versions \ + && export LLVM_PUBKEY_HASH="bda960a8da687a275a2078d43c111d66b1c6a893a3275271beedf266c1ff4a0cdecb429c7a5cccf9f486ea7aa43fd27f" \ + && wget -nv -O /tmp/llvm-snapshot.gpg.key https://apt.llvm.org/llvm-snapshot.gpg.key \ + && echo "${LLVM_PUBKEY_HASH} /tmp/llvm-snapshot.gpg.key" | sha384sum -c \ + && apt-key add /tmp/llvm-snapshot.gpg.key \ + && export CODENAME="$(lsb_release --codename --short | tr 'A-Z' 'a-z')" \ + && echo "deb https://apt.llvm.org/${CODENAME}/ llvm-toolchain-${CODENAME}-${LLVM_VERSION} main" >> \ + /etc/apt/sources.list \ + && apt-get clean + +RUN curl -s https://apt.kitware.com/keys/kitware-archive-latest.asc | \ + gpg --dearmor - > /etc/apt/trusted.gpg.d/kitware.gpg && \ + echo "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" >> /etc/apt/sources.list +# initial packages +RUN apt-get update \ + && apt-get install \ + bash \ + openjdk-8-jdk\ + # build-essential \ + ccache \ + clang-${LLVM_VERSION} \ +# clang-tidy-${LLVM_VERSION} \ + cmake \ + fakeroot \ +# gdb \ +# git \ +# gperf \ + lld-${LLVM_VERSION} \ + llvm-${LLVM_VERSION} \ +# llvm-${LLVM_VERSION}-dev \ +# moreutils \ + ninja-build \ +# pigz \ +# rename \ + software-properties-common \ + tzdata \ + --yes --no-install-recommends \ + && apt-get clean + +# This symlink required by gcc to find lld compiler +RUN ln -s /usr/bin/lld-${LLVM_VERSION} /usr/bin/ld.lld + +ENV RUSTUP_HOME=/rust/rustup +ENV CARGO_HOME=/rust/cargo +ENV PATH="/rust/cargo/env:${PATH}" +ENV PATH="/rust/cargo/bin:${PATH}" +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y && \ + chmod 777 -R /rust && \ + rustup target add aarch64-unknown-linux-gnu && \ + rustup target add x86_64-apple-darwin && \ + rustup target add x86_64-unknown-freebsd && \ + rustup target add aarch64-apple-darwin && \ + rustup target add powerpc64le-unknown-linux-gnu + +ENV CC=clang-${LLVM_VERSION} +ENV CXX=clang++-${LLVM_VERSION} + +ADD ./build.sh /build.sh +RUN chmod +x /build.sh +CMD ["bash", "-c", "/build.sh 2>&1"] \ No newline at end of file diff --git a/utils/local-engine/build/image/build.sh b/utils/local-engine/build/image/build.sh new file mode 100644 index 000000000000..7d88445c7c4b --- /dev/null +++ b/utils/local-engine/build/image/build.sh @@ -0,0 +1,11 @@ +mkdir -p /build && cd /build || exit +export CCACHE_DIR=/ccache +export CCACHE_BASEDIR=/build +export CCACHE_NOHASHDIR=true +export CCACHE_COMPILERCHECK=content +export CCACHE_MAXSIZE=15G + +cmake -G Ninja "-DCMAKE_C_COMPILER=$CC" "-DCMAKE_CXX_COMPILER=$CXX" "-DCMAKE_BUILD_TYPE=Release" "-DENABLE_PROTOBUF=1" "-DENABLE_EMBEDDED_COMPILER=$ENABLE_EMBEDDED_COMPILER" "-DENABLE_TESTS=OFF" "-DWERROR=OFF" "-DENABLE_JEMALLOC=1" "-DENABLE_MULTITARGET_CODE=ON" /clickhouse +ninja ch + +cp /build/utils/local-engine/libch.so "/output/libch_$(date +%Y%m%d).so" \ No newline at end of file diff --git a/utils/local-engine/include/com_intel_oap_row_RowIterator.h b/utils/local-engine/include/com_intel_oap_row_RowIterator.h new file mode 100644 index 000000000000..efd36c4f4c23 --- /dev/null +++ b/utils/local-engine/include/com_intel_oap_row_RowIterator.h @@ -0,0 +1,45 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class com_intel_oap_row_RowIterator */ + +#ifndef _Included_com_intel_oap_row_RowIterator +#define _Included_com_intel_oap_row_RowIterator +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: com_intel_oap_row_RowIterator + * Method: nativeHasNext + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL Java_com_intel_oap_row_RowIterator_nativeHasNext + (JNIEnv *, jobject, jlong); + +/* + * Class: com_intel_oap_row_RowIterator + * Method: nativeNext + * Signature: (J)Lcom/intel/oap/row/SparkRowInfo; + */ +JNIEXPORT jobject JNICALL Java_com_intel_oap_row_RowIterator_nativeNext + (JNIEnv *, jobject, jlong); + +/* + * Class: com_intel_oap_row_RowIterator + * Method: nativeClose + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_com_intel_oap_row_RowIterator_nativeClose + (JNIEnv *, jobject, jlong); + +/* + * Class: com_intel_oap_row_RowIterator + * Method: nativeFetchMetrics + * Signature: (J)Lcom/intel/oap/vectorized/MetricsObject; + */ +JNIEXPORT jobject JNICALL Java_com_intel_oap_row_RowIterator_nativeFetchMetrics + (JNIEnv *, jobject, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/local-engine/include/com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper.h b/utils/local-engine/include/com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper.h new file mode 100644 index 000000000000..b83aaea04175 --- /dev/null +++ b/utils/local-engine/include/com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper.h @@ -0,0 +1,69 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper */ + +#ifndef _Included_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper +#define _Included_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper + * Method: nativeInitNative + * Signature: ()V + */ +JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeInitNative + (JNIEnv *, jobject); + +/* + * Class: com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper + * Method: nativeCreateKernelWithIterator + * Signature: (J[B[Lcom/intel/oap/execution/ColumnarNativeIterator;)J + */ +JNIEXPORT jlong JNICALL Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeCreateKernelWithIterator + (JNIEnv *, jobject, jlong, jbyteArray, jobjectArray); + +/* + * Class: com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper + * Method: nativeCreateKernelWithRowIterator + * Signature: ([B)J + */ +JNIEXPORT jlong JNICALL Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeCreateKernelWithRowIterator + (JNIEnv *, jobject, jbyteArray); + +/* + * Class: com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper + * Method: nativeSetJavaTmpDir + * Signature: (Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeSetJavaTmpDir + (JNIEnv *, jobject, jstring); + +/* + * Class: com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper + * Method: nativeSetBatchSize + * Signature: (I)V + */ +JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeSetBatchSize + (JNIEnv *, jobject, jint); + +/* + * Class: com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper + * Method: nativeSetMetricsTime + * Signature: (Z)V + */ +JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeSetMetricsTime + (JNIEnv *, jobject, jboolean); + +/* + * Class: com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper + * Method: nativeClose + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeClose + (JNIEnv *, jobject, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/utils/local-engine/jni/ReservationListenerWrapper.cpp b/utils/local-engine/jni/ReservationListenerWrapper.cpp new file mode 100644 index 000000000000..faad1388530e --- /dev/null +++ b/utils/local-engine/jni/ReservationListenerWrapper.cpp @@ -0,0 +1,44 @@ +#include "ReservationListenerWrapper.h" +#include +#include + +namespace local_engine +{ +jclass ReservationListenerWrapper::reservation_listener_class = nullptr; +jmethodID ReservationListenerWrapper::reservation_listener_reserve = nullptr; +jmethodID ReservationListenerWrapper::reservation_listener_reserve_or_throw = nullptr; +jmethodID ReservationListenerWrapper::reservation_listener_unreserve = nullptr; + +ReservationListenerWrapper::ReservationListenerWrapper(jobject listener_) : listener(listener_) +{ +} + +ReservationListenerWrapper::~ReservationListenerWrapper() +{ + GET_JNIENV(env) + env->DeleteGlobalRef(listener); + CLEAN_JNIENV +} + +void ReservationListenerWrapper::reserve(int64_t size) +{ + GET_JNIENV(env) + safeCallVoidMethod(env, listener, reservation_listener_reserve, size); + CLEAN_JNIENV +} + +void ReservationListenerWrapper::reserveOrThrow(int64_t size) +{ + GET_JNIENV(env) + safeCallVoidMethod(env, listener, reservation_listener_reserve_or_throw, size); + CLEAN_JNIENV + +} + +void ReservationListenerWrapper::free(int64_t size) +{ + GET_JNIENV(env) + safeCallVoidMethod(env, listener, reservation_listener_unreserve, size); + CLEAN_JNIENV +} +} diff --git a/utils/local-engine/jni/ReservationListenerWrapper.h b/utils/local-engine/jni/ReservationListenerWrapper.h new file mode 100644 index 000000000000..a0e61a651a3e --- /dev/null +++ b/utils/local-engine/jni/ReservationListenerWrapper.h @@ -0,0 +1,26 @@ +#pragma once +#include +#include +#include + +namespace local_engine +{ +class ReservationListenerWrapper +{ +public: + static jclass reservation_listener_class; + static jmethodID reservation_listener_reserve; + static jmethodID reservation_listener_reserve_or_throw; + static jmethodID reservation_listener_unreserve; + + explicit ReservationListenerWrapper(jobject listener); + ~ReservationListenerWrapper(); + void reserve(int64_t size); + void reserveOrThrow(int64_t size); + void free(int64_t size); + +private: + jobject listener; +}; +using ReservationListenerWrapperPtr = std::shared_ptr; +} diff --git a/utils/local-engine/jni/jni_common.cpp b/utils/local-engine/jni/jni_common.cpp new file mode 100644 index 000000000000..7df7d34c841b --- /dev/null +++ b/utils/local-engine/jni/jni_common.cpp @@ -0,0 +1,77 @@ +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +jclass CreateGlobalExceptionClassReference(JNIEnv* env, const char* class_name) +{ + jclass local_class = env->FindClass(class_name); + jclass global_class = static_cast(env->NewGlobalRef(local_class)); + env->DeleteLocalRef(local_class); + if (global_class == nullptr) { + std::string error_msg = "Unable to createGlobalClassReference for" + std::string(class_name); + throw std::runtime_error(error_msg); + } + return global_class; +} + +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) +{ + jclass local_class = env->FindClass(class_name); + jclass global_class = static_cast(env->NewGlobalRef(local_class)); + env->DeleteLocalRef(local_class); + if (global_class == nullptr) { + std::string error_message = + "Unable to createGlobalClassReference for" + std::string(class_name); + env->ThrowNew(JniErrorsGlobalState::instance().getIllegalAccessExceptionClass(), error_message.c_str()); + } + return global_class; +} + +jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig) +{ + jmethodID ret = env->GetMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find method " + std::string(name) + + " within signature" + std::string(sig); + env->ThrowNew(JniErrorsGlobalState::instance().getIllegalAccessExceptionClass(), error_message.c_str()); + } + + return ret; +} + +jmethodID GetStaticMethodID(JNIEnv * env, jclass this_class, const char * name, const char * sig) +{ + jmethodID ret = env->GetStaticMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find static method " + std::string(name) + + " within signature" + std::string(sig); + env->ThrowNew(JniErrorsGlobalState::instance().getIllegalAccessExceptionClass(), error_message.c_str()); + } + return ret; +} + +jstring charTojstring(JNIEnv* env, const char* pat) { + jclass str_class = (env)->FindClass("Ljava/lang/String;"); + jmethodID ctor_id = (env)->GetMethodID(str_class, "", "([BLjava/lang/String;)V"); + jbyteArray bytes = (env)->NewByteArray(strlen(pat)); + (env)->SetByteArrayRegion(bytes, 0, strlen(pat), reinterpret_cast(const_cast(pat))); + jstring encoding = (env)->NewStringUTF("UTF-8"); + jstring result = static_cast((env)->NewObject(str_class, ctor_id, bytes, encoding)); + env->DeleteLocalRef(bytes); + env->DeleteLocalRef(encoding); + return result; +} + +jbyteArray stringTojbyteArray(JNIEnv* env, const std::string & str) { + const auto * ptr = reinterpret_cast(str.c_str()) ; + jbyteArray jarray = env->NewByteArray(str.size()); + env->SetByteArrayRegion(jarray, 0, str.size(), ptr); + return jarray; +} + +} diff --git a/utils/local-engine/jni/jni_common.h b/utils/local-engine/jni/jni_common.h new file mode 100644 index 000000000000..de121b33620a --- /dev/null +++ b/utils/local-engine/jni/jni_common.h @@ -0,0 +1,82 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +jclass CreateGlobalExceptionClassReference(JNIEnv *env, const char *class_name); + +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name); + +jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig); + +jmethodID GetStaticMethodID(JNIEnv * env, jclass this_class, const char * name, const char * sig); + +jstring charTojstring(JNIEnv* env, const char* pat); + +jbyteArray stringTojbyteArray(JNIEnv* env, const std::string & str); + +#define LOCAL_ENGINE_JNI_JMETHOD_START +#define LOCAL_ENGINE_JNI_JMETHOD_END(env) \ + if ((env)->ExceptionCheck())\ + {\ + LOG_ERROR(&Poco::Logger::get("local_engine"), "Enter java exception handle.");\ + auto throwable = (env)->ExceptionOccurred();\ + jclass exceptionClass = (env)->FindClass("java/lang/Exception"); \ + jmethodID getMessageMethod = (env)->GetMethodID(exceptionClass, "getMessage", "()Ljava/lang/String;"); \ + jstring message = static_cast((env)->CallObjectMethod(throwable, getMessageMethod)); \ + const char *messageChars = (env)->GetStringUTFChars(message, NULL); \ + LOG_ERROR(&Poco::Logger::get("jni"), "exception:{}", messageChars); \ + (env)->ReleaseStringUTFChars(message, messageChars); \ + (env)->Throw(throwable);\ + } + +template +jobject safeCallObjectMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + auto ret = env->CallObjectMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env) + return ret; +} + +template +jboolean safeCallBooleanMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + auto ret = env->CallBooleanMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env); + return ret; +} + +template +jlong safeCallLongMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + auto ret = env->CallLongMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env); + return ret; +} + +template +jint safeCallIntMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + auto ret = env->CallIntMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env); + return ret; +} + +template +void safeCallVoidMethod(JNIEnv * env, jobject obj, jmethodID method_id, Args ... args) +{ + LOCAL_ENGINE_JNI_JMETHOD_START + env->CallVoidMethod(obj, method_id, args...); + LOCAL_ENGINE_JNI_JMETHOD_END(env); +} +} diff --git a/utils/local-engine/jni/jni_error.cpp b/utils/local-engine/jni/jni_error.cpp new file mode 100644 index 000000000000..7dd9714af3ac --- /dev/null +++ b/utils/local-engine/jni/jni_error.cpp @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include +#include "Common/Exception.h" +#include + +namespace local_engine +{ +JniErrorsGlobalState & JniErrorsGlobalState::instance() +{ + static JniErrorsGlobalState instance; + return instance; +} + +void JniErrorsGlobalState::destroy(JNIEnv * env) +{ + if (env) + { + if (io_exception_class) + { + env->DeleteGlobalRef(io_exception_class); + } + if (runtime_exception_class) + { + env->DeleteGlobalRef(runtime_exception_class); + } + if (unsupportedoperation_exception_class) + { + env->DeleteGlobalRef(unsupportedoperation_exception_class); + } + if (illegal_access_exception_class) + { + env->DeleteGlobalRef(illegal_access_exception_class); + } + if (illegal_argument_exception_class) + { + env->DeleteGlobalRef(illegal_argument_exception_class); + } + } +} + +void JniErrorsGlobalState::initialize(JNIEnv * env_) +{ + io_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/io/IOException;"); + runtime_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/lang/RuntimeException;"); + unsupportedoperation_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/lang/UnsupportedOperationException;"); + illegal_access_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/lang/IllegalAccessException;"); + illegal_argument_exception_class = CreateGlobalExceptionClassReference(env_, "Ljava/lang/IllegalArgumentException;"); +} + +void JniErrorsGlobalState::throwException(JNIEnv * env, const DB::Exception & e) +{ + throwRuntimeException(env, e.message(), e.getStackTraceString()); +} + +void JniErrorsGlobalState::throwException(JNIEnv * env, const std::exception & e) +{ + throwRuntimeException(env, e.what(), DB::getExceptionStackTraceString(e)); +} + +void JniErrorsGlobalState::throwException(JNIEnv * env,jclass exception_class, const std::string & message, const std::string & stack_trace) +{ + if (exception_class) + { + std::string error_msg = message + "\n" + stack_trace; + env->ThrowNew(exception_class, error_msg.c_str()); + } + else + { + // This will cause a coredump + throw std::runtime_error("Not found java runtime exception class"); + } + +} + +void JniErrorsGlobalState::throwRuntimeException(JNIEnv * env,const std::string & message, const std::string & stack_trace) +{ + throwException(env, runtime_exception_class, message, stack_trace); +} + + +} diff --git a/utils/local-engine/jni/jni_error.h b/utils/local-engine/jni/jni_error.h new file mode 100644 index 000000000000..0efd05b831f6 --- /dev/null +++ b/utils/local-engine/jni/jni_error.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +namespace local_engine +{ +class JniErrorsGlobalState : boost::noncopyable +{ +protected: + JniErrorsGlobalState() = default; +public: + ~JniErrorsGlobalState() = default; + + static JniErrorsGlobalState & instance(); + void initialize(JNIEnv * env_); + void destroy(JNIEnv * env); + + inline jclass getIOExceptionClass() { return io_exception_class; } + inline jclass getRuntimeExceptionClass() { return runtime_exception_class; } + inline jclass getUnsupportedOperationExceptionClass() { return unsupportedoperation_exception_class; } + inline jclass getIllegalAccessExceptionClass() { return illegal_access_exception_class; } + inline jclass getIllegalArgumentExceptionClass() { return illegal_argument_exception_class; } + + void throwException(JNIEnv * env, const DB::Exception & e); + void throwException(JNIEnv * env, const std::exception & e); + static void throwException(JNIEnv * env, jclass exception_class, const std::string & message, const std::string & stack_trace = ""); + void throwRuntimeException(JNIEnv * env, const std::string & message, const std::string & stack_trace = ""); + + +private: + jclass io_exception_class = nullptr; + jclass runtime_exception_class = nullptr; + jclass unsupportedoperation_exception_class = nullptr; + jclass illegal_access_exception_class = nullptr; + jclass illegal_argument_exception_class = nullptr; + +}; +// + +#define LOCAL_ENGINE_JNI_METHOD_START \ + try { + +#define LOCAL_ENGINE_JNI_METHOD_END(env, ret) \ + }\ + catch(DB::Exception & e)\ + {\ + local_engine::JniErrorsGlobalState::instance().throwException(env, e);\ + return ret;\ + }\ + catch (std::exception & e)\ + {\ + local_engine::JniErrorsGlobalState::instance().throwException(env, e);\ + return ret;\ + }\ + catch (...)\ + {\ + DB::WriteBufferFromOwnString ostr;\ + auto trace = boost::stacktrace::stacktrace();\ + boost::stacktrace::detail::to_string(&trace.as_vector()[0], trace.size());\ + local_engine::JniErrorsGlobalState::instance().throwRuntimeException(env, "Unknow Exception", ostr.str().c_str());\ + return ret;\ + } +} diff --git a/utils/local-engine/local_engine_jni.cpp b/utils/local-engine/local_engine_jni.cpp new file mode 100644 index 000000000000..0a5b47d5233d --- /dev/null +++ b/utils/local-engine/local_engine_jni.cpp @@ -0,0 +1,972 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __cplusplus + +static DB::ColumnWithTypeAndName getColumnFromColumnVector(JNIEnv * /*env*/, jobject /*obj*/, jlong block_address, jint column_position) +{ + DB::Block * block = reinterpret_cast(block_address); + return block->getByPosition(column_position); +} + +static std::string jstring2string(JNIEnv * env, jstring jStr) +{ + try + { + if (!jStr) + return ""; + + jclass string_class = env->GetObjectClass(jStr); + jmethodID get_bytes = env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B"); + jbyteArray string_jbytes = static_cast(local_engine::safeCallObjectMethod(env, jStr, get_bytes, env->NewStringUTF("UTF-8"))); + + size_t length = static_cast(env->GetArrayLength(string_jbytes)); + jbyte * p_bytes = env->GetByteArrayElements(string_jbytes, nullptr); + + std::string ret = std::string(reinterpret_cast(p_bytes), length); + env->ReleaseByteArrayElements(string_jbytes, p_bytes, JNI_ABORT); + + env->DeleteLocalRef(string_jbytes); + env->DeleteLocalRef(string_class); + return ret; + } + catch (DB::Exception & e) + { + local_engine::ExceptionUtils::handleException(e); + } +} + +extern "C" { +#endif + +extern char * createExecutor(const std::string &); + +namespace dbms +{ + class LocalExecutor; +} + +static jclass spark_row_info_class; +static jmethodID spark_row_info_constructor; + +static jclass split_result_class; +static jmethodID split_result_constructor; + +jint JNI_OnLoad(JavaVM * vm, void * /*reserved*/) +{ + JNIEnv * env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_8) != JNI_OK) + return JNI_ERR; + + local_engine::JniErrorsGlobalState::instance().initialize(env); + + spark_row_info_class = local_engine::CreateGlobalClassReference(env, "Lio/glutenproject/row/SparkRowInfo;"); + spark_row_info_constructor = env->GetMethodID(spark_row_info_class, "", "([J[JJJJ)V"); + + split_result_class = local_engine::CreateGlobalClassReference(env, "Lio/glutenproject/vectorized/SplitResult;"); + split_result_constructor = local_engine::GetMethodID(env, split_result_class, "", "(JJJJJJ[J[J)V"); + + local_engine::ShuffleReader::input_stream_class = local_engine::CreateGlobalClassReference(env, "Lio/glutenproject/vectorized/ShuffleInputStream;"); + local_engine::NativeSplitter::iterator_class = local_engine::CreateGlobalClassReference(env, "Lio/glutenproject/vectorized/IteratorWrapper;"); + local_engine::WriteBufferFromJavaOutputStream::output_stream_class = local_engine::CreateGlobalClassReference(env, "Ljava/io/OutputStream;"); + local_engine::SourceFromJavaIter::serialized_record_batch_iterator_class + = local_engine::CreateGlobalClassReference(env, "Lio/glutenproject/execution/ColumnarNativeIterator;"); + + local_engine::ShuffleReader::input_stream_read = env->GetMethodID(local_engine::ShuffleReader::input_stream_class, "read", "(JJ)J"); + + local_engine::NativeSplitter::iterator_has_next = local_engine::GetMethodID(env, local_engine::NativeSplitter::iterator_class, "hasNext", "()Z"); + local_engine::NativeSplitter::iterator_next = local_engine::GetMethodID(env, local_engine::NativeSplitter::iterator_class, "next", "()J"); + + local_engine::WriteBufferFromJavaOutputStream::output_stream_write + = local_engine::GetMethodID(env, local_engine::WriteBufferFromJavaOutputStream::output_stream_class, "write", "([BII)V"); + local_engine::WriteBufferFromJavaOutputStream::output_stream_flush + = local_engine::GetMethodID(env, local_engine::WriteBufferFromJavaOutputStream::output_stream_class, "flush", "()V"); + + local_engine::SourceFromJavaIter::serialized_record_batch_iterator_hasNext + = local_engine::GetMethodID(env, local_engine::SourceFromJavaIter::serialized_record_batch_iterator_class, "hasNext", "()Z"); + local_engine::SourceFromJavaIter::serialized_record_batch_iterator_next + = local_engine::GetMethodID(env, local_engine::SourceFromJavaIter::serialized_record_batch_iterator_class, "next", "()[B"); + + local_engine::SparkRowToCHColumn::spark_row_interator_class + = local_engine::CreateGlobalClassReference(env, "Lio/glutenproject/execution/SparkRowIterator;"); + local_engine::SparkRowToCHColumn::spark_row_interator_hasNext + = local_engine::GetMethodID(env, local_engine::SparkRowToCHColumn::spark_row_interator_class, "hasNext", "()Z"); + local_engine::SparkRowToCHColumn::spark_row_interator_next + = local_engine::GetMethodID(env, local_engine::SparkRowToCHColumn::spark_row_interator_class, "next", "()[B"); + local_engine::SparkRowToCHColumn::spark_row_iterator_nextBatch + = local_engine::GetMethodID(env, local_engine::SparkRowToCHColumn::spark_row_interator_class, "nextBatch", "()Ljava/nio/ByteBuffer;"); + + local_engine::ReservationListenerWrapper::reservation_listener_class + = local_engine::CreateGlobalClassReference(env, "Lio/glutenproject/memory/alloc/ReservationListener;"); + local_engine::ReservationListenerWrapper::reservation_listener_reserve + = local_engine::GetMethodID(env, local_engine::ReservationListenerWrapper::reservation_listener_class, "reserve", "(J)J"); + local_engine::ReservationListenerWrapper::reservation_listener_reserve_or_throw + = local_engine::GetMethodID(env, local_engine::ReservationListenerWrapper::reservation_listener_class, "reserveOrThrow", "(J)V"); + local_engine::ReservationListenerWrapper::reservation_listener_unreserve + = local_engine::GetMethodID(env, local_engine::ReservationListenerWrapper::reservation_listener_class, "unreserve", "(J)J"); + + local_engine::JNIUtils::vm = vm; + return JNI_VERSION_1_8; +} + +void JNI_OnUnload(JavaVM * vm, void * /*reserved*/) +{ + local_engine::BackendFinalizerUtil::finalizeGlobally(); + + JNIEnv * env; + vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_8); + + local_engine::JniErrorsGlobalState::instance().destroy(env); + + env->DeleteGlobalRef(spark_row_info_class); + env->DeleteGlobalRef(split_result_class); + env->DeleteGlobalRef(local_engine::ShuffleReader::input_stream_class); + env->DeleteGlobalRef(local_engine::NativeSplitter::iterator_class); + env->DeleteGlobalRef(local_engine::WriteBufferFromJavaOutputStream::output_stream_class); + env->DeleteGlobalRef(local_engine::SourceFromJavaIter::serialized_record_batch_iterator_class); + env->DeleteGlobalRef(local_engine::SparkRowToCHColumn::spark_row_interator_class); + env->DeleteGlobalRef(local_engine::ReservationListenerWrapper::reservation_listener_class); +} + +void Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeInitNative(JNIEnv * env, jobject, jbyteArray plan) +{ + LOCAL_ENGINE_JNI_METHOD_START + jsize plan_buf_size = env->GetArrayLength(plan); + jbyte * plan_buf_addr = env->GetByteArrayElements(plan, nullptr); + std::string plan_str; + plan_str.assign(reinterpret_cast(plan_buf_addr), plan_buf_size); + local_engine::BackendInitializerUtil::init(plan_str); + LOCAL_ENGINE_JNI_METHOD_END(env, ) +} + +void Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeFinalizeNative(JNIEnv * env) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::BackendFinalizerUtil::finalizeSessionally(); + LOCAL_ENGINE_JNI_METHOD_END(env, ) +} + + +jlong Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeCreateKernelWithRowIterator( + JNIEnv * env, jobject /*obj*/, jbyteArray plan) +{ + LOCAL_ENGINE_JNI_METHOD_START + jsize plan_size = env->GetArrayLength(plan); + jbyte * plan_address = env->GetByteArrayElements(plan, nullptr); + std::string plan_string; + plan_string.assign(reinterpret_cast(plan_address), plan_size); + auto * executor = createExecutor(plan_string); + env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); + return reinterpret_cast(executor); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jlong Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeCreateKernelWithIterator( + JNIEnv * env, jobject /*obj*/, jlong allocator_id, jbyteArray plan, jobjectArray iter_arr) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto query_context = local_engine::getAllocator(allocator_id)->query_context; + local_engine::SerializedPlanParser parser(query_context); + jsize iter_num = env->GetArrayLength(iter_arr); + for (jsize i = 0; i < iter_num; i++) + { + jobject iter = env->GetObjectArrayElement(iter_arr, i); + iter = env->NewGlobalRef(iter); + parser.addInputIter(iter); + } + jsize plan_size = env->GetArrayLength(plan); + jbyte * plan_address = env->GetByteArrayElements(plan, nullptr); + std::string plan_string; + plan_string.assign(reinterpret_cast(plan_address), plan_size); + auto query_plan = parser.parse(plan_string); + local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(parser.query_context); + executor->execute(std::move(query_plan)); + env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); + return reinterpret_cast(executor); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jboolean Java_io_glutenproject_row_RowIterator_nativeHasNext(JNIEnv * env, jobject /*obj*/, jlong executor_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); + return executor->hasNext(); + LOCAL_ENGINE_JNI_METHOD_END(env, false) +} + +jobject Java_io_glutenproject_row_RowIterator_nativeNext(JNIEnv * env, jobject /*obj*/, jlong executor_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); + local_engine::SparkRowInfoPtr spark_row_info = executor->next(); + + auto * offsets_arr = env->NewLongArray(spark_row_info->getNumRows()); + const auto * offsets_src = reinterpret_cast(spark_row_info->getOffsets().data()); + env->SetLongArrayRegion(offsets_arr, 0, spark_row_info->getNumRows(), offsets_src); + auto * lengths_arr = env->NewLongArray(spark_row_info->getNumRows()); + const auto * lengths_src = reinterpret_cast(spark_row_info->getLengths().data()); + env->SetLongArrayRegion(lengths_arr, 0, spark_row_info->getNumRows(), lengths_src); + int64_t address = reinterpret_cast(spark_row_info->getBufferAddress()); + int64_t column_number = reinterpret_cast(spark_row_info->getNumCols()); + int64_t total_size = reinterpret_cast(spark_row_info->getTotalBytes()); + + jobject spark_row_info_object + = env->NewObject(spark_row_info_class, spark_row_info_constructor, offsets_arr, lengths_arr, address, column_number, total_size); + return spark_row_info_object; + LOCAL_ENGINE_JNI_METHOD_END(env, nullptr) +} + +void Java_io_glutenproject_row_RowIterator_nativeClose(JNIEnv * env, jobject /*obj*/, jlong executor_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); + delete executor; + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +// Columnar Iterator +jboolean Java_io_glutenproject_vectorized_BatchIterator_nativeHasNext(JNIEnv * env, jobject /*obj*/, jlong executor_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); + return executor->hasNext(); + LOCAL_ENGINE_JNI_METHOD_END(env, false) +} + +jlong Java_io_glutenproject_vectorized_BatchIterator_nativeCHNext(JNIEnv * env, jobject /*obj*/, jlong executor_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); + DB::Block * column_batch = executor->nextColumnar(); + LOG_DEBUG(&Poco::Logger::get("jni"), "row size of the column batch: {}", column_batch->rows()); + return reinterpret_cast(column_batch); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_BatchIterator_nativeClose(JNIEnv * env, jobject /*obj*/, jlong executor_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); + delete executor; + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +void Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeSetJavaTmpDir(JNIEnv * /*env*/, jobject /*obj*/, jstring /*dir*/) +{ +} + +void Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeSetBatchSize( + JNIEnv * /*env*/, jobject /*obj*/, jint /*batch_size*/) +{ +} + +void Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeSetMetricsTime( + JNIEnv * /*env*/, jobject /*obj*/, jboolean /*setMetricsTime*/) +{ +} + +jboolean Java_io_glutenproject_vectorized_CHColumnVector_nativeHasNull(JNIEnv * env, jobject obj, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + DB::Block * block = reinterpret_cast(block_address); + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + if (!col.column->isNullable()) + { + return false; + } + else + { + const auto * nullable = checkAndGetColumn(*col.column); + size_t num_nulls = std::accumulate(nullable->getNullMapData().begin(), nullable->getNullMapData().end(), 0); + return num_nulls < block->rows(); + } + LOCAL_ENGINE_JNI_METHOD_END(env,false) +} + +jint Java_io_glutenproject_vectorized_CHColumnVector_nativeNumNulls(JNIEnv * env, jobject obj, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + if (!col.column->isNullable()) + { + return 0; + } + else + { + const auto * nullable = checkAndGetColumn(*col.column); + return std::accumulate(nullable->getNullMapData().begin(), nullable->getNullMapData().end(), 0); + } + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jboolean Java_io_glutenproject_vectorized_CHColumnVector_nativeIsNullAt( + JNIEnv * env, jobject obj, jint row_id, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + return col.column->isNullAt(row_id); + LOCAL_ENGINE_JNI_METHOD_END(env, false) +} + +jboolean Java_io_glutenproject_vectorized_CHColumnVector_nativeGetBoolean( + JNIEnv * env, jobject obj, jint row_id, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + DB::ColumnPtr nested_col = col.column; + if (const auto * nullable_col = checkAndGetColumn(nested_col.get())) + { + nested_col = nullable_col->getNestedColumnPtr(); + } + return nested_col->getBool(row_id); + LOCAL_ENGINE_JNI_METHOD_END(env, false) +} + +jbyte Java_io_glutenproject_vectorized_CHColumnVector_nativeGetByte( + JNIEnv * env, jobject obj, jint row_id, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + DB::ColumnPtr nested_col = col.column; + if (const auto * nullable_col = checkAndGetColumn(nested_col.get())) + { + nested_col = nullable_col->getNestedColumnPtr(); + } + return reinterpret_cast(nested_col->getDataAt(row_id).data)[0]; + LOCAL_ENGINE_JNI_METHOD_END(env, 0) +} + +jshort Java_io_glutenproject_vectorized_CHColumnVector_nativeGetShort( + JNIEnv * env, jobject obj, jint row_id, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + DB::ColumnPtr nested_col = col.column; + if (const auto * nullable_col = checkAndGetColumn(nested_col.get())) + { + nested_col = nullable_col->getNestedColumnPtr(); + } + return reinterpret_cast(nested_col->getDataAt(row_id).data)[0]; + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jint Java_io_glutenproject_vectorized_CHColumnVector_nativeGetInt( + JNIEnv * env, jobject obj, jint row_id, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + DB::ColumnPtr nested_col = col.column; + if (const auto * nullable_col = checkAndGetColumn(nested_col.get())) + { + nested_col = nullable_col->getNestedColumnPtr(); + } + if (col.type->getTypeId() == DB::TypeIndex::Date) + { + return nested_col->getUInt(row_id); + } + else + { + return nested_col->getInt(row_id); + } + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jlong Java_io_glutenproject_vectorized_CHColumnVector_nativeGetLong( + JNIEnv * env, jobject obj, jint row_id, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + DB::ColumnPtr nested_col = col.column; + if (const auto * nullable_col = checkAndGetColumn(nested_col.get())) + { + nested_col = nullable_col->getNestedColumnPtr(); + } + return nested_col->getInt(row_id); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jfloat Java_io_glutenproject_vectorized_CHColumnVector_nativeGetFloat( + JNIEnv * env, jobject obj, jint row_id, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + DB::ColumnPtr nested_col = col.column; + if (const auto * nullable_col = checkAndGetColumn(nested_col.get())) + { + nested_col = nullable_col->getNestedColumnPtr(); + } + return nested_col->getFloat32(row_id); + LOCAL_ENGINE_JNI_METHOD_END(env, 0.0) +} + +jdouble Java_io_glutenproject_vectorized_CHColumnVector_nativeGetDouble( + JNIEnv * env, jobject obj, jint row_id, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + DB::ColumnPtr nested_col = col.column; + if (const auto * nullable_col = checkAndGetColumn(nested_col.get())) + { + nested_col = nullable_col->getNestedColumnPtr(); + } + return nested_col->getFloat64(row_id); + LOCAL_ENGINE_JNI_METHOD_END(env, 0.0) +} + +jstring Java_io_glutenproject_vectorized_CHColumnVector_nativeGetString( + JNIEnv * env, jobject obj, jint row_id, jlong block_address, jint column_position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto col = getColumnFromColumnVector(env, obj, block_address, column_position); + DB::ColumnPtr nested_col = col.column; + if (const auto * nullable_col = checkAndGetColumn(nested_col.get())) + { + nested_col = nullable_col->getNestedColumnPtr(); + } + const auto * string_col = checkAndGetColumn(nested_col.get()); + auto result = string_col->getDataAt(row_id); + return local_engine::charTojstring(env, result.toString().c_str()); + LOCAL_ENGINE_JNI_METHOD_END(env, local_engine::charTojstring(env, "")) +} + +// native block +void Java_io_glutenproject_vectorized_CHNativeBlock_nativeClose(JNIEnv * /*env*/, jobject /*obj*/, jlong /*block_address*/) +{ +} + +jint Java_io_glutenproject_vectorized_CHNativeBlock_nativeNumRows(JNIEnv * env, jobject /*obj*/, jlong block_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + DB::Block * block = reinterpret_cast(block_address); + return block->rows(); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jint Java_io_glutenproject_vectorized_CHNativeBlock_nativeNumColumns(JNIEnv * env, jobject /*obj*/, jlong block_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * block = reinterpret_cast(block_address); + return block->columns(); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jbyteArray Java_io_glutenproject_vectorized_CHNativeBlock_nativeColumnType(JNIEnv * env, jobject /*obj*/, jlong block_address, jint position) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * block = reinterpret_cast(block_address); + const auto & col = block->getByPosition(position); + std::string substrait_type; + dbms::SerializedPlanBuilder::buildType(col.type, substrait_type); + return local_engine::stringTojbyteArray(env, substrait_type); + LOCAL_ENGINE_JNI_METHOD_END(env, local_engine::stringTojbyteArray(env, "")) +} + +jlong Java_io_glutenproject_vectorized_CHNativeBlock_nativeTotalBytes(JNIEnv * env, jobject /*obj*/, jlong block_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * block = reinterpret_cast(block_address); + return block->bytes(); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jlong Java_io_glutenproject_vectorized_CHStreamReader_createNativeShuffleReader( + JNIEnv * env, jclass /*clazz*/, jobject input_stream, jboolean compressed, size_t customize_buffer_size) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * input = env->NewGlobalRef(input_stream); + auto read_buffer = std::make_unique(input, customize_buffer_size); + auto * shuffle_reader = new local_engine::ShuffleReader(std::move(read_buffer), compressed); + return reinterpret_cast(shuffle_reader); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jlong Java_io_glutenproject_vectorized_CHStreamReader_nativeNext(JNIEnv * env, jobject /*obj*/, jlong shuffle_reader) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::ShuffleReader * reader = reinterpret_cast(shuffle_reader); + DB::Block * block = reader->read(); + return reinterpret_cast(block); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_CHStreamReader_nativeClose(JNIEnv * env, jobject /*obj*/, jlong shuffle_reader) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::ShuffleReader * reader = reinterpret_cast(shuffle_reader); + delete reader; + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +jlong Java_io_glutenproject_vectorized_CHCoalesceOperator_createNativeOperator(JNIEnv * env, jobject /*obj*/, jint buf_size) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::BlockCoalesceOperator * instance = new local_engine::BlockCoalesceOperator(buf_size); + return reinterpret_cast(instance); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_CHCoalesceOperator_nativeMergeBlock( + JNIEnv * env, jobject /*obj*/, jlong instance_address, jlong block_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::BlockCoalesceOperator * instance = reinterpret_cast(instance_address); + DB::Block * block = reinterpret_cast(block_address); + auto new_block = DB::Block(*block); + instance->mergeBlock(new_block); + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +jboolean Java_io_glutenproject_vectorized_CHCoalesceOperator_nativeIsFull(JNIEnv * env, jobject /*obj*/, jlong instance_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::BlockCoalesceOperator * instance = reinterpret_cast(instance_address); + bool full = instance->isFull(); + return full ? JNI_TRUE : JNI_FALSE; + LOCAL_ENGINE_JNI_METHOD_END(env, false) +} + +jlong Java_io_glutenproject_vectorized_CHCoalesceOperator_nativeRelease(JNIEnv * env, jobject /*obj*/, jlong instance_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::BlockCoalesceOperator * instance = reinterpret_cast(instance_address); + auto * block = instance->releaseBlock(); + Int64 address = reinterpret_cast(block); + return address; + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_CHCoalesceOperator_nativeClose(JNIEnv * env, jobject /*obj*/, jlong instance_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::BlockCoalesceOperator * instance = reinterpret_cast(instance_address); + delete instance; + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +// Splitter Jni Wrapper +jlong Java_io_glutenproject_vectorized_CHShuffleSplitterJniWrapper_nativeMake( + JNIEnv * env, + jobject, + jstring short_name, + jint num_partitions, + jbyteArray expr_list, + jbyteArray out_expr_list, + jint shuffle_id, + jlong map_id, + jint split_size, + jstring codec, + jstring data_file, + jstring local_dirs, + jint num_sub_dirs) +{ + LOCAL_ENGINE_JNI_METHOD_START + std::string hash_exprs; + std::string out_exprs; + if (expr_list != nullptr) + { + int len = env->GetArrayLength(expr_list); + auto * str = reinterpret_cast(new char[len]); + memset(str, 0, len); + env->GetByteArrayRegion(expr_list, 0, len, str); + hash_exprs = std::string(str, str + len); + delete[] str; + } + + if (out_expr_list != nullptr) + { + int len = env->GetArrayLength(out_expr_list); + auto * str = reinterpret_cast(new char[len]); + memset(str, 0, len); + env->GetByteArrayRegion(out_expr_list, 0, len, str); + out_exprs = std::string(str, str + len); + delete[] str; + } + + Poco::StringTokenizer local_dirs_tokenizer(jstring2string(env, local_dirs), ","); + std::vector local_dirs_list; + local_dirs_list.insert(local_dirs_list.end(), local_dirs_tokenizer.begin(), local_dirs_tokenizer.end()); + + local_engine::SplitOptions options{ + .split_size = static_cast(split_size), + .io_buffer_size = DBMS_DEFAULT_BUFFER_SIZE, + .data_file = jstring2string(env, data_file), + .local_dirs_list = std::move(local_dirs_list), + .num_sub_dirs = num_sub_dirs, + .shuffle_id = shuffle_id, + .map_id = static_cast(map_id), + .partition_nums = static_cast(num_partitions), + .hash_exprs = hash_exprs, + .out_exprs = out_exprs, + .compress_method = jstring2string(env, codec)}; + local_engine::SplitterHolder * splitter + = new local_engine::SplitterHolder{.splitter = local_engine::ShuffleSplitter::create(jstring2string(env, short_name), options)}; + return reinterpret_cast(splitter); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_CHShuffleSplitterJniWrapper_split(JNIEnv * env, jobject, jlong splitterId, jint, jlong block) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::SplitterHolder * splitter = reinterpret_cast(splitterId); + DB::Block * data = reinterpret_cast(block); + splitter->splitter->split(*data); + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +jobject Java_io_glutenproject_vectorized_CHShuffleSplitterJniWrapper_stop(JNIEnv * env, jobject, jlong splitterId) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::SplitterHolder * splitter = reinterpret_cast(splitterId); + auto result = splitter->splitter->stop(); + const auto & partition_lengths = result.partition_length; + auto *partition_length_arr = env->NewLongArray(partition_lengths.size()); + const auto *src = reinterpret_cast(partition_lengths.data()); + env->SetLongArrayRegion(partition_length_arr, 0, partition_lengths.size(), src); + + const auto & raw_partition_lengths = result.raw_partition_length; + auto *raw_partition_length_arr = env->NewLongArray(raw_partition_lengths.size()); + const auto *raw_src = reinterpret_cast(raw_partition_lengths.data()); + env->SetLongArrayRegion(raw_partition_length_arr, 0, raw_partition_lengths.size(), raw_src); + + jobject split_result = env->NewObject( + split_result_class, + split_result_constructor, + result.total_compute_pid_time, + result.total_write_time, + result.total_spill_time, + 0, + result.total_bytes_written, + result.total_bytes_written, + partition_length_arr, + raw_partition_length_arr); + + return split_result; + LOCAL_ENGINE_JNI_METHOD_END(env, nullptr) +} + +void Java_io_glutenproject_vectorized_CHShuffleSplitterJniWrapper_close(JNIEnv * env, jobject, jlong splitterId) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::SplitterHolder * splitter = reinterpret_cast(splitterId); + delete splitter; + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +// BlockNativeConverter +jobject Java_io_glutenproject_vectorized_BlockNativeConverter_convertColumnarToRow(JNIEnv * env, jobject, jlong block_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::CHColumnToSparkRow converter; + DB::Block * block = reinterpret_cast(block_address); + auto spark_row_info = converter.convertCHColumnToSparkRow(*block); + + auto * offsets_arr = env->NewLongArray(spark_row_info->getNumRows()); + const auto * offsets_src = reinterpret_cast(spark_row_info->getOffsets().data()); + env->SetLongArrayRegion(offsets_arr, 0, spark_row_info->getNumRows(), offsets_src); + auto * lengths_arr = env->NewLongArray(spark_row_info->getNumRows()); + const auto * lengths_src = reinterpret_cast(spark_row_info->getLengths().data()); + env->SetLongArrayRegion(lengths_arr, 0, spark_row_info->getNumRows(), lengths_src); + int64_t address = reinterpret_cast(spark_row_info->getBufferAddress()); + int64_t column_number = reinterpret_cast(spark_row_info->getNumCols()); + int64_t total_size = reinterpret_cast(spark_row_info->getTotalBytes()); + + jobject spark_row_info_object + = env->NewObject(spark_row_info_class, spark_row_info_constructor, offsets_arr, lengths_arr, address, column_number, total_size); + + return spark_row_info_object; + LOCAL_ENGINE_JNI_METHOD_END(env, nullptr) +} + +void Java_io_glutenproject_vectorized_BlockNativeConverter_freeMemory(JNIEnv * env, jobject, jlong address, jlong size) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::CHColumnToSparkRow converter; + converter.freeMem(reinterpret_cast(address), size); + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +jlong Java_io_glutenproject_vectorized_BlockNativeConverter_convertSparkRowsToCHColumn( + JNIEnv * env, jobject, jobject java_iter, jobjectArray names, jobjectArray types) +{ + LOCAL_ENGINE_JNI_METHOD_START + using namespace std; + + int num_columns = env->GetArrayLength(names); + vector c_names; + vector c_types; + c_names.reserve(num_columns); + for (int i = 0; i < num_columns; i++) + { + auto * name = static_cast(env->GetObjectArrayElement(names, i)); + c_names.emplace_back(std::move(jstring2string(env, name))); + + auto * type = static_cast(env->GetObjectArrayElement(types, i)); + auto type_length = env->GetArrayLength(type); + jbyte * type_ptr = env->GetByteArrayElements(type, nullptr); + string str_type(reinterpret_cast(type_ptr), type_length); + c_types.emplace_back(std::move(str_type)); + + env->ReleaseByteArrayElements(type, type_ptr, JNI_ABORT); + env->DeleteLocalRef(name); + env->DeleteLocalRef(type); + } + local_engine::SparkRowToCHColumn converter; + auto * block = converter.convertSparkRowItrToCHColumn(java_iter, c_names, c_types); + return reinterpret_cast(block); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_BlockNativeConverter_freeBlock(JNIEnv * env, jobject, jlong block_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::SparkRowToCHColumn converter; + converter.freeBlock(reinterpret_cast(block_address)); + LOCAL_ENGINE_JNI_METHOD_END(env, ) +} + +jlong Java_io_glutenproject_vectorized_BlockNativeWriter_nativeCreateInstance(JNIEnv * env, jobject) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * writer = new local_engine::NativeWriterInMemory(); + return reinterpret_cast(writer); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_BlockNativeWriter_nativeWrite(JNIEnv * env, jobject, jlong instance, jlong block_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * writer = reinterpret_cast(instance); + auto * block = reinterpret_cast(block_address); + writer->write(*block); + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +jint Java_io_glutenproject_vectorized_BlockNativeWriter_nativeResultSize(JNIEnv * env, jobject, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * writer = reinterpret_cast(instance); + return static_cast(writer->collect().size()); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_BlockNativeWriter_nativeCollect(JNIEnv * env, jobject, jlong instance, jbyteArray result) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * writer = reinterpret_cast(instance); + auto data = writer->collect(); + env->SetByteArrayRegion(result, 0, data.size(), reinterpret_cast(data.data())); + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +void Java_io_glutenproject_vectorized_BlockNativeWriter_nativeClose(JNIEnv * env, jobject, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * writer = reinterpret_cast(instance); + delete writer; + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +void Java_io_glutenproject_vectorized_StorageJoinBuilder_nativeBuild( + JNIEnv * env, jobject, jstring hash_table_id_, jobject in, jint io_buffer_size, jstring join_key_, jstring join_type_, jbyteArray named_struct) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto * input = env->NewGlobalRef(in); + auto hash_table_id = jstring2string(env, hash_table_id_); + auto join_key = jstring2string(env, join_key_); + auto join_type = jstring2string(env, join_type_); + jsize struct_size = env->GetArrayLength(named_struct); + jbyte * struct_address = env->GetByteArrayElements(named_struct, nullptr); + std::string struct_string; + struct_string.assign(reinterpret_cast(struct_address), struct_size); + local_engine::BroadCastJoinBuilder::buildJoinIfNotExist(hash_table_id, input, io_buffer_size, join_key, join_type, struct_string); + env->ReleaseByteArrayElements(named_struct, struct_address, JNI_ABORT); + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +// BlockSplitIterator +jlong Java_io_glutenproject_vectorized_BlockSplitIterator_nativeCreate( + JNIEnv * env, jobject, jobject in, jstring name, jstring expr, jstring schema, jint partition_num, jint buffer_size) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::NativeSplitter::Options options; + options.partition_nums = partition_num; + options.buffer_size = buffer_size; + auto expr_str = jstring2string(env, expr); + std::string schema_str; + if (schema) { + schema_str = jstring2string(env, schema); + } + options.exprs_buffer.swap(expr_str); + options.schema_buffer.swap(schema_str); + local_engine::NativeSplitter::Holder * splitter = new local_engine::NativeSplitter::Holder{ + .splitter = local_engine::NativeSplitter::create(jstring2string(env, name), options, in)}; + return reinterpret_cast(splitter); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_BlockSplitIterator_nativeClose(JNIEnv * env, jobject, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::NativeSplitter::Holder * splitter = reinterpret_cast(instance); + delete splitter; + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +jboolean Java_io_glutenproject_vectorized_BlockSplitIterator_nativeHasNext(JNIEnv * env, jobject, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::NativeSplitter::Holder * splitter = reinterpret_cast(instance); + return splitter->splitter->hasNext(); + LOCAL_ENGINE_JNI_METHOD_END(env, false) +} + +jlong Java_io_glutenproject_vectorized_BlockSplitIterator_nativeNext(JNIEnv * env, jobject, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::NativeSplitter::Holder * splitter = reinterpret_cast(instance); + return reinterpret_cast(splitter->splitter->next()); + LOCAL_ENGINE_JNI_METHOD_END(env, false) +} + +jint Java_io_glutenproject_vectorized_BlockSplitIterator_nativeNextPartitionId(JNIEnv * env, jobject, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::NativeSplitter::Holder * splitter = reinterpret_cast(instance); + return reinterpret_cast(splitter->splitter->nextPartitionId()); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jlong Java_io_glutenproject_vectorized_BlockOutputStream_nativeCreate(JNIEnv * env, jobject, jobject output_stream, jbyteArray buffer, jstring codec, jboolean compressed, jint customize_buffer_size) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::ShuffleWriter * writer = new local_engine::ShuffleWriter(output_stream, buffer, jstring2string(env, codec), compressed, customize_buffer_size); + return reinterpret_cast(writer); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_BlockOutputStream_nativeClose(JNIEnv * env, jobject, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::ShuffleWriter * writer = reinterpret_cast(instance); + writer->flush(); + delete writer; + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +void Java_io_glutenproject_vectorized_BlockOutputStream_nativeWrite(JNIEnv * env, jobject, jlong instance, jlong block_address) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::ShuffleWriter * writer = reinterpret_cast(instance); + DB::Block * block = reinterpret_cast(block_address); + writer->write(*block); + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +void Java_io_glutenproject_vectorized_BlockOutputStream_nativeFlush(JNIEnv * env, jobject, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::ShuffleWriter * writer = reinterpret_cast(instance); + writer->flush(); + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +jlong Java_io_glutenproject_vectorized_SimpleExpressionEval_createNativeInstance(JNIEnv * env, jclass, jobject input, jbyteArray plan) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto context = DB::Context::createCopy(local_engine::SerializedPlanParser::global_context); + local_engine::SerializedPlanParser parser(context); + jobject iter = env->NewGlobalRef(input); + parser.addInputIter(iter); + jsize plan_size = env->GetArrayLength(plan); + jbyte * plan_address = env->GetByteArrayElements(plan, nullptr); + std::string plan_string; + plan_string.assign(reinterpret_cast(plan_address), plan_size); + auto query_plan = parser.parse(plan_string); + local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(parser.query_context); + executor->execute(std::move(query_plan)); + env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); + return reinterpret_cast(executor); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_vectorized_SimpleExpressionEval_nativeClose(JNIEnv * env, jclass, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor * executor = reinterpret_cast(instance); + delete executor; + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +jboolean Java_io_glutenproject_vectorized_SimpleExpressionEval_nativeHasNext(JNIEnv * env, jclass, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor * executor = reinterpret_cast(instance); + return executor->hasNext(); + LOCAL_ENGINE_JNI_METHOD_END(env, false) +} + +jlong Java_io_glutenproject_vectorized_SimpleExpressionEval_nativeNext(JNIEnv * env, jclass, jlong instance) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor * executor = reinterpret_cast(instance); + return reinterpret_cast(executor->nextColumnar()); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +jlong Java_io_glutenproject_memory_alloc_NativeMemoryAllocator_getDefaultAllocator(JNIEnv* env, jclass) +{ + return -1; +} + +jlong Java_io_glutenproject_memory_alloc_NativeMemoryAllocator_createListenableAllocator(JNIEnv* env, jclass, jobject listener) +{ + LOCAL_ENGINE_JNI_METHOD_START + auto listener_wrapper = std::make_shared(env->NewGlobalRef(listener)); + return local_engine::initializeQuery(listener_wrapper); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +void Java_io_glutenproject_memory_alloc_NativeMemoryAllocator_releaseAllocator(JNIEnv* env, jclass, jlong allocator_id) +{ + LOCAL_ENGINE_JNI_METHOD_START + local_engine::releaseAllocator(allocator_id); + LOCAL_ENGINE_JNI_METHOD_END(env,) +} + +jlong Java_io_glutenproject_memory_alloc_NativeMemoryAllocator_bytesAllocated(JNIEnv* env, jclass, jlong allocator_id) +{ + LOCAL_ENGINE_JNI_METHOD_START + return local_engine::allocatorMemoryUsage(allocator_id); + LOCAL_ENGINE_JNI_METHOD_END(env, -1) +} + +#ifdef __cplusplus +} + +#endif diff --git a/utils/local-engine/proto/CMakeLists.txt b/utils/local-engine/proto/CMakeLists.txt new file mode 100644 index 000000000000..13bb59cb03e7 --- /dev/null +++ b/utils/local-engine/proto/CMakeLists.txt @@ -0,0 +1,30 @@ +file(GLOB protobuf_files + substrait/*.proto + substrait/extensions/*.proto + ) + +FOREACH(FIL ${protobuf_files}) + file(RELATIVE_PATH FIL_RELATIVE ${ClickHouse_SOURCE_DIR}/utils/local-engine/proto/ ${FIL}) + string(REGEX REPLACE "\\.proto" "" FILE_NAME ${FIL_RELATIVE}) + LIST(APPEND SUBSTRAIT_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${FILE_NAME}.pb.cc") + LIST(APPEND SUBSTRAIT_HEADERS "${CMAKE_CURRENT_BINARY_DIR}/${FILE_NAME}.pb.h") +ENDFOREACH() + +add_custom_target( + generate_substrait + COMMAND ${CMAKE_CURRENT_BINARY_DIR}/../../../contrib/protobuf-cmake/protoc -I${CMAKE_CURRENT_SOURCE_DIR} -I${ClickHouse_SOURCE_DIR}/contrib/protobuf/src --cpp_out=${CMAKE_CURRENT_BINARY_DIR}/ ${protobuf_files} + DEPENDS protoc + COMMENT "Running cpp protocol buffer compiler" + VERBATIM ) + +set(Protobuf_INCLUDE_DIR "${ClickHouse_SOURCE_DIR}/contrib/protobuf/src") + +set_source_files_properties(${SUBSTRAIT_SRCS} PROPERTIES GENERATED TRUE) + +add_library(substrait ${SUBSTRAIT_SRCS}) +target_compile_options(substrait PUBLIC -fPIC -Wno-reserved-identifier -Wno-deprecated) +add_dependencies(substrait generate_substrait) +target_include_directories(substrait SYSTEM BEFORE PRIVATE ${PROTOBUF_INCLUDE_DIR}) +target_include_directories(substrait SYSTEM BEFORE PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +target_link_libraries(substrait _libprotobuf) + diff --git a/utils/local-engine/proto/Exprs.proto b/utils/local-engine/proto/Exprs.proto new file mode 100644 index 000000000000..95185aa46a1d --- /dev/null +++ b/utils/local-engine/proto/Exprs.proto @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +syntax = "proto2"; +package exprs; + +option java_package = "org.apache.arrow.gandiva.ipc"; +option java_outer_classname = "GandivaTypes"; +option optimize_for = SPEED; + +enum GandivaType { + NONE = 0; // arrow::Type::NA + BOOL = 1; // arrow::Type::BOOL + UINT8 = 2; // arrow::Type::UINT8 + INT8 = 3; // arrow::Type::INT8 + UINT16 = 4; // represents arrow::Type fields in src/arrow/type.h + INT16 = 5; + UINT32 = 6; + INT32 = 7; + UINT64 = 8; + INT64 = 9; + HALF_FLOAT = 10; + FLOAT = 11; + DOUBLE = 12; + UTF8 = 13; + BINARY = 14; + FIXED_SIZE_BINARY = 15; + DATE32 = 16; + DATE64 = 17; + TIMESTAMP = 18; + TIME32 = 19; + TIME64 = 20; + INTERVAL = 21; + DECIMAL = 22; + LIST = 23; + STRUCT = 24; + UNION = 25; + DICTIONARY = 26; + MAP = 27; +} + +enum DateUnit { + DAY = 0; + MILLI = 1; +} + +enum TimeUnit { + SEC = 0; + MILLISEC = 1; + MICROSEC = 2; + NANOSEC = 3; +} + +enum IntervalType { + YEAR_MONTH = 0; + DAY_TIME = 1; +} + +enum SelectionVectorType { + SV_NONE = 0; + SV_INT16 = 1; + SV_INT32 = 2; +} + +message ExtGandivaType { + optional GandivaType type = 1; + optional uint32 width = 2; // used by FIXED_SIZE_BINARY + optional int32 precision = 3; // used by DECIMAL + optional int32 scale = 4; // used by DECIMAL + optional DateUnit dateUnit = 5; // used by DATE32/DATE64 + optional TimeUnit timeUnit = 6; // used by TIME32/TIME64 + optional string timeZone = 7; // used by TIMESTAMP + optional IntervalType intervalType = 8; // used by INTERVAL +} + +message Field { + // name of the field + optional string name = 1; + optional ExtGandivaType type = 2; + optional bool nullable = 3; + // for complex data types like structs, unions + repeated Field children = 4; +} + +message FieldNode { + optional Field field = 1; +} + +message FunctionNode { + optional string functionName = 1; + repeated TreeNode inArgs = 2; + optional ExtGandivaType returnType = 3; +} + +message IfNode { + optional TreeNode cond = 1; + optional TreeNode thenNode = 2; + optional TreeNode elseNode = 3; + optional ExtGandivaType returnType = 4; +} + +message AndNode { + repeated TreeNode args = 1; +} + +message OrNode { + repeated TreeNode args = 1; +} + +message NullNode { + optional ExtGandivaType type = 1; +} + +message IntNode { + optional int32 value = 1; +} + +message FloatNode { + optional float value = 1; +} + +message DoubleNode { + optional double value = 1; +} + +message BooleanNode { + optional bool value = 1; +} + +message LongNode { + optional int64 value = 1; +} + +message StringNode { + optional bytes value = 1; +} + +message BinaryNode { + optional bytes value = 1; +} + +message DecimalNode { + optional string value = 1; + optional int32 precision = 2; + optional int32 scale = 3; +} + +message TreeNode { + optional FieldNode fieldNode = 1; + optional FunctionNode fnNode = 2; + + // control expressions + optional IfNode ifNode = 6; + optional AndNode andNode = 7; + optional OrNode orNode = 8; + + // literals + optional NullNode nullNode = 11; + optional IntNode intNode = 12; + optional FloatNode floatNode = 13; + optional LongNode longNode = 14; + optional BooleanNode booleanNode = 15; + optional DoubleNode doubleNode = 16; + optional StringNode stringNode = 17; + optional BinaryNode binaryNode = 18; + optional DecimalNode decimalNode = 19; + + // in expr + optional InNode inNode = 21; +} + +message ExpressionRoot { + optional TreeNode root = 1; + optional Field resultType = 2; +} + +message ExpressionList { + repeated ExpressionRoot exprs = 2; +} + +message Condition { + optional TreeNode root = 1; +} + +message Schema { + repeated Field columns = 1; +} + +message GandivaDataTypes { + repeated ExtGandivaType dataType = 1; +} + +message GandivaFunctions { + repeated FunctionSignature function = 1; +} + +message FunctionSignature { + optional string name = 1; + optional ExtGandivaType returnType = 2; + repeated ExtGandivaType paramTypes = 3; +} + +message InNode { + optional TreeNode node = 1; + optional IntConstants intValues = 2; + optional LongConstants longValues = 3; + optional StringConstants stringValues = 4; + optional BinaryConstants binaryValues = 5; +} + +message IntConstants { + repeated IntNode intValues = 1; +} + +message LongConstants { + repeated LongNode longValues = 1; +} + +message StringConstants { + repeated StringNode stringValues = 1; +} + +message BinaryConstants { + repeated BinaryNode binaryValues = 1; +} diff --git a/utils/local-engine/proto/substrait/algebra.proto b/utils/local-engine/proto/substrait/algebra.proto new file mode 100644 index 000000000000..1435ad5b3d20 --- /dev/null +++ b/utils/local-engine/proto/substrait/algebra.proto @@ -0,0 +1,1346 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package substrait; + +import "google/protobuf/any.proto"; +import "substrait/extensions/extensions.proto"; +import "substrait/type.proto"; + +option cc_enable_arenas = true; +option csharp_namespace = "Substrait.Protobuf"; +option go_package = "github.com/substrait-io/substrait-go/proto"; +option java_multiple_files = true; +option java_package = "io.substrait.proto"; + +// Common fields for all relational operators +message RelCommon { + oneof emit_kind { + // The underlying relation is output as is (no reordering or projection of columns) + Direct direct = 1; + // Allows to control for order and inclusion of fields + Emit emit = 2; + } + + Hint hint = 3; + substrait.extensions.AdvancedExtension advanced_extension = 4; + + // Direct indicates no change on presence and ordering of fields in the output + message Direct {} + + // Remap which fields are output and in which order + message Emit { + repeated int32 output_mapping = 1; + } + + // Changes to the operation that can influence efficiency/performance but + // should not impact correctness. + message Hint { + Stats stats = 1; + RuntimeConstraint constraint = 2; + substrait.extensions.AdvancedExtension advanced_extension = 10; + + // The statistics related to a hint (physical properties of records) + message Stats { + double row_count = 1; + double record_size = 2; + substrait.extensions.AdvancedExtension advanced_extension = 10; + } + + message RuntimeConstraint { + // TODO: nodes, cpu threads/%, memory, iops, etc. + + substrait.extensions.AdvancedExtension advanced_extension = 10; + } + } +} + +// The scan operator of base data (physical or virtual), including filtering and projection. +message ReadRel { + RelCommon common = 1; + NamedStruct base_schema = 2; + Expression filter = 3; + Expression best_effort_filter = 11; + Expression.MaskExpression projection = 4; + substrait.extensions.AdvancedExtension advanced_extension = 10; + + // Definition of which type of scan operation is to be performed + oneof read_type { + VirtualTable virtual_table = 5; + LocalFiles local_files = 6; + NamedTable named_table = 7; + ExtensionTable extension_table = 8; + } + + // A base table. The list of string is used to represent namespacing (e.g., mydb.mytable). + // This assumes shared catalog between systems exchanging a message. + message NamedTable { + repeated string names = 1; + substrait.extensions.AdvancedExtension advanced_extension = 10; + } + + // A table composed of literals. + message VirtualTable { + repeated Expression.Literal.Struct values = 1; + } + + // A stub type that can be used to extend/introduce new table types outside + // the specification. + message ExtensionTable { + google.protobuf.Any detail = 1; + } + + // Represents a list of files in input of a scan operation + message LocalFiles { + repeated FileOrFiles items = 1; + substrait.extensions.AdvancedExtension advanced_extension = 10; + + // Many files consist of indivisible chunks (e.g. parquet row groups + // or CSV rows). If a slice partially selects an indivisible chunk + // then the consumer should employ some rule to decide which slice to + // include the chunk in (e.g. include it in the slice that contains + // the midpoint of the chunk) + message FileOrFiles { + oneof path_type { + // A URI that can refer to either a single folder or a single file + string uri_path = 1; + // A URI where the path portion is a glob expression that can + // identify zero or more paths. + // Consumers should support the POSIX syntax. The recursive + // globstar (**) may not be supported. + string uri_path_glob = 2; + // A URI that refers to a single file + string uri_file = 3; + // A URI that refers to a single folder + string uri_folder = 4; + } + + // Original file format enum, superseded by the file_format oneof. + reserved 5; + reserved "format"; + + // The index of the partition this item belongs to + uint64 partition_index = 6; + + // The start position in byte to read from this item + uint64 start = 7; + + // The length in byte to read from this item + uint64 length = 8; + + message ParquetReadOptions {} + message ArrowReadOptions {} + message OrcReadOptions {} + message DwrfReadOptions {} + + // The format of the files. + oneof file_format { + ParquetReadOptions parquet = 9; + ArrowReadOptions arrow = 10; + OrcReadOptions orc = 11; + google.protobuf.Any extension = 12; + DwrfReadOptions dwrf = 13; + } + } + } +} + +// This operator allows to represent calculated expressions of fields (e.g., a+b). Direct/Emit are used to represent classical relational projections +message ProjectRel { + RelCommon common = 1; + Rel input = 2; + repeated Expression expressions = 3; + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// The binary JOIN relational operator left-join-right, including various join types, a join condition and post_join_filter expression +message JoinRel { + RelCommon common = 1; + Rel left = 2; + Rel right = 3; + Expression expression = 4; + Expression post_join_filter = 5; + + JoinType type = 6; + + enum JoinType { + JOIN_TYPE_UNSPECIFIED = 0; + JOIN_TYPE_INNER = 1; + JOIN_TYPE_OUTER = 2; + JOIN_TYPE_LEFT = 3; + JOIN_TYPE_RIGHT = 4; + JOIN_TYPE_LEFT_SEMI = 5; + JOIN_TYPE_RIGHT_SEMI = 6; + JOIN_TYPE_ANTI = 7; + // This join is useful for nested sub-queries where we need exactly one tuple in output (or throw exception) + // See Section 3.2 of https://15721.courses.cs.cmu.edu/spring2018/papers/16-optimizer2/hyperjoins-btw2017.pdf + JOIN_TYPE_SINGLE = 8; + } + + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// Cartesian product relational operator of two tables (left and right) +message CrossRel { + RelCommon common = 1; + Rel left = 2; + Rel right = 3; + + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// The relational operator representing LIMIT/OFFSET or TOP type semantics. +message FetchRel { + RelCommon common = 1; + Rel input = 2; + // the offset expressed in number of records + int64 offset = 3; + // the amount of records to return + int64 count = 4; + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// The relational operator representing a GROUP BY Aggregate +message AggregateRel { + RelCommon common = 1; + + // Input of the aggregation + Rel input = 2; + + // A list of expression grouping that the aggregation measured should be calculated for. + repeated Grouping groupings = 3; + + // A list of one or more aggregate expressions along with an optional filter. + repeated Measure measures = 4; + + substrait.extensions.AdvancedExtension advanced_extension = 10; + + message Grouping { + repeated Expression grouping_expressions = 1; + } + + message Measure { + AggregateFunction measure = 1; + + // An optional boolean expression that acts to filter which records are + // included in the measure. True means include this record for calculation + // within the measure. + // Helps to support SUM() FILTER(WHERE...) syntax without masking opportunities for optimization + Expression filter = 2; + } +} + +// The ORDERY BY (or sorting) relational operator. Beside describing a base relation, it includes a list of fields to sort on +message SortRel { + RelCommon common = 1; + Rel input = 2; + repeated SortField sorts = 3; + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +message WindowRel { + RelCommon common = 1; + Rel input = 2; + repeated Measure measures = 3; + repeated Expression partition_expressions = 4; + repeated SortField sorts = 5; + substrait.extensions.AdvancedExtension advanced_extension = 10; + + message Measure { + Expression.WindowFunction measure = 1; + } +} + +// The relational operator capturing simple FILTERs (as in the WHERE clause of SQL) +message FilterRel { + RelCommon common = 1; + Rel input = 2; + Expression condition = 3; + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// The relational set operators (intersection/union/etc..) +message SetRel { + RelCommon common = 1; + // The first input is the primary input, the remaining are secondary + // inputs. There must be at least two inputs. + repeated Rel inputs = 2; + SetOp op = 3; + substrait.extensions.AdvancedExtension advanced_extension = 10; + + enum SetOp { + SET_OP_UNSPECIFIED = 0; + SET_OP_MINUS_PRIMARY = 1; + SET_OP_MINUS_MULTISET = 2; + SET_OP_INTERSECTION_PRIMARY = 3; + SET_OP_INTERSECTION_MULTISET = 4; + SET_OP_UNION_DISTINCT = 5; + SET_OP_UNION_ALL = 6; + } +} + +// Stub to support extension with a single input +message ExtensionSingleRel { + RelCommon common = 1; + Rel input = 2; + google.protobuf.Any detail = 3; +} + +// Stub to support extension with a zero inputs +message ExtensionLeafRel { + RelCommon common = 1; + google.protobuf.Any detail = 2; +} + +// Stub to support extension with multiple inputs +message ExtensionMultiRel { + RelCommon common = 1; + repeated Rel inputs = 2; + google.protobuf.Any detail = 3; +} + +// A redistribution operation +message ExchangeRel { + RelCommon common = 1; + Rel input = 2; + int32 partition_count = 3; + repeated ExchangeTarget targets = 4; + + // the type of exchange used + oneof exchange_kind { + ScatterFields scatter_by_fields = 5; + SingleBucketExpression single_target = 6; + MultiBucketExpression multi_target = 7; + RoundRobin round_robin = 8; + Broadcast broadcast = 9; + } + + substrait.extensions.AdvancedExtension advanced_extension = 10; + + message ScatterFields { + repeated Expression.FieldReference fields = 1; + } + + // Returns a single bucket number per record. + message SingleBucketExpression { + Expression expression = 1; + } + + // Returns zero or more bucket numbers per record + message MultiBucketExpression { + Expression expression = 1; + bool constrained_to_count = 2; + } + + // Send all data to every target. + message Broadcast {} + + // Route approximately + message RoundRobin { + // whether the round robin behavior is required to exact (per record) or + // approximate. Defaults to approximate. + bool exact = 1; + } + + // The message to describe partition targets of an exchange + message ExchangeTarget { + // Describes the partition id(s) to send. If this is empty, all data is sent + // to this target. + repeated int32 partition_id = 1; + + oneof target_type { + string uri = 2; + google.protobuf.Any extended = 3; + } + } +} + +message ExpandRel { + RelCommon common = 1; + Rel input = 2; + + repeated Expression aggregate_expressions = 3; + + // A list of expression grouping that the aggregation measured should be calculated for. + repeated GroupSets groupings = 4; + + message GroupSets { + repeated Expression groupSets_expressions = 1; + } + + string group_name = 5; + + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// A relation with output field names. +// +// This is for use at the root of a `Rel` tree. +message RelRoot { + // A relation + Rel input = 1; + // Field names in depth-first order + repeated string names = 2; +} + +// A relation (used internally in a plan) +message Rel { + oneof rel_type { + ReadRel read = 1; + FilterRel filter = 2; + FetchRel fetch = 3; + AggregateRel aggregate = 4; + SortRel sort = 5; + JoinRel join = 6; + ProjectRel project = 7; + SetRel set = 8; + ExtensionSingleRel extension_single = 9; + ExtensionMultiRel extension_multi = 10; + ExtensionLeafRel extension_leaf = 11; + CrossRel cross = 12; + //Physical relations + HashJoinRel hash_join = 13; + MergeJoinRel merge_join = 14; + ExpandRel expand = 15; + WindowRel window = 16; + GenerateRel generate = 17; + } +} + +// A base object for writing (e.g., a table or a view). +message NamedObjectWrite { + // The list of string is used to represent namespacing (e.g., mydb.mytable). + // This assumes shared catalog between systems exchanging a message. + repeated string names = 1; + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// A stub type that can be used to extend/introduce new table types outside +// the specification. +message ExtensionObject { + google.protobuf.Any detail = 1; +} + +message DdlRel { + // Definition of which type of object we are operating on + oneof write_type { + NamedObjectWrite named_object = 1; + ExtensionObject extension_object = 2; + } + + // The columns that will be modified (representing after-image of a schema change) + NamedStruct table_schema = 3; + // The default values for the columns (representing after-image of a schema change) + // E.g., in case of an ALTER TABLE that changes some of the column default values, we expect + // the table_defaults Struct to report a full list of default values reflecting the result of applying + // the ALTER TABLE operator successfully + Expression.Literal.Struct table_defaults = 4; + + // Which type of object we operate on + DdlObject object = 5; + + // The type of operation to perform + DdlOp op = 6; + + // The body of the CREATE VIEW + Rel view_definition = 7; + + enum DdlObject { + DDL_OBJECT_UNSPECIFIED = 0; + // A Table object in the system + DDL_OBJECT_TABLE = 1; + // A View object in the system + DDL_OBJECT_VIEW = 2; + } + + enum DdlOp { + DDL_OP_UNSPECIFIED = 0; + // A create operation (for any object) + DDL_OP_CREATE = 1; + // A create operation if the object does not exist, or replaces it (equivalent to a DROP + CREATE) if the object already exists + DDL_OP_CREATE_OR_REPLACE = 2; + // An operation that modifies the schema (e.g., column names, types, default values) for the target object + DDL_OP_ALTER = 3; + // An operation that removes an object from the system + DDL_OP_DROP = 4; + // An operation that removes an object from the system (without throwing an exception if the object did not exist) + DDL_OP_DROP_IF_EXIST = 5; + } + //TODO add PK/constraints/indexes/etc..? +} + +// The operator that modifies the content of a database (operates on 1 table at a time, but tuple-selection/source can be +// based on joining of multiple tables). +message WriteRel { + // Definition of which TABLE we are operating on + oneof write_type { + NamedObjectWrite named_table = 1; + ExtensionObject extension_table = 2; + } + + // The schema of the table (must align with Rel input (e.g., number of leaf fields must match)) + NamedStruct table_schema = 3; + + // The type of operation to perform + WriteOp op = 4; + + // The relation that determines the tuples to add/remove/modify + // the schema must match with table_schema. Default values must be explicitly stated + // in a ProjectRel at the top of the input. The match must also + // occur in case of DELETE to ensure multi-engine plans are unequivocal. + Rel input = 5; + + // Output mode determines what is the output of executing this rel + OutputMode output = 6; + + enum WriteOp { + WRITE_OP_UNSPECIFIED = 0; + // The insert of new tuples in a table + WRITE_OP_INSERT = 1; + // The removal of tuples from a table + WRITE_OP_DELETE = 2; + // The modification of existing tuples within a table + WRITE_OP_UPDATE = 3; + // The Creation of a new table, and the insert of new tuples in the table + WRITE_OP_CTAS = 4; + } + + enum OutputMode { + OUTPUT_MODE_UNSPECIFIED = 0; + // return no tuples at all + OUTPUT_MODE_NO_OUTPUT = 1; + // this mode makes the operator return all the tuple INSERTED/DELETED/UPDATED by the operator. + // The operator returns the AFTER-image of any change. This can be further manipulated by operators upstreams + // (e.g., retunring the typical "count of modified tuples"). + // For scenarios in which the BEFORE image is required, the user must implement a spool (via references to + // subplans in the body of the Rel input) and return those with anounter PlanRel.relations. + OUTPUT_MODE_MODIFIED_TUPLES = 2; + } +} + +// The hash equijoin join operator will build a hash table out of the right input based on a set of join keys. +// It will then probe that hash table for incoming inputs, finding matches. +message HashJoinRel { + RelCommon common = 1; + Rel left = 2; + Rel right = 3; + repeated Expression.FieldReference left_keys = 4; + repeated Expression.FieldReference right_keys = 5; + Expression post_join_filter = 6; + + JoinType type = 7; + + enum JoinType { + JOIN_TYPE_UNSPECIFIED = 0; + JOIN_TYPE_INNER = 1; + JOIN_TYPE_OUTER = 2; + JOIN_TYPE_LEFT = 3; + JOIN_TYPE_RIGHT = 4; + JOIN_TYPE_LEFT_SEMI = 5; + JOIN_TYPE_RIGHT_SEMI = 6; + JOIN_TYPE_LEFT_ANTI = 7; + JOIN_TYPE_RIGHT_ANTI = 8; + } + + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// The merge equijoin does a join by taking advantage of two sets that are sorted on the join keys. +// This allows the join operation to be done in a streaming fashion. +message MergeJoinRel { + RelCommon common = 1; + Rel left = 2; + Rel right = 3; + repeated Expression.FieldReference left_keys = 4; + repeated Expression.FieldReference right_keys = 5; + Expression post_join_filter = 6; + + JoinType type = 7; + + enum JoinType { + JOIN_TYPE_UNSPECIFIED = 0; + JOIN_TYPE_INNER = 1; + JOIN_TYPE_OUTER = 2; + JOIN_TYPE_LEFT = 3; + JOIN_TYPE_RIGHT = 4; + JOIN_TYPE_LEFT_SEMI = 5; + JOIN_TYPE_RIGHT_SEMI = 6; + JOIN_TYPE_LEFT_ANTI = 7; + JOIN_TYPE_RIGHT_ANTI = 8; + } + + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// The argument of a function +message FunctionArgument { + oneof arg_type { + string enum = 1; + Type type = 2; + Expression value = 3; + } +} + +// An optional function argument. Typically used for specifying behavior in +// invalid or corner cases. +message FunctionOption { + // Name of the option to set. If the consumer does not recognize the + // option, it must reject the plan. The name is matched case-insensitively + // with option names defined for the function. + string name = 1; + + // List of behavior options allowed by the producer. At least one must be + // specified; to leave an option unspecified, simply don't add an entry to + // `options`. The consumer must use the first option from the list that it + // supports. If the consumer supports none of the specified options, it + // must reject the plan. The name is matched case-insensitively and must + // match one of the option values defined for the option. + repeated string preference = 2; +} + +message Expression { + oneof rex_type { + Literal literal = 1; + FieldReference selection = 2; + ScalarFunction scalar_function = 3; + WindowFunction window_function = 5; + IfThen if_then = 6; + SwitchExpression switch_expression = 7; + SingularOrList singular_or_list = 8; + MultiOrList multi_or_list = 9; + Cast cast = 11; + Subquery subquery = 12; + Nested nested = 13; + + // deprecated: enum literals are only sensible in the context of + // function arguments, for which FunctionArgument should now be + // used + Enum enum = 10 [deprecated = true]; + } + + message Enum { + option deprecated = true; + + oneof enum_kind { + string specified = 1; + Empty unspecified = 2; + } + + message Empty { + option deprecated = true; + } + } + + message Literal { + oneof literal_type { + bool boolean = 1; + int32 i8 = 2; + int32 i16 = 3; + int32 i32 = 5; + int64 i64 = 7; + float fp32 = 10; + double fp64 = 11; + string string = 12; + bytes binary = 13; + // Timestamp in units of microseconds since the UNIX epoch. + int64 timestamp = 14; + // Date in units of days since the UNIX epoch. + int32 date = 16; + // Time in units of microseconds past midnight + int64 time = 17; + IntervalYearToMonth interval_year_to_month = 19; + IntervalDayToSecond interval_day_to_second = 20; + string fixed_char = 21; + VarChar var_char = 22; + bytes fixed_binary = 23; + Decimal decimal = 24; + Struct struct = 25; + Map map = 26; + // Timestamp in units of microseconds since the UNIX epoch. + int64 timestamp_tz = 27; + bytes uuid = 28; + Type null = 29; // a typed null literal + List list = 30; + Type.List empty_list = 31; + Type.Map empty_map = 32; + UserDefined user_defined = 33; + } + + // whether the literal type should be treated as a nullable type. Applies to + // all members of union other than the Typed null (which should directly + // declare nullability). + bool nullable = 50; + + // optionally points to a type_variation_anchor defined in this plan. + // Applies to all members of union other than the Typed null (which should + // directly declare the type variation). + uint32 type_variation_reference = 51; + + message VarChar { + string value = 1; + uint32 length = 2; + } + + message Decimal { + // little-endian twos-complement integer representation of complete value + // (ignoring precision) Always 16 bytes in length + bytes value = 1; + // The maximum number of digits allowed in the value. + // the maximum precision is 38. + int32 precision = 2; + // declared scale of decimal literal + int32 scale = 3; + } + + message Map { + message KeyValue { + Literal key = 1; + Literal value = 2; + } + + repeated KeyValue key_values = 1; + } + + message IntervalYearToMonth { + int32 years = 1; + int32 months = 2; + } + + message IntervalDayToSecond { + int32 days = 1; + int32 seconds = 2; + int32 microseconds = 3; + } + + message Struct { + // A possibly heterogeneously typed list of literals + repeated Literal fields = 1; + } + + message List { + // A homogeneously typed list of literals + repeated Literal values = 1; + } + + message UserDefined { + // points to a type_anchor defined in this plan + uint32 type_reference = 1; + + // The parameters to be bound to the type class, if the type class is + // parameterizable. + repeated Type.Parameter type_parameters = 3; + + // the value of the literal, serialized using some type-specific + // protobuf message + google.protobuf.Any value = 2; + } + } + + // Expression to dynamically construct nested types. + message Nested { + // Whether the returned nested type is nullable. + bool nullable = 1; + + // Optionally points to a type_variation_anchor defined in this plan for + // the returned nested type. + uint32 type_variation_reference = 2; + + oneof nested_type { + Struct struct = 3; + List list = 4; + Map map = 5; + } + + message Map { + message KeyValue { + // Mandatory key/value expressions. + Expression key = 1; + Expression value = 2; + } + + // One or more key-value pairs. To specify an empty map, use + // Literal.empty_map (otherwise type information would be missing). + repeated KeyValue key_values = 1; + } + + message Struct { + // Zero or more possibly heterogeneously-typed list of expressions that + // form the struct fields. + repeated Expression fields = 1; + } + + message List { + // A homogeneously-typed list of one or more expressions that form the + // list entries. To specify an empty list, use Literal.empty_list + // (otherwise type information would be missing). + repeated Expression values = 1; + } + } + + // A scalar function call. + message ScalarFunction { + // Points to a function_anchor defined in this plan, which must refer + // to a scalar function in the associated YAML file. Required; avoid + // using anchor/reference zero. + uint32 function_reference = 1; + + // The arguments to be bound to the function. This must have exactly the + // number of arguments specified in the function definition, and the + // argument types must also match exactly: + // + // - Value arguments must be bound using FunctionArgument.value, and + // the expression in that must yield a value of a type that a function + // overload is defined for. + // - Type arguments must be bound using FunctionArgument.type. + // - Enum arguments must be bound using FunctionArgument.enum + // followed by Enum.specified, with a string that case-insensitively + // matches one of the allowed options. + repeated FunctionArgument arguments = 4; + + // Options to specify behavior for corner cases, or leave behavior + // unspecified if the consumer does not need specific behavior in these + // cases. + repeated FunctionOption options = 5; + + // Must be set to the return type of the function, exactly as derived + // using the declaration in the extension. + Type output_type = 3; + + // Deprecated; use arguments instead. + repeated Expression args = 2 [deprecated = true]; + } + + // A window function call. + message WindowFunction { + // Points to a function_anchor defined in this plan, which must refer + // to a window function in the associated YAML file. Required; 0 is + // considered to be a valid anchor/reference. + uint32 function_reference = 1; + + // The arguments to be bound to the function. This must have exactly the + // number of arguments specified in the function definition, and the + // argument types must also match exactly: + // + // - Value arguments must be bound using FunctionArgument.value, and + // the expression in that must yield a value of a type that a function + // overload is defined for. + // - Type arguments must be bound using FunctionArgument.type, and a + // function overload must be defined for that type. + // - Enum arguments must be bound using FunctionArgument.enum + // followed by Enum.specified, with a string that case-insensitively + // matches one of the allowed options. + repeated FunctionArgument arguments = 9; + + // Options to specify behavior for corner cases, or leave behavior + // unspecified if the consumer does not need specific behavior in these + // cases. + repeated FunctionOption options = 11; + + // Must be set to the return type of the function, exactly as derived + // using the declaration in the extension. + Type output_type = 7; + + // Describes which part of the window function to perform within the + // context of distributed algorithms. Required. Must be set to + // INITIAL_TO_RESULT for window functions that are not decomposable. + AggregationPhase phase = 6; + + // If specified, the records that are part of the window defined by + // upper_bound and lower_bound are ordered according to this list + // before they are aggregated. The first sort field has the highest + // priority; only if a sort field determines two records to be equivalent + // is the next field queried. This field is optional, and is only allowed + // if the window function is defined to support sorting. + repeated SortField sorts = 3; + + // Specifies whether equivalent records are merged before being aggregated. + // Optional, defaults to AGGREGATION_INVOCATION_ALL. + AggregateFunction.AggregationInvocation invocation = 10; + + // When one or more partition expressions are specified, two records are + // considered to be in the same partition if and only if these expressions + // yield an equal tuple of values for both. When computing the window + // function, only the subset of records within the bounds that are also in + // the same partition as the current record are aggregated. + repeated Expression partitions = 2; + + // Defines the record relative to the current record from which the window + // extends. The bound is inclusive. If the lower bound indexes a record + // greater than the upper bound, TODO (null range/no records passed? + // wrapping around as if lower/upper were swapped? error? null?). + // Optional; defaults to the start of the partition. + Bound lower_bound = 5; + + string column_name = 12; + WindowType window_type = 13; + + // Defines the record relative to the current record up to which the window + // extends. The bound is inclusive. If the upper bound indexes a record + // less than the lower bound, TODO (null range/no records passed? + // wrapping around as if lower/upper were swapped? error? null?). + // Optional; defaults to the end of the partition. + Bound upper_bound = 4; + + // Deprecated; use arguments instead. + repeated Expression args = 8 [deprecated = true]; + + // Defines one of the two boundaries for the window of a window function. + message Bound { + // Defines that the bound extends this far back from the current record. + message Preceding { + // A strictly positive integer specifying the number of records that + // the window extends back from the current record. Required. Use + // CurrentRow for offset zero and Following for negative offsets. + int64 offset = 1; + } + + // Defines that the bound extends this far ahead of the current record. + message Following { + // A strictly positive integer specifying the number of records that + // the window extends ahead of the current record. Required. Use + // CurrentRow for offset zero and Preceding for negative offsets. + int64 offset = 1; + } + + // Defines that the bound extends to or from the current record. + message CurrentRow {} + + message Unbounded_Preceding {} + + message Unbounded_Following {} + + oneof kind { + // The bound extends some number of records behind the current record. + Preceding preceding = 1; + + // The bound extends some number of records ahead of the current + // record. + Following following = 2; + + // The bound extends to the current record. + CurrentRow current_row = 3; + + Unbounded_Preceding unbounded_preceding = 4; + Unbounded_Following unbounded_following = 5; + } + } + } + + message IfThen { + repeated IfClause ifs = 1; + Expression else = 2; + + message IfClause { + Expression if = 1; + Expression then = 2; + } + } + + message Cast { + Type type = 1; + Expression input = 2; + FailureBehavior failure_behavior = 3; + + enum FailureBehavior { + FAILURE_BEHAVIOR_UNSPECIFIED = 0; + FAILURE_BEHAVIOR_RETURN_NULL = 1; + FAILURE_BEHAVIOR_THROW_EXCEPTION = 2; + } + } + + message SwitchExpression { + Expression match = 3; + repeated IfValue ifs = 1; + Expression else = 2; + + message IfValue { + Literal if = 1; + Expression then = 2; + } + } + + message SingularOrList { + Expression value = 1; + repeated Expression options = 2; + } + + message MultiOrList { + repeated Expression value = 1; + repeated Record options = 2; + + message Record { + repeated Expression fields = 1; + } + } + + message EmbeddedFunction { + repeated Expression arguments = 1; + Type output_type = 2; + oneof kind { + PythonPickleFunction python_pickle_function = 3; + WebAssemblyFunction web_assembly_function = 4; + } + + message PythonPickleFunction { + bytes function = 1; + repeated string prerequisite = 2; + } + + message WebAssemblyFunction { + bytes script = 1; + repeated string prerequisite = 2; + } + } + + // A way to reference the inner property of a complex record. Can reference + // either a map key by literal, a struct field by the ordinal position of + // the desired field or a particular element in an array. Supports + // expressions that would roughly translate to something similar to: + // a.b[2].c['my_map_key'].x where a,b,c and x are struct field references + // (ordinalized in the internal representation here), [2] is a list offset + // and ['my_map_key'] is a reference into a map field. + message ReferenceSegment { + oneof reference_type { + MapKey map_key = 1; + StructField struct_field = 2; + ListElement list_element = 3; + } + + message MapKey { + // literal based reference to specific possible value in map. + Literal map_key = 1; + + // Optional child segment + ReferenceSegment child = 2; + } + + message StructField { + // zero-indexed ordinal position of field in struct + int32 field = 1; + + // Optional child segment + ReferenceSegment child = 2; + } + + message ListElement { + // zero-indexed ordinal position of element in list + int32 offset = 1; + + // Optional child segment + ReferenceSegment child = 2; + } + } + + // A reference that takes an existing subtype and selectively removes fields + // from it. For example, one might initially have an inner struct with 100 + // fields but a a particular operation only needs to interact with only 2 of + // those 100 fields. In this situation, one would use a mask expression to + // eliminate the 98 fields that are not relevant to the rest of the operation + // pipeline. + // + // Note that this does not fundamentally alter the structure of data beyond + // the elimination of unecessary elements. + message MaskExpression { + StructSelect select = 1; + bool maintain_singular_struct = 2; + + message Select { + oneof type { + StructSelect struct = 1; + ListSelect list = 2; + MapSelect map = 3; + } + } + + message StructSelect { + repeated StructItem struct_items = 1; + } + + message StructItem { + int32 field = 1; + Select child = 2; + } + + message ListSelect { + repeated ListSelectItem selection = 1; + Select child = 2; + + message ListSelectItem { + oneof type { + ListElement item = 1; + ListSlice slice = 2; + } + + message ListElement { + int32 field = 1; + } + + message ListSlice { + int32 start = 1; + int32 end = 2; + } + } + } + + message MapSelect { + oneof select { + MapKey key = 1; + MapKeyExpression expression = 2; + } + + Select child = 3; + + message MapKey { + string map_key = 1; + } + + message MapKeyExpression { + string map_key_expression = 1; + } + } + } + + // A reference to an inner part of a complex object. Can reference reference a + // single element or a masked version of elements + message FieldReference { + // Whether this is composed of a single element reference or a masked + // element subtree + oneof reference_type { + ReferenceSegment direct_reference = 1; + MaskExpression masked_reference = 2; + } + + // Whether this reference has an origin of a root struct or is based on the + // ouput of an expression. When this is a RootReference and direct_reference + // above is used, the direct_reference must be of a type StructField. + oneof root_type { + Expression expression = 3; + RootReference root_reference = 4; + OuterReference outer_reference = 5; + } + + // Singleton that expresses this FieldReference is rooted off the root + // incoming record type + message RootReference {} + + // A root reference for the outer relation's subquery + message OuterReference { + // number of subquery boundaries to traverse up for this field's reference + // + // This value must be >= 1 + uint32 steps_out = 1; + } + } + + // Subquery relation expression + message Subquery { + oneof subquery_type { + // Scalar subquery + Scalar scalar = 1; + // x IN y predicate + InPredicate in_predicate = 2; + // EXISTS/UNIQUE predicate + SetPredicate set_predicate = 3; + // ANY/ALL predicate + SetComparison set_comparison = 4; + } + + // A subquery with one row and one column. This is often an aggregate + // though not required to be. + message Scalar { + Rel input = 1; + } + + // Predicate checking that the left expression is contained in the right + // subquery + // + // Examples: + // + // x IN (SELECT * FROM t) + // (x, y) IN (SELECT a, b FROM t) + message InPredicate { + repeated Expression needles = 1; + Rel haystack = 2; + } + + // A predicate over a set of rows in the form of a subquery + // EXISTS and UNIQUE are common SQL forms of this operation. + message SetPredicate { + enum PredicateOp { + PREDICATE_OP_UNSPECIFIED = 0; + PREDICATE_OP_EXISTS = 1; + PREDICATE_OP_UNIQUE = 2; + } + // TODO: should allow expressions + PredicateOp predicate_op = 1; + Rel tuples = 2; + } + + // A subquery comparison using ANY or ALL. + // Examples: + // + // SELECT * + // FROM t1 + // WHERE x < ANY(SELECT y from t2) + message SetComparison { + enum ComparisonOp { + COMPARISON_OP_UNSPECIFIED = 0; + COMPARISON_OP_EQ = 1; + COMPARISON_OP_NE = 2; + COMPARISON_OP_LT = 3; + COMPARISON_OP_GT = 4; + COMPARISON_OP_LE = 5; + COMPARISON_OP_GE = 6; + } + + enum ReductionOp { + REDUCTION_OP_UNSPECIFIED = 0; + REDUCTION_OP_ANY = 1; + REDUCTION_OP_ALL = 2; + } + + // ANY or ALL + ReductionOp reduction_op = 1; + // A comparison operator + ComparisonOp comparison_op = 2; + // left side of the expression + Expression left = 3; + // right side of the expression + Rel right = 4; + } + } +} + +message GenerateRel { + RelCommon common = 1; + Rel input = 2; + + Expression generator = 3; + repeated Expression child_output = 4; + bool outer = 5; + + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + +// The description of a field to sort on (including the direction of sorting and null semantics) +message SortField { + Expression expr = 1; + + oneof sort_kind { + SortDirection direction = 2; + uint32 comparison_function_reference = 3; + } + enum SortDirection { + SORT_DIRECTION_UNSPECIFIED = 0; + SORT_DIRECTION_ASC_NULLS_FIRST = 1; + SORT_DIRECTION_ASC_NULLS_LAST = 2; + SORT_DIRECTION_DESC_NULLS_FIRST = 3; + SORT_DIRECTION_DESC_NULLS_LAST = 4; + SORT_DIRECTION_CLUSTERED = 5; + } +} + +// Describes which part of an aggregation or window function to perform within +// the context of distributed algorithms. +enum AggregationPhase { + // Implies `INTERMEDIATE_TO_RESULT`. + AGGREGATION_PHASE_UNSPECIFIED = 0; + + // Specifies that the function should be run only up to the point of + // generating an intermediate value, to be further aggregated later using + // INTERMEDIATE_TO_INTERMEDIATE or INTERMEDIATE_TO_RESULT. + AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE = 1; + + // Specifies that the inputs of the aggregate or window function are the + // intermediate values of the function, and that the output should also be + // an intermediate value, to be further aggregated later using + // INTERMEDIATE_TO_INTERMEDIATE or INTERMEDIATE_TO_RESULT. + AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE = 2; + + // A complete invocation: the function should aggregate the given set of + // inputs to yield a single return value. This style must be used for + // aggregate or window functions that are not decomposable. + AGGREGATION_PHASE_INITIAL_TO_RESULT = 3; + + // Specifies that the inputs of the aggregate or window function are the + // intermediate values of the function, generated previously using + // INITIAL_TO_INTERMEDIATE and possibly INTERMEDIATE_TO_INTERMEDIATE calls. + // This call should combine the intermediate values to yield the final + // return value. + AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT = 4; +} + +enum WindowType { + ROWS = 0; + RANGE = 1; +} + +// An aggregate function. +message AggregateFunction { + // Points to a function_anchor defined in this plan, which must refer + // to an aggregate function in the associated YAML file. Required; 0 is + // considered to be a valid anchor/reference. + uint32 function_reference = 1; + + // The arguments to be bound to the function. This must have exactly the + // number of arguments specified in the function definition, and the + // argument types must also match exactly: + // + // - Value arguments must be bound using FunctionArgument.value, and + // the expression in that must yield a value of a type that a function + // overload is defined for. + // - Type arguments must be bound using FunctionArgument.type, and a + // function overload must be defined for that type. + // - Enum arguments must be bound using FunctionArgument.enum + // followed by Enum.specified, with a string that case-insensitively + // matches one of the allowed options. + // - Optional enum arguments must be bound using FunctionArgument.enum + // followed by either Enum.specified or Enum.unspecified. If specified, + // the string must case-insensitively match one of the allowed options. + repeated FunctionArgument arguments = 7; + + // Options to specify behavior for corner cases, or leave behavior + // unspecified if the consumer does not need specific behavior in these + // cases. + repeated FunctionOption options = 8; + + // Must be set to the return type of the function, exactly as derived + // using the declaration in the extension. + Type output_type = 5; + + // Describes which part of the aggregation to perform within the context of + // distributed algorithms. Required. Must be set to INITIAL_TO_RESULT for + // aggregate functions that are not decomposable. + AggregationPhase phase = 4; + + // If specified, the aggregated records are ordered according to this list + // before they are aggregated. The first sort field has the highest + // priority; only if a sort field determines two records to be equivalent is + // the next field queried. This field is optional. + repeated SortField sorts = 3; + + // Specifies whether equivalent records are merged before being aggregated. + // Optional, defaults to AGGREGATION_INVOCATION_ALL. + AggregationInvocation invocation = 6; + + // deprecated; use arguments instead + repeated Expression args = 2 [deprecated = true]; + + // Method in which equivalent records are merged before being aggregated. + enum AggregationInvocation { + // This default value implies AGGREGATION_INVOCATION_ALL. + AGGREGATION_INVOCATION_UNSPECIFIED = 0; + + // Use all values in the aggregation calculation. + AGGREGATION_INVOCATION_ALL = 1; + + // Use only distinct values in the aggregation calculation. + AGGREGATION_INVOCATION_DISTINCT = 2; + } + + // This rel is used to create references, + // in case we refer to a RelRoot field names will be ignored + message ReferenceRel { + int32 subtree_ordinal = 1; + } +} diff --git a/utils/local-engine/proto/substrait/capabilities.proto b/utils/local-engine/proto/substrait/capabilities.proto new file mode 100644 index 000000000000..9a001158b8eb --- /dev/null +++ b/utils/local-engine/proto/substrait/capabilities.proto @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package substrait; + +option cc_enable_arenas = true; +option csharp_namespace = "Substrait.Protobuf"; +option go_package = "github.com/substrait-io/substrait-go/proto"; +option java_multiple_files = true; +option java_package = "io.substrait.proto"; + +// Defines a set of Capabilities that a system (producer or consumer) supports. +message Capabilities { + // List of Substrait versions this system supports + repeated string substrait_versions = 1; + + // list of com.google.Any message types this system supports for advanced + // extensions. + repeated string advanced_extension_type_urls = 2; + + // list of simple extensions this system supports. + repeated SimpleExtension simple_extensions = 3; + + message SimpleExtension { + string uri = 1; + repeated string function_keys = 2; + repeated string type_keys = 3; + repeated string type_variation_keys = 4; + } +} diff --git a/utils/local-engine/proto/substrait/ddl.proto b/utils/local-engine/proto/substrait/ddl.proto new file mode 100644 index 000000000000..833ec87369ae --- /dev/null +++ b/utils/local-engine/proto/substrait/ddl.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package substrait; + +import "substrait/plan.proto"; +import "substrait/algebra.proto"; + +option java_multiple_files = true; +option java_package = "io.substrait.proto"; +option csharp_namespace = "Substrait.Protobuf"; + +message DllPlan { + oneof dll_type { + InsertPlan insert_plan = 1; + } +} + +message InsertPlan { + Plan input = 1; + ReadRel.ExtensionTable output = 2; +} + +message Dll { + repeated DllPlan dll_plan = 1; +} \ No newline at end of file diff --git a/utils/local-engine/proto/substrait/extended_expression.proto b/utils/local-engine/proto/substrait/extended_expression.proto new file mode 100755 index 000000000000..5d1152055930 --- /dev/null +++ b/utils/local-engine/proto/substrait/extended_expression.proto @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package substrait; + +import "substrait/algebra.proto"; +import "substrait/extensions/extensions.proto"; +import "substrait/plan.proto"; +import "substrait/type.proto"; + +option csharp_namespace = "Substrait.Protobuf"; +option go_package = "github.com/substrait-io/substrait-go/proto"; +option java_multiple_files = true; +option java_package = "io.substrait.proto"; + +message ExpressionReference { + oneof expr_type { + Expression expression = 1; + AggregateFunction measure = 2; + } + // Field names in depth-first order + repeated string output_names = 3; +} + +// Describe a set of operations to complete. +// For compactness sake, identifiers are normalized at the plan level. +message ExtendedExpression { + // Substrait version of the expression. Optional up to 0.17.0, required for later + // versions. + Version version = 7; + + // a list of yaml specifications this expression may depend on + repeated substrait.extensions.SimpleExtensionURI extension_uris = 1; + + // a list of extensions this expression may depend on + repeated substrait.extensions.SimpleExtensionDeclaration extensions = 2; + + // one or more expression trees with same order in plan rel + repeated ExpressionReference referred_expr = 3; + + NamedStruct base_schema = 4; + // additional extensions associated with this expression. + substrait.extensions.AdvancedExtension advanced_extensions = 5; + + // A list of com.google.Any entities that this plan may use. Can be used to + // warn if some embedded message types are unknown. Note that this list may + // include message types that are ignorable (optimizations) or that are + // unused. In many cases, a consumer may be able to work with a plan even if + // one or more message types defined here are unknown. + repeated string expected_type_urls = 6; +} diff --git a/utils/local-engine/proto/substrait/extensions/extensions.proto b/utils/local-engine/proto/substrait/extensions/extensions.proto new file mode 100644 index 000000000000..4fd34b143fa0 --- /dev/null +++ b/utils/local-engine/proto/substrait/extensions/extensions.proto @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package substrait.extensions; + +import "google/protobuf/any.proto"; + +option cc_enable_arenas = true; +option csharp_namespace = "Substrait.Protobuf"; +option go_package = "github.com/substrait-io/substrait-go/proto/extensions"; +option java_multiple_files = true; +option java_package = "io.substrait.proto"; + +message SimpleExtensionURI { + // A surrogate key used in the context of a single plan used to reference the + // URI associated with an extension. + uint32 extension_uri_anchor = 1; + + // The URI where this extension YAML can be retrieved. This is the "namespace" + // of this extension. + string uri = 2; +} + +// Describes a mapping between a specific extension entity and the uri where +// that extension can be found. +message SimpleExtensionDeclaration { + oneof mapping_type { + ExtensionType extension_type = 1; + ExtensionTypeVariation extension_type_variation = 2; + ExtensionFunction extension_function = 3; + } + + // Describes a Type + message ExtensionType { + // references the extension_uri_anchor defined for a specific extension URI. + uint32 extension_uri_reference = 1; + + // A surrogate key used in the context of a single plan to reference a + // specific extension type + uint32 type_anchor = 2; + + // the name of the type in the defined extension YAML. + string name = 3; + } + + message ExtensionTypeVariation { + // references the extension_uri_anchor defined for a specific extension URI. + uint32 extension_uri_reference = 1; + + // A surrogate key used in the context of a single plan to reference a + // specific type variation + uint32 type_variation_anchor = 2; + + // the name of the type in the defined extension YAML. + string name = 3; + } + + message ExtensionFunction { + // references the extension_uri_anchor defined for a specific extension URI. + uint32 extension_uri_reference = 1; + + // A surrogate key used in the context of a single plan to reference a + // specific function + uint32 function_anchor = 2; + + // A simple name if there is only one impl for the function within the YAML. + // A compound name, referencing that includes type short names if there is + // more than one impl per name in the YAML. + string name = 3; + } +} + +// A generic object that can be used to embed additional extension information +// into the serialized substrait plan. +message AdvancedExtension { + // An optimization is helpful information that don't influence semantics. May + // be ignored by a consumer. + google.protobuf.Any optimization = 1; + + // An enhancement alter semantics. Cannot be ignored by a consumer. + google.protobuf.Any enhancement = 2; +} diff --git a/utils/local-engine/proto/substrait/function.proto b/utils/local-engine/proto/substrait/function.proto new file mode 100644 index 000000000000..0d09bef0eb8d --- /dev/null +++ b/utils/local-engine/proto/substrait/function.proto @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package substrait; + +import "substrait/parameterized_types.proto"; +import "substrait/type.proto"; +import "substrait/type_expressions.proto"; + +option cc_enable_arenas = true; +option csharp_namespace = "Substrait.Protobuf"; +option go_package = "github.com/substrait-io/substrait-go/proto"; +option java_multiple_files = true; +option java_package = "io.substrait.proto"; + +// List of function signatures available. +message FunctionSignature { + message FinalArgVariadic { + // the minimum number of arguments allowed for the list of final arguments + // (inclusive). + int64 min_args = 1; + + // the maximum number of arguments allowed for the list of final arguments + // (exclusive) + int64 max_args = 2; + + // the type of parameterized type consistency + ParameterConsistency consistency = 3; + + enum ParameterConsistency { + PARAMETER_CONSISTENCY_UNSPECIFIED = 0; + + // All argument must be the same concrete type. + PARAMETER_CONSISTENCY_CONSISTENT = 1; + + // Each argument can be any possible concrete type afforded by the bounds + // of any parameter defined in the arguments specification. + PARAMETER_CONSISTENCY_INCONSISTENT = 2; + } + } + + message FinalArgNormal {} + + message Scalar { + repeated Argument arguments = 2; + repeated string name = 3; + Description description = 4; + + bool deterministic = 7; + bool session_dependent = 8; + + DerivationExpression output_type = 9; + + oneof final_variable_behavior { + FinalArgVariadic variadic = 10; + FinalArgNormal normal = 11; + } + + repeated Implementation implementations = 12; + } + + message Aggregate { + repeated Argument arguments = 2; + string name = 3; + Description description = 4; + + bool deterministic = 7; + bool session_dependent = 8; + + DerivationExpression output_type = 9; + + oneof final_variable_behavior { + FinalArgVariadic variadic = 10; + FinalArgNormal normal = 11; + } + + bool ordered = 14; + uint64 max_set = 12; + Type intermediate_type = 13; + + repeated Implementation implementations = 15; + } + + message Window { + repeated Argument arguments = 2; + repeated string name = 3; + Description description = 4; + + bool deterministic = 7; + bool session_dependent = 8; + + DerivationExpression intermediate_type = 9; + DerivationExpression output_type = 10; + oneof final_variable_behavior { + FinalArgVariadic variadic = 16; + FinalArgNormal normal = 17; + } + bool ordered = 11; + uint64 max_set = 12; + WindowType window_type = 14; + repeated Implementation implementations = 15; + + enum WindowType { + WINDOW_TYPE_UNSPECIFIED = 0; + WINDOW_TYPE_STREAMING = 1; + WINDOW_TYPE_PARTITION = 2; + } + } + + message Description { + string language = 1; + string body = 2; + } + + message Implementation { + Type type = 1; + string uri = 2; + + enum Type { + TYPE_UNSPECIFIED = 0; + TYPE_WEB_ASSEMBLY = 1; + TYPE_TRINO_JAR = 2; + } + } + + message Argument { + string name = 1; + + oneof argument_kind { + ValueArgument value = 2; + TypeArgument type = 3; + EnumArgument enum = 4; + } + + message ValueArgument { + ParameterizedType type = 1; + bool constant = 2; + } + + message TypeArgument { + ParameterizedType type = 1; + } + + message EnumArgument { + repeated string options = 1; + bool optional = 2; + } + } +} diff --git a/utils/local-engine/proto/substrait/parameterized_types.proto b/utils/local-engine/proto/substrait/parameterized_types.proto new file mode 100644 index 000000000000..9b83fc43300e --- /dev/null +++ b/utils/local-engine/proto/substrait/parameterized_types.proto @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package substrait; + +import "substrait/type.proto"; + +option cc_enable_arenas = true; +option csharp_namespace = "Substrait.Protobuf"; +option go_package = "github.com/substrait-io/substrait-go/proto"; +option java_multiple_files = true; +option java_package = "io.substrait.proto"; + +message ParameterizedType { + oneof kind { + Type.Boolean bool = 1; + Type.I8 i8 = 2; + Type.I16 i16 = 3; + Type.I32 i32 = 5; + Type.I64 i64 = 7; + Type.FP32 fp32 = 10; + Type.FP64 fp64 = 11; + Type.String string = 12; + Type.Binary binary = 13; + Type.Timestamp timestamp = 14; + Type.Date date = 16; + Type.Time time = 17; + Type.IntervalYear interval_year = 19; + Type.IntervalDay interval_day = 20; + Type.TimestampTZ timestamp_tz = 29; + Type.UUID uuid = 32; + + ParameterizedFixedChar fixed_char = 21; + ParameterizedVarChar varchar = 22; + ParameterizedFixedBinary fixed_binary = 23; + ParameterizedDecimal decimal = 24; + + ParameterizedStruct struct = 25; + ParameterizedList list = 27; + ParameterizedMap map = 28; + + ParameterizedUserDefined user_defined = 30; + + // Deprecated in favor of user_defined, which allows nullability and + // variations to be specified. If user_defined_pointer is encountered, + // treat it as being non-nullable and having the default variation. + uint32 user_defined_pointer = 31 [deprecated = true]; + + TypeParameter type_parameter = 33; + } + + message TypeParameter { + string name = 1; + repeated ParameterizedType bounds = 2; + } + + message IntegerParameter { + string name = 1; + NullableInteger range_start_inclusive = 2; + NullableInteger range_end_exclusive = 3; + } + + message NullableInteger { + int64 value = 1; + } + + message ParameterizedFixedChar { + IntegerOption length = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ParameterizedVarChar { + IntegerOption length = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ParameterizedFixedBinary { + IntegerOption length = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ParameterizedDecimal { + IntegerOption scale = 1; + IntegerOption precision = 2; + uint32 variation_pointer = 3; + Type.Nullability nullability = 4; + } + + message ParameterizedStruct { + repeated ParameterizedType types = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ParameterizedNamedStruct { + // list of names in dfs order + repeated string names = 1; + ParameterizedStruct struct = 2; + } + + message ParameterizedList { + ParameterizedType type = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ParameterizedMap { + ParameterizedType key = 1; + ParameterizedType value = 2; + uint32 variation_pointer = 3; + Type.Nullability nullability = 4; + } + + message ParameterizedUserDefined { + uint32 type_pointer = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message IntegerOption { + oneof integer_type { + int32 literal = 1; + IntegerParameter parameter = 2; + } + } +} diff --git a/utils/local-engine/proto/substrait/plan.proto b/utils/local-engine/proto/substrait/plan.proto new file mode 100644 index 000000000000..b6aee4d424da --- /dev/null +++ b/utils/local-engine/proto/substrait/plan.proto @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package substrait; + +import "substrait/algebra.proto"; +import "substrait/extensions/extensions.proto"; + +option cc_enable_arenas = true; +option csharp_namespace = "Substrait.Protobuf"; +option go_package = "github.com/substrait-io/substrait-go/proto"; +option java_multiple_files = true; +option java_package = "io.substrait.proto"; + +// Either a relation or root relation +message PlanRel { + oneof rel_type { + // Any relation (used for references and CTEs) + Rel rel = 1; + // The root of a relation tree + RelRoot root = 2; + } +} + +// Describe a set of operations to complete. +// For compactness sake, identifiers are normalized at the plan level. +message Plan { + // Substrait version of the plan. Optional up to 0.17.0, required for later + // versions. + Version version = 6; + + // a list of yaml specifications this plan may depend on + repeated substrait.extensions.SimpleExtensionURI extension_uris = 1; + + // a list of extensions this plan may depend on + repeated substrait.extensions.SimpleExtensionDeclaration extensions = 2; + + // one or more relation trees that are associated with this plan. + repeated PlanRel relations = 3; + + // additional extensions associated with this plan. + substrait.extensions.AdvancedExtension advanced_extensions = 4; + + // A list of com.google.Any entities that this plan may use. Can be used to + // warn if some embedded message types are unknown. Note that this list may + // include message types that are ignorable (optimizations) or that are + // unused. In many cases, a consumer may be able to work with a plan even if + // one or more message types defined here are unknown. + repeated string expected_type_urls = 5; +} + +// This message type can be used to deserialize only the version of a Substrait +// Plan message. This prevents deserialization errors when there were breaking +// changes between the Substrait version of the tool that produced the plan and +// the Substrait version used to deserialize it, such that a consumer can emit +// a more helpful error message in this case. +message PlanVersion { + Version version = 6; +} + +message Version { + // Substrait version number. + uint32 major_number = 1; + uint32 minor_number = 2; + uint32 patch_number = 3; + + // If a particular version of Substrait is used that does not correspond to + // a version number exactly (for example when using an unofficial fork or + // using a version that is not yet released or is between versions), set this + // to the full git hash of the utilized commit of + // https://github.com/substrait-io/substrait (or fork thereof), represented + // using a lowercase hex ASCII string 40 characters in length. The version + // number above should be set to the most recent version tag in the history + // of that commit. + string git_hash = 4; + + // Identifying information for the producer that created this plan. Under + // ideal circumstances, consumers should not need this information. However, + // it is foreseen that consumers may need to work around bugs in particular + // producers in practice, and therefore may need to know which producer + // created the plan. + string producer = 5; +} diff --git a/utils/local-engine/proto/substrait/type.proto b/utils/local-engine/proto/substrait/type.proto new file mode 100644 index 000000000000..5d4a8f918b83 --- /dev/null +++ b/utils/local-engine/proto/substrait/type.proto @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package substrait; + +import "google/protobuf/empty.proto"; + +option cc_enable_arenas = true; +option csharp_namespace = "Substrait.Protobuf"; +option go_package = "github.com/substrait-io/substrait-go/proto"; +option java_multiple_files = true; +option java_package = "io.substrait.proto"; + +message Type { + oneof kind { + Boolean bool = 1; + I8 i8 = 2; + I16 i16 = 3; + I32 i32 = 5; + I64 i64 = 7; + FP32 fp32 = 10; + FP64 fp64 = 11; + String string = 12; + Binary binary = 13; + Timestamp timestamp = 14; + Date date = 16; + Time time = 17; + IntervalYear interval_year = 19; + IntervalDay interval_day = 20; + TimestampTZ timestamp_tz = 29; + UUID uuid = 32; + + FixedChar fixed_char = 21; + VarChar varchar = 22; + FixedBinary fixed_binary = 23; + Decimal decimal = 24; + + Struct struct = 25; + List list = 27; + Map map = 28; + + UserDefined user_defined = 30; + + // Deprecated in favor of user_defined, which allows nullability and + // variations to be specified. If user_defined_type_reference is + // encountered, treat it as being non-nullable and having the default + // variation. + uint32 user_defined_type_reference = 31 [deprecated = true]; + + Nothing nothing = 33; + } + + enum Nullability { + NULLABILITY_UNSPECIFIED = 0; + NULLABILITY_NULLABLE = 1; + NULLABILITY_REQUIRED = 2; + } + + message Nothing { + uint32 type_variation_reference = 1; + } + + message Boolean { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message I8 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message I16 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message I32 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message I64 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message FP32 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message FP64 { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message String { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message Binary { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message Timestamp { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message Date { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message Time { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message TimestampTZ { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message IntervalYear { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message IntervalDay { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + message UUID { + uint32 type_variation_reference = 1; + Nullability nullability = 2; + } + + // Start compound types. + message FixedChar { + int32 length = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message VarChar { + int32 length = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message FixedBinary { + int32 length = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message Decimal { + int32 scale = 1; + int32 precision = 2; + uint32 type_variation_reference = 3; + Nullability nullability = 4; + } + + message Struct { + repeated Type types = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message List { + Type type = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + } + + message Map { + Type key = 1; + Type value = 2; + uint32 type_variation_reference = 3; + Nullability nullability = 4; + } + + message UserDefined { + uint32 type_reference = 1; + uint32 type_variation_reference = 2; + Nullability nullability = 3; + repeated Parameter type_parameters = 4; + } + + message Parameter { + oneof parameter { + // Explicitly null/unspecified parameter, to select the default value (if + // any). + google.protobuf.Empty null = 1; + + // Data type parameters, like the i32 in LIST. + Type data_type = 2; + + // Value parameters, like the 10 in VARCHAR<10>. + bool boolean = 3; + int64 integer = 4; + string enum = 5; + string string = 6; + } + } +} + +// A message for modeling name/type pairs. +// +// Useful for representing relation schemas. +// +// Notes: +// +// * The names field is in depth-first order. +// +// For example a schema such as: +// +// a: int64 +// b: struct +// +// would have a `names` field that looks like: +// +// ["a", "b", "c", "d"] +// +// * Only struct fields are contained in this field's elements, +// * Map keys should be traversed first, then values when producing/consuming +message NamedStruct { + // list of names in dfs order + repeated string names = 1; + Type.Struct struct = 2; + PartitionColumns partition_columns = 3; +} + +message PartitionColumns { + repeated ColumnType column_type = 1; + enum ColumnType { + NORMAL_COL = 0; + PARTITION_COL = 1; + } +} diff --git a/utils/local-engine/proto/substrait/type_expressions.proto b/utils/local-engine/proto/substrait/type_expressions.proto new file mode 100644 index 000000000000..fcfeb641b273 --- /dev/null +++ b/utils/local-engine/proto/substrait/type_expressions.proto @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package substrait; + +import "substrait/type.proto"; + +option cc_enable_arenas = true; +option csharp_namespace = "Substrait.Protobuf"; +option go_package = "github.com/substrait-io/substrait-go/proto"; +option java_multiple_files = true; +option java_package = "io.substrait.proto"; + +message DerivationExpression { + oneof kind { + Type.Boolean bool = 1; + Type.I8 i8 = 2; + Type.I16 i16 = 3; + Type.I32 i32 = 5; + Type.I64 i64 = 7; + Type.FP32 fp32 = 10; + Type.FP64 fp64 = 11; + Type.String string = 12; + Type.Binary binary = 13; + Type.Timestamp timestamp = 14; + Type.Date date = 16; + Type.Time time = 17; + Type.IntervalYear interval_year = 19; + Type.IntervalDay interval_day = 20; + Type.TimestampTZ timestamp_tz = 29; + Type.UUID uuid = 32; + + ExpressionFixedChar fixed_char = 21; + ExpressionVarChar varchar = 22; + ExpressionFixedBinary fixed_binary = 23; + ExpressionDecimal decimal = 24; + + ExpressionStruct struct = 25; + ExpressionList list = 27; + ExpressionMap map = 28; + + ExpressionUserDefined user_defined = 30; + + // Deprecated in favor of user_defined, which allows nullability and + // variations to be specified. If user_defined_pointer is encountered, + // treat it as being non-nullable and having the default variation. + uint32 user_defined_pointer = 31 [deprecated = true]; + + string type_parameter_name = 33; + string integer_parameter_name = 34; + + int32 integer_literal = 35; + UnaryOp unary_op = 36; + BinaryOp binary_op = 37; + IfElse if_else = 38; + ReturnProgram return_program = 39; + } + + message ExpressionFixedChar { + DerivationExpression length = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ExpressionVarChar { + DerivationExpression length = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ExpressionFixedBinary { + DerivationExpression length = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ExpressionDecimal { + DerivationExpression scale = 1; + DerivationExpression precision = 2; + uint32 variation_pointer = 3; + Type.Nullability nullability = 4; + } + + message ExpressionStruct { + repeated DerivationExpression types = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ExpressionNamedStruct { + repeated string names = 1; + ExpressionStruct struct = 2; + } + + message ExpressionList { + DerivationExpression type = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message ExpressionMap { + DerivationExpression key = 1; + DerivationExpression value = 2; + uint32 variation_pointer = 3; + Type.Nullability nullability = 4; + } + + message ExpressionUserDefined { + uint32 type_pointer = 1; + uint32 variation_pointer = 2; + Type.Nullability nullability = 3; + } + + message IfElse { + DerivationExpression if_condition = 1; + DerivationExpression if_return = 2; + DerivationExpression else_return = 3; + } + + message UnaryOp { + UnaryOpType op_type = 1; + DerivationExpression arg = 2; + + enum UnaryOpType { + UNARY_OP_TYPE_UNSPECIFIED = 0; + UNARY_OP_TYPE_BOOLEAN_NOT = 1; + } + } + + message BinaryOp { + BinaryOpType op_type = 1; + DerivationExpression arg1 = 2; + DerivationExpression arg2 = 3; + + enum BinaryOpType { + BINARY_OP_TYPE_UNSPECIFIED = 0; + BINARY_OP_TYPE_PLUS = 1; + BINARY_OP_TYPE_MINUS = 2; + BINARY_OP_TYPE_MULTIPLY = 3; + BINARY_OP_TYPE_DIVIDE = 4; + BINARY_OP_TYPE_MIN = 5; + BINARY_OP_TYPE_MAX = 6; + BINARY_OP_TYPE_GREATER_THAN = 7; + BINARY_OP_TYPE_LESS_THAN = 8; + BINARY_OP_TYPE_AND = 9; + BINARY_OP_TYPE_OR = 10; + BINARY_OP_TYPE_EQUALS = 11; + BINARY_OP_TYPE_COVERS = 12; + } + } + + message ReturnProgram { + message Assignment { + string name = 1; + DerivationExpression expression = 2; + } + + repeated Assignment assignments = 1; + DerivationExpression final_expression = 2; + } +} diff --git a/utils/local-engine/tests/CMakeLists.txt b/utils/local-engine/tests/CMakeLists.txt new file mode 100644 index 000000000000..5a36c33e3ab4 --- /dev/null +++ b/utils/local-engine/tests/CMakeLists.txt @@ -0,0 +1,53 @@ +macro (grep_gtest_sources BASE_DIR DST_VAR) + # Cold match files that are not in tests/ directories + file(GLOB_RECURSE "${DST_VAR}" RELATIVE "${BASE_DIR}" "gtest*.cpp") +endmacro() + +set(USE_INTERNAL_GTEST_LIBRARY 0) +set(BENCHMARK_ENABLE_TESTING OFF) + +enable_testing() +include(CTest) + +include_directories(${GTEST_INCLUDE_DIRS}) + +set(TEST_DATA_DIR "${ClickHouse_SOURCE_DIR}/utils/local-engine/tests") + +configure_file( + ${ClickHouse_SOURCE_DIR}/utils/local-engine/tests/testConfig.h.in + ${ClickHouse_SOURCE_DIR}/utils/local-engine/tests/testConfig.h +) +set(HAVE_POSIX_REGEX 1) +include(FetchContent) +FetchContent_Declare(googlebenchmark GIT_REPOSITORY https://github.com/google/benchmark GIT_TAG main) +FetchContent_MakeAvailable(googlebenchmark) +include_directories( + ${builder_headers} + ${parser_headers} +) + +target_compile_options(benchmark PUBLIC + -Wno-extra-semi-stmt + -Wno-format-nonliteral + -Wno-missing-noreturn + -Wno-old-style-cast + -Wno-undef + -Wno-used-but-marked-unused + -Wno-zero-as-null-pointer-constant + -Wno-shift-sign-overflow + -Wno-thread-safety-analysis + ) + +grep_gtest_sources("${ClickHouse_SOURCE_DIR}/utils/local_engine/tests" local_engine_gtest_sources) + +add_executable(unit_tests_local_engine ${local_engine_gtest_sources} ) + +add_executable(benchmark_local_engine benchmark_local_engine.cpp benchmark_parquet_read.cpp benchmark_spark_row.cpp) + +target_include_directories(unit_tests_local_engine PRIVATE + ${GTEST_INCLUDE_DIRS}/include + ) +include_directories(benchmark_local_engine SYSTEM PUBLIC ${FETCH_CONTENT_SOURCE_DIR_GOOGLEBENCHMARK}/include ${ClickHouse_SOURCE_DIR}/utils/local_engine) + +target_link_libraries(unit_tests_local_engine PRIVATE ${LOCALENGINE_SHARED_LIB} _gtest_all clickhouse_parsers) +target_link_libraries(benchmark_local_engine PRIVATE ${LOCALENGINE_SHARED_LIB} benchmark::benchmark) diff --git a/utils/local-engine/tests/benchmark_local_engine.cpp b/utils/local-engine/tests/benchmark_local_engine.cpp new file mode 100644 index 000000000000..8af616489cda --- /dev/null +++ b/utils/local-engine/tests/benchmark_local_engine.cpp @@ -0,0 +1,1514 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "testConfig.h" + +#if defined(__SSE2__) +# include +#endif + + +using namespace local_engine; +using namespace dbms; + +DB::ContextMutablePtr global_context; + +[[maybe_unused]] static void BM_CHColumnToSparkRow(benchmark::State & state) +{ + for (auto _ : state) + { + state.PauseTiming(); + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_orderkey", "I64") + .column("l_partkey", "I64") + .column("l_suppkey", "I64") + .column("l_linenumber", "I32") + .column("l_quantity", "FP64") + .column("l_extendedprice", "FP64") + .column("l_discount", "FP64") + .column("l_tax", "FP64") + .column("l_returnflag", "String") + .column("l_linestatus", "String") + .column("l_shipdate", "Date") + .column("l_commitdate", "Date") + .column("l_receiptdate", "Date") + .column("l_shipinstruct", "String") + .column("l_shipmode", "String") + .column("l_comment", "String") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto plan = plan_builder.readMergeTree("default", "test", "home/saber/Documents/data/mergetree", 1, 10, std::move(schema)).build(); + local_engine::SerializedPlanParser parser(global_context); + auto query_plan = parser.parse(std::move(plan)); + local_engine::LocalExecutor local_executor; + state.ResumeTiming(); + local_executor.execute(std::move(query_plan)); + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + } + } +} + +[[maybe_unused]] static void BM_MergeTreeRead(benchmark::State & state) +{ + std::shared_ptr metadata = std::make_shared(); + ColumnsDescription columns_description; + auto int64_type = std::make_shared(); + auto int32_type = std::make_shared(); + auto double_type = std::make_shared(); + + const auto * type_string = "columns format version: 1\n" + "3 columns:\n" + "`l_orderkey` Int64\n" + "`l_commitdate` Date\n" + "`l_receiptdate` Date\n" + "`l_shipinstruct` String\n" + "`l_shipmode` String\n" + "`l_comment` String\n"; + auto names_and_types_list = NamesAndTypesList::parse(type_string); + metadata = local_engine::buildMetaData(names_and_types_list, global_context); + auto param = DB::MergeTreeData::MergingParams(); + auto settings = local_engine::buildMergeTreeSettings(); + + local_engine::CustomStorageMergeTree custom_merge_tree( + DB::StorageID("default", "test"), + "data0/tpch100_zhichao/mergetree_nullable/lineitem", + *metadata, + false, + global_context, + "", + param, + std::move(settings)); + auto snapshot = std::make_shared(custom_merge_tree, metadata); + custom_merge_tree.loadDataParts(false); + for (auto _ : state) + { + state.PauseTiming(); + auto query_info = local_engine::buildQueryInfo(names_and_types_list); + auto data_parts = custom_merge_tree.getDataPartsVectorForInternalUsage(); + int min_block = 0; + int max_block = state.range(0); + MergeTreeData::DataPartsVector selected_parts; + std::copy_if( + std::begin(data_parts), + std::end(data_parts), + std::inserter(selected_parts, std::begin(selected_parts)), + [min_block, max_block](MergeTreeData::DataPartPtr part) + { return part->info.min_block >= min_block && part->info.max_block <= max_block; }); + auto query = custom_merge_tree.reader.readFromParts( + selected_parts, names_and_types_list.getNames(), snapshot, *query_info, global_context, 10000, 1); + QueryPlanOptimizationSettings optimization_settings{.optimize_plan = false}; + auto query_pipeline = query->buildQueryPipeline(optimization_settings, {}); + state.ResumeTiming(); + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*query_pipeline)); + auto executor = PullingPipelineExecutor(pipeline); + Chunk chunk; + int sum = 0; + while (executor.pull(chunk)) + { + sum += chunk.getNumRows(); + } + std::cerr << "rows:" << sum << std::endl; + } +} + +[[maybe_unused]] static void BM_ParquetRead(benchmark::State & state) +{ + const auto * type_string = "columns format version: 1\n" + "2 columns:\n" + "`l_returnflag` String\n" + "`l_linestatus` String\n"; + auto names_and_types_list = NamesAndTypesList::parse(type_string); + ColumnsWithTypeAndName columns; + for (const auto & item : names_and_types_list) + { + ColumnWithTypeAndName col; + col.column = item.type->createColumn(); + col.type = item.type; + col.name = item.name; + columns.emplace_back(std::move(col)); + } + auto header = Block(std::move(columns)); + + for (auto _ : state) + { + substrait::ReadRel::LocalFiles files; + substrait::ReadRel::LocalFiles::FileOrFiles * file = files.add_items(); + std::string file_path = "file:///home/hongbin/code/gluten/jvm/src/test/resources/tpch-data/lineitem/" + "part-00000-d08071cb-0dfa-42dc-9198-83cb334ccda3-c000.snappy.parquet"; + file->set_uri_file(file_path); + substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions parquet_format; + file->mutable_parquet()->CopyFrom(parquet_format); + auto builder = std::make_unique(); + builder->init(Pipe(std::make_shared(SerializedPlanParser::global_context, header, files))); + + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*builder)); + auto executor = PullingPipelineExecutor(pipeline); + auto result = header.cloneEmpty(); + size_t total_rows = 0; + while (executor.pull(result)) + { + debug::headBlock(result); + total_rows += result.rows(); + } + std::cerr << "rows:" << total_rows << std::endl; + } +} + +[[maybe_unused]] static void BM_ShuffleSplitter(benchmark::State & state) +{ + std::shared_ptr metadata = std::make_shared(); + ColumnsDescription columns_description; + auto int64_type = std::make_shared(); + auto int32_type = std::make_shared(); + auto double_type = std::make_shared(); + const auto * type_string = "columns format version: 1\n" + "15 columns:\n" + "`l_partkey` Int64\n" + "`l_suppkey` Int64\n" + "`l_linenumber` Int32\n" + "`l_quantity` Float64\n" + "`l_extendedprice` Float64\n" + "`l_discount` Float64\n" + "`l_tax` Float64\n" + "`l_returnflag` String\n" + "`l_linestatus` String\n" + "`l_shipdate` Date\n" + "`l_commitdate` Date\n" + "`l_receiptdate` Date\n" + "`l_shipinstruct` String\n" + "`l_shipmode` String\n" + "`l_comment` String\n"; + auto names_and_types_list = NamesAndTypesList::parse(type_string); + metadata = local_engine::buildMetaData(names_and_types_list, global_context); + auto param = DB::MergeTreeData::MergingParams(); + auto settings = local_engine::buildMergeTreeSettings(); + + local_engine::CustomStorageMergeTree custom_merge_tree( + DB::StorageID("default", "test"), + "home/saber/Documents/data/mergetree", + *metadata, + false, + global_context, + "", + param, + std::move(settings)); + custom_merge_tree.loadDataParts(false); + auto snapshot = std::make_shared(custom_merge_tree, metadata); + for (auto _ : state) + { + state.PauseTiming(); + auto query_info = local_engine::buildQueryInfo(names_and_types_list); + auto data_parts = custom_merge_tree.getDataPartsVectorForInternalUsage(); + int min_block = 0; + int max_block = state.range(0); + MergeTreeData::DataPartsVector selected_parts; + std::copy_if( + std::begin(data_parts), + std::end(data_parts), + std::inserter(selected_parts, std::begin(selected_parts)), + [min_block, max_block](MergeTreeData::DataPartPtr part) + { return part->info.min_block >= min_block && part->info.max_block <= max_block; }); + auto query = custom_merge_tree.reader.readFromParts( + selected_parts, names_and_types_list.getNames(), snapshot, *query_info, global_context, 10000, 1); + QueryPlanOptimizationSettings optimization_settings{.optimize_plan = false}; + auto query_pipeline = query->buildQueryPipeline(optimization_settings, {}); + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*query_pipeline)); + state.ResumeTiming(); + auto executor = PullingPipelineExecutor(pipeline); + Block chunk = executor.getHeader(); + int sum = 0; + auto root = "/tmp/test_shuffle/" + local_engine::ShuffleSplitter::compress_methods[state.range(1)]; + local_engine::SplitOptions options{ + .split_size = 8192, + .io_buffer_size = DBMS_DEFAULT_BUFFER_SIZE, + .data_file = root + "/data.dat", + .map_id = 1, + .partition_nums = 4, + .compress_method = local_engine::ShuffleSplitter::compress_methods[state.range(1)]}; + auto splitter = local_engine::ShuffleSplitter::create("rr", options); + while (executor.pull(chunk)) + { + sum += chunk.rows(); + splitter->split(chunk); + } + splitter->stop(); + splitter->writeIndexFile(); + std::cout << sum << "\n"; + } +} + +[[maybe_unused]] static void BM_HashShuffleSplitter(benchmark::State & state) +{ + std::shared_ptr metadata = std::make_shared(); + ColumnsDescription columns_description; + auto int64_type = std::make_shared(); + auto int32_type = std::make_shared(); + auto double_type = std::make_shared(); + const auto * type_string = "columns format version: 1\n" + "15 columns:\n" + "`l_partkey` Int64\n" + "`l_suppkey` Int64\n" + "`l_linenumber` Int32\n" + "`l_quantity` Float64\n" + "`l_extendedprice` Float64\n" + "`l_discount` Float64\n" + "`l_tax` Float64\n" + "`l_returnflag` String\n" + "`l_linestatus` String\n" + "`l_shipdate` Date\n" + "`l_commitdate` Date\n" + "`l_receiptdate` Date\n" + "`l_shipinstruct` String\n" + "`l_shipmode` String\n" + "`l_comment` String\n"; + auto names_and_types_list = NamesAndTypesList::parse(type_string); + metadata = local_engine::buildMetaData(names_and_types_list, global_context); + auto param = DB::MergeTreeData::MergingParams(); + auto settings = local_engine::buildMergeTreeSettings(); + + local_engine::CustomStorageMergeTree custom_merge_tree( + DB::StorageID("default", "test"), + "home/saber/Documents/data/mergetree", + *metadata, + false, + global_context, + "", + param, + std::move(settings)); + custom_merge_tree.loadDataParts(false); + auto snapshot = std::make_shared(custom_merge_tree, metadata); + + for (auto _ : state) + { + state.PauseTiming(); + auto query_info = local_engine::buildQueryInfo(names_and_types_list); + auto data_parts = custom_merge_tree.getDataPartsVectorForInternalUsage(); + int min_block = 0; + int max_block = state.range(0); + MergeTreeData::DataPartsVector selected_parts; + std::copy_if( + std::begin(data_parts), + std::end(data_parts), + std::inserter(selected_parts, std::begin(selected_parts)), + [min_block, max_block](MergeTreeData::DataPartPtr part) + { return part->info.min_block >= min_block && part->info.max_block <= max_block; }); + auto query = custom_merge_tree.reader.readFromParts( + selected_parts, names_and_types_list.getNames(), snapshot, *query_info, global_context, 10000, 1); + QueryPlanOptimizationSettings optimization_settings{.optimize_plan = false}; + auto query_pipeline = query->buildQueryPipeline(optimization_settings, {}); + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*query_pipeline)); + state.ResumeTiming(); + auto executor = PullingPipelineExecutor(pipeline); + Block chunk = executor.getHeader(); + int sum = 0; + auto root = "/tmp/test_shuffle/" + local_engine::ShuffleSplitter::compress_methods[state.range(1)]; + local_engine::SplitOptions options{ + .split_size = 8192, + .io_buffer_size = DBMS_DEFAULT_BUFFER_SIZE, + .data_file = root + "/data.dat", + .map_id = 1, + .partition_nums = 4, + .compress_method = local_engine::ShuffleSplitter::compress_methods[state.range(1)]}; + auto splitter = local_engine::ShuffleSplitter::create("hash", options); + while (executor.pull(chunk)) + { + sum += chunk.rows(); + splitter->split(chunk); + } + splitter->stop(); + splitter->writeIndexFile(); + std::cout << sum << "\n"; + } +} + +[[maybe_unused]] static void BM_ShuffleReader(benchmark::State & state) +{ + for (auto _ : state) + { + auto read_buffer = std::make_unique("/tmp/test_shuffle/ZSTD/data.dat"); + // read_buffer->seek(357841655, SEEK_SET); + auto shuffle_reader = local_engine::ShuffleReader(std::move(read_buffer), true); + Block * block; + int sum = 0; + do + { + block = shuffle_reader.read(); + sum += block->rows(); + } while (block->columns() != 0); + std::cout << "total rows:" << sum << std::endl; + } +} + +[[maybe_unused]] static void BM_SimpleAggregate(benchmark::State & state) +{ + for (auto _ : state) + { + state.PauseTiming(); + + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_orderkey", "I64") + .column("l_partkey", "I64") + .column("l_suppkey", "I64") + .column("l_linenumber", "I32") + .column("l_quantity", "FP64") + .column("l_extendedprice", "FP64") + .column("l_discount", "FP64") + .column("l_tax", "FP64") + .column("l_shipdate_new", "FP64") + .column("l_commitdate_new", "FP64") + .column("l_receiptdate_new", "FP64") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto * measure = dbms::measureFunction(dbms::SUM, {dbms::selection(6)}); + auto plan + = plan_builder.registerSupportedFunctions() + .aggregate({}, {measure}) + .read( + "/home/kyligence/Documents/test-dataset/intel-gazelle-test-" + std::to_string(state.range(0)) + ".snappy.parquet", + std::move(schema)) + .build(); + local_engine::SerializedPlanParser parser(global_context); + auto query_plan = parser.parse(std::move(plan)); + local_engine::LocalExecutor local_executor; + state.ResumeTiming(); + local_executor.execute(std::move(query_plan)); + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + } + } +} + +[[maybe_unused]] static void BM_TPCH_Q6(benchmark::State & state) +{ + for (auto _ : state) + { + state.PauseTiming(); + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_discount", "FP64") + .column("l_extendedprice", "FP64") + .column("l_quantity", "FP64") + .column("l_shipdate_new", "Date") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto * agg_mul = dbms::scalarFunction(dbms::MULTIPLY, {dbms::selection(1), dbms::selection(0)}); + auto * measure1 = dbms::measureFunction(dbms::SUM, {agg_mul}); + auto * measure2 = dbms::measureFunction(dbms::SUM, {dbms::selection(1)}); + auto * measure3 = dbms::measureFunction(dbms::SUM, {dbms::selection(2)}); + auto plan + = plan_builder.registerSupportedFunctions() + .aggregate({}, {measure1, measure2, measure3}) + .project({dbms::selection(2), dbms::selection(1), dbms::selection(0)}) + .filter(dbms::scalarFunction( + dbms::AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {scalarFunction(IS_NOT_NULL, {selection(3)}), + scalarFunction(IS_NOT_NULL, {selection(0)})}), + scalarFunction(IS_NOT_NULL, {selection(2)})}), + dbms::scalarFunction(GREATER_THAN_OR_EQUAL, {selection(3), literalDate(8766)})}), + scalarFunction(LESS_THAN, {selection(3), literalDate(9131)})}), + scalarFunction(GREATER_THAN_OR_EQUAL, {selection(0), literal(0.05)})}), + scalarFunction(LESS_THAN_OR_EQUAL, {selection(0), literal(0.07)})}), + scalarFunction(LESS_THAN, {selection(2), literal(24.0)})})) + .read( + "/home/kyligence/Documents/test-dataset/intel-gazelle-test-" + std::to_string(state.range(0)) + ".snappy.parquet", + std::move(schema)) + .build(); + local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); + auto query_plan = parser.parse(std::move(plan)); + local_engine::LocalExecutor local_executor; + state.ResumeTiming(); + local_executor.execute(std::move(query_plan)); + while (local_executor.hasNext()) + { + Block * block = local_executor.nextColumnar(); + delete block; + } + } +} + + +[[maybe_unused]] static void BM_MERGE_TREE_TPCH_Q6(benchmark::State & state) +{ + SerializedPlanParser::global_context = global_context; + local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); + for (auto _ : state) + { + state.PauseTiming(); + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_discount", "FP64") + .column("l_extendedprice", "FP64") + .column("l_quantity", "FP64") + .column("l_shipdate", "Date") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto * agg_mul = dbms::scalarFunction(dbms::MULTIPLY, {dbms::selection(1), dbms::selection(0)}); + auto * measure1 = dbms::measureFunction(dbms::SUM, {agg_mul}); + auto plan = plan_builder.registerSupportedFunctions() + .aggregate({}, {measure1}) + .project({dbms::selection(2), dbms::selection(1), dbms::selection(0)}) + .filter(dbms::scalarFunction( + dbms::AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {scalarFunction(IS_NOT_NULL, {selection(3)}), + scalarFunction(IS_NOT_NULL, {selection(0)})}), + scalarFunction(IS_NOT_NULL, {selection(2)})}), + dbms::scalarFunction(GREATER_THAN_OR_EQUAL, {selection(3), literalDate(8766)})}), + scalarFunction(LESS_THAN, {selection(3), literalDate(9131)})}), + scalarFunction(GREATER_THAN_OR_EQUAL, {selection(0), literal(0.05)})}), + scalarFunction(LESS_THAN_OR_EQUAL, {selection(0), literal(0.07)})}), + scalarFunction(LESS_THAN, {selection(2), literal(24.0)})})) + .readMergeTree("default", "test", "home/saber/Documents/data/mergetree/", 1, 4, std::move(schema)) + .build(); + + auto query_plan = parser.parse(std::move(plan)); + local_engine::LocalExecutor local_executor; + state.ResumeTiming(); + local_executor.execute(std::move(query_plan)); + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + } + } +} + +[[maybe_unused]] static void BM_MERGE_TREE_TPCH_Q6_NEW(benchmark::State & state) +{ + SerializedPlanParser::global_context = global_context; + local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); + for (auto _ : state) + { + state.PauseTiming(); + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_discount", "FP64") + .column("l_extendedprice", "FP64") + .column("l_quantity", "FP64") + .column("l_shipdate", "Date") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto * agg_mul = dbms::scalarFunction(dbms::MULTIPLY, {dbms::selection(1), dbms::selection(0)}); + + auto * measure1 = dbms::measureFunction(dbms::SUM, {dbms::selection(0)}); + auto plan = plan_builder.registerSupportedFunctions() + .aggregate({}, {measure1}) + .project({agg_mul}) + .project({dbms::selection(2), dbms::selection(1), dbms::selection(0)}) + .filter(dbms::scalarFunction( + dbms::AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {scalarFunction(IS_NOT_NULL, {selection(3)}), + scalarFunction(IS_NOT_NULL, {selection(0)})}), + scalarFunction(IS_NOT_NULL, {selection(2)})}), + dbms::scalarFunction(GREATER_THAN_OR_EQUAL, {selection(3), literalDate(8766)})}), + scalarFunction(LESS_THAN, {selection(3), literalDate(9131)})}), + scalarFunction(GREATER_THAN_OR_EQUAL, {selection(0), literal(0.05)})}), + scalarFunction(LESS_THAN_OR_EQUAL, {selection(0), literal(0.07)})}), + scalarFunction(LESS_THAN, {selection(2), literal(24.0)})})) + .readMergeTree("default", "test", "home/saber/Documents/data/mergetree/", 1, 4, std::move(schema)) + .build(); + + auto query_plan = parser.parse(std::move(plan)); + local_engine::LocalExecutor local_executor; + state.ResumeTiming(); + local_executor.execute(std::move(query_plan)); + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + } + } +} + + +[[maybe_unused]] static void BM_MERGE_TREE_TPCH_Q6_FROM_TEXT(benchmark::State & state) +{ + SerializedPlanParser::global_context = global_context; + local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); + for (auto _ : state) + { + state.PauseTiming(); + + //const char * path = "/data1/tpc_data/tpch1000_zhichao/serialized_q6_substrait_plan1.txt"; + const char * path = "/data1/tpc_data/tpch100_zhichao/serialized_q4_substrait_plan_parquet.bin"; + //const char * path = "/data1/tpc_data/tpch100_zhichao/serialized_q4_substrait_plan_mergetree.bin"; + std::ifstream t(path); + std::string str((std::istreambuf_iterator(t)), std::istreambuf_iterator()); + std::cout << "the plan from: " << path << std::endl; + + auto query_plan = parser.parse(str); + local_engine::LocalExecutor local_executor; + state.ResumeTiming(); + local_executor.execute(std::move(query_plan)); + while (local_executor.hasNext()) + { + [[maybe_unused]] auto * x = local_executor.nextColumnar(); + } + } +} + + +[[maybe_unused]] static void BM_CHColumnToSparkRowWithString(benchmark::State & state) +{ + for (auto _ : state) + { + state.PauseTiming(); + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_orderkey", "I64") + .column("l_partkey", "I64") + .column("l_suppkey", "I64") + .column("l_linenumber", "I32") + .column("l_quantity", "FP64") + .column("l_extendedprice", "FP64") + .column("l_discount", "FP64") + .column("l_tax", "FP64") + .column("l_returnflag", "String") + .column("l_linestatus", "String") + .column("l_shipdate_new", "FP64") + .column("l_commitdate_new", "FP64") + .column("l_receiptdate_new", "FP64") + .column("l_shipinstruct", "String") + .column("l_shipmode", "String") + .column("l_comment", "String") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto plan + = plan_builder + .read( + "/home/kyligence/Documents/test-dataset/intel-gazelle-test-" + std::to_string(state.range(0)) + ".snappy.parquet", + std::move(schema)) + .build(); + local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); + auto query_plan = parser.parse(std::move(plan)); + local_engine::LocalExecutor local_executor; + state.ResumeTiming(); + local_executor.execute(std::move(query_plan)); + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + } + } +} + +[[maybe_unused]] static void BM_SparkRowToCHColumn(benchmark::State & state) +{ + for (auto _ : state) + { + state.PauseTiming(); + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_orderkey", "I64") + .column("l_partkey", "I64") + .column("l_suppkey", "I64") + .column("l_linenumber", "I32") + .column("l_quantity", "FP64") + .column("l_extendedprice", "FP64") + .column("l_discount", "FP64") + .column("l_tax", "FP64") + .column("l_shipdate_new", "FP64") + .column("l_commitdate_new", "FP64") + .column("l_receiptdate_new", "FP64") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto plan + = plan_builder + .read( + "/home/kyligence/Documents/test-dataset/intel-gazelle-test-" + std::to_string(state.range(0)) + ".snappy.parquet", + std::move(schema)) + .build(); + + local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); + auto query_plan = parser.parse(std::move(plan)); + local_engine::LocalExecutor local_executor; + + local_executor.execute(std::move(query_plan)); + local_engine::SparkRowToCHColumn converter; + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + state.ResumeTiming(); + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + state.PauseTiming(); + } + state.ResumeTiming(); + } +} + + +[[maybe_unused]] static void BM_SparkRowToCHColumnWithString(benchmark::State & state) +{ + for (auto _ : state) + { + state.PauseTiming(); + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_orderkey", "I64") + .column("l_partkey", "I64") + .column("l_suppkey", "I64") + .column("l_linenumber", "I32") + .column("l_quantity", "FP64") + .column("l_extendedprice", "FP64") + .column("l_discount", "FP64") + .column("l_tax", "FP64") + .column("l_returnflag", "String") + .column("l_linestatus", "String") + .column("l_shipdate_new", "FP64") + .column("l_commitdate_new", "FP64") + .column("l_receiptdate_new", "FP64") + .column("l_shipinstruct", "String") + .column("l_shipmode", "String") + .column("l_comment", "String") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto plan + = plan_builder + .read( + "/home/kyligence/Documents/test-dataset/intel-gazelle-test-" + std::to_string(state.range(0)) + ".snappy.parquet", + std::move(schema)) + .build(); + local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); + auto query_plan = parser.parse(std::move(plan)); + local_engine::LocalExecutor local_executor; + + local_executor.execute(std::move(query_plan)); + local_engine::SparkRowToCHColumn converter; + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + state.ResumeTiming(); + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + state.PauseTiming(); + } + state.ResumeTiming(); + } +} + +[[maybe_unused]] static void BM_SIMDFilter(benchmark::State & state) +{ + const int n = 10000000; + for (auto _ : state) + { + state.PauseTiming(); + PaddedPODArray arr; + PaddedPODArray condition; + PaddedPODArray res_data; + arr.reserve(n); + condition.reserve(n); + res_data.reserve(n); + for (int i = 0; i < n; i++) + { + arr.push_back(i); + condition.push_back(state.range(0)); + } + const Int32 * data_pos = arr.data(); + const UInt8 * filt_pos = condition.data(); + state.ResumeTiming(); +#ifdef __SSE2__ + int size = n; + static constexpr size_t SIMD_BYTES = 16; + const __m128i zero16 = _mm_setzero_si128(); + const UInt8 * filt_end_sse = filt_pos + size / SIMD_BYTES * SIMD_BYTES; + + while (filt_pos < filt_end_sse) + { + UInt16 mask = _mm_movemask_epi8(_mm_cmpeq_epi8(_mm_loadu_si128(reinterpret_cast(filt_pos)), zero16)); + mask = ~mask; + + if (0 == mask) + { + /// Nothing is inserted. + } + else if (0xFFFF == mask) + { + res_data.insert(data_pos, data_pos + SIMD_BYTES); + } + else + { + for (size_t i = 0; i < SIMD_BYTES; ++i) + if (filt_pos[i]) + [[maybe_unused]] auto x = data_pos[i]; + } + + filt_pos += SIMD_BYTES; + data_pos += SIMD_BYTES; + } +#endif + } +} + +[[maybe_unused]] static void BM_NormalFilter(benchmark::State & state) +{ + const int n = 10000000; + for (auto _ : state) + { + state.PauseTiming(); + PaddedPODArray arr; + PaddedPODArray condition; + PaddedPODArray res_data; + arr.reserve(n); + condition.reserve(n); + res_data.reserve(n); + for (int i = 0; i < n; i++) + { + arr.push_back(i); + condition.push_back(state.range(0)); + } + const Int32 * data_pos = arr.data(); + const UInt8 * filt_pos = condition.data(); + const UInt8 * filt_end = filt_pos + n; + state.ResumeTiming(); + while (filt_pos < filt_end) + { + if (*filt_pos) + res_data.push_back(*data_pos); + + ++filt_pos; + ++data_pos; + } + } +} + +[[maybe_unused]] static void BM_TestCreateExecute(benchmark::State & state) +{ + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_discount", "FP64") + .column("l_extendedprice", "FP64") + .column("l_quantity", "FP64") + .column("l_shipdate_new", "Date") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto * agg_mul = dbms::scalarFunction(dbms::MULTIPLY, {dbms::selection(1), dbms::selection(0)}); + auto * measure1 = dbms::measureFunction(dbms::SUM, {agg_mul}); + auto * measure2 = dbms::measureFunction(dbms::SUM, {dbms::selection(1)}); + auto * measure3 = dbms::measureFunction(dbms::SUM, {dbms::selection(2)}); + auto plan + = plan_builder.registerSupportedFunctions() + .aggregate({}, {measure1, measure2, measure3}) + .project({dbms::selection(2), dbms::selection(1), dbms::selection(0)}) + .filter(dbms::scalarFunction( + dbms::AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {scalarFunction(IS_NOT_NULL, {selection(3)}), scalarFunction(IS_NOT_NULL, {selection(0)})}), + scalarFunction(IS_NOT_NULL, {selection(2)})}), + dbms::scalarFunction(GREATER_THAN_OR_EQUAL, {selection(3), literalDate(8766)})}), + scalarFunction(LESS_THAN, {selection(3), literalDate(9131)})}), + scalarFunction(GREATER_THAN_OR_EQUAL, {selection(0), literal(0.05)})}), + scalarFunction(LESS_THAN_OR_EQUAL, {selection(0), literal(0.07)})}), + scalarFunction(LESS_THAN, {selection(2), literal(24.0)})})) + .readMergeTree("default", "test", "home/saber/Documents/data/mergetree", 1, 4, std::move(schema)) + .build(); + std::string plan_string = plan->SerializeAsString(); + local_engine::SerializedPlanParser::global_context = global_context; + local_engine::SerializedPlanParser::global_context->setConfig(local_engine::SerializedPlanParser::config); + for (auto _ : state) + { + Stopwatch stopwatch; + stopwatch.start(); + auto context = Context::createCopy(local_engine::SerializedPlanParser::global_context); + context->setPath("/"); + auto context_us = stopwatch.elapsedMicroseconds(); + local_engine::SerializedPlanParser parser(context); + auto query_plan = parser.parse(plan_string); + auto parser_us = stopwatch.elapsedMicroseconds() - context_us; + local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(parser.query_context); + auto executor_us = stopwatch.elapsedMicroseconds() - parser_us; + executor->execute(std::move(query_plan)); + auto execute_us = stopwatch.elapsedMicroseconds() - executor_us; + LOG_DEBUG( + &Poco::Logger::root(), + "create context: {}, create parser: {}, create executor: {}, execute executor: {}", + context_us, + parser_us, + executor_us, + execute_us); + } +} + +[[maybe_unused]] static int add(int a, int b) +{ + return a + b; +} + +[[maybe_unused]] static void BM_TestSum(benchmark::State & state) +{ + int cnt = state.range(0); + int i = 0; + std::vector x; + std::vector y; + x.reserve(cnt); + x.assign(cnt, 2); + y.reserve(cnt); + + for (auto _ : state) + { + for (i = 0; i < cnt; i++) + { + y[i] = add(x[i], i); + } + } +} + +[[maybe_unused]] static void BM_TestSumInline(benchmark::State & state) +{ + int cnt = state.range(0); + int i = 0; + std::vector x; + std::vector y; + x.reserve(cnt); + x.assign(cnt, 2); + y.reserve(cnt); + + for (auto _ : state) + { + for (i = 0; i < cnt; i++) + { + y[i] = x[i] + i; + } + } +} + +[[maybe_unused]] static void BM_TestPlus(benchmark::State & state) +{ + UInt64 rows = state.range(0); + auto & factory = FunctionFactory::instance(); + auto & type_factory = DataTypeFactory::instance(); + auto plus = factory.get("plus", global_context); + auto type = type_factory.get("UInt64"); + ColumnsWithTypeAndName arguments; + arguments.push_back(ColumnWithTypeAndName(type, "x")); + arguments.push_back(ColumnWithTypeAndName(type, "y")); + auto function = plus->build(arguments); + + ColumnsWithTypeAndName arguments_with_data; + Block block; + auto x = ColumnWithTypeAndName(type, "x"); + auto y = ColumnWithTypeAndName(type, "y"); + MutableColumnPtr mutable_x = x.type->createColumn(); + MutableColumnPtr mutable_y = y.type->createColumn(); + mutable_x->reserve(rows); + mutable_y->reserve(rows); + ColumnVector & column_x = assert_cast &>(*mutable_x); + ColumnVector & column_y = assert_cast &>(*mutable_y); + for (UInt64 i = 0; i < rows; i++) + { + column_x.insertValue(i); + column_y.insertValue(i + 1); + } + x.column = std::move(mutable_x); + y.column = std::move(mutable_y); + block.insert(x); + block.insert(y); + auto executable_function = function->prepare(arguments); + for (auto _ : state) + { + auto result = executable_function->execute(block.getColumnsWithTypeAndName(), type, rows, false); + } +} + +[[maybe_unused]] static void BM_TestPlusEmbedded(benchmark::State & state) +{ + UInt64 rows = state.range(0); + auto & factory = FunctionFactory::instance(); + auto & type_factory = DataTypeFactory::instance(); + auto plus = factory.get("plus", global_context); + auto type = type_factory.get("UInt64"); + ColumnsWithTypeAndName arguments; + arguments.push_back(ColumnWithTypeAndName(type, "x")); + arguments.push_back(ColumnWithTypeAndName(type, "y")); + auto function = plus->build(arguments); + ColumnsWithTypeAndName arguments_with_data; + Block block; + auto x = ColumnWithTypeAndName(type, "x"); + auto y = ColumnWithTypeAndName(type, "y"); + MutableColumnPtr mutable_x = x.type->createColumn(); + MutableColumnPtr mutable_y = y.type->createColumn(); + mutable_x->reserve(rows); + mutable_y->reserve(rows); + ColumnVector & column_x = assert_cast &>(*mutable_x); + ColumnVector & column_y = assert_cast &>(*mutable_y); + for (UInt64 i = 0; i < rows; i++) + { + column_x.insertValue(i); + column_y.insertValue(i + 1); + } + x.column = std::move(mutable_x); + y.column = std::move(mutable_y); + block.insert(x); + block.insert(y); + CHJIT chjit; + auto compiled_function = compileFunction(chjit, *function); + std::vector columns(arguments.size() + 1); + for (size_t i = 0; i < arguments.size(); ++i) + { + auto column = block.getByPosition(i).column->convertToFullColumnIfConst(); + columns[i] = getColumnData(column.get()); + } + for (auto _ : state) + { + auto result_column = type->createColumn(); + result_column->reserve(rows); + columns[arguments.size()] = getColumnData(result_column.get()); + compiled_function.compiled_function(rows, columns.data()); + } +} + +[[maybe_unused]] static void BM_TestReadColumn(benchmark::State & state) +{ + for (auto _ : state) + { + ReadBufferFromFile data_buf("/home/saber/Documents/test/c151.bin", 100000); + CompressedReadBuffer compressed(data_buf); + ReadBufferFromFile buf("/home/saber/Documents/test/c151.mrk2"); + while (!buf.eof() && !data_buf.eof()) + { + size_t x; + size_t y; + size_t z; + readIntBinary(x, buf); + readIntBinary(y, buf); + readIntBinary(z, buf); + std::cout << std::to_string(x) + " " << std::to_string(y) + " " << std::to_string(z) + " " + << "\n"; + data_buf.seek(x, SEEK_SET); + assert(!data_buf.eof()); + std::string data; + data.reserve(y); + compressed.readBig(reinterpret_cast(data.data()), y); + std::cout << data << "\n"; + } + } +} + +[[maybe_unused]] static double quantile(const std::vector & x) +{ + double q = 0.8; + assert(q >= 0.0 && q <= 1.0); + const int n = x.size(); + double id = (n - 1) * q; + int lo = static_cast(floor(id)); + int hi = static_cast(ceil(id)); + double qs = x[lo]; + double h = (id - lo); + return (1.0 - h) * qs + h * x[hi]; +} + +// compress benchmark +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace DB +{ +class FasterCompressedReadBufferBase +{ +protected: + ReadBuffer * compressed_in; + + /// If 'compressed_in' buffer has whole compressed block - then use it. Otherwise copy parts of data to 'own_compressed_buffer'. + PODArray own_compressed_buffer; + /// Points to memory, holding compressed block. + char * compressed_buffer = nullptr; + + ssize_t variant; + + /// Variant for reference implementation of LZ4. + static constexpr ssize_t LZ4_REFERENCE = -3; + + LZ4::StreamStatistics stream_stat; + LZ4::PerformanceStatistics perf_stat; + + size_t readCompressedData(size_t & size_decompressed, size_t & size_compressed_without_checksum) + { + if (compressed_in->eof()) + return 0; + + CityHash_v1_0_2::uint128 checksum; + compressed_in->readStrict(reinterpret_cast(&checksum), sizeof(checksum)); + + own_compressed_buffer.resize(COMPRESSED_BLOCK_HEADER_SIZE); + compressed_in->readStrict(&own_compressed_buffer[0], COMPRESSED_BLOCK_HEADER_SIZE); + + UInt8 method = own_compressed_buffer[0]; /// See CompressedWriteBuffer.h + + size_t & size_compressed = size_compressed_without_checksum; + + if (method == static_cast(CompressionMethodByte::LZ4) || method == static_cast(CompressionMethodByte::ZSTD) + || method == static_cast(CompressionMethodByte::NONE)) + { + size_compressed = unalignedLoad(&own_compressed_buffer[1]); + size_decompressed = unalignedLoad(&own_compressed_buffer[5]); + } + else + throw std::runtime_error("Unknown compression method: " + toString(method)); + + if (size_compressed > DBMS_MAX_COMPRESSED_SIZE) + throw std::runtime_error("Too large size_compressed. Most likely corrupted data."); + + /// Is whole compressed block located in 'compressed_in' buffer? + if (compressed_in->offset() >= COMPRESSED_BLOCK_HEADER_SIZE + && compressed_in->position() + size_compressed - COMPRESSED_BLOCK_HEADER_SIZE <= compressed_in->buffer().end()) + { + compressed_in->position() -= COMPRESSED_BLOCK_HEADER_SIZE; + compressed_buffer = compressed_in->position(); + compressed_in->position() += size_compressed; + } + else + { + own_compressed_buffer.resize(size_compressed + (variant == LZ4_REFERENCE ? 0 : LZ4::ADDITIONAL_BYTES_AT_END_OF_BUFFER)); + compressed_buffer = &own_compressed_buffer[0]; + compressed_in->readStrict(compressed_buffer + COMPRESSED_BLOCK_HEADER_SIZE, size_compressed - COMPRESSED_BLOCK_HEADER_SIZE); + } + + return size_compressed + sizeof(checksum); + } + + void decompress(char * to, size_t size_decompressed, size_t size_compressed_without_checksum) + { + UInt8 method = compressed_buffer[0]; /// See CompressedWriteBuffer.h + + if (method == static_cast(CompressionMethodByte::LZ4)) + { + //LZ4::statistics(compressed_buffer + COMPRESSED_BLOCK_HEADER_SIZE, to, size_decompressed, stat); + LZ4::decompress( + compressed_buffer + COMPRESSED_BLOCK_HEADER_SIZE, to, size_compressed_without_checksum, size_decompressed, perf_stat); + } + else + throw std::runtime_error("Unknown compression method: " + toString(method)); + } + +public: + /// 'compressed_in' could be initialized lazily, but before first call of 'readCompressedData'. + FasterCompressedReadBufferBase(ReadBuffer * in, ssize_t variant_) + : compressed_in(in), own_compressed_buffer(COMPRESSED_BLOCK_HEADER_SIZE), variant(variant_), perf_stat(variant) + { + } + LZ4::StreamStatistics getStreamStatistics() const { return stream_stat; } + LZ4::PerformanceStatistics getPerformanceStatistics() const { return perf_stat; } +}; + + +class FasterCompressedReadBuffer : public FasterCompressedReadBufferBase, public BufferWithOwnMemory +{ +private: + size_t size_compressed = 0; + + bool nextImpl() override + { + size_t size_decompressed; + size_t size_compressed_without_checksum; + size_compressed = readCompressedData(size_decompressed, size_compressed_without_checksum); + if (!size_compressed) + return false; + + memory.resize(size_decompressed + LZ4::ADDITIONAL_BYTES_AT_END_OF_BUFFER); + working_buffer = Buffer(&memory[0], &memory[size_decompressed]); + + decompress(working_buffer.begin(), size_decompressed, size_compressed_without_checksum); + + return true; + } + +public: + FasterCompressedReadBuffer(ReadBuffer & in_, ssize_t method) + : FasterCompressedReadBufferBase(&in_, method), BufferWithOwnMemory(0) + { + } +}; + +} + + +[[maybe_unused]] static void BM_TestDecompress(benchmark::State & state) +{ + std::vector files + = {"/home/saber/Documents/data/mergetree/all_1_1_0/l_discount.bin", + "/home/saber/Documents/data/mergetree/all_1_1_0/l_extendedprice.bin", + "/home/saber/Documents/data/mergetree/all_1_1_0/l_quantity.bin", + "/home/saber/Documents/data/mergetree/all_1_1_0/l_shipdate.bin", + + "/home/saber/Documents/data/mergetree/all_2_2_0/l_discount.bin", + "/home/saber/Documents/data/mergetree/all_2_2_0/l_extendedprice.bin", + "/home/saber/Documents/data/mergetree/all_2_2_0/l_quantity.bin", + "/home/saber/Documents/data/mergetree/all_2_2_0/l_shipdate.bin", + + "/home/saber/Documents/data/mergetree/all_3_3_0/l_discount.bin", + "/home/saber/Documents/data/mergetree/all_3_3_0/l_extendedprice.bin", + "/home/saber/Documents/data/mergetree/all_3_3_0/l_quantity.bin", + "/home/saber/Documents/data/mergetree/all_3_3_0/l_shipdate.bin"}; + for (auto _ : state) + { + for (const auto & file : files) + { + ReadBufferFromFile in(file); + FasterCompressedReadBuffer decompressing_in(in, state.range(0)); + while (!decompressing_in.eof()) + { + decompressing_in.position() = decompressing_in.buffer().end(); + decompressing_in.next(); + } + // std::cout << "call count:" << std::to_string(decompressing_in.getPerformanceStatistics().data[state.range(0)].count) << "\n"; + // std::cout << "false count:" << std::to_string(decompressing_in.false_count) << "\n"; + // decompressing_in.getStreamStatistics().print(); + } + } +} + +#include + +[[maybe_unused]] static void BM_CHColumnToSparkRowNew(benchmark::State & state) +{ + std::shared_ptr metadata = std::make_shared(); + ColumnsDescription columns_description; + auto int64_type = std::make_shared(); + auto int32_type = std::make_shared(); + auto double_type = std::make_shared(); + const auto * type_string = "columns format version: 1\n" + "15 columns:\n" + "`l_partkey` Int64\n" + "`l_suppkey` Int64\n" + "`l_linenumber` Int32\n" + "`l_quantity` Float64\n" + "`l_extendedprice` Float64\n" + "`l_discount` Float64\n" + "`l_tax` Float64\n" + "`l_returnflag` String\n" + "`l_linestatus` String\n" + "`l_shipdate` Date\n" + "`l_commitdate` Date\n" + "`l_receiptdate` Date\n" + "`l_shipinstruct` String\n" + "`l_shipmode` String\n" + "`l_comment` String\n"; + auto names_and_types_list = NamesAndTypesList::parse(type_string); + metadata = local_engine::buildMetaData(names_and_types_list, global_context); + auto param = DB::MergeTreeData::MergingParams(); + auto settings = local_engine::buildMergeTreeSettings(); + + local_engine::CustomStorageMergeTree custom_merge_tree( + DB::StorageID("default", "test"), + "data1/tpc_data/tpch10_liuneng/mergetree/lineitem", + *metadata, + false, + global_context, + "", + param, + std::move(settings)); + auto snapshot = std::make_shared(custom_merge_tree, metadata); + custom_merge_tree.loadDataParts(false); + for (auto _ : state) + { + state.PauseTiming(); + auto query_info = local_engine::buildQueryInfo(names_and_types_list); + auto data_parts = custom_merge_tree.getDataPartsVectorForInternalUsage(); + int min_block = 0; + int max_block = 10; + MergeTreeData::DataPartsVector selected_parts; + std::copy_if( + std::begin(data_parts), + std::end(data_parts), + std::inserter(selected_parts, std::begin(selected_parts)), + [min_block, max_block](MergeTreeData::DataPartPtr part) + { return part->info.min_block >= min_block && part->info.max_block <= max_block; }); + auto query = custom_merge_tree.reader.readFromParts( + selected_parts, names_and_types_list.getNames(), snapshot, *query_info, global_context, 10000, 1); + QueryPlanOptimizationSettings optimization_settings{.optimize_plan = false}; + auto query_pipeline = query->buildQueryPipeline(optimization_settings, {}); + state.ResumeTiming(); + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*query_pipeline)); + auto executor = PullingPipelineExecutor(pipeline); + Block header = executor.getHeader(); + CHColumnToSparkRow converter; + int sum = 0; + while (executor.pull(header)) + { + sum += header.rows(); + auto spark_row = converter.convertCHColumnToSparkRow(header); + converter.freeMem(spark_row->getBufferAddress(), spark_row->getTotalBytes()); + } + std::cerr << "rows: " << sum << std::endl; + } +} + +struct MergeTreeWithSnapshot +{ + std::shared_ptr merge_tree; + std::shared_ptr snapshot; + NamesAndTypesList columns; +}; + +MergeTreeWithSnapshot buildMergeTree(NamesAndTypesList names_and_types, std::string relative_path, std::string table) +{ + auto metadata = local_engine::buildMetaData(names_and_types, global_context); + auto param = DB::MergeTreeData::MergingParams(); + auto settings = local_engine::buildMergeTreeSettings(); + std::shared_ptr custom_merge_tree = std::make_shared( + DB::StorageID("default", table), relative_path, *metadata, false, global_context, "", param, std::move(settings)); + auto snapshot = std::make_shared(*custom_merge_tree, metadata); + custom_merge_tree->loadDataParts(false); + return MergeTreeWithSnapshot{.merge_tree = custom_merge_tree, .snapshot = snapshot, .columns = names_and_types}; +} + +QueryPlanPtr readFromMergeTree(MergeTreeWithSnapshot storage) +{ + auto query_info = local_engine::buildQueryInfo(storage.columns); + auto data_parts = storage.merge_tree->getDataPartsVectorForInternalUsage(); + return storage.merge_tree->reader.readFromParts( + data_parts, storage.columns.getNames(), storage.snapshot, *query_info, global_context, 10000, 1); +} + +QueryPlanPtr joinPlan(QueryPlanPtr left, QueryPlanPtr right, String left_key, String right_key, size_t block_size = 8192) +{ + auto join = std::make_shared(global_context->getSettings(), global_context->getTemporaryVolume()); + auto left_columns = left->getCurrentDataStream().header.getColumnsWithTypeAndName(); + auto right_columns = right->getCurrentDataStream().header.getColumnsWithTypeAndName(); + join->setKind(JoinKind::Left); + join->setStrictness(JoinStrictness::All); + join->setColumnsFromJoinedTable(right->getCurrentDataStream().header.getNamesAndTypesList()); + join->addDisjunct(); + ASTPtr lkey = std::make_shared(left_key); + ASTPtr rkey = std::make_shared(right_key); + join->addOnKeys(lkey, rkey); + for (const auto & column : join->columnsFromJoinedTable()) + { + join->addJoinedColumn(column); + } + + auto left_keys = left->getCurrentDataStream().header.getNamesAndTypesList(); + join->addJoinedColumnsAndCorrectTypes(left_keys, true); + ActionsDAGPtr left_convert_actions = nullptr; + ActionsDAGPtr right_convert_actions = nullptr; + std::tie(left_convert_actions, right_convert_actions) = join->createConvertingActions(left_columns, right_columns); + + if (right_convert_actions) + { + auto converting_step = std::make_unique(right->getCurrentDataStream(), right_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + right->addStep(std::move(converting_step)); + } + + if (left_convert_actions) + { + auto converting_step = std::make_unique(right->getCurrentDataStream(), right_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + left->addStep(std::move(converting_step)); + } + auto hash_join = std::make_shared(join, right->getCurrentDataStream().header); + + QueryPlanStepPtr join_step + = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), hash_join, block_size, 1, false); + + std::vector plans; + plans.emplace_back(std::move(left)); + plans.emplace_back(std::move(right)); + + auto query_plan = std::make_unique(); + query_plan->unitePlans(std::move(join_step), std::move(plans)); + return query_plan; +} + +[[maybe_unused]] static void BM_JoinTest(benchmark::State & state) +{ + std::shared_ptr metadata = std::make_shared(); + ColumnsDescription columns_description; + auto int64_type = std::make_shared(); + auto int32_type = std::make_shared(); + auto double_type = std::make_shared(); + const auto * supplier_type_string = "columns format version: 1\n" + "2 columns:\n" + "`s_suppkey` Int64\n" + "`s_nationkey` Int64\n"; + auto supplier_types = NamesAndTypesList::parse(supplier_type_string); + auto supplier = buildMergeTree(supplier_types, "home/saber/Documents/data/tpch/mergetree/supplier", "supplier"); + + const auto * nation_type_string = "columns format version: 1\n" + "1 columns:\n" + "`n_nationkey` Int64\n"; + auto nation_types = NamesAndTypesList::parse(nation_type_string); + auto nation = buildMergeTree(nation_types, "home/saber/Documents/data/tpch/mergetree/nation", "nation"); + + + const auto * partsupp_type_string = "columns format version: 1\n" + "3 columns:\n" + "`ps_suppkey` Int64\n" + "`ps_availqty` Int64\n" + "`ps_supplycost` Float64\n"; + auto partsupp_types = NamesAndTypesList::parse(partsupp_type_string); + auto partsupp = buildMergeTree(partsupp_types, "home/saber/Documents/data/tpch/mergetree/partsupp", "partsupp"); + + for (auto _ : state) + { + state.PauseTiming(); + QueryPlanPtr supplier_query; + { + auto left = readFromMergeTree(partsupp); + auto right = readFromMergeTree(supplier); + supplier_query = joinPlan(std::move(left), std::move(right), "ps_suppkey", "s_suppkey"); + } + auto right = readFromMergeTree(nation); + auto query_plan = joinPlan(std::move(supplier_query), std::move(right), "s_nationkey", "n_nationkey"); + QueryPlanOptimizationSettings optimization_settings{.optimize_plan = false}; + BuildQueryPipelineSettings pipeline_settings; + auto pipeline_builder = query_plan->buildQueryPipeline(optimization_settings, pipeline_settings); + state.ResumeTiming(); + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); + auto executor = PullingPipelineExecutor(pipeline); + Block header = executor.getHeader(); + [[maybe_unused]] int sum = 0; + while (executor.pull(header)) + { + sum += header.rows(); + } + } +} + +BENCHMARK(BM_ParquetRead)->Unit(benchmark::kMillisecond)->Iterations(10); + +// BENCHMARK(BM_TestDecompress)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Unit(benchmark::kMillisecond)->Iterations(50)->Repetitions(6)->ComputeStatistics("80%", quantile); +// BENCHMARK(BM_JoinTest)->Unit(benchmark::k +// Millisecond)->Iterations(10)->Repetitions(250)->ComputeStatistics("80%", quantile); + +//BENCHMARK(BM_CHColumnToSparkRow)->Unit(benchmark::kMillisecond)->Iterations(40); +//BENCHMARK(BM_MergeTreeRead)->Arg(1)->Unit(benchmark::kMillisecond)->Iterations(10); + +//BENCHMARK(BM_ShuffleSplitter)->Args({2, 0})->Args({2, 1})->Args({2, 2})->Unit(benchmark::kMillisecond)->Iterations(1); +//BENCHMARK(BM_HashShuffleSplitter)->Args({2, 0})->Args({2, 1})->Args({2, 2})->Unit(benchmark::kMillisecond)->Iterations(1); +//BENCHMARK(BM_ShuffleReader)->Unit(benchmark::kMillisecond)->Iterations(10); +//BENCHMARK(BM_SimpleAggregate)->Arg(150)->Unit(benchmark::kMillisecond)->Iterations(40); +//BENCHMARK(BM_SIMDFilter)->Arg(1)->Arg(0)->Unit(benchmark::kMillisecond)->Iterations(40); +//BENCHMARK(BM_NormalFilter)->Arg(1)->Arg(0)->Unit(benchmark::kMillisecond)->Iterations(40); +//BENCHMARK(BM_TPCH_Q6)->Arg(150)->Unit(benchmark::kMillisecond)->Iterations(10); +//BENCHMARK(BM_MERGE_TREE_TPCH_Q6)->Unit(benchmark::kMillisecond)->Iterations(10); +//BENCHMARK(BM_MERGE_TREE_TPCH_Q6_NEW)->Unit(benchmark::kMillisecond)->Iterations(10); + +//BENCHMARK(BM_MERGE_TREE_TPCH_Q6_FROM_TEXT)->Unit(benchmark::kMillisecond)->Iterations(5); + +//BENCHMARK(BM_CHColumnToSparkRowWithString)->Arg(1)->Arg(3)->Arg(30)->Arg(90)->Arg(150)->Unit(benchmark::kMillisecond)->Iterations(10); +//BENCHMARK(BM_SparkRowToCHColumn)->Arg(1)->Arg(3)->Arg(30)->Arg(90)->Arg(150)->Unit(benchmark::kMillisecond)->Iterations(10); +//BENCHMARK(BM_SparkRowToCHColumnWithString)->Arg(1)->Arg(3)->Arg(30)->Arg(90)->Arg(150)->Unit(benchmark::kMillisecond)->Iterations(10); +//BENCHMARK(BM_TestCreateExecute)->Unit(benchmark::kMillisecond)->Iterations(1000); +//BENCHMARK(BM_TestReadColumn)->Unit(benchmark::kMillisecond)->Iterations(1); + +//BENCHMARK(BM_TestSum)->Arg(1000000)->Unit(benchmark::kMicrosecond)->Iterations(100)->Repetitions(100)->ComputeStatistics("80%", quantile)->DisplayAggregatesOnly(); +//BENCHMARK(BM_TestSumInline)->Arg(1000000)->Unit(benchmark::kMicrosecond)->Iterations(100)->Repetitions(100)->ComputeStatistics("80%", quantile)->DisplayAggregatesOnly(); +// +//BENCHMARK(BM_TestPlus)->Arg(65505)->Unit(benchmark::kMicrosecond)->Iterations(100)->Repetitions(1000)->ComputeStatistics("80%", quantile)->DisplayAggregatesOnly(); +//BENCHMARK(BM_TestPlusEmbedded)->Arg(65505)->Unit(benchmark::kMicrosecond)->Iterations(100)->Repetitions(1000)->ComputeStatistics("80%", quantile)->DisplayAggregatesOnly(); + + +int main(int argc, char ** argv) +{ + //local_engine::Logger::initConsoleLogger(); + SharedContextHolder shared_context = Context::createShared(); + global_context = Context::createGlobal(shared_context.get()); + global_context->makeGlobalContext(); + + auto config = Poco::AutoPtr(new Poco::Util::MapConfiguration()); + global_context->setConfig(config); + const std::string path = "/"; + global_context->setPath(path); + SerializedPlanParser::global_context = global_context; + local_engine::SerializedPlanParser::initFunctionEnv(); + + registerReadBufferBuilders(); + + ::benchmark::Initialize(&argc, argv); + if (::benchmark::ReportUnrecognizedArguments(argc, argv)) + return 1; + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); + return 0; +} diff --git a/utils/local-engine/tests/benchmark_parquet_read.cpp b/utils/local-engine/tests/benchmark_parquet_read.cpp new file mode 100644 index 000000000000..fd1627903551 --- /dev/null +++ b/utils/local-engine/tests/benchmark_parquet_read.cpp @@ -0,0 +1,138 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void BM_ParquetReadString(benchmark::State& state) +{ + using namespace DB; + Block header{ + ColumnWithTypeAndName(DataTypeString().createColumn(), std::make_shared(), "l_returnflag"), + ColumnWithTypeAndName(DataTypeString().createColumn(), std::make_shared(), "l_linestatus") + }; + std::string file + = "/data1/liyang/cppproject/gluten/jvm/src/test/resources/tpch-data/lineitem/part-00000-d08071cb-0dfa-42dc-9198-83cb334ccda3-c000.snappy.parquet"; + FormatSettings format_settings; + Block res; + for (auto _ : state) + { + auto in = std::make_unique(file); + auto format = std::make_shared(*in, header, format_settings); + auto pipeline = QueryPipeline(std::move(format)); + auto reader = std::make_unique(pipeline); + while (reader->pull(res)) + { + // debug::headBlock(res); + } + } +} + +static void BM_ParquetReadDate32(benchmark::State& state) +{ + using namespace DB; + Block header{ + ColumnWithTypeAndName(DataTypeDate32().createColumn(), std::make_shared(), "l_shipdate"), + ColumnWithTypeAndName(DataTypeDate32().createColumn(), std::make_shared(), "l_commitdate"), + ColumnWithTypeAndName(DataTypeDate32().createColumn(), std::make_shared(), "l_receiptdate") + }; + std::string file + = "/data1/liyang/cppproject/gluten/jvm/src/test/resources/tpch-data/lineitem/part-00000-d08071cb-0dfa-42dc-9198-83cb334ccda3-c000.snappy.parquet"; + FormatSettings format_settings; + Block res; + for (auto _ : state) + { + auto in = std::make_unique(file); + auto format = std::make_shared(*in, header, format_settings); + auto pipeline = QueryPipeline(std::move(format)); + auto reader = std::make_unique(pipeline); + while (reader->pull(res)) + { + // debug::headBlock(res); + } + } +} + +static void BM_OptimizedParquetReadString(benchmark::State& state) +{ + using namespace DB; + using namespace local_engine; + Block header{ + ColumnWithTypeAndName(DataTypeString().createColumn(), std::make_shared(), "l_returnflag"), + ColumnWithTypeAndName(DataTypeString().createColumn(), std::make_shared(), "l_linestatus") + }; + std::string file = "file:///data1/liyang/cppproject/gluten/jvm/src/test/resources/tpch-data/lineitem/" + "part-00000-d08071cb-0dfa-42dc-9198-83cb334ccda3-c000.snappy.parquet"; + Block res; + + for (auto _ : state) + { + substrait::ReadRel::LocalFiles files; + substrait::ReadRel::LocalFiles::FileOrFiles * file_item = files.add_items(); + file_item->set_uri_file(file); + substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions parquet_format; + file_item->mutable_parquet()->CopyFrom(parquet_format); + + auto builder = std::make_unique(); + builder->init( + Pipe(std::make_shared(local_engine::SerializedPlanParser::global_context, header, files))); + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*builder)); + auto reader = PullingPipelineExecutor(pipeline); + while (reader.pull(res)) + { + // debug::headBlock(res); + } + } +} + +static void BM_OptimizedParquetReadDate32(benchmark::State& state) +{ + using namespace DB; + using namespace local_engine; + Block header{ + ColumnWithTypeAndName(DataTypeDate32().createColumn(), std::make_shared(), "l_shipdate"), + ColumnWithTypeAndName(DataTypeDate32().createColumn(), std::make_shared(), "l_commitdate"), + ColumnWithTypeAndName(DataTypeDate32().createColumn(), std::make_shared(), "l_receiptdate") + }; + std::string file = "file:///data1/liyang/cppproject/gluten/jvm/src/test/resources/tpch-data/lineitem/" + "part-00000-d08071cb-0dfa-42dc-9198-83cb334ccda3-c000.snappy.parquet"; + Block res; + + for (auto _ : state) + { + substrait::ReadRel::LocalFiles files; + substrait::ReadRel::LocalFiles::FileOrFiles * file_item = files.add_items(); + file_item->set_uri_file(file); + substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions parquet_format; + file_item->mutable_parquet()->CopyFrom(parquet_format); + + auto builder = std::make_unique(); + builder->init( + Pipe(std::make_shared(local_engine::SerializedPlanParser::global_context, header, files))); + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*builder)); + auto reader = PullingPipelineExecutor(pipeline); + while (reader.pull(res)) + { + // debug::headBlock(res); + } + } +} + +BENCHMARK(BM_ParquetReadString)->Unit(benchmark::kMillisecond)->Iterations(10); +BENCHMARK(BM_ParquetReadDate32)->Unit(benchmark::kMillisecond)->Iterations(10); +BENCHMARK(BM_OptimizedParquetReadString)->Unit(benchmark::kMillisecond)->Iterations(10); +BENCHMARK(BM_OptimizedParquetReadDate32)->Unit(benchmark::kMillisecond)->Iterations(200); + diff --git a/utils/local-engine/tests/benchmark_spark_row.cpp b/utils/local-engine/tests/benchmark_spark_row.cpp new file mode 100644 index 000000000000..32f305f5e9ec --- /dev/null +++ b/utils/local-engine/tests/benchmark_spark_row.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace DB; +using namespace local_engine; + +struct NameType +{ + String name; + String type; +}; + +using NameTypes = std::vector; + +static Block getLineitemHeader(const NameTypes & name_types) +{ + auto & factory = DataTypeFactory::instance(); + ColumnsWithTypeAndName columns(name_types.size()); + for (size_t i=0; i(file); + FormatSettings format_settings; + auto format = std::make_shared(*in, header, format_settings); + auto pipeline = QueryPipeline(std::move(format)); + auto reader = std::make_unique(pipeline); + while (reader->pull(block)) + return; +} + +static void BM_CHColumnToSparkRow_Lineitem(benchmark::State& state) +{ + const NameTypes name_types = { + {"l_orderkey", "Nullable(Int64)"}, + {"l_partkey", "Nullable(Int64)"}, + {"l_suppkey", "Nullable(Int64)"}, + {"l_linenumber", "Nullable(Int64)"}, + {"l_quantity", "Nullable(Float64)"}, + {"l_extendedprice", "Nullable(Float64)"}, + {"l_discount", "Nullable(Float64)"}, + {"l_tax", "Nullable(Float64)"}, + {"l_returnflag", "Nullable(String)"}, + {"l_linestatus", "Nullable(String)"}, + {"l_shipdate", "Nullable(Date32)"}, + {"l_commitdate", "Nullable(Date32)"}, + {"l_receiptdate", "Nullable(Date32)"}, + {"l_shipinstruct", "Nullable(String)"}, + {"l_shipmode", "Nullable(String)"}, + {"l_comment", "Nullable(String)"}, + }; + + const Block header = std::move(getLineitemHeader(name_types)); + const String file = "/data1/liyang/cppproject/gluten/gluten-core/src/test/resources/tpch-data/lineitem/" + "part-00000-d08071cb-0dfa-42dc-9198-83cb334ccda3-c000.snappy.parquet"; + Block block; + readParquetFile(header, file, block); + // std::cerr << "read_rows:" << block.rows() << std::endl; + CHColumnToSparkRow converter; + for (auto _ : state) + { + auto spark_row_info = converter.convertCHColumnToSparkRow(block); + converter.freeMem(spark_row_info->getBufferAddress(), spark_row_info->getTotalBytes()); + } +} + + +static void BM_SparkRowToCHColumn_Lineitem(benchmark::State& state) +{ + const NameTypes name_types = { + {"l_orderkey", "Nullable(Int64)"}, + {"l_partkey", "Nullable(Int64)"}, + {"l_suppkey", "Nullable(Int64)"}, + {"l_linenumber", "Nullable(Int64)"}, + {"l_quantity", "Nullable(Float64)"}, + {"l_extendedprice", "Nullable(Float64)"}, + {"l_discount", "Nullable(Float64)"}, + {"l_tax", "Nullable(Float64)"}, + {"l_returnflag", "Nullable(String)"}, + {"l_linestatus", "Nullable(String)"}, + {"l_shipdate", "Nullable(Date32)"}, + {"l_commitdate", "Nullable(Date32)"}, + {"l_receiptdate", "Nullable(Date32)"}, + {"l_shipinstruct", "Nullable(String)"}, + {"l_shipmode", "Nullable(String)"}, + {"l_comment", "Nullable(String)"}, + }; + + const Block header = std::move(getLineitemHeader(name_types)); + const String file = "/data1/liyang/cppproject/gluten/jvm/src/test/resources/tpch-data/lineitem/" + "part-00000-d08071cb-0dfa-42dc-9198-83cb334ccda3-c000.snappy.parquet"; + Block in_block; + readParquetFile(header, file, in_block); + + CHColumnToSparkRow spark_row_converter; + auto spark_row_info = spark_row_converter.convertCHColumnToSparkRow(in_block); + for (auto _ : state) + [[maybe_unused]] auto out_block = SparkRowToCHColumn::convertSparkRowInfoToCHColumn(*spark_row_info, header); +} + +BENCHMARK(BM_CHColumnToSparkRow_Lineitem)->Unit(benchmark::kMillisecond)->Iterations(10); +BENCHMARK(BM_SparkRowToCHColumn_Lineitem)->Unit(benchmark::kMillisecond)->Iterations(10); diff --git a/utils/local-engine/tests/data/alltypes/alltypes_notnull.parquet b/utils/local-engine/tests/data/alltypes/alltypes_notnull.parquet new file mode 100644 index 000000000000..64aab87b6139 Binary files /dev/null and b/utils/local-engine/tests/data/alltypes/alltypes_notnull.parquet differ diff --git a/utils/local-engine/tests/data/alltypes/alltypes_null.parquet b/utils/local-engine/tests/data/alltypes/alltypes_null.parquet new file mode 100644 index 000000000000..926a5c5f435c Binary files /dev/null and b/utils/local-engine/tests/data/alltypes/alltypes_null.parquet differ diff --git a/utils/local-engine/tests/data/array.parquet b/utils/local-engine/tests/data/array.parquet new file mode 100644 index 000000000000..d989f3d7cbc1 Binary files /dev/null and b/utils/local-engine/tests/data/array.parquet differ diff --git a/utils/local-engine/tests/data/date.parquet b/utils/local-engine/tests/data/date.parquet new file mode 100644 index 000000000000..5f2e11525aad Binary files /dev/null and b/utils/local-engine/tests/data/date.parquet differ diff --git a/utils/local-engine/tests/data/datetime64.parquet b/utils/local-engine/tests/data/datetime64.parquet new file mode 100644 index 000000000000..a597b85aa0de Binary files /dev/null and b/utils/local-engine/tests/data/datetime64.parquet differ diff --git a/utils/local-engine/tests/data/decimal.parquet b/utils/local-engine/tests/data/decimal.parquet new file mode 100644 index 000000000000..e1981938866e Binary files /dev/null and b/utils/local-engine/tests/data/decimal.parquet differ diff --git a/utils/local-engine/tests/data/iris.parquet b/utils/local-engine/tests/data/iris.parquet new file mode 100644 index 000000000000..20979952d618 Binary files /dev/null and b/utils/local-engine/tests/data/iris.parquet differ diff --git a/utils/local-engine/tests/data/lineitem.orc b/utils/local-engine/tests/data/lineitem.orc new file mode 100644 index 000000000000..70a6e11c8778 Binary files /dev/null and b/utils/local-engine/tests/data/lineitem.orc differ diff --git a/utils/local-engine/tests/data/map.parquet b/utils/local-engine/tests/data/map.parquet new file mode 100644 index 000000000000..def9242ee305 Binary files /dev/null and b/utils/local-engine/tests/data/map.parquet differ diff --git a/utils/local-engine/tests/data/struct.parquet b/utils/local-engine/tests/data/struct.parquet new file mode 100644 index 000000000000..7a90433ae703 Binary files /dev/null and b/utils/local-engine/tests/data/struct.parquet differ diff --git a/utils/local-engine/tests/gtest_ch_functions.cpp b/utils/local-engine/tests/gtest_ch_functions.cpp new file mode 100644 index 000000000000..ab0b8edc9bf0 --- /dev/null +++ b/utils/local-engine/tests/gtest_ch_functions.cpp @@ -0,0 +1,168 @@ +#include +#include +#include +#include +#include +#include +#include + +TEST(TestFuntion, Hash) +{ + using namespace DB; + auto & factory = FunctionFactory::instance(); + auto function = factory.get("murmurHash2_64", local_engine::SerializedPlanParser::global_context); + auto type0 = DataTypeFactory::instance().get("String"); + auto column0 = type0->createColumn(); + column0->insert("A"); + column0->insert("A"); + column0->insert("B"); + column0->insert("c"); + + auto column1 = type0->createColumn(); + column1->insert("X"); + column1->insert("X"); + column1->insert("Y"); + column1->insert("Z"); + + ColumnsWithTypeAndName columns = {ColumnWithTypeAndName(std::move(column0),type0, "string0"), + ColumnWithTypeAndName(std::move(column1),type0, "string0")}; + Block block(columns); + std::cerr << "input:\n"; + debug::headBlock(block); + auto executable = function->build(block.getColumnsWithTypeAndName()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + std::cerr << "output:\n"; + debug::headColumn(result); + ASSERT_EQ(result->getUInt(0), result->getUInt(1)); +} + +TEST(TestFunction, In) +{ + using namespace DB; + auto & factory = FunctionFactory::instance(); + auto function = factory.get("in", local_engine::SerializedPlanParser::global_context); + auto type0 = DataTypeFactory::instance().get("String"); + auto type_set = std::make_shared(); + + + auto column1 = type0->createColumn(); + column1->insert("X"); + column1->insert("X"); + column1->insert("Y"); + column1->insert("Z"); + + SizeLimits limit; + auto set = std::make_shared(limit, true, false); + Block col1_set_block; + auto col1_set = type0->createColumn(); + col1_set->insert("X"); + col1_set->insert("Y"); + + col1_set_block.insert(ColumnWithTypeAndName(std::move(col1_set), type0, "string0")); + set->setHeader(col1_set_block.getColumnsWithTypeAndName()); + set->insertFromBlock(col1_set_block.getColumnsWithTypeAndName()); + set->finishInsert(); + + auto arg = ColumnSet::create(set->getTotalRowCount(), set); + + ColumnsWithTypeAndName columns = {ColumnWithTypeAndName(std::move(column1),type0, "string0"), + ColumnWithTypeAndName(std::move(arg),type_set, "__set")}; + Block block(columns); + std::cerr << "input:\n"; + debug::headBlock(block); + auto executable = function->build(block.getColumnsWithTypeAndName()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + std::cerr << "output:\n"; + debug::headColumn(result); + ASSERT_EQ(result->getUInt(3), 0); +} + + +TEST(TestFunction, NotIn1) +{ + using namespace DB; + auto & factory = FunctionFactory::instance(); + auto function = factory.get("notIn", local_engine::SerializedPlanParser::global_context); + auto type0 = DataTypeFactory::instance().get("String"); + auto type_set = std::make_shared(); + + + auto column1 = type0->createColumn(); + column1->insert("X"); + column1->insert("X"); + column1->insert("Y"); + column1->insert("Z"); + + SizeLimits limit; + auto set = std::make_shared(limit, true, false); + Block col1_set_block; + auto col1_set = type0->createColumn(); + col1_set->insert("X"); + col1_set->insert("Y"); + + col1_set_block.insert(ColumnWithTypeAndName(std::move(col1_set), type0, "string0")); + set->setHeader(col1_set_block.getColumnsWithTypeAndName()); + set->insertFromBlock(col1_set_block.getColumnsWithTypeAndName()); + set->finishInsert(); + + auto arg = ColumnSet::create(set->getTotalRowCount(), set); + + ColumnsWithTypeAndName columns = {ColumnWithTypeAndName(std::move(column1),type0, "string0"), + ColumnWithTypeAndName(std::move(arg),type_set, "__set")}; + Block block(columns); + std::cerr << "input:\n"; + debug::headBlock(block); + auto executable = function->build(block.getColumnsWithTypeAndName()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + std::cerr << "output:\n"; + debug::headColumn(result); + ASSERT_EQ(result->getUInt(3), 1); +} + +TEST(TestFunction, NotIn2) +{ + using namespace DB; + auto & factory = FunctionFactory::instance(); + auto function = factory.get("in", local_engine::SerializedPlanParser::global_context); + auto type0 = DataTypeFactory::instance().get("String"); + auto type_set = std::make_shared(); + + + auto column1 = type0->createColumn(); + column1->insert("X"); + column1->insert("X"); + column1->insert("Y"); + column1->insert("Z"); + + SizeLimits limit; + auto set = std::make_shared(limit, true, false); + Block col1_set_block; + auto col1_set = type0->createColumn(); + col1_set->insert("X"); + col1_set->insert("Y"); + + col1_set_block.insert(ColumnWithTypeAndName(std::move(col1_set), type0, "string0")); + set->setHeader(col1_set_block.getColumnsWithTypeAndName()); + set->insertFromBlock(col1_set_block.getColumnsWithTypeAndName()); + set->finishInsert(); + + auto arg = ColumnSet::create(set->getTotalRowCount(), set); + + ColumnsWithTypeAndName columns = {ColumnWithTypeAndName(std::move(column1),type0, "string0"), + ColumnWithTypeAndName(std::move(arg),type_set, "__set")}; + Block block(columns); + std::cerr << "input:\n"; + debug::headBlock(block); + auto executable = function->build(block.getColumnsWithTypeAndName()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + + auto function_not = factory.get("not", local_engine::SerializedPlanParser::global_context); + auto type_bool = DataTypeFactory::instance().get("UInt8"); + ColumnsWithTypeAndName columns2 = {ColumnWithTypeAndName(result, type_bool, "string0")}; + Block block2(columns2); + auto executable2 = function_not->build(block2.getColumnsWithTypeAndName()); + auto result2 = executable2->execute(block2.getColumnsWithTypeAndName(), executable2->getResultType(), block2.rows()); + std::cerr << "output:\n"; + debug::headColumn(result2); + ASSERT_EQ(result2->getUInt(3), 1); +} diff --git a/utils/local-engine/tests/gtest_ch_join.cpp b/utils/local-engine/tests/gtest_ch_join.cpp new file mode 100644 index 000000000000..1133e1571163 --- /dev/null +++ b/utils/local-engine/tests/gtest_ch_join.cpp @@ -0,0 +1,233 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + + +using namespace DB; +using namespace local_engine; + +TEST(TestJoin, simple) +{ + auto global_context = SerializedPlanParser::global_context; + local_engine::SerializedPlanParser::global_context->setSetting("join_use_nulls", true); + auto & factory = DB::FunctionFactory::instance(); + auto function = factory.get("murmurHash2_64", local_engine::SerializedPlanParser::global_context); + auto int_type = DataTypeFactory::instance().get("Int32"); + auto column0 = int_type->createColumn(); + column0->insert(1); + column0->insert(2); + column0->insert(3); + column0->insert(4); + + auto column1 = int_type->createColumn(); + column1->insert(2); + column1->insert(4); + column1->insert(6); + column1->insert(8); + + ColumnsWithTypeAndName columns = {ColumnWithTypeAndName(std::move(column0),int_type, "colA"), + ColumnWithTypeAndName(std::move(column1),int_type, "colB")}; + Block left(columns); + + auto column3 = int_type->createColumn(); + column3->insert(1); + column3->insert(2); + column3->insert(3); + column3->insert(5); + + auto column4 = int_type->createColumn(); + column4->insert(1); + column4->insert(3); + column4->insert(5); + column4->insert(9); + + ColumnsWithTypeAndName columns2 = {ColumnWithTypeAndName(std::move(column3),int_type, "colD"), + ColumnWithTypeAndName(std::move(column4),int_type, "colC")}; + Block right(columns2); + + auto left_table = std::make_shared(left); + auto right_table = std::make_shared(right); + QueryPlan left_plan; + left_plan.addStep(std::make_unique(Pipe(left_table))); + QueryPlan right_plan; + right_plan.addStep(std::make_unique(Pipe(right_table))); + + auto join = std::make_shared(global_context->getSettings(), global_context->getTemporaryVolume()); + join->setKind(JoinKind::Left); + join->setStrictness(JoinStrictness::All); + join->setColumnsFromJoinedTable(right.getNamesAndTypesList()); + join->addDisjunct(); + ASTPtr lkey = std::make_shared("colA"); + ASTPtr rkey = std::make_shared("colD"); + join->addOnKeys(lkey, rkey); + for (const auto & column : join->columnsFromJoinedTable()) + { + join->addJoinedColumn(column); + } + + auto left_keys = left.getNamesAndTypesList(); + join->addJoinedColumnsAndCorrectTypes(left_keys, true); + std::cerr << "after join:\n"; + for (const auto& key : left_keys) + { + std::cerr << key.dump() <createConvertingActions(left.getColumnsWithTypeAndName(), right.getColumnsWithTypeAndName()); + + if (right_convert_actions) + { + auto converting_step = std::make_unique(right_plan.getCurrentDataStream(), right_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + right_plan.addStep(std::move(converting_step)); + } + + if (left_convert_actions) + { + auto converting_step = std::make_unique(right_plan.getCurrentDataStream(), right_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + left_plan.addStep(std::move(converting_step)); + } + auto hash_join = std::make_shared(join, right_plan.getCurrentDataStream().header); + + QueryPlanStepPtr join_step = std::make_unique( + left_plan.getCurrentDataStream(), + right_plan.getCurrentDataStream(), + hash_join, + 8192, 1, false); + + std::cerr<< "join step:" <getOutputStream().header.dumpStructure() << std::endl; + + std::vector plans; + plans.emplace_back(std::make_unique(std::move(left_plan))); + plans.emplace_back(std::make_unique(std::move(right_plan))); + + auto query_plan = QueryPlan(); + query_plan.unitePlans(std::move(join_step), {std::move(plans)}); + std::cerr << query_plan.getCurrentDataStream().header.dumpStructure() << std::endl; + ActionsDAGPtr project = std::make_shared(query_plan.getCurrentDataStream().header.getNamesAndTypesList()); + project->project({NameWithAlias("colA", "colA"),NameWithAlias("colB", "colB"),NameWithAlias("colD", "colD"),NameWithAlias("colC", "colC")}); + QueryPlanStepPtr project_step = std::make_unique(query_plan.getCurrentDataStream(), project); + query_plan.addStep(std::move(project_step)); + auto pipeline = query_plan.buildQueryPipeline(QueryPlanOptimizationSettings(), BuildQueryPipelineSettings()); + auto executable_pipe = QueryPipelineBuilder::getPipeline(std::move(*pipeline)); + PullingPipelineExecutor executor(executable_pipe); + auto res = pipeline->getHeader().cloneEmpty(); + executor.pull(res); + debug::headBlock(res); +} + + +TEST(TestJoin, StorageJoinFromReadBufferTest) +{ + auto global_context = SerializedPlanParser::global_context; + auto & factory = DB::FunctionFactory::instance(); + auto function = factory.get("murmurHash2_64", local_engine::SerializedPlanParser::global_context); + auto int_type = DataTypeFactory::instance().get("Int32"); + auto column0 = int_type->createColumn(); + column0->insert(1); + column0->insert(2); + column0->insert(3); + column0->insert(4); + + auto column1 = int_type->createColumn(); + column1->insert(2); + column1->insert(4); + column1->insert(6); + column1->insert(8); + + ColumnsWithTypeAndName columns = {ColumnWithTypeAndName(std::move(column0),int_type, "colA"), + ColumnWithTypeAndName(std::move(column1),int_type, "colB")}; + Block left(columns); + + auto column3 = int_type->createColumn(); + column3->insert(1); + column3->insert(2); + column3->insert(3); + column3->insert(5); + + auto column4 = int_type->createColumn(); + column4->insert(1); + column4->insert(3); + column4->insert(5); + column4->insert(9); + + ColumnsWithTypeAndName columns2 = {ColumnWithTypeAndName(std::move(column3),int_type, "colD"), + ColumnWithTypeAndName(std::move(column4),int_type, "colC")}; + Block right(columns2); + std::string buf; + WriteBufferFromString write_buf(buf); + NativeWriter writer(write_buf, 0, right.cloneEmpty()); + writer.write(right); + + auto in = std::make_unique(buf); + auto metadata = local_engine::buildMetaData(right.getNamesAndTypesList(), global_context); + + auto join_storage = std::shared_ptr(new StorageJoinFromReadBuffer( + std::move(in), + StorageID("default", "test"), + {"colD"}, + false, + {}, + JoinKind::Left, + JoinStrictness::All, + ColumnsDescription(right.getNamesAndTypesList()), + {}, + "test", + true)); + auto storage_snapshot = std::make_shared(*join_storage, metadata); + auto left_table = std::make_shared(left); + SelectQueryInfo query_info; + auto right_table = join_storage->read(right.getNames(), storage_snapshot, query_info, global_context, QueryProcessingStage::Enum::FetchColumns, 8192, 1); + QueryPlan left_plan; + left_plan.addStep(std::make_unique(Pipe(left_table))); + + auto join = std::make_shared(SizeLimits(), false, JoinKind::Left, JoinStrictness::All, right.getNames()); + auto required_rkey = NameAndTypePair("colD", int_type); + join->addJoinedColumn(required_rkey); + join->addJoinedColumn(NameAndTypePair("colC", int_type)); + ASTPtr lkey = std::make_shared("colA"); + ASTPtr rkey = std::make_shared("colD"); + join->addOnKeys(lkey, rkey); + + + auto hash_join = join_storage->getJoinLocked(join, global_context); + + QueryPlanStepPtr join_step = std::make_unique( + left_plan.getCurrentDataStream(), + hash_join, + 8192); + + join_step->setStepDescription("JOIN"); + left_plan.addStep(std::move(join_step)); + + ActionsDAGPtr project = std::make_shared(left_plan.getCurrentDataStream().header.getNamesAndTypesList()); + project->project({NameWithAlias("colA", "colA"),NameWithAlias("colB", "colB"),NameWithAlias("colD", "colD"),NameWithAlias("colC", "colC")}); + QueryPlanStepPtr project_step = std::make_unique(left_plan.getCurrentDataStream(), project); + left_plan.addStep(std::move(project_step)); + auto pipeline = left_plan.buildQueryPipeline(QueryPlanOptimizationSettings(), BuildQueryPipelineSettings()); + auto executable_pipe = QueryPipelineBuilder::getPipeline(std::move(*pipeline)); + PullingPipelineExecutor executor(executable_pipe); + auto res = pipeline->getHeader().cloneEmpty(); + executor.pull(res); + debug::headBlock(res); +} + diff --git a/utils/local-engine/tests/gtest_ch_storages.cpp b/utils/local-engine/tests/gtest_ch_storages.cpp new file mode 100644 index 000000000000..6801340d1b0b --- /dev/null +++ b/utils/local-engine/tests/gtest_ch_storages.cpp @@ -0,0 +1,270 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace DB; +using namespace local_engine; + +TEST(TestBatchParquetFileSource, blob) +{ + auto config = local_engine::SerializedPlanParser::config; + config->setString("blob.storage_account_url", "http://127.0.0.1:10000/devstoreaccount1"); + config->setString("blob.container_name", "libch"); + config->setString("blob.container_already_exists", "true"); + config->setString( + "blob.connection_string", + "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=" + "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/" + "devstoreaccount1;"); + + auto builder = std::make_unique(); + substrait::ReadRel::LocalFiles files; + substrait::ReadRel::LocalFiles::FileOrFiles * file = files.add_items(); + std::string file_path = "wasb://libch/parquet/lineitem/part-00000-f83d0a59-2bff-41bc-acde-911002bf1b33-c000.snappy.parquet"; + file->set_uri_file(file_path); + substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions parquet_format; + file->mutable_parquet()->CopyFrom(parquet_format); + + const auto * type_string = "columns format version: 1\n" + "15 columns:\n" + "`l_partkey` Int64\n" + "`l_suppkey` Int64\n" + "`l_linenumber` Int32\n" + "`l_quantity` Float64\n" + "`l_extendedprice` Float64\n" + "`l_discount` Float64\n" + "`l_tax` Float64\n" + "`l_returnflag` String\n" + "`l_linestatus` String\n" + "`l_shipdate` Date\n" + "`l_commitdate` Date\n" + "`l_receiptdate` Date\n" + "`l_shipinstruct` String\n" + "`l_shipmode` String\n" + "`l_comment` String\n"; + auto names_and_types_list = NamesAndTypesList::parse(type_string); + ColumnsWithTypeAndName columns; + for (const auto & item : names_and_types_list) + { + ColumnWithTypeAndName col; + col.column = item.type->createColumn(); + col.type = item.type; + col.name = item.name; + columns.emplace_back(std::move(col)); + } + auto header = Block(std::move(columns)); + builder->init(Pipe(std::make_shared(SerializedPlanParser::global_context, header, files))); + + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*builder)); + auto executor = PullingPipelineExecutor(pipeline); + auto result = header.cloneEmpty(); + size_t total_rows = 0; + bool is_first = true; + while (executor.pull(result)) + { + if (is_first) + debug::headBlock(result); + total_rows += result.rows(); + is_first = false; + } + + ASSERT_TRUE(total_rows > 0); + std::cerr << "rows:" << total_rows << std::endl; +} + +TEST(TestBatchParquetFileSource, s3) +{ + auto config = local_engine::SerializedPlanParser::config; + config->setString("s3.endpoint", "http://localhost:9000/tpch/"); + config->setString("s3.region", "us-east-1"); + config->setString("s3.access_key_id", "admin"); + config->setString("s3.secret_access_key", "password"); + + auto builder = std::make_unique(); + substrait::ReadRel::LocalFiles files; + substrait::ReadRel::LocalFiles::FileOrFiles * file = files.add_items(); + std::string file_path = "s3://tpch/lineitem/part-00000-f83d0a59-2bff-41bc-acde-911002bf1b33-c000.snappy.parquet"; + file->set_uri_file(file_path); + substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions parquet_format; + file->mutable_parquet()->CopyFrom(parquet_format); + + const auto * type_string = "columns format version: 1\n" + "15 columns:\n" + "`l_partkey` Int64\n" + "`l_suppkey` Int64\n" + "`l_linenumber` Int32\n" + "`l_quantity` Float64\n" + "`l_extendedprice` Float64\n" + "`l_discount` Float64\n" + "`l_tax` Float64\n" + "`l_returnflag` String\n" + "`l_linestatus` String\n" + "`l_shipdate` Date\n" + "`l_commitdate` Date\n" + "`l_receiptdate` Date\n" + "`l_shipinstruct` String\n" + "`l_shipmode` String\n" + "`l_comment` String\n"; + auto names_and_types_list = NamesAndTypesList::parse(type_string); + ColumnsWithTypeAndName columns; + for (const auto & item : names_and_types_list) + { + ColumnWithTypeAndName col; + col.column = item.type->createColumn(); + col.type = item.type; + col.name = item.name; + columns.emplace_back(std::move(col)); + } + auto header = Block(std::move(columns)); + builder->init(Pipe(std::make_shared(SerializedPlanParser::global_context, header, files))); + + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*builder)); + auto executor = PullingPipelineExecutor(pipeline); + auto result = header.cloneEmpty(); + size_t total_rows = 0; + bool is_first = true; + while (executor.pull(result)) + { + if (is_first) + debug::headBlock(result); + total_rows += result.rows(); + is_first = false; + } + + ASSERT_TRUE(total_rows > 0); + std::cerr << "rows:" << total_rows << std::endl; +} + +TEST(TestBatchParquetFileSource, local_file) +{ + auto builder = std::make_unique(); + + substrait::ReadRel::LocalFiles files; + substrait::ReadRel::LocalFiles::FileOrFiles * file = files.add_items(); + file->set_uri_file("file:///home/admin1/Documents/data/tpch/parquet/lineitem/part-00000-f83d0a59-2bff-41bc-acde-911002bf1b33-c000.snappy.parquet"); + substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions parquet_format; + file->mutable_parquet()->CopyFrom(parquet_format); + file = files.add_items(); + file->set_uri_file("file:///home/admin1/Documents/data/tpch/parquet/lineitem/part-00001-f83d0a59-2bff-41bc-acde-911002bf1b33-c000.snappy.parquet"); + file->mutable_parquet()->CopyFrom(parquet_format); + file = files.add_items(); + file->set_uri_file("file:///home/admin1/Documents/data/tpch/parquet/lineitem/part-00002-f83d0a59-2bff-41bc-acde-911002bf1b33-c000.snappy.parquet"); + file->mutable_parquet()->CopyFrom(parquet_format); + + const auto * type_string = "columns format version: 1\n" + "2 columns:\n" + // "`l_partkey` Int64\n" + // "`l_suppkey` Int64\n" + // "`l_linenumber` Int32\n" + // "`l_quantity` Float64\n" + // "`l_extendedprice` Float64\n" + "`l_discount` Float64\n" + "`l_tax` Float64\n"; + // "`l_returnflag` String\n" + // "`l_linestatus` String\n" + // "`l_shipdate` Date\n" + // "`l_commitdate` Date\n" + // "`l_receiptdate` Date\n" + // "`l_shipinstruct` String\n" + // "`l_shipmode` String\n" + // "`l_comment` String\n"; + auto names_and_types_list = NamesAndTypesList::parse(type_string); + ColumnsWithTypeAndName columns; + for (const auto & item : names_and_types_list) + { + ColumnWithTypeAndName col; + col.column = item.type->createColumn(); + col.type = item.type; + col.name = item.name; + columns.emplace_back(std::move(col)); + } + auto header = Block(std::move(columns)); + builder->init(Pipe(std::make_shared(SerializedPlanParser::global_context, header, files))); + + auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*builder)); + auto executor = PullingPipelineExecutor(pipeline); + auto result = header.cloneEmpty(); + size_t total_rows = 0; + bool is_first = true; + while (executor.pull(result)) + { + if (is_first) + debug::headBlock(result); + total_rows += result.rows(); + is_first = false; + } + std::cerr << "rows:" << total_rows << std::endl; + ASSERT_TRUE(total_rows == 59986052); +} + +TEST(TestWrite, MergeTreeWriteTest) +{ + auto config = local_engine::SerializedPlanParser::config; + config->setString("s3.endpoint", "http://localhost:9000/tpch/"); + config->setString("s3.region", "us-east-1"); + config->setString("s3.access_key_id", "admin"); + config->setString("s3.secret_access_key", "password"); + auto global_context = local_engine::SerializedPlanParser::global_context; + + auto param = DB::MergeTreeData::MergingParams(); + auto settings = std::make_unique(); + settings->set("min_bytes_for_wide_part", Field(0)); + settings->set("min_rows_for_wide_part", Field(0)); + + const auto * type_string = "columns format version: 1\n" + "15 columns:\n" + "`l_partkey` Int64\n" + "`l_suppkey` Int64\n" + "`l_linenumber` Int32\n" + "`l_quantity` Float64\n" + "`l_extendedprice` Float64\n" + "`l_discount` Float64\n" + "`l_tax` Float64\n" + "`l_returnflag` String\n" + "`l_linestatus` String\n" + "`l_shipdate` Date\n" + "`l_commitdate` Date\n" + "`l_receiptdate` Date\n" + "`l_shipinstruct` String\n" + "`l_shipmode` String\n" + "`l_comment` String\n"; + auto names_and_types_list = NamesAndTypesList::parse(type_string); + auto metadata = local_engine::buildMetaData(names_and_types_list, global_context); + + local_engine::CustomStorageMergeTree custom_merge_tree(DB::StorageID("default", "test"), + "tmp/test-write/", + *metadata, + false, + global_context, + "", + param, + std::move(settings) + ); + + substrait::ReadRel::LocalFiles files; + substrait::ReadRel::LocalFiles::FileOrFiles * file = files.add_items(); + file->set_uri_file("s3://tpch/lineitem/part-00000-f83d0a59-2bff-41bc-acde-911002bf1b33-c000.snappy.parquet"); + substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions parquet_format; + file->mutable_parquet()->CopyFrom(parquet_format); + auto source = std::make_shared(SerializedPlanParser::global_context, metadata->getSampleBlock(), files); + + QueryPipelineBuilder query_pipeline_builder; + query_pipeline_builder.init(Pipe(source)); + query_pipeline_builder.setSinks([&](const Block &, Pipe::StreamType type) -> ProcessorPtr + { + if (type != Pipe::StreamType::Main) + return nullptr; + + return std::make_shared(custom_merge_tree, metadata, global_context); + }); + auto executor = query_pipeline_builder.execute(); + executor->execute(1); +} diff --git a/utils/local-engine/tests/gtest_local_engine.cpp b/utils/local-engine/tests/gtest_local_engine.cpp new file mode 100644 index 000000000000..9ea81e240684 --- /dev/null +++ b/utils/local-engine/tests/gtest_local_engine.cpp @@ -0,0 +1,373 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "testConfig.h" +#include "Common/Logger.h" +#include "Common/DebugUtils.h" + +using namespace local_engine; +using namespace dbms; + +TEST(TestSelect, ReadRel) +{ + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("sepal_length", "FP64") + .column("sepal_width", "FP64") + .column("petal_length", "FP64") + .column("petal_width", "FP64") + .column("type", "I64") + .column("type_string", "String") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto plan = plan_builder.read(TEST_DATA(/data/iris.parquet), std::move(schema)).build(); + + ASSERT_TRUE(plan->relations(0).root().input().has_read()); + ASSERT_EQ(plan->relations_size(), 1); + local_engine::LocalExecutor local_executor; + local_engine::SerializedPlanParser parser(local_engine::SerializedPlanParser::global_context); + auto query_plan = parser.parse(std::move(plan)); + local_executor.execute(std::move(query_plan)); + ASSERT_TRUE(local_executor.hasNext()); + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + ASSERT_GT(spark_row_info->getNumRows(), 0); + local_engine::SparkRowToCHColumn converter; + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + ASSERT_EQ(spark_row_info->getNumRows(), block->rows()); + } +} + +TEST(TestSelect, ReadDate) +{ + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("date", "Date").build(); + dbms::SerializedPlanBuilder plan_builder; + auto plan = plan_builder.read(TEST_DATA(/data/date.parquet), std::move(schema)).build(); + + ASSERT_TRUE(plan->relations(0).root().input().has_read()); + ASSERT_EQ(plan->relations_size(), 1); + local_engine::LocalExecutor local_executor; + local_engine::SerializedPlanParser parser(local_engine::SerializedPlanParser::global_context); + auto query_plan = parser.parse(std::move(plan)); + local_executor.execute(std::move(query_plan)); + ASSERT_TRUE(local_executor.hasNext()); + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + ASSERT_GT(spark_row_info->getNumRows(), 0); + local_engine::SparkRowToCHColumn converter; + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + ASSERT_EQ(spark_row_info->getNumRows(), block->rows()); + } +} + +TEST(TestSelect, TestFilter) +{ + dbms::SerializedSchemaBuilder schema_builder; + // sorted by key + auto * schema = schema_builder.column("sepal_length", "FP64") + .column("sepal_width", "FP64") + .column("petal_length", "FP64") + .column("petal_width", "FP64") + .column("type", "I64") + .column("type_string", "String") + .build(); + dbms::SerializedPlanBuilder plan_builder; + // sepal_length * 0.8 + auto * mul_exp = dbms::scalarFunction(dbms::MULTIPLY, {dbms::selection(2), dbms::literal(0.8)}); + // sepal_length * 0.8 < 4.0 + auto * less_exp = dbms::scalarFunction(dbms::LESS_THAN, {mul_exp, dbms::literal(4.0)}); + // type_string = '类型1' + auto * type_0 = dbms::scalarFunction(dbms::EQUAL_TO, {dbms::selection(5), dbms::literal("类型1")}); + + auto * filter = dbms::scalarFunction(dbms::AND, {less_exp, type_0}); + auto plan = plan_builder.registerSupportedFunctions().filter(filter).read(TEST_DATA(/data/iris.parquet), std::move(schema)).build(); + ASSERT_EQ(plan->relations_size(), 1); + local_engine::LocalExecutor local_executor; + local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); + auto query_plan = parser.parse(std::move(plan)); + local_executor.execute(std::move(query_plan)); + ASSERT_TRUE(local_executor.hasNext()); + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + ASSERT_EQ(spark_row_info->getNumRows(), 1); + local_engine::SparkRowToCHColumn converter; + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + ASSERT_EQ(spark_row_info->getNumRows(), block->rows()); + } +} + +TEST(TestSelect, TestAgg) +{ + dbms::SerializedSchemaBuilder schema_builder; + // sorted by key + auto * schema = schema_builder.column("sepal_length", "FP64") + .column("sepal_width", "FP64") + .column("petal_length", "FP64") + .column("petal_width", "FP64") + .column("type", "I64") + .column("type_string", "String") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto * mul_exp = dbms::scalarFunction(dbms::MULTIPLY, {dbms::selection(2), dbms::literal(0.8)}); + auto * less_exp = dbms::scalarFunction(dbms::LESS_THAN, {mul_exp, dbms::literal(4.0)}); + auto * measure = dbms::measureFunction(dbms::SUM, {dbms::selection(2)}); + auto plan = plan_builder.registerSupportedFunctions() + .aggregate({}, {measure}) + .filter(less_exp) + .read(TEST_DATA(/data/iris.parquet), std::move(schema)) + .build(); + ASSERT_EQ(plan->relations_size(), 1); + local_engine::LocalExecutor local_executor; + local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); + auto query_plan = parser.parse(std::move(plan)); + local_executor.execute(std::move(query_plan)); + ASSERT_TRUE(local_executor.hasNext()); + while (local_executor.hasNext()) + { + local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + ASSERT_EQ(spark_row_info->getNumRows(), 1); + ASSERT_EQ(spark_row_info->getNumCols(), 1); + local_engine::SparkRowToCHColumn converter; + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + ASSERT_EQ(spark_row_info->getNumRows(), block->rows()); + auto reader = SparkRowReader(block->getDataTypes()); + reader.pointTo(spark_row_info->getBufferAddress() + spark_row_info->getOffsets()[1], spark_row_info->getLengths()[0]); + ASSERT_EQ(reader.getDouble(0), 103.2); + } +} + +TEST(TestSelect, MergeTreeWriteTest) +{ + std::shared_ptr metadata = std::make_shared(); + ColumnsDescription columns_description; + auto shared_context = Context::createShared(); + auto global_context = Context::createGlobal(shared_context.get()); + global_context->makeGlobalContext(); + global_context->setPath("/home/kyligence/Documents/clickhouse_conf/data/"); + global_context->getDisksMap().emplace(); + auto int64_type = std::make_shared(); + auto int32_type = std::make_shared(); + auto double_type = std::make_shared(); + columns_description.add(ColumnDescription("l_orderkey", int64_type)); + columns_description.add(ColumnDescription("l_partkey", int64_type)); + columns_description.add(ColumnDescription("l_suppkey", int64_type)); + columns_description.add(ColumnDescription("l_linenumber", int32_type)); + columns_description.add(ColumnDescription("l_quantity", double_type)); + columns_description.add(ColumnDescription("l_extendedprice", double_type)); + columns_description.add(ColumnDescription("l_discount", double_type)); + columns_description.add(ColumnDescription("l_tax", double_type)); + columns_description.add(ColumnDescription("l_shipdate_new", double_type)); + columns_description.add(ColumnDescription("l_commitdate_new", double_type)); + columns_description.add(ColumnDescription("l_receiptdate_new", double_type)); + metadata->setColumns(columns_description); + metadata->partition_key.expression_list_ast = std::make_shared(); + metadata->sorting_key = KeyDescription::getSortingKeyFromAST(makeASTFunction("tuple"), columns_description, global_context, {}); + metadata->primary_key.expression = std::make_shared(std::make_shared()); + auto param = DB::MergeTreeData::MergingParams(); + auto settings = std::make_unique(); + settings->set("min_bytes_for_wide_part", Field(0)); + settings->set("min_rows_for_wide_part", Field(0)); + + local_engine::CustomStorageMergeTree custom_merge_tree( + DB::StorageID("default", "test"), "test-intel/", *metadata, false, global_context, "", param, std::move(settings)); + + auto sink = std::make_shared(custom_merge_tree, metadata, global_context); + + substrait::ReadRel::LocalFiles files; + substrait::ReadRel::LocalFiles::FileOrFiles * file = files.add_items(); + std::string file_path = "file:///home/kyligence/Documents/test-dataset/intel-gazelle-test-150.snappy.parquet"; + file->set_uri_file(file_path); + substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions parquet_format; + file->mutable_parquet()->CopyFrom(parquet_format); + auto source = std::make_shared(SerializedPlanParser::global_context, metadata->getSampleBlock(), files); + + QueryPipelineBuilder query_pipeline; + query_pipeline.init(Pipe(source)); + query_pipeline.setSinks( + [&](const Block &, Pipe::StreamType type) -> ProcessorPtr + { + if (type != Pipe::StreamType::Main) + return nullptr; + + return std::make_shared(custom_merge_tree, metadata, global_context); + }); + auto executor = query_pipeline.execute(); + executor->execute(1); +} + +TEST(TESTUtil, TestByteToLong) +{ + Int64 expected = 0xf085460ccf7f0000l; + char * arr = new char[8]; + arr[0] = -16; + arr[1] = -123; + arr[2] = 70; + arr[3] = 12; + arr[4] = -49; + arr[5] = 127; + arr[6] = 0; + arr[7] = 0; + std::reverse(arr, arr + 8); + Int64 result = reinterpret_cast(arr)[0]; + std::cout << std::to_string(result); + + ASSERT_EQ(expected, result); +} + + +TEST(TestSimpleAgg, TestGenerate) +{ +// dbms::SerializedSchemaBuilder schema_builder; +// auto * schema = schema_builder.column("l_orderkey", "I64") +// .column("l_partkey", "I64") +// .column("l_suppkey", "I64") +// .build(); +// dbms::SerializedPlanBuilder plan_builder; +// auto * measure = dbms::measureFunction(dbms::SUM, {dbms::selection(6)}); +// auto plan +// = plan_builder.registerSupportedFunctions() +// .aggregate({}, {measure}) +// .read( +// //"/home/kyligence/Documents/test-dataset/intel-gazelle-test-" + std::to_string(state.range(0)) + ".snappy.parquet", +// "/data0/tpch100_zhichao/parquet_origin/lineitem/part-00087-066b93b4-39e1-4d46-83ab-d7752096b599-c000.snappy.parquet", +// std::move(schema)) +// .build(); + local_engine::SerializedPlanParser parser(local_engine::SerializedPlanParser::global_context); +//// auto query_plan = parser.parse(std::move(plan)); + + //std::ifstream t("/home/hongbin/develop/experiments/221011_substrait_agg_on_empty_table.json"); + //std::ifstream t("/home/hongbin/develop/experiments/221101_substrait_agg_on_simple_table_last_phrase.json"); + std::ifstream t("/home/hongbin/develop/experiments/221102_substrait_agg_and_countdistinct_second_phrase.json"); + std::string str((std::istreambuf_iterator(t)), + std::istreambuf_iterator()); + auto query_plan = parser.parseJson(str); + local_engine::LocalExecutor local_executor; + local_executor.execute(std::move(query_plan)); + while (local_executor.hasNext()) + { + auto block = local_executor.nextColumnar(); + debug::headBlock(*block); + } +} + +TEST(TestSubstrait, TestGenerate) +{ + dbms::SerializedSchemaBuilder schema_builder; + auto * schema = schema_builder.column("l_discount", "FP64") + .column("l_extendedprice", "FP64") + .column("l_quantity", "FP64") + .column("l_shipdate_new", "Date") + .build(); + dbms::SerializedPlanBuilder plan_builder; + auto * agg_mul = dbms::scalarFunction(dbms::MULTIPLY, {dbms::selection(1), dbms::selection(0)}); + auto * measure1 = dbms::measureFunction(dbms::SUM, {agg_mul}); + auto * measure2 = dbms::measureFunction(dbms::SUM, {dbms::selection(1)}); + auto * measure3 = dbms::measureFunction(dbms::SUM, {dbms::selection(2)}); + auto plan + = plan_builder.registerSupportedFunctions() + .aggregate({}, {measure1, measure2, measure3}) + .project({dbms::selection(2), dbms::selection(1), dbms::selection(0)}) + .filter(dbms::scalarFunction( + dbms::AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {dbms::scalarFunction( + AND, + {scalarFunction(IS_NOT_NULL, {selection(3)}), scalarFunction(IS_NOT_NULL, {selection(0)})}), + scalarFunction(IS_NOT_NULL, {selection(2)})}), + dbms::scalarFunction(GREATER_THAN_OR_EQUAL, {selection(3), literalDate(8766)})}), + scalarFunction(LESS_THAN, {selection(3), literalDate(9131)})}), + scalarFunction(GREATER_THAN_OR_EQUAL, {selection(0), literal(0.05)})}), + scalarFunction(LESS_THAN_OR_EQUAL, {selection(0), literal(0.07)})}), + scalarFunction(LESS_THAN, {selection(2), literal(24.0)})})) + .readMergeTree("default", "test", "usr/code/data/test-mergetree", 1, 12, std::move(schema)) + .build(); + std::ofstream output; + output.open("/home/kyligence/Documents/code/ClickHouse/plan.txt", std::fstream::in | std::fstream::out | std::fstream::trunc); + output << plan->SerializeAsString(); + // plan->SerializeToOstream(&output); + output.flush(); + output.close(); +} + +TEST(ReadBufferFromFile, seekBackwards) +{ + static constexpr size_t N = 256; + static constexpr size_t BUF_SIZE = 64; + + auto tmp_file = createTemporaryFile("/tmp/"); + + { + WriteBufferFromFile out(tmp_file->path()); + for (size_t i = 0; i < N; ++i) + writeIntBinary(i, out); + } + + ReadBufferFromFile in(tmp_file->path(), BUF_SIZE); + size_t x; + + /// Read something to initialize the buffer. + in.seek(BUF_SIZE * 10, SEEK_SET); + readIntBinary(x, in); + + /// Check 2 consecutive seek calls without reading. + in.seek(BUF_SIZE * 2, SEEK_SET); + // readIntBinary(x, in); + in.seek(BUF_SIZE, SEEK_SET); + + readIntBinary(x, in); + ASSERT_EQ(x, 8); +} + +int main(int argc, char ** argv) +{ + local_engine::Logger::initConsoleLogger(); + + SharedContextHolder shared_context = Context::createShared(); + local_engine::SerializedPlanParser::global_context = Context::createGlobal(shared_context.get()); + local_engine::SerializedPlanParser::global_context->makeGlobalContext(); + auto config = Poco::AutoPtr(new Poco::Util::MapConfiguration()); + local_engine::SerializedPlanParser::global_context->setConfig(config); + local_engine::SerializedPlanParser::global_context->setPath("/tmp"); + local_engine::SerializedPlanParser::global_context->getDisksMap().emplace(); + local_engine::SerializedPlanParser::initFunctionEnv(); + registerReadBufferBuilders(); + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/utils/local-engine/tests/gtest_orc_input_format.cpp b/utils/local-engine/tests/gtest_orc_input_format.cpp new file mode 100644 index 000000000000..1d6113c1a77f --- /dev/null +++ b/utils/local-engine/tests/gtest_orc_input_format.cpp @@ -0,0 +1,137 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +class TestOrcInputFormat : public local_engine::ORCBlockInputFormat +{ +public: + explicit TestOrcInputFormat( + DB::ReadBuffer & in_, + DB::Block header_, + const DB::FormatSettings & format_settings_, + const std::vector & stripes_) + : local_engine::ORCBlockInputFormat(in_, header_, format_settings_, stripes_) + {} + + DB::Chunk callGenerate() + { + return generate(); + } +}; + +static std::string orc_file_path = "./utils/local-engine/tests/data/lineitem.orc"; + +static DB::Block buildLineitemHeader() +{ + /* + `l_orderkey` bigint COMMENT 'oops', + `l_partkey` bigint COMMENT 'oops', + `l_suppkey` bigint COMMENT 'oops', + `l_linenumber` int COMMENT 'oops', + `l_quantity` double COMMENT 'oops', + `l_extendedprice` double COMMENT 'oops', + `l_discount` double COMMENT 'oops', + `l_tax` double COMMENT 'oops', + `l_returnflag` string COMMENT 'oops', + `l_linestatus` string COMMENT 'oops', + `l_shipdate` date COMMENT 'oops', + `l_commitdate` date COMMENT 'oops', + `l_receiptdate` date COMMENT 'oops', + `l_shipinstruct` string COMMENT 'oops', + `l_shipmode` string COMMENT 'oops', + `l_comment` string COMMENT 'oops') + */ + + DB::Block header; + + auto bigint_ty = std::make_shared(); + auto int_ty = std::make_shared(); + auto double_ty = std::make_shared(); + auto string_ty = std::make_shared(); + auto date_ty = std::make_shared(); + + auto l_orderkey_col = bigint_ty->createColumn(); + DB::ColumnWithTypeAndName l_orderkey(std::move(l_orderkey_col), bigint_ty, "l_orderkey"); + header.insert(l_orderkey); + DB::ColumnWithTypeAndName l_partkey(std::move(bigint_ty->createColumn()), bigint_ty, "l_partkey"); + header.insert(l_partkey); + DB::ColumnWithTypeAndName l_suppkey(std::move(bigint_ty->createColumn()), bigint_ty, "l_suppkey"); + header.insert(l_suppkey); + DB::ColumnWithTypeAndName l_linenumber(std::move(int_ty->createColumn()), int_ty, "l_linenumber"); + header.insert(l_linenumber); + DB::ColumnWithTypeAndName l_quantity(std::move(double_ty->createColumn()), double_ty, "l_quantity"); + header.insert(l_quantity); + DB::ColumnWithTypeAndName l_extendedprice(std::move(double_ty->createColumn()), double_ty, "l_extendedprice"); + header.insert(l_extendedprice); + DB::ColumnWithTypeAndName l_discount(std::move(double_ty->createColumn()), double_ty, "l_discount"); + header.insert(l_discount); + DB::ColumnWithTypeAndName l_tax(std::move(double_ty->createColumn()), double_ty, "l_tax"); + header.insert(l_tax); + DB::ColumnWithTypeAndName l_returnflag(std::move(string_ty->createColumn()), string_ty, "l_returnflag"); + header.insert(l_returnflag); + DB::ColumnWithTypeAndName l_linestatus(std::move(string_ty->createColumn()), string_ty, "l_linestatus"); + header.insert(l_linestatus); + DB::ColumnWithTypeAndName l_shipdate(std::move(date_ty->createColumn()), date_ty, "l_shipdate"); + header.insert(l_shipdate); + DB::ColumnWithTypeAndName l_commitdate(std::move(date_ty->createColumn()), date_ty, "l_commitdate"); + header.insert(l_commitdate); + DB::ColumnWithTypeAndName l_receiptdate(std::move(date_ty->createColumn()), date_ty, "l_receiptdate"); + header.insert(l_receiptdate); + DB::ColumnWithTypeAndName l_shipinstruct(std::move(string_ty->createColumn()), string_ty, "l_shipinstruct"); + header.insert(l_shipinstruct); + DB::ColumnWithTypeAndName l_shipmode(std::move(string_ty->createColumn()), string_ty, "l_shipmode"); + header.insert(l_shipmode); + DB::ColumnWithTypeAndName l_comment(std::move(string_ty->createColumn()), string_ty, "l_comment"); + header.insert(l_comment); + + return header; +} + +std::vector collectRequiredStripes(DB::ReadBuffer* read_buffer) +{ + std::vector stripes; + DB::FormatSettings format_settings; + format_settings.seekable_read = true; + std::atomic is_stopped{0}; + auto arrow_file = DB::asArrowFile(*read_buffer, format_settings, is_stopped, "ORC", ORC_MAGIC_BYTES); + auto orc_reader = local_engine::OrcUtil::createOrcReader(arrow_file); + auto num_stripes = orc_reader->getNumberOfStripes(); + + size_t total_num_rows = 0; + for (size_t i = 0; i < num_stripes; ++i) + { + auto stripe_metadata = orc_reader->getStripe(i); + auto offset = stripe_metadata->getOffset(); + local_engine::StripeInformation stripe_info; + stripe_info.index = i; + stripe_info.offset = stripe_metadata->getLength(); + stripe_info.length = stripe_metadata->getLength(); + stripe_info.num_rows = stripe_metadata->getNumberOfRows(); + stripe_info.start_row = total_num_rows; + stripes.emplace_back(stripe_info); + total_num_rows += stripe_metadata->getNumberOfRows(); + } + return stripes; + +} + +TEST(OrcInputFormat, CallGenerate) +{ + auto file_in = std::make_shared(orc_file_path); + auto stripes = collectRequiredStripes(file_in.get()); + DB::FormatSettings format_settings; + auto input_format = std::make_shared(*file_in, buildLineitemHeader(), format_settings, stripes); + auto chunk = input_format->callGenerate(); + EXPECT_TRUE(chunk.getNumRows() == 2); +} diff --git a/utils/local-engine/tests/gtest_parquet_read.cpp b/utils/local-engine/tests/gtest_parquet_read.cpp new file mode 100644 index 000000000000..dd4b51fe6ed7 --- /dev/null +++ b/utils/local-engine/tests/gtest_parquet_read.cpp @@ -0,0 +1,280 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace DB; + +template +static void readSchema(const String & path) +{ + FormatSettings settings; + auto in = std::make_shared(path); + ParquetSchemaReader schema_reader(*in, settings); + auto name_and_types = schema_reader.readSchema(); + auto & factory = DataTypeFactory::instance(); + + auto check_type = [&name_and_types, &factory](const String & column, const String & expect_str_type) + { + auto expect_type = factory.get(expect_str_type); + + auto name_and_type = name_and_types.tryGetByName(column); + EXPECT_TRUE(name_and_type); + + // std::cout << "real_type:" << name_and_type->type->getName() << ", expect_type:" << expect_type->getName() << std::endl; + EXPECT_TRUE(name_and_type->type->equals(*expect_type)); + }; + + check_type("f_bool", "Nullable(UInt8)"); + check_type("f_byte", "Nullable(Int8)"); + check_type("f_short", "Nullable(Int16)"); + check_type("f_int", "Nullable(Int32)"); + check_type("f_long", "Nullable(Int64)"); + check_type("f_float", "Nullable(Float32)"); + check_type("f_double", "Nullable(Float64)"); + check_type("f_string", "Nullable(String)"); + check_type("f_binary", "Nullable(String)"); + check_type("f_decimal", "Nullable(Decimal(10, 2))"); + check_type("f_date", "Nullable(Date32)"); + check_type("f_timestamp", "Nullable(DateTime64(9))"); + check_type("f_array", "Nullable(Array(Nullable(String)))"); + check_type("f_array_array", "Nullable(Array(Nullable(Array(Nullable(String)))))"); + check_type("f_array_map", "Nullable(Array(Nullable(Map(String, Nullable(Int64)))))"); + check_type("f_array_struct", "Nullable(Array(Nullable(Tuple(a Nullable(String), b Nullable(Int64)))))"); + check_type("f_map", "Nullable(Map(String, Nullable(Int64)))"); + check_type("f_map_map", "Nullable(Map(String, Nullable(Map(String, Nullable(Int64)))))"); + check_type("f_map_array", "Nullable(Map(String, Nullable(Array(Nullable(Int64)))))"); + check_type("f_map_struct", "Nullable(Map(String, Nullable(Tuple(a Nullable(String), b Nullable(Int64)))))"); + check_type("f_struct", "Nullable(Tuple(a Nullable(String), b Nullable(Int64)))"); + check_type("f_struct_struct", "Nullable(Tuple(a Nullable(String), b Nullable(Int64), c Nullable(Tuple(x Nullable(String), y Nullable(Int64)))))"); + check_type("f_struct_array", "Nullable(Tuple(a Nullable(String), b Nullable(Int64), c Nullable(Array(Nullable(Int64)))))"); + check_type("f_struct_map", "Nullable(Tuple(a Nullable(String), b Nullable(Int64), c Nullable(Map(String, Nullable(Int64)))))"); +} + +template +static void readData(const String & path, const std::map & fields) +{ + auto in = std::make_shared(path); + FormatSettings settings; + SchemaReader schema_reader(*in, settings); + auto name_and_types = schema_reader.readSchema(); + + ColumnsWithTypeAndName columns; + columns.reserve(name_and_types.size()); + for (const auto & name_and_type : name_and_types) + if (fields.count(name_and_type.name)) + columns.emplace_back(name_and_type.type, name_and_type.name); + + Block header(columns); + in = std::make_shared(path); + auto format = std::make_shared(*in, header, settings); + auto pipeline = QueryPipeline(std::move(format)); + auto reader = std::make_unique(pipeline); + + Block block; + EXPECT_TRUE(reader->pull(block)); + EXPECT_TRUE(block.rows() == 1); + + for (const auto & name_and_type : name_and_types) + { + const auto & name = name_and_type.name; + auto it = fields.find(name); + if (it != fields.end()) + { + const auto & column = block.getByName(name); + auto field = (*column.column)[0]; + auto expect_field = it->second; + // std::cout << "field:" << toString(field) << ", expect_field:" << toString(expect_field) << std::endl; + EXPECT_TRUE(field == expect_field); + } + } +} + +TEST(ParquetRead, ReadSchema) +{ + readSchema("./utils/local-engine/tests/data/alltypes/alltypes_notnull.parquet"); + readSchema("./utils/local-engine/tests/data/alltypes/alltypes_null.parquet"); + readSchema("./utils/local-engine/tests/data/alltypes/alltypes_null.parquet"); + readSchema("./utils/local-engine/tests/data/alltypes/alltypes_null.parquet"); +} + +TEST(ParquetRead, ReadDataNotNull) +{ + const String path = "./utils/local-engine/tests/data/alltypes/alltypes_notnull.parquet"; + const std::map fields{ + {"f_array", Array{"hello", "world"}}, + {"f_bool", UInt8(1)}, + {"f_byte", Int8(1)}, + {"f_short", Int16(2)}, + {"f_int", Int32(3)}, + {"f_long", Int64(4)}, + {"f_float", Float32(5.5)}, + {"f_double", Float64(6.6)}, + {"f_string", "hello world"}, + {"f_binary", "hello world"}, + {"f_decimal", DecimalField(777, 2)}, + {"f_date", Int32(18262)}, + {"f_timestamp", DecimalField(1666162060000000L, 6)}, + {"f_array", Array{"hello", "world"}}, + { + "f_array_array", + []() -> Field + { + Array res; + res.push_back(Array{"hello"}); + res.push_back(Array{"world"}); + return std::move(res); + }(), + }, + { + "f_array_map", + []() -> Field + { + Array res; + + Map map; + map.push_back(Tuple{"hello", Int64(1)}); + res.push_back(map); + + map.clear(); + map.push_back(Tuple{"world", Int64(2)}); + res.push_back(map); + + return std::move(res); + }(), + }, + { + "f_array_struct", + []() -> Field + { + Array res; + res.push_back(Tuple{"hello", Int64(1)}); + res.push_back(Tuple{"world", Int64(2)}); + + return std::move(res); + }(), + }, + { + "f_map", + []() -> Field + { + Map res; + res.push_back(Tuple{"hello", Int64(1)}); + res.push_back(Tuple{"world", Int64(2)}); + return std::move(res); + }(), + }, + { + "f_map_map", + []() -> Field + { + Map nested_map; + nested_map.push_back(Tuple{"world", Int64(3)}); + + Map res; + res.push_back(Tuple{"hello", std::move(nested_map)}); + return std::move(res); + }(), + }, + { + "f_map_array", + []() -> Field + { + Array array{Int64(1), Int64(2), Int64(3)}; + + Map res; + res.push_back(Tuple{"hello", std::move(array)}); + return std::move(res); + }(), + }, + { + "f_map_struct", + []() -> Field + { + Tuple tuple{"world", Int64(4)}; + + Map res; + res.push_back(Tuple{"hello", std::move(tuple)}); + return std::move(res); + }(), + }, + { + "f_struct", + []() -> Field + { + Tuple res{"hello world", Int64(5)}; + return std::move(res); + }(), + }, + { + "f_struct_struct", + []() -> Field + { + Tuple tuple{"world", Int64(6)}; + Tuple res{"hello", Int64(6), std::move(tuple)}; + return std::move(res); + }(), + }, + { + "f_struct_array", + []() -> Field + { + Array array{Int64(1), Int64(2), Int64(3)}; + Tuple res{"hello", Int64(7), std::move(array)}; + return std::move(res); + }(), + }, + { + "f_struct_map", + []() -> Field + { + Map map; + map.push_back(Tuple{"world", Int64(9)}); + + Tuple res{"hello", Int64(8), std::move(map)}; + return std::move(res); + }(), + }, + }; + + readData(path, fields); + readData(path, fields); +} + + +TEST(ParquetRead, ReadDataNull) +{ + const String path = "./utils/local-engine/tests/data/alltypes/alltypes_null.parquet"; + std::map fields{ + {"f_array", Null{}}, {"f_bool", Null{}}, {"f_byte", Null{}}, {"f_short", Null{}}, + {"f_int", Null{}}, {"f_long", Null{}}, {"f_float", Null{}}, {"f_double", Null{}}, + {"f_string", Null{}}, {"f_binary", Null{}}, {"f_decimal", Null{}}, {"f_date", Null{}}, + {"f_timestamp", Null{}}, {"f_array", Null{}}, {"f_array_array", Null{}}, {"f_array_map", Null{}}, + {"f_array_struct", Null{}}, {"f_map", Null{}}, {"f_map_map", Null{}}, {"f_map_array", Null{}}, + {"f_map_struct", Null{}}, {"f_struct", Null{}}, {"f_struct_struct", Null{}}, {"f_struct_array", Null{}}, + {"f_struct_map", Null{}}, + }; + + readData(path, fields); + readData(path, fields); +} + diff --git a/utils/local-engine/tests/gtest_spark_row.cpp b/utils/local-engine/tests/gtest_spark_row.cpp new file mode 100644 index 000000000000..bcc5621502a6 --- /dev/null +++ b/utils/local-engine/tests/gtest_spark_row.cpp @@ -0,0 +1,463 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace local_engine; +using namespace DB; + + +struct DataTypeAndField +{ + DataTypePtr type; + Field field; +}; +using DataTypeAndFields = std::vector; + +using SparkRowAndBlock = std::pair; + +static SparkRowAndBlock mockSparkRowInfoAndBlock(const DataTypeAndFields & type_and_fields) +{ + /// Initialize types + ColumnsWithTypeAndName columns(type_and_fields.size()); + for (size_t i=0; iinsert(type_and_fields[i].field); + block.setColumns(std::move(mutable_colums)); + + auto converter = CHColumnToSparkRow(); + auto spark_row_info = converter.convertCHColumnToSparkRow(block); + return std::make_tuple(std::move(spark_row_info), std::make_shared(std::move(block))); +} + +static Int32 getDayNum(const String & date) +{ + ExtendedDayNum res; + ReadBufferFromString in(date); + readDateText(res, in); + return res; +} + +static DateTime64 getDateTime64(const String & datetime64, UInt32 scale) +{ + DateTime64 res; + ReadBufferFromString in(datetime64); + readDateTime64Text(res, scale, in); + return res; +} + +static void assertReadConsistentWithWritten(const SparkRowInfo & spark_row_info, const Block & in, const DataTypeAndFields type_and_fields) +{ + /// Check if output of SparkRowReader is consistent with types_and_fields + { + auto reader = SparkRowReader(spark_row_info.getDataTypes()); + reader.pointTo(spark_row_info.getBufferAddress(), spark_row_info.getTotalBytes()); + for (size_t i = 0; i < type_and_fields.size(); ++i) + { + /* + const auto read_field{std::move(reader.getField(i))}; + const auto & written_field = type_and_fields[i].field; + std::cout << "read_field:" << read_field.getType() << "," << toString(read_field) << std::endl; + std::cout << "written_field:" << written_field.getType() << "," << toString(written_field) << std::endl; + */ + EXPECT_TRUE(reader.getField(i) == type_and_fields[i].field); + } + } + + /// check if output of SparkRowToCHColumn is consistents with initial block. + { + auto block = SparkRowToCHColumn::convertSparkRowInfoToCHColumn(spark_row_info, in.cloneEmpty()); + const auto & out = *block; + EXPECT_TRUE(in.rows() == out.rows()); + EXPECT_TRUE(in.columns() == out.columns()); + for (size_t col_idx = 0; col_idx < in.columns(); ++col_idx) + { + const auto & in_col = in.getByPosition(col_idx); + const auto & out_col = out.getByPosition(col_idx); + for (size_t row_idx = 0; row_idx < in.rows(); ++row_idx) + { + const auto in_field = (*in_col.column)[row_idx]; + const auto out_field = (*out_col.column)[row_idx]; + EXPECT_TRUE(in_field == out_field); + } + } + } +} + +TEST(SparkRow, BitSetWidthCalculation) +{ + EXPECT_TRUE(calculateBitSetWidthInBytes(0) == 0); + EXPECT_TRUE(calculateBitSetWidthInBytes(1) == 8); + EXPECT_TRUE(calculateBitSetWidthInBytes(32) == 8); + EXPECT_TRUE(calculateBitSetWidthInBytes(64) == 8); + EXPECT_TRUE(calculateBitSetWidthInBytes(65) == 16); + EXPECT_TRUE(calculateBitSetWidthInBytes(128) == 16); +} + +TEST(SparkRow, GetArrayElementSize) +{ + const std::map type_to_size = { + {std::make_shared(), 1}, + {std::make_shared(), 1}, + {std::make_shared(), 2}, + {std::make_shared(), 2}, + {std::make_shared(), 2}, + {std::make_shared(), 4}, + {std::make_shared(), 4}, + {std::make_shared(), 4}, + {std::make_shared(), 4}, + {std::make_shared(9, 4), 4}, + {std::make_shared(), 8}, + {std::make_shared(), 8}, + {std::make_shared(), 8}, + {std::make_shared(6), 8}, + {std::make_shared(18, 4), 8}, + + {std::make_shared(), 8}, + {std::make_shared(38, 4), 8}, + {std::make_shared(std::make_shared(), std::make_shared()), 8}, + {std::make_shared(std::make_shared()), 8}, + {std::make_shared(DataTypes{std::make_shared(), std::make_shared()}), 8}, + }; + + for (const auto & [type, size] : type_to_size) + { + EXPECT_TRUE(BackingDataLengthCalculator::getArrayElementSize(type) == size); + if (type->canBeInsideNullable()) + { + const auto type_with_nullable = std::make_shared(type); + EXPECT_TRUE(BackingDataLengthCalculator::getArrayElementSize(type_with_nullable) == size); + } + } +} + +TEST(SparkRow, PrimitiveTypes) +{ + DataTypeAndFields type_and_fields = { + {std::make_shared(), -1}, + {std::make_shared(), UInt64(1)}, + {std::make_shared(), -2}, + {std::make_shared(), UInt32(2)}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + 4 * 8); +} + +TEST(SparkRow, PrimitiveStringTypes) +{ + DataTypeAndFields type_and_fields = { + {std::make_shared(), -1}, + {std::make_shared(), UInt64(1)}, + {std::make_shared(), "Hello World"}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + (8 * 3) + roundNumberOfBytesToNearestWord(strlen("Hello World"))); +} + +TEST(SparkRow, PrimitiveStringDateTimestampTypes) +{ + DataTypeAndFields type_and_fields = { + {std::make_shared(), -1}, + {std::make_shared(), UInt64(1)}, + {std::make_shared(), "Hello World"}, + {std::make_shared(), getDayNum("2015-06-22")}, + {std::make_shared(0), getDateTime64("2015-05-08 08:10:25", 0)}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + (8 * 5) + roundNumberOfBytesToNearestWord(strlen("Hello World"))); +} + + +TEST(SparkRow, DecimalTypes) +{ + DataTypeAndFields type_and_fields = { + {std::make_shared(9, 2), DecimalField(1234, 2)}, + {std::make_shared(18, 2), DecimalField(5678, 2)}, + {std::make_shared(38, 2), DecimalField(Decimal128(Int128(12345678)), 2)}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + (8 * 3) + 16); +} + + +TEST(SparkRow, NullHandling) +{ + DataTypeAndFields type_and_fields = { + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + {std::make_shared(std::make_shared()), Null{}}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + EXPECT_TRUE(spark_row_info->getTotalBytes() == static_cast(8 + (8 * type_and_fields.size()))); +} + +TEST(SparkRow, StructTypes) +{ + DataTypeAndFields type_and_fields = { + {std::make_shared(DataTypes{std::make_shared()}), Tuple{Int32(1)}}, + {std::make_shared(DataTypes{std::make_shared(DataTypes{std::make_shared()})}), + []() -> Field + { + Tuple t(1); + t.back() = Tuple{Int64(2)}; + return std::move(t); + }()}, + }; + + /* + for (size_t i=0; igetName() << ",field:" << toString(type_and_fields[i].field) + << std::endl; + } + */ + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + + EXPECT_TRUE( + spark_row_info->getTotalBytes() + == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) + + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); +} + +TEST(SparkRow, ArrayTypes) +{ + DataTypeAndFields type_and_fields = { + {std::make_shared(std::make_shared()), Array{Int32(1), Int32(2)}}, + {std::make_shared(std::make_shared(std::make_shared())), + []() -> Field + { + Array array(1); + array.back() = Array{Int32(1), Int32(2)}; + return std::move(array); + }()}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + EXPECT_TRUE( + spark_row_info->getTotalBytes() + == 8 + 2 * 8 + + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) + + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); +} + +TEST(SparkRow, MapTypes) +{ + const auto map_type = std::make_shared(std::make_shared(), std::make_shared()); + DataTypeAndFields type_and_fields = { + {map_type, + []() -> Field + { + Map map(2); + map[0] = std::move(Tuple{Int32(1), Int32(2)}); + map[1] = std::move(Tuple{Int32(3), Int32(4)}); + return std::move(map); + }()}, + {std::make_shared(std::make_shared(), map_type), + []() -> Field + { + Map inner_map(2); + inner_map[0] = std::move(Tuple{Int32(5), Int32(6)}); + inner_map[1] = std::move(Tuple{Int32(7), Int32(8)}); + + Map map(1); + map.back() = std::move(Tuple{Int32(9), std::move(inner_map)}); + return std::move(map); + }()}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + + EXPECT_TRUE( + spark_row_info->getTotalBytes() + == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) + + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); +} + + +TEST(SparkRow, StructMapTypes) +{ + const auto map_type = std::make_shared(std::make_shared(), std::make_shared()); + const auto tuple_type = std::make_shared(DataTypes{std::make_shared()}); + + DataTypeAndFields type_and_fields = { + {std::make_shared(DataTypes{map_type}), + []() -> Field + { + Map map(1); + map[0] = std::move(Tuple{Int32(1), Int32(2)}); + return std::move(Tuple{std::move(map)}); + }()}, + {std::make_shared(std::make_shared(), tuple_type), + []() -> Field + { + Tuple inner_tuple{Int32(4)}; + Map map(1); + map.back() = std::move(Tuple{Int32(3), std::move(inner_tuple)}); + return std::move(map); + }()}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + + EXPECT_TRUE( + spark_row_info->getTotalBytes() + == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) + + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); +} + + +TEST(SparkRow, StructArrayTypes) +{ + const auto array_type = std::make_shared(std::make_shared()); + const auto tuple_type = std::make_shared(DataTypes{std::make_shared()}); + DataTypeAndFields type_and_fields = { + {std::make_shared(DataTypes{array_type}), + []() -> Field + { + Array array{Int32(1)}; + Tuple tuple(1); + tuple[0] = std::move(array); + return std::move(tuple); + }()}, + {std::make_shared(tuple_type), + []() -> Field + { + Tuple tuple{Int64(2)}; + Array array(1); + array[0] = std::move(tuple); + return std::move(array); + }()}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + EXPECT_TRUE( + spark_row_info->getTotalBytes() + == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) + + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); + +} + +TEST(SparkRow, ArrayMapTypes) +{ + const auto map_type = std::make_shared(std::make_shared(), std::make_shared()); + const auto array_type = std::make_shared(std::make_shared()); + + DataTypeAndFields type_and_fields = { + {std::make_shared(map_type), + []() -> Field + { + Map map(1); + map[0] = std::move(Tuple{Int32(1),Int32(2)}); + + Array array(1); + array[0] = std::move(map); + return std::move(array); + }()}, + {std::make_shared(std::make_shared(), array_type), + []() -> Field + { + Array array{Int32(4)}; + Tuple tuple(2); + tuple[0] = Int32(3); + tuple[1] = std::move(array); + + Map map(1); + map[0] = std::move(tuple); + return std::move(map); + }()}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + + EXPECT_TRUE( + spark_row_info->getTotalBytes() + == 8 + 2 * 8 + BackingDataLengthCalculator(type_and_fields[0].type).calculate(type_and_fields[0].field) + + BackingDataLengthCalculator(type_and_fields[1].type).calculate(type_and_fields[1].field)); +} + + +TEST(SparkRow, NullableComplexTypes) +{ + const auto map_type = std::make_shared(std::make_shared(), std::make_shared()); + const auto tuple_type = std::make_shared(DataTypes{std::make_shared()}); + const auto array_type = std::make_shared(std::make_shared()); + DataTypeAndFields type_and_fields = { + {std::make_shared(map_type), Null{}}, + {std::make_shared(tuple_type), Null{}}, + {std::make_shared(array_type), Null{}}, + }; + + SparkRowInfoPtr spark_row_info; + BlockPtr block; + std::tie(spark_row_info, block) = mockSparkRowInfoAndBlock(type_and_fields); + assertReadConsistentWithWritten(*spark_row_info, *block, type_and_fields); + EXPECT_TRUE(spark_row_info->getTotalBytes() == 8 + 3 * 8); +} + diff --git a/utils/local-engine/tests/gtest_transformer.cpp b/utils/local-engine/tests/gtest_transformer.cpp new file mode 100644 index 000000000000..20f905fd85fa --- /dev/null +++ b/utils/local-engine/tests/gtest_transformer.cpp @@ -0,0 +1,113 @@ +#include +#include +#include +#include + +using namespace DB; + +TEST(TestPartitionColumnFillingTransform, TestInt32) +{ + auto int_type = DataTypeFactory::instance().get("Int32"); + auto column0 = int_type->createColumn(); + column0->insert(1); + column0->insert(2); + column0->insert(3); + column0->insert(4); + + ColumnsWithTypeAndName input_columns = {ColumnWithTypeAndName(int_type, "colA")}; + Block input(input_columns); + ColumnsWithTypeAndName output_columns = {ColumnWithTypeAndName(int_type, "colB"), ColumnWithTypeAndName(int_type, "colA")}; + Block output(output_columns); + String partition_name = "colB"; + String partition_value = "8"; + auto transformer = local_engine::PartitionColumnFillingTransform(input, output, partition_name, partition_value); + + Chunk chunk; + chunk.addColumn(std::move(column0)); + transformer.transform(chunk); + ASSERT_EQ(2, chunk.getNumColumns()); + WhichDataType which(chunk.getColumns().at(0)->getDataType()); + ASSERT_TRUE(which.isInt32()); +} + + +TEST(TestPartitionColumnFillingTransform, TestFloat32) +{ + auto int_type = DataTypeFactory::instance().get("Int32"); + auto float32_type = DataTypeFactory::instance().get("Float32"); + + auto column0 = int_type->createColumn(); + column0->insert(1); + column0->insert(2); + column0->insert(3); + column0->insert(4); + + ColumnsWithTypeAndName input_columns = {ColumnWithTypeAndName(int_type, "colA")}; + Block input(input_columns); + ColumnsWithTypeAndName output_columns = {ColumnWithTypeAndName(int_type, "colA"), ColumnWithTypeAndName(float32_type, "colB")}; + Block output(output_columns); + String partition_name = "colB"; + String partition_value = "3.1415926"; + auto transformer = local_engine::PartitionColumnFillingTransform(input, output, partition_name, partition_value); + + Chunk chunk; + chunk.addColumn(std::move(column0)); + transformer.transform(chunk); + ASSERT_EQ(2, chunk.getNumColumns()); + WhichDataType which(chunk.getColumns().at(1)->getDataType()); + ASSERT_TRUE(which.isFloat32()); +} + +TEST(TestPartitionColumnFillingTransform, TestDate) +{ + auto int_type = DataTypeFactory::instance().get("Int32"); + auto date_type = DataTypeFactory::instance().get("Date"); + + auto column0 = int_type->createColumn(); + column0->insert(1); + column0->insert(2); + column0->insert(3); + column0->insert(4); + + ColumnsWithTypeAndName input_columns = {ColumnWithTypeAndName(int_type, "colA")}; + Block input(input_columns); + ColumnsWithTypeAndName output_columns = {ColumnWithTypeAndName(int_type, "colA"), ColumnWithTypeAndName(date_type, "colB")}; + Block output(output_columns); + String partition_name = "colB"; + String partition_value = "2022-01-01"; + auto transformer = local_engine::PartitionColumnFillingTransform(input, output, partition_name, partition_value); + + Chunk chunk; + chunk.addColumn(std::move(column0)); + transformer.transform(chunk); + ASSERT_EQ(2, chunk.getNumColumns()); + WhichDataType which(chunk.getColumns().at(1)->getDataType()); + ASSERT_TRUE(which.isUInt16()); +} + +TEST(TestPartitionColumnFillingTransform, TestString) +{ + auto int_type = DataTypeFactory::instance().get("Int32"); + auto string_type = DataTypeFactory::instance().get("String"); + + auto column0 = int_type->createColumn(); + column0->insert(1); + column0->insert(2); + column0->insert(3); + column0->insert(4); + + ColumnsWithTypeAndName input_columns = {ColumnWithTypeAndName(int_type, "colA")}; + Block input(input_columns); + ColumnsWithTypeAndName output_columns = {ColumnWithTypeAndName(int_type, "colA"), ColumnWithTypeAndName(string_type, "colB")}; + Block output(output_columns); + String partition_name = "colB"; + String partition_value = "2022-01-01"; + auto transformer = local_engine::PartitionColumnFillingTransform(input, output, partition_name, partition_value); + + Chunk chunk; + chunk.addColumn(std::move(column0)); + transformer.transform(chunk); + ASSERT_EQ(2, chunk.getNumColumns()); + WhichDataType which(chunk.getColumns().at(1)->getDataType()); + ASSERT_TRUE(which.isString()); +} diff --git a/utils/local-engine/tests/gtest_utils.cpp b/utils/local-engine/tests/gtest_utils.cpp new file mode 100644 index 000000000000..64d47e797ae9 --- /dev/null +++ b/utils/local-engine/tests/gtest_utils.cpp @@ -0,0 +1,15 @@ +#include +#include + +using namespace local_engine; + +TEST(TestStringUtils, TestExtractPartitionValues) +{ + std::string path = "/tmp/col1=1/col2=test/a.parquet"; + auto values = StringUtils::parsePartitionTablePath(path); + ASSERT_EQ(2, values.size()); + ASSERT_EQ("col1", values[0].first); + ASSERT_EQ("1", values[0].second); + ASSERT_EQ("col2", values[1].first); + ASSERT_EQ("test", values[1].second); +} diff --git a/utils/local-engine/tests/testConfig.h.in b/utils/local-engine/tests/testConfig.h.in new file mode 100644 index 000000000000..75157c0126fd --- /dev/null +++ b/utils/local-engine/tests/testConfig.h.in @@ -0,0 +1,6 @@ +#pragma once + +#define TEST_DATA(file) "file://@TEST_DATA_DIR@"#file +#define PARQUET_DATA(file) "file://@PARQUET_DATA_DIR@"#file +#define MERGETREE_DATA(file) "@MERGETREE_DATA_DIR@"#file + diff --git a/utils/local-engine/tool/check-style b/utils/local-engine/tool/check-style new file mode 100755 index 000000000000..c3d4923999a6 --- /dev/null +++ b/utils/local-engine/tool/check-style @@ -0,0 +1,362 @@ +#!/usr/bin/env bash + +# For code formatting we have clang-format. +# +# But it's not sane to apply clang-format for whole code base, +# because it sometimes makes worse for properly formatted files. +# +# It's only reasonable to blindly apply clang-format only in cases +# when the code is likely to be out of style. +# +# For this purpose we have a script that will use very primitive heuristics +# (simple regexps) to check if the code is likely to have basic style violations. +# and then to run formatter only for the specified files. + +ROOT_PATH=$(git rev-parse --show-toplevel) +EXCLUDE_DIRS='build/|integration/|widechar_width/|glibc-compatibility/|memcpy/|consistent-hashing|ch_parquet/|com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper.h|com_intel_oap_row_RowIterator.h' + +# From [1]: +# But since array_to_string_internal() in array.c still loops over array +# elements and concatenates them into a string, it's probably not more +# efficient than the looping solutions proposed, but it's more readable. +# +# [1]: https://stackoverflow.com/a/15394738/328260 +function in_array() +{ + local IFS="|" + local value=$1 && shift + + [[ "${IFS}${*}${IFS}" =~ "${IFS}${value}${IFS}" ]] +} + +find $ROOT_PATH/{utils} -name '*.h' -or -name '*.cpp' 2>/dev/null | + grep -vP $EXCLUDE_DIRS | + xargs grep $@ -P '((class|struct|namespace|enum|if|for|while|else|throw|switch).*|\)(\s*const)?(\s*override)?\s*)\{$|\s$|^ {1,3}[^\* ]\S|\t|^\s*(if|else if|if constexpr|else if constexpr|for|while|catch|switch)\(|\( [^\s\\]|\S \)' | +# a curly brace not in a new line, but not for the case of C++11 init or agg. initialization | trailing whitespace | number of ws not a multiple of 4, but not in the case of comment continuation | missing whitespace after for/if/while... before opening brace | whitespaces inside braces + grep -v -P '(//|:\s+\*|\$\(\()| \)"' +# single-line comment | continuation of a multiline comment | a typical piece of embedded shell code | something like ending of raw string literal + +# Tabs +find $ROOT_PATH/{utils} -name '*.h' -or -name '*.cpp' 2>/dev/null | + grep -vP $EXCLUDE_DIRS | + xargs grep $@ -F $'\t' + +# // namespace comments are unneeded +find $ROOT_PATH/{utils} -name '*.h' -or -name '*.cpp' 2>/dev/null | + grep -vP $EXCLUDE_DIRS | + xargs grep $@ -P '}\s*//+\s*namespace\s*' + +# Broken symlinks +find -L $ROOT_PATH -type l 2>/dev/null | grep -v contrib && echo "^ Broken symlinks found" + +# Double whitespaces +find $ROOT_PATH/{utils} -name '*.h' -or -name '*.cpp' 2>/dev/null | + grep -vP $EXCLUDE_DIRS | + while read i; do $ROOT_PATH/utils/check-style/double-whitespaces.pl < $i || echo -e "^ File $i contains double whitespaces\n"; done + +# Unused/Undefined/Duplicates ErrorCodes/ProfileEvents/CurrentMetrics +declare -A EXTERN_TYPES +EXTERN_TYPES[ErrorCodes]=int +EXTERN_TYPES[ProfileEvents]=Event +EXTERN_TYPES[CurrentMetrics]=Metric + +EXTERN_TYPES_EXCLUDES=( + ProfileEvents::global_counters + ProfileEvents::Event + ProfileEvents::Count + ProfileEvents::Counters + ProfileEvents::end + ProfileEvents::increment + ProfileEvents::getName + ProfileEvents::Type + ProfileEvents::TypeEnum + ProfileEvents::dumpToMapColumn + ProfileEvents::getProfileEvents + ProfileEvents::ThreadIdToCountersSnapshot + ProfileEvents::LOCAL_NAME + ProfileEvents::CountersIncrement + + CurrentMetrics::add + CurrentMetrics::sub + CurrentMetrics::set + CurrentMetrics::end + CurrentMetrics::Increment + CurrentMetrics::Metric + CurrentMetrics::values + CurrentMetrics::Value + + ErrorCodes::ErrorCode + ErrorCodes::getName + ErrorCodes::increment + ErrorCodes::end + ErrorCodes::values + ErrorCodes::values[i] + ErrorCodes::getErrorCodeByName +) +for extern_type in ${!EXTERN_TYPES[@]}; do + type_of_extern=${EXTERN_TYPES[$extern_type]} + allowed_chars='[_A-Za-z]+' + + # Unused + # NOTE: to fix automatically, replace echo with: + # sed -i "/extern const $type_of_extern $val/d" $file + find $ROOT_PATH/{src,base,programs,utils} -name '*.h' -or -name '*.cpp' | { + # NOTE: the check is pretty dumb and distinguish only by the type_of_extern, + # and this matches with zkutil::CreateMode + grep -v 'src/Common/ZooKeeper/Types.h' + } | { + grep -vP $EXCLUDE_DIRS | xargs grep -l -P "extern const $type_of_extern $allowed_chars" + } | while read file; do + grep -P "extern const $type_of_extern $allowed_chars;" $file | sed -r -e "s/^.*?extern const $type_of_extern ($allowed_chars);.*?$/\1/" | while read val; do + if ! grep -q "$extern_type::$val" $file; then + # Excludes for SOFTWARE_EVENT/HARDWARE_EVENT/CACHE_EVENT in ThreadProfileEvents.cpp + if [[ ! $extern_type::$val =~ ProfileEvents::Perf.* ]]; then + echo "$extern_type::$val is defined but not used in file $file" + fi + fi + done + done + + # Undefined + # NOTE: to fix automatically, replace echo with: + # ( grep -q -F 'namespace $extern_type' $file && \ + # sed -i -r "0,/(\s*)extern const $type_of_extern [$allowed_chars]+/s//\1extern const $type_of_extern $val;\n&/" $file || \ + # awk '{ print; if (ns == 1) { ns = 2 }; if (ns == 2) { ns = 0; print "namespace $extern_type\n{\n extern const $type_of_extern '$val';\n}" } }; /namespace DB/ { ns = 1; };' < $file > ${file}.tmp && mv ${file}.tmp $file ) + find $ROOT_PATH/{src,base,programs,utils} -name '*.h' -or -name '*.cpp' | { + grep -vP $EXCLUDE_DIRS | xargs grep -l -P "$extern_type::$allowed_chars" + } | while read file; do + grep -P "$extern_type::$allowed_chars" $file | grep -P -v '^\s*//' | sed -r -e "s/^.*?$extern_type::($allowed_chars).*?$/\1/" | while read val; do + if ! grep -q "extern const $type_of_extern $val" $file; then + if ! in_array "$extern_type::$val" "${EXTERN_TYPES_EXCLUDES[@]}"; then + echo "$extern_type::$val is used in file $file but not defined" + fi + fi + done + done + + # Duplicates + find $ROOT_PATH/{src,base,programs,utils} -name '*.h' -or -name '*.cpp' | { + grep -vP $EXCLUDE_DIRS | xargs grep -l -P "$extern_type::$allowed_chars" + } | while read file; do + grep -P "extern const $type_of_extern $allowed_chars;" $file | sort | uniq -c | grep -v -P ' +1 ' && echo "Duplicate $extern_type in file $file" + done +done + +# Three or more consecutive empty lines +find $ROOT_PATH/{src,base,programs,utils} -name '*.h' -or -name '*.cpp' | + grep -vP $EXCLUDE_DIRS | + while read file; do awk '/^$/ { ++i; if (i > 2) { print "More than two consecutive empty lines in file '$file'" } } /./ { i = 0 }' $file; done + +# Broken XML files (requires libxml2-utils) +#find $ROOT_PATH/{src,base,programs,utils} -name '*.xml' | +# grep -vP $EXCLUDE_DIRS | +# xargs xmllint --noout --nonet + +# FIXME: for now only clickhouse-test +#pylint --rcfile=$ROOT_PATH/.pylintrc --persistent=no --score=n $ROOT_PATH/tests/clickhouse-test $ROOT_PATH/tests/ci/*.py + +#find $ROOT_PATH -not -path $ROOT_PATH'/contrib*' \( -name '*.yaml' -or -name '*.yml' \) -type f | +# grep -vP $EXCLUDE_DIRS | +# xargs yamllint --config-file=$ROOT_PATH/.yamllint + +# Machine translation to Russian is strictly prohibited +#find $ROOT_PATH/docs/ru -name '*.md' | +# grep -vP $EXCLUDE_DIRS | +# xargs grep -l -F 'machine_translated: true' + +# Tests should not be named with "fail" in their names. It makes looking at the results less convenient. +#find $ROOT_PATH/tests/queries -iname '*fail*' | +# grep -vP $EXCLUDE_DIRS | +# grep . && echo 'Tests should not be named with "fail" in their names. It makes looking at the results less convenient when you search for "fail" substring in browser.' + +# Queries to system.query_log/system.query_thread_log should have current_database = currentDatabase() condition +# NOTE: it is not that accuate, but at least something. +#tests_with_query_log=( $( +# find $ROOT_PATH/tests/queries -iname '*.sql' -or -iname '*.sh' -or -iname '*.py' -or -iname '*.j2' | +# grep -vP $EXCLUDE_DIRS | +# xargs grep --with-filename -e system.query_log -e system.query_thread_log | cut -d: -f1 | sort -u +#) ) +#for test_case in "${tests_with_query_log[@]}"; do +# grep -qE current_database.*currentDatabase "$test_case" || { +# grep -qE 'current_database.*\$CLICKHOUSE_DATABASE' "$test_case" +# } || echo "Queries to system.query_log/system.query_thread_log does not have current_database = currentDatabase() condition in $test_case" +#done + +# Queries to: +tables_with_database_column=( + system.tables + system.parts + system.detached_parts + system.parts_columns + system.columns + system.projection_parts + system.mutations +) +# should have database = currentDatabase() condition +# +# NOTE: it is not that accuate, but at least something. +#tests_with_database_column=( $( +# find $ROOT_PATH/tests/queries -iname '*.sql' -or -iname '*.sh' -or -iname '*.py' -or -iname '*.j2' | +# grep -vP $EXCLUDE_DIRS | +# xargs grep --with-filename $(printf -- "-e %s " "${tables_with_database_column[@]}") | +# grep -v -e ':--' -e ':#' | +# cut -d: -f1 | sort -u +#) ) +#for test_case in "${tests_with_database_column[@]}"; do +# grep -qE database.*currentDatabase "$test_case" || { +# grep -qE 'database.*\$CLICKHOUSE_DATABASE' "$test_case" +# } || { +# # explicit database +# grep -qE "database[ ]*=[ ]*'" "$test_case" +# } || { +# echo "Queries to ${tables_with_database_column[*]} does not have database = currentDatabase()/\$CLICKHOUSE_DATABASE condition in $test_case" +# } +#done + +# Queries with ReplicatedMergeTree +# NOTE: it is not that accuate, but at least something. +#tests_with_replicated_merge_tree=( $( +# find $ROOT_PATH/tests/queries -iname '*.sql' -or -iname '*.sh' -or -iname '*.py' -or -iname '*.j2' | +# grep -vP $EXCLUDE_DIRS | +# xargs grep --with-filename -e ReplicatedMergeTree | cut -d: -f1 | sort -u +#) ) +#for test_case in "${tests_with_replicated_merge_tree[@]}"; do +# case "$test_case" in +# *.gen.*) +# ;; +# *.sh) +# test_case_zk_prefix="\$CLICKHOUSE_TEST_ZOOKEEPER_PREFIX" +# grep -q -e "ReplicatedMergeTree[ ]*(.*$test_case_zk_prefix" "$test_case" || echo "ReplicatedMergeTree should contain '$test_case_zk_prefix' in zookeeper path to avoid overlaps ($test_case)" +# ;; +# *.sql|*.sql.j2) +# test_case_zk_prefix="\({database}\|currentDatabase()\)" +# grep -q -e "ReplicatedMergeTree[ ]*(.*$test_case_zk_prefix" "$test_case" || echo "ReplicatedMergeTree should contain '$test_case_zk_prefix' in zookeeper path to avoid overlaps ($test_case)" +# ;; +# *.py) +# # Right now there is not such tests anyway +# echo "No ReplicatedMergeTree style check for *.py ($test_case)" +# ;; +# esac +#done + +# All the submodules should be from https://github.com/ +find $ROOT_PATH -name '.gitmodules' | while read i; do grep -F 'url = ' $i | grep -v -F 'https://github.com/' && echo 'All the submodules should be from https://github.com/'; done + +# There shouldn't be any code snippets under GPL or LGPL +find $ROOT_PATH/{src,base,programs} -name '*.h' -or -name '*.cpp' 2>/dev/null | xargs grep -i -F 'General Public License' && echo "There shouldn't be any code snippets under GPL or LGPL" + +# There shouldn't be any docker containers outside docker directory +#find $ROOT_PATH -not -path $ROOT_PATH'/tests/ci*' -not -path $ROOT_PATH'/docker*' -not -path $ROOT_PATH'/contrib*' -not -path $ROOT_PATH'/utils/local-engine' -name Dockerfile -type f 2>/dev/null | xargs --no-run-if-empty -n1 echo "Please move Dockerfile to docker directory:" + +# There shouldn't be any docker compose files outside docker directory +#find $ROOT_PATH -not -path $ROOT_PATH'/tests/testflows*' -not -path $ROOT_PATH'/docker*' -not -path $ROOT_PATH'/contrib*' -name '*compose*.yml' -type f 2>/dev/null | xargs --no-run-if-empty grep -l "version:" | xargs --no-run-if-empty -n1 echo "Please move docker compose to docker directory:" + +# Check that every header file has #pragma once in first line +find $ROOT_PATH/{src,programs,utils} -name '*.h' | + grep -vP $EXCLUDE_DIRS | + while read file; do [[ $(head -n1 $file) != '#pragma once' ]] && echo "File $file must have '#pragma once' in first line"; done + +# Check for executable bit on non-executable files +find $ROOT_PATH/{src,base,programs,utils,tests,docs,cmake} '(' -name '*.cpp' -or -name '*.h' -or -name '*.sql' -or -name '*.j2' -or -name '*.xml' -or -name '*.reference' -or -name '*.txt' -or -name '*.md' ')' -and -executable | grep -P '.' && echo "These files should not be executable." + +# Check for BOM +find $ROOT_PATH/{src,base,programs,utils,tests,docs,cmake} -name '*.md' -or -name '*.cpp' -or -name '*.h' | xargs grep -l -F $'\xEF\xBB\xBF' | grep -P '.' && echo "Files should not have UTF-8 BOM" +find $ROOT_PATH/{src,base,programs,utils,tests,docs,cmake} -name '*.md' -or -name '*.cpp' -or -name '*.h' | xargs grep -l -F $'\xFF\xFE' | grep -P '.' && echo "Files should not have UTF-16LE BOM" +find $ROOT_PATH/{src,base,programs,utils,tests,docs,cmake} -name '*.md' -or -name '*.cpp' -or -name '*.h' | xargs grep -l -F $'\xFE\xFF' | grep -P '.' && echo "Files should not have UTF-16BE BOM" + +# Too many exclamation marks +find $ROOT_PATH/{src,base,programs,utils} -name '*.h' -or -name '*.cpp' | + grep -vP $EXCLUDE_DIRS | + xargs grep -F '!!!' | grep -P '.' && echo "Too many exclamation marks (looks dirty, unconfident)." + +# Trailing whitespaces +find $ROOT_PATH/{src,base,programs,utils} -name '*.h' -or -name '*.cpp' | + grep -vP $EXCLUDE_DIRS | + xargs grep -n -P ' $' | grep -n -P '.' && echo "^ Trailing whitespaces." + +# Forbid stringstream because it's easy to use them incorrectly and hard to debug possible issues +find $ROOT_PATH/{src,programs,utils} -name '*.h' -or -name '*.cpp' | + grep -vP $EXCLUDE_DIRS | + xargs grep -P 'std::[io]?stringstream' | grep -v "STYLE_CHECK_ALLOW_STD_STRING_STREAM" && echo "Use WriteBufferFromOwnString or ReadBufferFromString instead of std::stringstream" + +# Forbid std::cerr/std::cout in src (fine in programs/utils) +std_cerr_cout_excludes=( + /examples/ + /tests/ + _fuzzer + # DUMP() + base/base/iostream_debug_helpers.h + # OK + src/Common/ProgressIndication.cpp + # only under #ifdef DBMS_HASH_MAP_DEBUG_RESIZES, that is used only in tests + src/Common/HashTable/HashTable.h + # SensitiveDataMasker::printStats() + src/Common/SensitiveDataMasker.cpp + # StreamStatistics::print() + src/Compression/LZ4_decompress_faster.cpp + # ContextSharedPart with subsequent std::terminate() + src/Interpreters/Context.cpp + # IProcessor::dump() + src/Processors/IProcessor.cpp + src/Client/ClientBase.cpp + src/Client/LineReader.cpp + src/Client/QueryFuzzer.cpp + src/Client/Suggest.cpp + src/Bridge/IBridge.cpp + src/Daemon/BaseDaemon.cpp + src/Loggers/Loggers.cpp +) +sources_with_std_cerr_cout=( $( + find $ROOT_PATH/{src,base} -name '*.h' -or -name '*.cpp' | \ + grep -vP $EXCLUDE_DIRS | \ + grep -F -v $(printf -- "-e %s " "${std_cerr_cout_excludes[@]}") | \ + xargs grep -F --with-filename -e std::cerr -e std::cout | cut -d: -f1 | sort -u +) ) +# Exclude comments +for src in "${sources_with_std_cerr_cout[@]}"; do + # suppress stderr, since it may contain warning for #pargma once in headers + if gcc -fpreprocessed -dD -E "$src" 2>/dev/null | grep -F -q -e std::cerr -e std::cout; then + echo "$src: uses std::cerr/std::cout" + fi +done + +# Queries with event_date should have yesterday() not today() +# +# NOTE: it is not that accuate, but at least something. +#tests_with_event_time_date=( $( +# find $ROOT_PATH/tests/queries -iname '*.sql' -or -iname '*.sh' -or -iname '*.py' -or -iname '*.j2' | +# grep -vP $EXCLUDE_DIRS | +# xargs grep --with-filename -e event_time -e event_date | cut -d: -f1 | sort -u +#) ) +#for test_case in "${tests_with_event_time_date[@]}"; do +# cat "$test_case" | tr '\n' ' ' | grep -q -i -e 'WHERE.*event_date[ ]*=[ ]*today()' -e 'WHERE.*event_date[ ]*=[ ]*today()' && { +# echo "event_time/event_date should be filtered using >=yesterday() in $test_case (to avoid flakiness)" +# } +#done + +# Conflict markers +find $ROOT_PATH/{src,base,programs,utils,tests,docs,cmake} -name '*.md' -or -name '*.cpp' -or -name '*.h' | + xargs grep -P '^(<<<<<<<|=======|>>>>>>>)$' | grep -P '.' && echo "Conflict markers are found in files" + +# Forbid subprocess.check_call(...) in integration tests because it does not provide enough information on errors +#find $ROOT_PATH'/tests/integration' -name '*.py' | +# xargs grep -F 'subprocess.check_call' | grep -v "STYLE_CHECK_ALLOW_SUBPROCESS_CHECK_CALL" && echo "Use helpers.cluster.run_and_check or subprocess.run instead of subprocess.check_call to print detailed info on error" + +# Forbid non-unique error codes +if [[ "$(grep -Po "M\([0-9]*," $ROOT_PATH/src/Common/ErrorCodes.cpp | wc -l)" != "$(grep -Po "M\([0-9]*," $ROOT_PATH/src/Common/ErrorCodes.cpp | sort | uniq | wc -l)" ]] +then + echo "ErrorCodes.cpp contains non-unique error codes" +fi + +# Check that there is no system-wide libraries/headers in use. +# +# NOTE: it is better to override find_path/find_library in cmake, but right now +# it is not possible, see [1] for the reference. +# +# [1]: git grep --recurse-submodules -e find_library -e find_path contrib +#if git grep -e find_path -e find_library -- :**CMakeLists.txt; then +# echo "There is find_path/find_library usage. ClickHouse should use everything bundled. Consider adding one more contrib module." +#fi + +# Forbid files that differ only by character case +find $ROOT_PATH/utils | sort -f | uniq -i -c | awk '{ if ($1 > 1) print }' diff --git a/utils/local-engine/tool/parquet_to_mergetree.py b/utils/local-engine/tool/parquet_to_mergetree.py new file mode 100644 index 000000000000..92051ce9bdc6 --- /dev/null +++ b/utils/local-engine/tool/parquet_to_mergetree.py @@ -0,0 +1,108 @@ +import os +import re +import subprocess +from argparse import ArgumentParser +from multiprocessing import Pool + +parser = ArgumentParser() +parser.add_argument("--path", type=str, required=True, help="temp directory for merge tree") +parser.add_argument("--source", type=str, required=True, help="directory of parquet files") +parser.add_argument("--dst", type=str, required=True, help="destination directory for merge tree") +parser.add_argument("--schema", type=str, + default="l_orderkey Nullable(Int64),l_partkey Nullable(Int64),l_suppkey Nullable(Int64),l_linenumber Nullable(Int64),l_quantity Nullable(Float64),l_extendedprice Nullable(Float64),l_discount Nullable(Float64),l_tax Nullable(Float64),l_returnflag Nullable(String),l_linestatus Nullable(String),l_shipdate Nullable(Date),l_commitdate Nullable(Date),l_receiptdate Nullable(Date),l_shipinstruct Nullable(String),l_shipmode Nullable(String),l_comment Nullable(String)") + + +def get_transform_command(data_path, + parquet_file, + schema): + return f""" + clickhouse-local --no-system-tables --path {data_path} --file "{parquet_file}" --input-format=Parquet \\ + -S "{schema}" \\ + --query " \\ + CREATE TABLE m1 ({schema}) ENGINE = MergeTree() order by tuple(); \\ + insert into m1 SELECT * FROM table;\\ + OPTIMIZE table m1 FINAL; + " + """ + + +def get_move_command(data_path, dst_path, no): + return f"mkdir -p {dst_path}/all_{no}_{no}_1; cp -r {data_path}/data/_local/m1/all_1_1_1/* {dst_path}/all_{no}_{no}_1" + + +def get_clean_command(data_path): + return f"rm -rf {data_path}/data/*" + + +def transform(data_path, source, schema, dst): + assert os.path.exists(data_path), f"{data_path} is not exist" + for no, file in enumerate([file for file in os.listdir(source) if file.endswith(".parquet")]): + abs_file = f"{source}/{file}" + if not os.path.exists(abs_file): + raise f"{abs_file} not found" + command1 = get_transform_command(data_path, abs_file, schema) + command2 = get_move_command(data_path, dst, no+1) + command3 = get_clean_command(data_path) + if os.system(command3) != 0: + raise Exception(command3 + " failed") + if os.system(command1) != 0: + raise Exception(command1 + " failed") + if os.system(command2) != 0: + raise Exception(command2 + " failed") + print(f"{abs_file}") + +class Engine(object): + def __init__(self, source, data_path, schema, dst): + self.source = source + self.data_path = data_path + self.schema = schema + self.dst = dst + def __call__(self, ele): + no = ele[0] + file = ele[1] + abs_file = f"{self.source}/{file}" + print(abs_file) + if not os.path.exists(abs_file): + raise f"{abs_file} not found" + private_path = f"{self.data_path}/{str(no)}" + os.system(f"mkdir -p {private_path}") + command1 = get_transform_command(private_path, abs_file, self.schema) + command2 = get_move_command(private_path, self.dst, no+1) + command3 = get_clean_command(private_path) + if os.system(command3) != 0: + raise Exception(command3 + " failed") + if os.system(command1) != 0: + raise Exception(command1 + " failed") + if os.system(command2) != 0: + raise Exception(command2 + " failed") + print(f"{abs_file}") + + +def multi_transform(data_path, source, schema, dst): + assert os.path.exists(data_path), f"{data_path} is not exist" + data_inputs = enumerate([file for file in os.listdir(source) if file.endswith(".parquet")]) + pool = Pool() + engine = Engine(source, data_path, schema, dst) + pool.map(engine, list(data_inputs)) # process data_inputs iterable with pool + + +def check_version(version): + proc = subprocess.Popen(["clickhouse-local", "--version"], stdout=subprocess.PIPE, shell=False) + (out, err) = proc.communicate() + if err: + raise Exception(f"Fail to call clickhouse-local, error: {err}") + ver = re.search(r'version\s*([\d.]+)', str(out)).group(1) + ver_12 = float(ver.split('.')[0] + '.' + ver.split('.')[1]) + if ver_12 >= float(version): + raise Exception(f"Version of clickhouse-local too high({ver}), should be <= 22.5") + +""" +python3 parquet_to_mergetree.py --path=/root/data/tmp --source=/home/ubuntu/tpch-data-sf100/lineitem --dst=/root/data/mergetree +""" +if __name__ == '__main__': + args = parser.parse_args() + if not os.path.exists(args.dst): + os.mkdir(args.dst) + #transform(args.path, args.source, args.schema, args.dst) + check_version('22.6') + multi_transform(args.path, args.source, args.schema, args.dst)