Skip to content

Commit

Permalink
fix: properly serialize pydantic models (#1757)
Browse files Browse the repository at this point in the history
* fix: properly serialize pydantic models

* fix(agent): update `output_schema` type hint to include pydantic model
  • Loading branch information
collindutter authored Feb 24, 2025
1 parent ef52194 commit ef83084
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 11 deletions.
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .audio_artifact import AudioArtifact
from .action_artifact import ActionArtifact
from .generic_artifact import GenericArtifact
from .model_artifact import ModelArtifact


__all__ = [
Expand All @@ -25,4 +26,5 @@
"AudioArtifact",
"ActionArtifact",
"GenericArtifact",
"ModelArtifact",
]
24 changes: 24 additions & 0 deletions griptape/artifacts/model_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from attrs import define, field
from pydantic import BaseModel

from griptape.artifacts.generic_artifact import GenericArtifact


@define
class ModelArtifact(GenericArtifact[BaseModel]):
"""Stores Pydantic models as Artifacts.
Required since Pydantic models require a custom serialization method.
Attributes:
value: The pydantic model to store.
"""

# We must explicitly define the type rather than rely on the parent T since
# generic type information is lost at runtime.
value: BaseModel = field(metadata={"serializable": True})

def to_text(self) -> str:
return self.value.model_dump_json()
4 changes: 3 additions & 1 deletion griptape/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@

from .union_field import Union

from .pydantic_model_field import PydanticModel

__all__ = ["BaseSchema", "PolymorphicSchema", "Bytes", "Union"]

__all__ = ["BaseSchema", "PolymorphicSchema", "Bytes", "Union", "PydanticModel"]
10 changes: 9 additions & 1 deletion griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@

import attrs
from marshmallow import INCLUDE, Schema, fields
from pydantic import BaseModel

from griptape.schemas.bytes_field import Bytes
from griptape.schemas.pydantic_model_field import PydanticModel
from griptape.schemas.union_field import Union as UnionField


class BaseSchema(Schema):
class Meta:
unknown = INCLUDE

DATACLASS_TYPE_MAPPING = {**Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw}
DATACLASS_TYPE_MAPPING = {
**Schema.TYPE_MAPPING,
dict: fields.Dict,
bytes: Bytes,
Any: fields.Raw,
BaseModel: PydanticModel,
}

@classmethod
def from_attrs_cls(cls, attrs_cls: type) -> type:
Expand Down
21 changes: 21 additions & 0 deletions griptape/schemas/pydantic_model_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional

from marshmallow import fields

if TYPE_CHECKING:
from pydantic import BaseModel


class PydanticModel(fields.Field):
def _serialize(self, value: Optional[BaseModel], attr: Any, obj: Any, **kwargs) -> Optional[dict]:
if value is None:
return None
return value.model_dump()

def _deserialize(self, value: dict, attr: Any, data: Any, **kwargs) -> dict:
# Not implemented as it is non-trivial to deserialize json back into a model
# since we need to know the model class to instantiate it.
# Would rather not implement right now rather than implement incorrectly.
raise NotImplementedError("Model fields cannot be deserialized directly.")
3 changes: 2 additions & 1 deletion griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from griptape.tasks import PromptTask

if TYPE_CHECKING:
from pydantic import BaseModel
from schema import Schema

from griptape.artifacts import BaseArtifact
Expand All @@ -27,7 +28,7 @@ class Agent(Structure):
)
stream: bool = field(default=None, kw_only=True)
prompt_driver: BasePromptDriver = field(default=None, kw_only=True)
output_schema: Optional[Schema] = field(default=None, kw_only=True)
output_schema: Optional[Union[Schema, type[BaseModel]]] = field(default=None, kw_only=True)
tools: list[BaseTool] = field(factory=list, kw_only=True)
max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)
fail_fast: bool = field(default=False, kw_only=True)
Expand Down
16 changes: 12 additions & 4 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@
from schema import Schema

from griptape import utils
from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, JsonArtifact, ListArtifact, TextArtifact
from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.artifacts.generic_artifact import GenericArtifact
from griptape.artifacts import (
ActionArtifact,
AudioArtifact,
BaseArtifact,
ErrorArtifact,
GenericArtifact,
JsonArtifact,
ListArtifact,
ModelArtifact,
TextArtifact,
)
from griptape.common import PromptStack, ToolAction
from griptape.configs import Defaults
from griptape.memory.structure import Run
Expand Down Expand Up @@ -216,7 +224,7 @@ def try_run(self) -> ListArtifact | TextArtifact | AudioArtifact | GenericArtifa
if isinstance(self.output_schema, Schema):
return JsonArtifact(output.value)
elif isinstance(self.output_schema, type) and issubclass(self.output_schema, BaseModel):
return GenericArtifact(TypeAdapter(self.output_schema).validate_json(output.value))
return ModelArtifact(TypeAdapter(self.output_schema).validate_json(output.value))
else:
raise ValueError(f"Unsupported output schema type: {type(self.output_schema)}")
elif isinstance(output, (TextArtifact, AudioArtifact, JsonArtifact, ErrorArtifact)):
Expand Down
5 changes: 5 additions & 0 deletions tests/mocks/mock_serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from attrs import define, field
from pydantic import BaseModel

