diff --git a/crates/torii/grpc/proto/types.proto b/crates/torii/grpc/proto/types.proto index f291e58652..70680462ad 100644 --- a/crates/torii/grpc/proto/types.proto +++ b/crates/torii/grpc/proto/types.proto @@ -75,6 +75,7 @@ message Query { uint32 offset = 3; bool dont_include_hashed_keys = 4; repeated OrderBy order_by = 5; + repeated string entity_models = 6; } message EventQuery { diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index b88dcd66ec..62c16a304b 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -13,6 +13,7 @@ use std::str::FromStr; use std::sync::Arc; use std::time::Duration; +use dojo_types::naming::compute_selector_from_tag; use dojo_types::primitive::{Primitive, PrimitiveError}; use dojo_types::schema::Ty; use dojo_world::contracts::naming::compute_selector_from_names; @@ -227,6 +228,7 @@ impl DojoWorld { offset: u32, dont_include_hashed_keys: bool, order_by: Option<&str>, + entity_models: Vec, ) -> Result<(Vec, u32), Error> { self.query_by_hashed_keys( table, @@ -237,6 +239,7 @@ impl DojoWorld { Some(offset), dont_include_hashed_keys, order_by, + entity_models, ) .await } @@ -262,7 +265,11 @@ impl DojoWorld { entities: Vec<(String, String)>, dont_include_hashed_keys: bool, order_by: Option<&str>, + entity_models: Vec, ) -> Result, Error> { + let entity_models = + entity_models.iter().map(|tag| compute_selector_from_tag(tag)).collect::>(); + tracing::debug!( "Fetching entities from table {table} with {} entity/model pairs", entities.len() @@ -315,10 +322,27 @@ impl DojoWorld { let query_start = std::time::Instant::now(); for (models_str, entity_ids) in &model_groups { tracing::debug!("Processing model group with {} entities", entity_ids.len()); - let model_ids = - models_str.split(',').map(|id| Felt::from_str(id).unwrap()).collect::>(); - let schemas = - self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); + let model_ids = models_str + .split(',') + .filter_map(|id| { + let model_id = Felt::from_str(id).unwrap(); + if entity_models.is_empty() || entity_models.contains(&model_id) { + Some(model_id) + } else { + None + } + }) + .collect::>(); + let schemas = self + .model_cache + .models(&model_ids) + .await? + .into_iter() + .map(|m| m.schema) + .collect::>(); + if schemas.is_empty() { + continue; + } let (entity_query, _) = build_sql_query( &schemas, @@ -405,6 +429,7 @@ impl DojoWorld { offset: Option, dont_include_hashed_keys: bool, order_by: Option<&str>, + entity_models: Vec, ) -> Result<(Vec, u32), Error> { // TODO: use prepared statement for where clause let filter_ids = match hashed_keys { @@ -484,6 +509,7 @@ impl DojoWorld { db_entities, dont_include_hashed_keys, order_by, + entity_models, ) .await?; Ok((entities, total_count)) @@ -500,6 +526,7 @@ impl DojoWorld { offset: Option, dont_include_hashed_keys: bool, order_by: Option<&str>, + entity_models: Vec, ) -> Result<(Vec, u32), Error> { let keys_pattern = build_keys_pattern(keys_clause)?; @@ -627,6 +654,7 @@ impl DojoWorld { db_entities, dont_include_hashed_keys, order_by, + entity_models, ) .await?; Ok((entities, total_count)) @@ -670,7 +698,10 @@ impl DojoWorld { offset: Option, dont_include_hashed_keys: bool, order_by: Option<&str>, + entity_models: Vec, ) -> Result<(Vec, u32), Error> { + let entity_models = + entity_models.iter().map(|model| compute_selector_from_tag(model)).collect::>(); let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize) .expect("invalid comparison operator"); @@ -710,9 +741,15 @@ impl DojoWorld { let model_ids = models_str .split(',') - .map(Felt::from_str) - .collect::, _>>() - .map_err(ParseError::FromStr)?; + .filter_map(|id| { + let model_id = Felt::from_str(id).unwrap(); + if entity_models.is_empty() || entity_models.contains(&model_id) { + Some(model_id) + } else { + None + } + }) + .collect::>(); let schemas = self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); @@ -760,6 +797,7 @@ impl DojoWorld { offset: Option, dont_include_hashed_keys: bool, order_by: Option<&str>, + entity_models: Vec, ) -> Result<(Vec, u32), Error> { let (where_clause, having_clause, join_clause, bind_values) = build_composite_clause(table, model_relation_table, &composite)?; @@ -829,6 +867,7 @@ impl DojoWorld { db_entities, dont_include_hashed_keys, order_by, + entity_models, ) .await?; Ok((entities, total_count)) @@ -996,6 +1035,7 @@ impl DojoWorld { query.offset, query.dont_include_hashed_keys, order_by, + query.entity_models, ) .await? } @@ -1018,6 +1058,7 @@ impl DojoWorld { Some(query.offset), query.dont_include_hashed_keys, order_by, + query.entity_models, ) .await? } @@ -1031,6 +1072,7 @@ impl DojoWorld { Some(query.offset), query.dont_include_hashed_keys, order_by, + query.entity_models, ) .await? } @@ -1044,6 +1086,7 @@ impl DojoWorld { Some(query.offset), query.dont_include_hashed_keys, order_by, + query.entity_models, ) .await? } @@ -1057,6 +1100,7 @@ impl DojoWorld { Some(query.offset), query.dont_include_hashed_keys, order_by, + query.entity_models, ) .await? } diff --git a/crates/torii/grpc/src/server/tests/entities_test.rs b/crates/torii/grpc/src/server/tests/entities_test.rs index e76d502a10..8de8a92f55 100644 --- a/crates/torii/grpc/src/server/tests/entities_test.rs +++ b/crates/torii/grpc/src/server/tests/entities_test.rs @@ -142,6 +142,7 @@ async fn test_entities_queries(sequencer: &RunnerCtx) { None, false, None, + vec![], ) .await .unwrap() diff --git a/crates/torii/grpc/src/types/mod.rs b/crates/torii/grpc/src/types/mod.rs index aae2ae2150..13ff998900 100644 --- a/crates/torii/grpc/src/types/mod.rs +++ b/crates/torii/grpc/src/types/mod.rs @@ -105,6 +105,7 @@ pub struct Query { pub offset: u32, pub dont_include_hashed_keys: bool, pub order_by: Vec, + pub entity_models: Vec, } #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] @@ -269,6 +270,7 @@ impl From for proto::types::Query { offset: value.offset, dont_include_hashed_keys: value.dont_include_hashed_keys, order_by: value.order_by.into_iter().map(|o| o.into()).collect(), + entity_models: value.entity_models, } } }