diff --git a/morpheus/llm/services/openai_chat_service.py b/morpheus/llm/services/openai_chat_service.py index 45db9e1114..021ef48785 100644 --- a/morpheus/llm/services/openai_chat_service.py +++ b/morpheus/llm/services/openai_chat_service.py @@ -339,7 +339,7 @@ def __init__(self, *, api_key: str = None, base_url: str = None, default_model_k The API key for the LLM service, by default None. If `None` the API key will be read from the `OPENAI_API_KEY` environment variable. If neither are present an error will be raised. base_url : str, optional - The api host url, by default None. If `None` the url will be read from the `OPENAI_API_BASE` environment + The api host url, by default None. If `None` the url will be read from the `OPENAI_BASE_URL` environment variable. If neither are present the OpenAI default will be used., by default None default_model_kwargs : dict, optional Default arguments to use when creating a client via the `get_client` function. Any argument specified here @@ -369,6 +369,9 @@ def __init__(self, *, api_key: str = None, base_url: str = None, default_model_k log_file = os.path.join(appdirs.user_log_dir(appauthor="NVIDIA", appname="morpheus"), "openai.log") + # Ensure the log directory exists + os.makedirs(os.path.dirname(log_file), exist_ok=True) + # Add a file handler file_handler = logging.FileHandler(log_file) diff --git a/tests/llm/services/test_openai_chat_client.py b/tests/llm/services/test_openai_chat_client.py index 577c83c7bb..4976eb1655 100644 --- a/tests/llm/services/test_openai_chat_client.py +++ b/tests/llm/services/test_openai_chat_client.py @@ -19,18 +19,29 @@ import pytest from _utils.llm import mk_mock_openai_response -from morpheus.llm.services.llm_service import LLMClient -from morpheus.llm.services.openai_chat_service import OpenAIChatClient from morpheus.llm.services.openai_chat_service import OpenAIChatService +@pytest.mark.parametrize("api_key", ["12345", None]) +@pytest.mark.parametrize("base_url", ["http://test.openai.com/v1", None]) @pytest.mark.parametrize("max_retries", [5, 10]) -def test_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], max_retries: int): - client = OpenAIChatClient(OpenAIChatService(), model_name="test_model", max_retries=max_retries) - assert isinstance(client, LLMClient) +def test_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], + api_key: str, + base_url: str, + max_retries: int): + OpenAIChatService(api_key=api_key, base_url=base_url).get_client(model_name="test_model", max_retries=max_retries) for mock_client in mock_chat_completion: - mock_client.assert_called_once_with(max_retries=max_retries) + mock_client.assert_called_once_with(api_key=api_key, base_url=base_url, max_retries=max_retries) + + +@pytest.mark.parametrize("max_retries", [5, 10]) +def test_constructor_default_service_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], + max_retries: int): + OpenAIChatService().get_client(model_name="test_model", max_retries=max_retries) + + for mock_client in mock_chat_completion: + mock_client.assert_called_once_with(api_key=None, base_url=None, max_retries=max_retries) @pytest.mark.parametrize("use_async", [True, False]) @@ -56,10 +67,10 @@ def test_generate(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock], expected_messages: list[dict], temperature: int): (mock_client, mock_async_client) = mock_chat_completion - client = OpenAIChatClient(OpenAIChatService(), - model_name="test_model", - set_assistant=set_assistant, - temperature=temperature) + client = OpenAIChatService().get_client(model_name="test_model", + set_assistant=set_assistant, + temperature=temperature) + if use_async: results = asyncio.run(client.generate_async(**input_dict)) mock_async_client.chat.completions.create.assert_called_once_with(model="test_model", @@ -108,10 +119,9 @@ def test_generate_batch(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMo expected_messages: list[list[dict]], temperature: int): (mock_client, mock_async_client) = mock_chat_completion - client = OpenAIChatClient(OpenAIChatService(), - model_name="test_model", - set_assistant=set_assistant, - temperature=temperature) + client = OpenAIChatService().get_client(model_name="test_model", + set_assistant=set_assistant, + temperature=temperature) expected_results = ["test_output" for _ in range(len(inputs["prompt"]))] expected_calls = [ @@ -134,7 +144,7 @@ def test_generate_batch(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMo @pytest.mark.parametrize("completion", [[], [None]], ids=["no_choices", "no_content"]) @pytest.mark.usefixtures("mock_chat_completion") def test_extract_completion_errors(completion: list): - client = OpenAIChatClient(OpenAIChatService(), model_name="test_model") + client = OpenAIChatService().get_client(model_name="test_model") mock_completion = mk_mock_openai_response(completion) with pytest.raises(ValueError):