from griptape.mixins.serializable_mixin import SerializableMixin

Expand All @@ -13,10 +14,14 @@ class MockSerializable(SerializableMixin):
class NestedMockSerializable(SerializableMixin):
foo: str = field(default="bar", kw_only=True, metadata={"serializable": True})

class MockOutput(BaseModel):
foo: str

foo: str = field(default="bar", kw_only=True, metadata={"serializable": True})
bar: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
baz: Optional[list[int]] = field(default=None, kw_only=True, metadata={"serializable": True})
secret: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
nested: Optional[MockSerializable.NestedMockSerializable] = field(
default=None, kw_only=True, metadata={"serializable": True}
)
model: Optional[BaseModel] = field(default=None, kw_only=True, metadata={"serializable": True})
25 changes: 25 additions & 0 deletions tests/unit/artifacts/test_model_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from pydantic import create_model

from griptape.artifacts import BaseArtifact, ModelArtifact


class TestModelArtifact:
@pytest.fixture()
def model_artifact(self):
return ModelArtifact(
value=create_model("ModelArtifact", value=(str, ...))(value="foo"),
)

def test_to_text(self, model_artifact: ModelArtifact):
assert model_artifact.to_text() == '{"value":"foo"}'

def test_to_dict(self, model_artifact: ModelArtifact):
generic_dict = model_artifact.to_dict()

assert generic_dict["value"] == {"value": "foo"}

def test_deserialization(self, model_artifact):
artifact_dict = model_artifact.to_dict()
with pytest.raises(NotImplementedError):
BaseArtifact.from_dict(artifact_dict)
16 changes: 12 additions & 4 deletions tests/unit/mixins/test_seriliazable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,29 @@ def test_from_json(self):

def test_str(self):
assert str(MockSerializable()) == json.dumps(
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None, "nested": None}
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None, "nested": None, "model": None}
)

def test_to_json(self):
assert MockSerializable().to_json() == json.dumps(
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None, "nested": None}
assert MockSerializable(model=MockSerializable.MockOutput(foo="bar")).to_json() == json.dumps(
{
"type": "MockSerializable",
"foo": "bar",
"bar": None,
"baz": None,
"nested": None,
"model": {"foo": "bar"},
}
)

def test_to_dict(self):
assert MockSerializable().to_dict() == {
assert MockSerializable(model=MockSerializable.MockOutput(foo="bar")).to_dict() == {
"type": "MockSerializable",
"foo": "bar",
"bar": None,
"baz": None,
"nested": None,
"model": {"foo": "bar"},
}

def test_import_class_rec(self):
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/schemas/test_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import pytest
from marshmallow import fields
from pydantic import BaseModel

from griptape.artifacts import BaseArtifact, TextArtifact
from griptape.loaders import TextLoader
from griptape.schemas import PolymorphicSchema
from griptape.schemas.base_schema import BaseSchema
from griptape.schemas.bytes_field import Bytes
from griptape.schemas.pydantic_model_field import PydanticModel
from griptape.schemas.union_field import Union as UnionField
from tests.mocks.mock_serializable import MockSerializable

Expand All @@ -26,6 +28,10 @@ class UnsupportedType:
pass


class MockModel(BaseModel):
foo: str


class TestBaseSchema:
def test_from_attrs_cls(self):
schema = BaseSchema.from_attrs_cls(MockSerializable)()
Expand Down Expand Up @@ -64,6 +70,8 @@ def test_get_field_for_type(self):
assert isinstance(BaseSchema._get_field_for_type(bool), fields.Bool)
assert isinstance(BaseSchema._get_field_for_type(tuple), fields.Raw)
assert isinstance(BaseSchema._get_field_for_type(dict), fields.Dict)

assert isinstance(BaseSchema._get_field_for_type(BaseModel), PydanticModel)
with pytest.raises(ValueError):
BaseSchema._get_field_for_type(list)

Expand Down

0 comments on commit ef83084

Please sign in to comment.