Skip to content

Commit

Permalink
feat(wip): wasm rustls
Browse files Browse the repository at this point in the history
Signed-off-by: csh <[email protected]>
  • Loading branch information
L-jasmine committed Nov 7, 2023
1 parent 520f70b commit 3e1b203
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 19 deletions.
32 changes: 30 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,20 @@ url = "2.1"
# mio = { version = "0.8.0", features = ["os-poll", "net"] }

[target.'cfg(target_os="wasi")'.dependencies]
tokio_wasi = {version = "1", features = [ "io-util", "fs", "net", "time", "rt", "macros", "sync"] }
tokio_wasi = { version = "1", features = [
"io-util",
"fs",
"net",
"time",
"rt",
"macros",
"sync",
] }
tokio-util_wasi = { version = "0.7.2", features = ["codec", "io"] }
wasmedge_wasi_socket = "0.4.3"
wasmedge_rustls_api = { version = "0.1.0", optional = true, features = [
"tokio_async",
] }

# [target.'cfg(not(target_os="wasi"))'.dev-dependencies]
# tempfile = "3.1.0"
Expand Down Expand Up @@ -88,7 +99,14 @@ optional = true

[dev-dependencies]
tempfile = "3.1.0"
tokio_wasi = { version = "1", features = [ "io-util", "fs", "net", "time", "rt", "macros"] }
tokio_wasi = { version = "1", features = [
"io-util",
"fs",
"net",
"time",
"rt",
"macros",
] }
rand = "0.8.0"

[features]
Expand All @@ -99,6 +117,7 @@ default = [
"mysql_common/uuid",
"mysql_common/frunk",
# "native-tls-tls",
"wasmedge-tls",
]
default-rustls = [
"flate2/zlib",
Expand All @@ -109,6 +128,15 @@ default-rustls = [
"mysql_common/frunk",
"rustls-tls",
]
wasmedge-tls = [
# "flate2/zlib",
"mysql_common/bigdecimal03",
"mysql_common/rust_decimal",
"mysql_common/time03",
"mysql_common/uuid",
"mysql_common/frunk",
"wasmedge_rustls_api",
]
minimal = ["flate2/zlib"]
native-tls-tls = ["native-tls", "tokio-native-tls"]
rustls-tls = [
Expand Down
22 changes: 15 additions & 7 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,14 +415,22 @@ impl Conn {

/// Returns true if io stream is encrypted.
fn is_secure(&self) -> bool {
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
#[cfg(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]
if let Some(ref stream) = self.inner.stream {
stream.is_secure()
} else {
false
}

#[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
#[cfg(not(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
)))]
false
}

Expand Down Expand Up @@ -486,7 +494,7 @@ impl Conn {
};
Ok(())
}
#[cfg(not(target_os = "wasi"))]
#[cfg(any(not(target_os = "wasi"), feature = "wasmedge-tls"))]
async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
if self
.inner
Expand All @@ -503,12 +511,12 @@ impl Conn {
}

let collation = if self.inner.version >= (5, 5, 3) {
UTF8MB4_GENERAL_CI
mysql_common::constants::UTF8MB4_GENERAL_CI
} else {
UTF8_GENERAL_CI
crate::consts::UTF8MB4_GENERAL_CI
};

