diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py index 8e892e17162aa..00ccecfdc71a8 100644 --- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py +++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py @@ -1,5 +1,7 @@ +import asyncio +from functools import partial from contextlib import contextmanager -from contextvars import ContextVar, Token +from contextvars import Context, ContextVar, Token, copy_context from typing import Any, Callable, Generator, List, Optional, Dict, Protocol import inspect import uuid @@ -251,6 +253,10 @@ def wrapper(func: Callable, instance: Any, args: list, kwargs: dict) -> Any: bound_args = inspect.signature(func).bind(*args, **kwargs) id_ = f"{func.__qualname__}-{uuid.uuid4()}" tags = active_instrument_tags.get() + result = None + + # Copy the current context + context = copy_context() token = active_span_id.set(id_) parent_id = None if token.old_value is Token.MISSING else token.old_value @@ -261,20 +267,60 @@ def wrapper(func: Callable, instance: Any, args: list, kwargs: dict) -> Any: parent_id=parent_id, tags=tags, ) + + def handle_future_result( + future: asyncio.Future, + span_id: str, + bound_args: inspect.BoundArguments, + instance: Any, + context: Context, + ) -> None: + try: + result = future.result() + self.span_exit( + id_=span_id, + bound_args=bound_args, + instance=instance, + result=result, + ) + return result + except BaseException as e: + self.event(SpanDropEvent(span_id=span_id, err_str=str(e))) + self.span_drop( + id_=span_id, bound_args=bound_args, instance=instance, err=e + ) + raise + finally: + context.run(active_span_id.reset, token) + try: result = func(*args, **kwargs) + if isinstance(result, asyncio.Future): + # If the result is a Future, wrap it + new_future = asyncio.ensure_future(result) + new_future.add_done_callback( + partial( + handle_future_result, + span_id=id_, + bound_args=bound_args, + instance=instance, + context=context, + ) + ) + return new_future + else: + # For non-Future results, proceed as before + self.span_exit( + id_=id_, bound_args=bound_args, instance=instance, result=result + ) + return result except BaseException as e: self.event(SpanDropEvent(span_id=id_, err_str=str(e))) self.span_drop(id_=id_, bound_args=bound_args, instance=instance, err=e) raise - else: - self.span_exit( - id_=id_, bound_args=bound_args, instance=instance, result=result - ) - return result finally: - # clean up - active_span_id.reset(token) + if not isinstance(result, asyncio.Future): + active_span_id.reset(token) @wrapt.decorator async def async_wrapper(