-
Notifications
You must be signed in to change notification settings - Fork 188
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: properly serialize pydantic models (#1757)
* fix: properly serialize pydantic models * fix(agent): update `output_schema` type hint to include pydantic model
- Loading branch information
1 parent
ef52194
commit ef83084
Showing
11 changed files
with
123 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from __future__ import annotations | ||
|
||
from attrs import define, field | ||
from pydantic import BaseModel | ||
|
||
from griptape.artifacts.generic_artifact import GenericArtifact | ||
|
||
|
||
@define | ||
class ModelArtifact(GenericArtifact[BaseModel]): | ||
"""Stores Pydantic models as Artifacts. | ||
Required since Pydantic models require a custom serialization method. | ||
Attributes: | ||
value: The pydantic model to store. | ||
""" | ||
|
||
# We must explicitly define the type rather than rely on the parent T since | ||
# generic type information is lost at runtime. | ||
value: BaseModel = field(metadata={"serializable": True}) | ||
|
||
def to_text(self) -> str: | ||
return self.value.model_dump_json() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any, Optional | ||
|
||
from marshmallow import fields | ||
|
||
if TYPE_CHECKING: | ||
from pydantic import BaseModel | ||
|
||
|
||
class PydanticModel(fields.Field): | ||
def _serialize(self, value: Optional[BaseModel], attr: Any, obj: Any, **kwargs) -> Optional[dict]: | ||
if value is None: | ||
return None | ||
return value.model_dump() | ||
|
||
def _deserialize(self, value: dict, attr: Any, data: Any, **kwargs) -> dict: | ||
# Not implemented as it is non-trivial to deserialize json back into a model | ||
# since we need to know the model class to instantiate it. | ||
# Would rather not implement right now rather than implement incorrectly. | ||
raise NotImplementedError("Model fields cannot be deserialized directly.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pytest | ||
from pydantic import create_model | ||
|
||
from griptape.artifacts import BaseArtifact, ModelArtifact | ||
|
||
|
||
class TestModelArtifact: | ||
@pytest.fixture() | ||
def model_artifact(self): | ||
return ModelArtifact( | ||
value=create_model("ModelArtifact", value=(str, ...))(value="foo"), | ||
) | ||
|
||
def test_to_text(self, model_artifact: ModelArtifact): | ||
assert model_artifact.to_text() == '{"value":"foo"}' | ||
|
||
def test_to_dict(self, model_artifact: ModelArtifact): | ||
generic_dict = model_artifact.to_dict() | ||
|
||
assert generic_dict["value"] == {"value": "foo"} | ||
|
||
def test_deserialization(self, model_artifact): | ||
artifact_dict = model_artifact.to_dict() | ||
with pytest.raises(NotImplementedError): | ||
BaseArtifact.from_dict(artifact_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters