Skip to content

Commit

Permalink
fix(drivers-prompt-openai): conditionally add modalities/reasoning_ef…
Browse files Browse the repository at this point in the history
…fort based on model (#1668)

Co-authored-by: Collin Dutter <[email protected]>
  • Loading branch information
vachillo and collindutter authored Feb 10, 2025
1 parent 3af4640 commit b9311c1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
8 changes: 6 additions & 2 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,12 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"model": self.model,
"user": self.user,
"seed": self.seed,
"modalities": self.modalities,
**({"reasoning_effort": self.reasoning_effort} if self.is_reasoning_model else {}),
**({"modalities": self.modalities} if self.modalities and not self.is_reasoning_model else {}),
**(
{"reasoning_effort": self.reasoning_effort}
if self.is_reasoning_model and self.model != "o1-mini"
else {}
),
**({"temperature": self.temperature} if not self.is_reasoning_model else {}),
**({"audio": self.audio} if "audio" in self.modalities else {}),
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
Expand Down
44 changes: 30 additions & 14 deletions tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def test_init(self):

@pytest.mark.parametrize("use_native_tools", [True, False])
@pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"])
@pytest.mark.parametrize("model", ["gpt-4o", "o1", "o3"])
@pytest.mark.parametrize("model", ["gpt-4o", "o1", "o3", "o3-mini"])
@pytest.mark.parametrize("modalities", [["text"], ["text", "audio"], ["audio"]])
def test_try_run(
self,
Expand Down Expand Up @@ -483,7 +483,11 @@ def test_try_run(
user=driver.user,
messages=reasoning_messages if driver.is_reasoning_model else messages,
seed=driver.seed,
modalities=driver.modalities,
**{
"modalities": driver.modalities,
}
if not driver.is_reasoning_model
else {},
**{
"audio": driver.audio,
}
Expand All @@ -492,7 +496,7 @@ def test_try_run(
**{
"reasoning_effort": driver.reasoning_effort,
}
if driver.is_reasoning_model
if driver.is_reasoning_model and model != "o1-mini"
else {},
**{
"temperature": driver.temperature,
Expand Down Expand Up @@ -559,7 +563,11 @@ def test_try_run_response_format_json_object(self, mock_chat_completion_create,
}
if "audio" in driver.modalities
else {},
modalities=driver.modalities,
**{
"modalities": driver.modalities,
}
if not driver.is_reasoning_model
else {},
response_format={"type": "json_object"},
)
assert message.value[0].value == "model-output"
Expand Down Expand Up @@ -596,7 +604,11 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create,
}
if "audio" in driver.modalities
else {},
modalities=driver.modalities,
**{
"modalities": driver.modalities,
}
if not driver.is_reasoning_model
else {},
response_format={
"json_schema": {
"schema": {
Expand All @@ -619,7 +631,7 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create,

@pytest.mark.parametrize("use_native_tools", [True, False])
@pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"])
@pytest.mark.parametrize("model", ["gpt-4o", "o1", "o3"])
@pytest.mark.parametrize("model", ["gpt-4o", "o1", "o3", "o3-mini"])
@pytest.mark.parametrize("modalities", [["text"], ["text", "audio"], ["audio"]])
def test_try_stream_run(
self,
Expand Down Expand Up @@ -659,12 +671,8 @@ def test_try_stream_run(
}
if "audio" in driver.modalities
else {},
modalities=driver.modalities,
**{
"reasoning_effort": driver.reasoning_effort,
}
if driver.is_reasoning_model
else {},
**{"modalities": driver.modalities} if not driver.is_reasoning_model else {},
**{"reasoning_effort": driver.reasoning_effort} if driver.is_reasoning_model and model != "o1-mini" else {},
**{
"temperature": driver.temperature,
}
Expand Down Expand Up @@ -747,7 +755,11 @@ def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack
}
if "audio" in driver.modalities
else {},
modalities=driver.modalities,
**{
"modalities": driver.modalities,
}
if not driver.is_reasoning_model
else {},
)
assert event.value[0].value == "model-output"

Expand Down Expand Up @@ -788,7 +800,11 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa
}
if "audio" in driver.modalities
else {},
modalities=driver.modalities,
**{
"modalities": driver.modalities,
}
if not driver.is_reasoning_model
else {},
max_tokens=1,
)
assert event.value[0].value == "model-output"

0 comments on commit b9311c1

Please sign in to comment.