diff --git a/.github/workflows/aiosql-package.yml b/.github/workflows/aiosql-package.yml index d7624bbd..1a8a1a2a 100644 --- a/.github/workflows/aiosql-package.yml +++ b/.github/workflows/aiosql-package.yml @@ -21,9 +21,9 @@ jobs: # https://downloads.python.org/pypy/versions.json python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/Makefile b/Makefile index 75dc9d2e..f15f4469 100644 --- a/Makefile +++ b/Makefile @@ -323,7 +323,7 @@ check.coverage.combine: $(VENV) else $(COVERAGE) html fi - $(COVERAGE) report --fail-under=100 --include='$(MODULE)/*' + $(COVERAGE) report --show-missing --precision=1 --fail-under=100.0 --include='$(MODULE)/*' # # Docker runs diff --git a/aiosql/adapters/duckdb.py b/aiosql/adapters/duckdb.py index 76c3c350..a26c7ae8 100644 --- a/aiosql/adapters/duckdb.py +++ b/aiosql/adapters/duckdb.py @@ -37,7 +37,7 @@ def insert_returning(self, conn, _query_name, sql, parameters): # pragma: no co res = res[0] return res[0] if res and len(res) == 1 else res - def select(self, conn, _query_name, sql, parameters, record_class=None): + def select(self, conn, _query_name: str, sql: str, parameters, record_class=None): column_names: List[str] = [] cur = self._cursor(conn) try: diff --git a/aiosql/adapters/generic.py b/aiosql/adapters/generic.py index 90c584a0..998a511b 100644 --- a/aiosql/adapters/generic.py +++ b/aiosql/adapters/generic.py @@ -5,6 +5,8 @@ class GenericAdapter: """ Generic AioSQL Adapter suitable for `named` parameter style and no with support. + + This class also serves as the base class for other adapters. """ def __init__(self, driver=None): @@ -18,7 +20,8 @@ def _cursor(self, conn): """Get a cursor from a connection.""" return conn.cursor() - def select(self, conn, _query_name, sql, parameters, record_class=None): + def select(self, conn, _query_name: str, sql: str, parameters, record_class=None): + """Handle a relation-returning SELECT (no suffix).""" column_names: List[str] = [] cur = self._cursor(conn) try: @@ -37,6 +40,9 @@ def select(self, conn, _query_name, sql, parameters, record_class=None): cur.close() def select_one(self, conn, _query_name, sql, parameters, record_class=None): + """Handle a tuple-returning (one row) SELECT (``^`` suffix). + + Return None if empty.""" cur = self._cursor(conn) try: cur.execute(sql, parameters) @@ -50,6 +56,9 @@ def select_one(self, conn, _query_name, sql, parameters, record_class=None): return result def select_value(self, conn, _query_name, sql, parameters): + """Handle a scalar-returning (one value) SELECT (``$`` suffix). + + Return None if empty.""" cur = self._cursor(conn) try: cur.execute(sql, parameters) @@ -68,6 +77,7 @@ def select_value(self, conn, _query_name, sql, parameters): @contextmanager def select_cursor(self, conn, _query_name, sql, parameters): + """Return the raw cursor after a SELECT exec.""" cur = self._cursor(conn) cur.execute(sql, parameters) try: @@ -76,6 +86,7 @@ def select_cursor(self, conn, _query_name, sql, parameters): cur.close() def insert_update_delete(self, conn, _query_name, sql, parameters): + """Handle affected row counts (INSERT UPDATE DELETE) (``!`` suffix).""" cur = self._cursor(conn) cur.execute(sql, parameters) rc = cur.rowcount if hasattr(cur, "rowcount") else -1 @@ -83,13 +94,16 @@ def insert_update_delete(self, conn, _query_name, sql, parameters): return rc def insert_update_delete_many(self, conn, _query_name, sql, parameters): + """Handle affected row counts (INSERT UPDATE DELETE) (``*!`` suffix).""" cur = self._cursor(conn) cur.executemany(sql, parameters) rc = cur.rowcount if hasattr(cur, "rowcount") else -1 cur.close() return rc + # FIXME this made sense when SQLite had no RETURNING prefix (v3.35, 2021-03-12) def insert_returning(self, conn, _query_name, sql, parameters): + """Special case for RETURNING (``= 3.10 from typing import Any, Callable, List, Optional, Set, Tuple, Union, Dict, cast from .types import DriverAdapterProtocol, QueryDatum, QueryDataTree, QueryFn, SQLOperationType -def _params(args, kwargs) -> Union[List[Any], Dict[str, Any]]: - if len(kwargs) > 0: - return kwargs - else: - return args - - -def _query_fn( - fn: Callable[..., Any], - name: str, - doc: Optional[str], - sql: str, - operation: SQLOperationType, - signature: Optional[inspect.Signature], - floc: Optional[Tuple[Path, int]] = None, -) -> QueryFn: - # TODO remove version workaround and pragmas when 3.7 support is dropped - # FIXME should get the lineno as well? - if floc and sys.version_info >= (3, 8, 0): # pragma: no cover - fname, lineno = floc - fn.__code__ = fn.__code__.replace(co_filename=str(fname), co_firstlineno=lineno) # type: ignore - qfn = cast(QueryFn, fn) - qfn.__name__ = name - qfn.__doc__ = doc - qfn.__signature__ = signature - qfn.sql = sql - qfn.operation = operation - return qfn - - -# NOTE about coverage: because __code__ is set to reflect the actual SQL file -# source, coverage does note detect that the "fn" functions are actually called, -# hence the "no cover" hints. -def _make_sync_fn(query_datum: QueryDatum) -> QueryFn: - query_name, doc_comments, operation_type, sql, record_class, signature, floc = query_datum - if operation_type == SQLOperationType.INSERT_RETURNING: - - def fn(self: Queries, conn, *args, **kwargs): # pragma: no cover - return self.driver_adapter.insert_returning( - conn, query_name, sql, _params(args, kwargs) - ) +class Queries: + """Container object with dynamic methods built from SQL queries. - elif operation_type == SQLOperationType.INSERT_UPDATE_DELETE: + The ``-- name:`` definition comments in the content of the SQL determine what the dynamic + methods of this class will be named. - def fn(self: Queries, conn, *args, **kwargs): # type: ignore # pragma: no cover - return self.driver_adapter.insert_update_delete( - conn, query_name, sql, _params(args, kwargs) - ) + **Parameters:** - elif operation_type == SQLOperationType.INSERT_UPDATE_DELETE_MANY: + - **driver_adapter**: Either a string to designate one of the aiosql built-in database driver + adapters (e.g. "sqlite3", "psycopg"). + If you have defined your own adapter class, you can pass its constructor. + - **kwargs_only**: whether to reject positional parameters. + """ - def fn(self: Queries, conn, *args, **kwargs): # type: ignore # pragma: no cover - assert not kwargs # help type checker - return self.driver_adapter.insert_update_delete_many(conn, query_name, sql, *args) + def __init__(self, driver_adapter: DriverAdapterProtocol, kwargs_only: bool = False): + self.driver_adapter: DriverAdapterProtocol = driver_adapter + self.is_aio: bool = getattr(driver_adapter, "is_aio_driver", False) + self._kwargs_only = kwargs_only + self._available_queries: Set[str] = set() - elif operation_type == SQLOperationType.SCRIPT: + # + # INTERNAL UTILS + # + def _params( + self, args: Union[List[Any], Tuple[Any]], kwargs: Dict[str, Any] + ) -> Union[List[Any], Tuple[Any], Dict[str, Any]]: + """Execute parameter handling.""" + if self._kwargs_only: + if args: + raise ValueError("cannot use positional parameters under kwargs_only") + return kwargs + elif kwargs: + # FIXME is this true? + if args: + raise ValueError("cannot mix positional and named parameters in query") + return kwargs + else: + return args + + def _query_fn( + self, + fn: Callable[..., Any], + name: str, + doc: Optional[str], + sql: str, + operation: SQLOperationType, + signature: Optional[inspect.Signature], + floc: Tuple[Union[Path, str], int] = ("", 0), + ) -> QueryFn: + """Add custom-made metadata to a dynamically generated function.""" + fname, lineno = floc + fn.__code__ = fn.__code__.replace(co_filename=str(fname), co_firstlineno=lineno) # type: ignore + qfn = cast(QueryFn, fn) + qfn.__name__ = name + qfn.__doc__ = doc + qfn.__signature__ = signature + qfn.sql = sql + qfn.operation = operation + return qfn - def fn(self: Queries, conn, *args, **kwargs): # type: ignore # pragma: no cover - return self.driver_adapter.execute_script(conn, sql) + # NOTE about coverage: because __code__ is set to reflect the actual SQL file + # source, coverage does note detect that the "fn" functions are actually called, + # hence the "no cover" hints. + def _make_sync_fn(self, query_datum: QueryDatum) -> QueryFn: + """Build a dynamic method from a parsed query.""" + query_name, doc_comments, operation_type, sql, record_class, signature, floc = query_datum + if operation_type == SQLOperationType.INSERT_RETURNING: - elif operation_type == SQLOperationType.SELECT: + def fn(self, conn, *args, **kwargs): # pragma: no cover + return self.driver_adapter.insert_returning( + conn, query_name, sql, self._params(args, kwargs) + ) - def fn(self: Queries, conn, *args, **kwargs): # type: ignore # pragma: no cover - return self.driver_adapter.select( - conn, query_name, sql, _params(args, kwargs), record_class - ) + elif operation_type == SQLOperationType.INSERT_UPDATE_DELETE: - elif operation_type == SQLOperationType.SELECT_ONE: + def fn(self, conn, *args, **kwargs): # type: ignore # pragma: no cover + return self.driver_adapter.insert_update_delete( + conn, query_name, sql, self._params(args, kwargs) + ) - def fn(self: Queries, conn, *args, **kwargs): # pragma: no cover - return self.driver_adapter.select_one( - conn, query_name, sql, _params(args, kwargs), record_class - ) + elif operation_type == SQLOperationType.INSERT_UPDATE_DELETE_MANY: - elif operation_type == SQLOperationType.SELECT_VALUE: + def fn(self, conn, *args, **kwargs): # type: ignore # pragma: no cover + assert not kwargs, "cannot use named parameters in many query" # help type checker + return self.driver_adapter.insert_update_delete_many(conn, query_name, sql, *args) - def fn(self: Queries, conn, *args, **kwargs): # pragma: no cover - return self.driver_adapter.select_value(conn, query_name, sql, _params(args, kwargs)) + elif operation_type == SQLOperationType.SCRIPT: - else: - raise ValueError(f"Unknown operation_type: {operation_type}") + def fn(self, conn, *args, **kwargs): # type: ignore # pragma: no cover + # FIXME parameters are ignored? + return self.driver_adapter.execute_script(conn, sql) - return _query_fn(fn, query_name, doc_comments, sql, operation_type, signature, floc) + elif operation_type == SQLOperationType.SELECT: + def fn(self, conn, *args, **kwargs): # type: ignore # pragma: no cover + return self.driver_adapter.select( + conn, query_name, sql, self._params(args, kwargs), record_class + ) -def _make_async_fn(fn: QueryFn) -> QueryFn: - async def afn(self: Queries, conn, *args, **kwargs): - return await fn(self, conn, *args, **kwargs) + elif operation_type == SQLOperationType.SELECT_ONE: - return _query_fn(afn, fn.__name__, fn.__doc__, fn.sql, fn.operation, fn.__signature__) + def fn(self, conn, *args, **kwargs): # pragma: no cover + return self.driver_adapter.select_one( + conn, query_name, sql, self._params(args, kwargs), record_class + ) + elif operation_type == SQLOperationType.SELECT_VALUE: -def _make_ctx_mgr(fn: QueryFn) -> QueryFn: - def ctx_mgr(self, conn, *args, **kwargs): - return self.driver_adapter.select_cursor(conn, fn.__name__, fn.sql, _params(args, kwargs)) + def fn(self, conn, *args, **kwargs): # pragma: no cover + return self.driver_adapter.select_value( + conn, query_name, sql, self._params(args, kwargs) + ) - return _query_fn( - ctx_mgr, f"{fn.__name__}_cursor", fn.__doc__, fn.sql, fn.operation, fn.__signature__ - ) + else: + raise ValueError(f"Unknown operation_type: {operation_type}") + return self._query_fn(fn, query_name, doc_comments, sql, operation_type, signature, floc) -def _create_methods(query_datum: QueryDatum, is_aio: bool) -> List[QueryFn]: - """Internal function to feed add_queries.""" - fn = _make_sync_fn(query_datum) - if is_aio: - fn = _make_async_fn(fn) + # NOTE does this make sense? + def _make_async_fn(self, fn: QueryFn) -> QueryFn: + """Wrap in an async function.""" - ctx_mgr = _make_ctx_mgr(fn) + async def afn(self, conn, *args, **kwargs): # pragma: no cover + return await fn(self, conn, *args, **kwargs) - if query_datum.operation_type == SQLOperationType.SELECT: - return [fn, ctx_mgr] - else: - return [fn] + return self._query_fn(afn, fn.__name__, fn.__doc__, fn.sql, fn.operation, fn.__signature__) + def _make_ctx_mgr(self, fn: QueryFn) -> QueryFn: + """Wrap in a context manager function.""" -class Queries: - """Container object with dynamic methods built from SQL queries. + def ctx_mgr(self, conn, *args, **kwargs): # pragma: no cover + return self.driver_adapter.select_cursor( + conn, fn.__name__, fn.sql, self._params(args, kwargs) + ) - The ``-- name`` definition comments in the content of the SQL determine what the dynamic - methods of this class will be named. + return self._query_fn( + ctx_mgr, f"{fn.__name__}_cursor", fn.__doc__, fn.sql, fn.operation, fn.__signature__ + ) - **Parameters:** + def _create_methods(self, query_datum: QueryDatum, is_aio: bool) -> List[QueryFn]: + """Internal function to feed add_queries.""" + fn = self._make_sync_fn(query_datum) + if is_aio: + fn = self._make_async_fn(fn) - - **driver_adapter** - Either a string to designate one of the aiosql built-in database driver - adapters (e.g. "sqlite3", "psycopg"). - If you have defined your own adapter class, you can pass its constructor. - """ + ctx_mgr = self._make_ctx_mgr(fn) - def __init__(self, driver_adapter: DriverAdapterProtocol): - self.driver_adapter: DriverAdapterProtocol = driver_adapter - self.is_aio: bool = getattr(driver_adapter, "is_aio_driver", False) - self._available_queries: Set[str] = set() + if query_datum.operation_type == SQLOperationType.SELECT: + return [fn, ctx_mgr] + else: + return [fn] + # + # PUBLIC INTERFACE + # @property def available_queries(self) -> List[str]: """Returns listing of all the available query methods loaded in this class. @@ -148,10 +171,10 @@ def available_queries(self) -> List[str]: """ return sorted(self._available_queries) - def __repr__(self): + def __repr__(self) -> str: return "Queries(" + self.available_queries.__repr__() + ")" - def add_query(self, query_name: str, fn: Callable): + def add_query(self, query_name: str, fn: Callable) -> None: """Adds a new dynamic method to this class. **Parameters:** @@ -162,13 +185,13 @@ def add_query(self, query_name: str, fn: Callable): setattr(self, query_name, fn) self._available_queries.add(query_name) - def add_queries(self, queries: List[QueryFn]): + def add_queries(self, queries: List[QueryFn]) -> None: """Add query methods to `Queries` instance.""" for fn in queries: query_name = fn.__name__.rpartition(".")[2] self.add_query(query_name, MethodType(fn, self)) - def add_child_queries(self, child_name: str, child_queries: "Queries"): + def add_child_queries(self, child_name: str, child_queries: "Queries") -> None: """Adds a Queries object as a property. **Parameters:** @@ -183,7 +206,7 @@ def add_child_queries(self, child_name: str, child_queries: "Queries"): def load_from_list(self, query_data: List[QueryDatum]): """Load Queries from a list of `QuaryDatum`""" for query_datum in query_data: - self.add_queries(_create_methods(query_datum, self.is_aio)) + self.add_queries(self._create_methods(query_datum, self.is_aio)) return self def load_from_tree(self, query_data_tree: QueryDataTree): @@ -192,5 +215,5 @@ def load_from_tree(self, query_data_tree: QueryDataTree): if isinstance(value, dict): self.add_child_queries(key, Queries(self.driver_adapter).load_from_tree(value)) else: - self.add_queries(_create_methods(value, self.is_aio)) + self.add_queries(self._create_methods(value, self.is_aio)) return self diff --git a/aiosql/query_loader.py b/aiosql/query_loader.py index 70895da2..edd95090 100644 --- a/aiosql/query_loader.py +++ b/aiosql/query_loader.py @@ -1,7 +1,7 @@ import re import inspect from pathlib import Path -from typing import Dict, List, Optional, Tuple, Type, Sequence, Any +from typing import Dict, List, Optional, Tuple, Type, Sequence, Any, Union from .utils import SQLParseException, SQLLoadException, VAR_REF, log from .types import QueryDatum, QueryDataTree, SQLOperationType, DriverAdapterProtocol @@ -72,7 +72,7 @@ def __init__( self.record_classes = record_classes if record_classes is not None else {} def _make_query_datum( - self, query: str, ns_parts: List[str], floc: Optional[Tuple[Path, int]] = None + self, query: str, ns_parts: List[str], floc: Tuple[Union[Path, str], int] ) -> QueryDatum: # Build a query datum # - query: the spec and name ("query-name!\n-- comments\nSQL;\n") @@ -134,7 +134,7 @@ def _build_signature(self, sql: str) -> inspect.Signature: return inspect.Signature(parameters=params) def load_query_data_from_sql( - self, sql: str, ns_parts: List[str] = [], fname: Optional[Path] = None + self, sql: str, ns_parts: List[str], fname: Union[Path, str] = "" ) -> List[QueryDatum]: usql = _remove_ml_comments(sql) qdefs = _QUERY_DEF.split(usql) @@ -143,7 +143,7 @@ def load_query_data_from_sql( data = [] # first item is anything before the first query definition, drop it! for qdef in qdefs[1:]: - data.append(self._make_query_datum(qdef, ns_parts, (fname, lineno) if fname else None)) + data.append(self._make_query_datum(qdef, ns_parts, (fname, lineno))) lineno += qdef.count("\n") return data diff --git a/aiosql/types.py b/aiosql/types.py index bd36375e..e281f4e5 100644 --- a/aiosql/types.py +++ b/aiosql/types.py @@ -38,9 +38,9 @@ class QueryDatum(NamedTuple): doc_comments: str operation_type: SQLOperationType sql: str - record_class: Any = None - signature: Optional[inspect.Signature] = None - floc: Optional[Tuple[Path, int]] = None + record_class: Any + signature: Optional[inspect.Signature] + floc: Tuple[Union[Path, str], int] class QueryFn(Protocol): @@ -49,8 +49,7 @@ class QueryFn(Protocol): sql: str operation: SQLOperationType - def __call__(self, *args: Any, **kwargs: Any) -> Any: - ... # pragma: no cover + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... # pragma: no cover # Can't make this a recursive type in terms of itself @@ -59,8 +58,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: class SyncDriverAdapterProtocol(Protocol): - def process_sql(self, query_name: str, op_type: SQLOperationType, sql: str) -> str: - ... # pragma: no cover + def process_sql( + self, query_name: str, op_type: SQLOperationType, sql: str + ) -> str: ... # pragma: no cover def select( self, @@ -69,8 +69,7 @@ def select( sql: str, parameters: Union[List, Dict], record_class: Optional[Callable], - ) -> List: - ... # pragma: no cover + ) -> List: ... # pragma: no cover def select_one( self, @@ -79,43 +78,37 @@ def select_one( sql: str, parameters: Union[List, Dict], record_class: Optional[Callable], - ) -> Optional[Any]: - ... # pragma: no cover + ) -> Optional[Any]: ... # pragma: no cover def select_value( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> Optional[Any]: - ... # pragma: no cover + ) -> Optional[Any]: ... # pragma: no cover def select_cursor( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> ContextManager[Any]: - ... # pragma: no cover + ) -> ContextManager[Any]: ... # pragma: no cover # TODO: Next major version introduce a return? Optional return? def insert_update_delete( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> int: - ... # pragma: no cover + ) -> int: ... # pragma: no cover # TODO: Next major version introduce a return? Optional return? def insert_update_delete_many( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> int: - ... # pragma: no cover + ) -> int: ... # pragma: no cover def insert_returning( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> Optional[Any]: - ... # pragma: no cover + ) -> Optional[Any]: ... # pragma: no cover - def execute_script(self, conn: Any, sql: str) -> str: - ... # pragma: no cover + def execute_script(self, conn: Any, sql: str) -> str: ... # pragma: no cover class AsyncDriverAdapterProtocol(Protocol): - def process_sql(self, query_name: str, op_type: SQLOperationType, sql: str) -> str: - ... # pragma: no cover + def process_sql( + self, query_name: str, op_type: SQLOperationType, sql: str + ) -> str: ... # pragma: no cover async def select( self, @@ -124,8 +117,7 @@ async def select( sql: str, parameters: Union[List, Dict], record_class: Optional[Callable], - ) -> List: - ... # pragma: no cover + ) -> List: ... # pragma: no cover async def select_one( self, @@ -134,38 +126,31 @@ async def select_one( sql: str, parameters: Union[List, Dict], record_class: Optional[Callable], - ) -> Optional[Any]: - ... # pragma: no cover + ) -> Optional[Any]: ... # pragma: no cover async def select_value( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> Optional[Any]: - ... # pragma: no cover + ) -> Optional[Any]: ... # pragma: no cover async def select_cursor( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> AsyncContextManager[Any]: - ... # pragma: no cover + ) -> AsyncContextManager[Any]: ... # pragma: no cover # TODO: Next major version introduce a return? Optional return? async def insert_update_delete( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> None: - ... # pragma: no cover + ) -> None: ... # pragma: no cover # TODO: Next major version introduce a return? Optional return? async def insert_update_delete_many( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> None: - ... # pragma: no cover + ) -> None: ... # pragma: no cover async def insert_returning( self, conn: Any, query_name: str, sql: str, parameters: Union[List, Dict] - ) -> Optional[Any]: - ... # pragma: no cover + ) -> Optional[Any]: ... # pragma: no cover - async def execute_script(self, conn: Any, sql: str) -> str: - ... # pragma: no cover + async def execute_script(self, conn: Any, sql: str) -> str: ... # pragma: no cover DriverAdapterProtocol = Union[SyncDriverAdapterProtocol, AsyncDriverAdapterProtocol] diff --git a/aiosql/utils.py b/aiosql/utils.py index 600fa28f..401dd38e 100644 --- a/aiosql/utils.py +++ b/aiosql/utils.py @@ -10,7 +10,7 @@ # NOTE beware of overlapping re r"(?P[^:]):(?P[\w-]+)(?=[^:]?)" ) -"""Pattern to identifies colon-variables in SQL code""" +"""Pattern to identifies colon-variables (aka _named_ style) in SQL code""" log = logging.getLogger("aiosql") # log.setLevel(logging.DEBUG) diff --git a/pyproject.toml b/pyproject.toml index e6980205..e0fe1ad7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aiosql" -version = "9.3" +version = "9.4" authors = [ { name = "William Vaughn et al.", email = "vaughnwilld@gmail.com" } ] description = "Simple SQL in Python" readme = "README.rst" @@ -18,7 +18,7 @@ classifiers = [ [project.optional-dependencies] dev = [ - "pytest", "pytest-asyncio", + "pytest<8", "pytest-asyncio", "rstcheck", "black", "coverage", "flake8", "mypy", "pyright", "types-setuptools", "build" ] diff --git a/tests/conf_schema.py b/tests/conf_schema.py index d203c858..84ee3a05 100644 --- a/tests/conf_schema.py +++ b/tests/conf_schema.py @@ -12,9 +12,7 @@ def create_user_blogs(db): serial = ( "SERIAL" if db == "pgsql" - else "INTEGER" - if db in ("sqlite", "duckdb") - else "INTEGER auto_increment" + else "INTEGER" if db in ("sqlite", "duckdb") else "INTEGER auto_increment" ) ddl_statements = [ f"""CREATE TABLE IF NOT EXISTS users ( diff --git a/tests/run_tests.py b/tests/run_tests.py index 8d17e0b0..a3b4dfd3 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -106,18 +106,16 @@ def run_parameterized_query(conn, queries, db=None): def run_parameterized_record_query(conn, queries, db, todate): - # this black-generated indentation is a jokeā€¦ - fun = ( - queries.blogs.sqlite_get_blogs_published_after - if _DB[db] == "sqlite3" - else queries.blogs.duckdb_get_blogs_published_after - if _DB[db] == "duckdb" - else queries.blogs.pg_get_blogs_published_after - if _DB[db] == "postgres" - else queries.blogs.my_get_blogs_published_after - if _DB[db] in ("mysql", "mariadb") - else None - ) + if _DB[db] == "sqlite3": + fun = queries.blogs.sqlite_get_blogs_published_after + elif _DB[db] == "duckdb": + fun = queries.blogs.duckdb_get_blogs_published_after + elif _DB[db] == "postgres": + fun = queries.blogs.pg_get_blogs_published_after + elif _DB[db] in ("mysql", "mariadb"): + fun = queries.blogs.my_get_blogs_published_after + else: + raise Exception(f"unexpected driver: {db}") raw_actual = fun(conn, published=todate(2018, 1, 1)) assert isinstance(raw_actual, Iterable) @@ -174,17 +172,17 @@ def run_select_one(conn, queries, db=None): def run_insert_returning(conn, queries, db, todate): - fun = ( - queries.blogs.publish_blog - if _DB[db] in ("sqlite3") - else queries.blogs.duckdb_publish_blog - if _DB[db] in ("duckdb") - else queries.blogs.pg_publish_blog - if _DB[db] in ("postgres", "mariadb") - else queries.blogs.my_publish_blog - if _DB[db] == "mysql" - else None - ) + if _DB[db] in ("sqlite3"): + fun = queries.blogs.publish_blog + elif _DB[db] in ("duckdb"): + fun = queries.blogs.duckdb_publish_blog + elif _DB[db] in ("postgres", "mariadb"): + fun = queries.blogs.pg_publish_blog + elif _DB[db] == "mysql": + fun = queries.blogs.my_publish_blog + else: + raise Exception(f"unexpected driver: {db}") + if db == "duckdb": blogid = fun( conn, @@ -338,9 +336,7 @@ async def run_async_parameterized_record_query(conn, queries, db, todate): fun = ( queries.blogs.pg_get_blogs_published_after if _DB[db] == "postgres" - else queries.blogs.sqlite_get_blogs_published_after - if _DB[db] == "sqlite3" - else None + else queries.blogs.sqlite_get_blogs_published_after if _DB[db] == "sqlite3" else None ) records = await fun(conn, published=todate(2018, 1, 1)) diff --git a/tests/test_loading.py b/tests/test_loading.py index c48dfc75..8a1d57c2 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -179,7 +179,8 @@ def test_file_loading(sql_file): def test_misc(sql_file): try: - aiosql.queries._make_sync_fn(("hello", None, -1, "SELECT NULL;", None, None, None)) + queries = aiosql.queries.Queries("sqlite3") + queries._make_sync_fn(("hello", None, -1, "SELECT NULL;", None, None, None)) assert False, "must raise an exception" # pragma: no cover except ValueError as e: assert "Unknown operation_type" in str(e) @@ -194,3 +195,27 @@ def test_misc(sql_file): assert False, "must raise en exception" # pragma: no cover except ValueError as e: assert "must be a directory" in str(e) + + +def test_kwargs(): + # kwargs_only == True + queries = aiosql.from_str("-- name: plus_one$\nSELECT 1 + :val;\n", "sqlite3", kwargs_only=True) + import sqlite3 + + conn = sqlite3.connect(":memory:") + assert 42 == queries.plus_one(conn, val=41) + try: + queries.plus_one(conn, 2) + assert False, "must raise an exception" # pragma: no cover + except ValueError as e: + assert "kwargs" in str(e) + # kwargs_only == False + queries = aiosql.from_str( + "-- name: plus_two$\nSELECT 2 + :val;\n", "sqlite3", kwargs_only=False + ) + assert 42 == queries.plus_two(conn, val=40) + try: + queries.plus_two(conn, 2, val=41) + assert False, "must raise an exception" # pragma: no cover + except ValueError as e: + assert "mix" in str(e)