Skip to content

Commit

Permalink
feat(drivers): add GriptapeCloudPromptDriver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Feb 11, 2025
1 parent 4f6f8ce commit 914bfb5
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .prompt.google import GooglePromptDriver
from .prompt.dummy import DummyPromptDriver
from .prompt.ollama import OllamaPromptDriver
from .prompt.griptape_cloud import GriptapeCloudPromptDriver

from .memory.conversation import BaseConversationMemoryDriver
from .memory.conversation.local import LocalConversationMemoryDriver
Expand Down Expand Up @@ -139,6 +140,7 @@
"GooglePromptDriver",
"DummyPromptDriver",
"OllamaPromptDriver",
"GriptapeCloudPromptDriver",
"BaseConversationMemoryDriver",
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/griptape_cloud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from griptape.drivers.prompt.griptape_cloud_prompt_driver import GriptapeCloudPromptDriver

__all__ = ["GriptapeCloudPromptDriver"]
70 changes: 70 additions & 0 deletions griptape/drivers/prompt/griptape_cloud_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

import logging
import os
from typing import TYPE_CHECKING
from urllib.parse import urljoin

import requests
from attrs import Factory, define, field

from griptape.common import DeltaMessage, Message, PromptStack, observable
from griptape.configs.defaults_config import Defaults
from griptape.drivers.prompt import BasePromptDriver

if TYPE_CHECKING:
from collections.abc import Iterator


logger = logging.getLogger(Defaults.logging_config.logger_name)


@define
class GriptapeCloudPromptDriver(BasePromptDriver):
base_url: str = field(
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
)
api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
kw_only=True,
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
url = urljoin(self.base_url.strip("/"), "/api/prompt-driver")

params = self._base_params(prompt_stack)
logger.debug(params)
response = requests.post(url, headers=self.headers, json=params)
response.raise_for_status()
response_json = response.json()
logger.debug(response_json)

return Message.from_dict(response_json)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
url = urljoin(self.base_url.strip("/"), "/api/prompt-driver")
params = self._base_params(prompt_stack)
logger.debug(params)
with requests.post(url, headers=self.headers, json=params, stream=True) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
decoded_line = line.decode("utf-8")
if decoded_line.startswith("data: "):
delta_message_payload = decoded_line[6:]
logger.debug(delta_message_payload)
yield DeltaMessage.from_json(delta_message_payload)

def _base_params(self, prompt_stack: PromptStack) -> dict:
return {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"model": self.model,
"use_native_tools": self.use_native_tools,
"structured_output_strategy": self.structured_output_strategy,
"extra_params": self.extra_params,
"prompt_stack": prompt_stack.to_dict(),
}

0 comments on commit 914bfb5

Please sign in to comment.