Skip to content

Commit

Permalink
fix: Column name mismatch or not found in Parquet scan with filter (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Dec 6, 2024
1 parent dc54699 commit bedaefc
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 8 deletions.
10 changes: 3 additions & 7 deletions crates/polars-io/src/parquet/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,16 +350,12 @@ fn rg_to_dfs_prefiltered(
eprintln!("parquet live columns = {num_live_columns}, dead columns = {num_dead_columns}");
}

// @NOTE: This is probably already sorted, but just to be sure.
let mut projection_sorted = projection.to_vec();
projection_sorted.sort();

// We create two look-up tables that map indexes offsets into the live- and dead-set onto
// column indexes of the schema.
// Note: This may contain less than `num_live_columns` if there are hive columns involved.
let mut live_idx_to_col_idx = Vec::with_capacity(num_live_columns);
let mut dead_idx_to_col_idx = Vec::with_capacity(num_dead_columns);
for &i in projection_sorted.iter() {
let mut dead_idx_to_col_idx: Vec<usize> = Vec::with_capacity(num_dead_columns);
for &i in projection.iter() {
let name = schema.get_at_index(i).unwrap().0.as_str();

if live_variables.contains(name) {
Expand Down Expand Up @@ -547,7 +543,7 @@ fn rg_to_dfs_prefiltered(

assert_eq!(
live_columns.len() + dead_columns.len(),
projection_sorted.len() + hive_partition_columns.map_or(0, |x| x.len())
projection.len() + hive_partition_columns.map_or(0, |x| x.len())
);

let mut merged = Vec::with_capacity(live_columns.len() + dead_columns.len());
Expand Down
8 changes: 7 additions & 1 deletion py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,13 @@ def test_hive_partition_force_async_17155(tmp_path: Path, monkeypatch: Any) -> N
("scan_func", "write_func"),
[
(partial(pl.scan_parquet, parallel="row_groups"), pl.DataFrame.write_parquet),
(partial(pl.scan_parquet, parallel="prefiltered"), pl.DataFrame.write_parquet),
(partial(pl.scan_parquet, parallel="columns"), pl.DataFrame.write_parquet),
(
lambda *a, **kw: pl.scan_parquet(*a, parallel="prefiltered", **kw).filter(
pl.col("b") == pl.col("b")
),
pl.DataFrame.write_parquet,
),
(pl.scan_ipc, pl.DataFrame.write_ipc),
],
)
Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2533,3 +2533,49 @@ def test_categorical_parametric_sliced(s: pl.Series, start: int, length: int) ->
pl.scan_parquet(f).slice(start, length).collect(),
df.slice(start, length),
)


@pytest.mark.write_disk
def test_prefilter_with_projection_column_order_20175(tmp_path: Path) -> None:
path = tmp_path / "1"

pl.DataFrame({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}).write_parquet(path)

q = (
pl.scan_parquet(path, parallel="prefiltered")
.filter(pl.col("a") == 1)
.select("a", "d", "c")
)

assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "d": 1, "c": 1}))

f = io.BytesIO()

pl.read_csv(b"""\
c0,c1,c2,c3,c4,c5,c6,c7,c8,c9,c10
1,1,1,1,1,1,1,1,1,1,1
1,1,1,1,1,1,1,1,1,1,1
""").write_parquet(f)

f.seek(0)

q = (
pl.scan_parquet(
f,
rechunk=True,
parallel="prefiltered",
)
.filter(
pl.col("c0") == 1,
)
.select("c0", "c9", "c3")
)

assert_frame_equal(
q.collect(),
pl.read_csv(b"""\
c0,c9,c3
1,1,1
1,1,1
"""),
)

0 comments on commit bedaefc

Please sign in to comment.