Skip to content

Commit 47ed345

Browse files
authored
Fix parquet predicate filtering with column projection (#15113)
Fixes #15051 The predicate filtering in parquet did not work while column projection is used. This PR fixes that limitation. With this PR change, the user will be able to use both column name reference and column index reference in the filter. - column name reference: the filters may specify any columns by name even if they are not present in column projection. - column reference (index): The indices used should be the indices of output columns in the requested order. This is achieved by extracting column names from filter and add to output buffers, after predicate filtering is done, these filter-only columns are removed and only requested columns are returned. The change includes reading only output columns' statistics data instead of all root columns. Summary of changes: - `get_column_names_in_expression` extracts column names in filter. - The extra columns in filter are added to output buffers during reader initialization - `cpp/src/io/parquet/reader_impl_helpers.cpp`, `cpp/src/io/parquet/reader_impl.cpp` - instead of extracting statistics data of all root columns, it extracts for only output columns (including columns in filter) - `cpp/src/io/parquet/predicate_pushdown.cpp` - To do this, output column schemas and its dtypes should be cached. - statistics data extraction code is updated to check for `schema_idx` in row group metadata. - No need to convert filter again for all root columns, reuse the passed output columns reference filter. - Rest of the code is same. - After the output filter predicate is calculated, these filter-only columns are removed - moved `named_to_reference_converter` constructor to cpp, and remove used constructor. - small include<> cleanup Authors: - Karthikeyan (https://github.com/karthikeyann) - Vukasin Milovanovic (https://github.com/vuule) - Muhammad Haseeb (https://github.com/mhaseeb123) Approvers: - Lawrence Mitchell (https://github.com/wence-) - Vukasin Milovanovic (https://github.com/vuule) - Muhammad Haseeb (https://github.com/mhaseeb123) URL: #15113
1 parent c7fe7fe commit 47ed345

9 files changed

+276
-65
lines changed

cpp/include/cudf/io/parquet.hpp

+26-3
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,31 @@ class parquet_reader_options {
205205
/**
206206
* @brief Sets AST based filter for predicate pushdown.
207207
*
208+
* The filter can utilize cudf::ast::column_name_reference to reference a column by its name,
209+
* even if it's not necessarily present in the requested projected columns.
210+
* To refer to output column indices, you can use cudf::ast::column_reference.
211+
*
212+
* For a parquet with columns ["A", "B", "C", ... "X", "Y", "Z"],
213+
* Example 1: with/without column projection
214+
* @code
215+
* use_columns({"A", "X", "Z"})
216+
* .filter(operation(ast_operator::LESS, column_name_reference{"C"}, literal{100}));
217+
* @endcode
218+
* Column "C" need not be present in output table.
219+
* Example 2: without column projection
220+
* @code
221+
* filter(operation(ast_operator::LESS, column_reference{1}, literal{100}));
222+
* @endcode
223+
* Here, `1` will refer to column "B" because output will contain all columns in
224+
* order ["A", ..., "Z"].
225+
* Example 3: with column projection
226+
* @code
227+
* use_columns({"A", "Z", "X"})
228+
* .filter(operation(ast_operator::LESS, column_reference{1}, literal{100}));
229+
* @endcode
230+
* Here, `1` will refer to column "Z" because output will contain 3 columns in
231+
* order ["A", "Z", "X"].
232+
*
208233
* @param filter AST expression to use as filter
209234
*/
210235
void set_filter(ast::expression const& filter) { _filter = filter; }
@@ -309,9 +334,7 @@ class parquet_reader_options_builder {
309334
}
310335

311336
/**
312-
* @brief Sets vector of individual row groups to read.
313-
*
314-
* @param filter Vector of row groups to read
337+
* @copydoc parquet_reader_options::set_filter
315338
* @return this for chaining
316339
*/
317340
parquet_reader_options_builder& filter(ast::expression const& filter)

cpp/src/io/parquet/predicate_pushdown.cpp

+122-17
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@
3131
#include <rmm/mr/device/per_device_resource.hpp>
3232
#include <rmm/resource_ref.hpp>
3333

34+
#include <thrust/iterator/counting_iterator.h>
35+
3436
#include <algorithm>
35-
#include <list>
3637
#include <numeric>
3738
#include <optional>
39+
#include <unordered_set>
3840

3941
namespace cudf::io::parquet::detail {
4042

@@ -127,7 +129,7 @@ struct stats_caster {
127129
// Creates device columns from column statistics (min, max)
128130
template <typename T>
129131
std::pair<std::unique_ptr<column>, std::unique_ptr<column>> operator()(
130-
size_t col_idx,
132+
int schema_idx,
131133
cudf::data_type dtype,
132134
rmm::cuda_stream_view stream,
133135
rmm::device_async_resource_ref mr) const
@@ -206,22 +208,31 @@ struct stats_caster {
206208
}; // local struct host_column
207209
host_column min(total_row_groups);
208210
host_column max(total_row_groups);
209-
210211
size_type stats_idx = 0;
211212
for (size_t src_idx = 0; src_idx < row_group_indices.size(); ++src_idx) {
212213
for (auto const rg_idx : row_group_indices[src_idx]) {
213214
auto const& row_group = per_file_metadata[src_idx].row_groups[rg_idx];
214-
auto const& colchunk = row_group.columns[col_idx];
215-
// To support deprecated min, max fields.
216-
auto const& min_value = colchunk.meta_data.statistics.min_value.has_value()
217-
? colchunk.meta_data.statistics.min_value
218-
: colchunk.meta_data.statistics.min;
219-
auto const& max_value = colchunk.meta_data.statistics.max_value.has_value()
220-
? colchunk.meta_data.statistics.max_value
221-
: colchunk.meta_data.statistics.max;
222-
// translate binary data to Type then to <T>
223-
min.set_index(stats_idx, min_value, colchunk.meta_data.type);
224-
max.set_index(stats_idx, max_value, colchunk.meta_data.type);
215+
auto col = std::find_if(
216+
row_group.columns.begin(),
217+
row_group.columns.end(),
218+
[schema_idx](ColumnChunk const& col) { return col.schema_idx == schema_idx; });
219+
if (col != std::end(row_group.columns)) {
220+
auto const& colchunk = *col;
221+
// To support deprecated min, max fields.
222+
auto const& min_value = colchunk.meta_data.statistics.min_value.has_value()
223+
? colchunk.meta_data.statistics.min_value
224+
: colchunk.meta_data.statistics.min;
225+
auto const& max_value = colchunk.meta_data.statistics.max_value.has_value()
226+
? colchunk.meta_data.statistics.max_value
227+
: colchunk.meta_data.statistics.max;
228+
// translate binary data to Type then to <T>
229+
min.set_index(stats_idx, min_value, colchunk.meta_data.type);
230+
max.set_index(stats_idx, max_value, colchunk.meta_data.type);
231+
} else {
232+
// Marking it null, if column present in row group
233+
min.set_index(stats_idx, thrust::nullopt, {});
234+
max.set_index(stats_idx, thrust::nullopt, {});
235+
}
225236
stats_idx++;
226237
}
227238
};
@@ -378,6 +389,7 @@ class stats_expression_converter : public ast::detail::expression_transformer {
378389
std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::filter_row_groups(
379390
host_span<std::vector<size_type> const> row_group_indices,
380391
host_span<data_type const> output_dtypes,
392+
host_span<int const> output_column_schemas,
381393
std::reference_wrapper<ast::expression const> filter,
382394
rmm::cuda_stream_view stream) const
383395
{
@@ -412,7 +424,8 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
412424
std::vector<std::unique_ptr<column>> columns;
413425
stats_caster stats_col{total_row_groups, per_file_metadata, input_row_group_indices};
414426
for (size_t col_idx = 0; col_idx < output_dtypes.size(); col_idx++) {
415-
auto const& dtype = output_dtypes[col_idx];
427+
auto const schema_idx = output_column_schemas[col_idx];
428+
auto const& dtype = output_dtypes[col_idx];
416429
// Only comparable types except fixed point are supported.
417430
if (cudf::is_compound(dtype) && dtype.id() != cudf::type_id::STRING) {
418431
// placeholder only for unsupported types.
@@ -423,14 +436,14 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
423436
continue;
424437
}
425438
auto [min_col, max_col] =
426-
cudf::type_dispatcher<dispatch_storage_type>(dtype, stats_col, col_idx, dtype, stream, mr);
439+
cudf::type_dispatcher<dispatch_storage_type>(dtype, stats_col, schema_idx, dtype, stream, mr);
427440
columns.push_back(std::move(min_col));
428441
columns.push_back(std::move(max_col));
429442
}
430443
auto stats_table = cudf::table(std::move(columns));
431444

432445
// Converts AST to StatsAST with reference to min, max columns in above `stats_table`.
433-
stats_expression_converter stats_expr{filter, static_cast<size_type>(output_dtypes.size())};
446+
stats_expression_converter stats_expr{filter.get(), static_cast<size_type>(output_dtypes.size())};
434447
auto stats_ast = stats_expr.get_stats_expr();
435448
auto predicate_col = cudf::detail::compute_column(stats_table, stats_ast.get(), stream, mr);
436449
auto predicate = predicate_col->view();
@@ -475,6 +488,20 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
475488
}
476489

477490
// convert column named expression to column index reference expression
491+
named_to_reference_converter::named_to_reference_converter(
492+
std::optional<std::reference_wrapper<ast::expression const>> expr, table_metadata const& metadata)
493+
{
494+
if (!expr.has_value()) return;
495+
// create map for column name.
496+
std::transform(metadata.schema_info.cbegin(),
497+
metadata.schema_info.cend(),
498+
thrust::counting_iterator<size_t>(0),
499+
std::inserter(column_name_to_index, column_name_to_index.end()),
500+
[](auto const& sch, auto index) { return std::make_pair(sch.name, index); });
501+
502+
expr.value().get().accept(*this);
503+
}
504+
478505
std::reference_wrapper<ast::expression const> named_to_reference_converter::visit(
479506
ast::literal const& expr)
480507
{
@@ -530,4 +557,82 @@ named_to_reference_converter::visit_operands(
530557
return transformed_operands;
531558
}
532559

560+
/**
561+
* @brief Converts named columns to index reference columns
562+
*
563+
*/
564+
class names_from_expression : public ast::detail::expression_transformer {
565+
public:
566+
names_from_expression(std::optional<std::reference_wrapper<ast::expression const>> expr,
567+
std::vector<std::string> const& skip_names)
568+
: _skip_names(skip_names.cbegin(), skip_names.cend())
569+
{
570+
if (!expr.has_value()) return;
571+
expr.value().get().accept(*this);
572+
}
573+
574+
/**
575+
* @copydoc ast::detail::expression_transformer::visit(ast::literal const& )
576+
*/
577+
std::reference_wrapper<ast::expression const> visit(ast::literal const& expr) override
578+
{
579+
return expr;
580+
}
581+
/**
582+
* @copydoc ast::detail::expression_transformer::visit(ast::column_reference const& )
583+
*/
584+
std::reference_wrapper<ast::expression const> visit(ast::column_reference const& expr) override
585+
{
586+
return expr;
587+
}
588+
/**
589+
* @copydoc ast::detail::expression_transformer::visit(ast::column_name_reference const& )
590+
*/
591+
std::reference_wrapper<ast::expression const> visit(
592+
ast::column_name_reference const& expr) override
593+
{
594+
// collect column names
595+
auto col_name = expr.get_column_name();
596+
if (_skip_names.count(col_name) == 0) { _column_names.insert(col_name); }
597+
return expr;
598+
}
599+
/**
600+
* @copydoc ast::detail::expression_transformer::visit(ast::operation const& )
601+
*/
602+
std::reference_wrapper<ast::expression const> visit(ast::operation const& expr) override
603+
{
604+
visit_operands(expr.get_operands());
605+
return expr;
606+
}
607+
608+
/**
609+
* @brief Returns the column names in AST.
610+
*
611+
* @return AST operation expression
612+
*/
613+
[[nodiscard]] std::vector<std::string> to_vector() &&
614+
{
615+
return {std::make_move_iterator(_column_names.begin()),
616+
std::make_move_iterator(_column_names.end())};
617+
}
618+
619+
private:
620+
void visit_operands(std::vector<std::reference_wrapper<ast::expression const>> operands)
621+
{
622+
for (auto const& operand : operands) {
623+
operand.get().accept(*this);
624+
}
625+
}
626+
627+
std::unordered_set<std::string> _column_names;
628+
std::unordered_set<std::string> _skip_names;
629+
};
630+
631+
[[nodiscard]] std::vector<std::string> get_column_names_in_expression(
632+
std::optional<std::reference_wrapper<ast::expression const>> expr,
633+
std::vector<std::string> const& skip_names)
634+
{
635+
return names_from_expression(expr, skip_names).to_vector();
636+
}
637+
533638
} // namespace cudf::io::parquet::detail

cpp/src/io/parquet/reader_impl.cpp

+18-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
#include <rmm/resource_ref.hpp>
2828

29+
#include <thrust/iterator/counting_iterator.h>
30+
2931
#include <bitset>
3032
#include <numeric>
3133

@@ -436,9 +438,18 @@ reader::impl::impl(std::size_t chunk_read_limit,
436438
// Binary columns can be read as binary or strings
437439
_reader_column_schema = options.get_column_schema();
438440

439-
// Select only columns required by the options
441+
// Select only columns required by the options and filter
442+
std::optional<std::vector<std::string>> filter_columns_names;
443+
if (options.get_filter().has_value() and options.get_columns().has_value()) {
444+
// list, struct, dictionary are not supported by AST filter yet.
445+
// extract columns not present in get_columns() & keep count to remove at end.
446+
filter_columns_names =
447+
get_column_names_in_expression(options.get_filter(), *(options.get_columns()));
448+
_num_filter_only_columns = filter_columns_names->size();
449+
}
440450
std::tie(_input_columns, _output_buffers, _output_column_schemas) =
441451
_metadata->select_columns(options.get_columns(),
452+
filter_columns_names,
442453
options.is_enabled_use_pandas_metadata(),
443454
_strings_to_categorical,
444455
_timestamp_type.id());
@@ -572,7 +583,12 @@ table_with_metadata reader::impl::finalize_output(
572583
*read_table, filter.value().get(), _stream, rmm::mr::get_current_device_resource());
573584
CUDF_EXPECTS(predicate->view().type().id() == type_id::BOOL8,
574585
"Predicate filter should return a boolean");
575-
auto output_table = cudf::detail::apply_boolean_mask(*read_table, *predicate, _stream, _mr);
586+
// Exclude columns present in filter only in output
587+
auto counting_it = thrust::make_counting_iterator<std::size_t>(0);
588+
auto const output_count = read_table->num_columns() - _num_filter_only_columns;
589+
auto only_output = read_table->select(counting_it, counting_it + output_count);
590+
auto output_table = cudf::detail::apply_boolean_mask(only_output, *predicate, _stream, _mr);
591+
if (_num_filter_only_columns > 0) { out_metadata.schema_info.resize(output_count); }
576592
return {std::move(output_table), std::move(out_metadata)};
577593
}
578594
return {std::make_unique<table>(std::move(out_columns)), std::move(out_metadata)};

cpp/src/io/parquet/reader_impl.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,9 @@ class reader::impl {
368368
// _output_buffers associated metadata
369369
std::unique_ptr<table_metadata> _output_metadata;
370370

371+
// number of extra filter columns
372+
std::size_t _num_filter_only_columns{0};
373+
371374
bool _strings_to_categorical = false;
372375

373376
// are there usable page indexes available

cpp/src/io/parquet/reader_impl_chunking.cu

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17+
#include "compact_protocol_reader.hpp"
1718
#include "io/comp/nvcomp_adapter.hpp"
1819
#include "io/utilities/config_utils.hpp"
1920
#include "io/utilities/time_utils.cuh"

cpp/src/io/parquet/reader_impl_helpers.cpp

+24-13
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "reader_impl_helpers.hpp"
1818

19+
#include "compact_protocol_reader.hpp"
1920
#include "io/parquet/parquet.hpp"
2021
#include "io/utilities/base64_utilities.hpp"
2122
#include "io/utilities/row_selection.hpp"
@@ -25,6 +26,7 @@
2526
#include <thrust/iterator/counting_iterator.h>
2627
#include <thrust/iterator/zip_iterator.h>
2728

29+
#include <functional>
2830
#include <numeric>
2931
#include <regex>
3032

@@ -954,13 +956,15 @@ aggregate_reader_metadata::select_row_groups(
954956
int64_t skip_rows_opt,
955957
std::optional<size_type> const& num_rows_opt,
956958
host_span<data_type const> output_dtypes,
959+
host_span<int const> output_column_schemas,
957960
std::optional<std::reference_wrapper<ast::expression const>> filter,
958961
rmm::cuda_stream_view stream) const
959962
{
960963
std::optional<std::vector<std::vector<size_type>>> filtered_row_group_indices;
964+
// if filter is not empty, then gather row groups to read after predicate pushdown
961965
if (filter.has_value()) {
962-
filtered_row_group_indices =
963-
filter_row_groups(row_group_indices, output_dtypes, filter.value(), stream);
966+
filtered_row_group_indices = filter_row_groups(
967+
row_group_indices, output_dtypes, output_column_schemas, filter.value(), stream);
964968
if (filtered_row_group_indices.has_value()) {
965969
row_group_indices =
966970
host_span<std::vector<size_type> const>(filtered_row_group_indices.value());
@@ -1017,10 +1021,12 @@ aggregate_reader_metadata::select_row_groups(
10171021
std::tuple<std::vector<input_column_info>,
10181022
std::vector<cudf::io::detail::inline_column_buffer>,
10191023
std::vector<size_type>>
1020-
aggregate_reader_metadata::select_columns(std::optional<std::vector<std::string>> const& use_names,
1021-
bool include_index,
1022-
bool strings_to_categorical,
1023-
type_id timestamp_type_id) const
1024+
aggregate_reader_metadata::select_columns(
1025+
std::optional<std::vector<std::string>> const& use_names,
1026+
std::optional<std::vector<std::string>> const& filter_columns_names,
1027+
bool include_index,
1028+
bool strings_to_categorical,
1029+
type_id timestamp_type_id) const
10241030
{
10251031
auto find_schema_child = [&](SchemaElement const& schema_elem, std::string const& name) {
10261032
auto const& col_schema_idx =
@@ -1184,13 +1190,18 @@ aggregate_reader_metadata::select_columns(std::optional<std::vector<std::string>
11841190

11851191
// Find which of the selected paths are valid and get their schema index
11861192
std::vector<path_info> valid_selected_paths;
1187-
for (auto const& selected_path : *use_names) {
1188-
auto found_path =
1189-
std::find_if(all_paths.begin(), all_paths.end(), [&](path_info& valid_path) {
1190-
return valid_path.full_path == selected_path;
1191-
});
1192-
if (found_path != all_paths.end()) {
1193-
valid_selected_paths.push_back({selected_path, found_path->schema_idx});
1193+
// vector reference pushback (*use_names). If filter names passed.
1194+
std::vector<std::reference_wrapper<std::vector<std::string> const>> column_names{
1195+
*use_names, *filter_columns_names};
1196+
for (auto const& used_column_names : column_names) {
1197+
for (auto const& selected_path : used_column_names.get()) {
1198+
auto found_path =
1199+
std::find_if(all_paths.begin(), all_paths.end(), [&](path_info& valid_path) {
1200+
return valid_path.full_path == selected_path;
1201+
});
1202+
if (found_path != all_paths.end()) {
1203+
valid_selected_paths.push_back({selected_path, found_path->schema_idx});
1204+
}
11941205
}
11951206
}
11961207

0 commit comments

Comments
 (0)