Skip to content

Commit

Permalink
feat: zksyncMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
Jrigada authored and Karrq committed Nov 26, 2024
1 parent 639b6a2 commit 402d191
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 9 deletions.
101 changes: 92 additions & 9 deletions src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use crate::{
use alloy_primitives::{keccak256, Address, Bytes, B256, U256};
use alloy_provider::{
network::{AnyNetwork, AnyRpcBlock, AnyRpcTransaction, AnyTxEnvelope},
Provider,
Provider, RootProvider,
};
use alloy_rpc_types::{BlockId, Transaction};
use alloy_serde::WithOtherFields;
use alloy_transport::Transport;
use alloy_transport::{Transport, TransportResult};
use eyre::WrapErr;
use futures::{
channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
Expand Down Expand Up @@ -55,11 +55,15 @@ type FullBlockFuture<Err> = Pin<
type TransactionFuture<Err> =
Pin<Box<dyn Future<Output = (TransactionSender, Result<AnyRpcTransaction, Err>, B256)> + Send>>;

type BytecodeHashFuture<Err> =
Pin<Box<dyn Future<Output = (ByteCodeHashSender, Result<Option<Bytecode>, Err>, B256)> + Send>>;

type AccountInfoSender = OneshotSender<DatabaseResult<AccountInfo>>;
type StorageSender = OneshotSender<DatabaseResult<U256>>;
type BlockHashSender = OneshotSender<DatabaseResult<B256>>;
type FullBlockSender = OneshotSender<DatabaseResult<AnyRpcBlock>>;
type TransactionSender = OneshotSender<DatabaseResult<AnyRpcTransaction>>;
type ByteCodeHashSender = OneshotSender<DatabaseResult<Bytecode>>;

type AddressData = AddressHashMap<AccountInfo>;
type StorageData = AddressHashMap<StorageInfo>;
Expand All @@ -72,6 +76,7 @@ enum ProviderRequest<Err> {
BlockHash(BlockHashFuture<Err>),
FullBlock(FullBlockFuture<Err>),
Transaction(TransactionFuture<Err>),
ByteCodeHash(BytecodeHashFuture<Err>),
}

/// The Request type the Backend listens for
Expand All @@ -89,13 +94,22 @@ enum BackendRequest {
Transaction(B256, TransactionSender),
/// Sets the pinned block to fetch data from
SetPinnedBlock(BlockId),

/// Update Address data
UpdateAddress(AddressData),
/// Update Storage data
UpdateStorage(StorageData),
/// Update Block Hashes
UpdateBlockHash(BlockHashData),
/// Get the bytecode for the given hash
ByteCodeHash(B256, ByteCodeHashSender),
}

pub trait ZkSyncMiddleware: Send + Sync {
fn get_bytecode_by_hash(
&self,
hash: B256,
) -> impl std::future::Future<Output = alloy_transport::TransportResult<Option<Bytecode>>>
+ std::marker::Send;
}

/// Handles an internal provider and listens for requests.
Expand Down Expand Up @@ -128,7 +142,7 @@ pub struct BackendHandler<T, P> {
impl<T, P> BackendHandler<T, P>
where
T: Transport + Clone,
P: Provider<T, AnyNetwork> + Clone + Unpin + 'static,
P: ZkSyncMiddleware + Provider<T, AnyNetwork> + Clone + Unpin + 'static,
{
fn new(
provider: P,
Expand Down Expand Up @@ -210,6 +224,9 @@ where
self.db.block_hashes().write().insert(block, hash);
}
}
BackendRequest::ByteCodeHash(code_hash, sender) => {
self.request_bytecode_by_hash(code_hash, sender);
}
}
}

Expand Down Expand Up @@ -334,12 +351,25 @@ where
}
}
}

fn request_bytecode_by_hash(&mut self, code_hash: B256, sender: ByteCodeHashSender) {
let provider = self.provider.clone();
let fut = Box::pin(async move {
let bytecode = provider
.get_bytecode_by_hash(code_hash)
.await
.wrap_err("could not get bytecode {code_hash}");
(sender, bytecode, code_hash)
});

self.pending_requests.push(ProviderRequest::ByteCodeHash(fut));
}
}

