Skip to content

Commit

Permalink
Added full outer join (#822)
Browse files Browse the repository at this point in the history
* added main logic for outer join

* fixing filters

* removign datasetquery tests and added more datachain unit tests
  • Loading branch information
ilongin authored Jan 20, 2025
1 parent 78f8953 commit c6faa5f
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 15 deletions.
45 changes: 38 additions & 7 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sqlalchemy.dialects import sqlite
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
from sqlalchemy.sql import func
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
from sqlalchemy.sql.expression import bindparam, cast
from sqlalchemy.sql.selectable import Select
from tqdm.auto import tqdm
Expand All @@ -40,7 +41,6 @@
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql._typing import _FromClauseArgument, _OnClauseArgument
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.selectable import Join
from sqlalchemy.types import TypeEngine

from datachain.lib.file import File
Expand Down Expand Up @@ -654,16 +654,47 @@ def join(
right: "_FromClauseArgument",
onclause: "_OnClauseArgument",
inner: bool = True,
) -> "Join":
full: bool = False,
columns=None,
) -> "Select":
"""
Join two tables together.
"""
return sqlalchemy.join(
left,
right,
onclause,
isouter=not inner,
if not full:
join_query = sqlalchemy.join(
left,
right,
onclause,
isouter=not inner,
)
return sqlalchemy.select(*columns).select_from(join_query)

left_right_join = sqlalchemy.select(*columns).select_from(
sqlalchemy.join(left, right, onclause, isouter=True)
)
right_left_join = sqlalchemy.select(*columns).select_from(
sqlalchemy.join(right, left, onclause, isouter=True)
)

def add_left_rows_filter(exp: BinaryExpression):
"""
Adds filter to right_left_join to remove unmatched left table rows by
getting column names that need to be NULL from BinaryExpressions in onclause
"""
return right_left_join.where(
getattr(left.c, exp.left.name) == None # type: ignore[union-attr] # noqa: E711
)

if isinstance(onclause, BinaryExpression):
right_left_join = add_left_rows_filter(onclause)

if isinstance(onclause, BooleanClauseList):
for c in onclause.get_children():
if isinstance(c, BinaryExpression):
right_left_join = add_left_rows_filter(c)

union = sqlalchemy.union(left_right_join, right_left_join).subquery()
return sqlalchemy.select(*union.c).select_from(union)

def create_pre_udf_table(self, query: "Select") -> "Table":
"""
Expand Down
4 changes: 2 additions & 2 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_FromClauseArgument,
_OnClauseArgument,
)
from sqlalchemy.sql.selectable import Join, Select
from sqlalchemy.sql.selectable import Select
from sqlalchemy.types import TypeEngine

from datachain.data_storage import schema
Expand Down Expand Up @@ -873,7 +873,7 @@ def join(
right: "_FromClauseArgument",
onclause: "_OnClauseArgument",
inner: bool = True,
) -> "Join":
) -> "Select":
"""
Join two tables together.
"""
Expand Down
4 changes: 3 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,7 @@ def merge(
on: Union[MergeColType, Sequence[MergeColType]],
right_on: Optional[Union[MergeColType, Sequence[MergeColType]]] = None,
inner=False,
full=False,
rname="right_",
) -> "Self":
"""Merge two chains based on the specified criteria.
Expand All @@ -1345,6 +1346,7 @@ def merge(
right_on: Optional predicate or list of Predicates for the `right_ds`
to join.
inner (bool): Whether to run inner join or outer join.
full (bool): Whether to run full outer join.
rname (str): Name prefix for conflicting signal names.
Examples:
Expand Down Expand Up @@ -1419,7 +1421,7 @@ def _resolve(
)

query = self._query.join(
right_ds._query, sqlalchemy.and_(*ops), inner, rname + "{name}"
right_ds._query, sqlalchemy.and_(*ops), inner, full, rname + "{name}"
)
query.feature_schema = None
ds = self._evolve(query=query)
Expand Down
12 changes: 8 additions & 4 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ class SQLJoin(Step):
query2: "DatasetQuery"
predicates: Union[JoinPredicateType, tuple[JoinPredicateType, ...]]
inner: bool
full: bool
rname: str

def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
Expand Down Expand Up @@ -977,14 +978,14 @@ def apply(
self.validate_expression(join_expression, q1, q2)

def q(*columns):
join_query = self.catalog.warehouse.join(
return self.catalog.warehouse.join(
q1,
q2,
join_expression,
inner=self.inner,
full=self.full,
columns=columns,
)
return sqlalchemy.select(*columns).select_from(join_query)
# return sqlalchemy.select(*subquery.c).select_from(subquery)

return step_result(
q,
Expand Down Expand Up @@ -1489,6 +1490,7 @@ def join(
dataset_query: "DatasetQuery",
predicates: Union[JoinPredicateType, Sequence[JoinPredicateType]],
inner=False,
full=False,
rname="{name}_right",
) -> "Self":
left = self.clone(new_table=False)
Expand All @@ -1504,7 +1506,9 @@ def join(
if isinstance(predicates, (str, ColumnClause, ColumnElement))
else tuple(predicates)
)
new_query.steps = [SQLJoin(self.catalog, left, right, predicates, inner, rname)]
new_query.steps = [
SQLJoin(self.catalog, left, right, predicates, inner, full, rname)
]
return new_query

@detach
Expand Down
49 changes: 48 additions & 1 deletion tests/unit/lib/test_datachain_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TeamMember(BaseModel):
TeamMember(player="Alice", sport="volleyball", weight=120.3, height=5.5),
TeamMember(player="Charlie", sport="football", weight=200.0, height=6.0),
TeamMember(player="David", sport="football", weight=158.7, height=5.7),
TeamMember(player="John", sport="basketball", weight=250.3, height=7.0),
]


Expand Down Expand Up @@ -77,7 +78,53 @@ def test_merge_objects(test_session):
assert pd.isnull(player.height)

assert i == len(employees)
assert j == len(team)
assert j == len(team) - 1


@pytest.mark.parametrize("multiple_predicates", [True, False])
def test_merge_objects_full_join(test_session, multiple_predicates):
ch1 = DataChain.from_values(emp=employees, session=test_session)
ch2 = DataChain.from_values(team=team, session=test_session)
if multiple_predicates:
ch = ch1.merge(
ch2,
["emp.person.name", "emp.person.name"],
["team.player", "team.player"],
full=True,
)
else:
ch = ch1.merge(ch2, "emp.person.name", "team.player", full=True)

str_default = String.default_value(test_session.catalog.warehouse.db.dialect)
int_default = Int.default_value(test_session.catalog.warehouse.db.dialect)

i = 0
for items in ch.order_by("emp.person.name", "team.player").collect():
assert len(items) == 2

empl, player = items
assert isinstance(empl, Employee)
assert isinstance(player, TeamMember)

if player.player == "John":
assert empl.person.name == str_default
assert empl.person.age == int_default
continue

if empl.person.name == "Bob":
assert player.player == str_default
assert player.sport == str_default
assert pd.isnull(player.weight)
assert pd.isnull(player.height)
continue

assert player.player == team[i].player
assert player.sport == team[i].sport
assert math.isclose(player.weight, team[i].weight, rel_tol=1e-7)
assert math.isclose(player.height, team[i].height, rel_tol=1e-7)
i += 1

assert i == len(employees) - 1 == len(team) - 1


def test_merge_similar_objects(test_session):
Expand Down

0 comments on commit c6faa5f

Please sign in to comment.