From 7d25d43de5bdfba3a715189b0b70f477df902a70 Mon Sep 17 00:00:00 2001 From: Vu Tran Date: Sun, 17 Mar 2024 21:44:40 +0700 Subject: [PATCH] Switch from psycopg v2 to v3 (#164) * change dependencies * Switch to psycopg^3 * refactor * disable postgresClock * cancel async-leak-task when exit test * refactor * refactor * update extra * update lock file * up * up * remove scope * up * remove cancel * re-organize functions & modules * refactor * remove clock-exception --- README.md | 6 +- noxfile.py | 2 +- poetry.lock | 97 +++++++++++++---- pyproject.toml | 8 +- pyrate_limiter/__init__.py | 2 +- pyrate_limiter/abstracts/bucket.py | 7 +- pyrate_limiter/buckets/postgres.py | 56 +++++----- pyrate_limiter/clocks.py | 14 +-- pyrate_limiter/exceptions.py | 6 -- tests/conftest.py | 166 +---------------------------- tests/demo_bucket_factory.py | 56 ++++++++++ tests/helpers.py | 103 ++++++++++++++++++ tests/test_bucket_factory.py | 7 +- tests/test_limiter.py | 31 +++--- 14 files changed, 304 insertions(+), 257 deletions(-) create mode 100644 tests/demo_bucket_factory.py create mode 100644 tests/helpers.py diff --git a/README.md b/README.md index f1813663..99ae1add 100644 --- a/README.md +++ b/README.md @@ -508,15 +508,15 @@ bucket = SQLiteBucket(rates, conn, table) #### PostgresBucket -Postgres is supported, but you have to install `psycopg2` or `asyncpg` either as an extra or as a separate package. +Postgres is supported, but you have to install `psycopg[pool]` either as an extra or as a separate package. You can use Postgres's built-in **CURRENT_TIMESTAMP** as the time source with `PostgresClock`, or use an external custom time source. ```python from pyrate_limiter import PostgresBucket, Rate, PostgresClock -from psycopg2.pool import ThreadedConnectionPool +from psycopg_pool import ConnectionPool -connection_pool = ThreadedConnectionPool(5, 10, 'postgresql://postgres:postgres@localhost:5432') +connection_pool = ConnectionPool('postgresql://postgres:postgres@localhost:5432') clock = PostgresClock(connection_pool) rates = [Rate(3, 1000), Rate(4, 1500)] diff --git a/noxfile.py b/noxfile.py index 1bf55284..b68519a5 100644 --- a/noxfile.py +++ b/noxfile.py @@ -4,7 +4,7 @@ # Reuse virtualenv created by poetry instead of creating new ones nox.options.reuse_existing_virtualenvs = True -PYTEST_ARGS = ["--verbose", "--maxfail=1", "--numprocesses=8"] +PYTEST_ARGS = ["--verbose", "--maxfail=1", "--numprocesses=auto"] COVERAGE_ARGS = ["--cov", "--cov-report=term", "--cov-report=xml", "--cov-report=html"] diff --git a/poetry.lock b/poetry.lock index 79a0aa82..03c6907f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -58,6 +58,35 @@ pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""} [package.extras] dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] +[[package]] +name = "backports-zoneinfo" +version = "0.2.1" +description = "Backport of the standard library zoneinfo module" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"}, + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"}, + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"}, + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win32.whl", hash = "sha256:e8236383a20872c0cdf5a62b554b27538db7fa1bbec52429d8d106effbaeca08"}, + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8439c030a11780786a2002261569bdf362264f605dfa4d65090b64b05c9f79a7"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:f04e857b59d9d1ccc39ce2da1021d196e47234873820cbeaad210724b1ee28ac"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:17746bd546106fa389c51dbea67c8b7c8f0d14b5526a579ca6ccf5ed72c526cf"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5c144945a7752ca544b4b78c8c41544cdfaf9786f25fe5ffb10e838e19a27570"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win32.whl", hash = "sha256:e55b384612d93be96506932a786bbcde5a2db7a9e6a4bb4bffe8b733f5b9036b"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a76b38c52400b762e48131494ba26be363491ac4f9a04c1b7e92483d169f6582"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"}, + {file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"}, +] + +[package.extras] +tzdata = ["tzdata"] + [[package]] name = "beautifulsoup4" version = "4.12.3" @@ -790,28 +819,46 @@ pyyaml = ">=5.1" virtualenv = ">=20.10.0" [[package]] -name = "psycopg2" -version = "2.9.9" -description = "psycopg2 - Python-PostgreSQL Database Adapter" +name = "psycopg" +version = "3.1.18" +description = "PostgreSQL database adapter for Python" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "psycopg2-2.9.9-cp310-cp310-win32.whl", hash = "sha256:38a8dcc6856f569068b47de286b472b7c473ac7977243593a288ebce0dc89516"}, - {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, - {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, - {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, - {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, - {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, - {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, - {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, - {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, - {file = "psycopg2-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:bac58c024c9922c23550af2a581998624d6e02350f4ae9c5f0bc642c633a2d5e"}, - {file = "psycopg2-2.9.9-cp39-cp39-win32.whl", hash = "sha256:c92811b2d4c9b6ea0285942b2e7cac98a59e166d59c588fe5cfe1eda58e72d59"}, - {file = "psycopg2-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:de80739447af31525feddeb8effd640782cf5998e1a4e9192ebdf829717e3913"}, - {file = "psycopg2-2.9.9.tar.gz", hash = "sha256:d1454bde93fb1e224166811694d600e746430c006fbb031ea06ecc2ea41bf156"}, + {file = "psycopg-3.1.18-py3-none-any.whl", hash = "sha256:4d5a0a5a8590906daa58ebd5f3cfc34091377354a1acced269dd10faf55da60e"}, + {file = "psycopg-3.1.18.tar.gz", hash = "sha256:31144d3fb4c17d78094d9e579826f047d4af1da6a10427d91dfcfb6ecdf6f12b"}, ] +[package.dependencies] +"backports.zoneinfo" = {version = ">=0.2.0", markers = "python_version < \"3.9\""} +psycopg-pool = {version = "*", optional = true, markers = "extra == \"pool\""} +typing-extensions = ">=4.1" +tzdata = {version = "*", markers = "sys_platform == \"win32\""} + +[package.extras] +binary = ["psycopg-binary (==3.1.18)"] +c = ["psycopg-c (==3.1.18)"] +dev = ["black (>=24.1.0)", "codespell (>=2.2)", "dnspython (>=2.1)", "flake8 (>=4.0)", "mypy (>=1.4.1)", "types-setuptools (>=57.4)", "wheel (>=0.37)"] +docs = ["Sphinx (>=5.0)", "furo (==2022.6.21)", "sphinx-autobuild (>=2021.3.14)", "sphinx-autodoc-typehints (>=1.12)"] +pool = ["psycopg-pool"] +test = ["anyio (>=3.6.2,<4.0)", "mypy (>=1.4.1)", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"] + +[[package]] +name = "psycopg-pool" +version = "3.2.1" +description = "Connection Pool for Psycopg" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "psycopg-pool-3.2.1.tar.gz", hash = "sha256:6509a75c073590952915eddbba7ce8b8332a440a31e77bba69561483492829ad"}, + {file = "psycopg_pool-3.2.1-py3-none-any.whl", hash = "sha256:060b551d1b97a8d358c668be58b637780b884de14d861f4f5ecc48b7563aafb7"}, +] + +[package.dependencies] +typing-extensions = ">=4.4" + [[package]] name = "py" version = "1.11.0" @@ -1351,13 +1398,25 @@ name = "typing-extensions" version = "4.10.0" description = "Backported and Experimental Type Hints for Python 3.8+" category = "main" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] +[[package]] +name = "tzdata" +version = "2024.1" +description = "Provider of IANA time zone data" +category = "main" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, + {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, +] + [[package]] name = "urllib3" version = "2.2.1" @@ -1414,10 +1473,10 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["filelock", "redis", "psycopg2"] +all = ["filelock", "redis", "psycopg"] docs = ["furo", "myst-parser", "sphinx", "sphinx-autodoc-typehints", "sphinx-copybutton", "sphinxcontrib-apidoc"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "3332ed9ac8f743ce94bda482c985168f6b9498ac382263cff977ae629b7434d7" +content-hash = "da1001531f2b92a12d3953ed4560f57d5bfd0829b4d3a05104d750048311ddc7" diff --git a/pyproject.toml b/pyproject.toml index f1c681ae..583eeaad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pyrate-limiter" -version = "3.5.1" +version = "3.6.0" description = "Python Rate-Limiter using Leaky-Bucket Algorithm" authors = ["vutr "] license = "MIT" @@ -29,7 +29,7 @@ python = "^3.8" # Optional backend dependencies filelock = {optional=true, version=">=3.0"} redis = {optional=true, version="^5.0.0"} -psycopg2 = {version = "^2.9.9", optional = true} +psycopg = {extras = ["pool"], version = "^3.1.18", optional = true} # Documentation dependencies needed for Readthedocs builds furo = {optional=true, version="^2022.3.4"} @@ -40,7 +40,7 @@ sphinx-copybutton = {optional=true, version=">=0.5"} sphinxcontrib-apidoc = {optional=true, version="^0.3"} [tool.poetry.extras] -all = ["filelock", "redis", "psycopg2"] +all = ["filelock", "redis", "psycopg"] docs = ["furo", "myst-parser", "sphinx", "sphinx-autodoc-typehints", "sphinx-copybutton", "sphinxcontrib-apidoc"] @@ -58,7 +58,7 @@ coverage = "6" [tool.poetry.group.dev.dependencies] pytest = "^8.1.1" pytest-asyncio = "^0.23.5.post1" -psycopg2 = "^2.9.9" +psycopg = {extras = ["pool"], version = "^3.1.18"} [tool.black] line-length = 120 diff --git a/pyrate_limiter/__init__.py b/pyrate_limiter/__init__.py index 84b3b155..a322da3b 100644 --- a/pyrate_limiter/__init__.py +++ b/pyrate_limiter/__init__.py @@ -3,5 +3,5 @@ from .buckets import * from .clocks import * from .exceptions import * -from .limiter import Limiter +from .limiter import * from .utils import * diff --git a/pyrate_limiter/abstracts/bucket.py b/pyrate_limiter/abstracts/bucket.py index 8f64ed7a..db962ac7 100644 --- a/pyrate_limiter/abstracts/bucket.py +++ b/pyrate_limiter/abstracts/bucket.py @@ -123,6 +123,7 @@ def __init__(self, leak_interval: int): self.async_buckets = defaultdict() self.clocks = defaultdict() self.leak_interval = leak_interval + self._task = None super().__init__() def register(self, bucket: AbstractBucket, clock: AbstractClock): @@ -171,7 +172,7 @@ async def _leak(self, sync=True) -> None: def leak_async(self): if self.async_buckets and not self.is_async_leak_started: self.is_async_leak_started = True - asyncio.create_task(self._leak(sync=False)) + self._task = asyncio.create_task(self._leak(sync=False)) def run(self) -> None: assert self.sync_buckets @@ -181,6 +182,10 @@ def start(self) -> None: if self.sync_buckets and not self.is_alive(): super().start() + def cancel(self) -> None: + if self._task: + self._task.cancel() + class BucketFactory(ABC): """Asbtract BucketFactory class. diff --git a/pyrate_limiter/buckets/postgres.py b/pyrate_limiter/buckets/postgres.py index 4516c510..1c82f9fe 100644 --- a/pyrate_limiter/buckets/postgres.py +++ b/pyrate_limiter/buckets/postgres.py @@ -14,7 +14,7 @@ from ..abstracts import RateItem if TYPE_CHECKING: - from psycopg2.pool import AbstractConnectionPool + from psycopg_pool import ConnectionPool class Queries: @@ -54,9 +54,9 @@ class Queries: class PostgresBucket(AbstractBucket): table: str - pool: AbstractConnectionPool + pool: ConnectionPool - def __init__(self, pool: AbstractConnectionPool, table: str, rates: List[Rate]): + def __init__(self, pool: ConnectionPool, table: str, rates: List[Rate]): self.table = table.lower() self.pool = pool assert rates @@ -65,21 +65,15 @@ def __init__(self, pool: AbstractConnectionPool, table: str, rates: List[Rate]): self._create_table() @contextmanager - def _get_conn(self, autocommit=False): - with self.pool._getconn() as conn: - with conn.cursor() as cur: - yield cur - - if autocommit: - conn.commit() - - self.pool._putconn(conn) + def _get_conn(self): + with self.pool.connection() as conn: + yield conn def _create_table(self): - with self._get_conn(autocommit=True) as cur: - cur.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl)) + with self._get_conn() as conn: + conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl)) index_name = f'timestampIndex_{self.table}' - cur.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name)) + conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name)) def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]: """Put an item (typically the current time) in the bucket @@ -88,12 +82,12 @@ def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]: if item.weight == 0: return True - with self._get_conn(autocommit=True) as cur: + with self._get_conn() as conn: for rate in self.rates: bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'" query = f'SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})' - cur.execute(query) - count = int(cur.fetchone()[0]) + conn = conn.execute(query) + count = int(conn.fetchone()[0]) if rate.limit - count < item.weight: self.failing_rate = rate @@ -103,7 +97,7 @@ def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]: query = Queries.PUT.format(table=self._full_tbl) arguments = [(item.name, item.weight, item.timestamp / 1000)] * item.weight - cur.executemany(query, tuple(arguments)) + conn.executemany(query, tuple(arguments)) return True @@ -120,12 +114,12 @@ def leak( count = 0 - with self._get_conn(autocommit=True) as cur: - cur.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000)) - result = cur.fetchone() + with self._get_conn() as conn: + conn = conn.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000)) + result = conn.fetchone() if result: - cur.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000)) + conn.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000)) count = int(result[0]) return count @@ -134,8 +128,8 @@ def flush(self) -> Union[None, Awaitable[None]]: """Flush the whole bucket - Must remove `failing-rate` after flushing """ - with self._get_conn(autocommit=True) as cur: - cur.execute(Queries.FLUSH.format(table=self._full_tbl)) + with self._get_conn() as conn: + conn.execute(Queries.FLUSH.format(table=self._full_tbl)) self.failing_rate = None return None @@ -143,9 +137,9 @@ def flush(self) -> Union[None, Awaitable[None]]: def count(self) -> Union[int, Awaitable[int]]: """Count number of items in the bucket""" count = 0 - with self._get_conn() as cur: - cur.execute(Queries.COUNT.format(table=self._full_tbl)) - result = cur.fetchone() + with self._get_conn() as conn: + conn = conn.execute(Queries.COUNT.format(table=self._full_tbl)) + result = conn.fetchone() assert result count = int(result[0]) @@ -158,9 +152,9 @@ def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateI """ item = None - with self._get_conn() as cur: - cur.execute(Queries.PEEK.format(table=self._full_tbl, offset=index)) - result = cur.fetchone() + with self._get_conn() as conn: + conn = conn.execute(Queries.PEEK.format(table=self._full_tbl, offset=index)) + result = conn.fetchone() if result: name, weight, timestamp = result[0], int(result[1]), int(result[2]) item = RateItem(name=name, weight=weight, timestamp=timestamp) diff --git a/pyrate_limiter/clocks.py b/pyrate_limiter/clocks.py index ba8c3521..9f3eb4cc 100644 --- a/pyrate_limiter/clocks.py +++ b/pyrate_limiter/clocks.py @@ -8,11 +8,10 @@ from typing import TYPE_CHECKING from .abstracts import AbstractClock -from .exceptions import PyrateClockException from .utils import dedicated_sqlite_clock_connection if TYPE_CHECKING: - from psycopg2.pool import AbstractConnectionPool + from psycopg_pool import ConnectionPool class MonotonicClock(AbstractClock): @@ -57,22 +56,17 @@ def now(self) -> int: class PostgresClock(AbstractClock): """Get timestamp using Postgres as remote clock backend""" - def __init__(self, pool: 'AbstractConnectionPool'): + def __init__(self, pool: 'ConnectionPool'): self.pool = pool def now(self) -> int: value = 0 - with self.pool._getconn() as conn: + with self.pool.connection() as conn: with conn.cursor() as cur: cur.execute("SELECT EXTRACT(epoch FROM current_timestamp) * 1000") result = cur.fetchone() - - if not result: - raise PyrateClockException(self, detail=f"invalid result from query current-timestamp: {result}") - + assert result, "unable to get current-timestamp from postgres" value = int(result[0]) - self.pool._putconn(conn) - return value diff --git a/pyrate_limiter/exceptions.py b/pyrate_limiter/exceptions.py index 15fee456..ac84f529 100644 --- a/pyrate_limiter/exceptions.py +++ b/pyrate_limiter/exceptions.py @@ -33,9 +33,3 @@ def __init__(self, item: RateItem, rate: Rate, actual_delay: int, max_delay: int "actual_delay": actual_delay, } super().__init__(error) - - -class PyrateClockException(Exception): - def __init__(self, clock: object, detail=None): - error = f"Clock({repr(clock)}) is failing: {detail}" - super().__init__(error) diff --git a/tests/conftest.py b/tests/conftest.py index 88eac097..bc8429ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,43 +1,28 @@ -"""Pytest config +"""Pytest config & fixtures """ import sqlite3 -from concurrent.futures import ThreadPoolExecutor -from inspect import isawaitable from logging import basicConfig from logging import getLogger from os import getenv from pathlib import Path from tempfile import gettempdir -from time import sleep -from time import time -from typing import Dict from typing import List -from typing import Optional -from typing import Tuple from typing import Union import pytest -from psycopg2.pool import ThreadedConnectionPool -from pyrate_limiter import AbstractBucket -from pyrate_limiter import AbstractClock -from pyrate_limiter import BucketFactory from pyrate_limiter import Duration from pyrate_limiter import id_generator from pyrate_limiter import InMemoryBucket -from pyrate_limiter import Limiter from pyrate_limiter import MonotonicClock from pyrate_limiter import PostgresBucket from pyrate_limiter import Rate -from pyrate_limiter import RateItem from pyrate_limiter import RedisBucket from pyrate_limiter import SQLiteBucket from pyrate_limiter import SQLiteClock from pyrate_limiter import SQLiteQueries as Queries from pyrate_limiter import TimeAsyncClock from pyrate_limiter import TimeClock -from pyrate_limiter import validate_rate_list -# from pyrate_limiter import PostgresClock # Make log messages visible on test failure (or with pytest -s) @@ -46,14 +31,13 @@ logger = getLogger("pyrate_limiter") logger.setLevel(getenv("LOG_LEVEL", "INFO")) -pg_pool = ThreadedConnectionPool(3, 5, 'postgresql://postgres:postgres@localhost:5432') +DEFAULT_RATES = [Rate(3, 1000), Rate(4, 1500)] clocks = [ MonotonicClock(), TimeClock(), SQLiteClock.default(), TimeAsyncClock(), - # PostgresClock(pg_pool) ] ClockSet = Union[ @@ -61,7 +45,6 @@ TimeClock, SQLiteClock, TimeAsyncClock, - # PostgresClock ] @@ -133,10 +116,11 @@ async def create_sqlite_bucket(rates: List[Rate]): async def create_postgres_bucket(rates: List[Rate]): - global pg_pool + from psycopg_pool import ConnectionPool as PgConnectionPool + pool = PgConnectionPool('postgresql://postgres:postgres@localhost:5432') table = f"test_bucket_{id_generator()}" - bucket = PostgresBucket(pg_pool, table, rates) + bucket = PostgresBucket(pool, table, rates) assert bucket.count() == 0 return bucket @@ -155,56 +139,6 @@ def create_bucket(request): return request.param -DEFAULT_RATES = [Rate(3, 1000), Rate(4, 1500)] -validate_rate_list(DEFAULT_RATES) - - -class DemoBucketFactory(BucketFactory): - """Multi-bucket factory used for testing schedule-leaks""" - - buckets: Optional[Dict[str, AbstractBucket]] = None - clock: AbstractClock - auto_leak: bool - - def __init__(self, bucket_clock: AbstractClock, auto_leak=False, **buckets: AbstractBucket): - self.auto_leak = auto_leak - self.clock = bucket_clock - self.buckets = {} - self.leak_interval = 300 - - for item_name_pattern, bucket in buckets.items(): - assert isinstance(bucket, AbstractBucket) - self.schedule_leak(bucket, bucket_clock) - self.buckets[item_name_pattern] = bucket - - def wrap_item(self, name: str, weight: int = 1): - now = self.clock.now() - - async def wrap_async(): - return RateItem(name, await now, weight=weight) - - def wrap_sync(): - return RateItem(name, now, weight=weight) - - return wrap_async() if isawaitable(now) else wrap_sync() - - def get(self, item: RateItem) -> AbstractBucket: - assert self.buckets is not None - - if item.name in self.buckets: - bucket = self.buckets[item.name] - assert isinstance(bucket, AbstractBucket) - return bucket - - bucket = self.create(self.clock, InMemoryBucket, DEFAULT_RATES) - self.buckets[item.name] = bucket - return bucket - - def schedule_leak(self, *args): - if self.auto_leak: - super().schedule_leak(*args) - - @pytest.fixture(params=[True, False]) def limiter_should_raise(request): return request.param @@ -213,93 +147,3 @@ def limiter_should_raise(request): @pytest.fixture(params=[None, 500, Duration.SECOND * 2, Duration.MINUTE]) def limiter_delay(request): return request.param - - -async def inspect_bucket_items(bucket: AbstractBucket, expected_item_count: int): - """Inspect items in the bucket - - Assert number of item == expected-item-count - - Assert that items are ordered by timestamps, from latest to earliest - """ - collected_items = [] - - for idx in range(expected_item_count): - item = bucket.peek(idx) - - if isawaitable(item): - item = await item - - assert isinstance(item, RateItem) - collected_items.append(item) - - item_names = [item.name for item in collected_items] - - for i in range(1, expected_item_count): - item = collected_items[i] - prev_item = collected_items[i - 1] - assert item.timestamp <= prev_item.timestamp - - return item_names - - -async def concurrent_acquire(limiter: Limiter, items: List[str]): - with ThreadPoolExecutor() as executor: - result = list(executor.map(limiter.try_acquire, items)) - for idx, coro in enumerate(result): - while isawaitable(coro): - coro = await coro - result[idx] = coro - - return result - - -async def async_acquire(limiter: Limiter, item: str, weight: int = 1) -> Tuple[bool, int]: - start = time() - acquire = limiter.try_acquire(item, weight=weight) - - if isawaitable(acquire): - acquire = await acquire - - time_cost_in_ms = int((time() - start) * 1000) - assert isinstance(acquire, bool) - return acquire, time_cost_in_ms - - -async def async_count(bucket: AbstractBucket) -> int: - count = bucket.count() - - if isawaitable(count): - count = await count - - assert isinstance(count, int) - return count - - -async def prefilling_bucket(limiter: Limiter, sleep_interval: float, item: str): - """Pre-filling bucket to the limit before testing - the time cost might vary depending on the bucket's backend - - For in-memory bucket, this should be less than a 1ms - - For external bucket's source ie Redis, this mostly depends on the network latency - """ - acquire_ok, cost = await async_acquire(limiter, item) - logger.info("cost = %s", cost) - assert cost <= 50 - assert acquire_ok - sleep(sleep_interval) - - acquire_ok, cost = await async_acquire(limiter, item) - logger.info("cost = %s", cost) - assert cost <= 50 - assert acquire_ok - sleep(sleep_interval) - - acquire_ok, cost = await async_acquire(limiter, item) - logger.info("cost = %s", cost) - assert cost <= 50 - assert acquire_ok - - -async def flushing_bucket(bucket: AbstractBucket): - flush = bucket.flush() - - if isawaitable(flush): - await flush diff --git a/tests/demo_bucket_factory.py b/tests/demo_bucket_factory.py new file mode 100644 index 00000000..ba4f5f24 --- /dev/null +++ b/tests/demo_bucket_factory.py @@ -0,0 +1,56 @@ +from inspect import isawaitable +from typing import Dict +from typing import Optional + +from .conftest import DEFAULT_RATES +from pyrate_limiter import AbstractBucket +from pyrate_limiter import AbstractClock +from pyrate_limiter import BucketFactory +from pyrate_limiter import InMemoryBucket +from pyrate_limiter import RateItem + + +class DemoBucketFactory(BucketFactory): + """Multi-bucket factory used for testing schedule-leaks""" + + buckets: Optional[Dict[str, AbstractBucket]] = None + clock: AbstractClock + auto_leak: bool + + def __init__(self, bucket_clock: AbstractClock, auto_leak=False, **buckets: AbstractBucket): + self.auto_leak = auto_leak + self.clock = bucket_clock + self.buckets = {} + self.leak_interval = 300 + + for item_name_pattern, bucket in buckets.items(): + assert isinstance(bucket, AbstractBucket) + self.schedule_leak(bucket, bucket_clock) + self.buckets[item_name_pattern] = bucket + + def wrap_item(self, name: str, weight: int = 1): + now = self.clock.now() + + async def wrap_async(): + return RateItem(name, await now, weight=weight) + + def wrap_sync(): + return RateItem(name, now, weight=weight) + + return wrap_async() if isawaitable(now) else wrap_sync() + + def get(self, item: RateItem) -> AbstractBucket: + assert self.buckets is not None + + if item.name in self.buckets: + bucket = self.buckets[item.name] + assert isinstance(bucket, AbstractBucket) + return bucket + + bucket = self.create(self.clock, InMemoryBucket, DEFAULT_RATES) + self.buckets[item.name] = bucket + return bucket + + def schedule_leak(self, *args): + if self.auto_leak: + super().schedule_leak(*args) diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 00000000..23d1060b --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,103 @@ +"""Duh.... +""" +from concurrent.futures import ThreadPoolExecutor +from inspect import isawaitable +from time import sleep +from time import time +from typing import List +from typing import Tuple + +from .conftest import logger +from pyrate_limiter import AbstractBucket +from pyrate_limiter import Limiter +from pyrate_limiter import RateItem + + +async def inspect_bucket_items(bucket: AbstractBucket, expected_item_count: int): + """Inspect items in the bucket + - Assert number of item == expected-item-count + - Assert that items are ordered by timestamps, from latest to earliest + """ + collected_items = [] + + for idx in range(expected_item_count): + item = bucket.peek(idx) + + if isawaitable(item): + item = await item + + assert isinstance(item, RateItem) + collected_items.append(item) + + item_names = [item.name for item in collected_items] + + for i in range(1, expected_item_count): + item = collected_items[i] + prev_item = collected_items[i - 1] + assert item.timestamp <= prev_item.timestamp + + return item_names + + +async def concurrent_acquire(limiter: Limiter, items: List[str]): + with ThreadPoolExecutor() as executor: + result = list(executor.map(limiter.try_acquire, items)) + for idx, coro in enumerate(result): + while isawaitable(coro): + coro = await coro + result[idx] = coro + + return result + + +async def async_acquire(limiter: Limiter, item: str, weight: int = 1) -> Tuple[bool, int]: + start = time() + acquire = limiter.try_acquire(item, weight=weight) + + if isawaitable(acquire): + acquire = await acquire + + time_cost_in_ms = int((time() - start) * 1000) + assert isinstance(acquire, bool) + return acquire, time_cost_in_ms + + +async def async_count(bucket: AbstractBucket) -> int: + count = bucket.count() + + if isawaitable(count): + count = await count + + assert isinstance(count, int) + return count + + +async def prefilling_bucket(limiter: Limiter, sleep_interval: float, item: str): + """Pre-filling bucket to the limit before testing + the time cost might vary depending on the bucket's backend + - For in-memory bucket, this should be less than a 1ms + - For external bucket's source ie Redis, this mostly depends on the network latency + """ + acquire_ok, cost = await async_acquire(limiter, item) + logger.info("cost = %s", cost) + assert cost <= 50 + assert acquire_ok + sleep(sleep_interval) + + acquire_ok, cost = await async_acquire(limiter, item) + logger.info("cost = %s", cost) + assert cost <= 50 + assert acquire_ok + sleep(sleep_interval) + + acquire_ok, cost = await async_acquire(limiter, item) + logger.info("cost = %s", cost) + assert cost <= 50 + assert acquire_ok + + +async def flushing_bucket(bucket: AbstractBucket): + flush = bucket.flush() + + if isawaitable(flush): + await flush diff --git a/tests/test_bucket_factory.py b/tests/test_bucket_factory.py index 272cf1aa..4770e915 100644 --- a/tests/test_bucket_factory.py +++ b/tests/test_bucket_factory.py @@ -6,10 +6,10 @@ import pytest -from .conftest import async_count from .conftest import DEFAULT_RATES -from .conftest import DemoBucketFactory from .conftest import logger +from .demo_bucket_factory import DemoBucketFactory +from .helpers import async_count from pyrate_limiter import AbstractBucket from pyrate_limiter import RateItem @@ -32,6 +32,8 @@ async def test_factory_01(clock, create_bucket): bucket = factory.get(item) assert isinstance(bucket, AbstractBucket) + if factory._leaker: + factory._leaker.cancel() @pytest.mark.asyncio @@ -82,3 +84,4 @@ async def test_factory_leak(clock, create_bucket): assert await async_count(factory.buckets[item_name]) == 0 assert len(factory.buckets) == 3 + factory._leaker.cancel() diff --git a/tests/test_limiter.py b/tests/test_limiter.py index ea369f84..12f8846b 100644 --- a/tests/test_limiter.py +++ b/tests/test_limiter.py @@ -4,14 +4,14 @@ import pytest -from .conftest import async_acquire -from .conftest import concurrent_acquire from .conftest import DEFAULT_RATES -from .conftest import DemoBucketFactory -from .conftest import flushing_bucket -from .conftest import inspect_bucket_items from .conftest import logger -from .conftest import prefilling_bucket +from .demo_bucket_factory import DemoBucketFactory +from .helpers import async_acquire +from .helpers import concurrent_acquire +from .helpers import flushing_bucket +from .helpers import inspect_bucket_items +from .helpers import prefilling_bucket from pyrate_limiter import AbstractBucket from pyrate_limiter import BucketAsyncWrapper from pyrate_limiter import BucketFactory @@ -20,6 +20,7 @@ from pyrate_limiter import InMemoryBucket from pyrate_limiter import Limiter from pyrate_limiter import LimiterDelayException +from pyrate_limiter import SingleBucketFactory from pyrate_limiter import TimeClock @@ -40,7 +41,6 @@ async def test_limiter_constructor_01(clock): @pytest.mark.asyncio async def test_limiter_constructor_02( - clock, create_bucket, limiter_should_raise, limiter_delay, @@ -48,20 +48,19 @@ async def test_limiter_constructor_02( bucket = await create_bucket(DEFAULT_RATES) limiter = Limiter(bucket) - assert isinstance(limiter.bucket_factory, BucketFactory) + assert isinstance(limiter.bucket_factory, SingleBucketFactory) assert isinstance(limiter.bucket_factory.clock, TimeClock) assert limiter.max_delay is None assert limiter.raise_when_fail is True limiter = Limiter( bucket, - clock=clock, + clock=TimeClock(), raise_when_fail=limiter_should_raise, max_delay=limiter_delay, ) assert isinstance(limiter.bucket_factory, BucketFactory) - assert limiter.bucket_factory.clock is clock assert limiter.raise_when_fail == limiter_should_raise assert limiter.max_delay == limiter_delay @@ -72,27 +71,25 @@ async def test_limiter_constructor_02( assert acquire_ok - factory = DemoBucketFactory(clock, demo=bucket) + factory = DemoBucketFactory(TimeClock(), demo=bucket) limiter = Limiter( factory, raise_when_fail=limiter_should_raise, max_delay=limiter_delay, ) assert limiter.bucket_factory is factory - assert limiter.bucket_factory.clock is clock assert limiter.raise_when_fail == limiter_should_raise assert limiter.max_delay == limiter_delay @pytest.mark.asyncio async def test_limiter_01( - clock, create_bucket, limiter_should_raise, limiter_delay, ): bucket = await create_bucket(DEFAULT_RATES) - factory = DemoBucketFactory(clock, demo=bucket) + factory = DemoBucketFactory(TimeClock(), demo=bucket) limiter = Limiter( factory, raise_when_fail=limiter_should_raise, @@ -170,13 +167,12 @@ async def test_limiter_01( @pytest.mark.asyncio async def test_limiter_concurrency( - clock, create_bucket, limiter_should_raise, limiter_delay, ): bucket: AbstractBucket = await create_bucket(DEFAULT_RATES) - factory = DemoBucketFactory(clock, demo=bucket) + factory = DemoBucketFactory(TimeClock(), demo=bucket) limiter = Limiter( factory, raise_when_fail=limiter_should_raise, @@ -218,13 +214,12 @@ async def test_limiter_concurrency( @pytest.mark.asyncio async def test_limiter_decorator( - clock, create_bucket, limiter_should_raise, limiter_delay, ): bucket = await create_bucket(DEFAULT_RATES) - factory = DemoBucketFactory(clock, demo=bucket) + factory = DemoBucketFactory(TimeClock(), demo=bucket) limiter = Limiter( factory, raise_when_fail=limiter_should_raise,