let ssl_request = SslRequest::new(
let ssl_request = mysql_common::packets::SslRequest::new(
self.inner.capabilities,
DEFAULT_MAX_ALLOWED_PACKET as u32,
collation as u8,
Expand Down Expand Up @@ -846,7 +854,7 @@ impl Conn {
conn.inner.stream = Some(stream);
conn.setup_stream()?;
conn.handle_handshake().await?;
#[cfg(not(target_os = "wasi"))]
#[cfg(any(not(target_os = "wasi"), feature = "wasmedge-tls"))]
conn.switch_to_ssl_if_needed().await?;
conn.do_handshake_response().await?;
conn.continue_auth().await?;
Expand Down
6 changes: 5 additions & 1 deletion src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ impl Error {
pub enum IoError {
#[error("Input/output error: {}", _0)]
Io(#[source] io::Error),
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
#[cfg(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]
#[error("TLS error: `{}'", _0)]
Tls(#[source] tls::TlsError),
}
Expand Down
9 changes: 8 additions & 1 deletion src/error/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#![cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#![cfg(any(
feature = "native-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]

pub mod native_tls_error;
pub mod rustls_error;
Expand All @@ -8,3 +12,6 @@ pub use native_tls_error::TlsError;

#[cfg(feature = "rustls")]
pub use rustls_error::TlsError;

#[cfg(feature = "wasmedge-tls")]
pub use wasmedge_rustls_api::TlsError;
56 changes: 50 additions & 6 deletions src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ pub(crate) enum Endpoint {
Secure(#[pin] tokio_native_tls::TlsStream<TcpStream>),
#[cfg(feature = "rustls-tls")]
Secure(#[pin] tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
#[cfg(feature = "wasmedge-tls")]
Secure(#[pin] wasmedge_rustls_api::stream::async_stream::TlsStream<tokio::net::TcpStream>),
#[cfg(unix)]
Socket(#[pin] Socket),
}
Expand Down Expand Up @@ -150,7 +152,14 @@ impl Future for CheckTcpStream<'_> {
}

impl Endpoint {
#[cfg(all(any(feature = "native-tls-tls", feature = "rustls-tls"), unix))]
#[cfg(all(
any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
),
unix
))]
fn is_socket(&self) -> bool {
match self {
Self::Socket(_) => true,
Expand All @@ -177,6 +186,12 @@ impl Endpoint {
CheckTcpStream(stream).await?;
Ok(())
}
#[cfg(feature = "wasmedge-tls")]
Endpoint::Secure(tls_stream) => {
let stream = tls_stream.get_mut().0;
CheckTcpStream(stream).await?;
Ok(())
}
#[cfg(unix)]
Endpoint::Socket(socket) => {
socket.write(&[]).await?;
Expand All @@ -186,12 +201,20 @@ impl Endpoint {
}
}

#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
#[cfg(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]
pub fn is_secure(&self) -> bool {
matches!(self, Endpoint::Secure(_))
}

#[cfg(all(not(feature = "native-tls"), not(feature = "rustls")))]
#[cfg(all(
not(feature = "native-tls"),
not(feature = "rustls"),
not(feature = "wasmedge-tls")
))]
pub async fn _make_secure(
&mut self,
_domain: String,
Expand All @@ -216,6 +239,11 @@ impl Endpoint {
let stream = stream.get_ref().0;
stream.set_nodelay(val)?;
}
#[cfg(feature = "wasmedge-tls")]
Endpoint::Secure(ref stream) => {
let stream = stream.get_ref().0;
stream.set_nodelay(val)?;
}
#[cfg(unix)]
Endpoint::Socket(_) => (/* inapplicable */),
}
Expand Down Expand Up @@ -262,6 +290,8 @@ impl AsyncRead for Endpoint {
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
#[cfg(feature = "rustls-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
#[cfg(feature = "wasmedge-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
#[cfg(unix)]
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf),
})
Expand All @@ -283,6 +313,8 @@ impl AsyncWrite for Endpoint {
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
#[cfg(feature = "rustls-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
#[cfg(feature = "wasmedge-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
#[cfg(unix)]
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf),
})
Expand All @@ -301,6 +333,8 @@ impl AsyncWrite for Endpoint {
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
#[cfg(feature = "rustls-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
#[cfg(feature = "wasmedge-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
#[cfg(unix)]
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx),
})
Expand All @@ -319,6 +353,8 @@ impl AsyncWrite for Endpoint {
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
#[cfg(feature = "rustls-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
#[cfg(feature = "wasmedge-tls")]
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
#[cfg(unix)]
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx),
})
Expand Down Expand Up @@ -409,12 +445,14 @@ impl Stream {
pub(crate) fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
self.codec.as_ref().unwrap().get_ref().set_tcp_nodelay(val)
}
#[cfg(not(target_os = "wasi"))]
#[cfg(any(not(target_os = "wasi"), feature = "wasmedge-tls"))]
pub(crate) async fn make_secure(
&mut self,
domain: String,
ssl_opts: SslOpts,
ssl_opts: crate::SslOpts,
) -> crate::error::Result<()> {
use tokio_util::codec::FramedParts;

let codec = self.codec.take().unwrap();
let FramedParts { mut io, codec, .. } = codec.into_parts();
io.make_secure(domain, ssl_opts).await?;
Expand All @@ -423,7 +461,11 @@ impl Stream {
Ok(())
}

#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
#[cfg(any(
feature = "native-tls-tls",
feature = "rustls-tls",
feature = "wasmedge-tls"
))]
pub(crate) fn is_secure(&self) -> bool {
self.codec.as_ref().unwrap().get_ref().is_secure()
}
Expand Down Expand Up @@ -506,6 +548,8 @@ mod test {
super::Endpoint::Plain(Some(stream)) => stream,
#[cfg(feature = "rustls-tls")]
super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
#[cfg(feature = "wasmedge-tls")]
super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
#[cfg(feature = "native-tls")]
super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().get_ref().get_ref(),
_ => unreachable!(),
Expand Down
3 changes: 2 additions & 1 deletion src/io/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg(any(feature = "native-tls", feature = "rustls"))]
#![cfg(any(feature = "native-tls", feature = "rustls", feature = "wasmedge-tls"))]

mod native_tls_io;
mod rustls_io;
mod wasmedge_rustls_io;
34 changes: 34 additions & 0 deletions src/io/tls/wasmedge_rustls_io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#![cfg(feature = "wasmedge-tls")]

use wasmedge_rustls_api::{stream::async_stream::TlsStream, ClientConfig};

use crate::{io::Endpoint, Result};

impl Endpoint {
pub async fn make_secure(&mut self, domain: String, _ssl_opts: crate::SslOpts) -> Result<()> {
#[cfg(unix)]
if self.is_socket() {
// won't secure socket connection
return Ok(());
}

let config = ClientConfig::default();

*self = match self {
Endpoint::Plain(ref mut stream) => {
let stream = stream.take().unwrap();

let connection = TlsStream::connect(&config, domain, stream)
.await
.map_err(|e| e.0)?;

Endpoint::Secure(connection)
}
Endpoint::Secure(_) => unreachable!(),
#[cfg(unix)]
Endpoint::Socket(_) => unreachable!(),
};

Ok(())
}
}
2 changes: 1 addition & 1 deletion src/opts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl HostPortOrUrl {
/// ```
#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)]
pub struct SslOpts {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "rustls-tls",))]
client_identity: Option<ClientIdentity>,
root_cert_path: Option<Cow<'static, Path>>,
skip_domain_validation: bool,
Expand Down

0 comments on commit 3e1b203

Please sign in to comment.