diff --git a/crates/polars-plan/src/dsl/function_expr/range/linear_space.rs b/crates/polars-plan/src/dsl/function_expr/range/linear_space.rs index a8ed034f0266..32119dcd7a7e 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/linear_space.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/linear_space.rs @@ -73,7 +73,11 @@ pub(super) fn linear_space(s: &[Column], closed: ClosedInterval) -> PolarsResult } } -pub(super) fn linear_spaces(s: &[Column], closed: ClosedInterval) -> PolarsResult { +pub(super) fn linear_spaces( + s: &[Column], + closed: ClosedInterval, + array_width: Option, +) -> PolarsResult { let start = &s[0]; let end = &s[1]; let num_samples = &s[2]; @@ -108,7 +112,10 @@ pub(super) fn linear_spaces(s: &[Column], closed: ClosedInterval) -> PolarsResul let out = linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?; - let to_type = DataType::List(Box::new(DataType::Float32)); + let to_type = array_width.map_or_else( + || DataType::List(Box::new(DataType::Float32)), + |width| DataType::Array(Box::new(DataType::Float32), width), + ); out.cast(&to_type) }, (mut dt, dt2) if dt.is_temporal() && dt == dt2 => { @@ -147,7 +154,10 @@ pub(super) fn linear_spaces(s: &[Column], closed: ClosedInterval) -> PolarsResul let out = linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?; - let to_type = DataType::List(Box::new(dt.clone())); + let to_type = array_width.map_or_else( + || DataType::List(Box::new(dt.clone())), + |width| DataType::Array(Box::new(dt.clone()), width), + ); out.cast(&to_type) }, (dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => { @@ -185,7 +195,10 @@ pub(super) fn linear_spaces(s: &[Column], closed: ClosedInterval) -> PolarsResul let out = linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?; - let to_type = DataType::List(Box::new(DataType::Float64)); + let to_type = array_width.map_or_else( + || DataType::List(Box::new(DataType::Float64)), + |width| DataType::Array(Box::new(DataType::Float64), width), + ); out.cast(&to_type) }, } diff --git a/crates/polars-plan/src/dsl/function_expr/range/mod.rs b/crates/polars-plan/src/dsl/function_expr/range/mod.rs index 35d1a33e27df..e3bc12bbb6a3 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/mod.rs @@ -35,6 +35,7 @@ pub enum RangeFunction { }, LinearSpaces { closed: ClosedInterval, + array_width: Option, }, #[cfg(feature = "dtype-date")] DateRange { @@ -102,10 +103,17 @@ impl RangeFunction { match self { IntRange { dtype, .. } => mapper.with_dtype(dtype.clone()), IntRanges => mapper.with_dtype(DataType::List(Box::new(DataType::Int64))), - LinearSpace { closed: _ } => mapper.with_dtype(map_linspace_dtype(&mapper)?), - LinearSpaces { closed: _ } => { - let inner_dt = map_linspace_dtype(&mapper)?; - mapper.with_dtype(DataType::List(Box::new(inner_dt))) + LinearSpace { .. } => mapper.with_dtype(map_linspace_dtype(&mapper)?), + LinearSpaces { + closed: _, + array_width, + } => { + let inner = Box::new(map_linspace_dtype(&mapper)?); + let dt = match array_width { + Some(width) => DataType::Array(inner, *width), + None => DataType::List(inner), + }; + mapper.with_dtype(dt) }, #[cfg(feature = "dtype-date")] DateRange { .. } => mapper.with_dtype(DataType::Date), @@ -181,8 +189,11 @@ impl From for SpecialEq> { LinearSpace { closed } => { map_as_slice!(linear_space::linear_space, closed) }, - LinearSpaces { closed } => { - map_as_slice!(linear_space::linear_spaces, closed) + LinearSpaces { + closed, + array_width, + } => { + map_as_slice!(linear_space::linear_spaces, closed, array_width) }, #[cfg(feature = "dtype-date")] DateRange { interval, closed } => { diff --git a/crates/polars-plan/src/dsl/functions/range.rs b/crates/polars-plan/src/dsl/functions/range.rs index 80061432ebb1..0f218e04c67b 100644 --- a/crates/polars-plan/src/dsl/functions/range.rs +++ b/crates/polars-plan/src/dsl/functions/range.rs @@ -176,17 +176,63 @@ pub fn linear_space(start: Expr, end: Expr, num_samples: Expr, closed: ClosedInt } } +fn match_literal_int(expr: &Expr) -> PolarsResult> { + match expr { + Expr::Literal(n) => Ok(match n { + LiteralValue::Int(v) => Some(usize::try_from(*v).unwrap()), + LiteralValue::UInt8(v) => Some(*v as usize), + LiteralValue::UInt16(v) => Some(*v as usize), + LiteralValue::UInt32(v) => Some(usize::try_from(*v).unwrap()), + LiteralValue::UInt64(v) => Some(usize::try_from(*v).unwrap()), + LiteralValue::Int8(v) => Some(usize::try_from(*v).unwrap()), + LiteralValue::Int16(v) => Some(usize::try_from(*v).unwrap()), + LiteralValue::Int32(v) => Some(usize::try_from(*v).unwrap()), + LiteralValue::Int64(v) => Some(usize::try_from(*v).unwrap()), + LiteralValue::Int128(v) => Some(usize::try_from(*v).unwrap()), + _ => { + polars_bail!(InvalidOperation: "'as_array' is only valid when 'num_samples' is constant") + }, + }), + Expr::Cast { expr, dtype, .. } => { + // lit(x, dtype=...) are Cast expressions. We verify the inner expression is literal. + if dtype.is_integer() { + match_literal_int(expr) + } else { + polars_bail!(InvalidOperation: "'as_array' is only valid when 'num_samples' is constant") + } + }, + _ => { + polars_bail!(InvalidOperation: "'as_array' is only valid when 'num_samples' is constant") + }, + } +} + /// Create a column of linearly-spaced sequences from 'start', 'end', and 'num_samples' expressions. -pub fn linear_spaces(start: Expr, end: Expr, num_samples: Expr, closed: ClosedInterval) -> Expr { +pub fn linear_spaces( + start: Expr, + end: Expr, + num_samples: Expr, + closed: ClosedInterval, + as_array: bool, +) -> PolarsResult { + let array_width = if as_array { + match_literal_int(&num_samples)? + } else { + None + }; + let input = vec![start, end, num_samples]; - Expr::Function { + Ok(Expr::Function { input, - function: FunctionExpr::Range(RangeFunction::LinearSpaces { closed }), + function: FunctionExpr::Range(RangeFunction::LinearSpaces { + closed, + array_width, + }), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, ..Default::default() }, - } + }) } diff --git a/crates/polars-python/src/functions/range.rs b/crates/polars-python/src/functions/range.rs index 5295dade30e7..6695205487ca 100644 --- a/crates/polars-python/src/functions/range.rs +++ b/crates/polars-python/src/functions/range.rs @@ -179,10 +179,13 @@ pub fn linear_spaces( end: PyExpr, num_samples: PyExpr, closed: Wrap, + as_array: bool, ) -> PyResult { let start = start.inner; let end = end.inner; let num_samples = num_samples.inner; let closed = closed.0; - Ok(dsl::linear_spaces(start, end, num_samples, closed).into()) + let out = + dsl::linear_spaces(start, end, num_samples, closed, as_array).map_err(PyPolarsErr::from)?; + Ok(out.into()) } diff --git a/py-polars/polars/functions/range/linear_space.py b/py-polars/polars/functions/range/linear_space.py index df00be914445..5a7c49190718 100644 --- a/py-polars/polars/functions/range/linear_space.py +++ b/py-polars/polars/functions/range/linear_space.py @@ -198,6 +198,7 @@ def linear_spaces( num_samples: int | IntoExprColumn, *, closed: ClosedInterval = ..., + as_array: bool = ..., eager: Literal[False] = ..., ) -> Expr: ... @@ -209,6 +210,7 @@ def linear_spaces( num_samples: int | IntoExprColumn, *, closed: ClosedInterval = ..., + as_array: bool = ..., eager: Literal[True], ) -> Series: ... @@ -220,6 +222,7 @@ def linear_spaces( num_samples: int | IntoExprColumn, *, closed: ClosedInterval = ..., + as_array: bool = ..., eager: bool, ) -> Expr | Series: ... @@ -230,6 +233,7 @@ def linear_spaces( num_samples: int | IntoExprColumn, *, closed: ClosedInterval = "both", + as_array: bool = False, eager: bool = False, ) -> Expr | Series: """ @@ -245,6 +249,8 @@ def linear_spaces( Number of samples in the output sequence. closed : {'both', 'left', 'right', 'none'} Define which sides of the interval are closed (inclusive). + as_array + Return result as a fixed-length pl.Array. `num_samples` must be a constant. eager Evaluate immediately and return a `Series`. If set to `False` (default), return an expression instead. @@ -265,21 +271,31 @@ def linear_spaces( Examples -------- >>> df = pl.DataFrame({"start": [1, -1], "end": [3, 2], "step": [4, 5]}) - >>> df.with_columns(linear_space=pl.linear_spaces("start", "end", "step")) + >>> df.with_columns(ls=pl.linear_spaces("start", "end", "step")) shape: (2, 4) ┌───────┬─────┬──────┬────────────────────────┐ - │ start ┆ end ┆ step ┆ linear_space │ + │ start ┆ end ┆ step ┆ ls │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ list[f64] │ ╞═══════╪═════╪══════╪════════════════════════╡ │ 1 ┆ 3 ┆ 4 ┆ [1.0, 1.666667, … 3.0] │ │ -1 ┆ 2 ┆ 5 ┆ [-1.0, -0.25, … 2.0] │ └───────┴─────┴──────┴────────────────────────┘ + >>> df.with_columns(ls=pl.linear_spaces("start", "end", 3, as_array=True)) + shape: (2, 4) + ┌───────┬─────┬──────┬──────────────────┐ + │ start ┆ end ┆ step ┆ ls │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ array[f64, 3] │ + ╞═══════╪═════╪══════╪══════════════════╡ + │ 1 ┆ 3 ┆ 4 ┆ [1.0, 2.0, 3.0] │ + │ -1 ┆ 2 ┆ 5 ┆ [-1.0, 0.5, 2.0] │ + └───────┴─────┴──────┴──────────────────┘ """ start = parse_into_expression(start) end = parse_into_expression(end) num_samples = parse_into_expression(num_samples) - result = wrap_expr(plr.linear_spaces(start, end, num_samples, closed)) + result = wrap_expr(plr.linear_spaces(start, end, num_samples, closed, as_array)) if eager: return F.select(result).to_series() diff --git a/py-polars/tests/unit/functions/range/test_linear_space.py b/py-polars/tests/unit/functions/range/test_linear_space.py index 4ca23ec49485..70e277c69cd0 100644 --- a/py-polars/tests/unit/functions/range/test_linear_space.py +++ b/py-polars/tests/unit/functions/range/test_linear_space.py @@ -12,6 +12,7 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: + from polars import Expr from polars._typing import ClosedInterval, PolarsDataType @@ -260,6 +261,168 @@ def test_linear_spaces_values(interval: ClosedInterval) -> None: assert_series_equal(row, expected) +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +def test_linear_spaces_one_numeric(interval: ClosedInterval) -> None: + # Two expressions, one numeric input + starts = [1, 2] + ends = [5, 6] + num_samples = [3, 4] + lf = pl.LazyFrame( + { + "start": starts, + "end": ends, + "num_samples": num_samples, + } + ) + result = lf.select( + pl.linear_spaces(starts[0], "end", "num_samples", closed=interval).alias( + "start" + ), + pl.linear_spaces("start", ends[0], "num_samples", closed=interval).alias("end"), + pl.linear_spaces("start", "end", num_samples[0], closed=interval).alias( + "num_samples" + ), + ) + expected_start0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_start1 = pl.linear_space( + starts[0], ends[1], num_samples[1], closed=interval, eager=True + ) + expected_end0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_end1 = pl.linear_space( + starts[1], ends[0], num_samples[1], closed=interval, eager=True + ) + expected_ns0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_ns1 = pl.linear_space( + starts[1], ends[1], num_samples[0], closed=interval, eager=True + ) + expected = pl.LazyFrame( + { + "start": [expected_start0, expected_start1], + "end": [expected_end0, expected_end1], + "num_samples": [expected_ns0, expected_ns1], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +def test_linear_spaces_two_numeric(interval: ClosedInterval) -> None: + # One expression, two numeric inputs + starts = [1, 2] + ends = [5, 6] + num_samples = [3, 4] + lf = pl.LazyFrame( + { + "start": starts, + "end": ends, + "num_samples": num_samples, + } + ) + result = lf.select( + pl.linear_spaces("start", ends[0], num_samples[0], closed=interval).alias( + "start" + ), + pl.linear_spaces(starts[0], "end", num_samples[0], closed=interval).alias( + "end" + ), + pl.linear_spaces(starts[0], ends[0], "num_samples", closed=interval).alias( + "num_samples" + ), + ) + expected_start0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_start1 = pl.linear_space( + starts[1], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_end0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_end1 = pl.linear_space( + starts[0], ends[1], num_samples[0], closed=interval, eager=True + ) + expected_ns0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_ns1 = pl.linear_space( + starts[0], ends[0], num_samples[1], closed=interval, eager=True + ) + expected = pl.LazyFrame( + { + "start": [expected_start0, expected_start1], + "end": [expected_end0, expected_end1], + "num_samples": [expected_ns0, expected_ns1], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "num_samples", + [ + 5, + pl.lit(5), + pl.lit(5, dtype=pl.UInt8), + pl.lit(5, dtype=pl.UInt16), + pl.lit(5, dtype=pl.UInt32), + pl.lit(5, dtype=pl.UInt64), + pl.lit(5, dtype=pl.Int8), + pl.lit(5, dtype=pl.Int16), + pl.lit(5, dtype=pl.Int32), + pl.lit(5, dtype=pl.Int64), + ], +) +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +@pytest.mark.parametrize( + "dtype", + [ + pl.Float32, + pl.Float64, + pl.Datetime, + ], +) +def test_linear_spaces_as_array( + interval: ClosedInterval, + num_samples: int | Expr, + dtype: PolarsDataType, +) -> None: + starts = [1, 2] + ends = [5, 6] + lf = pl.LazyFrame( + { + "start": pl.Series(starts, dtype=dtype), + "end": pl.Series(ends, dtype=dtype), + } + ) + result = lf.select( + a=pl.linear_spaces("start", "end", num_samples, closed=interval, as_array=True) + ) + expected_0 = pl.linear_space( + pl.lit(starts[0], dtype=dtype), + pl.lit(ends[0], dtype=dtype), + num_samples, + closed=interval, + eager=True, + ) + expected_1 = pl.linear_space( + pl.lit(starts[1], dtype=dtype), + pl.lit(ends[1], dtype=dtype), + num_samples, + closed=interval, + eager=True, + ) + expected = pl.LazyFrame( + {"a": pl.Series([expected_0, expected_1], dtype=pl.Array(dtype, 5))} + ) + assert_frame_equal(result, expected) + + @pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) def test_linear_spaces_numeric_input(interval: ClosedInterval) -> None: starts = [1, 2] @@ -278,15 +441,26 @@ def test_linear_spaces_numeric_input(interval: ClosedInterval) -> None: pl.linear_spaces("start", 10, "num_samples", closed=interval).alias("end"), pl.linear_spaces("start", "end", 8, closed=interval).alias("num_samples"), ) - args = {"closed": interval, "eager": True} - expected_all0 = pl.linear_space(starts[0], ends[0], num_samples[0], **args) # type: ignore[arg-type] - expected_all1 = pl.linear_space(starts[1], ends[1], num_samples[1], **args) # type: ignore[arg-type] - expected_start0 = pl.linear_space(0, ends[0], num_samples[0], **args) # type: ignore[arg-type] - expected_start1 = pl.linear_space(0, ends[1], num_samples[1], **args) # type: ignore[arg-type] - expected_end0 = pl.linear_space(starts[0], 10, num_samples[0], **args) # type: ignore[arg-type] - expected_end1 = pl.linear_space(starts[1], 10, num_samples[1], **args) # type: ignore[arg-type] - expected_ns0 = pl.linear_space(starts[0], ends[0], 8, **args) # type: ignore[arg-type] - expected_ns1 = pl.linear_space(starts[1], ends[1], 8, **args) # type: ignore[arg-type] + expected_all0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_all1 = pl.linear_space( + starts[1], ends[1], num_samples[1], closed=interval, eager=True + ) + expected_start0 = pl.linear_space( + 0, ends[0], num_samples[0], closed=interval, eager=True + ) + expected_start1 = pl.linear_space( + 0, ends[1], num_samples[1], closed=interval, eager=True + ) + expected_end0 = pl.linear_space( + starts[0], 10, num_samples[0], closed=interval, eager=True + ) + expected_end1 = pl.linear_space( + starts[1], 10, num_samples[1], closed=interval, eager=True + ) + expected_ns0 = pl.linear_space(starts[0], ends[0], 8, closed=interval, eager=True) + expected_ns1 = pl.linear_space(starts[1], ends[1], 8, closed=interval, eager=True) expected = pl.LazyFrame( { "all": [expected_all0, expected_all1],