Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added case() function #770

Merged
merged 11 commits into from
Jan 3, 2025
4 changes: 2 additions & 2 deletions src/datachain/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import case, literal
from sqlalchemy import literal

from . import array, path, random, string
from .aggregate import (
Expand All @@ -16,7 +16,7 @@
sum,
)
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
from .conditional import greatest, least
from .conditional import case, greatest, least
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
from .random import rand
from .string import byte_hamming_distance
Expand Down
52 changes: 52 additions & 0 deletions src/datachain/func/conditional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Union

from sqlalchemy import case as sql_case
from sqlalchemy.sql.elements import BinaryExpression

from datachain.lib.utils import DataChainParamsError
from datachain.sql.functions import conditional

from .func import ColT, Func
Expand Down Expand Up @@ -79,3 +83,51 @@ def least(*args: Union[ColT, float]) -> Func:
return Func(
"least", inner=conditional.least, cols=cols, args=func_args, result_type=int
)


def case(
*args: tuple[BinaryExpression, Union[int, float, complex, bool, str]], else_=None
) -> Func:
"""
Returns the case function that produces case expression which has a list of
conditions and corresponding results. Results can only be python primitives
like string, numbes or booleans. Result type is inferred from condition results.

Args:
args (tuple(BinaryExpression, value(str | int | float | complex | bool):
- Tuple of binary expression and values pair which corresponds to one
case condition - value
else_ (str | int | float | complex | bool): else value in case expression

Returns:
Func: A Func object that represents the case function.

Example:
```py
dc.mutate(
res=func.case((C("num") > 0, "P"), (C("num") < 0, "N"), else_="Z"),
)
```

Note:
- Result column will always be of the same type as the input columns.
"""
supported_types = [int, float, complex, str, bool]

type_ = type(else_) if else_ else None

if not args:
raise DataChainParamsError("Missing case statements")

for arg in args:
if type_ and not isinstance(arg[1], type_):
raise DataChainParamsError("Case statement values must be of the same type")
type_ = type(arg[1])

if type_ not in supported_types:
raise DataChainParamsError(
f"Case supports only python literals ({supported_types}) for values"
)

kwargs = {"else_": else_}
return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)
6 changes: 5 additions & 1 deletion src/datachain/func/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
inner: Callable,
cols: Optional[Sequence[ColT]] = None,
args: Optional[Sequence[Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
result_type: Optional["DataType"] = None,
is_array: bool = False,
is_window: bool = False,
Expand All @@ -45,6 +46,7 @@ def __init__(
self.inner = inner
self.cols = cols or []
self.args = args or []
self.kwargs = kwargs or {}
self.result_type = result_type
self.is_array = is_array
self.is_window = is_window
Expand All @@ -63,6 +65,7 @@ def over(self, window: "Window") -> "Func":
self.inner,
self.cols,
self.args,
self.kwargs,
self.result_type,
self.is_array,
self.is_window,
Expand Down Expand Up @@ -333,6 +336,7 @@ def label(self, label: str) -> "Func":
self.inner,
self.cols,
self.args,
self.kwargs,
self.result_type,
self.is_array,
self.is_window,
Expand Down Expand Up @@ -387,7 +391,7 @@ def get_col(col: ColT) -> ColT:
return col

cols = [get_col(col) for col in self._db_cols]
func_col = self.inner(*cols, *self.args)
func_col = self.inner(*cols, *self.args, **self.kwargs)

if self.is_window:
if not self.window:
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/sql/test_conditional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from datachain import func
from datachain.lib.utils import DataChainParamsError
from datachain.sql import select, values


Expand Down Expand Up @@ -64,3 +65,45 @@ def test_conditionals_with_multiple_rows(warehouse, expr, expected):
query = select(expr).select_from(values([(3, 5), (8, 7), (9, 1)], ["a", "b"]))
result = list(warehouse.db.execute(query))
assert result == expected


@pytest.mark.parametrize(
"val,expected",
[
(1, "A"),
(2, "D"),
(3, "B"),
(4, "D"),
(5, "C"),
(100, "D"),
],
)
def test_case(warehouse, val, expected):
query = select(
func.case(*[(val < 2, "A"), (2 < val < 4, "B"), (4 < val < 6, "C")], else_="D")
)
result = tuple(warehouse.db.execute(query))
assert result == ((expected,),)


def test_case_missing_statements(warehouse):
with pytest.raises(DataChainParamsError) as exc_info:
select(func.case(*[], else_="D"))
assert str(exc_info.value) == "Missing case statements"


def test_case_not_same_result_types(warehouse):
val = 2
with pytest.raises(DataChainParamsError) as exc_info:
select(func.case(*[(val > 1, "A"), (2 < val < 4, 5)], else_="D"))
assert str(exc_info.value) == "Case statement values must be of the same type"


def test_case_wrong_result_type(warehouse):
val = 2
with pytest.raises(DataChainParamsError) as exc_info:
select(func.case(*[(val > 1, ["a", "b"]), (2 < val < 4, [])], else_=[]))
assert str(exc_info.value) == (
"Case supports only python literals ([<class 'int'>, <class 'float'>, "
"<class 'complex'>, <class 'str'>, <class 'bool'>]) for values"
)
20 changes: 19 additions & 1 deletion tests/unit/test_func.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from sqlalchemy import Label

from datachain import DataChain
from datachain import C, DataChain
from datachain.func import (
bit_hamming_distance,
byte_hamming_distance,
case,
int_hash_64,
literal,
)
Expand Down Expand Up @@ -642,3 +643,20 @@ def test_byte_hamming_distance_mutate(dc):
.collect("test")
)
assert list(res) == [2, 1, 0, 1, 2]


@pytest.mark.parametrize(
"val,else_,type_",
[
["A", "D", str],
[1, 2, int],
[1.5, 2.5, float],
[True, False, bool],
],
)
def test_case_mutate(dc, val, else_, type_):
res = dc.mutate(test=case((C("num") < 2, val), else_=else_))
assert list(res.order_by("test").collect("test")) == sorted(
[val, else_, else_, else_, else_]
)
assert res.schema["test"] == type_
Loading