Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix pickle problem #140

Merged
merged 6 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,11 @@ def __init__(


class FactorGraphRAGStrategy(RAGStrategy):
prompt = Prompts(file_path=Path(__file__).parent.parent / "prompts.yaml")

def __init__(self, knowledgebase: FactorGraphKnowledgeBase) -> None:
super().__init__(knowledgebase)
self.current_generated_trace_count = 0
self.prompt = Prompts(file_path=Path(__file__).parent.parent / "prompts.yaml")

def generate_knowledge(
self,
Expand Down
12 changes: 10 additions & 2 deletions rdagent/core/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,20 @@ class Hypothesis:
- Belief
"""

def __init__(self, hypothesis: str, reason: str, concise_reason: str, concise_observation: str, concise_justification: str, concise_knowledge: str) -> None:
def __init__(
self,
hypothesis: str,
reason: str,
concise_reason: str,
concise_observation: str,
concise_justification: str,
concise_knowledge: str,
) -> None:
self.hypothesis: str = hypothesis
self.reason: str = reason
self.concise_reason: str = concise_reason
self.concise_observation: str = concise_observation
self.concise_justification: str = concise_justification
self.concise_justification: str = concise_justification
self.concise_knowledge: str = concise_knowledge

def __str__(self) -> str:
Expand Down
19 changes: 16 additions & 3 deletions rdagent/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import importlib
import json
import multiprocessing as mp
import pickle
from collections.abc import Callable
from typing import Any, ClassVar, cast
from typing import Any, ClassVar, NoReturn, cast

from fuzzywuzzy import fuzz # type: ignore[import-untyped]

Expand All @@ -27,13 +28,25 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
# TODO: this restriction can be solved.
exception_message = "Please only use kwargs in Singleton to avoid misunderstanding."
raise RDAgentException(exception_message)
all_args = [(-1, f"{cls.__module__}.{cls.__name__}")] + [(i, args[i]) for i in args] + list(sorted(kwargs.items()))
class_name = [(-1, f"{cls.__module__}.{cls.__name__}")]
args_l = [(i, args[i]) for i in args]
kwargs_l = list(sorted(kwargs.items()))
all_args = class_name + args_l + kwargs_l
kwargs_hash = hash(tuple(all_args))
if kwargs_hash not in cls._instance_dict:
cls._instance_dict[kwargs_hash] = super().__new__(cls) # Corrected call
cls._instance_dict[kwargs_hash].__init__(**kwargs) # Ensure __init__ is called
return cls._instance_dict[kwargs_hash]

def __reduce__(self) -> NoReturn:
"""
NOTE:
When loading an object from a pickle, the __new__ method does not receive the `kwargs`
it was initialized with. This makes it difficult to retrieve the correct singleton object.
Therefore, we have made it unpickable.
"""
msg = f"Instances of {self.__class__.__name__} cannot be pickled"
raise pickle.PicklingError(msg)


def parse_json(response: str) -> Any:
try:
Expand Down
29 changes: 29 additions & 0 deletions test/utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,29 @@

class A(SingletonBaseClass):
def __init__(self, **kwargs):
print(self, "__init__", kwargs) # make sure the __init__ is called only once.
self.kwargs = kwargs

def __str__(self) -> str:
return f"{self.__class__.__name__}.{getattr(self, 'kwargs', None)}"

def __repr__(self) -> str:
return self.__str__()


class MiscTest(unittest.TestCase):
def test_singleton(self):
print("a1=================")
a1 = A()
print("a2=================")
a2 = A()
print("a3=================")
a3 = A(x=3)
print("a4=================")
a4 = A(x=2)
print("a5=================")
a5 = A(b=3)
print("a6=================")
a6 = A(x=3)

# Check that a1 and a2 are the same instance
Expand All @@ -37,6 +50,22 @@ def test_singleton(self):

print(id(a1), id(a2), id(a3), id(a4), id(a5), id(a6))

print("...................... Start testing pickle ......................")

# Test pickle
import pickle

with self.assertRaises(pickle.PicklingError):
with open("a3.pkl", "wb") as f:
pickle.dump(a3, f)
# NOTE: If the pickle feature is not disabled,
# loading a3.pkl will return a1, and a1 will be updated with a3's attributes.
# print(a1.kwargs)
# with open("a3.pkl", "rb") as f:
# a3_pkl = pickle.load(f)
# print(id(a3), id(a3_pkl)) # not the same object
# print(a1.kwargs) # a1 will be changed.


if __name__ == "__main__":
unittest.main()