From be089f39dc05aac9dc6aee2ada89c22644114c43 Mon Sep 17 00:00:00 2001 From: Karthikeyan Natarajan Date: Thu, 22 Feb 2024 01:24:39 +0530 Subject: [PATCH 1/8] fix stats filter conversion dtypes and names --- cpp/src/io/parquet/predicate_pushdown.cpp | 49 ++++++++++++++++++-- cpp/src/io/parquet/reader_impl.cpp | 5 +- cpp/src/io/parquet/reader_impl_helpers.cpp | 5 +- cpp/src/io/parquet/reader_impl_helpers.hpp | 28 ++++++++--- cpp/src/io/parquet/reader_impl_preprocess.cu | 11 +---- cpp/tests/io/parquet_reader_test.cpp | 21 +++++++++ 6 files changed, 94 insertions(+), 25 deletions(-) diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index f43a8fd24c4..ba496eb68ba 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -374,9 +374,46 @@ class stats_expression_converter : public ast::detail::expression_transformer { }; } // namespace +std::tuple, host_span> +aggregate_reader_metadata::get_schema_dtypes(bool strings_to_categorical, type_id timestamp_type_id) +{ + // TODO, get types and names for only names present in filter.? and their col_idx. + // create root column types and names as vector + if (!_root_level_types.empty()) return {_root_level_types, _root_level_names}; + std::function get_dtype = [strings_to_categorical, + timestamp_type_id, + &get_dtype, + this](int schema_idx) -> cudf::data_type { + // returns type of root level columns only. + // if (schema_idx < 0) { return false; } + auto const& schema_elem = get_schema(schema_idx); + if (schema_elem.is_stub()) { + CUDF_EXPECTS(schema_elem.num_children == 1, "Unexpected number of children for stub"); + return get_dtype(schema_elem.children_idx[0]); + } + + auto const one_level_list = schema_elem.is_one_level_list(get_schema(schema_elem.parent_idx)); + // if we're at the root, this is a new output column + auto const col_type = one_level_list + ? type_id::LIST + : to_type_id(schema_elem, strings_to_categorical, timestamp_type_id); + auto const dtype = to_data_type(col_type, schema_elem); + // path_is_valid is skipped for nested columns here. TODO: more test cases where no leaf. + return dtype; + }; + + auto const& root = get_schema(0); + for (auto const& schema_idx : root.children_idx) { + if (schema_idx < 0) { continue; } + _root_level_types.push_back(get_dtype(schema_idx)); + _root_level_names.push_back(get_schema(schema_idx).name); + } + return {_root_level_types, _root_level_names}; + ; +} + std::optional>> aggregate_reader_metadata::filter_row_groups( host_span const> row_group_indices, - host_span output_dtypes, std::reference_wrapper filter, rmm::cuda_stream_view stream) const { @@ -410,8 +447,8 @@ std::optional>> aggregate_reader_metadata::fi // For each column, it contains #sources * #column_chunks_per_src rows. std::vector> columns; stats_caster stats_col{total_row_groups, per_file_metadata, input_row_group_indices}; - for (size_t col_idx = 0; col_idx < output_dtypes.size(); col_idx++) { - auto const& dtype = output_dtypes[col_idx]; + for (size_t col_idx = 0; col_idx < _root_level_types.size(); col_idx++) { + auto const& dtype = _root_level_types[col_idx]; // Only comparable types except fixed point are supported. if (cudf::is_compound(dtype) && dtype.id() != cudf::type_id::STRING) { // placeholder only for unsupported types. @@ -427,9 +464,13 @@ std::optional>> aggregate_reader_metadata::fi columns.push_back(std::move(max_col)); } auto stats_table = cudf::table(std::move(columns)); + // named filter to reference filter w.r.t parquet schema order. + auto expr_conv = named_to_reference_converter(filter, _root_level_names); + auto reference_filter = expr_conv.get_converted_expr(); // Converts AST to StatsAST with reference to min, max columns in above `stats_table`. - stats_expression_converter stats_expr{filter, static_cast(output_dtypes.size())}; + stats_expression_converter stats_expr{reference_filter.value().get(), + static_cast(_root_level_types.size())}; auto stats_ast = stats_expr.get_stats_expr(); auto predicate_col = cudf::detail::compute_column(stats_table, stats_ast.get(), stream, mr); auto predicate = predicate_col->view(); diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 26d810a3337..8e63dd05246 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -363,6 +363,9 @@ reader::impl::impl(std::size_t chunk_read_limit, _strings_to_categorical, _timestamp_type.id()); + // Find the name, and dtypes of parquet root level schema. (save it in _metadata.) + _metadata->get_schema_dtypes(_strings_to_categorical, _timestamp_type.id()); + // Save the states of the output buffers for reuse in `chunk_read()`. for (auto const& buff : _output_buffers) { _output_buffers_template.emplace_back(cudf::io::detail::inline_column_buffer::empty_like(buff)); @@ -508,7 +511,7 @@ table_with_metadata reader::impl::read( auto expr_conv = named_to_reference_converter(filter, metadata); auto output_filter = expr_conv.get_converted_expr(); - prepare_data(skip_rows, num_rows, uses_custom_row_bounds, row_group_indices, output_filter); + prepare_data(skip_rows, num_rows, uses_custom_row_bounds, row_group_indices, filter); return read_chunk_internal(uses_custom_row_bounds, output_filter); } diff --git a/cpp/src/io/parquet/reader_impl_helpers.cpp b/cpp/src/io/parquet/reader_impl_helpers.cpp index 6f11debb8df..759b35caf6a 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.cpp +++ b/cpp/src/io/parquet/reader_impl_helpers.cpp @@ -444,14 +444,13 @@ aggregate_reader_metadata::select_row_groups( host_span const> row_group_indices, int64_t skip_rows_opt, std::optional const& num_rows_opt, - host_span output_dtypes, std::optional> filter, rmm::cuda_stream_view stream) const { std::optional>> filtered_row_group_indices; + // if filter is not empty, then gather row groups to read after predicate pushdown if (filter.has_value()) { - filtered_row_group_indices = - filter_row_groups(row_group_indices, output_dtypes, filter.value(), stream); + filtered_row_group_indices = filter_row_groups(row_group_indices, filter.value(), stream); if (filtered_row_group_indices.has_value()) { row_group_indices = host_span const>(filtered_row_group_indices.value()); diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index 8d8ab8707be..cd4c0114341 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -82,6 +82,8 @@ class aggregate_reader_metadata { int64_t num_rows; size_type num_row_groups; + std::vector _root_level_types; + std::vector _root_level_names; /** * @brief Create a metadata object from each element in the source vector */ @@ -167,18 +169,19 @@ class aggregate_reader_metadata { */ [[nodiscard]] std::vector get_pandas_index_names() const; + std::tuple, host_span> get_schema_dtypes( + bool strings_to_categorical, type_id timestamp_type_id); + /** * @brief Filters the row groups based on predicate filter * * @param row_group_indices Lists of row groups to read, one per source - * @param output_dtypes List of output column datatypes * @param filter AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches * @return Filtered row group indices, if any is filtered. */ [[nodiscard]] std::optional>> filter_row_groups( host_span const> row_group_indices, - host_span output_dtypes, std::reference_wrapper filter, rmm::cuda_stream_view stream) const; @@ -191,7 +194,6 @@ class aggregate_reader_metadata { * @param row_group_indices Lists of row groups to read, one per source * @param row_start Starting row of the selection * @param row_count Total number of rows selected - * @param output_dtypes List of output column datatypes * @param filter Optional AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches * @return A tuple of corrected row_start, row_count and list of row group indexes and its @@ -201,7 +203,6 @@ class aggregate_reader_metadata { host_span const> row_group_indices, int64_t row_start, std::optional const& row_count, - host_span output_dtypes, std::optional> filter, rmm::cuda_stream_view stream) const; @@ -234,7 +235,6 @@ class named_to_reference_converter : public ast::detail::expression_transformer public: named_to_reference_converter(std::optional> expr, table_metadata const& metadata) - : metadata(metadata) { if (!expr.has_value()) return; // create map for column name. @@ -251,6 +251,21 @@ class named_to_reference_converter : public ast::detail::expression_transformer expr.value().get().accept(*this); } + named_to_reference_converter(std::reference_wrapper expr, + host_span root_column_names) + { + // create map for column name. + std::transform( + thrust::make_zip_iterator(root_column_names.begin(), thrust::counting_iterator(0)), + thrust::make_zip_iterator(root_column_names.end(), + thrust::counting_iterator(root_column_names.size())), + std::inserter(column_name_to_index, column_name_to_index.end()), + [](auto const& name_index) { + return std::make_pair(thrust::get<0>(name_index), thrust::get<1>(name_index)); + }); + + expr.get().accept(*this); + } /** * @copydoc ast::detail::expression_transformer::visit(ast::literal const& ) */ @@ -284,7 +299,6 @@ class named_to_reference_converter : public ast::detail::expression_transformer std::vector> visit_operands( std::vector> operands); - table_metadata const& metadata; std::unordered_map column_name_to_index; std::optional> _stats_expr; // Using std::list or std::deque to avoid reference invalidation diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 48ff32038b3..79f053c75fc 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -1089,18 +1089,9 @@ void reader::impl::preprocess_file( { CUDF_EXPECTS(!_file_preprocessed, "Attempted to preprocess file more than once"); - // if filter is not empty, then create output types as vector and pass for filtering. - std::vector output_types; - if (filter.has_value()) { - std::transform(_output_buffers.cbegin(), - _output_buffers.cend(), - std::back_inserter(output_types), - [](auto const& col) { return col.type; }); - } std::tie( _file_itm_data.global_skip_rows, _file_itm_data.global_num_rows, _file_itm_data.row_groups) = - _metadata->select_row_groups( - row_group_indices, skip_rows, num_rows, output_types, filter, _stream); + _metadata->select_row_groups(row_group_indices, skip_rows, num_rows, filter, _stream); if (_file_itm_data.global_num_rows > 0 && not _file_itm_data.row_groups.empty() && not _input_columns.empty()) { diff --git a/cpp/tests/io/parquet_reader_test.cpp b/cpp/tests/io/parquet_reader_test.cpp index abbd0c97f07..6e8a635a055 100644 --- a/cpp/tests/io/parquet_reader_test.cpp +++ b/cpp/tests/io/parquet_reader_test.cpp @@ -1406,6 +1406,27 @@ TEST_F(ParquetReaderTest, FilterIdentity) CUDF_TEST_EXPECT_TABLES_EQUAL(*result.tbl, *result2.tbl); } +TEST_F(ParquetReaderTest, FilterWithColumnProjection) +{ + auto [src, filepath] = create_parquet_with_stats("FilterWithColumnProjection.parquet"); + auto val = cudf::numeric_scalar{10}; + auto lit = cudf::ast::literal{val}; + auto col_ref = cudf::ast::column_name_reference{"col_uint32"}; + auto col_index = cudf::ast::column_reference{0}; + auto read_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref, lit); + auto filter_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_index, lit); + + auto predicate = cudf::compute_column(src, filter_expr); + auto projected_table = cudf::table_view{{src.get_column(2), src.get_column(0)}}; + auto expected = cudf::apply_boolean_mask(projected_table, *predicate); + + auto read_opts = cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}) + .columns({"col_double", "col_uint32"}) + .filter(read_expr); + auto result = cudf::io::read_parquet(read_opts); + CUDF_TEST_EXPECT_TABLES_EQUAL(*result.tbl, *expected); +} + TEST_F(ParquetReaderTest, FilterReferenceExpression) { auto [src, filepath] = create_parquet_with_stats("FilterReferenceExpression.parquet"); From f458410ebc1e4b1df1b6c2b898bde2ee52bf71b4 Mon Sep 17 00:00:00 2001 From: Karthikeyan Natarajan Date: Fri, 1 Mar 2024 14:28:58 +0530 Subject: [PATCH 2/8] filter columns limitation fixed. get_columns() need not specify all columns in filter. columns in filter are read, and discard finally at the output after filtering. --- cpp/src/io/parquet/predicate_pushdown.cpp | 36 +++++++++++++++ cpp/src/io/parquet/reader_impl.cpp | 20 ++++++++- cpp/src/io/parquet/reader_impl.hpp | 3 ++ cpp/src/io/parquet/reader_impl_helpers.cpp | 30 ++++++++----- cpp/src/io/parquet/reader_impl_helpers.hpp | 52 ++++++++++++++++++++++ cpp/tests/io/parquet_reader_test.cpp | 4 +- 6 files changed, 130 insertions(+), 15 deletions(-) diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index ba496eb68ba..12d17f6e049 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -570,4 +570,40 @@ named_to_reference_converter::visit_operands( return transformed_operands; } +// extract column names from expression +std::reference_wrapper names_from_expression::visit(ast::literal const& expr) +{ + return expr; +} + +std::reference_wrapper names_from_expression::visit( + ast::column_reference const& expr) +{ + return expr; +} + +std::reference_wrapper names_from_expression::visit( + ast::column_name_reference const& expr) +{ + // collect column names + auto col_name = expr.get_column_name(); + if (_skip_names.count(col_name) == 0) { _column_names.insert(col_name); } + return expr; +} + +std::reference_wrapper names_from_expression::visit( + ast::operation const& expr) +{ + visit_operands(expr.get_operands()); + return expr; +} + +void names_from_expression::visit_operands( + std::vector> operands) +{ + for (auto const& operand : operands) { + operand.get().accept(*this); + } +} + } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 8e63dd05246..ed6637bcef5 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -23,6 +23,8 @@ #include #include +#include + #include #include @@ -356,9 +358,18 @@ reader::impl::impl(std::size_t chunk_read_limit, // Binary columns can be read as binary or strings _reader_column_schema = options.get_column_schema(); - // Select only columns required by the options + // Select only columns required by the options and filter + std::optional> filter_columns_names; + if (options.get_filter().has_value() and options.get_columns().has_value()) { + // list, struct, dictionary are not supported by AST filter yet. + // extract columns not present in get_columns() & keep count to remove at end. + auto extractor = names_from_expression(options.get_filter(), *(options.get_columns())); + filter_columns_names = extractor.get_column_names(); + _num_filter_columns = filter_columns_names->size(); + } std::tie(_input_columns, _output_buffers, _output_column_schemas) = _metadata->select_columns(options.get_columns(), + filter_columns_names, options.is_enabled_use_pandas_metadata(), _strings_to_categorical, _timestamp_type.id()); @@ -491,7 +502,12 @@ table_with_metadata reader::impl::finalize_output( *read_table, filter.value().get(), _stream, rmm::mr::get_current_device_resource()); CUDF_EXPECTS(predicate->view().type().id() == type_id::BOOL8, "Predicate filter should return a boolean"); - auto output_table = cudf::detail::apply_boolean_mask(*read_table, *predicate, _stream, _mr); + // Exclude columns present in filter only in output + auto counting_it = thrust::make_counting_iterator(0); + auto const output_count = read_table->num_columns() - _num_filter_columns; + auto only_output = read_table->select(counting_it, counting_it + output_count); + auto output_table = cudf::detail::apply_boolean_mask(only_output, *predicate, _stream, _mr); + if (_num_filter_columns > 0) { out_metadata.schema_info.resize(output_count); } return {std::move(output_table), std::move(out_metadata)}; } return {std::make_unique(std::move(out_columns)), std::move(out_metadata)}; diff --git a/cpp/src/io/parquet/reader_impl.hpp b/cpp/src/io/parquet/reader_impl.hpp index 67c56c9c2d7..d08b37c3fd8 100644 --- a/cpp/src/io/parquet/reader_impl.hpp +++ b/cpp/src/io/parquet/reader_impl.hpp @@ -364,6 +364,9 @@ class reader::impl { // _output_buffers associated metadata std::unique_ptr _output_metadata; + // number of extra filter columns + std::size_t _num_filter_columns{0}; + bool _strings_to_categorical = false; std::optional> _reader_column_schema; data_type _timestamp_type{type_id::EMPTY}; diff --git a/cpp/src/io/parquet/reader_impl_helpers.cpp b/cpp/src/io/parquet/reader_impl_helpers.cpp index 759b35caf6a..19e3ddadd8d 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.cpp +++ b/cpp/src/io/parquet/reader_impl_helpers.cpp @@ -18,6 +18,7 @@ #include "io/utilities/row_selection.hpp" +#include #include #include @@ -498,10 +499,12 @@ aggregate_reader_metadata::select_row_groups( std::tuple, std::vector, std::vector> -aggregate_reader_metadata::select_columns(std::optional> const& use_names, - bool include_index, - bool strings_to_categorical, - type_id timestamp_type_id) const +aggregate_reader_metadata::select_columns( + std::optional> const& use_names, + std::optional> const& filter_columns_names, + bool include_index, + bool strings_to_categorical, + type_id timestamp_type_id) const { auto find_schema_child = [&](SchemaElement const& schema_elem, std::string const& name) { auto const& col_schema_idx = @@ -666,13 +669,18 @@ aggregate_reader_metadata::select_columns(std::optional // Find which of the selected paths are valid and get their schema index std::vector valid_selected_paths; - for (auto const& selected_path : *use_names) { - auto found_path = - std::find_if(all_paths.begin(), all_paths.end(), [&](path_info& valid_path) { - return valid_path.full_path == selected_path; - }); - if (found_path != all_paths.end()) { - valid_selected_paths.push_back({selected_path, found_path->schema_idx}); + // vector reference pushback (*use_names). If filter names passed. + std::vector const>> column_names{ + *use_names, *filter_columns_names}; + for (auto const& used_column_names : column_names) { + for (auto const& selected_path : used_column_names.get()) { + auto found_path = + std::find_if(all_paths.begin(), all_paths.end(), [&](path_info& valid_path) { + return valid_path.full_path == selected_path; + }); + if (found_path != all_paths.end()) { + valid_selected_paths.push_back({selected_path, found_path->schema_idx}); + } } } diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index cd4c0114341..c09ba83d5dc 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -30,6 +30,7 @@ #include #include +#include #include namespace cudf::io::parquet::detail { @@ -211,6 +212,7 @@ class aggregate_reader_metadata { * * @param use_names List of paths of column names to select; `nullopt` if user did not select * columns to read + * @param filter_columns_names List of paths of column names that are present only in filter * @param include_index Whether to always include the PANDAS index column(s) * @param strings_to_categorical Type conversion parameter * @param timestamp_type_id Type conversion parameter @@ -222,6 +224,7 @@ class aggregate_reader_metadata { std::vector, std::vector> select_columns(std::optional> const& use_names, + std::optional> const& filter_columns_names, bool include_index, bool strings_to_categorical, type_id timestamp_type_id) const; @@ -306,4 +309,53 @@ class named_to_reference_converter : public ast::detail::expression_transformer std::list _operators; }; +/** + * @brief Converts named columns to index reference columns + * + */ +class names_from_expression : public ast::detail::expression_transformer { + public: + names_from_expression(std::optional> expr, + std::vector const& skip_names) + : _skip_names(skip_names.cbegin(), skip_names.cend()) + { + if (!expr.has_value()) return; + expr.value().get().accept(*this); + } + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::literal const& ) + */ + std::reference_wrapper visit(ast::literal const& expr) override; + /** + * @copydoc ast::detail::expression_transformer::visit(ast::column_reference const& ) + */ + std::reference_wrapper visit(ast::column_reference const& expr) override; + /** + * @copydoc ast::detail::expression_transformer::visit(ast::column_name_reference const& ) + */ + std::reference_wrapper visit( + ast::column_name_reference const& expr) override; + /** + * @copydoc ast::detail::expression_transformer::visit(ast::operation const& ) + */ + std::reference_wrapper visit(ast::operation const& expr) override; + + /** + * @brief Returns the column names in AST. + * + * @return AST operation expression + */ + [[nodiscard]] std::vector get_column_names() const + { + return {_column_names.begin(), _column_names.end()}; + } + + private: + void visit_operands(std::vector> operands); + + std::unordered_set _column_names; + std::unordered_set _skip_names; +}; + } // namespace cudf::io::parquet::detail diff --git a/cpp/tests/io/parquet_reader_test.cpp b/cpp/tests/io/parquet_reader_test.cpp index 6e8a635a055..80658b50543 100644 --- a/cpp/tests/io/parquet_reader_test.cpp +++ b/cpp/tests/io/parquet_reader_test.cpp @@ -1417,11 +1417,11 @@ TEST_F(ParquetReaderTest, FilterWithColumnProjection) auto filter_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_index, lit); auto predicate = cudf::compute_column(src, filter_expr); - auto projected_table = cudf::table_view{{src.get_column(2), src.get_column(0)}}; + auto projected_table = cudf::table_view{{src.get_column(2)}}; auto expected = cudf::apply_boolean_mask(projected_table, *predicate); auto read_opts = cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}) - .columns({"col_double", "col_uint32"}) + .columns({"col_double"}) .filter(read_expr); auto result = cudf::io::read_parquet(read_opts); CUDF_TEST_EXPECT_TABLES_EQUAL(*result.tbl, *expected); From b01b2d8a0c887dc8cbd0d3ddc8d6966cca2a9ca4 Mon Sep 17 00:00:00 2001 From: Karthikeyan Natarajan Date: Fri, 1 Mar 2024 14:50:58 +0530 Subject: [PATCH 3/8] address review comments, added docstring --- cpp/src/io/parquet/predicate_pushdown.cpp | 8 +++---- cpp/src/io/parquet/reader_impl.cpp | 2 +- cpp/src/io/parquet/reader_impl_helpers.hpp | 27 +++++++++++++--------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index 12d17f6e049..d40a7741ad7 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -374,12 +374,12 @@ class stats_expression_converter : public ast::detail::expression_transformer { }; } // namespace -std::tuple, host_span> -aggregate_reader_metadata::get_schema_dtypes(bool strings_to_categorical, type_id timestamp_type_id) +void aggregate_reader_metadata::cache_root_dtypes_names(bool strings_to_categorical, + type_id timestamp_type_id) { // TODO, get types and names for only names present in filter.? and their col_idx. // create root column types and names as vector - if (!_root_level_types.empty()) return {_root_level_types, _root_level_names}; + if (!_root_level_types.empty()) return; std::function get_dtype = [strings_to_categorical, timestamp_type_id, &get_dtype, @@ -408,8 +408,6 @@ aggregate_reader_metadata::get_schema_dtypes(bool strings_to_categorical, type_i _root_level_types.push_back(get_dtype(schema_idx)); _root_level_names.push_back(get_schema(schema_idx).name); } - return {_root_level_types, _root_level_names}; - ; } std::optional>> aggregate_reader_metadata::filter_row_groups( diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index ed6637bcef5..2f3a1283606 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -375,7 +375,7 @@ reader::impl::impl(std::size_t chunk_read_limit, _timestamp_type.id()); // Find the name, and dtypes of parquet root level schema. (save it in _metadata.) - _metadata->get_schema_dtypes(_strings_to_categorical, _timestamp_type.id()); + _metadata->cache_root_dtypes_names(_strings_to_categorical, _timestamp_type.id()); // Save the states of the output buffers for reuse in `chunk_read()`. for (auto const& buff : _output_buffers) { diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index c09ba83d5dc..8aea35dcdc8 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -170,8 +170,13 @@ class aggregate_reader_metadata { */ [[nodiscard]] std::vector get_pandas_index_names() const; - std::tuple, host_span> get_schema_dtypes( - bool strings_to_categorical, type_id timestamp_type_id); + /** + * @brief Extract root level column data types and names and caches them. + * + * @param strings_to_categorical Type conversion parameter + * @param timestamp_type_id Type conversion parameter + */ + void cache_root_dtypes_names(bool strings_to_categorical, type_id timestamp_type_id); /** * @brief Filters the row groups based on predicate filter @@ -241,15 +246,15 @@ class named_to_reference_converter : public ast::detail::expression_transformer { if (!expr.has_value()) return; // create map for column name. - std::transform( - thrust::make_zip_iterator(metadata.schema_info.cbegin(), - thrust::counting_iterator(0)), - thrust::make_zip_iterator(metadata.schema_info.cend(), - thrust::counting_iterator(metadata.schema_info.size())), - std::inserter(column_name_to_index, column_name_to_index.end()), - [](auto const& name_index) { - return std::make_pair(thrust::get<0>(name_index).name, thrust::get<1>(name_index)); - }); + auto it_name_id = thrust::make_zip_iterator(metadata.schema_info.cbegin(), + thrust::counting_iterator(0)); + std::transform(it_name_id, + it_name_id + metadata.schema_info.size(), + std::inserter(column_name_to_index, column_name_to_index.end()), + [](auto const& name_index) { + return std::make_pair(thrust::get<0>(name_index).name, + thrust::get<1>(name_index)); + }); expr.value().get().accept(*this); } From 4a07e3dbb7546aef10061016dc0927d1ce3285fe Mon Sep 17 00:00:00 2001 From: Karthikeyan Natarajan Date: Fri, 1 Mar 2024 15:31:03 +0530 Subject: [PATCH 4/8] add docstring for filter --- cpp/include/cudf/io/parquet.hpp | 12 ++++++++-- cpp/tests/io/parquet_reader_test.cpp | 34 ++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/cpp/include/cudf/io/parquet.hpp b/cpp/include/cudf/io/parquet.hpp index dc035db8d39..df9641aa59e 100644 --- a/cpp/include/cudf/io/parquet.hpp +++ b/cpp/include/cudf/io/parquet.hpp @@ -195,6 +195,10 @@ class parquet_reader_options { /** * @brief Sets AST based filter for predicate pushdown. * + * The filter can utilize cudf::ast::column_name_reference to reference a column by its name, + * even if it's not necessarily present in the requested projected columns. + * To refer to output column indices, you can use cudf::ast::column_reference. + * * @param filter AST expression to use as filter */ void set_filter(ast::expression const& filter) { _filter = filter; } @@ -292,9 +296,13 @@ class parquet_reader_options_builder { } /** - * @brief Sets vector of individual row groups to read. + * @brief Sets AST based filter for predicate pushdown. * - * @param filter Vector of row groups to read + * The filter can utilize cudf::ast::column_name_reference to reference a column by its name, + * even if it's not necessarily present in the requested projected columns. + * To refer to output column indices, you can use cudf::ast::column_reference. + * + * @param filter AST expression to use as filter * @return this for chaining */ parquet_reader_options_builder& filter(ast::expression const& filter) diff --git a/cpp/tests/io/parquet_reader_test.cpp b/cpp/tests/io/parquet_reader_test.cpp index 80658b50543..7d9fd28fb42 100644 --- a/cpp/tests/io/parquet_reader_test.cpp +++ b/cpp/tests/io/parquet_reader_test.cpp @@ -1413,18 +1413,34 @@ TEST_F(ParquetReaderTest, FilterWithColumnProjection) auto lit = cudf::ast::literal{val}; auto col_ref = cudf::ast::column_name_reference{"col_uint32"}; auto col_index = cudf::ast::column_reference{0}; - auto read_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref, lit); auto filter_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_index, lit); - auto predicate = cudf::compute_column(src, filter_expr); - auto projected_table = cudf::table_view{{src.get_column(2)}}; - auto expected = cudf::apply_boolean_mask(projected_table, *predicate); + auto predicate = cudf::compute_column(src, filter_expr); - auto read_opts = cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}) - .columns({"col_double"}) - .filter(read_expr); - auto result = cudf::io::read_parquet(read_opts); - CUDF_TEST_EXPECT_TABLES_EQUAL(*result.tbl, *expected); + { // column_name_reference in parquet filter (not present in column projection) + auto read_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref, lit); + auto projected_table = cudf::table_view{{src.get_column(2)}}; + auto expected = cudf::apply_boolean_mask(projected_table, *predicate); + + auto read_opts = cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}) + .columns({"col_double"}) + .filter(read_expr); + auto result = cudf::io::read_parquet(read_opts); + CUDF_TEST_EXPECT_TABLES_EQUAL(*result.tbl, *expected); + } + + { // column_reference in parquet filter (indices as per order of column projection) + auto col_index2 = cudf::ast::column_reference{1}; + auto read_ref_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_index, lit); + + auto projected_table = cudf::table_view{{src.get_column(2), src.get_column(0)}}; + auto expected = cudf::apply_boolean_mask(projected_table, *predicate); + auto read_opts = cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}) + .columns({"col_double", "col_uint32"}) + .filter(read_ref_expr); + auto result = cudf::io::read_parquet(read_opts); + CUDF_TEST_EXPECT_TABLES_EQUAL(*result.tbl, *expected); + } } TEST_F(ParquetReaderTest, FilterReferenceExpression) From acb0723c30385b818f0e3505cea7b9bb824f05d3 Mon Sep 17 00:00:00 2001 From: Karthikeyan Natarajan Date: Wed, 6 Mar 2024 15:47:02 +0530 Subject: [PATCH 5/8] update docs with example --- cpp/include/cudf/io/parquet.hpp | 29 +++++++++++++++++++++------- cpp/tests/io/parquet_reader_test.cpp | 4 ++-- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/cpp/include/cudf/io/parquet.hpp b/cpp/include/cudf/io/parquet.hpp index df9641aa59e..44919448001 100644 --- a/cpp/include/cudf/io/parquet.hpp +++ b/cpp/include/cudf/io/parquet.hpp @@ -199,6 +199,27 @@ class parquet_reader_options { * even if it's not necessarily present in the requested projected columns. * To refer to output column indices, you can use cudf::ast::column_reference. * + * For a parquet with columns ["A", "B", "C", ... "X", "Y", "Z"], + * Example 1: with/without column projection + * @code + * use_columns({"A", "X", "Z"}) + * .filter(operation(ast_operator::LESS, column_name_reference{"C"}, literal{100})); + * @endcode + * Column "C" need not be present in output column. + * Example 2: without column projection + * @code + * filter(operation(ast_operator::LESS, column_reference{1}, literal{100})); + * @endcode + * Here, `1` will refer to column "B" because output will contain all columns in + * order ["A", ..., "Z"]. + * Example 3: with column projection + * @code + * use_columns({"A", "Z", "X"}) + * .filter(operation(ast_operator::LESS, column_reference{1}, literal{100})); + * @endcode + * Here, `1` will refer to column "Z" because output will contain 3 columns in + * order ["A", "Z", "X"]. + * * @param filter AST expression to use as filter */ void set_filter(ast::expression const& filter) { _filter = filter; } @@ -296,13 +317,7 @@ class parquet_reader_options_builder { } /** - * @brief Sets AST based filter for predicate pushdown. - * - * The filter can utilize cudf::ast::column_name_reference to reference a column by its name, - * even if it's not necessarily present in the requested projected columns. - * To refer to output column indices, you can use cudf::ast::column_reference. - * - * @param filter AST expression to use as filter + * @copydoc parquet_reader_options::set_filter * @return this for chaining */ parquet_reader_options_builder& filter(ast::expression const& filter) diff --git a/cpp/tests/io/parquet_reader_test.cpp b/cpp/tests/io/parquet_reader_test.cpp index 7d9fd28fb42..eba238b036d 100644 --- a/cpp/tests/io/parquet_reader_test.cpp +++ b/cpp/tests/io/parquet_reader_test.cpp @@ -1430,8 +1430,8 @@ TEST_F(ParquetReaderTest, FilterWithColumnProjection) } { // column_reference in parquet filter (indices as per order of column projection) - auto col_index2 = cudf::ast::column_reference{1}; - auto read_ref_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_index, lit); + auto col_index2 = cudf::ast::column_reference{0}; + auto read_ref_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_index2, lit); auto projected_table = cudf::table_view{{src.get_column(2), src.get_column(0)}}; auto expected = cudf::apply_boolean_mask(projected_table, *predicate); From e40cffcfe3ee070e9d4077c70622c94d265a40dd Mon Sep 17 00:00:00 2001 From: Karthikeyan Natarajan Date: Wed, 24 Apr 2024 02:35:11 +0000 Subject: [PATCH 6/8] address review comments, include cleanup, reorg code --- cpp/include/cudf/io/parquet.hpp | 2 +- cpp/src/io/parquet/predicate_pushdown.cpp | 133 ++++++++++++++++----- cpp/src/io/parquet/reader_impl.cpp | 10 +- cpp/src/io/parquet/reader_impl.hpp | 2 +- cpp/src/io/parquet/reader_impl_chunking.cu | 1 + cpp/src/io/parquet/reader_impl_helpers.cpp | 1 + cpp/src/io/parquet/reader_impl_helpers.hpp | 90 ++------------ 7 files changed, 121 insertions(+), 118 deletions(-) diff --git a/cpp/include/cudf/io/parquet.hpp b/cpp/include/cudf/io/parquet.hpp index 44919448001..0603536c7ea 100644 --- a/cpp/include/cudf/io/parquet.hpp +++ b/cpp/include/cudf/io/parquet.hpp @@ -205,7 +205,7 @@ class parquet_reader_options { * use_columns({"A", "X", "Z"}) * .filter(operation(ast_operator::LESS, column_name_reference{"C"}, literal{100})); * @endcode - * Column "C" need not be present in output column. + * Column "C" need not be present in output table. * Example 2: without column projection * @code * filter(operation(ast_operator::LESS, column_reference{1}, literal{100})); diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index d40a7741ad7..d2ad4d86b14 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -30,10 +30,12 @@ #include +#include + #include -#include #include #include +#include namespace cudf::io::parquet::detail { @@ -385,7 +387,6 @@ void aggregate_reader_metadata::cache_root_dtypes_names(bool strings_to_categori &get_dtype, this](int schema_idx) -> cudf::data_type { // returns type of root level columns only. - // if (schema_idx < 0) { return false; } auto const& schema_elem = get_schema(schema_idx); if (schema_elem.is_stub()) { CUDF_EXPECTS(schema_elem.num_children == 1, "Unexpected number of children for stub"); @@ -513,6 +514,34 @@ std::optional>> aggregate_reader_metadata::fi } // convert column named expression to column index reference expression +named_to_reference_converter::named_to_reference_converter( + std::optional> expr, table_metadata const& metadata) +{ + if (!expr.has_value()) return; + // create map for column name. + std::transform(metadata.schema_info.cbegin(), + metadata.schema_info.cend(), + thrust::counting_iterator(0), + std::inserter(column_name_to_index, column_name_to_index.end()), + [](auto const& sch, auto index) { return std::make_pair(sch.name, index); }); + + expr.value().get().accept(*this); +} + +named_to_reference_converter::named_to_reference_converter( + std::reference_wrapper expr, + host_span root_column_names) +{ + // create map for column name. + std::transform(root_column_names.begin(), + root_column_names.end(), + thrust::counting_iterator(0), + std::inserter(column_name_to_index, column_name_to_index.end()), + [](auto const& name, auto index) { return std::make_pair(name, index); }); + + expr.get().accept(*this); +} + std::reference_wrapper named_to_reference_converter::visit( ast::literal const& expr) { @@ -568,40 +597,82 @@ named_to_reference_converter::visit_operands( return transformed_operands; } -// extract column names from expression -std::reference_wrapper names_from_expression::visit(ast::literal const& expr) -{ - return expr; -} +/** + * @brief Converts named columns to index reference columns + * + */ +class names_from_expression : public ast::detail::expression_transformer { + public: + names_from_expression(std::optional> expr, + std::vector const& skip_names) + : _skip_names(skip_names.cbegin(), skip_names.cend()) + { + if (!expr.has_value()) return; + expr.value().get().accept(*this); + } -std::reference_wrapper names_from_expression::visit( - ast::column_reference const& expr) -{ - return expr; -} + /** + * @copydoc ast::detail::expression_transformer::visit(ast::literal const& ) + */ + std::reference_wrapper visit(ast::literal const& expr) override + { + return expr; + } + /** + * @copydoc ast::detail::expression_transformer::visit(ast::column_reference const& ) + */ + std::reference_wrapper visit(ast::column_reference const& expr) override + { + return expr; + } + /** + * @copydoc ast::detail::expression_transformer::visit(ast::column_name_reference const& ) + */ + std::reference_wrapper visit( + ast::column_name_reference const& expr) override + { + // collect column names + auto col_name = expr.get_column_name(); + if (_skip_names.count(col_name) == 0) { _column_names.insert(col_name); } + return expr; + } + /** + * @copydoc ast::detail::expression_transformer::visit(ast::operation const& ) + */ + std::reference_wrapper visit(ast::operation const& expr) override + { + visit_operands(expr.get_operands()); + return expr; + } -std::reference_wrapper names_from_expression::visit( - ast::column_name_reference const& expr) -{ - // collect column names - auto col_name = expr.get_column_name(); - if (_skip_names.count(col_name) == 0) { _column_names.insert(col_name); } - return expr; -} + /** + * @brief Returns the column names in AST. + * + * @return AST operation expression + */ + [[nodiscard]] std::vector to_vector() && + { + return {std::make_move_iterator(_column_names.begin()), + std::make_move_iterator(_column_names.end())}; + } -std::reference_wrapper names_from_expression::visit( - ast::operation const& expr) -{ - visit_operands(expr.get_operands()); - return expr; -} + private: + void visit_operands(std::vector> operands) + { + for (auto const& operand : operands) { + operand.get().accept(*this); + } + } -void names_from_expression::visit_operands( - std::vector> operands) + std::unordered_set _column_names; + std::unordered_set _skip_names; +}; + +[[nodiscard]] std::vector get_column_names_in_expression( + std::optional> expr, + std::vector const& skip_names) { - for (auto const& operand : operands) { - operand.get().accept(*this); - } + return names_from_expression(expr, skip_names).to_vector(); } } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 4bc53e56baf..12cd4438c96 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -405,9 +405,9 @@ reader::impl::impl(std::size_t chunk_read_limit, if (options.get_filter().has_value() and options.get_columns().has_value()) { // list, struct, dictionary are not supported by AST filter yet. // extract columns not present in get_columns() & keep count to remove at end. - auto extractor = names_from_expression(options.get_filter(), *(options.get_columns())); - filter_columns_names = extractor.get_column_names(); - _num_filter_columns = filter_columns_names->size(); + filter_columns_names = + get_column_names_in_expression(options.get_filter(), *(options.get_columns())); + _num_filter_only_columns = filter_columns_names->size(); } std::tie(_input_columns, _output_buffers, _output_column_schemas) = _metadata->select_columns(options.get_columns(), @@ -546,10 +546,10 @@ table_with_metadata reader::impl::finalize_output( "Predicate filter should return a boolean"); // Exclude columns present in filter only in output auto counting_it = thrust::make_counting_iterator(0); - auto const output_count = read_table->num_columns() - _num_filter_columns; + auto const output_count = read_table->num_columns() - _num_filter_only_columns; auto only_output = read_table->select(counting_it, counting_it + output_count); auto output_table = cudf::detail::apply_boolean_mask(only_output, *predicate, _stream, _mr); - if (_num_filter_columns > 0) { out_metadata.schema_info.resize(output_count); } + if (_num_filter_only_columns > 0) { out_metadata.schema_info.resize(output_count); } return {std::move(output_table), std::move(out_metadata)}; } return {std::make_unique
(std::move(out_columns)), std::move(out_metadata)}; diff --git a/cpp/src/io/parquet/reader_impl.hpp b/cpp/src/io/parquet/reader_impl.hpp index a483dfcfdea..bc9d77dbb02 100644 --- a/cpp/src/io/parquet/reader_impl.hpp +++ b/cpp/src/io/parquet/reader_impl.hpp @@ -367,7 +367,7 @@ class reader::impl { std::unique_ptr _output_metadata; // number of extra filter columns - std::size_t _num_filter_columns{0}; + std::size_t _num_filter_only_columns{0}; bool _strings_to_categorical = false; diff --git a/cpp/src/io/parquet/reader_impl_chunking.cu b/cpp/src/io/parquet/reader_impl_chunking.cu index 912f53a8277..577ab37252d 100644 --- a/cpp/src/io/parquet/reader_impl_chunking.cu +++ b/cpp/src/io/parquet/reader_impl_chunking.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "compact_protocol_reader.hpp" #include "io/comp/nvcomp_adapter.hpp" #include "io/utilities/config_utils.hpp" #include "io/utilities/time_utils.cuh" diff --git a/cpp/src/io/parquet/reader_impl_helpers.cpp b/cpp/src/io/parquet/reader_impl_helpers.cpp index bc31c806a45..3bc1991fa2f 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.cpp +++ b/cpp/src/io/parquet/reader_impl_helpers.cpp @@ -16,6 +16,7 @@ #include "reader_impl_helpers.hpp" +#include "compact_protocol_reader.hpp" #include "io/parquet/parquet.hpp" #include "io/utilities/row_selection.hpp" diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index f23a3afbe61..01df8c98153 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -16,7 +16,6 @@ #pragma once -#include "compact_protocol_reader.hpp" #include "parquet_gpu.hpp" #include @@ -25,12 +24,8 @@ #include #include -#include -#include - #include #include -#include #include namespace cudf::io::parquet::detail { @@ -295,38 +290,11 @@ class aggregate_reader_metadata { class named_to_reference_converter : public ast::detail::expression_transformer { public: named_to_reference_converter(std::optional> expr, - table_metadata const& metadata) - { - if (!expr.has_value()) return; - // create map for column name. - auto it_name_id = thrust::make_zip_iterator(metadata.schema_info.cbegin(), - thrust::counting_iterator(0)); - std::transform(it_name_id, - it_name_id + metadata.schema_info.size(), - std::inserter(column_name_to_index, column_name_to_index.end()), - [](auto const& name_index) { - return std::make_pair(thrust::get<0>(name_index).name, - thrust::get<1>(name_index)); - }); - - expr.value().get().accept(*this); - } + table_metadata const& metadata); named_to_reference_converter(std::reference_wrapper expr, - host_span root_column_names) - { - // create map for column name. - std::transform( - thrust::make_zip_iterator(root_column_names.begin(), thrust::counting_iterator(0)), - thrust::make_zip_iterator(root_column_names.end(), - thrust::counting_iterator(root_column_names.size())), - std::inserter(column_name_to_index, column_name_to_index.end()), - [](auto const& name_index) { - return std::make_pair(thrust::get<0>(name_index), thrust::get<1>(name_index)); - }); - - expr.get().accept(*this); - } + host_span root_column_names); + /** * @copydoc ast::detail::expression_transformer::visit(ast::literal const& ) */ @@ -368,52 +336,14 @@ class named_to_reference_converter : public ast::detail::expression_transformer }; /** - * @brief Converts named columns to index reference columns + * @brief Get the column names in expression object * + * @param expr The optional expression object to get the column names from + * @param skip_names The names of column names to skip in returned column names + * @return The column names present in expression object except the skip_names */ -class names_from_expression : public ast::detail::expression_transformer { - public: - names_from_expression(std::optional> expr, - std::vector const& skip_names) - : _skip_names(skip_names.cbegin(), skip_names.cend()) - { - if (!expr.has_value()) return; - expr.value().get().accept(*this); - } - - /** - * @copydoc ast::detail::expression_transformer::visit(ast::literal const& ) - */ - std::reference_wrapper visit(ast::literal const& expr) override; - /** - * @copydoc ast::detail::expression_transformer::visit(ast::column_reference const& ) - */ - std::reference_wrapper visit(ast::column_reference const& expr) override; - /** - * @copydoc ast::detail::expression_transformer::visit(ast::column_name_reference const& ) - */ - std::reference_wrapper visit( - ast::column_name_reference const& expr) override; - /** - * @copydoc ast::detail::expression_transformer::visit(ast::operation const& ) - */ - std::reference_wrapper visit(ast::operation const& expr) override; - - /** - * @brief Returns the column names in AST. - * - * @return AST operation expression - */ - [[nodiscard]] std::vector get_column_names() const - { - return {_column_names.begin(), _column_names.end()}; - } - - private: - void visit_operands(std::vector> operands); - - std::unordered_set _column_names; - std::unordered_set _skip_names; -}; +[[nodiscard]] std::vector get_column_names_in_expression( + std::optional> expr, + std::vector const& skip_names); } // namespace cudf::io::parquet::detail From a220d7d6b3a8bfeed91e3f16b5ea46a61ad5f039 Mon Sep 17 00:00:00 2001 From: Karthikeyan Natarajan Date: Fri, 10 May 2024 07:37:28 +0000 Subject: [PATCH 7/8] fix col index ref on projection --- cpp/src/io/parquet/predicate_pushdown.cpp | 83 +++++++++----------- cpp/src/io/parquet/reader_impl.cpp | 10 ++- cpp/src/io/parquet/reader_impl_helpers.cpp | 4 +- cpp/src/io/parquet/reader_impl_helpers.hpp | 17 ++-- cpp/src/io/parquet/reader_impl_preprocess.cu | 3 +- cpp/tests/io/parquet_reader_test.cpp | 15 +++- 6 files changed, 72 insertions(+), 60 deletions(-) diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index 3dee6138a3b..d0cb7475dac 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -129,7 +129,7 @@ struct stats_caster { // Creates device columns from column statistics (min, max) template std::pair, std::unique_ptr> operator()( - size_t col_idx, + int schema_idx, cudf::data_type dtype, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) const @@ -208,22 +208,31 @@ struct stats_caster { }; // local struct host_column host_column min(total_row_groups); host_column max(total_row_groups); - size_type stats_idx = 0; for (size_t src_idx = 0; src_idx < row_group_indices.size(); ++src_idx) { for (auto const rg_idx : row_group_indices[src_idx]) { auto const& row_group = per_file_metadata[src_idx].row_groups[rg_idx]; - auto const& colchunk = row_group.columns[col_idx]; - // To support deprecated min, max fields. - auto const& min_value = colchunk.meta_data.statistics.min_value.has_value() - ? colchunk.meta_data.statistics.min_value - : colchunk.meta_data.statistics.min; - auto const& max_value = colchunk.meta_data.statistics.max_value.has_value() - ? colchunk.meta_data.statistics.max_value - : colchunk.meta_data.statistics.max; - // translate binary data to Type then to - min.set_index(stats_idx, min_value, colchunk.meta_data.type); - max.set_index(stats_idx, max_value, colchunk.meta_data.type); + auto col = std::find_if( + row_group.columns.begin(), + row_group.columns.end(), + [schema_idx](ColumnChunk const& col) { return col.schema_idx == schema_idx; }); + if (col != std::end(row_group.columns)) { + auto const& colchunk = *col; + // To support deprecated min, max fields. + auto const& min_value = colchunk.meta_data.statistics.min_value.has_value() + ? colchunk.meta_data.statistics.min_value + : colchunk.meta_data.statistics.min; + auto const& max_value = colchunk.meta_data.statistics.max_value.has_value() + ? colchunk.meta_data.statistics.max_value + : colchunk.meta_data.statistics.max; + // translate binary data to Type then to + min.set_index(stats_idx, min_value, colchunk.meta_data.type); + max.set_index(stats_idx, max_value, colchunk.meta_data.type); + } else { + // Marking it null, if column present in row group + min.set_index(stats_idx, thrust::nullopt, {}); + max.set_index(stats_idx, thrust::nullopt, {}); + } stats_idx++; } }; @@ -377,17 +386,17 @@ class stats_expression_converter : public ast::detail::expression_transformer { }; } // namespace -void aggregate_reader_metadata::cache_root_dtypes_names(bool strings_to_categorical, - type_id timestamp_type_id) +void aggregate_reader_metadata::cache_output_dtypes(host_span output_schemas, + bool strings_to_categorical, + type_id timestamp_type_id) { - // TODO, get types and names for only names present in filter.? and their col_idx. - // create root column types and names as vector - if (!_root_level_types.empty()) return; + // store output column types as vector + if (!_output_types.empty()) return; std::function get_dtype = [strings_to_categorical, timestamp_type_id, &get_dtype, this](int schema_idx) -> cudf::data_type { - // returns type of root level columns only. + // returns type of columns by using schema_idx. auto const& schema_elem = get_schema(schema_idx); if (schema_elem.is_stub()) { CUDF_EXPECTS(schema_elem.num_children == 1, "Unexpected number of children for stub"); @@ -400,20 +409,19 @@ void aggregate_reader_metadata::cache_root_dtypes_names(bool strings_to_categori ? type_id::LIST : to_type_id(schema_elem, strings_to_categorical, timestamp_type_id); auto const dtype = to_data_type(col_type, schema_elem); - // path_is_valid is skipped for nested columns here. TODO: more test cases where no leaf. + // path_is_valid is skipped for nested columns here. return dtype; }; - auto const& root = get_schema(0); - for (auto const& schema_idx : root.children_idx) { + for (auto const& schema_idx : output_schemas) { if (schema_idx < 0) { continue; } - _root_level_types.push_back(get_dtype(schema_idx)); - _root_level_names.push_back(get_schema(schema_idx).name); + _output_types.push_back(get_dtype(schema_idx)); } } std::optional>> aggregate_reader_metadata::filter_row_groups( host_span const> row_group_indices, + host_span output_column_schemas, std::reference_wrapper filter, rmm::cuda_stream_view stream) const { @@ -447,8 +455,9 @@ std::optional>> aggregate_reader_metadata::fi // For each column, it contains #sources * #column_chunks_per_src rows. std::vector> columns; stats_caster stats_col{total_row_groups, per_file_metadata, input_row_group_indices}; - for (size_t col_idx = 0; col_idx < _root_level_types.size(); col_idx++) { - auto const& dtype = _root_level_types[col_idx]; + for (size_t col_idx = 0; col_idx < _output_types.size(); col_idx++) { + auto const schema_idx = output_column_schemas[col_idx]; + auto const& dtype = _output_types[col_idx]; // Only comparable types except fixed point are supported. if (cudf::is_compound(dtype) && dtype.id() != cudf::type_id::STRING) { // placeholder only for unsupported types. @@ -459,18 +468,14 @@ std::optional>> aggregate_reader_metadata::fi continue; } auto [min_col, max_col] = - cudf::type_dispatcher(dtype, stats_col, col_idx, dtype, stream, mr); + cudf::type_dispatcher(dtype, stats_col, schema_idx, dtype, stream, mr); columns.push_back(std::move(min_col)); columns.push_back(std::move(max_col)); } auto stats_table = cudf::table(std::move(columns)); - // named filter to reference filter w.r.t parquet schema order. - auto expr_conv = named_to_reference_converter(filter, _root_level_names); - auto reference_filter = expr_conv.get_converted_expr(); // Converts AST to StatsAST with reference to min, max columns in above `stats_table`. - stats_expression_converter stats_expr{reference_filter.value().get(), - static_cast(_root_level_types.size())}; + stats_expression_converter stats_expr{filter.get(), static_cast(_output_types.size())}; auto stats_ast = stats_expr.get_stats_expr(); auto predicate_col = cudf::detail::compute_column(stats_table, stats_ast.get(), stream, mr); auto predicate = predicate_col->view(); @@ -529,20 +534,6 @@ named_to_reference_converter::named_to_reference_converter( expr.value().get().accept(*this); } -named_to_reference_converter::named_to_reference_converter( - std::reference_wrapper expr, - host_span root_column_names) -{ - // create map for column name. - std::transform(root_column_names.begin(), - root_column_names.end(), - thrust::counting_iterator(0), - std::inserter(column_name_to_index, column_name_to_index.end()), - [](auto const& name, auto index) { return std::make_pair(name, index); }); - - expr.get().accept(*this); -} - std::reference_wrapper named_to_reference_converter::visit( ast::literal const& expr) { diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 3df322892b0..d3fe2a8df3b 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -23,9 +23,10 @@ #include #include -#include #include +#include + #include #include @@ -417,8 +418,9 @@ reader::impl::impl(std::size_t chunk_read_limit, _strings_to_categorical, _timestamp_type.id()); - // Find the name, and dtypes of parquet root level schema. (save it in _metadata.) - _metadata->cache_root_dtypes_names(_strings_to_categorical, _timestamp_type.id()); + // Find the dtypes of output columns (save it in _metadata). + _metadata->cache_output_dtypes( + _output_column_schemas, _strings_to_categorical, _timestamp_type.id()); // Save the states of the output buffers for reuse in `chunk_read()`. for (auto const& buff : _output_buffers) { @@ -570,7 +572,7 @@ table_with_metadata reader::impl::read( auto expr_conv = named_to_reference_converter(filter, metadata); auto output_filter = expr_conv.get_converted_expr(); - prepare_data(skip_rows, num_rows, uses_custom_row_bounds, row_group_indices, filter); + prepare_data(skip_rows, num_rows, uses_custom_row_bounds, row_group_indices, output_filter); return read_chunk_internal(uses_custom_row_bounds, output_filter); } diff --git a/cpp/src/io/parquet/reader_impl_helpers.cpp b/cpp/src/io/parquet/reader_impl_helpers.cpp index 04abfa98961..41d763d3740 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.cpp +++ b/cpp/src/io/parquet/reader_impl_helpers.cpp @@ -635,13 +635,15 @@ aggregate_reader_metadata::select_row_groups( host_span const> row_group_indices, int64_t skip_rows_opt, std::optional const& num_rows_opt, + host_span output_column_schemas, std::optional> filter, rmm::cuda_stream_view stream) const { std::optional>> filtered_row_group_indices; // if filter is not empty, then gather row groups to read after predicate pushdown if (filter.has_value()) { - filtered_row_group_indices = filter_row_groups(row_group_indices, filter.value(), stream); + filtered_row_group_indices = + filter_row_groups(row_group_indices, output_column_schemas, filter.value(), stream); if (filtered_row_group_indices.has_value()) { row_group_indices = host_span const>(filtered_row_group_indices.value()); diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index e23e12d5314..ab52d10d9f8 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -123,8 +123,7 @@ class aggregate_reader_metadata { int64_t num_rows; size_type num_row_groups; - std::vector _root_level_types; - std::vector _root_level_names; + std::vector _output_types; /** * @brief Create a metadata object from each element in the source vector */ @@ -227,23 +226,28 @@ class aggregate_reader_metadata { [[nodiscard]] std::vector get_pandas_index_names() const; /** - * @brief Extract root level column data types and names and caches them. + * @brief Extract output column data types and caches them. * + * @param output_schemas schema indices of output columns * @param strings_to_categorical Type conversion parameter * @param timestamp_type_id Type conversion parameter */ - void cache_root_dtypes_names(bool strings_to_categorical, type_id timestamp_type_id); + void cache_output_dtypes(host_span output_schemas, + bool strings_to_categorical, + type_id timestamp_type_id); /** * @brief Filters the row groups based on predicate filter * * @param row_group_indices Lists of row groups to read, one per source + * @param output_column_schemas schema indices of output columns * @param filter AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches * @return Filtered row group indices, if any is filtered. */ [[nodiscard]] std::optional>> filter_row_groups( host_span const> row_group_indices, + host_span output_column_schemas, std::reference_wrapper filter, rmm::cuda_stream_view stream) const; @@ -256,6 +260,7 @@ class aggregate_reader_metadata { * @param row_group_indices Lists of row groups to read, one per source * @param row_start Starting row of the selection * @param row_count Total number of rows selected + * @param output_column_schemas schema indices of output columns * @param filter Optional AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches * @return A tuple of corrected row_start, row_count and list of row group indexes and its @@ -265,6 +270,7 @@ class aggregate_reader_metadata { host_span const> row_group_indices, int64_t row_start, std::optional const& row_count, + host_span output_column_schemas, std::optional> filter, rmm::cuda_stream_view stream) const; @@ -300,9 +306,6 @@ class named_to_reference_converter : public ast::detail::expression_transformer named_to_reference_converter(std::optional> expr, table_metadata const& metadata); - named_to_reference_converter(std::reference_wrapper expr, - host_span root_column_names); - /** * @copydoc ast::detail::expression_transformer::visit(ast::literal const& ) */ diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index c2c261a1753..7b988fa4ceb 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -1222,7 +1222,8 @@ void reader::impl::preprocess_file( std::tie( _file_itm_data.global_skip_rows, _file_itm_data.global_num_rows, _file_itm_data.row_groups) = - _metadata->select_row_groups(row_group_indices, skip_rows, num_rows, filter, _stream); + _metadata->select_row_groups( + row_group_indices, skip_rows, num_rows, _output_column_schemas, filter, _stream); // check for page indexes _has_page_index = std::all_of(_file_itm_data.row_groups.begin(), diff --git a/cpp/tests/io/parquet_reader_test.cpp b/cpp/tests/io/parquet_reader_test.cpp index 8d7aa8d0586..aa9172b0608 100644 --- a/cpp/tests/io/parquet_reader_test.cpp +++ b/cpp/tests/io/parquet_reader_test.cpp @@ -1408,6 +1408,7 @@ TEST_F(ParquetReaderTest, FilterIdentity) TEST_F(ParquetReaderTest, FilterWithColumnProjection) { + // col_uint32, col_int64, col_double auto [src, filepath] = create_parquet_with_stats("FilterWithColumnProjection.parquet"); auto val = cudf::numeric_scalar{10}; auto lit = cudf::ast::literal{val}; @@ -1430,7 +1431,7 @@ TEST_F(ParquetReaderTest, FilterWithColumnProjection) } { // column_reference in parquet filter (indices as per order of column projection) - auto col_index2 = cudf::ast::column_reference{0}; + auto col_index2 = cudf::ast::column_reference{1}; auto read_ref_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_index2, lit); auto projected_table = cudf::table_view{{src.get_column(2), src.get_column(0)}}; @@ -1441,6 +1442,18 @@ TEST_F(ParquetReaderTest, FilterWithColumnProjection) auto result = cudf::io::read_parquet(read_opts); CUDF_TEST_EXPECT_TABLES_EQUAL(*result.tbl, *expected); } + + // Error cases + { // column_reference is not same type as literal, column_reference index is out of bounds + for (auto const index : {0, 2}) { + auto col_index2 = cudf::ast::column_reference{index}; + auto read_ref_expr = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_index2, lit); + auto read_opts = cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}) + .columns({"col_double", "col_uint32"}) + .filter(read_ref_expr); + EXPECT_THROW(cudf::io::read_parquet(read_opts), cudf::logic_error); + } + } } TEST_F(ParquetReaderTest, FilterReferenceExpression) From 9e4008ea27c094b861a8e77203e8524c627a6f44 Mon Sep 17 00:00:00 2001 From: karthikeyann Date: Wed, 15 May 2024 22:34:40 -0500 Subject: [PATCH 8/8] remove caching output dtypes --- cpp/src/io/parquet/predicate_pushdown.cpp | 40 ++------------------ cpp/src/io/parquet/reader_impl.cpp | 4 -- cpp/src/io/parquet/reader_impl_helpers.cpp | 5 ++- cpp/src/io/parquet/reader_impl_helpers.hpp | 16 ++------ cpp/src/io/parquet/reader_impl_preprocess.cu | 18 ++++++++- 5 files changed, 27 insertions(+), 56 deletions(-) diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index d0cb7475dac..0109be661a7 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -386,41 +386,9 @@ class stats_expression_converter : public ast::detail::expression_transformer { }; } // namespace -void aggregate_reader_metadata::cache_output_dtypes(host_span output_schemas, - bool strings_to_categorical, - type_id timestamp_type_id) -{ - // store output column types as vector - if (!_output_types.empty()) return; - std::function get_dtype = [strings_to_categorical, - timestamp_type_id, - &get_dtype, - this](int schema_idx) -> cudf::data_type { - // returns type of columns by using schema_idx. - auto const& schema_elem = get_schema(schema_idx); - if (schema_elem.is_stub()) { - CUDF_EXPECTS(schema_elem.num_children == 1, "Unexpected number of children for stub"); - return get_dtype(schema_elem.children_idx[0]); - } - - auto const one_level_list = schema_elem.is_one_level_list(get_schema(schema_elem.parent_idx)); - // if we're at the root, this is a new output column - auto const col_type = one_level_list - ? type_id::LIST - : to_type_id(schema_elem, strings_to_categorical, timestamp_type_id); - auto const dtype = to_data_type(col_type, schema_elem); - // path_is_valid is skipped for nested columns here. - return dtype; - }; - - for (auto const& schema_idx : output_schemas) { - if (schema_idx < 0) { continue; } - _output_types.push_back(get_dtype(schema_idx)); - } -} - std::optional>> aggregate_reader_metadata::filter_row_groups( host_span const> row_group_indices, + host_span output_dtypes, host_span output_column_schemas, std::reference_wrapper filter, rmm::cuda_stream_view stream) const @@ -455,9 +423,9 @@ std::optional>> aggregate_reader_metadata::fi // For each column, it contains #sources * #column_chunks_per_src rows. std::vector> columns; stats_caster stats_col{total_row_groups, per_file_metadata, input_row_group_indices}; - for (size_t col_idx = 0; col_idx < _output_types.size(); col_idx++) { + for (size_t col_idx = 0; col_idx < output_dtypes.size(); col_idx++) { auto const schema_idx = output_column_schemas[col_idx]; - auto const& dtype = _output_types[col_idx]; + auto const& dtype = output_dtypes[col_idx]; // Only comparable types except fixed point are supported. if (cudf::is_compound(dtype) && dtype.id() != cudf::type_id::STRING) { // placeholder only for unsupported types. @@ -475,7 +443,7 @@ std::optional>> aggregate_reader_metadata::fi auto stats_table = cudf::table(std::move(columns)); // Converts AST to StatsAST with reference to min, max columns in above `stats_table`. - stats_expression_converter stats_expr{filter.get(), static_cast(_output_types.size())}; + stats_expression_converter stats_expr{filter.get(), static_cast(output_dtypes.size())}; auto stats_ast = stats_expr.get_stats_expr(); auto predicate_col = cudf::detail::compute_column(stats_table, stats_ast.get(), stream, mr); auto predicate = predicate_col->view(); diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index a5bea11ee4f..05aac6ad1c8 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -453,10 +453,6 @@ reader::impl::impl(std::size_t chunk_read_limit, _strings_to_categorical, _timestamp_type.id()); - // Find the dtypes of output columns (save it in _metadata). - _metadata->cache_output_dtypes( - _output_column_schemas, _strings_to_categorical, _timestamp_type.id()); - // Save the states of the output buffers for reuse in `chunk_read()`. for (auto const& buff : _output_buffers) { _output_buffers_template.emplace_back(cudf::io::detail::inline_column_buffer::empty_like(buff)); diff --git a/cpp/src/io/parquet/reader_impl_helpers.cpp b/cpp/src/io/parquet/reader_impl_helpers.cpp index bc60313e9cf..2cee5079e4c 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.cpp +++ b/cpp/src/io/parquet/reader_impl_helpers.cpp @@ -635,6 +635,7 @@ aggregate_reader_metadata::select_row_groups( host_span const> row_group_indices, int64_t skip_rows_opt, std::optional const& num_rows_opt, + host_span output_dtypes, host_span output_column_schemas, std::optional> filter, rmm::cuda_stream_view stream) const @@ -642,8 +643,8 @@ aggregate_reader_metadata::select_row_groups( std::optional>> filtered_row_group_indices; // if filter is not empty, then gather row groups to read after predicate pushdown if (filter.has_value()) { - filtered_row_group_indices = - filter_row_groups(row_group_indices, output_column_schemas, filter.value(), stream); + filtered_row_group_indices = filter_row_groups( + row_group_indices, output_dtypes, output_column_schemas, filter.value(), stream); if (filtered_row_group_indices.has_value()) { row_group_indices = host_span const>(filtered_row_group_indices.value()); diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index ab52d10d9f8..157f5a54f0a 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -123,7 +123,6 @@ class aggregate_reader_metadata { int64_t num_rows; size_type num_row_groups; - std::vector _output_types; /** * @brief Create a metadata object from each element in the source vector */ @@ -225,21 +224,11 @@ class aggregate_reader_metadata { */ [[nodiscard]] std::vector get_pandas_index_names() const; - /** - * @brief Extract output column data types and caches them. - * - * @param output_schemas schema indices of output columns - * @param strings_to_categorical Type conversion parameter - * @param timestamp_type_id Type conversion parameter - */ - void cache_output_dtypes(host_span output_schemas, - bool strings_to_categorical, - type_id timestamp_type_id); - /** * @brief Filters the row groups based on predicate filter * * @param row_group_indices Lists of row groups to read, one per source + * @param output_dtypes Datatypes of of output columns * @param output_column_schemas schema indices of output columns * @param filter AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches @@ -247,6 +236,7 @@ class aggregate_reader_metadata { */ [[nodiscard]] std::optional>> filter_row_groups( host_span const> row_group_indices, + host_span output_dtypes, host_span output_column_schemas, std::reference_wrapper filter, rmm::cuda_stream_view stream) const; @@ -260,6 +250,7 @@ class aggregate_reader_metadata { * @param row_group_indices Lists of row groups to read, one per source * @param row_start Starting row of the selection * @param row_count Total number of rows selected + * @param output_dtypes Datatypes of of output columns * @param output_column_schemas schema indices of output columns * @param filter Optional AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches @@ -270,6 +261,7 @@ class aggregate_reader_metadata { host_span const> row_group_indices, int64_t row_start, std::optional const& row_count, + host_span output_dtypes, host_span output_column_schemas, std::optional> filter, rmm::cuda_stream_view stream) const; diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 81530da4ce8..084f82a2ca0 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -1229,10 +1229,24 @@ void reader::impl::preprocess_file( { CUDF_EXPECTS(!_file_preprocessed, "Attempted to preprocess file more than once"); + // if filter is not empty, then create output types as vector and pass for filtering. + std::vector output_dtypes; + if (filter.has_value()) { + std::transform(_output_buffers_template.cbegin(), + _output_buffers_template.cend(), + std::back_inserter(output_dtypes), + [](auto const& col) { return col.type; }); + } + std::tie( _file_itm_data.global_skip_rows, _file_itm_data.global_num_rows, _file_itm_data.row_groups) = - _metadata->select_row_groups( - row_group_indices, skip_rows, num_rows, _output_column_schemas, filter, _stream); + _metadata->select_row_groups(row_group_indices, + skip_rows, + num_rows, + output_dtypes, + _output_column_schemas, + filter, + _stream); // check for page indexes _has_page_index = std::all_of(_file_itm_data.row_groups.begin(),