Skip to content

Commit

Permalink
add arrays contains to sqlite (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Feb 3, 2025
1 parent aed4ae7 commit 07df868
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 5 deletions.
17 changes: 16 additions & 1 deletion examples/get_started/common_sql_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def num_chars_udf(file):
return ([],)


dc = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/")
dc = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/", anon=True)
dc.map(num_chars_udf, params=["file"], output={"num_chars": list[str]}).select(
"file.path", "num_chars"
).show(5)
Expand All @@ -32,6 +32,12 @@ def num_chars_udf(file):
.show(5)
)

parts = string.split(path.name(C("file.path")), ".")
chain = dc.mutate(
isdog=array.contains(parts, "dog"),
iscat=array.contains(parts, "cat"),
)
chain.select("file.path", "isdog", "iscat").show(5)

chain = dc.mutate(
a=array.length(string.split("file.path", "/")),
Expand Down Expand Up @@ -79,6 +85,15 @@ def num_chars_udf(file):
3 dogs-and-cats/cat.10.json cat.10 json
4 dogs-and-cats/cat.100.jpg cat.100 jpg
[Limited by 5 rows]
file isdog iscat
path
0 dogs-and-cats/cat.1.jpg 0 1
1 dogs-and-cats/cat.1.json 0 1
2 dogs-and-cats/cat.10.jpg 0 1
3 dogs-and-cats/cat.10.json 0 1
4 dogs-and-cats/cat.100.jpg 0 1
[Limited by 5 rows]
Processed: 400 rows [00:00, 16496.93 rows/s]
a b greatest least
Expand Down
3 changes: 2 additions & 1 deletion src/datachain/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
row_number,
sum,
)
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
from .conditional import case, greatest, ifelse, isnone, least
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
from .random import rand
Expand All @@ -34,6 +34,7 @@
"case",
"collect",
"concat",
"contains",
"cosine_distance",
"count",
"dense_rank",
Expand Down
40 changes: 39 additions & 1 deletion src/datachain/func/array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Union
from typing import Any, Union

from datachain.sql.functions import array

Expand Down Expand Up @@ -140,6 +140,44 @@ def length(arg: Union[str, Sequence, Func]) -> Func:
return Func("length", inner=array.length, cols=cols, args=args, result_type=int)


def contains(arr: Union[str, Sequence, Func], elem: Any) -> Func:
"""
Checks whether the `arr` array has the `elem` element.
Args:
arr (str | Sequence | Func): Array to check for the element.
If a string is provided, it is assumed to be the name of the array column.
If a sequence is provided, it is assumed to be an array of values.
If a Func is provided, it is assumed to be a function returning an array.
elem (Any): Element to check for in the array.
Returns:
Func: A Func object that represents the contains function. Result of the
function will be 1 if the element is present in the array, and 0 otherwise.
Example:
```py
dc.mutate(
contains1=func.array.contains("signal.values", 3),
contains2=func.array.contains([1, 2, 3, 4, 5], 7),
)
```
"""

def inner(arg):
is_json = type(elem) in [list, dict]
return array.contains(arg, elem, is_json)

if isinstance(arr, (str, Func)):
cols = [arr]
args = None
else:
cols = None
args = [arr]

return Func("contains", inner=inner, cols=cols, args=args, result_type=int)


def sip_hash_64(arg: Union[str, Sequence]) -> Func:
"""
Computes the SipHash-64 hash of the array.
Expand Down
14 changes: 13 additions & 1 deletion src/datachain/sql/functions/array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.sql.functions import GenericFunction

from datachain.sql.types import Float, Int64
from datachain.sql.types import Boolean, Float, Int64
from datachain.sql.utils import compiler_not_implemented


Expand Down Expand Up @@ -37,6 +37,17 @@ class length(GenericFunction): # noqa: N801
inherit_cache = True


class contains(GenericFunction): # noqa: N801
"""
Checks if element is in the array.
"""

type = Boolean()
package = "array"
name = "contains"
inherit_cache = True


class sip_hash_64(GenericFunction): # noqa: N801
"""
Computes the SipHash-64 hash of the array.
Expand All @@ -51,4 +62,5 @@ class sip_hash_64(GenericFunction): # noqa: N801
compiler_not_implemented(cosine_distance)
compiler_not_implemented(euclidean_distance)
compiler_not_implemented(length)
compiler_not_implemented(contains)
compiler_not_implemented(sip_hash_64)
18 changes: 17 additions & 1 deletion src/datachain/sql/sqlite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def setup():
compiles(sql_path.file_stem, "sqlite")(compile_path_file_stem)
compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
compiles(array.length, "sqlite")(compile_array_length)
compiles(array.contains, "sqlite")(compile_array_contains)
compiles(string.length, "sqlite")(compile_string_length)
compiles(string.split, "sqlite")(compile_string_split)
compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
Expand Down Expand Up @@ -269,13 +270,16 @@ def create_string_functions(conn):

