From 3c3324962f6f5d658c642f1f20c8fa086ec3daa7 Mon Sep 17 00:00:00 2001 From: Guilherme Martins Crocetti <24530683+gmcrocetti@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:37:30 -0300 Subject: [PATCH] wip: implementing tests for the atomic behavior of delete_rows --- pandas/io/sql.py | 4 ++-- pandas/tests/io/test_sql.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 4f7bfed64eb48f..99b45115826342 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -2879,8 +2879,8 @@ def drop_table(self, name: str, schema: str | None = None) -> None: def delete_rows(self, name: str, schema: str | None = None) -> None: delete_sql = f"DELETE FROM {_get_valid_sqlite_name(name)}" if self.has_table(name, schema): - with self.run_transaction(): - self.execute(delete_sql) + with self.run_transaction() as cur: + cur.execute(delete_sql) def _create_sql_schema( self, diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index b4ca60d77e77fe..4085aff10c782b 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -2716,6 +2716,34 @@ def test_delete_rows_success(conn, test_frame1, request): assert pandasSQL.has_table("temp_frame") +@pytest.mark.parametrize("conn", adbc_connectable) +def test_delete_rows_is_atomic(conn, request): + import adbc_driver_manager + + if "sqlite" in conn: + pytest.skip("sqlite has no inspection system") # TODO: Change error message + + table_name = "temp_frame" + original_df = DataFrame({"a": [1, 2, 3]}) + replacing_df = DataFrame({"a": ["a", "b", "c", "d"]}) + + conn = request.getfixturevalue(conn) + pandasSQL = pandasSQL_builder(conn) + + with pandasSQL.run_transaction(): + pandasSQL.to_sql(original_df, table_name, if_exists="fail", index=False) + + with pytest.raises(adbc_driver_manager.ProgrammingError): + with pandasSQL.run_transaction(): + pandasSQL.to_sql( + replacing_df, table_name, if_exists="delete_rows", index=False + ) + + with pandasSQL.run_transaction(): + unchanged_df = pandasSQL.read_query(f"SELECT * FROM {table_name}") + tm.assert_frame_equal(unchanged_df, original_df) + + @pytest.mark.parametrize("conn", all_connectable) def test_roundtrip(conn, request, test_frame1): if conn == "sqlite_str":