Skip to content

Commit

Permalink
Ignore empty choices in openai autolog (mlflow#13372)
Browse files Browse the repository at this point in the history
Signed-off-by: harupy <[email protected]>
  • Loading branch information
harupy authored Oct 11, 2024
1 parent 33116b8 commit 7f46072
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 5 deletions.
5 changes: 3 additions & 2 deletions mlflow/openai/_openai_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ def _stream_output_logging_hook(stream: Iterator) -> Iterator:
chunks = []
output = []
for chunk in stream:
if isinstance(chunk, Completion):
# `chunk.choices` can be empty: https://github.com/mlflow/mlflow/issues/13361
if isinstance(chunk, Completion) and chunk.choices:
output.append(chunk.choices[0].text or "")
elif isinstance(chunk, ChatCompletionChunk):
elif isinstance(chunk, ChatCompletionChunk) and chunk.choices:
output.append(chunk.choices[0].delta.content or "")
chunks.append(chunk)
yield chunk
Expand Down
56 changes: 53 additions & 3 deletions tests/openai/mock_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pydantic import BaseModel
from starlette.responses import StreamingResponse

EMPTY_CHOICES = "EMPTY_CHOICES"

app = fastapi.FastAPI()


Expand Down Expand Up @@ -78,11 +80,28 @@ def _make_chat_stream_chunk(content):
}


def _make_chat_stream_chunk_empty_choices():
return {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1677652288,
"model": "gpt-4o-mini",
"system_fingerprint": "fp_44709d6fcb",
"choices": [],
"usage": None,
}


async def chat_response_stream():
yield _make_chat_stream_chunk("Hello")
yield _make_chat_stream_chunk(" world")


async def chat_response_stream_empty_choices():
yield _make_chat_stream_chunk_empty_choices()
yield _make_chat_stream_chunk("Hello")


@app.post("/chat/completions")
async def chat(payload: ChatPayload):
if not 0.0 <= payload.temperature <= 2.0:
Expand All @@ -92,8 +111,15 @@ async def chat(payload: ChatPayload):
)
if payload.stream:
# SSE stream
if EMPTY_CHOICES == payload.messages[0].content:
content = (
f"data: {json.dumps(d)}\n\n" async for d in chat_response_stream_empty_choices()
)
else:
content = (f"data: {json.dumps(d)}\n\n" async for d in chat_response_stream())

return StreamingResponse(
(f"data: {json.dumps(d)}\n\n" async for d in chat_response_stream()),
content,
media_type="text/event-stream",
)
else:
Expand Down Expand Up @@ -136,17 +162,41 @@ def _make_completions_stream_chunk(content):
}


def _make_completions_stream_chunk_empty_choices():
return {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-4o-mini",
"choices": [],
"system_fingerprint": None,
"usage": None,
}


async def completions_response_stream():
yield _make_completions_stream_chunk("Hello")
yield _make_completions_stream_chunk(" world")


async def completions_response_stream_empty_choices():
yield _make_completions_stream_chunk_empty_choices()
yield _make_completions_stream_chunk("Hello")


@app.post("/completions")
def completions(payload: CompletionsPayload):
if payload.stream:
# SSE stream
if EMPTY_CHOICES == payload.prompt:
content = (
f"data: {json.dumps(d)}\n\n"
async for d in completions_response_stream_empty_choices()
)
else:
content = (f"data: {json.dumps(d)}\n\n" async for d in completions_response_stream())

return StreamingResponse(
(f"data: {json.dumps(d)}\n\n" async for d in completions_response_stream()),
content,
media_type="text/event-stream",
)
else:
Expand Down
41 changes: 41 additions & 0 deletions tests/openai/test_openai_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mlflow.tracing.constant import TraceMetadataKey

from tests.openai.conftest import is_v1
from tests.openai.mock_openai import EMPTY_CHOICES


@pytest.fixture
Expand Down Expand Up @@ -139,6 +140,26 @@ def test_chat_completions_autolog_tracing_error(client):
assert span.events[0].attributes["exception.type"] == "BadRequestError"


def test_chat_completions_streaming_empty_choices(client):
mlflow.openai.autolog()
stream = client.chat.completions.create(
messages=[{"role": "user", "content": EMPTY_CHOICES}],
model="gpt-4o-mini",
stream=True,
)

# Ensure the stream has a chunk with empty choices
first_chunk = next(stream)
assert first_chunk.choices == []

# Exhaust the stream
for _ in stream:
pass

trace = mlflow.get_last_active_trace()
assert trace.info.status == "OK"


@pytest.mark.skipif(not is_v1, reason="Requires OpenAI SDK v1")
@pytest.mark.parametrize("log_models", [True, False])
def test_completions_autolog(client, log_models):
Expand Down Expand Up @@ -172,6 +193,26 @@ def test_completions_autolog(client, log_models):
assert TraceMetadataKey.SOURCE_RUN not in trace.info.request_metadata


def test_completions_autolog_streaming_empty_choices(client):
mlflow.openai.autolog()
stream = client.completions.create(
prompt=EMPTY_CHOICES,
model="gpt-4o-mini",
stream=True,
)

# Ensure the stream has a chunk with empty choices
first_chunk = next(stream)
assert first_chunk.choices == []

# Exhaust the stream
for _ in stream:
pass

trace = mlflow.get_last_active_trace()
assert trace.info.status == "OK"


@pytest.mark.skipif(not is_v1, reason="Requires OpenAI SDK v1")
def test_completions_autolog_streaming(client):
mlflow.openai.autolog()
Expand Down

0 comments on commit 7f46072

Please sign in to comment.