diff --git a/aio_pika/abc.py b/aio_pika/abc.py index 00283d6f..686f2336 100644 --- a/aio_pika/abc.py +++ b/aio_pika/abc.py @@ -5,18 +5,19 @@ from types import TracebackType from typing import ( Any, AsyncContextManager, AsyncIterable, Awaitable, Callable, Dict, - FrozenSet, Generator, Iterator, MutableMapping, NamedTuple, Optional, Set, + FrozenSet, Iterator, MutableMapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union, ) -import aiormq +import aiormq.abc from aiormq.abc import ExceptionType from pamqp.common import Arguments from yarl import URL from .pool import PoolInstance -from .tools import CallbackCollection, CallbackSetType, CallbackType - +from .tools import ( + CallbackCollection, CallbackSetType, CallbackType, OneShotCallback +) TimeoutType = Optional[Union[int, float]] @@ -219,7 +220,6 @@ async def __aexit__( class AbstractQueue: channel: "AbstractChannel" - connection: "AbstractConnection" name: str durable: bool exclusive: bool @@ -228,6 +228,19 @@ class AbstractQueue: passive: bool declaration_result: aiormq.spec.Queue.DeclareOk + @abstractmethod + def __init__( + self, + channel: "AbstractChannel", + name: Optional[str], + durable: bool, + exclusive: bool, + auto_delete: bool, + arguments: Arguments, + passive: bool = False, + ): + raise NotImplementedError + @abstractmethod async def declare( self, timeout: TimeoutType = None, @@ -341,6 +354,21 @@ async def __anext__(self) -> AbstractIncomingMessage: class AbstractExchange(ABC): + @abstractmethod + def __init__( + self, + channel: "AbstractChannel", + name: str, + type: Union[ExchangeType, str] = ExchangeType.DIRECT, + *, + auto_delete: bool = False, + durable: bool = False, + internal: bool = False, + passive: bool = False, + arguments: Arguments = None + ): + raise NotImplementedError + @property @abstractmethod def channel(self) -> "AbstractChannel": @@ -392,20 +420,46 @@ async def delete( raise NotImplementedError +class UnderlayChannel(NamedTuple): + channel: aiormq.abc.AbstractChannel + close_callback: OneShotCallback + + @classmethod + async def create_channel( + cls, transport: "UnderlayConnection", + close_callback: Callable[..., Awaitable[Any]], **kwargs: Any + ) -> "UnderlayChannel": + close_callback = OneShotCallback(close_callback) + + await transport.connection.ready() + transport.connection.closing.add_done_callback(close_callback) + channel = await transport.connection.channel(**kwargs) + channel.closing.add_done_callback(close_callback) + + return cls( + channel=channel, + close_callback=close_callback, + ) + + async def close(self, exc: Optional[ExceptionType] = None) -> Any: + result: Any + result, _ = await asyncio.gather( + self.channel.close(exc), self.close_callback.wait() + ) + return result + + class AbstractChannel(PoolInstance, ABC): QUEUE_CLASS: Type[AbstractQueue] EXCHANGE_CLASS: Type[AbstractExchange] close_callbacks: CallbackCollection return_callbacks: CallbackCollection - connection: "AbstractConnection" + ready: asyncio.Event loop: asyncio.AbstractEventLoop default_exchange: AbstractExchange - @property - @abstractmethod - def done_callbacks(self) -> CallbackCollection: - raise NotImplementedError + publisher_confirms: bool @property @abstractmethod @@ -431,10 +485,6 @@ def channel(self) -> aiormq.abc.AbstractChannel: def number(self) -> Optional[int]: raise NotImplementedError - @abstractmethod - def __await__(self) -> Generator[Any, Any, "AbstractChannel"]: - raise NotImplementedError - @abstractmethod async def __aenter__(self) -> "AbstractChannel": raise NotImplementedError @@ -537,19 +587,50 @@ def transaction(self) -> AbstractTransaction: async def flow(self, active: bool = True) -> aiormq.spec.Channel.FlowOk: raise NotImplementedError + @abstractmethod + def __await__(self) -> Awaitable["AbstractChannel"]: + raise NotImplementedError + + +class UnderlayConnection(NamedTuple): + connection: aiormq.abc.AbstractConnection + close_callback: OneShotCallback + + @classmethod + async def connect( + cls, url: URL, close_callback: Callable[..., Awaitable[Any]], + timeout: TimeoutType = None, **kwargs: Any + ) -> "UnderlayConnection": + connection: aiormq.abc.AbstractConnection = await asyncio.wait_for( + aiormq.connect(url, **kwargs), timeout=timeout, + ) + close_callback = OneShotCallback(close_callback) + connection.closing.add_done_callback(close_callback) + await connection.ready() + return cls( + connection=connection, + close_callback=close_callback + ) + + async def close(self, exc: Optional[aiormq.abc.ExceptionType]): + result, _ = await asyncio.gather( + self.connection.close(exc), self.close_callback.wait() + ) + return result + class AbstractConnection(PoolInstance, ABC): loop: asyncio.AbstractEventLoop close_callbacks: CallbackCollection connected: asyncio.Event - connection: aiormq.abc.AbstractConnection + transport: UnderlayConnection @abstractmethod def __init__( self, url: URL, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any ): - NotImplementedError( + raise NotImplementedError( f"Method not implemented, passed: url={url}, loop={loop!r}", ) @@ -748,5 +829,7 @@ def channel( "MILLISECONDS", "TimeoutType", "TransactionState", + "UnderlayChannel", + "UnderlayConnection", "ZERO_TIME", ) diff --git a/aio_pika/channel.py b/aio_pika/channel.py index e7b109db..cfea828c 100644 --- a/aio_pika/channel.py +++ b/aio_pika/channel.py @@ -1,7 +1,10 @@ import asyncio +from abc import ABC from logging import getLogger from types import TracebackType -from typing import Any, Awaitable, Generator, Optional, Type, Union +from typing import ( + Any, Generator, Optional, Type, Union, AsyncContextManager, NamedTuple +) from warnings import warn import aiormq @@ -10,7 +13,7 @@ from .abc import ( AbstractChannel, AbstractConnection, AbstractExchange, AbstractQueue, - ChannelCloseCallback, TimeoutType, + TimeoutType, UnderlayConnection, UnderlayChannel, ) from .exchange import Exchange, ExchangeType from .message import ReturnCallback # noqa @@ -23,13 +26,32 @@ log = getLogger(__name__) -class Channel(AbstractChannel): +class ChannelContext(AsyncContextManager, AbstractChannel, ABC): + async def __aenter__(self) -> "AbstractChannel": + if not self.is_initialized: + await self.initialize() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + return await self.close(exc_val) + + def __await__(self) -> Generator[Any, Any, "AbstractChannel"]: + yield from self.initialize().__await__() + return self + + +class Channel(ChannelContext): """ Channel abstraction """ QUEUE_CLASS = Queue EXCHANGE_CLASS = Exchange - _channel: aiormq.abc.AbstractChannel + _channel: UnderlayChannel def __init__( self, @@ -55,33 +77,30 @@ def __init__( ) self.loop = connection.loop + self.ready = asyncio.Event() + + self.__operation_lock = asyncio.Lock() + + # That's means user closed channel instance explicitly + self._is_closed_by_user: bool = False - self._channel: aiormq.abc.AbstractChannel self._channel_number = channel_number + self._connection = connection - self.connection = connection self.close_callbacks = CallbackCollection(self) self.return_callbacks = CallbackCollection(self) + self.publisher_confirms = publisher_confirms self._on_return_raises = on_return_raises - self._publisher_confirms = publisher_confirms - self._delivery_tag = 0 - # That's means user closed channel instance explicitly - self._is_closed_by_user: bool = False - self.default_exchange: Exchange - @property - def done_callbacks(self) -> CallbackCollection: - return self.close_callbacks - @property def is_initialized(self) -> bool: """ Returns True when the channel has been opened and ready for interaction """ - return hasattr(self, "_channel") + return self.ready.is_set() @property def is_closed(self) -> bool: @@ -89,23 +108,24 @@ def is_closed(self) -> bool: side or after the close() method has been called. """ if not self.is_initialized or self._is_closed_by_user: return True - return ( - self._channel.is_closed or not self._channel.connection.is_opened - ) + return self._channel.channel.is_closed @task async def close(self, exc: aiormq.abc.ExceptionType = None) -> None: - if not self.is_initialized: - log.warning("Channel not opened") - return + async with self.__operation_lock: + if not self.is_initialized: + log.warning("Channel not opened") + return - channel: aiormq.abc.AbstractChannel = self._channel - del self._channel - self._is_closed_by_user = True - await channel.close() + log.debug("Closing channel %r", self) + try: + await self._channel.close() + finally: + self._is_closed_by_user = True @property def channel(self) -> aiormq.abc.AbstractChannel: + if not self.is_initialized: raise aiormq.exceptions.ChannelInvalidStateError( "Channel was not opened", @@ -116,109 +136,38 @@ def channel(self) -> aiormq.abc.AbstractChannel: "Channel has been closed", ) - return self._channel + return self._channel.channel @property def number(self) -> Optional[int]: - return self._channel.number if self.is_initialized else None + return ( + self.channel.number + if self.is_initialized + else self._channel_number + ) def __str__(self) -> str: return "{}".format(self.number or "Not initialized channel") def __repr__(self) -> str: + channel = getattr(self, "_channel", None) conn = None - - if self.is_initialized: - conn = self._channel.connection - - return '<%s #%s "%s">' % (self.__class__.__name__, self, conn) - - def __await__(self) -> Generator[Any, Any, "AbstractChannel"]: - yield from self.initialize().__await__() - return self - - async def __aenter__(self) -> "AbstractChannel": - if not self.is_initialized: - await self.initialize() - return self - - def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Awaitable[Any]: - return self.close() - - def add_close_callback( - self, callback: ChannelCloseCallback, weak: bool = False, - ) -> None: - warn( - "This method will be removed from future release. " - f"Use {self.__class__.__name__}.close_callbacks.add instead", - DeprecationWarning, - stacklevel=2, - ) - self.close_callbacks.add(callback, weak=weak) - - def remove_close_callback( - self, callback: ChannelCloseCallback, - ) -> None: - warn( - "This method will be removed from future release. " - f"Use {self.__class__.__name__}.close_callbacks.remove instead", - DeprecationWarning, - stacklevel=2, - ) - self.close_callbacks.remove(callback) - - def add_on_return_callback( - self, callback: ReturnCallback, weak: bool = False, - ) -> None: - warn( - "This method will be removed from future release. " - f"Use {self.__class__.__name__}.return_callbacks.add instead", - DeprecationWarning, - stacklevel=2, - ) - self.return_callbacks.add(callback, weak=weak) - - def remove_on_return_callback(self, callback: ReturnCallback) -> None: - warn( - "This method will be removed from future release. " - f"Use {self.__class__.__name__}.return_callbacks.remove instead", - DeprecationWarning, - stacklevel=2, - ) - self.return_callbacks.remove(callback) - - async def _create_channel( - self, timeout: TimeoutType = None, - ) -> aiormq.abc.AbstractChannel: - await self.connection.ready() - - return await self.connection.connection.channel( - publisher_confirms=self._publisher_confirms, + if channel is not None: + conn = channel.channel.connection + return '<%s #%s "%s">' % (self.__class__.__name__, self.number, conn) + + async def _open(self) -> None: + await self._connection.connected.wait() + self._transport = self._connection.transport + self._channel = await UnderlayChannel.create_channel( + self._connection.transport, self.__on_close, + publisher_confirms=self.publisher_confirms, on_return_raises=self._on_return_raises, channel_number=self._channel_number, - timeout=timeout, ) - async def initialize(self, timeout: TimeoutType = None) -> None: - if self.is_initialized: - raise RuntimeError("Already initialized") - elif self._is_closed_by_user: - raise RuntimeError("Can't initialize closed channel") - - channel: aiormq.abc.AbstractChannel = await self._create_channel( - timeout=timeout, - ) - - self._channel = channel self._delivery_tag = 0 - self.default_exchange = self.EXCHANGE_CLASS( - connection=self.connection, channel=self, arguments=None, auto_delete=False, @@ -229,24 +178,45 @@ async def initialize(self, timeout: TimeoutType = None) -> None: type=ExchangeType.DIRECT, ) - self._on_initialized() + async def initialize(self, timeout: TimeoutType = None) -> None: + if self.is_initialized: + raise RuntimeError("Already initialized") + elif self._is_closed_by_user: + raise RuntimeError("Can't initialize closed channel") + + async with self.__operation_lock: + await self._open() + self.ready.set() + self._on_initialized() - def _on_channel_closed(self, closing: asyncio.Future) -> None: - self.close_callbacks(closing.exception()) + async def __on_close(self, closing: asyncio.Future) -> None: + try: + await self.close_callbacks(closing.exception()) + finally: + self.ready.clear() def _on_initialized(self) -> None: self.channel.on_return_callbacks.add(self._on_return) - self.channel.closing.add_done_callback(self._on_channel_closed) def _on_return(self, message: aiormq.abc.DeliveredMessage) -> None: self.return_callbacks(IncomingMessage(message, no_ack=True)) async def reopen(self) -> None: - if hasattr(self, "_channel"): - del self._channel + log.debug("Start reopening channel %r", self) + async with self.__operation_lock: + if hasattr(self, "_channel"): + del self._channel + + if hasattr(self, "_transport_connection"): + del self._transport_connection + + self._is_closed_by_user = False + log.debug("Reopening channel %r", self) + await self._open() + self.ready.set() - self._is_closed_by_user = False - await self.initialize() + def __del__(self) -> None: + log.debug("%r deleted", self) async def declare_exchange( self, @@ -284,7 +254,6 @@ async def declare_exchange( durable = False exchange = self.EXCHANGE_CLASS( - connection=self.connection, channel=self, name=name, type=type, @@ -325,7 +294,6 @@ async def get_exchange( return await self.declare_exchange(name=name, passive=True) else: return self.EXCHANGE_CLASS( - connection=self.connection, channel=self, name=name, durable=False, @@ -461,7 +429,7 @@ async def exchange_delete( ) def transaction(self) -> Transaction: - if self._publisher_confirms: + if self.publisher_confirms: raise RuntimeError( "Cannot create transaction when publisher " "confirms are enabled", @@ -472,8 +440,5 @@ def transaction(self) -> Transaction: async def flow(self, active: bool = True) -> aiormq.spec.Channel.FlowOk: return await self.channel.flow(active=active) - def __del__(self) -> None: - log.debug("%r deleted", self) - __all__ = ("Channel",) diff --git a/aio_pika/connection.py b/aio_pika/connection.py index 4415732b..95b57459 100644 --- a/aio_pika/connection.py +++ b/aio_pika/connection.py @@ -3,7 +3,6 @@ from functools import partial from types import TracebackType from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, Union -from warnings import warn import aiormq import aiormq.abc @@ -13,13 +12,12 @@ # This needed only for migration from 6.x to 7.x # TODO: Remove this in 8.x release -from .abc import ConnectionType # noqa from .abc import ( AbstractChannel, AbstractConnection, ConnectionCloseCallback, TimeoutType, + UnderlayConnection, ) from .channel import Channel -from .tools import CallbackCollection - +from .tools import CallbackCollection, OneShotCallback log = logging.getLogger(__name__) T = TypeVar("T") @@ -38,16 +36,18 @@ def is_closed(self) -> bool: async def close( self, exc: Optional[aiormq.abc.ExceptionType] = asyncio.CancelledError, ) -> None: - if not self.closing.done(): - self.closing.set_result(exc) - - if not hasattr(self, "connection"): - return None + async with self.__operation_lock: + if not self.closing.done(): + self.closing.set_result(exc) - await self.connection.close(exc) + transport: Optional[UnderlayConnection] = getattr( + self, "transport", None + ) + if not transport: + return - if hasattr(self, "connection"): - del self.connection + del self.transport + await asyncio.wait_for(transport.close(exc), timeout=1) @classmethod def _parse_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: @@ -60,16 +60,15 @@ def __init__( self, url: URL, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any ): - super().__init__(url, loop, **kwargs) - self.loop = loop or asyncio.get_event_loop() + self.__operation_lock = asyncio.Lock() + self.url = URL(url) self.kwargs: Dict[str, Any] = self._parse_kwargs( kwargs or dict(self.url.query), ) - self.connection: aiormq.abc.AbstractConnection self.close_callbacks = CallbackCollection(self) self.connected: asyncio.Event = asyncio.Event() self.closing: asyncio.Future = self.loop.create_future() @@ -80,45 +79,16 @@ def __str__(self) -> str: def __repr__(self) -> str: return f'<{self.__class__.__name__}: "{self}">' - def add_close_callback( - self, callback: ConnectionCloseCallback, weak: bool = False, - ) -> None: - warn( - "This method will be removed from future release. " - f"Use {self.__class__.__name__}.close_callbacks.add instead", - DeprecationWarning, - stacklevel=2, - ) - self.close_callbacks.add(callback, weak=weak) - - def _on_connection_close( - self, connection: aiormq.abc.AbstractConnection, - closing: asyncio.Future, - ) -> None: - log.debug("Closing AMQP connection %r", connection) + async def _on_connection_close(self, closing: asyncio.Future) -> None: exc: Optional[BaseException] = closing.exception() - self.close_callbacks(exc) - + self.connected.clear() + await self.close_callbacks(exc) if self.closing.done(): return - if exc is not None: self.closing.set_exception(exc) - return - - self.closing.set_result(closing.result()) - - async def _make_connection( - self, *, timeout: TimeoutType = None, **kwargs: Any - ) -> aiormq.abc.AbstractConnection: - connection: aiormq.abc.AbstractConnection = await asyncio.wait_for( - aiormq.connect(self.url, **kwargs), timeout=timeout, - ) - connection.closing.add_done_callback( - partial(self._on_connection_close, connection), - ) - await connection.ready() - return connection + else: + self.closing.set_result(closing.result()) async def connect( self, timeout: TimeoutType = None, **kwargs: Any @@ -131,13 +101,11 @@ async def connect( You shouldn't call it explicitly. """ - self.connection = ( - await self._make_connection(timeout=timeout, **kwargs) - ) - self.connected.set() - self.connection.closing.add_done_callback( - lambda _: self.connected.clear(), - ) + async with self.__operation_lock: + self.transport = await UnderlayConnection.connect( + self.url, self._on_connection_close, timeout=timeout, **kwargs + ) + self.connected.set() def channel( self, diff --git a/aio_pika/exchange.py b/aio_pika/exchange.py index 265da3c0..9644ac7d 100644 --- a/aio_pika/exchange.py +++ b/aio_pika/exchange.py @@ -5,7 +5,7 @@ from pamqp.common import Arguments from .abc import ( - AbstractChannel, AbstractConnection, AbstractExchange, AbstractMessage, + AbstractChannel, AbstractExchange, AbstractMessage, ExchangeParamType, ExchangeType, TimeoutType, ) @@ -19,7 +19,6 @@ class Exchange(AbstractExchange): def __init__( self, - connection: AbstractConnection, channel: AbstractChannel, name: str, type: Union[ExchangeType, str] = ExchangeType.DIRECT, @@ -33,7 +32,6 @@ def __init__( if not arguments: arguments = {} - self.connection = connection self._channel = channel self.__type = type.value if isinstance(type, ExchangeType) else type self.name = name diff --git a/aio_pika/queue.py b/aio_pika/queue.py index 0ea3b60e..eb609cc7 100644 --- a/aio_pika/queue.py +++ b/aio_pika/queue.py @@ -51,7 +51,6 @@ def __init__( self.declaration_result: aiormq.spec.Queue.DeclareOk self.loop = channel.loop self.channel = channel - self.connection = channel.connection self.name = name or "" self.durable = durable self.exclusive = exclusive diff --git a/aio_pika/robust_channel.py b/aio_pika/robust_channel.py index bcfeeba4..32e29713 100644 --- a/aio_pika/robust_channel.py +++ b/aio_pika/robust_channel.py @@ -1,4 +1,5 @@ import asyncio +import traceback from collections import defaultdict from itertools import chain from logging import getLogger @@ -16,7 +17,7 @@ from .queue import Queue from .robust_exchange import RobustExchange from .robust_queue import RobustQueue -from .tools import CallbackCollection, create_task +from .tools import CallbackCollection log = getLogger(__name__) @@ -63,12 +64,18 @@ def __init__( self._prefetch_size: int = 0 self._global_qos: bool = False self.reopen_callbacks: CallbackCollection = CallbackCollection(self) + self.close_callbacks.add(self.__close_callback) + + async def __close_callback(self, *_) -> None: + if self._is_closed_by_user or not self.ready.is_set(): + return + await self.reopen() async def reopen(self) -> None: - log.debug("Reopening channel %r", self) await super().reopen() + await self.ready.wait() await self.restore() - self.reopen_callbacks() + await self.reopen_callbacks() async def restore(self) -> None: await self.set_qos( @@ -88,24 +95,6 @@ async def restore(self) -> None: for queue in queues: await queue.restore(self) - def _on_channel_closed(self, closing: asyncio.Future) -> None: - super()._on_channel_closed(closing) - - exc = closing.exception() - if ( - not self._is_closed_by_user and - not self.connection.is_closed and - self._channel.connection.is_opened - ): - create_task(self.reopen) - if exc: - log.exception( - "Robust channel %r has been closed.", - self, exc_info=exc, - ) - - log.debug("Robust channel %r has been closed.", self) - async def set_qos( self, prefetch_count: int = 0, @@ -118,7 +107,7 @@ async def set_qos( warn('Use "global_" instead of "all_channels"', DeprecationWarning) global_ = all_channels - await self.connection.connected.wait() + await self._transport_connection.ready() self._prefetch_count = prefetch_count self._prefetch_size = prefetch_size @@ -143,7 +132,7 @@ async def declare_exchange( timeout: TimeoutType = None, robust: bool = True, ) -> AbstractRobustExchange: - await self.connection.connected.wait() + await self._transport_connection.ready() exchange = ( await super().declare_exchange( name=name, @@ -170,7 +159,7 @@ async def exchange_delete( if_unused: bool = False, nowait: bool = False, ) -> aiormq.spec.Exchange.DeleteOk: - await self.connection.connected.wait() + await self._transport_connection.ready() result = await super().exchange_delete( exchange_name=exchange_name, timeout=timeout, @@ -194,7 +183,7 @@ async def declare_queue( timeout: TimeoutType = None, robust: bool = True ) -> AbstractRobustQueue: - await self.connection.connected.wait() + await self._transport_connection.ready() queue: RobustQueue = await super().declare_queue( # type: ignore name=name, durable=durable, @@ -218,7 +207,7 @@ async def queue_delete( if_empty: bool = False, nowait: bool = False, ) -> aiormq.spec.Queue.DeleteOk: - await self.connection.connected.wait() + await self._transport_connection.ready() result = await super().queue_delete( queue_name=queue_name, timeout=timeout, diff --git a/aio_pika/robust_exchange.py b/aio_pika/robust_exchange.py index 7d07920f..3642c820 100644 --- a/aio_pika/robust_exchange.py +++ b/aio_pika/robust_exchange.py @@ -5,7 +5,7 @@ from pamqp.common import Arguments from .abc import ( - AbstractChannel, AbstractConnection, AbstractExchange, + AbstractChannel, AbstractExchange, AbstractRobustChannel, AbstractRobustExchange, ExchangeParamType, TimeoutType, ) @@ -22,7 +22,6 @@ class RobustExchange(Exchange, AbstractRobustExchange): def __init__( self, - connection: AbstractConnection, channel: AbstractChannel, name: str, type: Union[ExchangeType, str] = ExchangeType.DIRECT, @@ -35,7 +34,6 @@ def __init__( ): super().__init__( - connection=connection, channel=channel, name=name, type=type, @@ -68,7 +66,7 @@ async def bind( timeout: TimeoutType = None, robust: bool = True ) -> aiormq.spec.Exchange.BindOk: - await self.connection.connected.wait() + await self.channel.ready.wait() result = await super().bind( exchange, @@ -92,7 +90,7 @@ async def unbind( arguments: Arguments = None, timeout: TimeoutType = None, ) -> aiormq.spec.Exchange.UnbindOk: - await self.connection.connected.wait() + await self.channel.ready.wait() result = await super().unbind( exchange, routing_key, arguments=arguments, timeout=timeout, diff --git a/aio_pika/robust_queue.py b/aio_pika/robust_queue.py index 17e3217d..264e1774 100644 --- a/aio_pika/robust_queue.py +++ b/aio_pika/robust_queue.py @@ -76,7 +76,7 @@ async def bind( timeout: TimeoutType = None, robust: bool = True ) -> aiormq.spec.Queue.BindOk: - await self.connection.connected.wait() + await self.channel.ready.wait() if routing_key is None: routing_key = self.name @@ -99,7 +99,7 @@ async def unbind( arguments: Arguments = None, timeout: TimeoutType = None, ) -> aiormq.spec.Queue.UnbindOk: - await self.connection.connected.wait() + await self.channel.ready.wait() if routing_key is None: routing_key = self.name @@ -120,7 +120,7 @@ async def consume( timeout: TimeoutType = None, robust: bool = True, ) -> ConsumerTag: - await self.connection.connected.wait() + await self.channel.ready.wait() consumer_tag = await super().consume( consumer_tag=consumer_tag, timeout=timeout, @@ -146,7 +146,7 @@ async def cancel( timeout: TimeoutType = None, nowait: bool = False, ) -> aiormq.spec.Basic.CancelOk: - await self.connection.connected.wait() + await self.channel.ready.wait() result = await super().cancel(consumer_tag, timeout, nowait) self._consumers.pop(consumer_tag, None) return result diff --git a/aio_pika/tools.py b/aio_pika/tools.py index 8cf571f5..6d95c60f 100644 --- a/aio_pika/tools.py +++ b/aio_pika/tools.py @@ -1,5 +1,6 @@ import asyncio import logging +import typing from functools import wraps from itertools import chain from threading import Lock @@ -81,13 +82,22 @@ def wrap(*args: Any, **kwargs: Any) -> Awaitable[T]: return wrap -CallbackType = Callable[..., Any] +CallbackType = Callable[..., Union[T, Awaitable[T]]] CallbackSetType = Union[AbstractSet[CallbackType]] +class StubAwaitable: + __slots__ = () + + def __await__(self): + yield + + class CallbackCollection(MutableSet): __slots__ = "__sender", "__callbacks", "__weak_callbacks", "__lock" + STUB_AWAITABLE = StubAwaitable() + def __init__(self, sender: Union[T, ReferenceType]): self.__sender: ReferenceType if isinstance(sender, ReferenceType): @@ -99,9 +109,7 @@ def __init__(self, sender: Union[T, ReferenceType]): self.__weak_callbacks: MutableSet[CallbackType] = WeakSet() self.__lock: Lock = Lock() - def add( - self, callback: Callable[..., Any], weak: bool = False, - ) -> None: + def add(self, callback: CallbackType, weak: bool = False) -> None: if self.is_frozen: raise RuntimeError("Collection frozen") if not callable(callback): @@ -113,7 +121,7 @@ def add( else: self.__callbacks.add(callback) # type: ignore - def discard(self, callback: Callable[..., Any]) -> None: + def discard(self, callback: CallbackType) -> None: if self.is_frozen: raise RuntimeError("Collection frozen") @@ -178,22 +186,58 @@ def __copy__(self) -> "CallbackCollection": return instance - def __call__(self, *args: Any, **kwargs: Any) -> None: + def __call__(self, *args: Any, **kwargs: Any) -> typing.Awaitable[Any]: + futures: typing.List[asyncio.Future] = [] + with self.__lock: sender = self.__sender() for cb in self: try: - cb(sender, *args, **kwargs) + result = cb(sender, *args, **kwargs) + if hasattr(result, '__await__'): + futures.append(asyncio.ensure_future(result)) except Exception: log.exception("Callback %r error", cb) + if not futures: + return self.STUB_AWAITABLE + return asyncio.gather(*futures, return_exceptions=True) + def __hash__(self) -> int: return id(self) +class OneShotCallback: + __slots__ = ('loop', 'finished', '__lock', "callback") + + def __init__(self, callback: Callable[..., Awaitable[T]]): + self.callback = callback + self.loop = asyncio.get_event_loop() + self.finished: asyncio.Event = asyncio.Event() + self.__lock: asyncio.Lock = asyncio.Lock() + + def wait(self) -> Awaitable[Any]: + return self.finished.wait() + + async def __closer(self, *args, **kwargs) -> None: + async with self.__lock: + if self.finished.is_set(): + return + try: + return await self.callback(*args, **kwargs) + finally: + self.finished.set() + + def __call__(self, *args, **kwargs) -> asyncio.Task: + return self.loop.create_task(self.__closer(*args, **kwargs)) + + __all__ = ( "CallbackCollection", + "CallbackType", + "CallbackSetType", + "OneShotCallback", "create_task", "iscoroutinepartial", "shield", diff --git a/tests/test_amqp.py b/tests/test_amqp.py index a8d98806..124a9fd6 100644 --- a/tests/test_amqp.py +++ b/tests/test_amqp.py @@ -425,7 +425,7 @@ async def test_context_process( routing_key, ) - if not channel._publisher_confirms: + if not channel.publisher_confirms: await asyncio.sleep(1) incoming_message: AbstractIncomingMessage = await queue.get(timeout=5) @@ -507,7 +507,7 @@ async def test_context_process_redelivery( routing_key, ) - if not channel._publisher_confirms: + if not channel.publisher_confirms: await asyncio.sleep(1) incoming_message = await queue.get(timeout=5) @@ -561,7 +561,7 @@ async def test_no_ack_redelivery( msg = Message(body) await exchange.publish(msg, routing_key) - if not channel._publisher_confirms: + if not channel.publisher_confirms: await asyncio.sleep(1) # ack 1 message out of 2 @@ -614,7 +614,7 @@ async def test_ack_multiple( msg = Message(body) await exchange.publish(msg, routing_key) - if not channel._publisher_confirms: + if not channel.publisher_confirms: await asyncio.sleep(1) # ack only last mesage with multiple flag, first @@ -1568,7 +1568,7 @@ async def test_heartbeat_disabling( connection: AbstractConnection = await connection_fabric(url) async with connection: - assert connection.connection.connection_tune.heartbeat == 0 + assert connection.transport.connection.connection_tune.heartbeat == 0 class TestCaseAmqpNoConfirms(TestCaseAmqp): diff --git a/tests/test_amqp_robust.py b/tests/test_amqp_robust.py index 35389c10..8501921a 100644 --- a/tests/test_amqp_robust.py +++ b/tests/test_amqp_robust.py @@ -1,11 +1,12 @@ import asyncio -import time from functools import partial import pytest +from aiormq import ChannelNotFoundEntity import aio_pika from aio_pika import RobustChannel +from tests import get_random_name from tests.test_amqp import ( TestCaseAmqp, TestCaseAmqpNoConfirms, TestCaseAmqpWithConfirms, ) @@ -61,6 +62,8 @@ async def test_channel_blocking_timeout_reopen(self, connection): reopen_event = asyncio.Event() channel.reopen_callbacks.add(lambda _: reopen_event.set()) + queue_name = get_random_name("test_channel_blocking_timeout_reopen") + def on_done(*args): close_reasons.append(args) close_event.set() @@ -68,14 +71,8 @@ def on_done(*args): channel.close_callbacks.add(on_done) - async def run(sleep_time=1): - await channel.set_qos(1) - if sleep_time: - time.sleep(sleep_time) - await channel.set_qos(0) - - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(run(), timeout=0.2) + with pytest.raises(ChannelNotFoundEntity): + await channel.declare_queue(queue_name, passive=True) await close_event.wait() @@ -87,8 +84,8 @@ async def run(sleep_time=1): # Ensure close callback has been called assert close_reasons - await asyncio.wait_for(reopen_event.wait(), timeout=2) - await asyncio.wait_for(run(sleep_time=0), timeout=2) + await asyncio.wait_for(reopen_event.wait(), timeout=60) + await channel.declare_queue(queue_name, auto_delete=True) class TestCaseAmqpNoConfirmsRobust(TestCaseAmqpNoConfirms): diff --git a/tests/test_amqp_robust_proxy.py b/tests/test_amqp_robust_proxy.py index ef68bde1..61c274c6 100644 --- a/tests/test_amqp_robust_proxy.py +++ b/tests/test_amqp_robust_proxy.py @@ -151,16 +151,16 @@ async def test_robust_reconnect( ) assert isinstance(read_conn, aio_pika.RobustConnection) + write_channel = await direct_connection.channel() - async with read_conn, direct_connection: + async with read_conn: read_channel = await read_conn.channel() - write_channel = await direct_connection.channel() assert isinstance(read_channel, aio_pika.RobustChannel) qname = get_random_name("robust", "proxy", "shared") - async with read_channel, write_channel: + async with read_channel: shared = [] # Declaring temporary queue @@ -229,7 +229,10 @@ async def reader(queue_name): await queue.get(timeout=0.5) finally: await queue.purge() - await queue.delete() + + # Waiting for rabbitmq queue not in use + await asyncio.sleep(1) + await queue.delete() async def test_channel_locked_resource2(connection: aio_pika.RobustConnection): @@ -405,7 +408,7 @@ async def reader(queue: aio_pika.Queue): await asyncio.wait_for(reconnect_event.wait(), timeout=5) logging.info("Waiting connections") - await channel.connection.ready() + await channel.ready.wait() async with shared_condition: await asyncio.wait_for( diff --git a/tests/test_tools.py b/tests/test_tools.py index 767da903..8fcb1feb 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,3 +1,4 @@ +import asyncio import logging from copy import copy from unittest import mock @@ -85,3 +86,59 @@ def test_callback_call(self, collection): assert l1 == l2 assert l1 == [1, 2] + + async def test_blank_awaitable_callback(self, collection): + await collection() + + async def test_awaitable_callback(self, loop, collection, instance): + future = loop.create_future() + + shared = [] + + async def coro(arg): + nonlocal shared + shared.append(arg) + + def task_maker(arg): + return loop.create_task(coro(arg)) + + collection.add(future.set_result) + collection.add(coro) + collection.add(task_maker) + + await collection() + + assert shared == [instance, instance] + assert await future == instance + + async def test_collection_create_tasks(self, loop, collection, instance): + future = loop.create_future() + + async def coro(arg): + await asyncio.sleep(0.5) + future.set_result(arg) + + collection.add(coro) + + # noinspection PyAsyncCall + collection() + + assert await future == instance + + async def test_collection_run_tasks_parallel(self, collection): + class Callable: + def __init__(self): + self.counter = 0 + + async def __call__(self, *args, **kwargs): + await asyncio.sleep(1) + self.counter += 1 + + callables = [Callable() for _ in range(100)] + + for callable in callables: + collection.add(callable) + + await asyncio.wait_for(collection(), timeout=2) + + assert [c.counter for c in callables] == [1] * 100