Skip to content

Commit

Permalink
fix: Fix merge_sorted producing incorrect results or panicking for …
Browse files Browse the repository at this point in the history
…some logical types (#21018)
  • Loading branch information
lukemanley authored Jan 31, 2025
1 parent 233f9b3 commit 8837dc8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
11 changes: 10 additions & 1 deletion crates/polars-ops/src/frame/join/merge_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ pub fn _merge_sorted_dfs(
ComputeError: "merge-sort datatype mismatch: {} != {}", dtype_lhs, dtype_rhs
);

if dtype_lhs.is_categorical() {
let rev_map_lhs = left_s.categorical().unwrap().get_rev_map();
let rev_map_rhs = right_s.categorical().unwrap().get_rev_map();
polars_ensure!(
rev_map_lhs.same_src(rev_map_rhs),
ComputeError: "can only merge-sort categoricals with the same categories"
);
}

// If one frame is empty, we can return the other immediately.
if right_s.is_empty() {
return Ok(left.clone());
Expand All @@ -41,7 +50,7 @@ pub fn _merge_sorted_dfs(
rhs_phys.as_materialized_series(),
&merge_indicator,
)?);
let mut out = out.cast(lhs.dtype()).unwrap();
let mut out = unsafe { out.from_physical_unchecked(lhs.dtype()) }.unwrap();
out.rename(lhs.name().clone());
Ok(out)
})
Expand Down
27 changes: 26 additions & 1 deletion py-polars/tests/unit/operations/test_merge_sorted.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

import polars as pl
from polars.testing import assert_frame_equal
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal, assert_series_equal

left = pl.DataFrame({"a": [42, 13, 37], "b": [3, 8, 9]})
right = pl.DataFrame({"a": [5, 10, 1996], "b": [1, 5, 7]})
Expand Down Expand Up @@ -43,3 +44,27 @@ def test_merge_sorted_proj_pd() -> None:
lf.select("a").collect(),
lf.collect().select("a"),
)


@pytest.mark.parametrize("precision", [2, 3])
def test_merge_sorted_decimal_20990(precision: int) -> None:
dtype = pl.Decimal(precision=precision, scale=1)
s = pl.Series("a", ["1.0", "0.1"], dtype)
df = pl.DataFrame([s.sort()])
result = df.lazy().merge_sorted(df.lazy(), "a").collect().get_column("a")
expected = pl.Series("a", ["0.1", "0.1", "1.0", "1.0"], dtype)
assert_series_equal(result, expected)


def test_merge_sorted_categorical() -> None:
left = pl.Series("a", ["a", "b"], pl.Categorical()).sort().to_frame()
right = pl.Series("a", ["a", "b", "b"], pl.Categorical()).sort().to_frame()
result = left.merge_sorted(right, "a").get_column("a")
expected = pl.Series("a", ["a", "a", "b", "b", "b"], pl.Categorical())
assert_series_equal(result, expected)

right = pl.Series("a", ["b", "a"], pl.Categorical()).sort().to_frame()
with pytest.raises(
ComputeError, match="can only merge-sort categoricals with the same categories"
):
left.merge_sorted(right, "a")

0 comments on commit 8837dc8

Please sign in to comment.