From 73323229d6c5450549741355150dd6ccb2c54706 Mon Sep 17 00:00:00 2001 From: Ivan Longin Date: Tue, 28 Jan 2025 12:01:35 +0100 Subject: [PATCH] Added ability for constant literals in `DataChain.mutate(...)` (#869) Added ability for constant literals in `DataChain.mutate(...)` --- src/datachain/lib/dc.py | 12 +++++++++++- tests/unit/test_func.py | 6 ++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 945ba3c1f..cd3f95f5d 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -25,6 +25,7 @@ from sqlalchemy.sql.sqltypes import NullType from datachain.dataset import DatasetRecord +from datachain.func import literal from datachain.func.base import Function from datachain.func.func import Func from datachain.lib.convert.python_to_sql import python_to_sql @@ -1129,8 +1130,12 @@ def mutate(self, **kwargs) -> "Self": ) ``` """ + primitives = (bool, str, int, float) + for col_name, expr in kwargs.items(): - if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType): + if not isinstance(expr, (*primitives, Column, Func)) and isinstance( + expr.type, NullType + ): raise DataChainColumnError( col_name, f"Cannot infer type with expression {expr}" ) @@ -1145,6 +1150,11 @@ def mutate(self, **kwargs) -> "Self": elif isinstance(value, Func): # adding new signal mutated[name] = value.get_column(schema) + elif isinstance(value, primitives): + # adding simple python constant primitives like str, int, float, bool + val = literal(value) + val.type = python_to_sql(type(value))() + mutated[name] = val # type: ignore[assignment] else: # adding new signal mutated[name] = value diff --git a/tests/unit/test_func.py b/tests/unit/test_func.py index 4b4237f12..ce14c85b2 100644 --- a/tests/unit/test_func.py +++ b/tests/unit/test_func.py @@ -425,6 +425,12 @@ def test_lt_mutate(dc): assert list(res) == [0, 0, 0, 0, 0] +@pytest.mark.parametrize("value", [1, 0.5, "a", True]) +def test_mutate_with_literal(dc, value): + res = dc.mutate(test=value).collect("test") + assert list(res) == [value] * 5 + + def test_le(): rnd1, rnd2 = rand(), rand()