Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and unit test updates to OpenAI client #1698

Merged
14 changes: 11 additions & 3 deletions morpheus/llm/services/openai_chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def __init__(self,

# Create the client objects for both sync and async
self._client = openai.OpenAI(api_key=parent._api_key, base_url=parent._base_url, max_retries=max_retries)
self._client_async = openai.AsyncOpenAI(api_key=parent._api_key, base_url=parent._base_url, max_retries=max_retries)
self._client_async = openai.AsyncOpenAI(api_key=parent._api_key,
base_url=parent._base_url,
max_retries=max_retries)

def get_input_names(self) -> list[str]:
input_names = [self._prompt_key]
Expand Down Expand Up @@ -326,8 +328,9 @@ def __init__(self, *, api_key: str = None, base_url: str = None, default_model_k
The API key for the LLM service, by default None. If `None` the API key will be read from the
`OPENAI_API_KEY` environment variable. If neither are present an error will be raised.
base_url : str, optional
The api host url, by default None. If `None` the url will be read from the `OPENAI_API_BASE` environment
variable. If neither are present the OpenAI default will be used., by default None
The api host url, by default None. If the `OPENAI_BASE_URL` environment variable is present,
it will always take precedence over this parameter. If neither are present the OpenAI default will
be used., by default None
default_model_kwargs : dict, optional
Default arguments to use when creating a client via the `get_client` function. Any argument specified here
will automatically be used when calling `get_client`. Arguments specified in the `get_client` function will
Expand Down Expand Up @@ -356,6 +359,9 @@ def __init__(self, *, api_key: str = None, base_url: str = None, default_model_k

log_file = os.path.join(appdirs.user_log_dir(appauthor="NVIDIA", appname="morpheus"), "openai.log")

# Ensure the log directory exists
os.makedirs(os.path.dirname(log_file), exist_ok=True)

# Add a file handler
file_handler = logging.FileHandler(log_file)

Expand Down Expand Up @@ -398,6 +404,8 @@ def get_client(self,
"""

final_model_kwargs = {**self._default_model_kwargs, **model_kwargs}
final_model_kwargs.pop("base_url", None)
final_model_kwargs.pop("api_key", None)

return OpenAIChatClient(self,
model_name=model_name,
Expand Down
21 changes: 19 additions & 2 deletions tests/llm/services/test_openai_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,30 @@
from morpheus.llm.services.openai_chat_service import OpenAIChatService


@pytest.mark.parametrize("api_key", ["12345", None])
@pytest.mark.parametrize("base_url", ["http://test.openai.com/v1", None])
@pytest.mark.parametrize("max_retries", [5, 10])
def test_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], max_retries: int):
def test_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock],
api_key: str,
base_url: str,
max_retries: int):
client = OpenAIChatClient(OpenAIChatService(api_key=api_key, base_url=base_url),
model_name="test_model",
max_retries=max_retries)
assert isinstance(client, LLMClient)

for mock_client in mock_chat_completion:
mock_client.assert_called_once_with(api_key=api_key, base_url=base_url, max_retries=max_retries)


@pytest.mark.parametrize("max_retries", [5, 10])
def test_constructor_default_service_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock],
max_retries: int):
client = OpenAIChatClient(OpenAIChatService(), model_name="test_model", max_retries=max_retries)
assert isinstance(client, LLMClient)

for mock_client in mock_chat_completion:
mock_client.assert_called_once_with(max_retries=max_retries)
mock_client.assert_called_once_with(api_key=None, base_url=None, max_retries=max_retries)


@pytest.mark.parametrize("use_async", [True, False])
Expand Down
Loading