diff --git a/examples/get_started/common_sql_functions.py b/examples/get_started/common_sql_functions.py index 7f6b90ef8..dcad97735 100644 --- a/examples/get_started/common_sql_functions.py +++ b/examples/get_started/common_sql_functions.py @@ -9,7 +9,7 @@ def num_chars_udf(file): return ([],) -dc = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/") +dc = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/", anon=True) dc.map(num_chars_udf, params=["file"], output={"num_chars": list[str]}).select( "file.path", "num_chars" ).show(5) @@ -32,6 +32,12 @@ def num_chars_udf(file): .show(5) ) +parts = string.split(path.name(C("file.path")), ".") +chain = dc.mutate( + isdog=array.contains(parts, "dog"), + iscat=array.contains(parts, "cat"), +) +chain.select("file.path", "isdog", "iscat").show(5) chain = dc.mutate( a=array.length(string.split("file.path", "/")), @@ -79,6 +85,15 @@ def num_chars_udf(file): 3 dogs-and-cats/cat.10.json cat.10 json 4 dogs-and-cats/cat.100.jpg cat.100 jpg +[Limited by 5 rows] + file isdog iscat + path +0 dogs-and-cats/cat.1.jpg 0 1 +1 dogs-and-cats/cat.1.json 0 1 +2 dogs-and-cats/cat.10.jpg 0 1 +3 dogs-and-cats/cat.10.json 0 1 +4 dogs-and-cats/cat.100.jpg 0 1 + [Limited by 5 rows] Processed: 400 rows [00:00, 16496.93 rows/s] a b greatest least diff --git a/src/datachain/func/__init__.py b/src/datachain/func/__init__.py index fc7249e0f..da9474055 100644 --- a/src/datachain/func/__init__.py +++ b/src/datachain/func/__init__.py @@ -15,7 +15,7 @@ row_number, sum, ) -from .array import cosine_distance, euclidean_distance, length, sip_hash_64 +from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64 from .conditional import case, greatest, ifelse, isnone, least from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64 from .random import rand @@ -34,6 +34,7 @@ "case", "collect", "concat", + "contains", "cosine_distance", "count", "dense_rank", diff --git a/src/datachain/func/array.py b/src/datachain/func/array.py index ae3614fb9..9b9ee978c 100644 --- a/src/datachain/func/array.py +++ b/src/datachain/func/array.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Union +from typing import Any, Union from datachain.sql.functions import array @@ -140,6 +140,44 @@ def length(arg: Union[str, Sequence, Func]) -> Func: return Func("length", inner=array.length, cols=cols, args=args, result_type=int) +def contains(arr: Union[str, Sequence, Func], elem: Any) -> Func: + """ + Checks whether the `arr` array has the `elem` element. + + Args: + arr (str | Sequence | Func): Array to check for the element. + If a string is provided, it is assumed to be the name of the array column. + If a sequence is provided, it is assumed to be an array of values. + If a Func is provided, it is assumed to be a function returning an array. + elem (Any): Element to check for in the array. + + Returns: + Func: A Func object that represents the contains function. Result of the + function will be 1 if the element is present in the array, and 0 otherwise. + + Example: + ```py + dc.mutate( + contains1=func.array.contains("signal.values", 3), + contains2=func.array.contains([1, 2, 3, 4, 5], 7), + ) + ``` + """ + + def inner(arg): + is_json = type(elem) in [list, dict] + return array.contains(arg, elem, is_json) + + if isinstance(arr, (str, Func)): + cols = [arr] + args = None + else: + cols = None + args = [arr] + + return Func("contains", inner=inner, cols=cols, args=args, result_type=int) + + def sip_hash_64(arg: Union[str, Sequence]) -> Func: """ Computes the SipHash-64 hash of the array. diff --git a/src/datachain/sql/functions/array.py b/src/datachain/sql/functions/array.py index ab7cfc8ad..815d5553d 100644 --- a/src/datachain/sql/functions/array.py +++ b/src/datachain/sql/functions/array.py @@ -1,6 +1,6 @@ from sqlalchemy.sql.functions import GenericFunction -from datachain.sql.types import Float, Int64 +from datachain.sql.types import Boolean, Float, Int64 from datachain.sql.utils import compiler_not_implemented @@ -37,6 +37,17 @@ class length(GenericFunction): # noqa: N801 inherit_cache = True +class contains(GenericFunction): # noqa: N801 + """ + Checks if element is in the array. + """ + + type = Boolean() + package = "array" + name = "contains" + inherit_cache = True + + class sip_hash_64(GenericFunction): # noqa: N801 """ Computes the SipHash-64 hash of the array. @@ -51,4 +62,5 @@ class sip_hash_64(GenericFunction): # noqa: N801 compiler_not_implemented(cosine_distance) compiler_not_implemented(euclidean_distance) compiler_not_implemented(length) +compiler_not_implemented(contains) compiler_not_implemented(sip_hash_64) diff --git a/src/datachain/sql/sqlite/base.py b/src/datachain/sql/sqlite/base.py index d9a285543..3c418e5bf 100644 --- a/src/datachain/sql/sqlite/base.py +++ b/src/datachain/sql/sqlite/base.py @@ -87,6 +87,7 @@ def setup(): compiles(sql_path.file_stem, "sqlite")(compile_path_file_stem) compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext) compiles(array.length, "sqlite")(compile_array_length) + compiles(array.contains, "sqlite")(compile_array_contains) compiles(string.length, "sqlite")(compile_string_length) compiles(string.split, "sqlite")(compile_string_split) compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace) @@ -269,13 +270,16 @@ def create_string_functions(conn): _registered_function_creators["string_functions"] = create_string_functions - has_json_extension = functions_exist(["json_array_length"]) + has_json_extension = functions_exist(["json_array_length", "json_array_contains"]) if not has_json_extension: def create_json_functions(conn): conn.create_function( "json_array_length", 1, py_json_array_length, deterministic=True ) + conn.create_function( + "json_array_contains", 3, py_json_array_contains, deterministic=True + ) _registered_function_creators["json_functions"] = create_json_functions @@ -428,10 +432,22 @@ def py_json_array_length(arr): return len(orjson.loads(arr)) +def py_json_array_contains(arr, value, is_json): + if is_json: + value = orjson.loads(value) + return value in orjson.loads(arr) + + def compile_array_length(element, compiler, **kwargs): return compiler.process(func.json_array_length(*element.clauses.clauses), **kwargs) +def compile_array_contains(element, compiler, **kwargs): + return compiler.process( + func.json_array_contains(*element.clauses.clauses), **kwargs + ) + + def compile_string_length(element, compiler, **kwargs): return compiler.process(func.length(*element.clauses.clauses), **kwargs) diff --git a/src/datachain/sql/sqlite/types.py b/src/datachain/sql/sqlite/types.py index 7e367529e..008ae5b10 100644 --- a/src/datachain/sql/sqlite/types.py +++ b/src/datachain/sql/sqlite/types.py @@ -31,6 +31,10 @@ def adapt_array(arr): return orjson.dumps(arr).decode("utf-8") +def adapt_dict(dct): + return orjson.dumps(dct).decode("utf-8") + + def convert_array(arr): return orjson.loads(arr) @@ -52,6 +56,7 @@ def adapt_np_generic(val): def register_type_converters(): sqlite3.register_adapter(list, adapt_array) + sqlite3.register_adapter(dict, adapt_dict) sqlite3.register_converter("ARRAY", convert_array) if numpy_imported: sqlite3.register_adapter(np.ndarray, adapt_np_array) diff --git a/tests/unit/sql/test_array.py b/tests/unit/sql/test_array.py index 9238a2c15..369373d25 100644 --- a/tests/unit/sql/test_array.py +++ b/tests/unit/sql/test_array.py @@ -65,6 +65,24 @@ def test_length(warehouse): assert result == ((4, 5, 2),) +def test_contains(warehouse): + query = select( + func.contains(["abc", "def", "g", "hi"], "abc").label("contains1"), + func.contains(["abc", "def", "g", "hi"], "cdf").label("contains2"), + func.contains([3.0, 5.0, 1.0, 6.0, 1.0], 1.0).label("contains3"), + func.contains([[1, None, 3], [4, 5, 6]], [1, None, 3]).label("contains4"), + # Not supported yet by CH, need to add it later + some Pydantic model as + # an input: + # func.contains( + # [{"c": 1, "a": True}, {"b": False}], {"a": True, "c": 1} + # ).label("contains5"), + func.contains([1, None, 3], None).label("contains6"), + func.contains([1, True, 3], True).label("contains7"), + ) + result = tuple(warehouse.db.execute(query)) + assert result == ((1, 0, 1, 1, 1, 1),) + + def test_length_on_split(warehouse): query = select( func.array.length(func.string.split(func.literal("abc/def/g/hi"), "/")), diff --git a/tests/unit/test_func.py b/tests/unit/test_func.py index 317cf5e1e..da631c8e6 100644 --- a/tests/unit/test_func.py +++ b/tests/unit/test_func.py @@ -11,6 +11,7 @@ isnone, literal, ) +from datachain.func.array import contains from datachain.func.random import rand from datachain.func.string import length as strlen from datachain.lib.signal_schema import SignalSchema @@ -797,3 +798,27 @@ def test_isnone_with_ifelse_mutate(col): res = dc.mutate(test=ifelse(isnone(col), "NONE", "NOT_NONE")) assert list(res.order_by("num").collect("test")) == ["NOT_NONE"] * 3 + ["NONE"] * 2 assert res.schema["test"] is str + + +def test_array_contains(): + dc = DataChain.from_values( + arr=[list(range(1, i)) * i for i in range(2, 7)], + val=list(range(2, 7)), + ) + + assert list(dc.mutate(res=contains("arr", 3)).order_by("val").collect("res")) == [ + 0, + 0, + 1, + 1, + 1, + ] + assert list( + dc.mutate(res=contains(C("arr"), 3)).order_by("val").collect("res") + ) == [0, 0, 1, 1, 1] + assert list( + dc.mutate(res=contains(C("arr"), 10)).order_by("val").collect("res") + ) == [0, 0, 0, 0, 0] + assert list( + dc.mutate(res=contains(C("arr"), None)).order_by("val").collect("res") + ) == [0, 0, 0, 0, 0]