Skip to content

Commit

Permalink
Fix OpenAI autolog (mlflow#14276)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
Signed-off-by: k99kurella <[email protected]>
  • Loading branch information
B-Step62 authored and karthikkurella committed Jan 30, 2025
1 parent 45572e5 commit 97b606d
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 54 deletions.
68 changes: 26 additions & 42 deletions mlflow/openai/_openai_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mlflow.openai.utils.chat_schema import set_span_chat_attributes
from mlflow.tracing.constant import TraceMetadataKey
from mlflow.tracing.trace_manager import InMemoryTraceManager
from mlflow.tracing.utils import TraceJSONEncoder, start_client_span_or_trace
from mlflow.tracking.context import registry as context_registry
from mlflow.tracking.fluent import _get_experiment_id
from mlflow.utils.autologging_utils import disable_autologging, get_autologging_config
Expand Down Expand Up @@ -72,14 +73,6 @@ def _set_api_key_env_var(client):
os.environ.pop("OPENAI_API_KEY")


class _OpenAIJsonEncoder(json.JSONEncoder):
def default(self, o):
try:
return super().default(o)
except TypeError:
return str(o)


def _get_span_type(task) -> str:
from openai.resources.chat.completions import Completions as ChatCompletions
from openai.resources.completions import Completions
Expand Down Expand Up @@ -152,22 +145,13 @@ def patched_call(original, self, *args, **kwargs):
attributes = {k: v for k, v in kwargs.items() if k != "messages"}

# If there is an active span, create a child span under it, otherwise create a new trace
if active_span := mlflow.get_current_active_span():
span = mlflow_client.start_span(
name=self.__class__.__name__,
request_id=active_span.request_id,
parent_id=active_span.span_id,
span_type=_get_span_type(self.__class__),
inputs=kwargs,
attributes=attributes,
)
else:
span = mlflow_client.start_trace(
name=self.__class__.__name__,
span_type=_get_span_type(self.__class__),
inputs=kwargs,
attributes=attributes,
)
span = start_client_span_or_trace(
mlflow_client,
name=self.__class__.__name__,
span_type=_get_span_type(self.__class__),
inputs=kwargs,
attributes=attributes,
)

request_id = span.request_id
# Associate run ID to the trace manually, because if a new run is created by
Expand All @@ -185,7 +169,7 @@ def patched_call(original, self, *args, **kwargs):
if config.log_traces and request_id:
try:
span.add_event(SpanEvent.from_exception(e))
mlflow_client.end_trace(request_id=request_id, status=SpanStatusCode.ERROR)
mlflow_client.end_span(request_id, span.span_id, status=SpanStatusCode.ERROR)
except Exception as inner_e:
_logger.warning(f"Encountered unexpected error when ending trace: {inner_e}")
raise e
Expand All @@ -196,29 +180,34 @@ def patched_call(original, self, *args, **kwargs):
# If the output is a stream, we add a hook to store the intermediate chunks
# and then log the outputs as a single artifact when the stream ends
def _stream_output_logging_hook(stream: Iterator) -> Iterator:
chunks = []
output = []
for chunk in stream:
for i, chunk in enumerate(stream):
# `chunk.choices` can be empty: https://github.com/mlflow/mlflow/issues/13361
if isinstance(chunk, Completion) and chunk.choices:
output.append(chunk.choices[0].text or "")
elif isinstance(chunk, ChatCompletionChunk) and chunk.choices:
output.append(chunk.choices[0].delta.content or "")
chunks.append(chunk)

# Record the raw chunks as events
span.add_event(
SpanEvent(
name=f"chunk_{i}",
# NB: OTel event attribute only accepts dictionary with primitive types
# (or list or them), not nested dict.
# TODO: Define a consistent format for stream chunk across all providers
attributes={
chunk.__class__.__name__: json.dumps(chunk, cls=TraceJSONEncoder)
},
)
)

yield chunk

try:
chunk_dicts = [chunk.to_dict() for chunk in chunks]
if config.log_traces and request_id:
outputs = "".join(output)

set_span_chat_attributes(span, kwargs, outputs)

mlflow_client.end_trace(
request_id=request_id,
attributes={"events": chunk_dicts},
outputs=outputs,
)
mlflow_client.end_span(request_id, span.span_id, outputs=outputs)
except Exception as e:
_logger.warning(f"Encountered unexpected error during openai autologging: {e}")

Expand All @@ -227,12 +216,7 @@ def _stream_output_logging_hook(stream: Iterator) -> Iterator:
if config.log_traces and request_id:
try:
set_span_chat_attributes(span, kwargs, result)
if span.parent_id is None:
mlflow_client.end_trace(request_id=request_id, outputs=result)
else:
mlflow_client.end_span(
request_id=request_id, span_id=span.span_id, outputs=result
)
mlflow_client.end_span(request_id, span.span_id, outputs=result)
except Exception as e:
_logger.warning(f"Encountered unexpected error when ending trace: {e}")

Expand Down
2 changes: 1 addition & 1 deletion mlflow/tracing/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def start_client_span_or_trace(
return client.start_span(
name=name,
request_id=parent_span.request_id,
span_id=parent_span.span_id,
parent_id=parent_span.span_id,
span_type=span_type,
inputs=inputs,
attributes=attributes,
Expand Down
71 changes: 60 additions & 11 deletions tests/openai/test_openai_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import mlflow
from mlflow import MlflowClient
from mlflow.exceptions import MlflowException
from mlflow.tracing.constant import SpanAttributeKey, TraceMetadataKey

from tests.openai.conftest import is_v1
Expand Down Expand Up @@ -136,12 +137,15 @@ def test_chat_completions_autolog_streaming(client):
}
assert span.outputs == "Hello world" # aggregated string of streaming response

stream_event_data = trace.data.spans[0].attributes["events"]

assert stream_event_data[0]["id"] == "chatcmpl-123"
assert stream_event_data[0]["choices"][0]["delta"]["content"] == "Hello"
assert stream_event_data[1]["id"] == "chatcmpl-123"
assert stream_event_data[1]["choices"][0]["delta"]["content"] == " world"
stream_event_data = trace.data.spans[0].events
assert stream_event_data[0].name == "chunk_0"
chunk_1 = json.loads(stream_event_data[0].attributes["ChatCompletionChunk"])
assert chunk_1["id"] == "chatcmpl-123"
assert chunk_1["choices"][0]["delta"]["content"] == "Hello"
assert stream_event_data[1].name == "chunk_1"
chunk_2 = json.loads(stream_event_data[1].attributes["ChatCompletionChunk"])
assert chunk_2["id"] == "chatcmpl-123"
assert chunk_2["choices"][0]["delta"]["content"] == " world"


@pytest.mark.skipif(not is_v1, reason="Requires OpenAI SDK v1")
Expand Down Expand Up @@ -170,6 +174,47 @@ def test_chat_completions_autolog_tracing_error(client):
assert span.events[0].attributes["exception.type"] == "UnprocessableEntityError"


@pytest.mark.skipif(not is_v1, reason="Requires OpenAI SDK v1")
def test_chat_completions_autolog_tracing_error_with_parent_span(client):
mlflow.openai.autolog()

@mlflow.trace
def create_completions(text: str) -> str:
try:
response = client.chat.completions.create(
messages=[{"role": "user", "content": text}],
model="gpt-4o-mini",
temperature=5.0,
)
return response.choices[0].delta.content
except openai.OpenAIError as e:
raise MlflowException("Failed to create completions") from e

with pytest.raises(MlflowException, match="Failed to create completions"):
create_completions("test")

trace = mlflow.get_last_active_trace()
assert trace.info.status == "ERROR"

assert len(trace.data.spans) == 2
parent_span = trace.data.spans[0]
assert parent_span.name == "create_completions"
assert parent_span.inputs == {"text": "test"}
assert parent_span.outputs is None
assert parent_span.status.status_code == "ERROR"
assert parent_span.events[0].name == "exception"
assert parent_span.events[0].attributes["exception.type"] == "mlflow.exceptions.MlflowException"
assert parent_span.events[0].attributes["exception.message"] == "Failed to create completions"

child_span = trace.data.spans[1]
assert child_span.name == "Completions"
assert child_span.inputs["messages"][0]["content"] == "test"
assert child_span.outputs is None
assert child_span.status.status_code == "ERROR"
assert child_span.events[0].name == "exception"
assert child_span.events[0].attributes["exception.type"] == "UnprocessableEntityError"


def test_chat_completions_streaming_empty_choices(client):
mlflow.openai.autolog()
stream = client.chat.completions.create(
Expand Down Expand Up @@ -270,12 +315,16 @@ def test_completions_autolog_streaming(client):
}
assert span.outputs == "Hello world" # aggregated string of streaming response

stream_event_data = trace.data.spans[0].attributes["events"]
stream_event_data = trace.data.spans[0].events

assert stream_event_data[0]["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"
assert stream_event_data[0]["choices"][0]["text"] == "Hello"
assert stream_event_data[1]["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"
assert stream_event_data[1]["choices"][0]["text"] == " world"
assert stream_event_data[0].name == "chunk_0"
chunk_1 = json.loads(stream_event_data[0].attributes["Completion"])
assert chunk_1["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"
assert chunk_1["choices"][0]["text"] == "Hello"
assert stream_event_data[1].name == "chunk_1"
chunk_2 = json.loads(stream_event_data[1].attributes["Completion"])
assert chunk_2["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"
assert chunk_2["choices"][0]["text"] == " world"


@pytest.mark.skipif(not is_v1, reason="Requires OpenAI SDK v1")
Expand Down

0 comments on commit 97b606d

Please sign in to comment.