impl<T, P> Future for BackendHandler<T, P>
where
T: Transport + Clone + Unpin,
P: Provider<T, AnyNetwork> + Clone + Unpin + 'static,
P: ZkSyncMiddleware + Provider<T, AnyNetwork> + Clone + Unpin + 'static,
{
type Output = ();

Expand Down Expand Up @@ -505,6 +535,20 @@ where
continue;
}
}
ProviderRequest::ByteCodeHash(fut) => {
if let Poll::Ready((sender, bytecode, code_hash)) = fut.poll_unpin(cx) {
let msg = match bytecode {
Ok(Some(bytecode)) => Ok(bytecode),
Ok(None) => Err(DatabaseError::MissingCode(code_hash)),
Err(err) => {
let err = Arc::new(err);
Err(DatabaseError::GetBytecode(code_hash, err))
}
};
let _ = sender.send(msg);
continue;
}
}
}
// not ready, insert and poll again
pin.pending_requests.push(request);
Expand Down Expand Up @@ -604,7 +648,7 @@ impl SharedBackend {
) -> Self
where
T: Transport + Clone + Unpin,
P: Provider<T, AnyNetwork> + Unpin + 'static + Clone,
P: ZkSyncMiddleware + Provider<T, AnyNetwork> + Unpin + 'static + Clone,
{
let (shared, handler) = Self::new(provider, db, pin_block);
// spawn the provider handler to a task
Expand All @@ -622,7 +666,7 @@ impl SharedBackend {
) -> Self
where
T: Transport + Clone + Unpin,
P: Provider<T, AnyNetwork> + Unpin + 'static + Clone,
P: ZkSyncMiddleware + Provider<T, AnyNetwork> + Unpin + 'static + Clone,
{
let (shared, handler) = Self::new(provider, db, pin_block);

Expand Down Expand Up @@ -652,7 +696,7 @@ impl SharedBackend {
) -> (Self, BackendHandler<T, P>)
where
T: Transport + Clone + Unpin,
P: Provider<T, AnyNetwork> + Unpin + 'static + Clone,
P: ZkSyncMiddleware + Provider<T, AnyNetwork> + Unpin + 'static + Clone,
{
let (backend, backend_rx) = unbounded();
let cache = Arc::new(FlushJsonBlockCacheDB(Arc::clone(db.cache())));
Expand Down Expand Up @@ -756,6 +800,14 @@ impl SharedBackend {
}
}
}
fn do_get_bytecode(&self, hash: B256) -> DatabaseResult<Bytecode> {
tokio::task::block_in_place(|| {
let (sender, rx) = oneshot_channel();
let req = BackendRequest::ByteCodeHash(hash, sender);
self.backend.clone().unbounded_send(req)?;
rx.recv()?
})
}

/// Flushes the DB to disk if caching is enabled
pub fn flush_cache(&self) {
Expand Down Expand Up @@ -818,7 +870,14 @@ impl DatabaseRef for SharedBackend {
}

fn code_by_hash_ref(&self, hash: B256) -> Result<Bytecode, Self::Error> {
Err(DatabaseError::MissingCode(hash))
trace!(target: "sharedbackend", %hash, "request codehash");
self.do_get_bytecode(hash).map_err(|err| {
error!(target: "sharedbackend", %err, %hash, "Failed to send/recv `code_by_hash`");
if err.is_possibly_non_archive_node_error() {
error!(target: "sharedbackend", "{NON_ARCHIVE_NODE_WARNING}");
}
err
})
}

fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
Expand All @@ -844,6 +903,30 @@ impl DatabaseRef for SharedBackend {
}
}

impl<T: alloy_transport::Transport + Clone> super::backend::ZkSyncMiddleware
for RootProvider<T, alloy_provider::network::AnyNetwork>
{
async fn get_bytecode_by_hash(
&self,
hash: alloy_primitives::B256,
) -> TransportResult<Option<revm::primitives::Bytecode>> {
let bytecode: Option<alloy_primitives::Bytes> =
self.raw_request("zks_getBytecodeByHash".into(), vec![hash]).await?;
Ok(bytecode.map(revm::primitives::Bytecode::new_raw))
}
}

impl<T: alloy_transport::Transport + Clone> super::backend::ZkSyncMiddleware
for Arc<RootProvider<T, alloy_provider::network::AnyNetwork>>
{
async fn get_bytecode_by_hash(
&self,
hash: alloy_primitives::B256,
) -> TransportResult<Option<revm::primitives::Bytecode>> {
self.as_ref().get_bytecode_by_hash(hash).await
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub enum DatabaseError {
BlockNotFound(BlockId),
#[error("failed to get transaction {0}: {1}")]
GetTransaction(B256, Arc<eyre::Error>),
#[error("failed to get bytecode for {0:?}: {1}")]
GetBytecode(B256, Arc<eyre::Error>),
}

impl DatabaseError {
Expand All @@ -43,6 +45,7 @@ impl DatabaseError {
Self::GetTransaction(_, err) => Some(err),
// Enumerate explicitly to make sure errors are updated if a new one is added.
Self::MissingCode(_) | Self::Recv(_) | Self::Send(_) | Self::BlockNotFound(_) => None,
Self::GetBytecode(_, err) => Some(err),
}
}

Expand Down

0 comments on commit 402d191

Please sign in to comment.