diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 931540fa7..345a8c445 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -85,23 +85,31 @@ def __init__( max_tokens >= 5000 and temperature == 1.0 ), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`" - @with_callbacks - def __call__(self, prompt=None, messages=None, **kwargs): - # Build the request. + def _build_request(self, prompt=None, messages=None, **kwargs): + """Build the request dictionary for LM calls""" cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] kwargs = {**self.kwargs, **kwargs} - - # Make the request and handle LRU & disk caching. + if self.model_type == "chat": completion = cached_litellm_completion if cache else litellm_completion else: completion = cached_litellm_text_completion if cache else litellm_text_completion - - response = completion( + + return dict( request=dict(model=self.model, messages=messages, **kwargs), + completion=completion, # <-- ADD THIS LINE num_retries=self.num_retries, ) + + @with_callbacks + def __call__(self, prompt=None, messages=None, **kwargs): + request = self._build_request(prompt, messages, **kwargs) + # Pass required arguments explicitly instead of **request + response = request["completion"]( + request=request["request"], + num_retries=request["num_retries"] + ) if kwargs.get("logprobs"): outputs = [ { @@ -216,6 +224,47 @@ def infer_adapter(self) -> Adapter: model_type = self.model_type return model_type_to_adapter[model_type] + async def _async_request(self, request: dict) -> list: + """Base async request handler""" + # Pass required arguments explicitly + response = await litellm.acompletion(**request["request"]) + if request["request"].get("logprobs"): + outputs = [ + { + "text": c.message.content if hasattr(c, "message") else c["text"], + "logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"], + } + for c in response["choices"] + ] + else: + outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] + + # Logging + kwargs = {k: v for k, v in request["request"].items() if not k.startswith("api_")} + entry = dict( + prompt=request["request"].get("prompt"), + messages=request["request"].get("messages"), + kwargs=kwargs, + response=response, + outputs=outputs, + usage=dict(response["usage"]), + cost=response.get("_hidden_params", {}).get("response_cost"), + timestamp=datetime.now().isoformat(), + uuid=str(uuid.uuid4()), + model=self.model, + response_model=response["model"], + model_type=self.model_type, + ) + self.history.append(entry) + self.update_global_history(entry) + + return outputs + + async def __acall__(self, prompt=None, messages=None, **kwargs): + """Async call interface""" + request = self._build_request(prompt, messages, **kwargs) + return await self._async_request(request) + def copy(self, **kwargs): """Returns a copy of the language model with possibly updated parameters.""" diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index 80ddc5c8e..26f813803 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -1,9 +1,24 @@ import magicattr +import inspect +import asyncio +from typing import Any, Union, Awaitable, TypeVar, Optional, List, Callable + +# Marker for async arguments +ASYNC_MARKER = object() + +def is_async_arg(arg): + """Check if an argument requires async resolution.""" + return (arg is ASYNC_MARKER or + inspect.iscoroutine(arg) or + inspect.isawaitable(arg) or + isinstance(arg, asyncio.Future)) from dspy.predict.parallel import Parallel from dspy.primitives.module import BaseModule from dspy.utils.callback import with_callbacks +T = TypeVar('T') + class ProgramMeta(type): pass @@ -18,9 +33,126 @@ def __init__(self, callbacks=None): self._compiled = False @with_callbacks - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Union[T, Awaitable[T]]: + """Call the module with given arguments. + + Automatically determines whether to use sync or async execution based on arguments. + If any argument is a coroutine, awaitable, or future, uses async execution. + Also uses async execution if the module has a custom aforward implementation. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Either the direct result (sync) or an awaitable of the result (async) + """ + # Check if we should use async execution + use_async = ( + # If any argument is async + any(is_async_arg(arg) for arg in args) or + any(is_async_arg(v) for v in kwargs.values()) or + # Or if we have a custom aforward implementation + (hasattr(self, 'aforward') and + self.aforward.__func__ is not Module.aforward) + ) + + if use_async: + async def _async_call(): + try: + # Collect ALL async values first + all_async = [ + arg for arg in args if is_async_arg(arg) + ] + [ + v for v in kwargs.values() if is_async_arg(v) + ] + + # Resolve ALL concurrently + if all_async: + resolved = await asyncio.gather(*all_async) + else: + resolved = [] + + # Rebuild args/kwargs with resolved values + resolved_iter = iter(resolved) + new_args = [next(resolved_iter) if is_async_arg(arg) else arg for arg in args] + new_kwargs = {k: next(resolved_iter) if is_async_arg(v) else v for k, v in kwargs.items()} + + # Validate all async values were resolved + for arg in new_args: + if arg is ASYNC_MARKER: + raise ValueError("Unresolved async argument in args") + for v in new_kwargs.values(): + if v is ASYNC_MARKER: + raise ValueError("Unresolved async argument in kwargs") + + return await self.aforward(*new_args, **new_kwargs) + except Exception as e: + raise e + return _async_call() + + # Use sync execution return self.forward(*args, **kwargs) + async def aforward(self, *args: Any, **kwargs: Any) -> T: + """Async version of forward. + + This method should be implemented by subclasses to provide async execution. + By default, raises NotImplementedError to encourage proper async implementation. + + When implementing this method: + 1. Use 'async def' and 'await' for async operations + 2. Avoid blocking operations - they should be properly awaited + 3. Consider using asyncio.create_task for concurrent operations + 4. Be mindful of async context managers (use 'async with') + + Example: + ```python + class MyAsyncModule(Module): + async def aforward(self, x): + # Good: proper async operation + result = await async_operation(x) + return result + + # Bad: blocking operation + # time.sleep(1) # Don't do this! + + # Bad: sync operation without proper async + # return self.forward(x) # Don't do this! + ``` + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + The result of the async computation + + Raises: + NotImplementedError: Subclasses must implement this method for async operations + """ + raise NotImplementedError( + "Subclasses must implement aforward for async operations. " + "Do not use sync operations or blocking calls in this method." + ) + + def forward(self, *args: Any, **kwargs: Any) -> T: + """Synchronous forward pass. + + Must be implemented by subclasses to define the module's computation. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + The result of the computation + + Raises: + NotImplementedError: If not implemented by subclass + """ + raise NotImplementedError("Subclasses must implement forward method") + def named_predictors(self): from dspy.predict.predict import Predict diff --git a/dspy/utils/streaming.py b/dspy/utils/streaming.py index bf1f4f5bd..4d466b0e3 100644 --- a/dspy/utils/streaming.py +++ b/dspy/utils/streaming.py @@ -37,25 +37,48 @@ def streamify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]: >>> print(value) # Print each streamed value incrementally """ import dspy + import inspect if not iscoroutinefunction(program): program = asyncify(program) async def generator(args, kwargs, stream: MemoryObjectSendStream): - with dspy.settings.context(send_stream=stream): - prediction = await program(*args, **kwargs) + try: + with dspy.settings.context(send_stream=stream): + # Get the raw output from the program + output = program(*args, **kwargs) + + # Handle both async and sync outputs + if inspect.isawaitable(output): + output = await output - await stream.send(prediction) + # If output is a generator/async generator, stream its items + if inspect.isgenerator(output) or inspect.isasyncgen(output): + async for chunk in output: + await stream.send(chunk) + else: + # For single predictions, send as a single chunk + await stream.send(output) + + # Send completion marker + await stream.send(None) + finally: + await stream.aclose() async def streamer(*args, **kwargs): send_stream, receive_stream = create_memory_object_stream(16) - async with create_task_group() as tg, send_stream, receive_stream: + async with create_task_group() as tg: tg.start_soon(generator, args, kwargs, send_stream) + + try: + async for value in receive_stream: + if value is None: # Completion marker + break + yield value + finally: + await receive_stream.aclose() - async for value in receive_stream: - yield value - if isinstance(value, Prediction): - return + return streamer return streamer diff --git a/tests/primitives/test_async_module.py b/tests/primitives/test_async_module.py new file mode 100644 index 000000000..f3157e475 --- /dev/null +++ b/tests/primitives/test_async_module.py @@ -0,0 +1,223 @@ +import asyncio +import pytest +from unittest.mock import Mock, patch +import dspy +from dspy.primitives.program import Module + + +class SimpleAsyncModule(Module): + async def aforward(self, x): + await asyncio.sleep(0.1) # Simulate async work + return x * 2 + + def forward(self, x): + return x * 2 + +class SimpleSyncModule(Module): + def forward(self, x): + return x * 2 + +class MixedModule(Module): + async def aforward(self, x, y=None): + if y is not None: + await asyncio.sleep(0.1) + return x + y + return x * 2 + + def forward(self, x, y=None): + if y is not None: + return x + y + return x * 2 + +@pytest.mark.asyncio +async def test_module_async_call(): + """Test that async arguments trigger aforward""" + module = SimpleAsyncModule() + async def async_input(): + await asyncio.sleep(0.1) + return 5 + + result = await module(async_input()) + assert result == 10 + +@pytest.mark.asyncio +async def test_module_sync_call(): + """Test that sync arguments use forward""" + module = SimpleSyncModule() + result = module(5) + assert result == 10 + +@pytest.mark.asyncio +async def test_mixed_module_async(): + """Test mixed module with async arguments""" + module = MixedModule() + async def async_value(): + await asyncio.sleep(0.1) + return 7 + result = await module(5, y=async_value()) + assert result == 12 + +@pytest.mark.asyncio +async def test_mixed_args(): + """Test handling of mixed sync/async arguments""" + module = SimpleAsyncModule() + async def async_value(): + await asyncio.sleep(0.1) + return 5 + + # Mix of sync and async args + result = await module(async_value()) + assert result == 10 + +@pytest.mark.asyncio +async def test_error_handling(): + """Test error handling in async context""" + class ErrorModule(Module): + async def aforward(self, x): + raise ValueError("Test error") + + def forward(self, x): + return x * 2 + + module = ErrorModule() + with pytest.raises(ValueError, match="Test error"): + await module(5) + +@pytest.mark.asyncio +async def test_callback_handling(): + """Test that callbacks work in async context""" + class TestCallback(dspy.utils.callback.BaseCallback): + def __init__(self): + self.mock = Mock() + + def on_module_start(self, call_id, instance, inputs): + self.mock.on_module_start(call_id, instance, inputs) + + def on_module_end(self, call_id, outputs, exception): + self.mock.on_module_end(call_id, outputs, exception) + + callback = TestCallback() + module = SimpleAsyncModule() + module.callbacks.append(callback) + + result = await module(5) + assert result == 10 + assert callback.mock.on_module_start.called + assert callback.mock.on_module_end.called + +def test_not_implemented(): + """Test NotImplementedError is raised when neither forward nor aforward is implemented""" + module = Module() + with pytest.raises(NotImplementedError): + module(5) + +@pytest.mark.asyncio +async def test_async_type_detection(): + """Test detection of different async types""" + module = SimpleAsyncModule() + + # Test coroutine + async def coro(): + return 5 + result = await module(coro()) + assert result == 10 + + # Test future + loop = asyncio.get_running_loop() + future = loop.create_future() + future.set_result(5) + result = await module(future) + assert result == 10 + + # Test custom awaitable + class CustomAwaitable: + def __await__(self): + async def inner(): + return 5 + return inner().__await__() + result = await module(CustomAwaitable()) + assert result == 10 + +def test_aforward_docstring(): + """Verify aforward has proper documentation""" + doc = Module.aforward.__doc__ + + # Check docstring exists and has key elements + assert doc is not None, "aforward should have a docstring" + + # Check key implementation guidance + assert "async def" in doc, "Should mention async def usage" + assert "await" in doc, "Should mention await usage" + assert "Example:" in doc, "Should include example code" + + # Check warning about blocking operations + assert "blocking" in doc.lower(), "Should warn about blocking operations" + + # Check it has proper sections + assert "Args:" in doc, "Should document arguments" + assert "Returns:" in doc, "Should document return value" + assert "Raises:" in doc, "Should document exceptions" + +@pytest.mark.asyncio +async def test_edge_case_args(): + """Test edge cases in argument detection and handling""" + class EdgeModule(Module): + async def aforward(self, x, y=None): + if x is None and y is None: + return 0 + if x is None: + return y + if y is None: + return x + return x + y + + def forward(self, x, y=None): + if x is None and y is None: + return 0 + if x is None: + return y + if y is None: + return x + return x + y + + module = EdgeModule() + + # Test None arguments + result = await module(None) + assert result == 0 + + # Test empty async value + async def empty_async(): + return None + result = await module(empty_async()) + assert result == 0 + + # Test mixed None and async + async def value_async(): + return 5 + result = await module(None, y=value_async()) + assert result == 5 + + # Test multiple async args with None + result = await module( + empty_async(), + y=value_async() + ) + assert result == 5 + +@pytest.mark.asyncio +async def test_async_kwargs_only(): + """Test async detection in kwargs-only calls""" + module = MixedModule() + + async def async_value(): + await asyncio.sleep(0.01) + return 7 + + # Call with only kwargs, no positional args + result = await module(x=5, y=async_value()) + assert result == 12 + + # Call with async value in first kwarg + result = await module(x=async_value(), y=5) + assert result == 12 \ No newline at end of file