_registered_function_creators["string_functions"] = create_string_functions

has_json_extension = functions_exist(["json_array_length"])
has_json_extension = functions_exist(["json_array_length", "json_array_contains"])
if not has_json_extension:

def create_json_functions(conn):
conn.create_function(
"json_array_length", 1, py_json_array_length, deterministic=True
)
conn.create_function(
"json_array_contains", 3, py_json_array_contains, deterministic=True
)

_registered_function_creators["json_functions"] = create_json_functions

Expand Down Expand Up @@ -428,10 +432,22 @@ def py_json_array_length(arr):
return len(orjson.loads(arr))


def py_json_array_contains(arr, value, is_json):
if is_json:
value = orjson.loads(value)
return value in orjson.loads(arr)


def compile_array_length(element, compiler, **kwargs):
return compiler.process(func.json_array_length(*element.clauses.clauses), **kwargs)


def compile_array_contains(element, compiler, **kwargs):
return compiler.process(
func.json_array_contains(*element.clauses.clauses), **kwargs
)


def compile_string_length(element, compiler, **kwargs):
return compiler.process(func.length(*element.clauses.clauses), **kwargs)

Expand Down
5 changes: 5 additions & 0 deletions src/datachain/sql/sqlite/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def adapt_array(arr):
return orjson.dumps(arr).decode("utf-8")


def adapt_dict(dct):
return orjson.dumps(dct).decode("utf-8")


def convert_array(arr):
return orjson.loads(arr)

Expand All @@ -52,6 +56,7 @@ def adapt_np_generic(val):

def register_type_converters():
sqlite3.register_adapter(list, adapt_array)
sqlite3.register_adapter(dict, adapt_dict)
sqlite3.register_converter("ARRAY", convert_array)
if numpy_imported:
sqlite3.register_adapter(np.ndarray, adapt_np_array)
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/sql/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def test_length(warehouse):
assert result == ((4, 5, 2),)


def test_contains(warehouse):
query = select(
func.contains(["abc", "def", "g", "hi"], "abc").label("contains1"),
func.contains(["abc", "def", "g", "hi"], "cdf").label("contains2"),
func.contains([3.0, 5.0, 1.0, 6.0, 1.0], 1.0).label("contains3"),
func.contains([[1, None, 3], [4, 5, 6]], [1, None, 3]).label("contains4"),
# Not supported yet by CH, need to add it later + some Pydantic model as
# an input:
# func.contains(
# [{"c": 1, "a": True}, {"b": False}], {"a": True, "c": 1}
# ).label("contains5"),
func.contains([1, None, 3], None).label("contains6"),
func.contains([1, True, 3], True).label("contains7"),
)
result = tuple(warehouse.db.execute(query))
assert result == ((1, 0, 1, 1, 1, 1),)


def test_length_on_split(warehouse):
query = select(
func.array.length(func.string.split(func.literal("abc/def/g/hi"), "/")),
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
isnone,
literal,
)
from datachain.func.array import contains
from datachain.func.random import rand
from datachain.func.string import length as strlen
from datachain.lib.signal_schema import SignalSchema
Expand Down Expand Up @@ -797,3 +798,27 @@ def test_isnone_with_ifelse_mutate(col):
res = dc.mutate(test=ifelse(isnone(col), "NONE", "NOT_NONE"))
assert list(res.order_by("num").collect("test")) == ["NOT_NONE"] * 3 + ["NONE"] * 2
assert res.schema["test"] is str


def test_array_contains():
dc = DataChain.from_values(
arr=[list(range(1, i)) * i for i in range(2, 7)],
val=list(range(2, 7)),
)

assert list(dc.mutate(res=contains("arr", 3)).order_by("val").collect("res")) == [
0,
0,
1,
1,
1,
]
assert list(
dc.mutate(res=contains(C("arr"), 3)).order_by("val").collect("res")
) == [0, 0, 1, 1, 1]
assert list(
dc.mutate(res=contains(C("arr"), 10)).order_by("val").collect("res")
) == [0, 0, 0, 0, 0]
assert list(
dc.mutate(res=contains(C("arr"), None)).order_by("val").collect("res")
) == [0, 0, 0, 0, 0]

0 comments on commit 07df868

Please sign in to comment.