Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: properly serialize pydantic models #1757

Merged
merged 2 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .audio_artifact import AudioArtifact
from .action_artifact import ActionArtifact
from .generic_artifact import GenericArtifact
from .model_artifact import ModelArtifact


__all__ = [
Expand All @@ -25,4 +26,5 @@
"AudioArtifact",
"ActionArtifact",
"GenericArtifact",
"ModelArtifact",
]
24 changes: 24 additions & 0 deletions griptape/artifacts/model_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from attrs import define, field
from pydantic import BaseModel

from griptape.artifacts.generic_artifact import GenericArtifact


@define
class ModelArtifact(GenericArtifact[BaseModel]):
"""Stores Pydantic models as Artifacts.

Required since Pydantic models require a custom serialization method.

Attributes:
value: The pydantic model to store.
"""

# We must explicitly define the type rather than rely on the parent T since
# generic type information is lost at runtime.
value: BaseModel = field(metadata={"serializable": True})

def to_text(self) -> str:
return self.value.model_dump_json()
4 changes: 3 additions & 1 deletion griptape/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@

from .union_field import Union

from .pydantic_model_field import PydanticModel

__all__ = ["BaseSchema", "PolymorphicSchema", "Bytes", "Union"]

__all__ = ["BaseSchema", "PolymorphicSchema", "Bytes", "Union", "PydanticModel"]
10 changes: 9 additions & 1 deletion griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@

import attrs
from marshmallow import INCLUDE, Schema, fields
from pydantic import BaseModel

from griptape.schemas.bytes_field import Bytes
from griptape.schemas.pydantic_model_field import PydanticModel
from griptape.schemas.union_field import Union as UnionField


class BaseSchema(Schema):
class Meta:
unknown = INCLUDE

DATACLASS_TYPE_MAPPING = {**Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw}
DATACLASS_TYPE_MAPPING = {
**Schema.TYPE_MAPPING,
dict: fields.Dict,
bytes: Bytes,
Any: fields.Raw,
BaseModel: PydanticModel,
}

@classmethod
def from_attrs_cls(cls, attrs_cls: type) -> type:
Expand Down
21 changes: 21 additions & 0 deletions griptape/schemas/pydantic_model_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional

from marshmallow import fields

if TYPE_CHECKING:
from pydantic import BaseModel


class PydanticModel(fields.Field):
def _serialize(self, value: Optional[BaseModel], attr: Any, obj: Any, **kwargs) -> Optional[dict]:
if value is None:
return None
return value.model_dump()

def _deserialize(self, value: dict, attr: Any, data: Any, **kwargs) -> dict:
# Not implemented as it is non-trivial to deserialize json back into a model
# since we need to know the model class to instantiate it.
# Would rather not implement right now rather than implement incorrectly.
raise NotImplementedError("Model fields cannot be deserialized directly.")
3 changes: 2 additions & 1 deletion griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from griptape.tasks import PromptTask

if TYPE_CHECKING:
from pydantic import BaseModel
from schema import Schema

from griptape.artifacts import BaseArtifact
Expand All @@ -27,7 +28,7 @@ class Agent(Structure):
)
stream: bool = field(default=None, kw_only=True)
prompt_driver: BasePromptDriver = field(default=None, kw_only=True)
output_schema: Optional[Schema] = field(default=None, kw_only=True)
output_schema: Optional[Union[Schema, type[BaseModel]]] = field(default=None, kw_only=True)
tools: list[BaseTool] = field(factory=list, kw_only=True)
max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)
fail_fast: bool = field(default=False, kw_only=True)
Expand Down
16 changes: 12 additions & 4 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@
from schema import Schema

from griptape import utils
from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, JsonArtifact, ListArtifact, TextArtifact
from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.artifacts.generic_artifact import GenericArtifact
from griptape.artifacts import (
ActionArtifact,
AudioArtifact,
BaseArtifact,
ErrorArtifact,
GenericArtifact,
JsonArtifact,
ListArtifact,
ModelArtifact,
TextArtifact,
)
from griptape.common import PromptStack, ToolAction
from griptape.configs import Defaults
from griptape.memory.structure import Run
Expand Down Expand Up @@ -216,7 +224,7 @@ def try_run(self) -> ListArtifact | TextArtifact | AudioArtifact | GenericArtifa
if isinstance(self.output_schema, Schema):
return JsonArtifact(output.value)
elif isinstance(self.output_schema, type) and issubclass(self.output_schema, BaseModel):
return GenericArtifact(TypeAdapter(self.output_schema).validate_json(output.value))
return ModelArtifact(TypeAdapter(self.output_schema).validate_json(output.value))
else:
raise ValueError(f"Unsupported output schema type: {type(self.output_schema)}")
elif isinstance(output, (TextArtifact, AudioArtifact, JsonArtifact, ErrorArtifact)):
Expand Down
5 changes: 5 additions & 0 deletions tests/mocks/mock_serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from attrs import define, field
from pydantic import BaseModel

