Skip to content

Commit

Permalink
fix(drivers-prompt-azure-openai): fix AzureOpenAiChatPromptDriver by …
Browse files Browse the repository at this point in the history
…removing unsupported "modalities"
  • Loading branch information
collindutter committed Feb 11, 2025
1 parent 914bfb5 commit 7a095b6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
del params["stream_options"]
if "parallel_tool_calls" in params:
del params["parallel_tool_calls"]

# TODO: Add once Azure supports modalities
del params["modalities"]
return params
30 changes: 22 additions & 8 deletions griptape/drivers/prompt/griptape_cloud_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from griptape.common import DeltaMessage, Message, PromptStack, observable
from griptape.configs.defaults_config import Defaults
from griptape.drivers.prompt import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, SimpleTokenizer

if TYPE_CHECKING:
from collections.abc import Iterator
Expand All @@ -29,10 +30,21 @@ class GriptapeCloudPromptDriver(BasePromptDriver):
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
kw_only=True,
)
tokenizer: BaseTokenizer = field(
default=Factory(
lambda self: SimpleTokenizer(
characters_per_token=4,
max_input_tokens=2000,
max_output_tokens=self.max_tokens,
),
takes_self=True,
),
kw_only=True,
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
url = urljoin(self.base_url.strip("/"), "/api/prompt-driver")
url = urljoin(self.base_url.strip("/"), "/api/chat/completions")

params = self._base_params(prompt_stack)
logger.debug(params)
Expand All @@ -45,7 +57,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
url = urljoin(self.base_url.strip("/"), "/api/prompt-driver")
url = urljoin(self.base_url.strip("/"), "/api/chat/completions")
params = self._base_params(prompt_stack)
logger.debug(params)
with requests.post(url, headers=self.headers, json=params, stream=True) as response:
Expand All @@ -60,11 +72,13 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:

def _base_params(self, prompt_stack: PromptStack) -> dict:
return {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"model": self.model,
"use_native_tools": self.use_native_tools,
"structured_output_strategy": self.structured_output_strategy,
"extra_params": self.extra_params,
"prompt_stack": prompt_stack.to_dict(),
"model": self.model,
"params": {
"max_tokens": self.max_tokens,
"use_native_tools": self.use_native_tools,
"temperature": self.temperature,
"structured_output_strategy": self.structured_output_strategy,
"extra_params": self.extra_params,
},
}

0 comments on commit 7a095b6

Please sign in to comment.