Skip to content

Commit

Permalink
feat(serialization: add serializable_overrides parameter to `to_dic…
Browse files Browse the repository at this point in the history
…t/json` and `from_dict/json` methods

Allows for users to override the serializable metadata for a class.
  • Loading branch information
collindutter committed Mar 10, 2025
1 parent 88e0f28 commit 24450e9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 22 deletions.
36 changes: 25 additions & 11 deletions griptape/mixins/serializable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,19 @@ class SerializableMixin(Generic[T]):
)

@classmethod
def get_schema(cls: type[T], subclass_name: Optional[str] = None, *, module_name: Optional[str] = None) -> Schema:
def get_schema(
cls: type[T],
subclass_name: Optional[str] = None,
*,
module_name: Optional[str] = None,
serializable_overrides: Optional[dict[str, bool]] = None,
) -> Schema:
"""Generates a Marshmallow schema for the class.
Args:
subclass_name: An optional subclass name. Required if the class is abstract.
module_name: An optional module name. Defaults to the class's module.
serializable_overrides: An optional dictionary of field names to override serializable status.
"""
if ABC in cls.__bases__:
if subclass_name is None:
Expand All @@ -43,28 +50,35 @@ def get_schema(cls: type[T], subclass_name: Optional[str] = None, *, module_name
module_name = module_name or cls.__module__
subclass_cls = cls._import_cls_rec(module_name, subclass_name)

schema_class = BaseSchema.from_attrs_cls(subclass_cls)
schema_class = BaseSchema.from_attrs_cls(subclass_cls, serializable_overrides=serializable_overrides)
else:
schema_class = BaseSchema.from_attrs_cls(cls)
schema_class = BaseSchema.from_attrs_cls(cls, serializable_overrides=serializable_overrides)

return schema_class()

@classmethod
def from_dict(cls: type[T], data: dict) -> T:
return cast(T, cls.get_schema(subclass_name=data.get("type"), module_name=data.get("module_name")).load(data))
def from_dict(cls: type[T], data: dict, *, serializable_overrides: Optional[dict[str, bool]] = None) -> T:
return cast(
T,
cls.get_schema(
subclass_name=data.get("type"),
module_name=data.get("module_name"),
serializable_overrides=serializable_overrides,
).load(data),
)

@classmethod
def from_json(cls: type[T], data: str) -> T:
return cls.from_dict(json.loads(data))
def from_json(cls: type[T], data: str, *, serializable_overrides: Optional[dict[str, bool]] = None) -> T:
return cls.from_dict(json.loads(data), serializable_overrides=serializable_overrides)

def __str__(self) -> str:
return json.dumps(self.to_dict())

def to_json(self) -> str:
return json.dumps(self.to_dict())
def to_json(self, *, serializable_overrides: Optional[dict[str, bool]] = None) -> str:
return json.dumps(self.to_dict(serializable_overrides=serializable_overrides))

def to_dict(self) -> dict:
schema = BaseSchema.from_attrs_cls(self.__class__)
def to_dict(self, *, serializable_overrides: Optional[dict[str, bool]] = None) -> dict:
schema = BaseSchema.from_attrs_cls(self.__class__, serializable_overrides=serializable_overrides)

return dict(schema().dump(self))

Expand Down
24 changes: 14 additions & 10 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@ class Meta:
}

@classmethod
def from_attrs_cls(cls, attrs_cls: type) -> type:
def from_attrs_cls(cls, attrs_cls: type, *, serializable_overrides: Optional[dict[str, bool]] = None) -> type:
"""Generate a Schema from an attrs class.
Args:
attrs_cls: An attrs class.
serializable_overrides: A dictionary of field names to whether they are serializable.
"""
from marshmallow import post_load

from griptape.mixins.serializable_mixin import SerializableMixin

if serializable_overrides is None:
serializable_overrides = {}

class SubSchema(cls):
@post_load
def make_obj(self, data: Any, **kwargs) -> Any:
Expand All @@ -52,16 +56,14 @@ def make_obj(self, data: Any, **kwargs) -> Any:

if issubclass(attrs_cls, SerializableMixin):
cls._resolve_types(attrs_cls)
return SubSchema.from_dict(
{
a.alias or a.name: cls._get_field_for_type(
a.type, serialization_key=a.metadata.get("serialization_key")
fields = {}
for field in attrs.fields(attrs_cls):
field_key = field.alias or field.name
if serializable_overrides.get(field_key, field.metadata.get("serializable", False)):
fields[field_key] = cls._get_field_for_type(
field.type, serialization_key=field.metadata.get("serialization_key")
)
for a in attrs.fields(attrs_cls)
if a.metadata.get("serializable")
},
name=f"{attrs_cls.__name__}Schema",
)
return SubSchema.from_dict(fields, name=f"{attrs_cls.__name__}Schema")
else:
raise ValueError(f"Class must implement SerializableMixin: {attrs_cls}")

Expand Down Expand Up @@ -244,6 +246,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
from griptape.engines.rag import RagContext
from griptape.events import EventListener
from griptape.memory import TaskMemory
from griptape.memory.meta import BaseMetaEntry
from griptape.memory.structure import BaseConversationMemory, Run
from griptape.memory.task.storage import BaseArtifactStorage
from griptape.rules.base_rule import BaseRule
Expand Down Expand Up @@ -275,6 +278,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"BaseObservabilityDriver": BaseObservabilityDriver,
"BaseAssistantDriver": BaseAssistantDriver,
"BaseArtifact": BaseArtifact,
"BaseMetaEntry": BaseMetaEntry,
"PromptStack": PromptStack,
"EventListener": EventListener,
"BaseMessageContent": BaseMessageContent,
Expand Down
10 changes: 9 additions & 1 deletion tests/mocks/mock_meta_entry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from __future__ import annotations

import json
from typing import Optional

from griptape.memory.meta import BaseMetaEntry


class MockMetaEntry(BaseMetaEntry):
def to_dict(self) -> dict:
def to_json(self, *, serializable_overrides: Optional[dict[str, bool]] = None) -> str:
return json.dumps(self.to_dict())

def to_dict(self, *, serializable_overrides: Optional[dict[str, bool]] = None) -> dict:
return {"foo": "bar"}

0 comments on commit 24450e9

Please sign in to comment.