From 28baa660a7464ef1a6241738295e047abce4ef13 Mon Sep 17 00:00:00 2001 From: Nasr Date: Thu, 12 Sep 2024 13:58:47 -0400 Subject: [PATCH] refactor: member clauses handle struct fields & operators & recursive --- crates/torii/grpc/src/server/mod.rs | 294 ++++++++++++++++------------ 1 file changed, 172 insertions(+), 122 deletions(-) diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index d6f6894d67..09c586b97d 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -41,6 +41,7 @@ use self::subscriptions::entity::EntityManager; use self::subscriptions::event_message::EventMessageManager; use self::subscriptions::model_diff::{ModelDiffRequest, StateDiffManager}; use crate::proto::types::clause::ClauseType; +use crate::proto::types::LogicalOperator; use crate::proto::world::world_server::WorldServer; use crate::proto::world::{ SubscribeEntitiesRequest, SubscribeEntityResponse, SubscribeEventsResponse, @@ -497,67 +498,32 @@ impl DojoWorld { limit: Option, offset: Option, ) -> Result<(Vec, u32), Error> { - let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize) - .expect("invalid comparison operator"); + let (where_clause, join_clause, having_clause, comparison_value, model_id) = + build_member_clause(table).await?; - let primitive: Primitive = - member_clause.value.ok_or(QueryError::MissingParam("value".into()))?.try_into()?; - - let comparison_value = primitive.to_sql_value()?; - - let (namespace, model) = member_clause - .model - .split_once('-') - .ok_or(QueryError::InvalidNamespacedModel(member_clause.model.clone()))?; - - let models_query = format!( - r#" - SELECT group_concat({model_relation_table}.model_id) as model_ids - FROM {table} - JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id - GROUP BY {table}.id - HAVING INSTR(model_ids, '{:#x}') > 0 - LIMIT 1 - "#, - compute_selector_from_names(namespace, model) - ); - let (models_str,): (String,) = sqlx::query_as(&models_query).fetch_one(&self.pool).await?; - - let model_ids = models_str - .split(',') - .map(Felt::from_str) - .collect::, _>>() - .map_err(ParseError::FromStr)?; - let schemas = - self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); - - let table_name = member_clause.model; - let column_name = format!("external_{}", member_clause.member); + let schemas = self.fetch_schemas(table, model_relation_table, model_id).await?; let (entity_query, arrays_queries, count_query) = build_sql_query( &schemas, table, entity_relation_column, - Some(&format!("[{table_name}].{column_name} {comparison_operator} ?")), + Some(&where_clause), None, limit, offset, )?; - let total_count = sqlx::query_scalar(&count_query) - .bind(comparison_value.clone()) - .fetch_one(&self.pool) - .await?; + let total_count = + sqlx::query_scalar(&count_query).bind(&comparison_value).fetch_one(&self.pool).await?; let db_entities = sqlx::query(&entity_query) - .bind(comparison_value.clone()) + .bind(&comparison_value) .bind(limit) .bind(offset) .fetch_all(&self.pool) .await?; let mut arrays_rows = HashMap::new(); for (name, query) in arrays_queries { - let rows = - sqlx::query(&query).bind(comparison_value.clone()).fetch_all(&self.pool).await?; + let rows = sqlx::query(&query).bind(&comparison_value).fetch_all(&self.pool).await?; arrays_rows.insert(name, rows); } @@ -568,7 +534,7 @@ impl DojoWorld { Ok((entities_collection, total_count)) } - async fn query_by_composite( + pub(crate) async fn query_by_composite( &self, table: &str, model_relation_table: &str, @@ -577,92 +543,29 @@ impl DojoWorld { limit: Option, offset: Option, ) -> Result<(Vec, u32), Error> { - // different types of clauses - let mut where_clauses = Vec::new(); - let mut model_clauses: HashMap> = - HashMap::new(); - let mut having_clauses = Vec::new(); - - // bind valeus for prepared statement - let mut bind_values = Vec::new(); - - for clause in composite.clauses { - match clause.clause_type.unwrap() { - ClauseType::HashedKeys(hashed_keys) => { - let ids = hashed_keys - .hashed_keys - .iter() - .map(|id| { - Ok(format!("{table}.id = '{:#x}'", Felt::from_bytes_be_slice(id))) - }) - .collect::, Error>>()?; - where_clauses.push(format!("({})", ids.join(" OR "))); - } - ClauseType::Keys(keys) => { - let keys_pattern = build_keys_pattern(&keys)?; - where_clauses.push(format!("{table}.keys REGEXP '{keys_pattern}'")); - } - ClauseType::Member(member) => { - let comparison_operator = - ComparisonOperator::from_repr(member.operator as usize) - .expect("invalid comparison operator"); - let value: Primitive = member.value.unwrap().try_into()?; - let comparison_value = value.to_sql_value()?; - - let column_name = format!("external_{}", member.member); - - model_clauses.entry(member.model.clone()).or_default().push(( - column_name, - comparison_operator, - comparison_value, - )); - - let (namespace, model) = member - .model - .split_once('-') - .ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?; - let model_id: Felt = compute_selector_from_names(namespace, model); - having_clauses.push(format!("INSTR(model_ids, '{:#x}') > 0", model_id)); - } - _ => return Err(QueryError::UnsupportedQuery.into()), - } - } + let (where_clause, having_clause, join_clause, bind_values, model_ids) = + self.build_composite_clause(table, model_relation_table, &composite)?; - let mut join_clauses = Vec::new(); - for (model, clauses) in model_clauses { - let model_conditions = clauses - .into_iter() - .map(|(column, op, value)| { - bind_values.push(value); - format!("[{}].{} {} ?", model, column, op) - }) - .collect::>() - .join(" AND "); - - join_clauses.push(format!( - "JOIN [{}] ON [{}].id = [{}].entity_id AND ({})", - model, table, model, model_conditions - )); - } + let schemas = self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); - let join_clause = join_clauses.join(" "); - let where_clause = if !where_clauses.is_empty() { - format!("WHERE {}", where_clauses.join(" AND ")) - } else { - String::new() - }; - let having_clause = if !having_clauses.is_empty() { - format!("HAVING {}", having_clauses.join(" AND ")) - } else { - String::new() - }; + let (entity_query, arrays_queries, count_query) = build_sql_query( + &schemas, + table, + entity_relation_column, + Some(&where_clause), + Some(&having_clause), + limit, + offset, + )?; let count_query = format!( r#" SELECT COUNT(DISTINCT [{table}].id) FROM [{table}] + JOIN {model_relation_table} ON [{table}].id = {model_relation_table}.entity_id {join_clause} {where_clause} + {having_clause} "# ); @@ -692,7 +595,7 @@ impl DojoWorld { ); let mut db_query = sqlx::query_as(&query); - for value in bind_values { + for value in &bind_values { db_query = db_query.bind(value); } db_query = db_query.bind(limit.unwrap_or(u32::MAX)).bind(offset.unwrap_or(0)); @@ -955,6 +858,39 @@ impl DojoWorld { ) -> Result>, Error> { self.event_manager.add_subscriber(clause.into()).await } + + async fn fetch_schemas( + &self, + table: &str, + model_relation_table: &str, + model_id: Felt, + ) -> Result, Error> { + let models_query = format!( + r#" + SELECT group_concat({model_relation_table}.model_id) as model_ids + FROM {table} + JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id + GROUP BY {table}.id + HAVING INSTR(model_ids, '{:#x}') > 0 + LIMIT 1 + "#, + model_id + ); + let (models_str,): (String,) = + sqlx::query_as(&models_query).fetch_optional(&self.pool).await?; + if models_str.is_none() { + return Ok(vec![]); + } + + let models_str = models_str.unwrap(); + let model_ids = models_str + .split(',') + .map(Felt::from_str) + .collect::, _>>() + .map_err(ParseError::FromStr)?; + + Ok(self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect()) + } } fn process_event_field(data: &str) -> Result>, Error> { @@ -1013,6 +949,120 @@ fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result Result<(String, String, String, String, Felt), Error> { + let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize) + .expect("invalid comparison operator"); + + let primitive: Primitive = member_clause + .value + .as_ref() + .ok_or(QueryError::MissingParam("value".into()))? + .clone() + .try_into()?; + + let comparison_value = primitive.to_sql_value()?; + + let (namespace, model) = member_clause + .model + .split_once('-') + .ok_or(QueryError::InvalidNamespacedModel(member_clause.model.clone()))?; + + let model_id = compute_selector_from_names(namespace, model); + + let table_name = &member_clause.model; + let parts: Vec<&str> = member_clause.member.split('.').collect(); + let (join_table_name, column_name) = if parts.len() > 1 { + let nested_table = parts[..parts.len() - 1].join("$"); + (format!("[{table_name}${nested_table}]"), format!("external_{}", parts.last().unwrap())) + } else { + (format!("[{table_name}]"), format!("external_{}", member_clause.member)) + }; + + let where_clause = format!("{join_table_name}.{column_name} {comparison_operator} ?"); + let join_clause = + format!("LEFT JOIN {join_table_name} ON [{{table}}].id = {join_table_name}.entity_id"); + let having_clause = + format!("INSTR(group_concat({{model_relation_table}}.model_id), '{:#x}') > 0", model_id); + + Ok((where_clause, join_clause, having_clause, comparison_value, model_id)) +} + +fn build_composite_clause( + &self, + composite: &proto::types::CompositeClause, +) -> Result<(String, String, String, Vec, Vec), Error> { + let is_or = composite.operator == LogicalOperator::Or as i32; + let mut where_clauses = Vec::new(); + let mut join_clauses = Vec::new(); + let mut having_clauses = Vec::new(); + let mut bind_values = Vec::new(); + let mut model_ids = Vec::new(); + + for clause in &composite.clauses { + match clause.clause_type.as_ref().unwrap() { + ClauseType::HashedKeys(hashed_keys) => { + let ids = hashed_keys + .hashed_keys + .iter() + .map(|id| { + bind_values.push(Felt::from_bytes_be_slice(id).to_string()); + "?".to_string() + }) + .collect::>() + .join(", "); + where_clauses.push(format!("{{table}}.id IN ({})", ids)); + } + ClauseType::Keys(keys) => { + let keys_pattern = build_keys_pattern(keys)?; + bind_values.push(keys_pattern); + where_clauses.push(format!("{{table}}.keys REGEXP ?")); + } + ClauseType::Member(member) => { + let (member_where, member_join, member_having, member_value, member_model_id) = + self.build_member_clause(member)?; + where_clauses.push(member_where); + join_clauses.push(member_join); + having_clauses.push(member_having); + bind_values.push(member_value); + model_ids.push(member_model_id); + } + ClauseType::Composite(nested_composite) => { + let (nested_where, nested_having, nested_join, nested_values, nested_model_ids) = + self.build_composite_clause(nested_composite)?; + where_clauses.push(format!("({})", nested_where.trim_start_matches("WHERE "))); + if !nested_having.is_empty() { + having_clauses.push(nested_having.trim_start_matches("HAVING ").to_string()); + } + join_clauses.extend( + nested_join + .split_whitespace() + .filter(|&s| s.starts_with("LEFT")) + .map(String::from), + ); + bind_values.extend(nested_values); + model_ids.extend(nested_model_ids); + } + _ => return Err(QueryError::UnsupportedQuery.into()), + } + } + + let join_clause = join_clauses.join(" "); + let where_clause = if !where_clauses.is_empty() { + format!("WHERE {}", where_clauses.join(if is_or { " OR " } else { " AND " })) + } else { + String::new() + }; + let having_clause = if !having_clauses.is_empty() { + format!("HAVING {}", having_clauses.join(if is_or { " OR " } else { " AND " })) + } else { + String::new() + }; + + Ok((where_clause, having_clause, join_clause, bind_values, model_ids)) +} + type ServiceResult = Result, Status>; type SubscribeModelsResponseStream = Pin> + Send>>;