Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(layers): FillTxLayer #374

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions crates/provider/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use std::marker::PhantomData;
/// A layering abstraction in the vein of [`tower::Layer`]
///
/// [`tower::Layer`]: https://docs.rs/tower/latest/tower/trait.Layer.html
pub trait ProviderLayer<P: Provider<N, T>, N: Network, T: Transport + Clone> {
pub trait ProviderLayer<P: Provider<N, T> + Clone, N: Network, T: Transport + Clone> {
/// The provider constructed by this layer.
type Provider: Provider<N, T>;
type Provider: Provider<N, T> + Clone;

/// Wrap the given provider in the layer's provider.
fn layer(&self, inner: P) -> Self::Provider;
Expand All @@ -23,7 +23,7 @@ impl<P, N, T> ProviderLayer<P, N, T> for Identity
where
T: Transport + Clone,
N: Network,
P: Provider<N, T>,
P: Provider<N, T> + Clone,
{
type Provider = P;

Expand All @@ -50,7 +50,7 @@ impl<P, N, T, Inner, Outer> ProviderLayer<P, N, T> for Stack<Inner, Outer>
where
T: Transport + Clone,
N: Network,
P: Provider<N, T>,
P: Provider<N, T> + Clone,
Inner: ProviderLayer<P, N, T>,
Outer: ProviderLayer<Inner::Provider, N, T>,
{
Expand Down Expand Up @@ -128,7 +128,7 @@ impl<L, N> ProviderBuilder<L, N> {
pub fn provider<P, T>(self, provider: P) -> L::Provider
where
L: ProviderLayer<P, N, T>,
P: Provider<N, T>,
P: Provider<N, T> + Clone,
T: Transport + Clone,
N: Network,
{
Expand Down
213 changes: 213 additions & 0 deletions crates/provider/src/layers/fill_tx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
use crate::{
layers::{GasEstimatorProvider, ManagedNonceProvider},
PendingTransactionBuilder, Provider, ProviderLayer, RootProvider,
};
use alloy_network::{Network, TransactionBuilder};
use alloy_transport::{Transport, TransportError, TransportErrorKind, TransportResult};
use async_trait::async_trait;
use futures::FutureExt;
use std::marker::PhantomData;

/// A layer that fills in missing transaction fields.
#[derive(Debug, Clone, Copy)]
pub struct FillTxLayer;

impl<P, N, T> ProviderLayer<P, N, T> for FillTxLayer
where
P: Provider<N, T> + Clone,
N: Network,
T: Transport + Clone,
{
type Provider = FillTxProvider<N, T, P>;

fn layer(&self, inner: P) -> Self::Provider {
let nonce_provider = ManagedNonceProvider::new(inner.clone());
let gas_estimation_provider = GasEstimatorProvider::new(inner.clone());
FillTxProvider { inner, nonce_provider, gas_estimation_provider, _phantom: PhantomData }
}
}

/// A provider that fills in missing transaction fields.
#[derive(Debug, Clone)]
pub struct FillTxProvider<N, T, P>
where
N: Network,
T: Transport + Clone,
P: Provider<N, T> + Clone,
{
inner: P,
nonce_provider: ManagedNonceProvider<N, T, P>,
gas_estimation_provider: GasEstimatorProvider<N, T, P>,
_phantom: PhantomData<(N, T)>,
}

impl<N, T, P> FillTxProvider<N, T, P>
where
N: Network,
T: Transport + Clone,
P: Provider<N, T> + Clone,
{
/// Fills in missing transaction fields.
pub async fn fill_tx(&self, tx: &mut N::TransactionRequest) -> TransportResult<()> {
let chain_id_fut = if let Some(chain_id) = tx.chain_id() {
async move { Ok(chain_id) }.left_future()
} else {
async move { self.inner.get_chain_id().await.map(|ci| ci.to::<u64>()) }.right_future()
};

// Check if `from` is set
if tx.nonce().is_none() && tx.from().is_none() {
return Err(TransportError::Transport(TransportErrorKind::Custom(
"`from` field must be set in transaction request to populate the `nonce` field"
.into(),
)));
}

let nonce_fut = if let Some(nonce) = tx.nonce() {
async move { Ok(nonce) }.left_future()
} else {
let from = tx.from().unwrap();
async move { self.nonce_provider.get_next_nonce(from).await }.right_future()
};

let gas_estimation_fut = if tx.gas_price().is_none() {
async { self.gas_estimation_provider.handle_eip1559_tx(tx).await }.left_future()
} else {
async { self.gas_estimation_provider.handle_legacy_tx(tx).await }.right_future()
};

match futures::try_join!(chain_id_fut, nonce_fut, gas_estimation_fut) {
Ok((chain_id, nonce, _)) => {
tx.set_chain_id(chain_id);
tx.set_nonce(nonce);
Ok(())
}
Err(e) => Err(e),
yash-atreya marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl<N, T, P> Provider<N, T> for FillTxProvider<N, T, P>
where
N: Network,
T: Transport + Clone,
P: Provider<N, T> + Clone,
{
#[inline]
fn root(&self) -> &RootProvider<N, T> {
self.inner.root()
}

async fn send_transaction(
&self,
mut tx: N::TransactionRequest,
) -> TransportResult<PendingTransactionBuilder<'_, N, T>> {
self.fill_tx(&mut tx).await?;
self.inner.send_transaction(tx).await
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ProviderBuilder;
use alloy_network::EthereumSigner;
use alloy_node_bindings::Anvil;
use alloy_primitives::{address, U256};
use alloy_rpc_client::RpcClient;
use alloy_rpc_types::TransactionRequest;
use alloy_transport_http::Http;
use reqwest::Client;

#[tokio::test]
async fn test_1559_tx_no_nonce_no_chain_id() {
let anvil = Anvil::new().spawn();
let url = anvil.endpoint_url();

let http = Http::<Client>::new(url);

let wallet = alloy_signer_wallet::Wallet::from(anvil.keys()[0].clone());

let provider = ProviderBuilder::new()
.layer(FillTxLayer)
.signer(EthereumSigner::from(wallet))
.provider(RootProvider::new(RpcClient::new(http, true)));

let tx = TransactionRequest {
from: Some(anvil.addresses()[0]),
value: Some(U256::from(100)),
to: address!("d8dA6BF26964aF9D7eEd9e03E53415D37aA96045").into(),
..Default::default()
};

let tx = provider.send_transaction(tx).await.unwrap();

let tx = provider.get_transaction_by_hash(tx.tx_hash().to_owned()).await.unwrap();

assert_eq!(tx.max_fee_per_gas, Some(U256::from(0x77359400)));
assert_eq!(tx.max_priority_fee_per_gas, Some(U256::from(0x0)));
assert_eq!(tx.gas, U256::from(21000));
assert_eq!(tx.nonce, 0);
assert_eq!(tx.chain_id, Some(31337));
}

#[tokio::test]
async fn test_legacy_tx_no_nonce_chain_id() {
let anvil = Anvil::new().spawn();
let url = anvil.endpoint_url();

let http = Http::<Client>::new(url);

let wallet = alloy_signer_wallet::Wallet::from(anvil.keys()[0].clone());

let provider = ProviderBuilder::new()
.layer(FillTxLayer)
.signer(EthereumSigner::from(wallet))
.provider(RootProvider::new(RpcClient::new(http, true)));

let gas_price = provider.get_gas_price().await.unwrap();
let tx = TransactionRequest {
from: Some(anvil.addresses()[0]),
value: Some(U256::from(100)),
to: address!("d8dA6BF26964aF9D7eEd9e03E53415D37aA96045").into(),
chain_id: Some(31337),
gas_price: Some(gas_price),
..Default::default()
};

let tx = provider.send_transaction(tx).await.unwrap();

let tx = provider.get_transaction_by_hash(tx.tx_hash().to_owned()).await.unwrap();

assert_eq!(tx.gas_price, Some(gas_price));
assert_eq!(tx.gas, U256::from(21000));
assert_eq!(tx.nonce, 0);
assert_eq!(tx.chain_id, Some(31337));
}

#[tokio::test]
#[should_panic]
async fn test_no_from() {
let anvil = Anvil::new().spawn();
let url = anvil.endpoint_url();

let http = Http::<Client>::new(url);

let wallet = alloy_signer_wallet::Wallet::from(anvil.keys()[0].clone());

let provider = ProviderBuilder::new()
.layer(FillTxLayer)
.signer(EthereumSigner::from(wallet))
.provider(RootProvider::new(RpcClient::new(http, true)));

let tx = TransactionRequest {
value: Some(U256::from(100)),
to: address!("d8dA6BF26964aF9D7eEd9e03E53415D37aA96045").into(),
..Default::default()
};

let _ = provider.send_transaction(tx).await.unwrap();
}
}
20 changes: 14 additions & 6 deletions crates/provider/src/layers/gas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub struct GasEstimatorLayer;

impl<P, N, T> ProviderLayer<P, N, T> for GasEstimatorLayer
where
P: Provider<N, T>,
P: Provider<N, T> + Clone,
N: Network,
T: Transport + Clone,
{
Expand All @@ -67,7 +67,7 @@ pub struct GasEstimatorProvider<N, T, P>
where
N: Network,
T: Transport + Clone,
P: Provider<N, T>,
P: Provider<N, T> + Clone,
{
inner: P,
_phantom: PhantomData<(N, T)>,
Expand All @@ -77,8 +77,13 @@ impl<N, T, P> GasEstimatorProvider<N, T, P>
where
N: Network,
T: Transport + Clone,
P: Provider<N, T>,
P: Provider<N, T> + Clone,
{
/// Creates a new GasEstimatorProvider.
pub(crate) const fn new(inner: P) -> Self {
Self { inner, _phantom: PhantomData }
}

/// Gets the gas_price to be used in legacy txs.
async fn get_gas_price(&self) -> TransportResult<U256> {
self.inner.get_gas_price().await
Expand All @@ -97,7 +102,7 @@ where
/// Populates the gas_limit, max_fee_per_gas and max_priority_fee_per_gas fields if unset.
/// Requires the chain_id to be set in the transaction request to be processed as a EIP-1559 tx.
/// If the network does not support EIP-1559, it will process it as a legacy tx.
async fn handle_eip1559_tx(
pub async fn handle_eip1559_tx(
&self,
tx: &mut N::TransactionRequest,
) -> Result<(), TransportError> {
Expand Down Expand Up @@ -130,7 +135,10 @@ where

/// Populates the gas_price and only populates the gas_limit field if unset.
/// This method always assumes that the gas_price is unset.
async fn handle_legacy_tx(&self, tx: &mut N::TransactionRequest) -> Result<(), TransportError> {
pub async fn handle_legacy_tx(
&self,
tx: &mut N::TransactionRequest,
) -> Result<(), TransportError> {
let gas_price_fut = self.get_gas_price();
let gas_limit_fut = if let Some(gas_limit) = tx.gas_limit() {
async move { Ok(gas_limit) }.left_future()
Expand All @@ -154,7 +162,7 @@ impl<N, T, P> Provider<N, T> for GasEstimatorProvider<N, T, P>
where
N: Network,
T: Transport + Clone,
P: Provider<N, T>,
P: Provider<N, T> + Clone,
{
fn root(&self) -> &RootProvider<N, T> {
self.inner.root()
Expand Down
3 changes: 3 additions & 0 deletions crates/provider/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ pub use nonce::{ManagedNonceLayer, ManagedNonceProvider};

mod gas;
pub use gas::{GasEstimatorLayer, GasEstimatorProvider};

mod fill_tx;
pub use fill_tx::{FillTxLayer, FillTxProvider};
16 changes: 11 additions & 5 deletions crates/provider/src/layers/nonce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct ManagedNonceLayer;

impl<P, N, T> ProviderLayer<P, N, T> for ManagedNonceLayer
where
P: Provider<N, T>,
P: Provider<N, T> + Clone,
N: Network,
T: Transport + Clone,
{
Expand Down Expand Up @@ -69,7 +69,7 @@ pub struct ManagedNonceProvider<N, T, P>
where
N: Network,
T: Transport + Clone,
P: Provider<N, T>,
P: Provider<N, T> + Clone,
{
inner: P,
nonces: DashMap<Address, Arc<Mutex<Option<u64>>>>,
Expand All @@ -80,9 +80,15 @@ impl<N, T, P> ManagedNonceProvider<N, T, P>
where
N: Network,
T: Transport + Clone,
P: Provider<N, T>,
P: Provider<N, T> + Clone,
{
async fn get_next_nonce(&self, from: Address) -> TransportResult<u64> {
/// Creates a new ManagedNonceProvider.
pub(crate) fn new(inner: P) -> Self {
Self { inner, nonces: DashMap::default(), _phantom: PhantomData }
}

/// Gets the next nonce for the given account.
pub async fn get_next_nonce(&self, from: Address) -> TransportResult<u64> {
// locks dashmap internally for a short duration to clone the `Arc`
let mutex = Arc::clone(self.nonces.entry(from).or_default().value());

Expand All @@ -109,7 +115,7 @@ impl<N, T, P> Provider<N, T> for ManagedNonceProvider<N, T, P>
where
N: Network,
T: Transport + Clone,
P: Provider<N, T>,
P: Provider<N, T> + Clone,
{
#[inline]
fn root(&self) -> &RootProvider<N, T> {
Expand Down
Loading
Loading