diff --git a/docs/concepts/serving/executor/add-endpoints.md b/docs/concepts/serving/executor/add-endpoints.md index cf66b12cf3fa3..6b25591960e31 100644 --- a/docs/concepts/serving/executor/add-endpoints.md +++ b/docs/concepts/serving/executor/add-endpoints.md @@ -382,10 +382,12 @@ Streaming endpoints receive one Document as input and yields one Document at a t Streaming endpoints are only supported for HTTP and gRPC protocols and for Deployment and Flow with one single Executor. -For HTTP deployment streaming executors generate both a GET and POST endpoint. -The GET endpoint support documents with string, integer, or float fields only, -whereas, POST requests support all docarrays. -The Jina client uses the POST endpoints. +For HTTP deployment streaming executors generate a GET endpoint. +The GET endpoint support passing documet fields in +the request body or as URL query parameters, +however, query parameters only support string, integer, or float fields, +whereas, the request body support all serializable docarrays. +The Jina client uses the request body. ``` A streaming endpoint has the following signature: diff --git a/jina/clients/base/helper.py b/jina/clients/base/helper.py index d3b76dc014e72..0a2d7481164a8 100644 --- a/jina/clients/base/helper.py +++ b/jina/clients/base/helper.py @@ -197,13 +197,14 @@ async def send_streaming_message(self, doc: 'Document', on: str): :param on: Request endpoint :yields: responses """ + req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict() request_kwargs = { 'url': self.url, 'headers': {'Accept': 'text/event-stream'}, - 'json': doc.dict() if docarray_v2 else doc.to_dict(), + 'json': req_dict, } - async with self.session.post(**request_kwargs) as response: + async with self.session.get(**request_kwargs) as response: async for chunk in response.content.iter_any(): events = chunk.split(b'event: ')[1:] for event in events: diff --git a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py index 6feb721ab0794..c4153ec3480fc 100644 --- a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py +++ b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py @@ -252,35 +252,13 @@ def add_streaming_routes( methods=['GET'], summary=f'Streaming Endpoint {endpoint_path}', ) - async def streaming_get(request: Request): - query_params = dict(request.query_params) + async def streaming_get(request: Request, body: input_doc_model = None): + body = body or dict(request.query_params) + body = input_doc_model.parse_obj(body) async def event_generator(): async for doc, error in streamer.stream_doc( - doc=input_doc_model(**query_params), exec_endpoint=endpoint_path - ): - if error: - raise HTTPException(status_code=499, detail=str(error)) - yield { - 'event': 'update', - 'data': doc.dict() - } - yield { - 'event': 'end' - } - - return EventSourceResponse(event_generator()) - - @app.api_route( - path=f'/{endpoint_path.strip("/")}', - methods=['POST'], - summary=f'Streaming Endpoint {endpoint_path}', - ) - async def streaming_post(body: dict): - - async def event_generator(): - async for doc, error in streamer.stream_doc( - doc=input_doc_model.parse_obj(body), exec_endpoint=endpoint_path + doc=body, exec_endpoint=endpoint_path ): if error: raise HTTPException(status_code=499, detail=str(error)) diff --git a/jina/serve/runtimes/gateway/request_handling.py b/jina/serve/runtimes/gateway/request_handling.py index 3b2fe6fdab111..1e4a47a5f2410 100644 --- a/jina/serve/runtimes/gateway/request_handling.py +++ b/jina/serve/runtimes/gateway/request_handling.py @@ -184,8 +184,19 @@ async def _load_balance(self, request): try: async with aiohttp.ClientSession() as session: + if request.method == 'GET': - async with session.get(target_url) as response: + request_kwargs = {} + try: + payload = await request.json() + if payload: + request_kwargs['json'] = payload + except Exception: + self.logger.debug('No JSON payload found in request') + + async with session.get( + url=target_url, **request_kwargs + ) as response: # Create a StreamResponse with the same headers and status as the target response stream_response = web.StreamResponse( status=response.status, diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index 0407bb87d8760..47006dd4be329 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -1,7 +1,7 @@ import inspect from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union -from jina import DocumentArray, Document +from jina import Document, DocumentArray from jina._docarray import docarray_v2 from jina.importer import ImportExtensions from jina.serve.networking.sse import EventSourceResponse @@ -11,15 +11,15 @@ from jina.logging.logger import JinaLogger if docarray_v2: - from docarray import DocList, BaseDoc + from docarray import BaseDoc, DocList def get_fastapi_app( - request_models_map: Dict, - caller: Callable, - logger: 'JinaLogger', - cors: bool = False, - **kwargs, + request_models_map: Dict, + caller: Callable, + logger: 'JinaLogger', + cors: bool = False, + **kwargs, ): """ Get the app from FastAPI as the REST interface. @@ -35,15 +35,18 @@ def get_fastapi_app( from fastapi import FastAPI, Response, HTTPException import pydantic from fastapi.middleware.cors import CORSMiddleware + import os + from pydantic import BaseModel, Field from pydantic.config import BaseConfig, inherit_config from jina.proto import jina_pb2 from jina.serve.runtimes.gateway.models import _to_camel_case - import os class Header(BaseModel): - request_id: Optional[str] = Field(description='Request ID', example=os.urandom(16).hex()) + request_id: Optional[str] = Field( + description='Request ID', example=os.urandom(16).hex() + ) class Config(BaseConfig): alias_generator = _to_camel_case @@ -66,11 +69,11 @@ class InnerConfig(BaseConfig): logger.warning('CORS is enabled. This service is accessible from any website!') def add_post_route( - endpoint_path, - input_model, - output_model, - input_doc_list_model=None, - output_doc_list_model=None, + endpoint_path, + input_model, + output_model, + input_doc_list_model=None, + output_doc_list_model=None, ): app_kwargs = dict( path=f'/{endpoint_path.strip("/")}', @@ -123,8 +126,8 @@ async def post(body: input_model, response: Response): return ret def add_streaming_routes( - endpoint_path, - input_doc_model=None, + endpoint_path, + input_doc_model=None, ): from fastapi import Request @@ -133,26 +136,17 @@ def add_streaming_routes( methods=['GET'], summary=f'Streaming Endpoint {endpoint_path}', ) - async def streaming_get(request: Request): - query_params = dict(request.query_params) - req = DataRequest() - req.header.exec_endpoint = endpoint_path - if not docarray_v2: - req.data.docs = DocumentArray([Document.from_dict(query_params)]) - else: - req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_model]( - [input_doc_model(**query_params)] + async def streaming_get(request: Request = None, body: input_doc_model = None): + if not body: + query_params = dict(request.query_params) + body = ( + input_doc_model.parse_obj(query_params) + if docarray_v2 + else Document.from_dict(query_params) ) - event_generator = _gen_dict_documents(await caller(req)) - return EventSourceResponse(event_generator) - - @app.api_route( - path=f'/{endpoint_path.strip("/")}', - methods=['POST'], - summary=f'Streaming Endpoint {endpoint_path}', - ) - async def streaming_post(body: input_doc_model, request: Request): + else: + if not docarray_v2: + body = Document.from_pydantic_model(body) req = DataRequest() req.header.exec_endpoint = endpoint_path if not docarray_v2: @@ -169,7 +163,9 @@ async def streaming_post(body: input_doc_model, request: Request): output_doc_model = input_output_map['output']['model'] is_generator = input_output_map['is_generator'] parameters_model = input_output_map['parameters']['model'] or Optional[Dict] - default_parameters = ... if input_output_map['parameters']['model'] else None + default_parameters = ( + ... if input_output_map['parameters']['model'] else None + ) if docarray_v2: _config = inherit_config(InnerConfig, BaseDoc.__config__) diff --git a/tests/integration/docarray_v2/test_issues.py b/tests/integration/docarray_v2/test_issues.py index a7757b7516e2c..906a895b8ed3a 100644 --- a/tests/integration/docarray_v2/test_issues.py +++ b/tests/integration/docarray_v2/test_issues.py @@ -1,10 +1,12 @@ -from typing import List, Optional, Dict +from typing import Dict, List, Optional +import aiohttp import pytest from docarray import BaseDoc, DocList from pydantic import Field -from jina import Executor, Flow, requests, Deployment, Client +from jina import Client, Deployment, Executor, Flow, requests +from jina.clients.base.helper import HTTPClientlet class Nested2Doc(BaseDoc): @@ -78,6 +80,7 @@ def test_issue_6019_with_nested_list(): assert res[0].text == 'hello world' assert res[0].nested[0].nested.value == 'test' + def test_issue_6084(): class EnvInfo(BaseDoc): history: str = '' @@ -86,7 +89,6 @@ class A(BaseDoc): b: EnvInfo class MyIssue6084Exec(Executor): - @requests def foo(self, docs: DocList[A], **kwargs) -> DocList[A]: pass @@ -96,48 +98,106 @@ def foo(self, docs: DocList[A], **kwargs) -> DocList[A]: pass +class NestedFieldSchema(BaseDoc): + name: str = "test_name" + dict_field: Dict = Field(default_factory=dict) + + +class InputWithComplexFields(BaseDoc): + text: str = "test" + nested_field: NestedFieldSchema = Field(default_factory=NestedFieldSchema) + dict_field: Dict = Field(default_factory=dict) + bool_field: bool = False + + +class SimpleInput(BaseDoc): + text: str = "test" + + +class MyExecutor(Executor): + @requests(on="/stream") + async def stream( + self, + doc: InputWithComplexFields, + parameters: Optional[Dict] = None, + **kwargs, + ) -> InputWithComplexFields: + for i in range(4): + yield InputWithComplexFields(text=f"hello world {doc.text} {i}") + + @requests(on="/stream-simple") + async def stream_simple( + self, + doc: SimpleInput, + parameters: Optional[Dict] = None, + **kwargs, + ) -> SimpleInput: + for i in range(4): + yield SimpleInput(text=f"hello world {doc.text} {i}") + + +@pytest.fixture(scope="module") +def streaming_deployment(): + protocol = "http" + with Deployment(uses=MyExecutor, protocol=protocol) as dep: + yield dep + + @pytest.mark.asyncio -async def test_issue_6090(): +async def test_issue_6090(streaming_deployment): """Tests if streaming works with pydantic models with complex fields which are not str, int, or float. """ - class NestedFieldSchema(BaseDoc): - name: str = "test_name" - dict_field: Dict = Field(default_factory=dict) - - class InputWithComplexFields(BaseDoc): - text: str = "test" - nested_field: NestedFieldSchema = Field(default_factory=NestedFieldSchema) - dict_field: Dict = Field(default_factory=dict) - bool_field: bool = False - - class MyExecutor(Executor): - @requests(on="/stream") - async def stream( - self, doc: InputWithComplexFields, parameters: Optional[Dict] = None, **kwargs - ) -> InputWithComplexFields: - for i in range(4): - yield InputWithComplexFields(text=f"hello world {doc.text} {i}") - docs = [] protocol = "http" - with Deployment(uses=MyExecutor, protocol=protocol) as dep: - client = Client(port=dep.port, protocol=protocol, asyncio=True) - example_doc = InputWithComplexFields(text="my input text") - async for doc in client.stream_doc( - on="/stream", - inputs=example_doc, - input_type=InputWithComplexFields, - return_type=InputWithComplexFields, - ): - docs.append(doc) + client = Client(port=streaming_deployment.port, protocol=protocol, asyncio=True) + example_doc = InputWithComplexFields(text="my input text") + async for doc in client.stream_doc( + on="/stream", + inputs=example_doc, + input_type=InputWithComplexFields, + return_type=InputWithComplexFields, + ): + docs.append(doc) assert [d.text for d in docs] == [ - "hello world my input text 0", - "hello world my input text 1", - "hello world my input text 2", - "hello world my input text 3", + 'hello world my input text 0', + 'hello world my input text 1', + 'hello world my input text 2', + 'hello world my input text 3', ] assert docs[0].nested_field.name == "test_name" + +@pytest.mark.asyncio +async def test_issue_6090_get_params(streaming_deployment): + """Tests if streaming works with pydantic models with complex fields which are not + str, int, or float. + """ + + docs = [] + url = ( + f"htto://localhost:{streaming_deployment.port}/stream-simple?text=my_input_text" + ) + async with aiohttp.ClientSession() as session: + + async with session.get(url) as resp: + async for chunk in resp.content.iter_any(): + print(chunk) + events = chunk.split(b'event: ')[1:] + for event in events: + if event.startswith(b'update'): + parsed = event[HTTPClientlet.UPDATE_EVENT_PREFIX:].decode() + parsed = SimpleInput.parse_raw(parsed) + print(parsed) + docs.append(parsed) + elif event.startswith(b'end'): + pass + + assert [d.text for d in docs] == [ + 'hello world my_input_text 0', + 'hello world my_input_text 1', + 'hello world my_input_text 2', + 'hello world my_input_text 3', + ] diff --git a/tests/integration/docarray_v2/test_streaming.py b/tests/integration/docarray_v2/test_streaming.py index 25befe58d814a..b6d89086a79b7 100644 --- a/tests/integration/docarray_v2/test_streaming.py +++ b/tests/integration/docarray_v2/test_streaming.py @@ -143,19 +143,17 @@ async def test_streaming_delay(protocol, include_gateway): ): client = Client(port=port, protocol=protocol, asyncio=True) i = 0 - stream = client.stream_doc( + start_time = time.time() + async for doc in client.stream_doc( on='/hello', inputs=MyDocument(text='hello world', number=i), return_type=MyDocument, - ) - start_time = None - async for doc in stream: - start_time = start_time or time.time() + ): assert doc.text == f'hello world {i}' i += 1 - delay = time.time() - start_time + # 0.5 seconds between each request + 0.5 seconds tolerance interval - assert delay < (0.5 * i), f'Expected delay to be less than {0.5 * i}, got {delay} on iteration {i}' + assert time.time() - start_time < (0.5 * i) + 0.6 @pytest.mark.asyncio diff --git a/tests/integration/streaming/test_streaming.py b/tests/integration/streaming/test_streaming.py index d43ed878c96db..5d2f6e4af848b 100644 --- a/tests/integration/streaming/test_streaming.py +++ b/tests/integration/streaming/test_streaming.py @@ -24,9 +24,9 @@ async def non_gen_task(self, docs: DocumentArray, **kwargs): @pytest.mark.parametrize('protocol', ['http', 'grpc']) @pytest.mark.parametrize('include_gateway', [False, True]) async def test_streaming_deployment(protocol, include_gateway): - from jina import Deployment port = random_port() + docs = [] with Deployment( uses=MyExecutor, @@ -38,10 +38,15 @@ async def test_streaming_deployment(protocol, include_gateway): client = Client(port=port, protocol=protocol, asyncio=True) i = 0 async for doc in client.stream_doc( - on='/hello', inputs=Document(text='hello world') + on='/hello', + inputs=Document(text='hello world'), + return_type=Document, + input_type=Document, ): - assert doc.text == f'hello world {i}' + docs.append(doc.text) i += 1 + assert docs == [f'hello world {i}' for i in range(100)] + assert len(docs) == 100 class WaitStreamExecutor(Executor):