diff --git a/README.md b/README.md index 6712f14c..275e1d9f 100644 --- a/README.md +++ b/README.md @@ -109,8 +109,6 @@ db_pool = PSQLPool( ) async def main() -> None: - await db_pool.startup() - res: QueryResult = await db_pool.execute( "SELECT * FROM users", ) @@ -147,10 +145,12 @@ As connection can be closed in different situations on various sides you can sel rendered ineffective. ## Results from querying + You have some options to get results from the query. `execute()` method, for example, returns `QueryResult` and this class can be converted into `list` of `dict`s - `list[dict[Any, Any]]` or into any Python class (`pydantic` model, as an example). Let's see some code: + ```python from typing import Any @@ -169,8 +169,6 @@ db_pool = PSQLPool( ) async def main() -> None: - await db_pool.startup() - res: QueryResult = await db_pool.execute( "SELECT * FROM users", ) @@ -213,8 +211,6 @@ db_pool = PSQLPool( ) async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() res: QueryResult = await connection.execute( @@ -252,8 +248,6 @@ from psqlpy import PSQLPool, IsolationLevel, QueryResult db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() async with connection.transaction() as transaction: res: QueryResult = await transaction.execute( @@ -276,8 +270,6 @@ from psqlpy import PSQLPool, IsolationLevel db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction( isolation_level=IsolationLevel.Serializable, @@ -310,8 +302,6 @@ from psqlpy import PSQLPool, IsolationLevel db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction( isolation_level=IsolationLevel.Serializable, @@ -339,8 +329,6 @@ from psqlpy import PSQLPool, IsolationLevel db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction( isolation_level=IsolationLevel.Serializable, @@ -367,8 +355,6 @@ from psqlpy import PSQLPool, IsolationLevel db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction( isolation_level=IsolationLevel.Serializable, @@ -383,6 +369,7 @@ async def main() -> None: ``` ### Transaction pipelining + When you have a lot of independent queries and want to execute them concurrently, you can use `pipeline`. Pipelining can improve performance in use cases in which multiple, independent queries need to be executed. @@ -390,6 +377,7 @@ In a traditional workflow, each query is sent to the server after the previous query completes. In contrast, pipelining allows the client to send all of the queries to the server up front, minimizing time spent by one side waiting for the other to finish sending data: + ``` Sequential Pipelined | Client | Server | | Client | Server | @@ -404,9 +392,11 @@ minimizing time spent by one side waiting for the other to finish sending data: | | process query 3 | | receive rows 3 | | ``` + Read more: https://docs.rs/tokio-postgres/latest/tokio_postgres/#pipelining Let's see some code: + ```python import asyncio @@ -415,8 +405,6 @@ from psqlpy import PSQLPool, QueryResult async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - transaction = await db_pool.transaction() results: list[QueryResult] = await transaction.pipeline( @@ -450,8 +438,6 @@ from psqlpy import PSQLPool, IsolationLevel db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction( isolation_level=IsolationLevel.Serializable, @@ -484,8 +470,6 @@ from psqlpy import PSQLPool, IsolationLevel db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction( isolation_level=IsolationLevel.Serializable, @@ -524,8 +508,6 @@ from psqlpy import PSQLPool, IsolationLevel, QueryResult db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction( isolation_level=IsolationLevel.Serializable, @@ -553,6 +535,7 @@ async def main() -> None: ``` ### Cursor as an async context manager + ```python from typing import Any @@ -563,8 +546,6 @@ db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - connection = await db_pool.connection() transaction: Transaction cursor: Cursor @@ -624,8 +605,6 @@ from psqlpy.extra_types import ( db_pool = PSQLPool() async def main() -> None: - await db_pool.startup() - res: list[dict[str, Any]] = await db_pool.execute( "INSERT INTO users VALUES ($1, $2, $3, $4, $5)", [ diff --git a/docs/examples/aiohttp/start_example.py b/docs/examples/aiohttp/start_example.py index f7d3b7ff..1ac2e1f0 100644 --- a/docs/examples/aiohttp/start_example.py +++ b/docs/examples/aiohttp/start_example.py @@ -1,6 +1,7 @@ # Start example import asyncio -from typing import cast +from typing import Any, cast + from aiohttp import web from psqlpy import PSQLPool @@ -11,7 +12,6 @@ async def start_db_pool(app: web.Application) -> None: dsn="postgres://postgres:postgres@localhost:5432/postgres", max_db_pool_size=2, ) - await db_pool.startup() app["db_pool"] = db_pool @@ -22,7 +22,7 @@ async def stop_db_pool(app: web.Application) -> None: await db_pool.close() -async def pg_pool_example(request: web.Request): +async def pg_pool_example(request: web.Request) -> Any: db_pool = cast(PSQLPool, request.app["db_pool"]) connection = await db_pool.connection() await asyncio.sleep(10) @@ -37,7 +37,7 @@ async def pg_pool_example(request: web.Request): application = web.Application() application.on_startup.append(start_db_pool) -application.add_routes([web.get('/', pg_pool_example)]) +application.add_routes([web.get("/", pg_pool_example)]) if __name__ == "__main__": diff --git a/docs/examples/fastapi/advanced_example.py b/docs/examples/fastapi/advanced_example.py index d3a50a27..5976fadb 100644 --- a/docs/examples/fastapi/advanced_example.py +++ b/docs/examples/fastapi/advanced_example.py @@ -1,12 +1,12 @@ # Start example import asyncio from contextlib import asynccontextmanager -from typing import Annotated, AsyncGenerator, cast -from fastapi import Depends, FastAPI, Request -from fastapi.responses import JSONResponse -from psqlpy import PSQLPool, Connection -import uvicorn +from typing import AsyncGenerator +import uvicorn +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from psqlpy import PSQLPool db_pool = PSQLPool( dsn="postgres://postgres:postgres@localhost:5432/postgres", @@ -17,7 +17,6 @@ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Startup database connection pool and close it on shutdown.""" - await db_pool.startup() app.state.db_pool = db_pool yield await db_pool.close() @@ -34,7 +33,7 @@ async def some_long_func() -> None: @app.get("/") -async def pg_pool_example(): +async def pg_pool_example() -> JSONResponse: await some_long_func() db_connection = await db_pool.connection() query_result = await db_connection.execute( @@ -47,4 +46,4 @@ async def pg_pool_example(): uvicorn.run( "advanced_example:app", port=8001, - ) \ No newline at end of file + ) diff --git a/docs/examples/fastapi/start_example.py b/docs/examples/fastapi/start_example.py index 13280c11..7bc3a06b 100644 --- a/docs/examples/fastapi/start_example.py +++ b/docs/examples/fastapi/start_example.py @@ -1,10 +1,12 @@ # Start example from contextlib import asynccontextmanager -from typing import Annotated, AsyncGenerator, cast +from typing import AsyncGenerator, cast + +import uvicorn from fastapi import Depends, FastAPI, Request from fastapi.responses import JSONResponse -from psqlpy import PSQLPool, Connection -import uvicorn +from psqlpy import Connection, PSQLPool +from typing_extensions import Annotated @asynccontextmanager @@ -14,7 +16,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: dsn="postgres://postgres:postgres@localhost:5432/postgres", max_db_pool_size=2, ) - await db_pool.startup() app.state.db_pool = db_pool yield await db_pool.close() @@ -31,7 +32,7 @@ async def db_connection(request: Request) -> Connection: @app.get("/") async def pg_pool_example( db_connection: Annotated[Connection, Depends(db_connection)], -): +) -> JSONResponse: query_result = await db_connection.execute( "SELECT * FROM users", ) @@ -41,4 +42,4 @@ async def pg_pool_example( if __name__ == "__main__": uvicorn.run( "start_example:app", - ) \ No newline at end of file + ) diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 46f5723f..e2abdd8b 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -39,8 +39,6 @@ class QueryResult: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - query_result: QueryResult = await db_pool.execute( "SELECT username FROM users WHERE id = $1", [100], @@ -82,8 +80,6 @@ class SingleQueryResult: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() async with connection.transaction() as trans: query_result: SingleQueryResult = await trans.fetch_row( @@ -343,8 +339,6 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction() await transaction.begin() @@ -383,8 +377,6 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction() await transaction.begin() @@ -425,8 +417,6 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction() await transaction.begin() @@ -465,12 +455,10 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction() await transaction.begin() - value: Any | None = await transaction.execute( + value: Any = await transaction.fetch_val( "SELECT username FROM users WHERE id = $1", [100], ) @@ -520,8 +508,6 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction() @@ -565,8 +551,6 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction() @@ -590,8 +574,6 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction() await transaction.execute(...) @@ -616,8 +598,6 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction() @@ -644,8 +624,6 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = connection.transaction() @@ -685,8 +663,6 @@ class Transaction: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() transaction = await connection.transaction() @@ -740,8 +716,6 @@ class Connection: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() - connection = await db_pool.connection() query_result: QueryResult = await connection.execute( "SELECT username FROM users WHERE id = $1", @@ -750,6 +724,109 @@ class Connection: dict_result: List[Dict[Any, Any]] = query_result.result() ``` """ + async def execute_many( + self: Self, + querystring: str, + parameters: list[list[Any]] | None = None, + prepared: bool = True, + ) -> None: ... + """Execute query multiple times with different parameters. + + Querystring can contain `$` parameters + for converting them in the driver side. + + ### Parameters: + - `querystring`: querystring to execute. + - `parameters`: list of list of parameters to pass in the query. + - `prepared`: should the querystring be prepared before the request. + By default any querystring will be prepared. + + ### Example: + ```python + import asyncio + + from psqlpy import PSQLPool, QueryResult + + + async def main() -> None: + db_pool = PSQLPool() + connection = await db_pool.connection() + query_result: QueryResult = await connection.execute_many( + "INSERT INTO users (name, age) VALUES ($1, $2)", + [["boba", 10], ["boba", 20]], + ) + dict_result: List[Dict[Any, Any]] = query_result.result() + ``` + """ + async def fetch_row( + self: Self, + querystring: str, + parameters: list[Any] | None = None, + prepared: bool = True, + ) -> SingleQueryResult: + """Fetch exaclty single row from query. + + Query must return exactly one row, otherwise error will be raised. + Querystring can contain `$` parameters + for converting them in the driver side. + + + ### Parameters: + - `querystring`: querystring to execute. + - `parameters`: list of parameters to pass in the query. + - `prepared`: should the querystring be prepared before the request. + By default any querystring will be prepared. + + ### Example: + ```python + import asyncio + + from psqlpy import PSQLPool, QueryResult + + + async def main() -> None: + db_pool = PSQLPool() + + connection = await db_pool.connection() + fetched_row: SingleQueryResult = await connection.fetch_row( + "SELECT * FROM users LIMIT 1", + [], + ) + ``` + """ + async def fetch_val( + self: Self, + querystring: str, + parameters: list[Any] | None = None, + prepared: bool = True, + ) -> Any | None: + """Execute the query and return first value of the first row. + + Querystring can contain `$` parameters + for converting them in the driver side. + + ### Parameters: + - `querystring`: querystring to execute. + - `parameters`: list of parameters to pass in the query. + - `prepared`: should the querystring be prepared before the request. + By default any querystring will be prepared. + + ### Example: + ```python + import asyncio + + from psqlpy import PSQLPool, QueryResult + + + async def main() -> None: + db_pool = PSQLPool() + connection = await db_pool.connection() + value: Any = await connection.fetch_val( + "SELECT username FROM users WHERE id = $1", + [100], + ) + ``` + """ def transaction( self, isolation_level: IsolationLevel | None = None, @@ -778,7 +855,7 @@ class PSQLPool: host: Optional[str] = None, port: Optional[int] = None, db_name: Optional[str] = None, - max_db_pool_size: Optional[str] = None, + max_db_pool_size: int = 2, conn_recycling_method: Optional[ConnRecyclingMethod] = None, ) -> None: """Create new PostgreSQL connection pool. @@ -804,11 +881,6 @@ class PSQLPool: - `max_db_pool_size`: maximum size of the connection pool - `conn_recycling_method`: how a connection is recycled. """ - async def startup(self: Self) -> None: - """Startup the connection pool. - - You must call it before start making queries. - """ async def close(self: Self) -> None: """Close the connection pool. @@ -841,7 +913,6 @@ class PSQLPool: async def main() -> None: db_pool = PSQLPool() - await db_pool.startup() query_result: QueryResult = await psqlpy.execute( "SELECT username FROM users WHERE id = $1", [100], @@ -864,7 +935,7 @@ def create_connection_pool( host: Optional[str] = None, port: Optional[int] = None, db_name: Optional[str] = None, - max_db_pool_size: Optional[str] = None, + max_db_pool_size: int = 2, conn_recycling_method: Optional[ConnRecyclingMethod] = None, ) -> PSQLPool: """Create new connection pool. diff --git a/python/tests/helpers.py b/python/tests/helpers.py index 254e62a1..3e1336b2 100644 --- a/python/tests/helpers.py +++ b/python/tests/helpers.py @@ -1,11 +1,14 @@ +from __future__ import annotations + import typing -from psqlpy import Transaction +if typing.TYPE_CHECKING: + from psqlpy import Connection, Transaction async def count_rows_in_test_table( table_name: str, - transaction: Transaction, + transaction: Transaction | Connection, ) -> int: query_result: typing.Final = await transaction.execute( f"SELECT COUNT(*) FROM {table_name}", diff --git a/python/tests/test_connection.py b/python/tests/test_connection.py index 6407b7cb..e89ba025 100644 --- a/python/tests/test_connection.py +++ b/python/tests/test_connection.py @@ -1,6 +1,12 @@ +from __future__ import annotations + +import typing + import pytest +from tests.helpers import count_rows_in_test_table from psqlpy import PSQLPool, QueryResult, Transaction +from psqlpy.exceptions import DBTransactionError, RustPSQLDriverPyBaseError pytestmark = pytest.mark.anyio @@ -20,7 +26,7 @@ async def test_connection_execute( assert len(conn_result.result()) == number_database_records -async def test_connection_transaction( +async def test_connection_connection( psql_pool: PSQLPool, ) -> None: """Test that connection can create transactions.""" @@ -28,3 +34,82 @@ async def test_connection_transaction( transaction = connection.transaction() assert isinstance(transaction, Transaction) + + +@pytest.mark.parametrize( + ("insert_values"), + [ + [[1, "name1"], [2, "name2"]], + [[10, "name1"], [20, "name2"], [30, "name3"]], + [[1, "name1"]], + [], + ], +) +async def test_connection_execute_many( + psql_pool: PSQLPool, + table_name: str, + number_database_records: int, + insert_values: list[list[typing.Any]], +) -> None: + connection = await psql_pool.connection() + try: + await connection.execute_many( + f"INSERT INTO {table_name} VALUES ($1, $2)", + insert_values, + ) + except DBTransactionError: + assert not insert_values + else: + assert await count_rows_in_test_table( + table_name, + connection, + ) - number_database_records == len(insert_values) + + +async def test_connection_fetch_row( + psql_pool: PSQLPool, + table_name: str, +) -> None: + connection = await psql_pool.connection() + database_single_query_result: typing.Final = await connection.fetch_row( + f"SELECT * FROM {table_name} LIMIT 1", + [], + ) + result = database_single_query_result.result() + assert isinstance(result, dict) + + +async def test_connection_fetch_row_more_than_one_row( + psql_pool: PSQLPool, + table_name: str, +) -> None: + connection = await psql_pool.connection() + with pytest.raises(RustPSQLDriverPyBaseError): + await connection.fetch_row( + f"SELECT * FROM {table_name}", + [], + ) + + +async def test_connection_fetch_val( + psql_pool: PSQLPool, + table_name: str, +) -> None: + connection = await psql_pool.connection() + value: typing.Final = await connection.fetch_val( + f"SELECT COUNT(*) FROM {table_name}", + [], + ) + assert isinstance(value, int) + + +async def test_connection_fetch_val_more_than_one_row( + psql_pool: PSQLPool, + table_name: str, +) -> None: + connection = await psql_pool.connection() + with pytest.raises(RustPSQLDriverPyBaseError): + await connection.fetch_row( + f"SELECT * FROM {table_name}", + [], + ) diff --git a/python/tests/test_cursor.py b/python/tests/test_cursor.py index 200d2cb8..1fe86a13 100644 --- a/python/tests/test_cursor.py +++ b/python/tests/test_cursor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math import pytest diff --git a/src/driver/connection.rs b/src/driver/connection.rs index a39cd26f..4bf11bb7 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -1,28 +1,234 @@ use deadpool_postgres::Object; -use pyo3::{pyclass, pymethods, PyAny, Python}; +use pyo3::{pyclass, pymethods, types::PyList, PyAny, Python}; use std::{collections::HashSet, sync::Arc, vec}; -use tokio_postgres::types::ToSql; use crate::{ common::rustdriver_future, exceptions::rust_errors::RustPSQLDriverPyResult, - query_result::PSQLDriverPyQueryResult, - value_converter::{convert_parameters, PythonDTO}, + query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, + value_converter::{convert_parameters, postgres_to_py, PythonDTO, QueryParameter}, }; +use tokio_postgres::Row; use super::{ transaction::{RustTransaction, Transaction}, transaction_options::{IsolationLevel, ReadVariant}, }; +#[allow(clippy::module_name_repetitions)] +pub struct RustConnection { + pub db_client: Arc>, +} + +impl RustConnection { + #[must_use] + pub fn new(db_client: Arc>) -> Self { + RustConnection { db_client } + } + /// Execute statement with or witout parameters. + /// + /// # Errors + /// + /// May return Err Result if + /// 1) Cannot convert incoming parameters + /// 2) Cannot prepare statement + /// 3) Cannot execute query + pub async fn inner_execute( + &self, + querystring: String, + parameters: Vec, + prepared: bool, + ) -> RustPSQLDriverPyResult { + let db_client = self.db_client.read().await; + let vec_parameters: Vec<&QueryParameter> = parameters + .iter() + .map(|param| param as &QueryParameter) + .collect(); + + let result = if prepared { + db_client + .query( + &db_client.prepare_cached(&querystring).await?, + &vec_parameters.into_boxed_slice(), + ) + .await? + } else { + db_client + .query(&querystring, &vec_parameters.into_boxed_slice()) + .await? + }; + + Ok(PSQLDriverPyQueryResult::new(result)) + } + + /// Execute querystring with many parameters. + /// + /// Method doesn't acquire lock on any structure fields. + /// It prepares and caches querystring in the inner Object object. + /// + /// Then execute the query. + /// + /// # Errors + /// May return Err Result if: + /// 1) Transaction is not started + /// 2) Transaction is done already + /// 3) Can not create/retrieve prepared statement + /// 4) Can not execute statement + pub async fn inner_execute_many( + &self, + querystring: String, + parameters: Vec>, + prepared: bool, + ) -> RustPSQLDriverPyResult<()> { + let mut db_client = self.db_client.write().await; + let transaction = db_client.transaction().await?; + for single_parameters in parameters { + if prepared { + transaction + .query( + &transaction.prepare_cached(&querystring).await?, + &single_parameters + .iter() + .map(|p| p as &QueryParameter) + .collect::>(), + ) + .await?; + } else { + transaction + .query( + &querystring, + &single_parameters + .iter() + .map(|p| p as &QueryParameter) + .collect::>(), + ) + .await?; + } + } + + transaction.commit().await?; + + Ok(()) + } + + /// Fetch exaclty single row from query. + /// + /// Method doesn't acquire lock on any structure fields. + /// It prepares and caches querystring in the inner Object object. + /// + /// Then execute the query. + /// + /// # Errors + /// May return Err Result if: + /// 1) Transaction is not started + /// 2) Transaction is done already + /// 3) Can not create/retrieve prepared statement + /// 4) Can not execute statement + /// 5) Query returns more than one row + pub async fn inner_fetch_row( + &self, + querystring: String, + parameters: Vec, + prepared: bool, + ) -> RustPSQLDriverPyResult { + let vec_parameters: Vec<&QueryParameter> = parameters + .iter() + .map(|param| param as &QueryParameter) + .collect(); + let db_client_guard = self.db_client.read().await; + + let result = if prepared { + db_client_guard + .query_one( + &db_client_guard.prepare_cached(&querystring).await?, + &vec_parameters.into_boxed_slice(), + ) + .await? + } else { + db_client_guard + .query_one(&querystring, &vec_parameters.into_boxed_slice()) + .await? + }; + + Ok(PSQLDriverSinglePyQueryResult::new(result)) + } + + /// Execute querystring with parameters. + /// + /// Method doesn't acquire lock on any structure fields. + /// It prepares and caches querystring in the inner Object object. + /// + /// Then execute the query. + /// + /// It returns `Vec` instead of `PSQLDriverPyQueryResult`. + /// + /// # Errors + /// May return Err Result if: + /// 1) Transaction is not started + /// 2) Transaction is done already + /// 3) Can not create/retrieve prepared statement + /// 4) Can not execute statement + pub async fn inner_execute_raw( + &self, + querystring: String, + parameters: Vec, + prepared: bool, + ) -> RustPSQLDriverPyResult> { + let db_client_guard = self.db_client.read().await; + let vec_parameters: Vec<&QueryParameter> = parameters + .iter() + .map(|param| param as &QueryParameter) + .collect(); + + let result = if prepared { + db_client_guard + .query( + &db_client_guard.prepare_cached(&querystring).await?, + &vec_parameters.into_boxed_slice(), + ) + .await? + } else { + db_client_guard + .query(&querystring, &vec_parameters.into_boxed_slice()) + .await? + }; + + Ok(result) + } + + /// Return new instance of transaction. + #[must_use] + pub fn inner_transaction( + &self, + isolation_level: Option, + read_variant: Option, + deferrable: Option, + ) -> Transaction { + let inner_transaction = RustTransaction::new( + Arc::new(RustConnection::new(self.db_client.clone())), + false, + false, + Arc::new(tokio::sync::RwLock::new(HashSet::new())), + isolation_level, + read_variant, + deferrable, + ); + + Transaction::new( + Arc::new(tokio::sync::RwLock::new(inner_transaction)), + Default::default(), + ) + } +} + #[pyclass()] pub struct Connection { - pub inner_connection: Arc, + pub inner_connection: Arc, } impl Connection { #[must_use] - pub fn new(inner_connection: Arc) -> Self { + pub fn new(inner_connection: Arc) -> Self { Connection { inner_connection } } } @@ -41,36 +247,117 @@ impl Connection { &'a self, py: Python<'a>, querystring: String, - parameters: Option<&'a PyAny>, + parameters: Option<&PyAny>, prepared: Option, ) -> RustPSQLDriverPyResult<&PyAny> { - let connection_arc = self.inner_connection.clone(); - let mut params: Vec = vec![]; if let Some(parameters) = parameters { params = convert_parameters(parameters)?; } - let is_prepared = prepared.unwrap_or(true); + let inner_connection_arc = self.inner_connection.clone(); rustdriver_future(py, async move { - let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(params.len()); - for param in ¶ms { - vec_parameters.push(param); + inner_connection_arc + .inner_execute(querystring, params, prepared.unwrap_or(true)) + .await + }) + } + + /// Execute querystring with parameters. + /// + /// It converts incoming parameters to rust readable + /// and then execute the query with them. + /// + /// # Errors + /// + /// May return Err Result if: + /// 1) Cannot convert python parameters + /// 2) Cannot execute querystring. + pub fn execute_many<'a>( + &'a self, + py: Python<'a>, + querystring: String, + parameters: Option<&'a PyList>, + prepared: Option, + ) -> RustPSQLDriverPyResult<&PyAny> { + let transaction_arc = self.inner_connection.clone(); + let mut params: Vec> = vec![]; + if let Some(parameters) = parameters { + for single_parameters in parameters { + params.push(convert_parameters(single_parameters)?); } + } - let result = if is_prepared { - connection_arc - .query( - &connection_arc.prepare_cached(&querystring).await?, - &vec_parameters.into_boxed_slice(), - ) - .await? - } else { - connection_arc - .query(&querystring, &vec_parameters.into_boxed_slice()) - .await? - }; + rustdriver_future(py, async move { + transaction_arc + .inner_execute_many(querystring, params, prepared.unwrap_or(true)) + .await + }) + } - Ok(PSQLDriverPyQueryResult::new(result)) + /// Execute querystring with parameters and return first row. + /// + /// It converts incoming parameters to rust readable, + /// executes query with them and returns first row of response. + /// + /// # Errors + /// + /// May return Err Result if: + /// 1) Cannot convert python parameters + /// 2) Cannot execute querystring. + /// 3) Query returns more than one row. + pub fn fetch_row<'a>( + &'a self, + py: Python<'a>, + querystring: String, + parameters: Option<&'a PyList>, + prepared: Option, + ) -> RustPSQLDriverPyResult<&PyAny> { + let transaction_arc = self.inner_connection.clone(); + let mut params: Vec = vec![]; + if let Some(parameters) = parameters { + params = convert_parameters(parameters)?; + } + + rustdriver_future(py, async move { + transaction_arc + .inner_fetch_row(querystring, params, prepared.unwrap_or(true)) + .await + }) + } + + /// Execute querystring with parameters and return first value in the first row. + /// + /// It converts incoming parameters to rust readable, + /// executes query with them and returns first row of response. + /// + /// # Errors + /// + /// May return Err Result if: + /// 1) Cannot convert python parameters + /// 2) Cannot execute querystring. + /// 3) Query returns more than one row + pub fn fetch_val<'a>( + &'a self, + py: Python<'a>, + querystring: String, + parameters: Option<&'a PyList>, + prepared: Option, + ) -> RustPSQLDriverPyResult<&PyAny> { + let transaction_arc = self.inner_connection.clone(); + let mut params: Vec = vec![]; + if let Some(parameters) = parameters { + params = convert_parameters(parameters)?; + } + + rustdriver_future(py, async move { + let first_row = transaction_arc + .inner_fetch_row(querystring, params, prepared.unwrap_or(true)) + .await? + .get_inner(); + Python::with_gil(|py| match first_row.columns().first() { + Some(first_column) => postgres_to_py(py, &first_row, first_column, 0), + None => Ok(py.None()), + }) }) } diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 751bfea8..e50eed25 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,16 +1,19 @@ use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; use pyo3::{pyclass, pymethods, PyAny, Python}; use std::{str::FromStr, sync::Arc, vec}; -use tokio_postgres::{types::ToSql, NoTls}; +use tokio_postgres::NoTls; use crate::{ common::rustdriver_future, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, query_result::PSQLDriverPyQueryResult, - value_converter::{convert_parameters, PythonDTO}, + value_converter::{convert_parameters, PythonDTO, QueryParameter}, }; -use super::{common_options::ConnRecyclingMethod, connection::Connection}; +use super::{ + common_options::ConnRecyclingMethod, + connection::{Connection, RustConnection}, +}; /// `PSQLPool` is for internal use only. /// @@ -70,7 +73,9 @@ impl RustPSQLPool { .get() .await?; - Ok(Connection::new(Arc::new(db_pool_manager))) + Ok(Connection::new(Arc::new(RustConnection::new(Arc::new( + tokio::sync::RwLock::new(db_pool_manager), + ))))) } /// Execute querystring with parameters. /// @@ -94,10 +99,10 @@ impl RustPSQLPool { .get() .await?; - let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(parameters.len()); - for param in ¶meters { - vec_parameters.push(param); - } + let vec_parameters: Vec<&QueryParameter> = parameters + .iter() + .map(|param| param as &QueryParameter) + .collect(); let result = if prepared { db_pool_manager diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index 322176af..7db482a6 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -3,7 +3,7 @@ use pyo3::{ PyRefMut, Python, }; use std::sync::Arc; -use tokio_postgres::{types::ToSql, Row}; +use tokio_postgres::Row; use crate::{ common::rustdriver_future, @@ -60,12 +60,6 @@ impl InnerCursor { pub async fn inner_start(&mut self) -> RustPSQLDriverPyResult<()> { let db_transaction_arc = self.db_transaction.clone(); - let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = - Vec::with_capacity(self.parameters.len()); - for param in &self.parameters { - vec_parameters.push(param); - } - let mut cursor_init_query = format!("DECLARE {}", self.cursor_name); if let Some(scroll) = self.scroll { if scroll { @@ -80,7 +74,7 @@ impl InnerCursor { db_transaction_arc .read() .await - .inner_execute(cursor_init_query, &self.parameters, self.prepared) + .inner_execute(cursor_init_query, self.parameters.clone(), self.prepared) .await?; self.is_started = true; diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index 6ae8b70d..fcca2921 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -6,9 +6,8 @@ use crate::{ common::rustdriver_future, exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, - value_converter::{convert_parameters, postgres_to_py, PythonDTO, ValueOrReferenceTo}, + value_converter::{convert_parameters, postgres_to_py, PythonDTO}, }; -use deadpool_postgres::Object; use futures_util::future; use pyo3::{ pyclass, pymethods, @@ -17,14 +16,16 @@ use pyo3::{ }; use std::{collections::HashSet, sync::Arc, vec}; use tokio::sync::RwLock; -use tokio_postgres::{types::ToSql, Row}; +use tokio_postgres::Row; + +use super::connection::RustConnection; /// Transaction for internal use only. /// /// It is not exposed to python. #[allow(clippy::module_name_repetitions)] pub struct RustTransaction { - pub db_client: Arc, + connection: Arc, is_started: bool, is_done: bool, rollback_savepoint: Arc>>, @@ -37,7 +38,7 @@ pub struct RustTransaction { impl RustTransaction { #[allow(clippy::too_many_arguments)] pub fn new( - db_client: Arc, + connection: Arc, is_started: bool, is_done: bool, rollback_savepoint: Arc>>, @@ -46,7 +47,7 @@ impl RustTransaction { deferable: Option, ) -> Self { Self { - db_client, + connection, is_started, is_done, rollback_savepoint, @@ -56,6 +57,21 @@ impl RustTransaction { } } + fn check_is_transaction_ready(&self) -> RustPSQLDriverPyResult<()> { + if !self.is_started { + return Err(RustPSQLDriverError::DataBaseTransactionError( + "Transaction is not started, please call begin() on transaction".into(), + )); + } + if self.is_done { + return Err(RustPSQLDriverError::DataBaseTransactionError( + "Transaction is already committed or rolled back".into(), + )); + } + + Ok(()) + } + /// Execute querystring with parameters. /// /// Method doesn't acquire lock on any structure fields. @@ -69,46 +85,16 @@ impl RustTransaction { /// 2) Transaction is done already /// 3) Can not create/retrieve prepared statement /// 4) Can not execute statement - pub async fn inner_execute( + pub async fn inner_execute( &self, querystring: String, - parameters: T, + parameters: Vec, prepared: bool, - ) -> RustPSQLDriverPyResult - where - T: ValueOrReferenceTo>, - { - if !self.is_started { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is not started, please call begin() on transaction".into(), - )); - } - if self.is_done { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is already committed or rolled back".into(), - )); - } - - let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = - Vec::with_capacity(parameters.as_ref().len()); - for param in parameters.as_ref() { - vec_parameters.push(param); - } - - let result = if prepared { - self.db_client - .query( - &self.db_client.prepare_cached(&querystring).await?, - &vec_parameters.into_boxed_slice(), - ) - .await? - } else { - self.db_client - .query(&querystring, &vec_parameters.into_boxed_slice()) - .await? - }; - - Ok(PSQLDriverPyQueryResult::new(result)) + ) -> RustPSQLDriverPyResult { + self.check_is_transaction_ready()?; + self.connection + .inner_execute(querystring, parameters, prepared) + .await } /// Execute querystring with parameters. @@ -126,46 +112,16 @@ impl RustTransaction { /// 2) Transaction is done already /// 3) Can not create/retrieve prepared statement /// 4) Can not execute statement - pub async fn inner_execute_raw( + pub async fn inner_execute_raw( &self, querystring: String, - parameters: T, + parameters: Vec, prepared: bool, - ) -> RustPSQLDriverPyResult> - where - T: ValueOrReferenceTo>, - { - if !self.is_started { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is not started, please call begin() on transaction".into(), - )); - } - if self.is_done { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is already committed or rolled back".into(), - )); - } - - let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = - Vec::with_capacity(parameters.as_ref().len()); - for param in parameters.as_ref() { - vec_parameters.push(param); - } - - let result = if prepared { - self.db_client - .query( - &self.db_client.prepare_cached(&querystring).await?, - &vec_parameters.into_boxed_slice(), - ) - .await? - } else { - self.db_client - .query(&querystring, &vec_parameters.into_boxed_slice()) - .await? - }; - - Ok(result) + ) -> RustPSQLDriverPyResult> { + self.check_is_transaction_ready()?; + self.connection + .inner_execute_raw(querystring, parameters, prepared) + .await } /// Execute querystring with many parameters. @@ -187,43 +143,16 @@ impl RustTransaction { parameters: Vec>, prepared: bool, ) -> RustPSQLDriverPyResult<()> { - if !self.is_started { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is not started, please call begin() on transaction".into(), - )); - } - if self.is_done { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is already committed or rolled back".into(), - )); - } + self.check_is_transaction_ready()?; if parameters.is_empty() { return Err(RustPSQLDriverError::DataBaseTransactionError( "No parameters passed to execute_many".into(), )); } for single_parameters in parameters { - if prepared { - self.db_client - .query( - &self.db_client.prepare_cached(&querystring).await?, - &single_parameters - .iter() - .map(|p| p as &(dyn ToSql + Sync)) - .collect::>(), - ) - .await?; - } else { - self.db_client - .query( - &querystring, - &single_parameters - .iter() - .map(|p| p as &(dyn ToSql + Sync)) - .collect::>(), - ) - .await?; - } + self.connection + .inner_execute(querystring.clone(), single_parameters, prepared) + .await?; } Ok(()) @@ -249,36 +178,10 @@ impl RustTransaction { parameters: Vec, prepared: bool, ) -> RustPSQLDriverPyResult { - if !self.is_started { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is not started, please call begin() on transaction".into(), - )); - } - if self.is_done { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is already committed or rolled back".into(), - )); - } - - let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(parameters.len()); - for param in ¶meters { - vec_parameters.push(param); - } - - let result = if prepared { - self.db_client - .query_one( - &self.db_client.prepare_cached(&querystring).await?, - &vec_parameters.into_boxed_slice(), - ) - .await? - } else { - self.db_client - .query_one(&querystring, &vec_parameters.into_boxed_slice()) - .await? - }; - - Ok(PSQLDriverSinglePyQueryResult::new(result)) + self.check_is_transaction_ready()?; + self.connection + .inner_fetch_row(querystring, parameters, prepared) + .await } /// Run many queries as pipeline. @@ -329,8 +232,8 @@ impl RustTransaction { Some(false) => " NOT DEFERRABLE", None => "", }); - - self.db_client.batch_execute(&querystring).await?; + let db_client_guard = self.connection.db_client.read().await; + db_client_guard.batch_execute(&querystring).await?; Ok(()) } @@ -375,18 +278,9 @@ impl RustTransaction { /// 2) Transaction is done /// 3) Cannot execute `COMMIT` command pub async fn inner_commit(&mut self) -> RustPSQLDriverPyResult<()> { - if !self.is_started { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Can not commit not started transaction".into(), - )); - } - - if self.is_done { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is already committed or rolled back".into(), - )); - } - self.db_client.batch_execute("COMMIT;").await?; + self.check_is_transaction_ready()?; + let db_client_guard = self.connection.db_client.read().await; + db_client_guard.batch_execute("COMMIT;").await?; self.is_done = true; Ok(()) @@ -404,17 +298,7 @@ impl RustTransaction { /// 3) Specified savepoint name is exists /// 4) Can not execute SAVEPOINT command pub async fn inner_savepoint(&self, savepoint_name: String) -> RustPSQLDriverPyResult<()> { - if !self.is_started { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Can not commit not started transaction".into(), - )); - } - - if self.is_done { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is already committed or rolled back".into(), - )); - }; + self.check_is_transaction_ready()?; let is_savepoint_name_exists = { let rollback_savepoint_read_guard = self.rollback_savepoint.read().await; @@ -426,7 +310,8 @@ impl RustTransaction { "SAVEPOINT name {savepoint_name} is already taken by this transaction", ))); } - self.db_client + let db_client_guard = self.connection.db_client.read().await; + db_client_guard .batch_execute(format!("SAVEPOINT {savepoint_name}").as_str()) .await?; let mut rollback_savepoint_guard = self.rollback_savepoint.write().await; @@ -444,18 +329,9 @@ impl RustTransaction { /// 2) Transaction is done /// 3) Can not execute ROLLBACK command pub async fn inner_rollback(&mut self) -> RustPSQLDriverPyResult<()> { - if !self.is_started { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Can not commit not started transaction".into(), - )); - }; - - if self.is_done { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is already committed or rolled back".into(), - )); - }; - self.db_client.batch_execute("ROLLBACK").await?; + self.check_is_transaction_ready()?; + let db_client_guard = self.connection.db_client.read().await; + db_client_guard.batch_execute("ROLLBACK").await?; self.is_done = true; Ok(()) } @@ -471,16 +347,7 @@ impl RustTransaction { /// 3) Specified savepoint name doesn't exist /// 4) Can not execute ROLLBACK TO SAVEPOINT command pub async fn inner_rollback_to(&self, rollback_name: String) -> RustPSQLDriverPyResult<()> { - if !self.is_started { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Can not commit not started transaction".into(), - )); - }; - if self.is_done { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is already committed or rolled back".into(), - )); - }; + self.check_is_transaction_ready()?; let rollback_savepoint_arc = self.rollback_savepoint.clone(); let is_rollback_exists = { @@ -492,7 +359,8 @@ impl RustTransaction { "Don't have rollback with this name".into(), )); } - self.db_client + let db_client_guard = self.connection.db_client.read().await; + db_client_guard .batch_execute(format!("ROLLBACK TO SAVEPOINT {rollback_name}").as_str()) .await?; @@ -513,16 +381,7 @@ impl RustTransaction { &self, rollback_name: String, ) -> RustPSQLDriverPyResult<()> { - if !self.is_started { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Can not commit not started transaction".into(), - )); - }; - if self.is_done { - return Err(RustPSQLDriverError::DataBaseTransactionError( - "Transaction is already committed or rolled back".into(), - )); - }; + self.check_is_transaction_ready()?; let mut rollback_savepoint_guard = self.rollback_savepoint.write().await; let is_rollback_exists = rollback_savepoint_guard.remove(&rollback_name); @@ -532,7 +391,8 @@ impl RustTransaction { "Don't have rollback with this name".into(), )); } - self.db_client + let db_client_guard = self.connection.db_client.read().await; + db_client_guard .batch_execute(format!("RELEASE SAVEPOINT {rollback_name}").as_str()) .await?; diff --git a/src/value_converter.rs b/src/value_converter.rs index a84a45b0..da8fd92c 100644 --- a/src/value_converter.rs +++ b/src/value_converter.rs @@ -22,23 +22,7 @@ use crate::{ extra_types::{BigInt, Integer, PyJSON, PyUUID, SmallInt}, }; -/// Trait allows pass variable by value and by reference. -pub trait ValueOrReferenceTo { - fn as_ref(&self) -> &T; -} - -impl ValueOrReferenceTo for &T { - #[allow(clippy::explicit_auto_deref)] - fn as_ref(&self) -> &T { - *self - } -} - -impl ValueOrReferenceTo for T { - fn as_ref(&self) -> &T { - self - } -} +pub type QueryParameter = (dyn ToSql + Sync); /// Additional type for types come from Python. ///