Skip to content

Commit

Permalink
fix: update GeminiModelSettings type handling - Use type casting to h…
Browse files Browse the repository at this point in the history
…andle model settings conversion - Keep GeminiModelSettings as a separate type - Fix safety settings handling with sequences
  • Loading branch information
hafsatariq18 committed Feb 7, 2025
1 parent d6e193d commit 2af1b59
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import datetime
from typing import Annotated, Any, Literal, Protocol, Union
from typing import Annotated, Any, Literal, Protocol, TypedDict, Union, cast
from uuid import uuid4

import pydantic
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
from typing_extensions import NotRequired, TypedDict, assert_never
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse, Timeout
from typing_extensions import NotRequired, assert_never

from .. import UnexpectedModelBehavior, _utils, exceptions, usage
from ..messages import (
Expand Down Expand Up @@ -59,12 +59,27 @@
"""


class GeminiModelSettings(TypedDict, total=False):
"""Settings used for a Gemini model request."""

max_tokens: int
temperature: float
top_p: float
timeout: float | Timeout
parallel_tool_calls: bool
seed: int
presence_penalty: float
frequency_penalty: float
logit_bias: dict[str, int]
gemini_safety_settings: Sequence[GeminiSafetySettings]


@dataclass(init=False)
class GeminiModel(Model):
"""A model that uses Gemini via `generativelanguage.googleapis.com` API.
This is implemented from scratch rather than using a dedicated SDK, good API documentation is
available [here](https://ai.google.dev/api).
available [here](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
Apart from `__init__`, all methods are private or match those of the base class.
"""
Expand Down Expand Up @@ -122,7 +137,9 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> tuple[ModelResponse, usage.Usage]:
check_allow_model_requests()
async with self._make_request(messages, False, model_settings or {}, model_request_parameters) as http_response:
async with self._make_request(
messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
) as http_response:
response = _gemini_response_ta.validate_json(await http_response.aread())
return self._process_response(response), _metadata_as_usage(response)

Expand All @@ -134,7 +151,9 @@ async def request_stream(
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[StreamedResponse]:
check_allow_model_requests()
async with self._make_request(messages, True, model_settings or {}, model_request_parameters) as http_response:
async with self._make_request(
messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
) as http_response:
yield await self._process_streamed_response(http_response)

def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
Expand All @@ -158,7 +177,7 @@ async def _make_request(
self,
messages: list[ModelMessage],
streamed: bool,
model_settings: ModelSettings,
model_settings: GeminiModelSettings,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[HTTPResponse]:
tools = self._get_tools(model_request_parameters)
Expand Down Expand Up @@ -186,8 +205,8 @@ async def _make_request(
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
generation_config['frequency_penalty'] = frequency_penalty
if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) is not None:
if gemini_safety_settings: # Only set if non-empty list
request_data['safety_settings'] = gemini_safety_settings
if gemini_safety_settings: # Only set if non-empty sequence
request_data['safety_settings'] = list(gemini_safety_settings)
if generation_config:
request_data['generation_config'] = generation_config

Expand Down

0 comments on commit 2af1b59

Please sign in to comment.