Skip to content

Commit 70a22e2

Browse files
committed
Refactor TLS code to be a bit easier to read
1 parent 2eb7c06 commit 70a22e2

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

sqlx-core/Cargo.toml

+4-3
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ decimal = [ "rust_decimal", "num-bigint" ]
3434
json = [ "serde", "serde_json" ]
3535

3636
# runtimes
37-
runtime-actix-native-tls = [ "sqlx-rt/runtime-actix-native-tls", "_rt-actix" ]
38-
runtime-async-std-native-tls = [ "sqlx-rt/runtime-async-std-native-tls", "_rt-async-std" ]
39-
runtime-tokio-native-tls = [ "sqlx-rt/runtime-tokio-native-tls", "_rt-tokio" ]
37+
runtime-actix-native-tls = [ "sqlx-rt/runtime-actix-native-tls", "_tls-native-tls", "_rt-actix" ]
38+
runtime-async-std-native-tls = [ "sqlx-rt/runtime-async-std-native-tls", "_tls-native-tls", "_rt-async-std" ]
39+
runtime-tokio-native-tls = [ "sqlx-rt/runtime-tokio-native-tls", "_tls-native-tls", "_rt-tokio" ]
4040

4141
# for conditional compilation
4242
_rt-actix = []
4343
_rt-async-std = []
4444
_rt-tokio = []
45+
_tls-native-tls = []
4546

4647
# support offline/decoupled building (enables serialization of `Describe`)
4748
offline = [ "serde", "either/serde" ]

sqlx-core/src/error.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,6 @@ impl Error {
129129
pub(crate) fn config(err: impl StdError + Send + Sync + 'static) -> Self {
130130
Error::Configuration(err.into())
131131
}
132-
133-
#[allow(dead_code)]
134-
#[inline]
135-
pub(crate) fn tls(err: impl StdError + Send + Sync + 'static) -> Self {
136-
Error::Tls(err.into())
137-
}
138132
}
139133

140134
pub(crate) fn mismatched_types<DB: Database, T: Type<DB>>(ty: &DB::TypeInfo) -> BoxDynError {
@@ -240,6 +234,14 @@ impl From<crate::migrate::MigrateError> for Error {
240234
}
241235
}
242236

237+
#[cfg(feature = "_tls-native-tls")]
238+
impl From<sqlx_rt::native_tls::Error> for Error {
239+
#[inline]
240+
fn from(error: sqlx_rt::native_tls::Error) -> Self {
241+
Error::Tls(Box::new(error))
242+
}
243+
}
244+
243245
// Format an error message as a `Protocol` error
244246
macro_rules! err_protocol {
245247
($expr:expr) => {

sqlx-core/src/net/tls.rs

+4-9
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ where
4848
if !accept_invalid_certs {
4949
if let Some(ca) = root_cert_path {
5050
let data = fs::read(ca).await?;
51-
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
51+
let cert = Certificate::from_pem(&data)?;
5252

5353
builder.add_root_certificate(cert);
5454
}
5555
}
5656

5757
#[cfg(not(feature = "_rt-async-std"))]
58-
let connector = builder.build().map_err(Error::tls)?;
58+
let connector = sqlx_rt::TlsConnector::from(builder.build()?);
5959

6060
#[cfg(feature = "_rt-async-std")]
61-
let connector = builder;
61+
let connector = sqlx_rt::TlsConnector::from(builder);
6262

6363
let stream = match replace(self, MaybeTlsStream::Upgrading) {
6464
MaybeTlsStream::Raw(stream) => stream,
@@ -75,12 +75,7 @@ where
7575
}
7676
};
7777

78-
*self = MaybeTlsStream::Tls(
79-
sqlx_rt::TlsConnector::from(connector)
80-
.connect(host, stream)
81-
.await
82-
.map_err(|err| Error::Tls(err.into()))?,
83-
);
78+
*self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);
8479

8580
Ok(())
8681
}

0 commit comments

Comments
 (0)