from griptape.mixins.serializable_mixin import SerializableMixin

Expand All @@ -13,10 +14,14 @@ class MockSerializable(SerializableMixin):
class NestedMockSerializable(SerializableMixin):
foo: str = field(default="bar", kw_only=True, metadata={"serializable": True})

class MockOutput(BaseModel):
foo: str

foo: str = field(default="bar", kw_only=True, metadata={"serializable": True})
bar: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
baz: Optional[list[int]] = field(default=None, kw_only=True, metadata={"serializable": True})
secret: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
nested: Optional[MockSerializable.NestedMockSerializable] = field(
default=None, kw_only=True, metadata={"serializable": True}
)
model: Optional[BaseModel] = field(default=None, kw_only=True, metadata={"serializable": True})
25 changes: 25 additions & 0 deletions tests/unit/artifacts/test_model_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from pydantic import create_model

from griptape.artifacts import BaseArtifact, ModelArtifact


class TestModelArtifact:
@pytest.fixture()
def model_artifact(self):
return ModelArtifact(
value=create_model("ModelArtifact", value=(str, ...))(value="foo"),
)

def test_to_text(self, model_artifact: ModelArtifact):
assert model_artifact.to_text() == '{"value":"foo"}'

def test_to_dict(self, model_artifact: ModelArtifact):
generic_dict = model_artifact.to_dict()

assert generic_dict["value"] == {"value": "foo"}

def test_deserialization(self, model_artifact):
artifact_dict = model_artifact.to_dict()
with pytest.raises(NotImplementedError):
BaseArtifact.from_dict(artifact_dict)
16 changes: 12 additions & 4 deletions tests/unit/mixins/test_seriliazable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,29 @@ def test_from_json(self):

def test_str(self):
assert str(MockSerializable()) == json.dumps(
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None, "nested": None}
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None, "nested": None, "model": None}
)

def test_to_json(self):
assert MockSerializable().to_json() == json.dumps(
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None, "nested": None}
assert MockSerializable(model=MockSerializable.MockOutput(foo="bar")).to_json() == json.dumps(
{
"type": "MockSerializable",
"foo": "bar",
"bar": None,
"baz": None,
"nested": None,
"model": {"foo": "bar"},
}
)

def test_to_dict(self):
assert MockSerializable().to_dict() == {
assert MockSerializable(model=MockSerializable.MockOutput(foo="bar")).to_dict() == {
"type": "MockSerializable",
"foo": "bar",
"bar": None,
"baz": None,
"nested": None,
"model": {"foo": "bar"},
}

def test_import_class_rec(self):
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/schemas/test_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import pytest
from marshmallow import fields
from pydantic import BaseModel

from griptape.artifacts import BaseArtifact, TextArtifact
from griptape.loaders import TextLoader
from griptape.schemas import PolymorphicSchema
from griptape.schemas.base_schema import BaseSchema
from griptape.schemas.bytes_field import Bytes
from griptape.schemas.pydantic_model_field import PydanticModel
from griptape.schemas.union_field import Union as UnionField
from tests.mocks.mock_serializable import MockSerializable

Expand All @@ -26,6 +28,10 @@ class UnsupportedType:
pass


class MockModel(BaseModel):
foo: str


class TestBaseSchema:
def test_from_attrs_cls(self):
schema = BaseSchema.from_attrs_cls(MockSerializable)()
Expand Down Expand Up @@ -64,6 +70,8 @@ def test_get_field_for_type(self):
assert isinstance(BaseSchema._get_field_for_type(bool), fields.Bool)
assert isinstance(BaseSchema._get_field_for_type(tuple), fields.Raw)
assert isinstance(BaseSchema._get_field_for_type(dict), fields.Dict)

assert isinstance(BaseSchema._get_field_for_type(BaseModel), PydanticModel)
with pytest.raises(ValueError):
BaseSchema._get_field_for_type(list)

Expand Down
Loading