diff --git a/crates/torii/grpc/proto/types.proto b/crates/torii/grpc/proto/types.proto index f50bbc5c6e..82ffc8bacb 100644 --- a/crates/torii/grpc/proto/types.proto +++ b/crates/torii/grpc/proto/types.proto @@ -162,6 +162,7 @@ enum ComparisonOperator { } message Token { + string token_id = 1; string contract_address = 2; string name = 3; string symbol = 4; diff --git a/crates/torii/grpc/proto/world.proto b/crates/torii/grpc/proto/world.proto index adfadd1797..a2a7fd27ab 100644 --- a/crates/torii/grpc/proto/world.proto +++ b/crates/torii/grpc/proto/world.proto @@ -40,6 +40,9 @@ service World { // Update token balance subscription rpc UpdateTokenBalancesSubscription (UpdateTokenBalancesSubscriptionRequest) returns (google.protobuf.Empty); + // Subscribe to token updates. + rpc SubscribeTokens (RetrieveTokensRequest) returns (stream SubscribeTokensResponse); + // Retrieve entities rpc RetrieveEventMessages (RetrieveEventMessagesRequest) returns (RetrieveEntitiesResponse); @@ -96,6 +99,11 @@ message RetrieveTokensResponse { repeated types.Token tokens = 1; } +// A response containing token updates +message SubscribeTokensResponse { + types.Token token = 1; +} + // A request to retrieve token balances message RetrieveTokenBalancesRequest { // The account addresses to retrieve balances for diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 1c32d78bda..db384829c2 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -34,6 +34,7 @@ use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; use subscriptions::event::EventManager; use subscriptions::indexer::IndexerManager; +use subscriptions::token::TokenManager; use subscriptions::token_balance::TokenBalanceManager; use tokio::net::TcpListener; use tokio::sync::mpsc::{channel, Receiver}; @@ -61,8 +62,8 @@ use crate::proto::world::{ RetrieveTokensRequest, RetrieveTokensResponse, SubscribeEntitiesRequest, SubscribeEntityResponse, SubscribeEventMessagesRequest, SubscribeEventsResponse, SubscribeIndexerRequest, SubscribeIndexerResponse, SubscribeTokenBalancesResponse, - UpdateEventMessagesSubscriptionRequest, UpdateTokenBalancesSubscriptionRequest, - WorldMetadataRequest, WorldMetadataResponse, + SubscribeTokensResponse, UpdateEventMessagesSubscriptionRequest, + UpdateTokenBalancesSubscriptionRequest, WorldMetadataRequest, WorldMetadataResponse, }; use crate::proto::{self}; use crate::types::schema::SchemaError; @@ -96,6 +97,7 @@ impl From for Error { impl From for proto::types::Token { fn from(value: Token) -> Self { Self { + token_id: value.id, contract_address: value.contract_address, name: value.name, symbol: value.symbol, @@ -127,6 +129,7 @@ pub struct DojoWorld { state_diff_manager: Arc, indexer_manager: Arc, token_balance_manager: Arc, + token_manager: Arc, } impl DojoWorld { @@ -143,6 +146,7 @@ impl DojoWorld { let state_diff_manager = Arc::new(StateDiffManager::default()); let indexer_manager = Arc::new(IndexerManager::default()); let token_balance_manager = Arc::new(TokenBalanceManager::default()); + let token_manager = Arc::new(TokenManager::default()); tokio::task::spawn(subscriptions::model_diff::Service::new_with_block_rcv( block_rx, @@ -165,6 +169,8 @@ impl DojoWorld { &token_balance_manager, ))); + tokio::task::spawn(subscriptions::token::Service::new(Arc::clone(&token_manager))); + Self { pool, world_address, @@ -175,6 +181,7 @@ impl DojoWorld { state_diff_manager, indexer_manager, token_balance_manager, + token_manager, } } } @@ -790,6 +797,13 @@ impl DojoWorld { Ok(RetrieveTokensResponse { tokens }) } + async fn subscribe_tokens( + &self, + contract_addresses: Vec, + ) -> Result>, Error> { + self.token_manager.add_subscriber(contract_addresses).await + } + async fn retrieve_token_balances( &self, account_addresses: Vec, @@ -1281,6 +1295,8 @@ type RetrieveEntitiesStreamingResponseStream = Pin> + Send>>; type SubscribeTokenBalancesResponseStream = Pin> + Send>>; +type SubscribeTokensResponseStream = + Pin> + Send>>; #[tonic::async_trait] impl proto::world::world_server::World for DojoWorld { @@ -1291,6 +1307,7 @@ impl proto::world::world_server::World for DojoWorld { type SubscribeIndexerStream = SubscribeIndexerResponseStream; type RetrieveEntitiesStreamingStream = RetrieveEntitiesStreamingResponseStream; type SubscribeTokenBalancesStream = SubscribeTokenBalancesResponseStream; + type SubscribeTokensStream = SubscribeTokensResponseStream; async fn world_metadata( &self, @@ -1338,6 +1355,23 @@ impl proto::world::world_server::World for DojoWorld { Ok(Response::new(tokens)) } + async fn subscribe_tokens( + &self, + request: Request, + ) -> ServiceResult { + let RetrieveTokensRequest { contract_addresses } = request.into_inner(); + let contract_addresses = contract_addresses + .iter() + .map(|address| Felt::from_bytes_be_slice(address)) + .collect::>(); + + let rx = self + .subscribe_tokens(contract_addresses) + .await + .map_err(|e| Status::internal(e.to_string()))?; + Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeTokensStream)) + } + async fn retrieve_token_balances( &self, request: Request, diff --git a/crates/torii/grpc/src/server/subscriptions/mod.rs b/crates/torii/grpc/src/server/subscriptions/mod.rs index caaa38736e..8d33bd8ef4 100644 --- a/crates/torii/grpc/src/server/subscriptions/mod.rs +++ b/crates/torii/grpc/src/server/subscriptions/mod.rs @@ -9,6 +9,7 @@ pub mod event; pub mod event_message; pub mod indexer; pub mod model_diff; +pub mod token; pub mod token_balance; pub(crate) fn match_entity_keys( diff --git a/crates/torii/grpc/src/server/subscriptions/token.rs b/crates/torii/grpc/src/server/subscriptions/token.rs new file mode 100644 index 0000000000..d223a21b3c --- /dev/null +++ b/crates/torii/grpc/src/server/subscriptions/token.rs @@ -0,0 +1,167 @@ +use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::{Stream, StreamExt}; +use rand::Rng; +use starknet_crypto::Felt; +use tokio::sync::mpsc::{ + channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, +}; +use tokio::sync::RwLock; +use torii_sqlite::error::{Error, ParseError}; +use torii_sqlite::simple_broker::SimpleBroker; +use torii_sqlite::types::Token; +use tracing::{error, trace}; + +use crate::proto; +use crate::proto::world::SubscribeTokensResponse; + +pub(crate) const LOG_TARGET: &str = "torii::grpc::server::subscriptions::balance"; + +#[derive(Debug)] +pub struct TokenSubscriber { + /// Contract addresses that the subscriber is interested in + /// If empty, subscriber receives updates for all contracts + pub contract_addresses: HashSet, + /// The channel to send the response back to the subscriber. + pub sender: Sender>, +} + +#[derive(Debug, Default)] +pub struct TokenManager { + subscribers: RwLock>, +} + +impl TokenManager { + pub async fn add_subscriber( + &self, + contract_addresses: Vec, + ) -> Result>, Error> { + let subscription_id = rand::thread_rng().gen::(); + let (sender, receiver) = channel(1); + + // Send initial empty response + let _ = sender.send(Ok(SubscribeTokensResponse { token: None })).await; + + self.subscribers.write().await.insert( + subscription_id, + TokenSubscriber { + contract_addresses: contract_addresses.into_iter().collect(), + sender, + }, + ); + + Ok(receiver) + } + + pub async fn update_subscriber(&self, id: u64, contract_addresses: Vec) { + let sender = { + let subscribers = self.subscribers.read().await; + if let Some(subscriber) = subscribers.get(&id) { + subscriber.sender.clone() + } else { + return; // Subscriber not found, exit early + } + }; + + self.subscribers.write().await.insert( + id, + TokenSubscriber { + contract_addresses: contract_addresses.into_iter().collect(), + sender, + }, + ); + } + + pub(super) async fn remove_subscriber(&self, id: u64) { + self.subscribers.write().await.remove(&id); + } +} + +#[must_use = "Service does nothing unless polled"] +#[allow(missing_debug_implementations)] +pub struct Service { + simple_broker: Pin + Send>>, + balance_sender: UnboundedSender, +} + +impl Service { + pub fn new(subs_manager: Arc) -> Self { + let (balance_sender, balance_receiver) = unbounded_channel(); + let service = + Self { simple_broker: Box::pin(SimpleBroker::::subscribe()), balance_sender }; + + tokio::spawn(Self::publish_updates(subs_manager, balance_receiver)); + + service + } + + async fn publish_updates( + subs: Arc, + mut balance_receiver: UnboundedReceiver, + ) { + while let Some(balance) = balance_receiver.recv().await { + if let Err(e) = Self::process_balance_update(&subs, &balance).await { + error!(target = LOG_TARGET, error = %e, "Processing balance update."); + } + } + } + + async fn process_balance_update(subs: &Arc, token: &Token) -> Result<(), Error> { + let mut closed_stream = Vec::new(); + + for (idx, sub) in subs.subscribers.read().await.iter() { + let contract_address = + Felt::from_str(&token.contract_address).map_err(ParseError::FromStr)?; + + // Skip if contract address filter doesn't match + if !sub.contract_addresses.is_empty() + && !sub.contract_addresses.contains(&contract_address) + { + continue; + } + + let resp = SubscribeTokensResponse { + token: Some(proto::types::Token { + token_id: token.id.clone(), + contract_address: token.contract_address.clone(), + name: token.name.clone(), + symbol: token.symbol.clone(), + decimals: token.decimals as u32, + metadata: token.metadata.clone(), + }), + }; + + if sub.sender.send(Ok(resp)).await.is_err() { + closed_stream.push(*idx); + } + } + + for id in closed_stream { + trace!(target = LOG_TARGET, id = %id, "Closing balance stream."); + subs.remove_subscriber(id).await + } + + Ok(()) + } +} + +impl Future for Service { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + while let Poll::Ready(Some(balance)) = this.simple_broker.poll_next_unpin(cx) { + if let Err(e) = this.balance_sender.send(balance) { + error!(target = LOG_TARGET, error = %e, "Sending balance update to processor."); + } + } + + Poll::Pending + } +}