Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow creating a sea_orm::Databse from sqlx::ConnectOptions #434

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-binder?/with-uuid", "sqlx
with-time = ["time", "sea-query/with-time", "sea-query-binder?/with-time", "sqlx?/time"]
postgres-array = ["sea-query/postgres-array", "sea-query-binder?/postgres-array", "sea-orm-macros?/postgres-array"]
sqlx-dep = []
sqlx-all = ["sqlx-mysql", "sqlx-postgres", "sqlx-sqlite"]
sqlx-all = ["sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", "sqlx-any"]
sqlx-mysql = ["sqlx-dep", "sea-query-binder/sqlx-mysql", "sqlx/mysql"]
sqlx-any = ["sqlx-dep", "sqlx/any"]
sqlx-postgres = ["sqlx-dep", "sea-query-binder/sqlx-postgres", "sqlx/postgres"]
sqlx-sqlite = ["sqlx-dep", "sea-query-binder/sqlx-sqlite", "sqlx/sqlite"]
runtime-async-std = []
Expand Down
2 changes: 1 addition & 1 deletion examples/rocket_example/api/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl sea_orm_rocket::Pool for SeaOrmPool {

async fn init(figment: &Figment) -> Result<Self, Self::Error> {
let config = figment.extract::<Config>().unwrap();
let mut options: ConnectOptions = config.url.into();
let mut options: ConnectOptions = config.url.try_into()?;
options
.max_connections(config.max_connections as u32)
.min_connections(config.min_connections.unwrap_or_default())
Expand Down
2 changes: 1 addition & 1 deletion examples/rocket_okapi_example/api/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl sea_orm_rocket::Pool for SeaOrmPool {

async fn init(figment: &Figment) -> Result<Self, Self::Error> {
let config = figment.extract::<Config>().unwrap();
let mut options: ConnectOptions = config.url.into();
let mut options: ConnectOptions = config.url.try_into()?;
options
.max_connections(config.max_connections as u32)
.min_connections(config.min_connections.unwrap_or_default())
Expand Down
300 changes: 263 additions & 37 deletions src/database/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
#[cfg(feature = "sqlx-any")]
use sqlx::any::AnyKind;
#[cfg(feature = "sqlx-mysql")]
use sqlx::mysql::MySqlConnectOptions;
#[cfg(feature = "sqlx-postgres")]
use sqlx::postgres::PgConnectOptions;
#[cfg(feature = "sqlx-sqlite")]
use sqlx::sqlite::SqliteConnectOptions;
use std::fmt::Debug;
use std::time::Duration;

mod connection;
Expand All @@ -24,11 +33,58 @@ use crate::{DbErr, RuntimeErr};
#[derive(Debug, Default)]
pub struct Database;

/// Supported database kinds of [sqlx::ConnectOptions]'.
#[derive(Debug, Clone)]
pub enum SqlxConnectOptions {
#[cfg(feature = "sqlx-mysql")]
/// Variant for [MySqlConnectOptions]
MySql(MySqlConnectOptions),
#[cfg(feature = "sqlx-postgres")]
/// Variant for [PgConnectOptions]
Postgres(PgConnectOptions),
#[cfg(feature = "sqlx-sqlite")]
/// Variant for [SqliteConnectOptions]
Sqlite(SqliteConnectOptions),
#[cfg(feature = "mock")]
/// Variant for a mock connection
Mock(DbBackend),
}

impl SqlxConnectOptions {
/// The database backend type
pub fn get_db_backend_type(&self) -> DbBackend {
match self {
#[cfg(feature = "sqlx-mysql")]
SqlxConnectOptions::MySql(_) => DbBackend::MySql,
#[cfg(feature = "sqlx-postgres")]
SqlxConnectOptions::Postgres(_) => DbBackend::Postgres,
#[cfg(feature = "sqlx-sqlite")]
SqlxConnectOptions::Sqlite(_) => DbBackend::Sqlite,
#[cfg(feature = "mock")]
SqlxConnectOptions::Mock(db_backend) => *db_backend,
}
}

#[cfg(feature = "mock")]
/// Create a mock database connection options
pub fn mock(db_backend: DbBackend) -> SqlxConnectOptions {
Self::Mock(db_backend)
}

#[cfg(feature = "mock")]
/// Is this for mock connection?
pub fn is_mock(&self) -> bool {
matches!(self, SqlxConnectOptions::Mock(_))
}
}

/// Defines the configuration options of a database
#[derive(Debug, Clone)]
pub struct ConnectOptions {
/// The URI of the database
pub(crate) url: String,
/// The database sqlx::ConnectOptions used to connect to the database.
pub(crate) connect_options: SqlxConnectOptions,
/// The URI of the database, if this struct was created from an URI string, otherwise None
pub(crate) url: Option<String>,
/// Maximum number of connections for a pool
pub(crate) max_connections: Option<u32>,
/// Minimum number of connections for a pool
Expand All @@ -55,58 +111,177 @@ pub struct ConnectOptions {
impl Database {
/// Method to create a [DatabaseConnection] on a database
#[instrument(level = "trace", skip(opt))]
pub async fn connect<C>(opt: C) -> Result<DatabaseConnection, DbErr>
pub async fn connect<C, E>(opt: C) -> Result<DatabaseConnection, DbErr>
where
C: Into<ConnectOptions>,
C: TryInto<ConnectOptions, Error = E> + Debug,
E: std::error::Error,
{
let opt: ConnectOptions = opt.into();
let describe = format!("{:?}", opt);
let opt: ConnectOptions = opt
.try_into()
.map_err(|e| DbErr::Conn(
RuntimeErr::Internal(format!("Couldn't parse connection options {} {}", describe, e))
))?;

#[cfg(feature = "sqlx-mysql")]
if DbBackend::MySql.is_prefix_of(&opt.url) {
return crate::SqlxMySqlConnector::connect(opt).await;
}
#[cfg(feature = "sqlx-postgres")]
if DbBackend::Postgres.is_prefix_of(&opt.url) {
return crate::SqlxPostgresConnector::connect(opt).await;
}
#[cfg(feature = "sqlx-sqlite")]
if DbBackend::Sqlite.is_prefix_of(&opt.url) {
return crate::SqlxSqliteConnector::connect(opt).await;
}
#[cfg(feature = "mock")]
if crate::MockDatabaseConnector::accepts(&opt.url) {
return crate::MockDatabaseConnector::connect(&opt.url).await;
if opt.connect_options.is_mock() {
return crate::MockDatabaseConnector::connect(opt).await;
}

let backend = opt.connect_options.get_db_backend_type();

match backend {
#[cfg(feature = "sqlx-mysql")]
DbBackend::MySql => crate::SqlxMySqlConnector::connect(opt).await,
#[cfg(feature = "sqlx-postgres")]
DbBackend::Postgres => crate::SqlxPostgresConnector::connect(opt).await,
#[cfg(feature = "sqlx-sqlite")]
DbBackend::Sqlite => crate::SqlxSqliteConnector::connect(opt).await,
#[cfg(not(all(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))]
_ => unreachable!(),
}
Err(DbErr::Conn(RuntimeErr::Internal(format!(
"The connection string '{}' has no supporting driver.",
opt.url
))))
}
}

impl From<&str> for ConnectOptions {
fn from(string: &str) -> ConnectOptions {
impl TryFrom<&str> for ConnectOptions {
type Error = DbErr;

fn try_from(string: &str) -> Result<Self, Self::Error> {
ConnectOptions::from_str(string)
}
}

impl From<&String> for ConnectOptions {
fn from(string: &String) -> ConnectOptions {
impl TryFrom<&String> for ConnectOptions {
type Error = DbErr;

fn try_from(string: &String) -> Result<Self, Self::Error> {
ConnectOptions::from_str(string.as_str())
}
}

impl From<String> for ConnectOptions {
fn from(string: String) -> ConnectOptions {
ConnectOptions::new(string)
impl TryFrom<String> for ConnectOptions {
type Error = DbErr;

fn try_from(string: String) -> Result<Self, Self::Error> {
ConnectOptions::new_from_url(string)
}
}

#[cfg(feature = "sqlx-mysql")]
impl TryFrom<MySqlConnectOptions> for ConnectOptions {
type Error = DbErr;

fn try_from(connect_options: MySqlConnectOptions) -> Result<Self, Self::Error> {
Ok(ConnectOptions::new(SqlxConnectOptions::MySql(
connect_options,
)))
}
}

#[cfg(feature = "sqlx-postgres")]
impl TryFrom<PgConnectOptions> for ConnectOptions {
type Error = DbErr;

fn try_from(connect_options: PgConnectOptions) -> Result<Self, Self::Error> {
Ok(ConnectOptions::new(SqlxConnectOptions::Postgres(
connect_options,
)))
}
}

#[cfg(feature = "sqlx-sqlite")]
impl TryFrom<SqliteConnectOptions> for ConnectOptions {
type Error = DbErr;

fn try_from(connect_options: SqliteConnectOptions) -> Result<Self, Self::Error> {
Ok(ConnectOptions::new(SqlxConnectOptions::Sqlite(
connect_options,
)))
}
}

#[cfg(feature = "sqlx-any")]
impl TryFrom<sqlx::any::AnyConnectOptions> for ConnectOptions {
type Error = DbErr;

fn try_from(connect_options: sqlx::any::AnyConnectOptions) -> Result<Self, Self::Error> {
Ok(ConnectOptions::new(connect_options.try_into()?))
}
}

#[cfg(feature = "sqlx-mysql")]
impl TryFrom<MySqlConnectOptions> for SqlxConnectOptions {
type Error = DbErr;

fn try_from(connect_options: MySqlConnectOptions) -> Result<Self, Self::Error> {
Ok(SqlxConnectOptions::MySql(connect_options))
}
}

#[cfg(feature = "sqlx-postgres")]
impl TryFrom<PgConnectOptions> for SqlxConnectOptions {
type Error = DbErr;

fn try_from(connect_options: PgConnectOptions) -> Result<Self, Self::Error> {
Ok(SqlxConnectOptions::Postgres(connect_options))
}
}

#[cfg(feature = "sqlx-sqlite")]
impl TryFrom<SqliteConnectOptions> for SqlxConnectOptions {
type Error = DbErr;

fn try_from(connect_options: SqliteConnectOptions) -> Result<Self, Self::Error> {
Ok(SqlxConnectOptions::Sqlite(connect_options))
}
}

#[cfg(feature = "sqlx-any")]
impl TryFrom<sqlx::any::AnyConnectOptions> for SqlxConnectOptions {
type Error = DbErr;

fn try_from(connect_options: sqlx::any::AnyConnectOptions) -> Result<Self, Self::Error> {
match connect_options.kind() {
#[cfg(feature = "sqlx-postgres")]
AnyKind::Postgres => Ok(SqlxConnectOptions::Postgres(
connect_options.as_postgres().unwrap().clone(),
)),
#[cfg(feature = "sqlx-mysql")]
AnyKind::MySql => Ok(SqlxConnectOptions::MySql(
connect_options.as_mysql().unwrap().clone(),
)),
#[cfg(feature = "sqlx-sqlite")]
AnyKind::Sqlite => Ok(SqlxConnectOptions::Sqlite(
connect_options.as_sqlite().unwrap().clone(),
)),
}
}
}

impl ConnectOptions {
/// Create new [ConnectOptions] for a [Database] by passing in a URI string
pub fn new(url: String) -> Self {
/// Create new [ConnectOptions] for a [Database] by passing in a [sqlx::ConnectOptions]
pub fn new(connect_options: SqlxConnectOptions) -> Self {
Self {
url,
connect_options,
url: None,
max_connections: None,
min_connections: None,
connect_timeout: None,
idle_timeout: None,
acquire_timeout: None,
max_lifetime: None,
sqlx_logging: true,
sqlx_logging_level: log::LevelFilter::Info,
sqlcipher_key: None,
schema_search_path: None
}
}

/// Create new [ConnectOptions] for a [Database] by passing in a URI string
pub fn new_from_url(url: String) -> Result<Self, DbErr> {
Ok(Self {
connect_options: Self::url_to_sqlx_connect_options(url.clone())?,
url: Some(url),
max_connections: None,
min_connections: None,
connect_timeout: None,
Expand All @@ -117,11 +292,54 @@ impl ConnectOptions {
sqlx_logging_level: log::LevelFilter::Info,
sqlcipher_key: None,
schema_search_path: None,
})
}

fn url_to_sqlx_connect_options(url: String) -> Result<SqlxConnectOptions, DbErr> {
#[cfg(feature = "sqlx-mysql")]
if DbBackend::MySql.is_prefix_of(&url) {
return url
.parse::<MySqlConnectOptions>()
.map_err(crate::sqlx_error_to_conn_err)?
.try_into();
}
#[cfg(feature = "sqlx-postgres")]
if DbBackend::Postgres.is_prefix_of(&url) {
return url
.parse::<PgConnectOptions>()
.map_err(crate::sqlx_error_to_conn_err)?
.try_into();
}
#[cfg(feature = "sqlx-sqlite")]
if DbBackend::Sqlite.is_prefix_of(&url) {
return url
.parse::<SqliteConnectOptions>()
.map_err(crate::sqlx_error_to_conn_err)?
.try_into();
}
#[cfg(feature = "mock")]
if crate::MockDatabaseConnector::accepts(&url) {
if DbBackend::MySql.is_prefix_of(&url) {
return Ok(SqlxConnectOptions::Mock(DbBackend::MySql));
}
#[cfg(feature = "sqlx-postgres")]
if DbBackend::Postgres.is_prefix_of(&url) {
return Ok(SqlxConnectOptions::Mock(DbBackend::Postgres));
}
#[cfg(feature = "sqlx-sqlite")]
if DbBackend::Sqlite.is_prefix_of(&url) {
return Ok(SqlxConnectOptions::Mock(DbBackend::Sqlite));
}
return Ok(SqlxConnectOptions::Mock(DbBackend::Postgres));
}
Err(DbErr::Conn(RuntimeErr::Internal(format!(
"The connection string '{}' has no supporting driver.",
url
))))
}

fn from_str(url: &str) -> Self {
Self::new(url.to_owned())
fn from_str(url: &str) -> Result<Self, DbErr> {
Self::new_from_url(url.to_owned())
}

#[cfg(feature = "sqlx-dep")]
Expand Down Expand Up @@ -152,11 +370,19 @@ impl ConnectOptions {
opt
}

/// Get the database URL of the pool
pub fn get_url(&self) -> &str {
/// Get the database URL of the pool. This is only present if the pool was created from an URL.
/// If it was created from some sqlx::ConnectOptions then this method returns None.
///
/// To get the actual ConnectOptions used to connect to the database see: [Self::get_connect_options].
pub fn get_url(&self) -> &Option<String> {
&self.url
}

/// Get the ConnectOptions used to connect to the database
pub fn get_connect_options(&self) -> &SqlxConnectOptions {
&self.connect_options
}

/// Set the maximum number of connections of the pool
pub fn max_connections(&mut self, value: u32) -> &mut Self {
self.max_connections = Some(value);
Expand Down
2 changes: 1 addition & 1 deletion src/database/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ impl DatabaseTransaction {
if let Err(sqlx::Error::RowNotFound) = err {
Ok(None)
} else {
err.map_err(|e| sqlx_error_to_query_err(e))
err.map_err(sqlx_error_to_query_err)
}
}
}
Expand Down
Loading