From 7ef977b1a4e652d2aae3c2992da72866e72d5cd6 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Tue, 26 Nov 2024 22:23:42 +1100 Subject: [PATCH] c --- .../src/chunked_array/list/namespace.rs | 11 ++++++- .../operations/namespaces/list/test_list.py | 30 ++++++++++++++----- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index fc498c25ce44..e2fffa4beb10 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -326,6 +326,13 @@ pub trait ListNameSpaceImpl: AsList { fn lst_lengths(&self) -> IdxCa { let ca = self.as_list(); + + let ca_validity = ca.rechunk_validity(); + + if ca_validity.as_ref().map_or(false, |x| x.set_bits() == 0) { + return IdxCa::full_null(ca.name().clone(), ca.len()); + } + let mut lengths = Vec::with_capacity(ca.len()); ca.downcast_iter().for_each(|arr| { let offsets = arr.offsets().as_slice(); @@ -335,7 +342,9 @@ pub trait ListNameSpaceImpl: AsList { last = *o; } }); - IdxCa::from_vec(ca.name().clone(), lengths) + + let arr = IdxArr::from_vec(lengths).with_validity(ca_validity); + IdxCa::with_chunk(ca.name().clone(), arr) } /// Get the value by index in the sublists. diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index 1264a5ed8773..9e5eb7611b8d 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -754,13 +754,6 @@ def test_utf8_empty_series_arg_min_max_10703() -> None: } -def test_list_len() -> None: - s = pl.Series([[1, 2, None], [5]]) - result = s.list.len() - expected = pl.Series([3, 1], dtype=pl.UInt32) - assert_series_equal(result, expected) - - def test_list_to_array() -> None: data = [[1.0, 2.0], [3.0, 4.0]] s = pl.Series(data, dtype=pl.List(pl.Float32)) @@ -804,6 +797,11 @@ def test_list_to_array_wrong_dtype() -> None: def test_list_lengths() -> None: + s = pl.Series([[1, 2, None], [5]]) + result = s.list.len() + expected = pl.Series([3, 1], dtype=pl.UInt32) + assert_series_equal(result, expected) + s = pl.Series("a", [[1, 2], [1, 2, 3]]) assert_series_equal(s.list.len(), pl.Series("a", [2, 3], dtype=pl.UInt32)) df = pl.DataFrame([s]) @@ -811,6 +809,24 @@ def test_list_lengths() -> None: df.select(pl.col("a").list.len())["a"], pl.Series("a", [2, 3], dtype=pl.UInt32) ) + assert_series_equal( + pl.select( + pl.when(pl.Series([True, False])) + .then(pl.Series([[1, 1], [1, 1]])) + .list.len() + ).to_series(), + pl.Series([2, None], dtype=pl.UInt32), + ) + + assert_series_equal( + pl.select( + pl.when(pl.Series([False, False])) + .then(pl.Series([[1, 1], [1, 1]])) + .list.len() + ).to_series(), + pl.Series([None, None], dtype=pl.UInt32), + ) + def test_list_arithmetic() -> None: s = pl.Series("a", [[1, 2], [1, 2, 3]])