Skip to content

Commit 97c2b9f

Browse files
authored
Support multiple async runtimes but not only tokio (#52)
Resolves #26. changelog: changed breaking: running with tokio needs additional setups now
1 parent c71c44d commit 97c2b9f

11 files changed

+161
-146
lines changed

Cargo.toml

+12-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ sasl-gssapi = ["rsasl/gssapi"]
2323

2424
[dependencies]
2525
bytes = "1.1.0"
26-
tokio = {version = "1.15.0", features = ["full"]}
2726
thiserror = "1.0.30"
2827
strum = { version = "0.23", features = ["derive"] }
2928
num_enum = "0.5.6"
@@ -39,13 +38,18 @@ rustls = { version = "0.23.2", optional = true }
3938
rustls-pemfile = { version = "2", optional = true }
4039
webpki-roots = { version = "0.26.1", optional = true }
4140
derive-where = "1.2.7"
42-
tokio-rustls = "0.26.0"
4341
fastrand = "2.0.2"
4442
tracing = "0.1.40"
4543
rsasl = { version = "2.0.1", default-features = false, features = ["provider", "config_builder", "registry_static", "std"], optional = true }
4644
md5 = { version = "0.7.0", optional = true }
4745
hex = { version = "0.4.3", optional = true }
4846
linkme = { version = "0.2", optional = true }
47+
async-io = "2.3.2"
48+
futures = "0.3.30"
49+
async-net = "2.0.0"
50+
futures-rustls = "0.26.0"
51+
futures-lite = "2.3.0"
52+
asyncs = "0.3.0"
4953

5054
[dev-dependencies]
5155
test-log = { version = "0.2.15", features = ["log", "trace"] }
@@ -59,9 +63,15 @@ assert_matches = "1.5.0"
5963
tempfile = "3.6.0"
6064
rcgen = { version = "0.12.1", features = ["default", "x509-parser"] }
6165
serial_test = "3.0.0"
66+
asyncs = { version = "0.3.0", features = ["test"] }
67+
blocking = "1.6.0"
6268

6369
[package.metadata.cargo-all-features]
6470
skip_optional_dependencies = true
6571

6672
[package.metadata.docs.rs]
6773
all-features = true
74+
75+
[profile.dev]
76+
# Need this for linkme crate to work for spawns in macOS
77+
lto = "thin"

src/client/mod.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ use std::time::Duration;
88

99
use const_format::formatcp;
1010
use either::{Either, Left, Right};
11+
use futures::channel::mpsc;
1112
use ignore_result::Ignore;
1213
use thiserror::Error;
13-
use tokio::sync::mpsc;
1414
use tracing::instrument;
1515

1616
pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
@@ -322,9 +322,9 @@ impl Client {
322322

323323
fn send_marshalled_request(&self, request: MarshalledRequest) -> StateReceiver {
324324
let (operation, receiver) = SessionOperation::new_marshalled(request).with_responser();
325-
if let Err(mpsc::error::SendError(operation)) = self.requester.send(operation) {
325+
if let Err(err) = self.requester.unbounded_send(operation) {
326326
let state = self.state();
327-
operation.responser.send(Err(state.to_error()));
327+
err.into_inner().responser.send(Err(state.to_error()));
328328
}
329329
receiver
330330
}
@@ -514,7 +514,7 @@ impl Client {
514514

515515
// TODO: move these to session side so to eliminate owned Client and String.
516516
fn delete_background(self, path: String) {
517-
tokio::spawn(async move {
517+
asyncs::spawn(async move {
518518
self.delete_foreground(&path).await;
519519
});
520520
}
@@ -524,7 +524,7 @@ impl Client {
524524
}
525525

526526
fn delete_ephemeral_background(self, prefix: String, unique: bool) {
527-
tokio::spawn(async move {
527+
asyncs::spawn(async move {
528528
let (parent, tree, name) = util::split_path(&prefix);
529529
let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
530530
if unique {
@@ -1673,13 +1673,13 @@ impl Connector {
16731673
let mut buf = Vec::with_capacity(4096);
16741674
let mut depot = Depot::new();
16751675
let conn = session.start(&mut endpoints, &mut buf, &mut depot).await?;
1676-
let (sender, receiver) = mpsc::unbounded_channel();
1676+
let (sender, receiver) = mpsc::unbounded();
16771677
let session_info = session.session.clone();
16781678
let session_timeout = session.session_timeout;
16791679
let mut state_watcher = StateWatcher::new(state_receiver);
16801680
// Consume all state changes so far.
16811681
state_watcher.state();
1682-
tokio::spawn(async move {
1682+
asyncs::spawn(async move {
16831683
session.serve(endpoints, conn, buf, depot, receiver).await;
16841684
});
16851685
let client =
@@ -2270,7 +2270,7 @@ mod tests {
22702270
.is_equal_to(Error::BadArguments(&"directory node must not be sequential"));
22712271
}
22722272

2273-
#[test_log::test(tokio::test)]
2273+
#[test_log::test(asyncs::test)]
22742274
async fn session_last_zxid_seen() {
22752275
use testcontainers::clients::Cli as DockerCli;
22762276
use testcontainers::core::{Healthcheck, WaitFor};

src/client/watcher.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use tokio::sync::watch;
1+
use asyncs::sync::watch;
22

33
use crate::chroot::OwnedChroot;
44
use crate::error::Error;
@@ -25,11 +25,11 @@ impl StateWatcher {
2525
///
2626
/// This method will block indefinitely after one of terminal states consumed.
2727
pub async fn changed(&mut self) -> SessionState {
28-
if self.receiver.changed().await.is_err() {
28+
match self.receiver.changed().await {
29+
Ok(changed) => *changed,
2930
// Terminal state must be delivered.
30-
std::future::pending().await
31+
Err(_) => std::future::pending().await,
3132
}
32-
self.state()
3333
}
3434

3535
/// Returns but not consumes most recently state.

src/deadline.rs

+12-13
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,34 @@
11
use std::future::Future;
22
use std::pin::Pin;
33
use std::task::{Context, Poll};
4-
use std::time::Duration;
4+
use std::time::{Duration, Instant};
55

6-
use tokio::time::{self, Instant, Sleep};
6+
use async_io::Timer;
7+
use futures::future::{Fuse, FusedFuture, FutureExt};
78

89
pub struct Deadline {
9-
sleep: Option<Sleep>,
10+
timer: Fuse<Timer>,
11+
deadline: Option<Instant>,
1012
}
1113

1214
impl Deadline {
1315
pub fn never() -> Self {
14-
Self { sleep: None }
16+
Self { timer: Timer::never().fuse(), deadline: None }
1517
}
1618

1719
pub fn until(deadline: Instant) -> Self {
18-
Self { sleep: Some(time::sleep_until(deadline)) }
20+
Self { timer: Timer::at(deadline).fuse(), deadline: Some(deadline) }
1921
}
2022

2123
pub fn elapsed(&self) -> bool {
22-
self.sleep.as_ref().map(|f| f.is_elapsed()).unwrap_or(false)
24+
self.timer.is_terminated()
2325
}
2426

2527
/// Remaining timeout.
2628
pub fn timeout(&self) -> Duration {
27-
match self.sleep.as_ref() {
29+
match self.deadline.as_ref() {
2830
None => Duration::MAX,
29-
Some(sleep) => sleep.deadline().saturating_duration_since(Instant::now()),
31+
Some(deadline) => deadline.saturating_duration_since(Instant::now()),
3032
}
3133
}
3234
}
@@ -35,10 +37,7 @@ impl Future for Deadline {
3537
type Output = ();
3638

3739
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
38-
if self.sleep.is_none() {
39-
return Poll::Pending;
40-
}
41-
let sleep = unsafe { self.map_unchecked_mut(|deadline| deadline.sleep.as_mut().unwrap_unchecked()) };
42-
sleep.poll(cx)
40+
let timer = unsafe { self.map_unchecked_mut(|deadline| &mut deadline.timer) };
41+
timer.poll(cx).map(|_| ())
4342
}
4443
}

src/endpoint.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::fmt::{self, Display, Formatter};
22
use std::time::Duration;
33

4+
use async_io::Timer;
5+
46
use crate::chroot::Chroot;
57
use crate::error::Error;
68
use crate::util::{Ref, ToRef};
@@ -219,7 +221,7 @@ impl IterableEndpoints {
219221
async fn delay(&self, index: Index, max_delay: Duration) {
220222
let timeout = max_delay.min(Self::timeout(index, self.endpoints.len()));
221223
if timeout != Duration::ZERO {
222-
tokio::time::sleep(timeout).await;
224+
Timer::after(timeout).await;
223225
}
224226
}
225227

@@ -336,7 +338,7 @@ mod tests {
336338
);
337339
}
338340

339-
#[tokio::test]
341+
#[asyncs::test]
340342
async fn test_iterable_endpoints_next() {
341343
use std::time::Duration;
342344

src/session/connection.rs

+26-22
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,24 @@ use std::pin::Pin;
33
use std::task::{Context, Poll};
44
use std::time::Duration;
55

6+
use async_io::Timer;
7+
use async_net::TcpStream;
8+
use asyncs::select;
69
use bytes::buf::BufMut;
10+
use futures::io::BufReader;
11+
use futures::prelude::*;
12+
use futures_lite::AsyncReadExt;
713
use ignore_result::Ignore;
8-
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
9-
use tokio::net::TcpStream;
10-
use tokio::{select, time};
1114
use tracing::{debug, trace};
1215

1316
#[cfg(feature = "tls")]
1417
mod tls {
1518
pub use std::sync::Arc;
1619

20+
pub use futures_rustls::client::TlsStream;
21+
pub use futures_rustls::TlsConnector;
1722
pub use rustls::pki_types::ServerName;
1823
pub use rustls::ClientConfig;
19-
pub use tokio_rustls::client::TlsStream;
20-
pub use tokio_rustls::TlsConnector;
2124
}
2225
#[cfg(feature = "tls")]
2326
use tls::*;
@@ -51,7 +54,7 @@ pub trait AsyncReadToBuf: AsyncReadExt {
5154
impl<T> AsyncReadToBuf for T where T: AsyncReadExt {}
5255

5356
impl AsyncRead for Connection {
54-
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
57+
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
5558
match self.get_mut() {
5659
Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
5760
#[cfg(feature = "tls")]
@@ -85,11 +88,11 @@ impl AsyncWrite for Connection {
8588
}
8689
}
8790

88-
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
91+
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
8992
match self.get_mut() {
90-
Self::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
93+
Self::Raw(stream) => Pin::new(stream).poll_close(cx),
9194
#[cfg(feature = "tls")]
92-
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
95+
Self::Tls(stream) => Pin::new(stream).poll_close(cx),
9396
}
9497
}
9598
}
@@ -99,7 +102,7 @@ pub struct ConnReader<'a> {
99102
}
100103

101104
impl AsyncRead for ConnReader<'_> {
102-
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
105+
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
103106
Pin::new(&mut self.get_mut().conn).poll_read(cx, buf)
104107
}
105108
}
@@ -121,8 +124,8 @@ impl AsyncWrite for ConnWriter<'_> {
121124
Pin::new(&mut self.get_mut().conn).poll_flush(cx)
122125
}
123126

124-
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
125-
Pin::new(&mut self.get_mut().conn).poll_shutdown(cx)
127+
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
128+
Pin::new(&mut self.get_mut().conn).poll_close(cx)
126129
}
127130
}
128131

@@ -142,13 +145,14 @@ impl Connection {
142145
Self::Tls(stream)
143146
}
144147

145-
pub async fn command(self, cmd: &str) -> Result<String> {
146-
let mut stream = BufStream::new(self);
147-
stream.write_all(cmd.as_bytes()).await?;
148-
stream.flush().await?;
148+
pub async fn command(mut self, cmd: &str) -> Result<String> {
149+
// let mut stream = BufStream::new(self);
150+
self.write_all(cmd.as_bytes()).await?;
151+
self.flush().await?;
149152
let mut line = String::new();
150-
stream.read_line(&mut line).await?;
151-
stream.shutdown().await.ignore();
153+
let mut reader = BufReader::new(self);
154+
reader.read_line(&mut line).await?;
155+
reader.close().await.ignore();
152156
Ok(line)
153157
}
154158

@@ -212,7 +216,7 @@ impl Connector {
212216
}
213217
select! {
214218
_ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")),
215-
_ = time::sleep(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
219+
_ = Timer::after(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
216220
r = TcpStream::connect((endpoint.host, endpoint.port)) => {
217221
match r {
218222
Err(err) => Err(err),
@@ -255,10 +259,10 @@ impl Connector {
255259
"fails to contact writable server from endpoints {:?}",
256260
endpoints.endpoints()
257261
);
258-
time::sleep(timeout).await;
262+
Timer::after(timeout).await;
259263
timeout = max_timeout.min(timeout * 2);
260264
} else {
261-
time::sleep(Duration::from_millis(5)).await;
265+
Timer::after(Duration::from_millis(5)).await;
262266
}
263267
}
264268
None
@@ -273,7 +277,7 @@ mod tests {
273277
use crate::deadline::Deadline;
274278
use crate::endpoint::EndpointRef;
275279

276-
#[tokio::test]
280+
#[asyncs::test]
277281
async fn raw() {
278282
let connector = Connector::new();
279283
let endpoint = EndpointRef::new("host1", 2181, true);

src/session/depot.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use std::collections::VecDeque;
22
use std::io::IoSlice;
33

4+
use futures_lite::io::AsyncWriteExt;
45
use hashbrown::HashMap;
56
use strum::IntoEnumIterator;
6-
use tokio::io::AsyncWriteExt;
77
use tracing::debug;
88

99
use super::request::{MarshalledRequest, OpStat, Operation, SessionOperation, StateResponser};

0 commit comments

Comments
 (0)