diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 34cdd48da..a82db0e75 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -12,6 +12,7 @@ from griptape.tasks import PromptTask if TYPE_CHECKING: + from pydantic import BaseModel from schema import Schema from griptape.artifacts import BaseArtifact @@ -27,7 +28,7 @@ class Agent(Structure): ) stream: bool = field(default=None, kw_only=True) prompt_driver: BasePromptDriver = field(default=None, kw_only=True) - output_schema: Optional[Schema] = field(default=None, kw_only=True) + output_schema: Optional[Union[Schema, type[BaseModel]]] = field(default=None, kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True)