Skip to content

Commit

Permalink
Update DSPy tracing to handle tool calling (mlflow#14196)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 authored Jan 9, 2025
1 parent d3ee78b commit 7e0292f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 17 deletions.
27 changes: 27 additions & 0 deletions mlflow/dspy/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,33 @@ def on_adapter_parse_end(
):
self._end_span(call_id, outputs, exception)

def on_tool_start(self, call_id: str, instance: Any, inputs: dict[str, Any]):
# DSPy uses the special "finish" tool to signal the end of the agent.
if instance.name == "finish":
return

inputs = self._unpack_kwargs(inputs)
# Tools are always called with keyword arguments only.
inputs.pop("args", None)

self._start_span(
call_id,
name=f"Tool.{instance.name}",
span_type=SpanType.TOOL,
inputs=inputs,
attributes={
"name": instance.name,
"description": instance.desc,
"args": instance.args,
},
)

def on_tool_end(
self, call_id: str, outputs: Optional[Any], exception: Optional[Exception] = None
):
if call_id in self._call_id_to_span:
self._end_span(call_id, outputs, exception)

def _start_span(
self,
call_id: str,
Expand Down
41 changes: 24 additions & 17 deletions tests/dspy/test_dspy_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,8 @@ def __call__(self, prompt=None, messages=None, **kwargs):


@pytest.mark.skipif(
# NB: We also need to filter out version < 2.5.17 because installing DSPy
# from source will have hard-coded version number 2.5.15.
# https://github.com/stanfordnlp/dspy/blob/803dff03c42d2f436aa67398ce5aba17e7b45611/pyproject.toml#L8-L9
_DSPY_VERSION >= Version("2.5.19") or _DSPY_VERSION < Version("2.5.17"),
reason="dspy.ReAct is broken in >=2.5.19",
_DSPY_VERSION < Version("2.5.42"),
reason="DSPy callback does not handle Tool in versions < 2.5.42",
)
def test_autolog_react():
mlflow.dspy.autolog()
Expand All @@ -157,24 +154,29 @@ def test_autolog_react():
lm=DummyLM(
[
{
"Thought_1": "I need to search for the highest mountain in the world",
"Action_1": "Search['Highest mountain in the world']",
"next_thought": "I need to search for the highest mountain in the world",
"next_tool_name": "search",
"next_tool_args": {"query": "Highest mountain in the world"},
},
{
"Thought_2": "I found the highest mountain in the world",
"Action_2": "Finish[Mount Everest]",
"next_thought": "I found the highest mountain in the world",
"next_tool_name": "finish",
"next_tool_args": {"answer": "Mount Everest"},
},
{
"answer": "Mount Everest",
"reasoning": "No more responses",
},
]
),
adapter=dspy.ChatAdapter(),
)

class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""
def search(query: str) -> list[str]:
return "Mount Everest"

question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")

react = dspy.ReAct(BasicQA, tools=[])
tools = [dspy.Tool(search)]
react = dspy.ReAct("question -> answer", tools=tools)
result = react(question="What is the highest mountain in the world?")
assert result["answer"] == "Mount Everest"

Expand All @@ -184,18 +186,23 @@ class BasicQA(dspy.Signature):
assert trace.info.execution_time_ms > 0

spans = trace.data.spans
assert len(spans) == 10
assert len(spans) == 15
assert [span.name for span in spans] == [
"ReAct.forward",
"Predict.forward_1",
"ChatAdapter.format_1",
"DummyLM.__call___1",
"ChatAdapter.parse_1",
"Retrieve.forward",
"Tool.search",
"Predict.forward_2",
"ChatAdapter.format_2",
"DummyLM.__call___2",
"ChatAdapter.parse_2",
"ChainOfThought.forward",
"Predict.forward_3",
"ChatAdapter.format_3",
"DummyLM.__call___3",
"ChatAdapter.parse_3",
]


Expand Down

0 comments on commit 7e0292f

Please sign in to comment.