Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(transport): allow RetryPolicy to be set via layer #1790

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion crates/provider/src/provider/trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1119,9 +1119,10 @@ mod tests {
use alloy_network::{AnyNetwork, EthereumWallet, TransactionBuilder};
use alloy_node_bindings::Anvil;
use alloy_primitives::{address, b256, bytes, keccak256};
use alloy_rpc_client::BuiltInConnectionString;
use alloy_rpc_client::{BuiltInConnectionString, RpcClient};
use alloy_rpc_types_eth::{request::TransactionRequest, Block};
use alloy_signer_local::PrivateKeySigner;
use alloy_transport::layers::{RetryBackoffLayer, RetryPolicy};
// For layer transport tests
#[cfg(feature = "hyper")]
use alloy_transport_http::{
Expand Down Expand Up @@ -1427,6 +1428,32 @@ mod tests {
}
}

#[tokio::test]
async fn test_custom_retry_policy() {
#[derive(Debug, Clone)]
struct CustomPolicy;
impl RetryPolicy for CustomPolicy {
fn should_retry(&self, _err: &alloy_transport::TransportError) -> bool {
true
}

fn backoff_hint(
&self,
_error: &alloy_transport::TransportError,
) -> Option<std::time::Duration> {
None
}
}

let retry_layer = RetryBackoffLayer::new_with_policy(10, 100, 10000, CustomPolicy);
let anvil = Anvil::new().spawn();
let client = RpcClient::builder().layer(retry_layer).http(anvil.endpoint_url());

let provider = RootProvider::<_, Ethereum>::new(client);
let num = provider.get_block_number().await.unwrap();
assert_eq!(0, num);
}

#[tokio::test]
async fn test_send_tx() {
let provider = ProviderBuilder::new().on_anvil();
Expand Down
42 changes: 31 additions & 11 deletions crates/transport/src/layers/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,42 @@ use tokio::time::sleep;
///
/// TransportError: crate::error::TransportError
#[derive(Debug, Clone)]
pub struct RetryBackoffLayer {
pub struct RetryBackoffLayer<P: RetryPolicy = RateLimitRetryPolicy> {
/// The maximum number of retries for rate limit errors
max_rate_limit_retries: u32,
/// The initial backoff in milliseconds
initial_backoff: u64,
/// The number of compute units per second for this provider
compute_units_per_second: u64,
/// The [RetryPolicy] to use. Defaults to [RateLimitRetryPolicy]
policy: P,
}

impl RetryBackoffLayer {
/// Creates a new retry layer with the given parameters.
/// Creates a new retry layer with the given parameters and the default [RateLimitRetryPolicy].
pub const fn new(
max_rate_limit_retries: u32,
initial_backoff: u64,
compute_units_per_second: u64,
) -> Self {
Self { max_rate_limit_retries, initial_backoff, compute_units_per_second }
Self {
max_rate_limit_retries,
initial_backoff,
compute_units_per_second,
policy: RateLimitRetryPolicy,
}
}
}

impl<P: RetryPolicy> RetryBackoffLayer<P> {
/// Creates a new retry layer with the given parameters and [RetryPolicy].
pub const fn new_with_policy(
max_rate_limit_retries: u32,
initial_backoff: u64,
compute_units_per_second: u64,
policy: P,
) -> Self {
Self { max_rate_limit_retries, initial_backoff, compute_units_per_second, policy }
}
}

Expand Down Expand Up @@ -72,13 +91,13 @@ impl RetryPolicy for RateLimitRetryPolicy {
}
}

impl<S> Layer<S> for RetryBackoffLayer {
type Service = RetryBackoffService<S>;
impl<S, P: RetryPolicy + Clone> Layer<S> for RetryBackoffLayer<P> {
type Service = RetryBackoffService<S, P>;

fn layer(&self, inner: S) -> Self::Service {
RetryBackoffService {
inner,
policy: RateLimitRetryPolicy,
policy: self.policy.clone(),
max_rate_limit_retries: self.max_rate_limit_retries,
initial_backoff: self.initial_backoff,
compute_units_per_second: self.compute_units_per_second,
Expand All @@ -90,11 +109,11 @@ impl<S> Layer<S> for RetryBackoffLayer {
/// A Tower Service used by the RetryBackoffLayer that is responsible for retrying requests based
/// on the error type. See [TransportError] and [RateLimitRetryPolicy].
#[derive(Debug, Clone)]
pub struct RetryBackoffService<S> {
pub struct RetryBackoffService<S, P: RetryPolicy = RateLimitRetryPolicy> {
/// The inner service
inner: S,
/// The retry policy
policy: RateLimitRetryPolicy,
/// The [RetryPolicy] to use.
policy: P,
/// The maximum number of retries for rate limit errors
max_rate_limit_retries: u32,
/// The initial backoff in milliseconds
Expand All @@ -105,18 +124,19 @@ pub struct RetryBackoffService<S> {
requests_enqueued: Arc<AtomicU32>,
}

impl<S> RetryBackoffService<S> {
impl<S, P: RetryPolicy> RetryBackoffService<S, P> {
const fn initial_backoff(&self) -> Duration {
Duration::from_millis(self.initial_backoff)
}
}

impl<S> Service<RequestPacket> for RetryBackoffService<S>
impl<S, P> Service<RequestPacket> for RetryBackoffService<S, P>
where
S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
+ Send
+ 'static
+ Clone,
P: RetryPolicy + Clone + 'static,
{
type Response = ResponsePacket;
type Error = TransportError;
Expand Down
Loading