Skip to content

Commit b71aa2c

Browse files
committed
refactor!: refactor OtherError to CustomError
1 parent 19bea9e commit b71aa2c

File tree

5 files changed

+56
-32
lines changed

5 files changed

+56
-32
lines changed

src/error.rs

+40-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
use std::borrow::Cow;
2+
use std::fmt::{self, Display, Formatter};
13
use std::sync::Arc;
24

3-
use derive_where::derive_where;
45
use static_assertions::assert_impl_all;
56
use thiserror::Error;
67

@@ -87,18 +88,41 @@ pub enum Error {
8788
RuntimeInconsistent,
8889

8990
#[error(transparent)]
90-
Other(OtherError),
91+
Custom(CustomError),
9192
}
9293

9394
#[derive(Error, Clone, Debug)]
94-
#[derive_where(Eq, PartialEq)]
95-
#[error("{message}")]
96-
pub struct OtherError {
97-
message: Arc<String>,
98-
#[derive_where(skip(EqHashOrd))]
95+
pub struct CustomError {
96+
message: Option<Arc<Cow<'static, str>>>,
9997
source: Option<Arc<dyn std::error::Error + Send + Sync + 'static>>,
10098
}
10199

100+
impl Display for CustomError {
101+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
102+
match (self.message.as_ref(), self.source.as_ref()) {
103+
(Some(message), None) => f.write_str(message),
104+
(Some(message), Some(err)) => write!(f, "{}: {}", message, err),
105+
(None, Some(err)) => err.fmt(f),
106+
_ => unreachable!("no error message or source"),
107+
}
108+
}
109+
}
110+
111+
impl PartialEq for CustomError {
112+
fn eq(&self, other: &Self) -> bool {
113+
if !self.message.eq(&other.message) {
114+
return false;
115+
}
116+
match (self.source.as_ref(), other.source.as_ref()) {
117+
(Some(lhs), Some(rhs)) => Arc::ptr_eq(lhs, rhs),
118+
(None, None) => true,
119+
_ => false,
120+
}
121+
}
122+
}
123+
124+
impl Eq for CustomError {}
125+
102126
impl Error {
103127
pub(crate) fn is_terminated(&self) -> bool {
104128
matches!(self, Self::NoHosts | Self::SessionExpired | Self::AuthFailed | Self::ClientClosed)
@@ -131,20 +155,20 @@ impl Error {
131155
}
132156
}
133157

134-
pub(crate) fn new_other(
135-
message: impl Into<Arc<String>>,
136-
source: Option<Arc<dyn std::error::Error + Send + Sync + 'static>>,
137-
) -> Self {
138-
Self::Other(OtherError { message: message.into(), source })
158+
pub(crate) fn with_message(message: impl Into<Cow<'static, str>>) -> Self {
159+
Self::Custom(CustomError { message: Some(Arc::new(message.into())), source: None })
139160
}
140161

141162
#[allow(dead_code)]
142-
pub(crate) fn other(message: impl Into<String>, source: impl std::error::Error + Send + Sync + 'static) -> Self {
143-
Self::new_other(message.into(), Some(Arc::new(source)))
163+
pub(crate) fn with_other(
164+
message: impl Into<Cow<'static, str>>,
165+
source: impl std::error::Error + Send + Sync + 'static,
166+
) -> Self {
167+
Self::Custom(CustomError { message: Some(Arc::new(message.into())), source: Some(Arc::new(source)) })
144168
}
145169

146-
pub(crate) fn other_from(source: impl std::error::Error + Send + Sync + 'static) -> Self {
147-
Self::new_other(source.to_string(), Some(Arc::new(source)))
170+
pub(crate) fn other(source: impl std::error::Error + Send + Sync + 'static) -> Self {
171+
Self::Custom(CustomError { message: None, source: Some(Arc::new(source)) })
148172
}
149173
}
150174

src/sasl/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ impl SaslSession {
9898
return Err(Error::UnexpectedError(format!("SASL {} session already finished", self.name())));
9999
}
100100
self.output.clear();
101-
match self.session.step(Some(challenge), &mut self.output).map_err(|e| Error::other(format!("{e}"), e))? {
101+
match self.session.step(Some(challenge), &mut self.output).map_err(Error::other)? {
102102
State::Running => Ok(Some(&self.output)),
103103
State::Finished(MessageSent::Yes) => {
104104
self.finished = true;

src/session/depot.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ impl Depot {
248248
if err.kind() == io::ErrorKind::WouldBlock {
249249
return Ok(());
250250
}
251-
return Err(Error::other_from(err));
251+
return Err(Error::other(err));
252252
}
253253
return Ok(());
254254
}
@@ -258,7 +258,7 @@ impl Depot {
258258
if err.kind() == io::ErrorKind::WouldBlock {
259259
return Ok(());
260260
}
261-
return Err(Error::other_from(err));
261+
return Err(Error::other(err));
262262
},
263263
Ok(written_bytes) => written_bytes,
264264
};

src/session/mod.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ impl Builder {
113113
let session = match self.session {
114114
Some(session) => {
115115
if session.is_readonly() {
116-
return Err(Error::new_other(
117-
format!("can't reestablish readonly and hence local session {}", session.id()),
118-
None,
119-
));
116+
return Err(Error::with_message(format!(
117+
"can't reestablish readonly and hence local session {}",
118+
session.id()
119+
)));
120120
}
121121
Span::current().record("session", display(session.id()));
122122
session
@@ -490,7 +490,7 @@ impl Session {
490490
},
491491
Err(err) => {
492492
if err.kind() != io::ErrorKind::WouldBlock {
493-
return Err(Error::other_from(err));
493+
return Err(Error::other(err));
494494
}
495495
},
496496
_ => {},
@@ -539,7 +539,7 @@ impl Session {
539539
},
540540
now = tick.tick() => {
541541
if now >= self.last_recv + self.connector.timeout() {
542-
return Err(Error::new_other(format!("no response from connection in {}ms", self.connector.timeout().as_millis()), None));
542+
return Err(Error::with_message(format!("no response from connection in {}ms", self.connector.timeout().as_millis())));
543543
}
544544
},
545545
}
@@ -583,7 +583,7 @@ impl Session {
583583
select! {
584584
Some(endpoint) = Self::poll(&mut seek_for_writable), if seek_for_writable.is_some() => {
585585
seek_for_writable = None;
586-
err = Some(Error::new_other(format!("encounter writable server {}", endpoint), None));
586+
err = Some(Error::with_message(format!("encounter writable server {}", endpoint)));
587587
channel_halted = true;
588588
},
589589
_ = conn.readable() => {
@@ -613,7 +613,7 @@ impl Session {
613613
},
614614
now = tick.tick() => {
615615
if now >= self.last_recv + self.connector.timeout() {
616-
return Err(Error::new_other(format!("no response from connection in {}ms", self.connector.timeout().as_millis()), None));
616+
return Err(Error::with_message(format!("no response from connection in {}ms", self.connector.timeout().as_millis())));
617617
}
618618
if self.last_ping.is_none() && now >= self.last_send + self.ping_timeout {
619619
self.send_ping(depot, now);

src/tls.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ impl TlsOptions {
126126
for r in rustls_pemfile::certs(&mut certs.as_bytes()) {
127127
let cert = match r {
128128
Ok(cert) => cert,
129-
Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)),
129+
Err(err) => return Err(Error::with_other("fail to read cert", err)),
130130
};
131131
if let Err(err) = self.ca_certs.add(cert) {
132-
return Err(Error::other(format!("fail to add cert {}", err), err));
132+
return Err(Error::with_other("fail to add cert", err));
133133
}
134134
}
135135
Ok(self)
@@ -139,11 +139,11 @@ impl TlsOptions {
139139
pub fn with_pem_identity(mut self, cert: &str, key: &str) -> Result<Self> {
140140
let r: std::result::Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect();
141141
let certs = match r {
142-
Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)),
142+
Err(err) => return Err(Error::with_other("fail to read cert", err)),
143143
Ok(certs) => certs,
144144
};
145145
let key = match rustls_pemfile::private_key(&mut key.as_bytes()) {
146-
Err(err) => return Err(Error::other(format!("fail to read client private key {err}"), err)),
146+
Err(err) => return Err(Error::with_other("fail to read client private key", err)),
147147
Ok(None) => return Err(Error::BadArguments(&"no client private key")),
148148
Ok(Some(key)) => key,
149149
};
@@ -163,7 +163,7 @@ impl TlsOptions {
163163
if let Some((client_cert, client_key)) = self.identity.take() {
164164
match builder.with_client_auth_cert(client_cert, client_key) {
165165
Ok(config) => Ok(config),
166-
Err(err) => Err(Error::other(format!("invalid client private key {err}"), err)),
166+
Err(err) => Err(Error::with_other("invalid client private key", err)),
167167
}
168168
} else {
169169
Ok(builder.with_no_client_auth())

0 commit comments

Comments
 (0)