Skip to content

Commit

Permalink
Merge pull request #498 from airtai/FastAPI-include
Browse files Browse the repository at this point in the history
feat: add FastAPI Router include_router method, fix #497
  • Loading branch information
sternakt authored Sep 1, 2023
2 parents 8f7abc4 + 913eaec commit 8a813a1
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 117 deletions.
45 changes: 44 additions & 1 deletion faststream/broker/fastapi/router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from abc import abstractmethod
from contextlib import asynccontextmanager
from enum import Enum
from typing import (
Expand All @@ -25,7 +26,7 @@
from fastapi.utils import generate_unique_id
from starlette import routing
from starlette.responses import JSONResponse, Response
from starlette.routing import _DefaultLifespan
from starlette.routing import BaseRoute, _DefaultLifespan
from starlette.types import AppType, ASGIApp, Lifespan

from faststream.asyncapi import schema as asyncapi
Expand Down Expand Up @@ -350,3 +351,45 @@ def serve_asyncapi_schema(
docs_router.get(f"{schema_url}.json")(download_app_json_schema)
docs_router.get(f"{schema_url}.yaml")(download_app_yaml_schema)
return docs_router

def include_router(
self,
router: "APIRouter",
*,
prefix: str = "",
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[BaseRoute]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
) -> None:
if isinstance(router, StreamRouter):
self._setup_log_context(self.broker, router.broker)
self.broker.handlers.update(router.broker.handlers)
self.broker._publishers.update(router.broker._publishers)

super().include_router(
router=router,
prefix=prefix,
tags=tags,
dependencies=dependencies,
default_response_class=default_response_class,
responses=responses,
callbacks=callbacks,
deprecated=deprecated,
include_in_schema=include_in_schema,
generate_unique_id_function=generate_unique_id_function,
)

@staticmethod
@abstractmethod
def _setup_log_context(
main_broker: BrokerAsyncUsecase[MsgType, Any],
including_broker: BrokerAsyncUsecase[MsgType, Any],
) -> None:
raise NotImplementedError()
8 changes: 8 additions & 0 deletions faststream/kafka/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,11 @@

class KafkaRouter(StreamRouter[ConsumerRecord]):
broker_class = KafkaBroker

@staticmethod
def _setup_log_context(
main_broker: KafkaBroker,
including_broker: KafkaBroker,
) -> None:
for h in including_broker.handlers.values():
main_broker._setup_log_context(h.topics)
9 changes: 8 additions & 1 deletion faststream/kafka/fastapi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor
from kafka.partitioner.default import DefaultPartitioner
from starlette import routing
from starlette.responses import JSONResponse, Response
from starlette.types import AppType, ASGIApp
from starlette.types import AppType, ASGIApp, Lifespan

from faststream.__about__ import __version__
from faststream._compat import override
Expand Down Expand Up @@ -122,6 +122,7 @@ class KafkaRouter(StreamRouter[ConsumerRecord]):
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
lifespan: Optional[Lifespan[Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
Expand Down Expand Up @@ -421,3 +422,9 @@ class KafkaRouter(StreamRouter[ConsumerRecord]):
self,
func: Callable[[AppType], Awaitable[None]],
) -> Callable[[AppType], Awaitable[None]]: ...
@override
@staticmethod
def _setup_log_context( # type: ignore[override]
main_broker: KafkaBroker,
including_broker: KafkaBroker,
) -> None: ...
30 changes: 28 additions & 2 deletions faststream/kafka/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def publish( # type: ignore[override]
if topic in handler.topics:
r = await call_handler(
handler=handler,
message=incoming,
message=[incoming] if handler.batch else incoming,
rpc=rpc,
rpc_timeout=rpc_timeout,
raise_timeout=raise_timeout,
Expand All @@ -109,6 +109,32 @@ async def publish( # type: ignore[override]

return None

async def publish_batch(
self,
*msgs: SendableMessage,
topic: str,
partition: Optional[int] = None,
timestamp_ms: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> None:
for handler in self.broker.handlers.values(): # pragma: no branch
if topic in handler.topics:
await call_handler(
handler=handler,
message=[
build_message(
message=message,
topic=topic,
partition=partition,
timestamp_ms=timestamp_ms,
headers=headers,
)
for message in msgs
],
)

return None


async def _fake_connect(self: KafkaBroker, *args: Any, **kwargs: Any) -> None:
self._producer = FakeProducer(self)
Expand All @@ -120,7 +146,7 @@ async def _fake_close(
exc_val: Optional[BaseException] = None,
exec_tb: Optional[TracebackType] = None,
) -> None:
for _key, p in self._publishers.items():
for p in self._publishers.values():
p.mock.reset_mock()
if getattr(p, "_fake_handler", False):
self.handlers.pop(p.topic, None)
Expand Down
8 changes: 8 additions & 0 deletions faststream/rabbit/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,11 @@

class RabbitRouter(StreamRouter[IncomingMessage]):
broker_class = RabbitBroker

@staticmethod
def _setup_log_context(
main_broker: RabbitBroker,
including_broker: RabbitBroker,
) -> None:
for h in including_broker.handlers.values():
main_broker._setup_log_context(h.queue, h.exchange)
9 changes: 8 additions & 1 deletion faststream/rabbit/fastapi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ from fastapi.utils import generate_unique_id
from pamqp.common import FieldTable
from starlette import routing
from starlette.responses import JSONResponse, Response
from starlette.types import AppType, ASGIApp
from starlette.types import AppType, ASGIApp, Lifespan
from yarl import URL

from faststream._compat import override
Expand Down Expand Up @@ -95,6 +95,7 @@ class RabbitRouter(StreamRouter[IncomingMessage]):
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
lifespan: Optional[Lifespan[Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
Expand Down Expand Up @@ -194,3 +195,9 @@ class RabbitRouter(StreamRouter[IncomingMessage]):
self,
func: Callable[[AppType], Awaitable[None]],
) -> Callable[[AppType], Awaitable[None]]: ...
@override
@staticmethod
def _setup_log_context( # type: ignore[override]
main_broker: RabbitBroker,
including_broker: RabbitBroker,
) -> None: ...
70 changes: 24 additions & 46 deletions tests/brokers/base/consume.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from unittest.mock import Mock

import pytest

Expand All @@ -17,26 +16,24 @@ async def test_consume(
self,
queue: str,
consume_broker: BrokerUsecase,
event: asyncio.Event,
):
@consume_broker.subscriber(queue)
def subscriber(m):
event.set()
...

await consume_broker.start()
await asyncio.wait(
(
asyncio.create_task(consume_broker.publish("hello", queue)),
asyncio.create_task(event.wait()),
asyncio.create_task(subscriber.wait_call()),
),
timeout=3,
)

assert event.is_set()
assert subscriber.event.is_set()

async def test_consume_from_multi(
self,
mock: Mock,
queue: str,
consume_broker: BrokerUsecase,
):
Expand All @@ -50,7 +47,6 @@ def subscriber(m):
consume.set()
else:
consume2.set()
mock()

await consume_broker.start()
await asyncio.wait(
Expand All @@ -65,11 +61,10 @@ def subscriber(m):

assert consume2.is_set()
assert consume.is_set()
assert mock.call_count == 2
assert subscriber.mock.call_count == 2

async def test_consume_double(
self,
mock: Mock,
queue: str,
consume_broker: BrokerUsecase,
):
Expand All @@ -82,7 +77,6 @@ async def handler(m):
consume.set()
else:
consume2.set()
mock()

await consume_broker.start()

Expand All @@ -98,83 +92,71 @@ async def handler(m):

assert consume2.is_set()
assert consume.is_set()
assert mock.call_count == 2
assert handler.mock.call_count == 2

async def test_different_consume(
self,
mock: Mock,
queue: str,
consume_broker: BrokerUsecase,
):
first_consume = asyncio.Event()
second_consume = asyncio.Event()

@consume_broker.subscriber(queue)
def handler(m):
first_consume.set()
mock.method()
...

another_topic = queue + "1"

@consume_broker.subscriber(another_topic)
def handler2(m):
second_consume.set()
mock.method2()
...

await consume_broker.start()

await asyncio.wait(
(
asyncio.create_task(consume_broker.publish("hello", queue)),
asyncio.create_task(consume_broker.publish("hello", another_topic)),
asyncio.create_task(first_consume.wait()),
asyncio.create_task(second_consume.wait()),
asyncio.create_task(handler.wait_call()),
asyncio.create_task(handler2.wait_call()),
),
timeout=3,
)

assert first_consume.is_set()
assert second_consume.is_set()
mock.method.assert_called_once()
mock.method2.assert_called_once()
assert handler.event.is_set()
assert handler2.event.is_set()
handler.mock.assert_called_once()
handler2.mock.assert_called_once()

async def test_consume_with_filter(
self,
mock: Mock,
queue: str,
consume_broker: BrokerUsecase,
):
consume = asyncio.Event()
consume2 = asyncio.Event()

@consume_broker.subscriber(
queue, filter=lambda m: m.content_type == "application/json"
)
async def handler(m):
consume2.set()
mock.call1(m)
...

@consume_broker.subscriber(queue)
async def handler2(m):
consume.set()
mock.call2(m)
...

await consume_broker.start()

await asyncio.wait(
(
asyncio.create_task(consume_broker.publish({"msg": "hello"}, queue)),
asyncio.create_task(consume_broker.publish("hello", queue)),
asyncio.create_task(consume.wait()),
asyncio.create_task(consume2.wait()),
asyncio.create_task(handler2.wait_call()),
asyncio.create_task(handler.wait_call()),
),
timeout=3,
)

assert consume2.is_set()
assert consume.is_set()
mock.call1.assert_called_once_with({"msg": "hello"})
mock.call2.assert_called_once_with("hello")
assert handler2.event.is_set()
assert handler2.event.is_set()
handler.mock.assert_called_once_with({"msg": "hello"})
handler2.mock.assert_called_once_with("hello")


@pytest.mark.asyncio
Expand All @@ -183,27 +165,23 @@ class BrokerRealConsumeTestcase(BrokerConsumeTestcase):
async def test_stop_consume_exc(
self,
queue: str,
mock: Mock,
consume_broker: BrokerUsecase,
event: asyncio.Event,
):
@consume_broker.subscriber(queue)
def subscriber(m):
event.set()
mock()
raise StopConsume()

await consume_broker.start()
await asyncio.wait(
(
asyncio.create_task(consume_broker.publish("hello", queue)),
asyncio.create_task(event.wait()),
asyncio.create_task(subscriber.wait_call()),
),
timeout=3,
)
await asyncio.sleep(0.5)
await consume_broker.publish("hello", queue)
await asyncio.sleep(0.5)

assert event.is_set()
mock.assert_called_once()
assert subscriber.event.is_set()
subscriber.mock.assert_called_once()
Loading

0 comments on commit 8a813a1

Please sign in to comment.