diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py index f8b9230b..28c15c31 100644 --- a/src/databricks/labs/lsql/backends.py +++ b/src/databricks/labs/lsql/backends.py @@ -5,7 +5,7 @@ import re from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence -from typing import Any, ClassVar, Protocol, TypeVar +from typing import Any, ClassVar, Literal, Protocol, TypeVar from databricks.labs.blueprint.commands import CommandExecutor from databricks.sdk import WorkspaceClient @@ -56,7 +56,13 @@ def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = No raise NotImplementedError @abstractmethod - def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): + def save_table( + self, + full_name: str, + rows: Sequence[DataclassInstance], + klass: Dataclass, + mode: Literal["append", "overwrite"] = "append", + ) -> None: raise NotImplementedError def create_table(self, full_name: str, klass: Dataclass): @@ -259,7 +265,13 @@ def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = No error_message = str(e) raise self._api_error_from_message(error_message) from None - def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): + def save_table( + self, + full_name: str, + rows: Sequence[DataclassInstance], + klass: Dataclass, + mode: Literal["append", "overwrite"] = "append", + ) -> None: rows = self._filter_none_rows(rows, klass) if len(rows) == 0: @@ -336,10 +348,17 @@ def fetch(self, sql, *, catalog=None, schema=None) -> Iterator[Row]: logger.debug(f"Returning rows: {rows}") return iter(rows) - def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): + def save_table( + self, + full_name: str, + rows: Sequence[DataclassInstance], + klass: Dataclass, + mode: Literal["append", "overwrite"] = "append", + ) -> None: rows = self._filter_none_rows(rows, klass) if mode == "overwrite": - self._save_table = [] + # Remove prior rows written for (only) this table. + self._save_table = [row for row in self._save_table if row[0] != full_name] if klass.__class__ == type: # noqa: E721 row_factory = self._row_factory(klass) rows = [row_factory(*dataclasses.astuple(r)) for r in rows] diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py index 2d95160b..694ccce6 100644 --- a/tests/unit/test_backends.py +++ b/tests/unit/test_backends.py @@ -402,10 +402,28 @@ def test_mock_backend_save_table(): mock_backend = MockBackend() mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + mock_backend.save_table("a.b.c", [Foo("ccc", True)], Foo) assert mock_backend.rows_written_for("a.b.c", "append") == [ Row(first="aaa", second=True), Row(first="bbb", second=False), + Row(first="ccc", second=True), + ] + + +def test_mock_backend_save_table_overwrite() -> None: + mock_backend = MockBackend() + + mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo, mode="overwrite") + mock_backend.save_table("d.e.f", [Foo("ddd", True), Foo("eee", False)], Foo, mode="overwrite") + mock_backend.save_table("d.e.f", [Foo("fff", True)], Foo, mode="overwrite") + + assert mock_backend.rows_written_for("a.b.c", "overwrite") == [ + Row(first="aaa", second=True), + Row(first="bbb", second=False), + ] + assert mock_backend.rows_written_for("d.e.f", "overwrite") == [ + Row(first="fff", second=True), ]