Skip to content

Commit

Permalink
wip: implementing tests for the atomic behavior of delete_rows
Browse files Browse the repository at this point in the history
  • Loading branch information
gmcrocetti committed Nov 21, 2024
1 parent 7065385 commit 3c33249
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,8 +2879,8 @@ def drop_table(self, name: str, schema: str | None = None) -> None:
def delete_rows(self, name: str, schema: str | None = None) -> None:
delete_sql = f"DELETE FROM {_get_valid_sqlite_name(name)}"
if self.has_table(name, schema):
with self.run_transaction():
self.execute(delete_sql)
with self.run_transaction() as cur:
cur.execute(delete_sql)

def _create_sql_schema(
self,
Expand Down
28 changes: 28 additions & 0 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2716,6 +2716,34 @@ def test_delete_rows_success(conn, test_frame1, request):
assert pandasSQL.has_table("temp_frame")


@pytest.mark.parametrize("conn", adbc_connectable)
def test_delete_rows_is_atomic(conn, request):
import adbc_driver_manager

if "sqlite" in conn:
pytest.skip("sqlite has no inspection system") # TODO: Change error message

table_name = "temp_frame"
original_df = DataFrame({"a": [1, 2, 3]})
replacing_df = DataFrame({"a": ["a", "b", "c", "d"]})

conn = request.getfixturevalue(conn)
pandasSQL = pandasSQL_builder(conn)

with pandasSQL.run_transaction():
pandasSQL.to_sql(original_df, table_name, if_exists="fail", index=False)

with pytest.raises(adbc_driver_manager.ProgrammingError):
with pandasSQL.run_transaction():
pandasSQL.to_sql(
replacing_df, table_name, if_exists="delete_rows", index=False
)

with pandasSQL.run_transaction():
unchanged_df = pandasSQL.read_query(f"SELECT * FROM {table_name}")
tm.assert_frame_equal(unchanged_df, original_df)


@pytest.mark.parametrize("conn", all_connectable)
def test_roundtrip(conn, request, test_frame1):
if conn == "sqlite_str":
Expand Down

0 comments on commit 3c33249

Please sign in to comment.