From 6928d07346ebb7cf599117e0a8be139d8e1405d5 Mon Sep 17 00:00:00 2001 From: Lancetnik Date: Fri, 1 Sep 2023 17:48:31 +0300 Subject: [PATCH 1/4] feat: add FastAPI Router include_router method, fix #497 --- faststream/broker/fastapi/router.py | 45 +++++++++++++++- faststream/kafka/fastapi.py | 8 +++ faststream/kafka/fastapi.pyi | 9 +++- faststream/kafka/test.py | 30 ++++++++++- faststream/rabbit/fastapi.py | 8 +++ faststream/rabbit/fastapi.pyi | 10 +++- tests/asyncapi/base/fastapi.py | 10 +--- tests/brokers/base/consume.py | 70 +++++++++---------------- tests/brokers/base/fastapi.py | 36 +++++++++++++ tests/brokers/base/parser.py | 55 ++++++++----------- tests/brokers/base/publish.py | 64 +++++++++++++--------- tests/brokers/base/rpc.py | 14 ++--- tests/brokers/kafka/test_test_client.py | 28 ++++++++++ 13 files changed, 261 insertions(+), 126 deletions(-) diff --git a/faststream/broker/fastapi/router.py b/faststream/broker/fastapi/router.py index 1b298ce2ba..ead67ff5ab 100644 --- a/faststream/broker/fastapi/router.py +++ b/faststream/broker/fastapi/router.py @@ -1,4 +1,5 @@ import json +from abc import abstractmethod from contextlib import asynccontextmanager from enum import Enum from typing import ( @@ -25,7 +26,7 @@ from fastapi.utils import generate_unique_id from starlette import routing from starlette.responses import JSONResponse, Response -from starlette.routing import _DefaultLifespan +from starlette.routing import BaseRoute, _DefaultLifespan from starlette.types import AppType, ASGIApp, Lifespan from faststream.asyncapi import schema as asyncapi @@ -350,3 +351,45 @@ def serve_asyncapi_schema( docs_router.get(f"{schema_url}.json")(download_app_json_schema) docs_router.get(f"{schema_url}.yaml")(download_app_yaml_schema) return docs_router + + def include_router( + self, + router: "APIRouter", + *, + prefix: str = "", + tags: Optional[List[Union[str, Enum]]] = None, + dependencies: Optional[Sequence[params.Depends]] = None, + default_response_class: Type[Response] = Default(JSONResponse), + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + callbacks: Optional[List[BaseRoute]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + generate_unique_id_function: Callable[[APIRoute], str] = Default( + generate_unique_id + ), + ) -> None: + if isinstance(router, StreamRouter): + self._setup_log_context(self.broker, router.broker) + self.broker.handlers.update(router.broker.handlers) + self.broker._publishers.update(router.broker._publishers) + + super().include_router( + router=router, + prefix=prefix, + tags=tags, + dependencies=dependencies, + default_response_class=default_response_class, + responses=responses, + callbacks=callbacks, + deprecated=deprecated, + include_in_schema=include_in_schema, + generate_unique_id_function=generate_unique_id_function, + ) + + @staticmethod + @abstractmethod + def _setup_log_context( + main_broker: BrokerAsyncUsecase[MsgType, Any], + including_broker: BrokerAsyncUsecase[MsgType, Any], + ) -> None: + raise NotImplementedError() diff --git a/faststream/kafka/fastapi.py b/faststream/kafka/fastapi.py index 8c89c93b61..e26418a3e5 100644 --- a/faststream/kafka/fastapi.py +++ b/faststream/kafka/fastapi.py @@ -6,3 +6,11 @@ class KafkaRouter(StreamRouter[ConsumerRecord]): broker_class = KafkaBroker + + @staticmethod + def _setup_log_context( + main_broker: KafkaBroker, + including_broker: KafkaBroker, + ) -> None: + for h in including_broker.handlers.values(): + main_broker._setup_log_context(h.topics) diff --git a/faststream/kafka/fastapi.pyi b/faststream/kafka/fastapi.pyi index c1d7c73106..2828974ce1 100644 --- a/faststream/kafka/fastapi.pyi +++ b/faststream/kafka/fastapi.pyi @@ -34,7 +34,7 @@ from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from kafka.partitioner.default import DefaultPartitioner from starlette import routing from starlette.responses import JSONResponse, Response -from starlette.types import AppType, ASGIApp +from starlette.types import AppType, ASGIApp, Lifespan from faststream.__about__ import __version__ from faststream._compat import override @@ -122,6 +122,7 @@ class KafkaRouter(StreamRouter[ConsumerRecord]): on_shutdown: Optional[Sequence[Callable[[], Any]]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, + lifespan: Optional[Lifespan[Any]] = None, generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id ), @@ -421,3 +422,9 @@ class KafkaRouter(StreamRouter[ConsumerRecord]): self, func: Callable[[AppType], Awaitable[None]], ) -> Callable[[AppType], Awaitable[None]]: ... + @override + @staticmethod + def _setup_log_context( # type: ignore[override] + main_broker: KafkaBroker, + including_broker: KafkaBroker, + ) -> None: ... diff --git a/faststream/kafka/test.py b/faststream/kafka/test.py index 5fad8d1df2..3104c30fc8 100644 --- a/faststream/kafka/test.py +++ b/faststream/kafka/test.py @@ -98,7 +98,7 @@ async def publish( # type: ignore[override] if topic in handler.topics: r = await call_handler( handler=handler, - message=incoming, + message=[incoming] if handler.batch else incoming, rpc=rpc, rpc_timeout=rpc_timeout, raise_timeout=raise_timeout, @@ -109,6 +109,32 @@ async def publish( # type: ignore[override] return None + async def publish_batch( + self, + *msgs: SendableMessage, + topic: str, + partition: Optional[int] = None, + timestamp_ms: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + ) -> None: + for handler in self.broker.handlers.values(): # pragma: no branch + if topic in handler.topics: + await call_handler( + handler=handler, + message=[ + build_message( + message=message, + topic=topic, + partition=partition, + timestamp_ms=timestamp_ms, + headers=headers, + ) + for message in msgs + ], + ) + + return None + async def _fake_connect(self: KafkaBroker, *args: Any, **kwargs: Any) -> None: self._producer = FakeProducer(self) @@ -120,7 +146,7 @@ async def _fake_close( exc_val: Optional[BaseException] = None, exec_tb: Optional[TracebackType] = None, ) -> None: - for _key, p in self._publishers.items(): + for p in self._publishers.values(): p.mock.reset_mock() if getattr(p, "_fake_handler", False): self.handlers.pop(p.topic, None) diff --git a/faststream/rabbit/fastapi.py b/faststream/rabbit/fastapi.py index e036f1e822..30a3ce09c5 100644 --- a/faststream/rabbit/fastapi.py +++ b/faststream/rabbit/fastapi.py @@ -6,3 +6,11 @@ class RabbitRouter(StreamRouter[IncomingMessage]): broker_class = RabbitBroker + + @staticmethod + def _setup_log_context( + main_broker: RabbitBroker, + including_broker: RabbitBroker, + ) -> None: + for h in including_broker.handlers.values(): + main_broker._setup_log_context(h.queue, h.exchange) diff --git a/faststream/rabbit/fastapi.pyi b/faststream/rabbit/fastapi.pyi index 4908e5583a..55061d346c 100644 --- a/faststream/rabbit/fastapi.pyi +++ b/faststream/rabbit/fastapi.pyi @@ -23,7 +23,7 @@ from fastapi.utils import generate_unique_id from pamqp.common import FieldTable from starlette import routing from starlette.responses import JSONResponse, Response -from starlette.types import AppType, ASGIApp +from starlette.types import AppType, ASGIApp, Lifespan from yarl import URL from faststream._compat import override @@ -49,7 +49,6 @@ from faststream.types import AnyDict class RabbitRouter(StreamRouter[IncomingMessage]): broker_class: Type[RabbitBroker] - # nosemgrep: python.lang.security.audit.hardcoded-password-default-argument.hardcoded-password-default-argument def __init__( self, url: Union[str, URL, None] = "amqp://guest:guest@localhost:5672/", @@ -95,6 +94,7 @@ class RabbitRouter(StreamRouter[IncomingMessage]): on_shutdown: Optional[Sequence[Callable[[], Any]]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, + lifespan: Optional[Lifespan[Any]] = None, generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id ), @@ -194,3 +194,9 @@ class RabbitRouter(StreamRouter[IncomingMessage]): self, func: Callable[[AppType], Awaitable[None]], ) -> Callable[[AppType], Awaitable[None]]: ... + @override + @staticmethod + def _setup_log_context( # type: ignore[override] + main_broker: RabbitBroker, + including_broker: RabbitBroker, + ) -> None: ... diff --git a/tests/asyncapi/base/fastapi.py b/tests/asyncapi/base/fastapi.py index 6150be21bf..4e7fc896d9 100644 --- a/tests/asyncapi/base/fastapi.py +++ b/tests/asyncapi/base/fastapi.py @@ -44,14 +44,8 @@ def test_fastapi_full_information(self): "title": "CustomApp", "version": "1.1.1", "description": "Test description", - "contact": { - "name": "support", - "url": IsStr(regex=r"https\:\/\/support\.com\/?"), - }, - "license": { - "name": "some", - "url": IsStr(regex=r"https\:\/\/some\.com\/?"), - }, + "contact": {"name": "support", "url": "https://support.com/"}, + "license": {"name": "some", "url": "https://some.com/"}, }, "servers": { "development": { diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index 67514c9030..3229b0f324 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -1,5 +1,4 @@ import asyncio -from unittest.mock import Mock import pytest @@ -17,26 +16,24 @@ async def test_consume( self, queue: str, consume_broker: BrokerUsecase, - event: asyncio.Event, ): @consume_broker.subscriber(queue) def subscriber(m): - event.set() + ... await consume_broker.start() await asyncio.wait( ( asyncio.create_task(consume_broker.publish("hello", queue)), - asyncio.create_task(event.wait()), + asyncio.create_task(subscriber.wait_call()), ), timeout=3, ) - assert event.is_set() + assert subscriber.event.is_set() async def test_consume_from_multi( self, - mock: Mock, queue: str, consume_broker: BrokerUsecase, ): @@ -50,7 +47,6 @@ def subscriber(m): consume.set() else: consume2.set() - mock() await consume_broker.start() await asyncio.wait( @@ -65,11 +61,10 @@ def subscriber(m): assert consume2.is_set() assert consume.is_set() - assert mock.call_count == 2 + assert subscriber.mock.call_count == 2 async def test_consume_double( self, - mock: Mock, queue: str, consume_broker: BrokerUsecase, ): @@ -82,7 +77,6 @@ async def handler(m): consume.set() else: consume2.set() - mock() await consume_broker.start() @@ -98,28 +92,22 @@ async def handler(m): assert consume2.is_set() assert consume.is_set() - assert mock.call_count == 2 + assert handler.mock.call_count == 2 async def test_different_consume( self, - mock: Mock, queue: str, consume_broker: BrokerUsecase, ): - first_consume = asyncio.Event() - second_consume = asyncio.Event() - @consume_broker.subscriber(queue) def handler(m): - first_consume.set() - mock.method() + ... another_topic = queue + "1" @consume_broker.subscriber(another_topic) def handler2(m): - second_consume.set() - mock.method2() + ... await consume_broker.start() @@ -127,37 +115,31 @@ def handler2(m): ( asyncio.create_task(consume_broker.publish("hello", queue)), asyncio.create_task(consume_broker.publish("hello", another_topic)), - asyncio.create_task(first_consume.wait()), - asyncio.create_task(second_consume.wait()), + asyncio.create_task(handler.wait_call()), + asyncio.create_task(handler2.wait_call()), ), timeout=3, ) - assert first_consume.is_set() - assert second_consume.is_set() - mock.method.assert_called_once() - mock.method2.assert_called_once() + assert handler.event.is_set() + assert handler2.event.is_set() + handler.mock.assert_called_once() + handler2.mock.assert_called_once() async def test_consume_with_filter( self, - mock: Mock, queue: str, consume_broker: BrokerUsecase, ): - consume = asyncio.Event() - consume2 = asyncio.Event() - @consume_broker.subscriber( queue, filter=lambda m: m.content_type == "application/json" ) async def handler(m): - consume2.set() - mock.call1(m) + ... @consume_broker.subscriber(queue) async def handler2(m): - consume.set() - mock.call2(m) + ... await consume_broker.start() @@ -165,16 +147,16 @@ async def handler2(m): ( asyncio.create_task(consume_broker.publish({"msg": "hello"}, queue)), asyncio.create_task(consume_broker.publish("hello", queue)), - asyncio.create_task(consume.wait()), - asyncio.create_task(consume2.wait()), + asyncio.create_task(handler2.wait_call()), + asyncio.create_task(handler.wait_call()), ), timeout=3, ) - assert consume2.is_set() - assert consume.is_set() - mock.call1.assert_called_once_with({"msg": "hello"}) - mock.call2.assert_called_once_with("hello") + assert handler2.event.is_set() + assert handler2.event.is_set() + handler.mock.assert_called_once_with({"msg": "hello"}) + handler2.mock.assert_called_once_with("hello") @pytest.mark.asyncio @@ -183,21 +165,17 @@ class BrokerRealConsumeTestcase(BrokerConsumeTestcase): async def test_stop_consume_exc( self, queue: str, - mock: Mock, consume_broker: BrokerUsecase, - event: asyncio.Event, ): @consume_broker.subscriber(queue) def subscriber(m): - event.set() - mock() raise StopConsume() await consume_broker.start() await asyncio.wait( ( asyncio.create_task(consume_broker.publish("hello", queue)), - asyncio.create_task(event.wait()), + asyncio.create_task(subscriber.wait_call()), ), timeout=3, ) @@ -205,5 +183,5 @@ def subscriber(m): await consume_broker.publish("hello", queue) await asyncio.sleep(0.5) - assert event.is_set() - mock.assert_called_once() + assert subscriber.event.is_set() + subscriber.mock.assert_called_once() diff --git a/tests/brokers/base/fastapi.py b/tests/brokers/base/fastapi.py index e61af7a786..6656183df6 100644 --- a/tests/brokers/base/fastapi.py +++ b/tests/brokers/base/fastapi.py @@ -339,3 +339,39 @@ async def m(): async with self.broker_test(router.broker) as rb: await rb.publish("hello", queue) publisher.mock.assert_called_with("response") + + async def test_include(self, queue: str): + router = self.router_class() + router2 = self.router_class() + router.broker = self.broker_test(router.broker) + + app = FastAPI(lifespan=router.lifespan_context) + + @router.subscriber(queue) + async def hello(): + return "hi" + + @router2.subscriber(queue + "1") + async def hello_router2(): + return "hi" + + router.include_router(router2) + async with router.broker: + with TestClient(app) as client: + assert client.app_state["broker"] is router.broker + + r = await router.broker.publish( + "hi", + queue, + rpc=True, + rpc_timeout=0.5, + ) + assert r == "hi" + + r = await router.broker.publish( + "hi", + queue + "1", + rpc=True, + rpc_timeout=0.5, + ) + assert r == "hi" diff --git a/tests/brokers/base/parser.py b/tests/brokers/base/parser.py index 5aff790a5b..dcaa28ed37 100644 --- a/tests/brokers/base/parser.py +++ b/tests/brokers/base/parser.py @@ -20,9 +20,7 @@ def patch_broker( ) -> BrokerAsyncUsecase: return broker - async def test_local_parser( - self, event: asyncio.Event, mock: Mock, queue: str, raw_broker - ): + async def test_local_parser(self, mock: Mock, queue: str, raw_broker): broker = self.broker_class() async def custom_parser(msg, original): @@ -32,7 +30,7 @@ async def custom_parser(msg, original): @broker.subscriber(queue, parser=custom_parser) async def handle(m): - event.set() + ... broker = self.patch_broker(raw_broker, broker) async with broker: @@ -41,17 +39,15 @@ async def handle(m): await asyncio.wait( ( asyncio.create_task(broker.publish(b"hello", queue)), - asyncio.create_task(event.wait()), + asyncio.create_task(handle.wait_call()), ), timeout=3, ) - assert event.is_set() + assert handle.event.is_set() mock.assert_called_once_with(b"hello") - async def test_local_sync_decoder( - self, event: asyncio.Event, mock: Mock, queue: str, raw_broker - ): + async def test_local_sync_decoder(self, mock: Mock, queue: str, raw_broker): broker = self.broker_class() def custom_decoder(msg): @@ -60,7 +56,7 @@ def custom_decoder(msg): @broker.subscriber(queue, decoder=custom_decoder) async def handle(m): - event.set() + ... broker = self.patch_broker(raw_broker, broker) async with broker: @@ -69,17 +65,15 @@ async def handle(m): await asyncio.wait( ( asyncio.create_task(broker.publish(b"hello", queue)), - asyncio.create_task(event.wait()), + asyncio.create_task(handle.wait_call()), ), timeout=3, ) - assert event.is_set() + assert handle.event.is_set() mock.assert_called_once_with(b"hello") - async def test_global_sync_decoder( - self, event: asyncio.Event, mock: Mock, queue: str, raw_broker - ): + async def test_global_sync_decoder(self, mock: Mock, queue: str, raw_broker): def custom_decoder(msg): mock(msg.body) return msg @@ -88,7 +82,7 @@ def custom_decoder(msg): @broker.subscriber(queue) async def handle(m): - event.set() + ... broker = self.patch_broker(raw_broker, broker) async with broker: @@ -97,12 +91,12 @@ async def handle(m): await asyncio.wait( ( asyncio.create_task(broker.publish(b"hello", queue)), - asyncio.create_task(event.wait()), + asyncio.create_task(handle.wait_call()), ), timeout=3, ) - assert event.is_set() + assert handle.event.is_set() mock.assert_called_once_with(b"hello") async def test_local_parser_no_share_between_subscribers( @@ -143,9 +137,8 @@ async def handle(m): mock.assert_called_once_with(b"hello") async def test_local_parser_no_share_between_handlers( - self, event: asyncio.Event, mock: Mock, queue: str, raw_broker + self, mock: Mock, queue: str, raw_broker ): - event2 = asyncio.Event() broker = self.broker_class() async def custom_parser(msg, original): @@ -155,11 +148,11 @@ async def custom_parser(msg, original): @broker.subscriber(queue, filter=lambda m: m.content_type == "application/json") async def handle(m): - event2.set() + ... @broker.subscriber(queue, parser=custom_parser) async def handle2(m): - event.set() + ... broker = self.patch_broker(raw_broker, broker) async with broker: @@ -169,21 +162,19 @@ async def handle2(m): ( asyncio.create_task(broker.publish({"msg": "hello"}, queue)), asyncio.create_task(broker.publish(b"hello", queue)), - asyncio.create_task(event.wait()), - asyncio.create_task(event2.wait()), + asyncio.create_task(handle.wait_call()), + asyncio.create_task(handle2.wait_call()), ), timeout=3, ) - assert event.is_set() - assert event2.is_set() + assert handle.event.is_set() + assert handle2.event.is_set() assert mock.call_count == 2 # instead 4 class CustomParserTestcase(LocalCustomParserTestcase): - async def test_global_parser( - self, event: asyncio.Event, mock: Mock, queue: str, raw_broker - ): + async def test_global_parser(self, mock: Mock, queue: str, raw_broker): async def custom_parser(msg, original): msg = await original(msg) mock(msg.body) @@ -193,7 +184,7 @@ async def custom_parser(msg, original): @broker.subscriber(queue) async def handle(m): - event.set() + ... broker = self.patch_broker(raw_broker, broker) async with broker: @@ -202,10 +193,10 @@ async def handle(m): await asyncio.wait( ( asyncio.create_task(broker.publish(b"hello", queue)), - asyncio.create_task(event.wait()), + asyncio.create_task(handle.wait_call()), ), timeout=3, ) - assert event.is_set() + assert handle.event.is_set() mock.assert_called_once_with(b"hello") diff --git a/tests/brokers/base/publish.py b/tests/brokers/base/publish.py index e35605d608..c23ae929fc 100644 --- a/tests/brokers/base/publish.py +++ b/tests/brokers/base/publish.py @@ -3,6 +3,7 @@ from typing import Dict, List, Tuple from unittest.mock import Mock +import anyio import pytest from pydantic import BaseModel @@ -46,7 +47,7 @@ async def test_serialize( message, message_type, expected_message, - event: asyncio.Event, + event, ): @pub_broker.subscriber(queue) async def handler(m: message_type, logger: Logger): @@ -72,7 +73,7 @@ async def handler(m: message_type, logger: Logger): @pytest.mark.asyncio async def test_unwrap_dict( - self, mock: Mock, queue: str, pub_broker: BrokerUsecase, event: asyncio.Event + self, mock: Mock, queue: str, pub_broker: BrokerUsecase, event ): @pub_broker.subscriber(queue) async def m(a: int, b: int, logger: Logger): @@ -124,7 +125,8 @@ async def test_base_publisher( self, queue: str, pub_broker: BrokerUsecase, - event: asyncio.Event, + event, + mock, ): @pub_broker.subscriber(queue) @pub_broker.publisher(queue + "resp") @@ -132,8 +134,9 @@ async def m(): return "" @pub_broker.subscriber(queue + "resp") - async def resp(): + async def resp(msg): event.set() + mock(msg) async with pub_broker: await pub_broker.start() @@ -146,13 +149,15 @@ async def resp(): ) assert event.is_set() + mock.assert_called_once_with("") @pytest.mark.asyncio async def test_publisher_object( self, queue: str, pub_broker: BrokerUsecase, - event: asyncio.Event, + event, + mock, ): publisher = pub_broker.publisher(queue + "resp") @@ -162,8 +167,9 @@ async def m(): return "" @pub_broker.subscriber(queue + "resp") - async def resp(): + async def resp(msg): event.set() + mock(msg) async with pub_broker: await pub_broker.start() @@ -176,13 +182,15 @@ async def resp(): ) assert event.is_set() + mock.assert_called_once_with("") @pytest.mark.asyncio async def test_publish_manual( self, queue: str, pub_broker: BrokerUsecase, - event: asyncio.Event, + event, + mock, ): publisher = pub_broker.publisher(queue + "resp") @@ -191,8 +199,9 @@ async def m(): await publisher.publish("") @pub_broker.subscriber(queue + "resp") - async def resp(): + async def resp(msg): event.set() + mock(msg) async with pub_broker: await pub_broker.start() @@ -205,11 +214,14 @@ async def resp(): ) assert event.is_set() + mock.assert_called_once_with("") @pytest.mark.asyncio - async def test_multiple_publishers(self, queue: str, pub_broker: BrokerUsecase): - consume = asyncio.Event() - consume2 = asyncio.Event() + async def test_multiple_publishers( + self, queue: str, pub_broker: BrokerUsecase, mock + ): + event = anyio.Event() + event2 = anyio.Event() @pub_broker.publisher(queue + "resp2") @pub_broker.subscriber(queue) @@ -218,33 +230,37 @@ async def m(): return "" @pub_broker.subscriber(queue + "resp") - async def resp(): - consume.set() + async def resp(msg): + event.set() + mock.resp1(msg) @pub_broker.subscriber(queue + "resp2") - async def resp2(): - consume2.set() + async def resp2(msg): + event2.set() + mock.resp2(msg) async with pub_broker: await pub_broker.start() await asyncio.wait( ( asyncio.create_task(pub_broker.publish("", queue)), - asyncio.create_task(consume.wait()), - asyncio.create_task(consume2.wait()), + asyncio.create_task(event.wait()), + asyncio.create_task(event2.wait()), ), timeout=3, ) - assert consume.is_set() - assert consume2.is_set() + assert event.is_set() + assert event2.is_set() + mock.resp1.assert_called_once_with("") + mock.resp2.assert_called_once_with("") @pytest.mark.asyncio async def test_reusable_publishers( - self, mock: Mock, queue: str, pub_broker: BrokerUsecase + self, queue: str, pub_broker: BrokerUsecase, mock ): - consume = asyncio.Event() - consume2 = asyncio.Event() + consume = anyio.Event() + consume2 = anyio.Event() pub = pub_broker.publisher(queue + "resp") @@ -286,9 +302,9 @@ async def resp(): async def test_reply_to( self, pub_broker: BrokerUsecase, - mock: Mock, queue: str, - event: asyncio.Event, + event, + mock, ): @pub_broker.subscriber(queue + "reply") async def reply_handler(m): diff --git a/tests/brokers/base/rpc.py b/tests/brokers/base/rpc.py index 556ac0ee21..19662cc5a8 100644 --- a/tests/brokers/base/rpc.py +++ b/tests/brokers/base/rpc.py @@ -1,6 +1,3 @@ -import asyncio -from unittest.mock import Mock - import anyio import pytest @@ -58,15 +55,12 @@ async def m(m): # pragma: no cover assert r is None @pytest.mark.asyncio - async def test_rpc_with_reply( - self, queue: str, mock: Mock, rpc_broker: BrokerUsecase, event: asyncio.Event - ): + async def test_rpc_with_reply(self, queue: str, rpc_broker: BrokerUsecase): reply_queue = queue + "1" @rpc_broker.subscriber(reply_queue) async def response_hanler(m: str): - event.set() - mock(m) + ... @rpc_broker.subscriber(queue) async def m(m): # pragma: no cover @@ -75,9 +69,9 @@ async def m(m): # pragma: no cover await rpc_broker.start() await rpc_broker.publish("hello", queue, reply_to=reply_queue) - await asyncio.wait_for(event.wait(), 3) + await response_hanler.wait_call(3) - mock.assert_called_with("1") + response_hanler.mock.assert_called_with("1") class ReplyAndConsumeForbidden: diff --git a/tests/brokers/kafka/test_test_client.py b/tests/brokers/kafka/test_test_client.py index 9de7ce2e65..9eddc9567a 100644 --- a/tests/brokers/kafka/test_test_client.py +++ b/tests/brokers/kafka/test_test_client.py @@ -8,6 +8,34 @@ @pytest.mark.asyncio class TestTestclient(BrokerTestclientTestcase): + @pytest.mark.kafka + async def test_batch_pub_by_default_pub( + self, + test_broker: KafkaBroker, + queue: str, + ): + @test_broker.subscriber(queue, batch=True) + async def m(): + pass + + await test_broker.start() + await test_broker.publish("hello", queue) + m.mock.assert_called_once_with(["hello"]) + + @pytest.mark.kafka + async def test_batch_pub_by_pub_batch( + self, + test_broker: KafkaBroker, + queue: str, + ): + @test_broker.subscriber(queue, batch=True) + async def m(): + pass + + await test_broker.start() + await test_broker.publish("hello", topic=queue) + m.mock.assert_called_once_with(["hello"]) + @pytest.mark.kafka async def test_with_real_testclient( self, From 9f46f0148b84c55cc0b7c25d98cff9b20232250a Mon Sep 17 00:00:00 2001 From: Lancetnik Date: Fri, 1 Sep 2023 17:52:50 +0300 Subject: [PATCH 2/4] test: rollback FastAPI AsyncAPI test --- tests/asyncapi/base/fastapi.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/asyncapi/base/fastapi.py b/tests/asyncapi/base/fastapi.py index 4e7fc896d9..6150be21bf 100644 --- a/tests/asyncapi/base/fastapi.py +++ b/tests/asyncapi/base/fastapi.py @@ -44,8 +44,14 @@ def test_fastapi_full_information(self): "title": "CustomApp", "version": "1.1.1", "description": "Test description", - "contact": {"name": "support", "url": "https://support.com/"}, - "license": {"name": "some", "url": "https://some.com/"}, + "contact": { + "name": "support", + "url": IsStr(regex=r"https\:\/\/support\.com\/?"), + }, + "license": { + "name": "some", + "url": IsStr(regex=r"https\:\/\/some\.com\/?"), + }, }, "servers": { "development": { From cf44795b5073a48fb866650fc1b4f96682425042 Mon Sep 17 00:00:00 2001 From: Lancetnik Date: Fri, 1 Sep 2023 17:54:17 +0300 Subject: [PATCH 3/4] fix: use publish_batch method in test --- tests/brokers/kafka/test_test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brokers/kafka/test_test_client.py b/tests/brokers/kafka/test_test_client.py index 9eddc9567a..64a67a8aab 100644 --- a/tests/brokers/kafka/test_test_client.py +++ b/tests/brokers/kafka/test_test_client.py @@ -33,7 +33,7 @@ async def m(): pass await test_broker.start() - await test_broker.publish("hello", topic=queue) + await test_broker.publish_batch("hello", topic=queue) m.mock.assert_called_once_with(["hello"]) @pytest.mark.kafka From 913eaeca1fd6bbced8777abdbfd72011ab418040 Mon Sep 17 00:00:00 2001 From: Lancetnik Date: Fri, 1 Sep 2023 17:56:44 +0300 Subject: [PATCH 4/4] fix: fix semgrep --- faststream/rabbit/fastapi.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/faststream/rabbit/fastapi.pyi b/faststream/rabbit/fastapi.pyi index 55061d346c..713fc4d1ce 100644 --- a/faststream/rabbit/fastapi.pyi +++ b/faststream/rabbit/fastapi.pyi @@ -49,6 +49,7 @@ from faststream.types import AnyDict class RabbitRouter(StreamRouter[IncomingMessage]): broker_class: Type[RabbitBroker] + # nosemgrep: python.lang.security.audit.hardcoded-password-default-argument.hardcoded-password-default-argument def __init__( self, url: Union[str, URL, None] = "amqp://guest:guest@localhost:5672/",