Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use body for streaming instead of params #6098

Merged
merged 35 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
94348f6
fix: add post endpoint for streaming
NarekA Oct 20, 2023
b645afb
test: stream complex docs
NarekA Oct 20, 2023
72037b5
docs: include docstring for issue test
NarekA Oct 20, 2023
99cdd7d
Merge branch 'master' into fix-streaming-6091
NarekA Oct 20, 2023
cc315c9
Merge branch 'master' into fix-streaming-6091
NarekA Oct 24, 2023
8bc5124
fix: use random port
NarekA Oct 24, 2023
804b665
fix: remove import
NarekA Oct 24, 2023
3ce6e1e
fix: simplify client code
NarekA Oct 24, 2023
a5376af
fix: use json field
NarekA Oct 24, 2023
025413f
fix: use to_dict for docarray v1
NarekA Oct 24, 2023
60c9b5e
docs: add docs about http get
NarekA Oct 24, 2023
03cc4c4
fix: typo in deployment
NarekA Oct 24, 2023
b89c01a
fix: don't use body.data
NarekA Oct 24, 2023
cf2f145
fix: do unpack body
NarekA Oct 24, 2023
fd30b0e
fix: docarray v2 cast model
NarekA Oct 24, 2023
f1f9a88
fix: change start time delay
NarekA Oct 24, 2023
3aa6534
Revert "fix: change start time delay"
NarekA Oct 25, 2023
57af50a
fix: use get only
NarekA Oct 25, 2023
5f27f0a
Merge remote-tracking branch 'origin/master' into fix-streaming-2-6091
NarekA Oct 25, 2023
a72a797
fix: delay test
NarekA Oct 25, 2023
ee77d13
fix: fix get and post endpoints
NarekA Oct 25, 2023
524aba4
fix: remove post
NarekA Oct 25, 2023
5cfb423
fix: remove endpoint tags
NarekA Oct 25, 2023
3d1b07f
Merge branch 'master' into fix-streaming-2-6091
NarekA Oct 25, 2023
ca3a707
test: use get with url params
NarekA Oct 26, 2023
92405db
docs: fix docs on streaming endpoints
NarekA Oct 26, 2023
2c626bf
test: increase tolerance
NarekA Oct 26, 2023
38a577b
fix: iteration over chunks
NarekA Oct 26, 2023
c4afd99
fix: gateway forwarding
NarekA Oct 27, 2023
c654c14
fix: pre-commit changes
NarekA Oct 27, 2023
bb90d14
fix: don't check for docarray_v2
NarekA Oct 27, 2023
08bfd6b
fix: remove type-hint
NarekA Oct 27, 2023
3eff072
fix: update test output
NarekA Oct 27, 2023
46af5fa
fix: adding payload
NarekA Oct 27, 2023
bf3cbe2
fix: use json payload
NarekA Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions docs/concepts/serving/executor/add-endpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,12 @@ Streaming endpoints receive one Document as input and yields one Document at a t

Streaming endpoints are only supported for HTTP and gRPC protocols and for Deployment and Flow with one single Executor.

