Skip to content

Commit

Permalink
change def in load and save, and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
kicksent committed Feb 5, 2025
1 parent 4965a6c commit 78f5d4f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
4 changes: 2 additions & 2 deletions burr/integrations/persisters/b_pymongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
return app_ids

def load(
self, partition_key: str, app_id: str, sequence_id: int = None, **kwargs
self, partition_key: Optional[str], app_id: str, sequence_id: int = None, **kwargs
) -> Optional[persistence.PersistedStateData]:
"""Load the state data for a given partition key, app id, and sequence id."""
query = {"partition_key": partition_key, "app_id": app_id}
Expand All @@ -118,7 +118,7 @@ def load(

def save(
self,
partition_key: str,
partition_key: Optional[str],
app_id: str,
sequence_id: int,
position: str,
Expand Down
13 changes: 13 additions & 0 deletions tests/integrations/persisters/test_b_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,16 @@ def test_serialization_with_pickle(mongodb_persister):
data = deserialized_persister.load("pk", "app_id_serde", 1)

assert data["state"].get_all() == {"a": 1, "b": 2}

def test_partition_key_is_optional(mongodb_persister):
# 1. Save and load with partition key = None
mongodb_persister.save(None, "app_id_none", 1, "pos1", state.State({"foo": "bar"}), "in_progress")
loaded_data = mongodb_persister.load(None, "app_id_none", 1)
assert loaded_data is not None
assert loaded_data["state"].get_all() == {"foo": "bar"}

# 2. Save and load again (different key/index) with partition key = None
mongodb_persister.save(None, "app_id_none2", 2, "pos2", state.State({"hello": "world"}), "completed")
loaded_data2 = mongodb_persister.load(None, "app_id_none2", 2)
assert loaded_data2 is not None
assert loaded_data2["state"].get_all() == {"hello": "world"}

0 comments on commit 78f5d4f

Please sign in to comment.