Skip to content

Commit

Permalink
Fix azure openai hanging problem (#10153)
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <[email protected]>
  • Loading branch information
serena-ruan authored Oct 27, 2023
1 parent b8cf267 commit cf17437
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 15 deletions.
78 changes: 64 additions & 14 deletions mlflow/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@
import mlflow
from mlflow import pyfunc
from mlflow.environment_variables import _MLFLOW_TESTING, MLFLOW_OPENAI_SECRET_SCOPE
from mlflow.exceptions import MlflowException
from mlflow.models import Model, ModelInputExample, ModelSignature
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.utils import _save_example
from mlflow.openai.utils import _OAITokenHolder, _validate_model_params
from mlflow.openai.utils import _exclude_params_from_envs, _OAITokenHolder, _validate_model_params
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
Expand Down Expand Up @@ -90,6 +91,10 @@ class _OpenAIApiConfig(NamedTuple):
batch_size: int
max_requests_per_minute: int
max_tokens_per_minute: int
api_version: Optional[str]
api_base: str
engine: Optional[str]
deployment_id: Optional[str]


@experimental
Expand Down Expand Up @@ -185,23 +190,29 @@ def _get_api_config() -> _OpenAIApiConfig:
import openai

api_type = os.getenv(_OpenAIEnvVar.OPENAI_API_TYPE.value, openai.api_type)
api_version = os.getenv(_OpenAIEnvVar.OPENAI_API_VERSION.value, openai.api_version)
api_base = os.getenv(_OpenAIEnvVar.OPENAI_API_BASE.value, openai.api_base)
engine = os.getenv(_OpenAIEnvVar.OPENAI_ENGINE.value, None)
deployment_id = os.getenv(_OpenAIEnvVar.OPENAI_DEPLOYMENT_NAME.value, None)
if api_type in ("azure", "azure_ad", "azuread"):
return _OpenAIApiConfig(
api_type=api_type,
batch_size=16,
max_requests_per_minute=3_500,
max_tokens_per_minute=60_000,
)
batch_size = 16
max_tokens_per_minute = 60_000
else:
# The maximum batch size is 2048:
# https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/embeddings_utils.py#L43
# We use a smaller batch size to be safe.
return _OpenAIApiConfig(
api_type=api_type,
batch_size=1024,
max_requests_per_minute=3_500,
max_tokens_per_minute=90_000,
)
batch_size = 1024
max_tokens_per_minute = 90_000
return _OpenAIApiConfig(
api_type=api_type,
batch_size=batch_size,
max_requests_per_minute=3_500,
max_tokens_per_minute=max_tokens_per_minute,
api_base=api_base,
api_version=api_version,
engine=engine,
deployment_id=deployment_id,
)


def _get_openai_package_version():
Expand All @@ -221,7 +232,12 @@ class _OpenAIEnvVar(str, Enum):
OPENAI_API_BASE = "OPENAI_API_BASE"
OPENAI_API_KEY = "OPENAI_API_KEY"
OPENAI_API_KEY_PATH = "OPENAI_API_KEY_PATH"
OPENAI_API_VERSION = "OPENAI_API_VERSION"
OPENAI_ORGANIZATION = "OPENAI_ORGANIZATION"
OPENAI_ENGINE = "OPENAI_ENGINE"
# use deployment_name instead of deployment_id to be
# consistent with gateway
OPENAI_DEPLOYMENT_NAME = "OPENAI_DEPLOYMENT_NAME"

@property
def secret_key(self):
Expand Down Expand Up @@ -662,6 +678,33 @@ def __init__(self, model):
self.task = task
self.api_config = _get_api_config()
self.api_token = _OAITokenHolder(self.api_config.api_type)
# If the same parameter exists in self.model & self.api_config,
# we use the parameter from self.model
self.envs = {
x: getattr(self.api_config, x)
for x in ["api_base", "api_version", "api_type", "engine", "deployment_id"]
if getattr(self.api_config, x) is not None and x not in self.model
}
api_type = self.model.get("api_type") or self.envs.get("api_type")
if api_type in ("azure", "azure_ad", "azuread"):
deployment_id = self.model.get("deployment_id") or self.envs.get("deployment_id")
if self.model.get("engine") or self.envs.get("engine"):
# Avoid using both parameters as they serve the same purpose
# Invalid inputs:
# - Wrong engine + correct/wrong deployment_id
# - No engine + wrong deployment_id
# Valid inputs:
# - Correct engine + correct/wrong deployment_id
# - No engine + correct deployment_id
if deployment_id is not None:
_logger.warning(
"Both engine and deployment_id are set. "
"Using engine as it takes precedence."
)
elif deployment_id is None:
raise MlflowException(
"Either engine or deployment_id must be set for Azure OpenAI API",
)

if self.task != "embeddings":
self._setup_completions()
Expand Down Expand Up @@ -693,8 +736,11 @@ def _predict_chat(self, data, params):
from mlflow.openai.api_request_parallel_processor import process_api_requests

_validate_model_params(self.task, self.model, params)
envs = _exclude_params_from_envs(params, self.envs)
messages_list = self.format_completions(self.get_params_list(data))
requests = [{**self.model, **params, "messages": messages} for messages in messages_list]
requests = [
{**self.model, **envs, **params, "messages": messages} for messages in messages_list
]
results = process_api_requests(
requests,
openai.ChatCompletion,
Expand All @@ -710,13 +756,15 @@ def _predict_completions(self, data, params):
from mlflow.openai.api_request_parallel_processor import process_api_requests

_validate_model_params(self.task, self.model, params)
envs = _exclude_params_from_envs(params, self.envs)
prompts_list = self.format_completions(self.get_params_list(data))

batch_size = params.pop("batch_size", self.api_config.batch_size)
_logger.debug(f"Requests are being batched by {batch_size} samples.")
requests = [
{
**self.model,
**envs,
**params,
"prompt": prompts_list[i : i + batch_size],
}
Expand All @@ -737,6 +785,7 @@ def _predict_embeddings(self, data, params):
from mlflow.openai.api_request_parallel_processor import process_api_requests

_validate_model_params(self.task, self.model, params)
envs = _exclude_params_from_envs(params, self.envs)
batch_size = params.pop("batch_size", self.api_config.batch_size)
_logger.debug(f"Requests are being batched by {batch_size} samples.")

Expand All @@ -745,6 +794,7 @@ def _predict_embeddings(self, data, params):
requests = [
{
**self.model,
**envs,
**params,
"input": texts[i : i + batch_size],
}
Expand Down
7 changes: 6 additions & 1 deletion mlflow/openai/api_request_parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,15 @@ class APIRequest:
token_consumption: int
attempts_left: int
results: list[tuple[int, OpenAIObject]]
timeout: int = 60

def call_api(self, retry_queue: queue.Queue, status_tracker: StatusTracker):
"""
Calls the OpenAI API and stores results.
"""
_logger.debug(f"Request #{self.index} started")
try:
response = self.task.create(**self.request_json)
response = self.task.create(**self.request_json, timeout=self.timeout)
_logger.debug(f"Request #{self.index} succeeded")
status_tracker.complete_task(success=True)
self.results.append((self.index, response))
Expand Down Expand Up @@ -128,6 +129,10 @@ def call_api(self, retry_queue: queue.Queue, status_tracker: StatusTracker):
)
status_tracker.increment_num_api_errors()
status_tracker.complete_task(success=False)
else:
_logger.warning(f"Request #{self.index} failed with {e!r}")
status_tracker.increment_num_api_errors()
status_tracker.complete_task(success=False)
except Exception as e:
_logger.debug(f"Request #{self.index} failed with {e!r}")
status_tracker.increment_num_api_errors()
Expand Down
7 changes: 7 additions & 0 deletions mlflow/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def _validate_model_params(task, model, params):
)


def _exclude_params_from_envs(params, envs):
"""
params passed at inference time should override envs.
"""
return {k: v for k, v in envs.items() if k not in params} if params else envs


class _OAITokenHolder:
def __init__(self, api_type):
import openai
Expand Down
37 changes: 37 additions & 0 deletions tests/openai/test_openai_model_export.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import json
from copy import deepcopy
from unittest import mock

import numpy as np
Expand All @@ -12,8 +13,10 @@

import mlflow
import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
from mlflow.exceptions import MlflowException
from mlflow.models.signature import ModelSignature
from mlflow.openai.utils import (
_exclude_params_from_envs,
_mock_chat_completion_response,
_mock_models_retrieve_response,
_mock_request,
Expand Down Expand Up @@ -458,6 +461,9 @@ def test_save_model_with_secret_scope(tmp_path, monkeypatch):
"OPENAI_API_KEY_PATH": f"{scope}:openai_api_key_path",
"OPENAI_API_BASE": f"{scope}:openai_api_base",
"OPENAI_ORGANIZATION": f"{scope}:openai_organization",
"OPENAI_API_VERSION": f"{scope}:openai_api_version",
"OPENAI_DEPLOYMENT_NAME": f"{scope}:openai_deployment_name",
"OPENAI_ENGINE": f"{scope}:openai_engine",
}


Expand Down Expand Up @@ -565,6 +571,7 @@ def test_embeddings(tmp_path):

def test_embeddings_batch_size_azure(tmp_path, monkeypatch):
monkeypatch.setenv("OPENAI_API_TYPE", "azure")
monkeypatch.setenv("OPENAI_ENGINE", "test_engine")
mlflow.openai.save_model(
model="text-embedding-ada-002",
task=openai.Embedding,
Expand Down Expand Up @@ -648,3 +655,33 @@ def test_inference_params_overlap(tmp_path):
params=ParamSchema([ParamSpec(name="prefix", default=None, dtype="string")]),
),
)


def test_engine_and_deployment_id_for_azure_openai(tmp_path, monkeypatch):
monkeypatch.setenv("OPENAI_API_TYPE", "azure")
mlflow.openai.save_model(
model="text-embedding-ada-002",
task=openai.Embedding,
path=tmp_path,
)
with pytest.raises(
MlflowException, match=r"Either engine or deployment_id must be set for Azure OpenAI API"
):
mlflow.pyfunc.load_model(tmp_path)


@pytest.mark.parametrize(
("params", "envs"),
[
({"a": None, "b": "b"}, {"a": "a", "c": "c"}),
({"a": "a", "b": "b"}, {"a": "a", "d": "d"}),
({}, {"a": "a", "b": "b"}),
({"a": "a"}, {"b": "b"}),
],
)
def test_exclude_params_from_envs(params, envs):
original_envs = deepcopy(envs)
result = _exclude_params_from_envs(params, envs)
assert envs == original_envs
assert not any(key in params for key in result)
assert all(key in envs for key in result)

0 comments on commit cf17437

Please sign in to comment.