From b019e4256ff42526baeb893b6100f4c8b9601900 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Thu, 5 Dec 2024 10:39:08 +0100 Subject: [PATCH] fix: Deal with masked out list elements (#20161) --- crates/polars-row/src/encode.rs | 109 +++++++++++++++++++--- py-polars/tests/unit/test_row_encoding.py | 12 +++ 2 files changed, 107 insertions(+), 14 deletions(-) diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index 702d2228d2f1..1de7c07c7dfc 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -61,11 +61,20 @@ pub fn convert_columns_amortized<'a>( fields: impl IntoIterator)> + Clone, rows: &mut RowsEncoded, ) { + let mut masked_out_max_length = 0; let mut row_widths = RowWidths::new(num_rows); let mut encoders = columns .iter() .zip(fields.clone()) - .map(|(column, (opt, dicts))| get_encoder(column.as_ref(), opt, dicts, &mut row_widths)) + .map(|(column, (opt, dicts))| { + get_encoder( + column.as_ref(), + opt, + dicts, + &mut row_widths, + &mut masked_out_max_length, + ) + }) .collect::>(); // Create an offsets array, we append 0 at the beginning here so it can serve as the final @@ -76,9 +85,10 @@ pub fn convert_columns_amortized<'a>( // Create a buffer without initializing everything to zero. let total_num_bytes = row_widths.sum(); - let mut out = Vec::::with_capacity(total_num_bytes); - let buffer = &mut out.spare_capacity_mut()[..total_num_bytes]; + let mut out = Vec::::with_capacity(total_num_bytes + masked_out_max_length); + let buffer = &mut out.spare_capacity_mut()[..total_num_bytes + masked_out_max_length]; + let masked_out_write_offset = total_num_bytes; let mut scratches = EncodeScratches::default(); for (encoder, (opt, dict)) in encoders.iter_mut().zip(fields) { unsafe { @@ -88,6 +98,7 @@ pub fn convert_columns_amortized<'a>( opt, dict, &mut offsets[1..], + masked_out_write_offset, &mut scratches, ) }; @@ -108,13 +119,20 @@ fn list_num_column_bytes( opt: RowEncodingOptions, dicts: Option<&RowEncodingCatOrder>, row_widths: &mut RowWidths, + masked_out_max_width: &mut usize, ) -> Encoder { let array = array.as_any().downcast_ref::>().unwrap(); let array = array.trim_to_normalized_offsets_recursive(); let values = array.values(); let mut list_row_widths = RowWidths::new(values.len()); - let encoder = get_encoder(values.as_ref(), opt, dicts, &mut list_row_widths); + let encoder = get_encoder( + values.as_ref(), + opt, + dicts, + &mut list_row_widths, + masked_out_max_width, + ); match array.validity() { None => row_widths.push_iter(array.offsets().offset_and_length_iter().map( @@ -133,6 +151,12 @@ fn list_num_column_bytes( .zip(validity.iter()) .map(|((offset, length), is_valid)| { if !is_valid { + if length > 0 { + for i in offset..offset + length { + *masked_out_max_width = + (*masked_out_max_width).max(list_row_widths.get(i)); + } + } return 1; } @@ -261,6 +285,7 @@ fn get_encoder( opt: RowEncodingOptions, dict: Option<&RowEncodingCatOrder>, row_widths: &mut RowWidths, + masked_out_max_width: &mut usize, ) -> Encoder { use ArrowDataType as D; let dtype = array.dtype(); @@ -275,8 +300,13 @@ fn get_encoder( debug_assert_eq!(array.values().len(), array.len() * width); let mut nested_row_widths = RowWidths::new(array.values().len()); - let nested_encoder = - get_encoder(array.values().as_ref(), opt, dict, &mut nested_row_widths); + let nested_encoder = get_encoder( + array.values().as_ref(), + opt, + dict, + &mut nested_row_widths, + masked_out_max_width, + ); Some(EncoderState::FixedSizeList( Box::new(nested_encoder), *width, @@ -297,6 +327,7 @@ fn get_encoder( opt, None, &mut RowWidths::new(row_widths.num_rows()), + masked_out_max_width, ) }) .collect(), @@ -310,6 +341,7 @@ fn get_encoder( opt, dict.as_ref(), &mut RowWidths::new(row_widths.num_rows()), + masked_out_max_width, ) }) .collect(), @@ -333,8 +365,13 @@ fn get_encoder( debug_assert_eq!(array.values().len(), array.len() * width); let mut nested_row_widths = RowWidths::new(array.values().len()); - let nested_encoder = - get_encoder(array.values().as_ref(), opt, dict, &mut nested_row_widths); + let nested_encoder = get_encoder( + array.values().as_ref(), + opt, + dict, + &mut nested_row_widths, + masked_out_max_width, + ); let mut fsl_row_widths = nested_row_widths.collapse_chunks(*width, array.len()); fsl_row_widths.push_constant(1); // validity byte @@ -358,13 +395,25 @@ fn get_encoder( match dict { None => { for array in array.values() { - let encoder = get_encoder(array.as_ref(), opt, None, row_widths); + let encoder = get_encoder( + array.as_ref(), + opt, + None, + row_widths, + masked_out_max_width, + ); nested_encoders.push(encoder); } }, Some(RowEncodingCatOrder::Struct(dicts)) => { for (array, dict) in array.values().iter().zip(dicts) { - let encoder = get_encoder(array.as_ref(), opt, dict.as_ref(), row_widths); + let encoder = get_encoder( + array.as_ref(), + opt, + dict.as_ref(), + row_widths, + masked_out_max_width, + ); nested_encoders.push(encoder); } }, @@ -376,8 +425,12 @@ fn get_encoder( } }, - D::List(_) => list_num_column_bytes::(array, opt, dict, row_widths), - D::LargeList(_) => list_num_column_bytes::(array, opt, dict, row_widths), + D::List(_) => { + list_num_column_bytes::(array, opt, dict, row_widths, masked_out_max_width) + }, + D::LargeList(_) => { + list_num_column_bytes::(array, opt, dict, row_widths, masked_out_max_width) + }, D::BinaryView => { let dc_array = array.as_any().downcast_ref::().unwrap(); @@ -654,6 +707,9 @@ unsafe fn encode_array( opt: RowEncodingOptions, dict: Option<&RowEncodingCatOrder>, offsets: &mut [usize], + masked_out_write_offset: usize, // Masked out values need to be written somewhere. We just + // reserved space at the end and tell all values to write + // there. scratches: &mut EncodeScratches, ) { let Some(state) = &encoder.state else { @@ -709,6 +765,13 @@ unsafe fn encode_array( if !is_valid { buffer[offsets[i]] = MaybeUninit::new(list_null_sentinel); offsets[i] += 1; + + // Values might have been masked out. + if length > 0 { + nested_offsets + .extend(std::iter::repeat_n(masked_out_write_offset, length)); + } + continue; } @@ -732,6 +795,7 @@ unsafe fn encode_array( opt, dict, nested_offsets, + masked_out_write_offset, &mut EncodeScratches::default(), ) }; @@ -756,6 +820,7 @@ unsafe fn encode_array( opt, dict, &mut child_offsets, + masked_out_write_offset, scratches, ); for (i, offset) in offsets.iter_mut().enumerate() { @@ -768,12 +833,28 @@ unsafe fn encode_array( match dict { None => { for array in arrays { - encode_array(buffer, array, opt, None, offsets, scratches); + encode_array( + buffer, + array, + opt, + None, + offsets, + masked_out_write_offset, + scratches, + ); } }, Some(RowEncodingCatOrder::Struct(dicts)) => { for (array, dict) in arrays.iter().zip(dicts) { - encode_array(buffer, array, opt, dict.as_ref(), offsets, scratches); + encode_array( + buffer, + array, + opt, + dict.as_ref(), + offsets, + masked_out_write_offset, + scratches, + ); } }, _ => unreachable!(), diff --git a/py-polars/tests/unit/test_row_encoding.py b/py-polars/tests/unit/test_row_encoding.py index 7fdf2c1a9a80..0705b7dd685f 100644 --- a/py-polars/tests/unit/test_row_encoding.py +++ b/py-polars/tests/unit/test_row_encoding.py @@ -316,6 +316,18 @@ def test_list_nulls(field: tuple[bool, bool, bool]) -> None: roundtrip_series_re([[None], [None, None], [None, None, None]], dtype, field) +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_masked_out_list_20151(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List(pl.Int64()) + + values = [[1, 2], None, [4, 5], [None, 3]] + + array_series = pl.Series(values, dtype=pl.Array(pl.Int64(), 2)) + list_from_array_series = array_series.cast(dtype) + + roundtrip_series_re(list_from_array_series, dtype, field) + + def test_int_after_null() -> None: roundtrip_re( pl.DataFrame(