For HTTP deployment streaming executors generate both a GET and POST endpoint.
The GET endpoint support documents with string, integer, or float fields only,
whereas, POST requests support all docarrays.
The Jina client uses the POST endpoints.
For HTTP deployment streaming executors generate a GET endpoint.
The GET endpoint support passing documet fields in
the request body or as URL query parameters,
however, query parameters only support string, integer, or float fields,
whereas, the request body support all serializable docarrays.
The Jina client uses the request body.
```

A streaming endpoint has the following signature:
Expand Down
5 changes: 3 additions & 2 deletions jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,14 @@ async def send_streaming_message(self, doc: 'Document', on: str):
:param on: Request endpoint
:yields: responses
"""
req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict()
request_kwargs = {
'url': self.url,
'headers': {'Accept': 'text/event-stream'},
'json': doc.dict() if docarray_v2 else doc.to_dict(),
'json': req_dict,
}

async with self.session.post(**request_kwargs) as response:
async with self.session.get(**request_kwargs) as response:
async for chunk in response.content.iter_any():
events = chunk.split(b'event: ')[1:]
for event in events:
Expand Down
30 changes: 4 additions & 26 deletions jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,35 +252,13 @@ def add_streaming_routes(
methods=['GET'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_get(request: Request):
query_params = dict(request.query_params)
async def streaming_get(request: Request, body: input_doc_model = None):
body = body or dict(request.query_params)
body = input_doc_model.parse_obj(body)

async def event_generator():
async for doc, error in streamer.stream_doc(
doc=input_doc_model(**query_params), exec_endpoint=endpoint_path
):
if error:
raise HTTPException(status_code=499, detail=str(error))
yield {
'event': 'update',
'data': doc.dict()
}
yield {
'event': 'end'
}

return EventSourceResponse(event_generator())

@app.api_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_post(body: dict):

async def event_generator():
async for doc, error in streamer.stream_doc(
doc=input_doc_model.parse_obj(body), exec_endpoint=endpoint_path
doc=body, exec_endpoint=endpoint_path
):
if error:
raise HTTPException(status_code=499, detail=str(error))
Expand Down
13 changes: 12 additions & 1 deletion jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,19 @@ async def _load_balance(self, request):

try:
async with aiohttp.ClientSession() as session:

if request.method == 'GET':
async with session.get(target_url) as response:
request_kwargs = {}
try:
payload = await request.json()
if payload:
request_kwargs['json'] = payload
except Exception:
self.logger.debug('No JSON payload found in request')

async with session.get(
url=target_url, **request_kwargs
) as response:
# Create a StreamResponse with the same headers and status as the target response
stream_response = web.StreamResponse(
status=response.status,
Expand Down
68 changes: 32 additions & 36 deletions jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

from jina import DocumentArray, Document
from jina import Document, DocumentArray
from jina._docarray import docarray_v2
from jina.importer import ImportExtensions
from jina.serve.networking.sse import EventSourceResponse
Expand All @@ -11,15 +11,15 @@
from jina.logging.logger import JinaLogger

if docarray_v2:
from docarray import DocList, BaseDoc
from docarray import BaseDoc, DocList


def get_fastapi_app(
request_models_map: Dict,
caller: Callable,
logger: 'JinaLogger',
cors: bool = False,
**kwargs,
request_models_map: Dict,
caller: Callable,
logger: 'JinaLogger',
cors: bool = False,
**kwargs,
):
"""
Get the app from FastAPI as the REST interface.
Expand All @@ -35,15 +35,18 @@ def get_fastapi_app(
from fastapi import FastAPI, Response, HTTPException
import pydantic
from fastapi.middleware.cors import CORSMiddleware
import os

from pydantic import BaseModel, Field
from pydantic.config import BaseConfig, inherit_config

from jina.proto import jina_pb2
from jina.serve.runtimes.gateway.models import _to_camel_case
import os

class Header(BaseModel):
request_id: Optional[str] = Field(description='Request ID', example=os.urandom(16).hex())
request_id: Optional[str] = Field(
description='Request ID', example=os.urandom(16).hex()
)

class Config(BaseConfig):
alias_generator = _to_camel_case
Expand All @@ -66,11 +69,11 @@ class InnerConfig(BaseConfig):
logger.warning('CORS is enabled. This service is accessible from any website!')

def add_post_route(
endpoint_path,
input_model,
output_model,
input_doc_list_model=None,
output_doc_list_model=None,
endpoint_path,
input_model,
output_model,
input_doc_list_model=None,
output_doc_list_model=None,
):
app_kwargs = dict(
path=f'/{endpoint_path.strip("/")}',
Expand Down Expand Up @@ -123,8 +126,8 @@ async def post(body: input_model, response: Response):
return ret

def add_streaming_routes(
endpoint_path,
input_doc_model=None,
endpoint_path,
input_doc_model=None,
):
from fastapi import Request

Expand All @@ -133,26 +136,17 @@ def add_streaming_routes(
methods=['GET'],
Copy link
Contributor Author

@NarekA NarekA Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following code works, but I did not allow get since it doesn't work with gateway deployments.

methods=['GET', 'POST'],

summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_get(request: Request):
query_params = dict(request.query_params)
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
req.data.docs = DocumentArray([Document.from_dict(query_params)])
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_model](
[input_doc_model(**query_params)]
async def streaming_get(request: Request = None, body: input_doc_model = None):
if not body:
query_params = dict(request.query_params)
body = (
input_doc_model.parse_obj(query_params)
if docarray_v2
else Document.from_dict(query_params)
)
event_generator = _gen_dict_documents(await caller(req))
return EventSourceResponse(event_generator)

@app.api_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_post(body: input_doc_model, request: Request):
else:
if not docarray_v2:
body = Document.from_pydantic_model(body)
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
Expand All @@ -169,7 +163,9 @@ async def streaming_post(body: input_doc_model, request: Request):
output_doc_model = input_output_map['output']['model']
is_generator = input_output_map['is_generator']
parameters_model = input_output_map['parameters']['model'] or Optional[Dict]
default_parameters = ... if input_output_map['parameters']['model'] else None
default_parameters = (
... if input_output_map['parameters']['model'] else None
)

if docarray_v2:
_config = inherit_config(InnerConfig, BaseDoc.__config__)
Expand Down
132 changes: 96 additions & 36 deletions tests/integration/docarray_v2/test_issues.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List, Optional, Dict
from typing import Dict, List, Optional

import aiohttp
import pytest
from docarray import BaseDoc, DocList
from pydantic import Field

from jina import Executor, Flow, requests, Deployment, Client
from jina import Client, Deployment, Executor, Flow, requests
from jina.clients.base.helper import HTTPClientlet


class Nested2Doc(BaseDoc):
Expand Down Expand Up @@ -78,6 +80,7 @@ def test_issue_6019_with_nested_list():
assert res[0].text == 'hello world'
assert res[0].nested[0].nested.value == 'test'


def test_issue_6084():
class EnvInfo(BaseDoc):
history: str = ''
Expand All @@ -86,7 +89,6 @@ class A(BaseDoc):
b: EnvInfo

class MyIssue6084Exec(Executor):

@requests
def foo(self, docs: DocList[A], **kwargs) -> DocList[A]:
pass
Expand All @@ -96,48 +98,106 @@ def foo(self, docs: DocList[A], **kwargs) -> DocList[A]:
pass


class NestedFieldSchema(BaseDoc):
name: str = "test_name"
dict_field: Dict = Field(default_factory=dict)


class InputWithComplexFields(BaseDoc):
text: str = "test"
nested_field: NestedFieldSchema = Field(default_factory=NestedFieldSchema)
dict_field: Dict = Field(default_factory=dict)
bool_field: bool = False


class SimpleInput(BaseDoc):
text: str = "test"


class MyExecutor(Executor):
@requests(on="/stream")
async def stream(
self,
doc: InputWithComplexFields,
parameters: Optional[Dict] = None,
**kwargs,
) -> InputWithComplexFields:
for i in range(4):
yield InputWithComplexFields(text=f"hello world {doc.text} {i}")

@requests(on="/stream-simple")
async def stream_simple(
self,
doc: SimpleInput,
parameters: Optional[Dict] = None,
**kwargs,
) -> SimpleInput:
for i in range(4):
yield SimpleInput(text=f"hello world {doc.text} {i}")


@pytest.fixture(scope="module")
def streaming_deployment():
protocol = "http"
with Deployment(uses=MyExecutor, protocol=protocol) as dep:
yield dep


@pytest.mark.asyncio
async def test_issue_6090():
async def test_issue_6090(streaming_deployment):
"""Tests if streaming works with pydantic models with complex fields which are not
str, int, or float.
"""

class NestedFieldSchema(BaseDoc):
name: str = "test_name"
dict_field: Dict = Field(default_factory=dict)

class InputWithComplexFields(BaseDoc):
text: str = "test"
nested_field: NestedFieldSchema = Field(default_factory=NestedFieldSchema)
dict_field: Dict = Field(default_factory=dict)
bool_field: bool = False

class MyExecutor(Executor):
@requests(on="/stream")
async def stream(
self, doc: InputWithComplexFields, parameters: Optional[Dict] = None, **kwargs
) -> InputWithComplexFields:
for i in range(4):
yield InputWithComplexFields(text=f"hello world {doc.text} {i}")

docs = []
protocol = "http"
with Deployment(uses=MyExecutor, protocol=protocol) as dep:
client = Client(port=dep.port, protocol=protocol, asyncio=True)
example_doc = InputWithComplexFields(text="my input text")
async for doc in client.stream_doc(
on="/stream",
inputs=example_doc,
input_type=InputWithComplexFields,
return_type=InputWithComplexFields,
):
docs.append(doc)
client = Client(port=streaming_deployment.port, protocol=protocol, asyncio=True)
example_doc = InputWithComplexFields(text="my input text")
async for doc in client.stream_doc(
on="/stream",
inputs=example_doc,
input_type=InputWithComplexFields,
return_type=InputWithComplexFields,
):
docs.append(doc)

assert [d.text for d in docs] == [
"hello world my input text 0",
"hello world my input text 1",
"hello world my input text 2",
"hello world my input text 3",
'hello world my input text 0',
'hello world my input text 1',
'hello world my input text 2',
'hello world my input text 3',
]
assert docs[0].nested_field.name == "test_name"


@pytest.mark.asyncio
async def test_issue_6090_get_params(streaming_deployment):
"""Tests if streaming works with pydantic models with complex fields which are not
str, int, or float.
"""

docs = []
url = (
f"htto://localhost:{streaming_deployment.port}/stream-simple?text=my_input_text"
)
async with aiohttp.ClientSession() as session:

async with session.get(url) as resp:
async for chunk in resp.content.iter_any():
print(chunk)
events = chunk.split(b'event: ')[1:]
for event in events:
if event.startswith(b'update'):
parsed = event[HTTPClientlet.UPDATE_EVENT_PREFIX:].decode()
parsed = SimpleInput.parse_raw(parsed)
print(parsed)
docs.append(parsed)
elif event.startswith(b'end'):
pass

assert [d.text for d in docs] == [
'hello world my_input_text 0',
'hello world my_input_text 1',
'hello world my_input_text 2',
'hello world my_input_text 3',
]
Loading