Skip to content

Commit

Permalink
Bugfix: MockBackend wasn't mocking savetable properly when the mo…
Browse files Browse the repository at this point in the history
…de is `append` (#289)

This PR updates the `MockBackend` implementation for mocking
`SQLBackend` so that the `.savetable(…, mode='overwrite')` works
correctly.

- Prior to this PR using overwrite mode would replace all rows, not just
the table being saved.
- As of this PR using overwrite mode will only replace any existing rows
for the given table; other tables are left-alone.

Some incidental changes:

- Update the test for `append`-mode to check that rows for a table
accumulate.
 - The type signature for `.save_table()` has been improved.
  • Loading branch information
asnare authored Sep 17, 2024
1 parent d88bb65 commit 621647f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
29 changes: 24 additions & 5 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
from typing import Any, ClassVar, Protocol, TypeVar
from typing import Any, ClassVar, Literal, Protocol, TypeVar

from databricks.labs.blueprint.commands import CommandExecutor
from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -56,7 +56,13 @@ def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = No
raise NotImplementedError

@abstractmethod
def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"):
def save_table(
self,
full_name: str,
rows: Sequence[DataclassInstance],
klass: Dataclass,
mode: Literal["append", "overwrite"] = "append",
) -> None:
raise NotImplementedError

def create_table(self, full_name: str, klass: Dataclass):
Expand Down Expand Up @@ -259,7 +265,13 @@ def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = No
error_message = str(e)
raise self._api_error_from_message(error_message) from None

def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"):
def save_table(
self,
full_name: str,
rows: Sequence[DataclassInstance],
klass: Dataclass,
mode: Literal["append", "overwrite"] = "append",
) -> None:
rows = self._filter_none_rows(rows, klass)

if len(rows) == 0:
Expand Down Expand Up @@ -336,10 +348,17 @@ def fetch(self, sql, *, catalog=None, schema=None) -> Iterator[Row]:
logger.debug(f"Returning rows: {rows}")
return iter(rows)

def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"):
def save_table(
self,
full_name: str,
rows: Sequence[DataclassInstance],
klass: Dataclass,
mode: Literal["append", "overwrite"] = "append",
) -> None:
rows = self._filter_none_rows(rows, klass)
if mode == "overwrite":
self._save_table = []
# Remove prior rows written for (only) this table.
self._save_table = [row for row in self._save_table if row[0] != full_name]
if klass.__class__ == type: # noqa: E721
row_factory = self._row_factory(klass)
rows = [row_factory(*dataclasses.astuple(r)) for r in rows]
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,28 @@ def test_mock_backend_save_table():
mock_backend = MockBackend()

mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo)
mock_backend.save_table("a.b.c", [Foo("ccc", True)], Foo)

assert mock_backend.rows_written_for("a.b.c", "append") == [
Row(first="aaa", second=True),
Row(first="bbb", second=False),
Row(first="ccc", second=True),
]


def test_mock_backend_save_table_overwrite() -> None:
mock_backend = MockBackend()

mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo, mode="overwrite")
mock_backend.save_table("d.e.f", [Foo("ddd", True), Foo("eee", False)], Foo, mode="overwrite")
mock_backend.save_table("d.e.f", [Foo("fff", True)], Foo, mode="overwrite")

assert mock_backend.rows_written_for("a.b.c", "overwrite") == [
Row(first="aaa", second=True),
Row(first="bbb", second=False),
]
assert mock_backend.rows_written_for("d.e.f", "overwrite") == [
Row(first="fff", second=True),
]


Expand Down

0 comments on commit 621647f

Please sign in to comment.