Skip to content

Commit

Permalink
Add 'as_array' parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Feb 1, 2025
1 parent 2c3bb48 commit 3a61c92
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 27 deletions.
21 changes: 17 additions & 4 deletions crates/polars-plan/src/dsl/function_expr/range/linear_space.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ pub(super) fn linear_space(s: &[Column], closed: ClosedInterval) -> PolarsResult
}
}

pub(super) fn linear_spaces(s: &[Column], closed: ClosedInterval) -> PolarsResult<Column> {
pub(super) fn linear_spaces(
s: &[Column],
closed: ClosedInterval,
array_width: Option<usize>,
) -> PolarsResult<Column> {
let start = &s[0];
let end = &s[1];
let num_samples = &s[2];
Expand Down Expand Up @@ -108,7 +112,10 @@ pub(super) fn linear_spaces(s: &[Column], closed: ClosedInterval) -> PolarsResul
let out =
linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?;

let to_type = DataType::List(Box::new(DataType::Float32));
let to_type = array_width.map_or_else(
|| DataType::List(Box::new(DataType::Float32)),
|width| DataType::Array(Box::new(DataType::Float32), width),
);
out.cast(&to_type)
},
(mut dt, dt2) if dt.is_temporal() && dt == dt2 => {
Expand Down Expand Up @@ -147,7 +154,10 @@ pub(super) fn linear_spaces(s: &[Column], closed: ClosedInterval) -> PolarsResul
let out =
linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?;

let to_type = DataType::List(Box::new(dt.clone()));
let to_type = array_width.map_or_else(
|| DataType::List(Box::new(dt.clone())),
|width| DataType::Array(Box::new(dt.clone()), width),
);
out.cast(&to_type)
},
(dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => {
Expand Down Expand Up @@ -185,7 +195,10 @@ pub(super) fn linear_spaces(s: &[Column], closed: ClosedInterval) -> PolarsResul
let out =
linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?;

let to_type = DataType::List(Box::new(DataType::Float64));
let to_type = array_width.map_or_else(
|| DataType::List(Box::new(DataType::Float64)),
|width| DataType::Array(Box::new(DataType::Float64), width),
);
out.cast(&to_type)
},
}
Expand Down
23 changes: 17 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/range/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub enum RangeFunction {
},
LinearSpaces {
closed: ClosedInterval,
array_width: Option<usize>,
},
#[cfg(feature = "dtype-date")]
DateRange {
Expand Down Expand Up @@ -102,10 +103,17 @@ impl RangeFunction {
match self {
IntRange { dtype, .. } => mapper.with_dtype(dtype.clone()),
IntRanges => mapper.with_dtype(DataType::List(Box::new(DataType::Int64))),
LinearSpace { closed: _ } => mapper.with_dtype(map_linspace_dtype(&mapper)?),
LinearSpaces { closed: _ } => {
let inner_dt = map_linspace_dtype(&mapper)?;
mapper.with_dtype(DataType::List(Box::new(inner_dt)))
LinearSpace { .. } => mapper.with_dtype(map_linspace_dtype(&mapper)?),
LinearSpaces {
closed: _,
array_width,
} => {
let inner = Box::new(map_linspace_dtype(&mapper)?);
let dt = match array_width {
Some(width) => DataType::Array(inner, *width),
None => DataType::List(inner),
};
mapper.with_dtype(dt)
},
#[cfg(feature = "dtype-date")]
DateRange { .. } => mapper.with_dtype(DataType::Date),
Expand Down Expand Up @@ -181,8 +189,11 @@ impl From<RangeFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
LinearSpace { closed } => {
map_as_slice!(linear_space::linear_space, closed)
},
LinearSpaces { closed } => {
map_as_slice!(linear_space::linear_spaces, closed)
LinearSpaces {
closed,
array_width,
} => {
map_as_slice!(linear_space::linear_spaces, closed, array_width)
},
#[cfg(feature = "dtype-date")]
DateRange { interval, closed } => {
Expand Down
54 changes: 50 additions & 4 deletions crates/polars-plan/src/dsl/functions/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,63 @@ pub fn linear_space(start: Expr, end: Expr, num_samples: Expr, closed: ClosedInt
}
}

fn match_literal_int(expr: &Expr) -> PolarsResult<Option<usize>> {
match expr {
Expr::Literal(n) => Ok(match n {
LiteralValue::Int(v) => Some(usize::try_from(*v).unwrap()),
LiteralValue::UInt8(v) => Some(*v as usize),
LiteralValue::UInt16(v) => Some(*v as usize),
LiteralValue::UInt32(v) => Some(usize::try_from(*v).unwrap()),
LiteralValue::UInt64(v) => Some(usize::try_from(*v).unwrap()),
LiteralValue::Int8(v) => Some(usize::try_from(*v).unwrap()),
LiteralValue::Int16(v) => Some(usize::try_from(*v).unwrap()),
LiteralValue::Int32(v) => Some(usize::try_from(*v).unwrap()),
LiteralValue::Int64(v) => Some(usize::try_from(*v).unwrap()),
LiteralValue::Int128(v) => Some(usize::try_from(*v).unwrap()),
_ => {
polars_bail!(InvalidOperation: "'as_array' is only valid when 'num_samples' is constant")
},
}),
Expr::Cast { expr, dtype, .. } => {
// lit(x, dtype=...) are Cast expressions. We verify the inner expression is literal.
if dtype.is_integer() {
match_literal_int(expr)
} else {
polars_bail!(InvalidOperation: "'as_array' is only valid when 'num_samples' is constant")
}
},
_ => {
polars_bail!(InvalidOperation: "'as_array' is only valid when 'num_samples' is constant")
},
}
}

/// Create a column of linearly-spaced sequences from 'start', 'end', and 'num_samples' expressions.
pub fn linear_spaces(start: Expr, end: Expr, num_samples: Expr, closed: ClosedInterval) -> Expr {
pub fn linear_spaces(
start: Expr,
end: Expr,
num_samples: Expr,
closed: ClosedInterval,
as_array: bool,
) -> PolarsResult<Expr> {
let array_width = if as_array {
match_literal_int(&num_samples)?
} else {
None
};

let input = vec![start, end, num_samples];

Expr::Function {
Ok(Expr::Function {
input,
function: FunctionExpr::Range(RangeFunction::LinearSpaces { closed }),
function: FunctionExpr::Range(RangeFunction::LinearSpaces {
closed,
array_width,
}),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
..Default::default()
},
}
})
}
5 changes: 4 additions & 1 deletion crates/polars-python/src/functions/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,13 @@ pub fn linear_spaces(
end: PyExpr,
num_samples: PyExpr,
closed: Wrap<ClosedInterval>,
as_array: bool,
) -> PyResult<PyExpr> {
let start = start.inner;
let end = end.inner;
let num_samples = num_samples.inner;
let closed = closed.0;
Ok(dsl::linear_spaces(start, end, num_samples, closed).into())
let out =
dsl::linear_spaces(start, end, num_samples, closed, as_array).map_err(PyPolarsErr::from)?;
Ok(out.into())
}
22 changes: 19 additions & 3 deletions py-polars/polars/functions/range/linear_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def linear_spaces(
num_samples: int | IntoExprColumn,
*,
closed: ClosedInterval = ...,
as_array: bool = ...,
eager: Literal[False] = ...,
) -> Expr: ...

Expand All @@ -209,6 +210,7 @@ def linear_spaces(
num_samples: int | IntoExprColumn,
*,
closed: ClosedInterval = ...,
as_array: bool = ...,
eager: Literal[True],
) -> Series: ...

Expand All @@ -220,6 +222,7 @@ def linear_spaces(
num_samples: int | IntoExprColumn,
*,
closed: ClosedInterval = ...,
as_array: bool = ...,
eager: bool,
) -> Expr | Series: ...

Expand All @@ -230,6 +233,7 @@ def linear_spaces(
num_samples: int | IntoExprColumn,
*,
closed: ClosedInterval = "both",
as_array: bool = False,
eager: bool = False,
) -> Expr | Series:
"""
Expand All @@ -245,6 +249,8 @@ def linear_spaces(
Number of samples in the output sequence.
closed : {'both', 'left', 'right', 'none'}
Define which sides of the interval are closed (inclusive).
as_array
Return result as a fixed-length pl.Array. `num_samples` must be a constant.
eager
Evaluate immediately and return a `Series`.
If set to `False` (default), return an expression instead.
Expand All @@ -265,21 +271,31 @@ def linear_spaces(
Examples
--------
>>> df = pl.DataFrame({"start": [1, -1], "end": [3, 2], "step": [4, 5]})
>>> df.with_columns(linear_space=pl.linear_spaces("start", "end", "step"))
>>> df.with_columns(ls=pl.linear_spaces("start", "end", "step"))
shape: (2, 4)
┌───────┬─────┬──────┬────────────────────────┐
│ start ┆ end ┆ step ┆ linear_space
│ start ┆ end ┆ step ┆ ls
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ list[f64] │
╞═══════╪═════╪══════╪════════════════════════╡
│ 1 ┆ 3 ┆ 4 ┆ [1.0, 1.666667, … 3.0] │
│ -1 ┆ 2 ┆ 5 ┆ [-1.0, -0.25, … 2.0] │
└───────┴─────┴──────┴────────────────────────┘
>>> df.with_columns(ls=pl.linear_spaces("start", "end", 3, as_array=True))
shape: (2, 4)
┌───────┬─────┬──────┬──────────────────┐
│ start ┆ end ┆ step ┆ ls │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ array[f64, 3] │
╞═══════╪═════╪══════╪══════════════════╡
│ 1 ┆ 3 ┆ 4 ┆ [1.0, 2.0, 3.0] │
│ -1 ┆ 2 ┆ 5 ┆ [-1.0, 0.5, 2.0] │
└───────┴─────┴──────┴──────────────────┘
"""
start = parse_into_expression(start)
end = parse_into_expression(end)
num_samples = parse_into_expression(num_samples)
result = wrap_expr(plr.linear_spaces(start, end, num_samples, closed))
result = wrap_expr(plr.linear_spaces(start, end, num_samples, closed, as_array))

if eager:
return F.select(result).to_series()
Expand Down
Loading

0 comments on commit 3a61c92

Please sign in to comment.