diff --git a/crates/torii/core/src/model.rs b/crates/torii/core/src/model.rs index bf7d70de0d..5030128013 100644 --- a/crates/torii/core/src/model.rs +++ b/crates/torii/core/src/model.rs @@ -398,6 +398,155 @@ pub fn map_row_to_ty( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub async fn fetch_entities( + pool: &Pool, + schemas: &[Ty], + table_name: &str, + model_relation_table: &str, + entity_relation_column: &str, + where_clause: Option<&str>, + having_clause: Option<&str>, + order_by: Option<&str>, + limit: Option, + offset: Option, + bind_values: Vec, +) -> Result<(Vec, u32), Error> { + // Helper function to collect columns (existing implementation) + fn collect_columns(table_prefix: &str, path: &str, ty: &Ty, selections: &mut Vec) { + match ty { + Ty::Struct(s) => { + for child in &s.children { + let new_path = if path.is_empty() { + child.name.clone() + } else { + format!("{}.{}", path, child.name) + }; + collect_columns(table_prefix, &new_path, &child.ty, selections); + } + } + Ty::Tuple(t) => { + for (i, child) in t.iter().enumerate() { + let new_path = + if path.is_empty() { format!("{}", i) } else { format!("{}.{}", path, i) }; + collect_columns(table_prefix, &new_path, child, selections); + } + } + Ty::Enum(e) => { + // Add the enum variant column with table prefix and alias + selections.push(format!("[{table_prefix}].[{path}] as \"{table_prefix}.{path}\"",)); + + // Add columns for each variant's value (if not empty tuple) + for option in &e.options { + if let Ty::Tuple(t) = &option.ty { + if t.is_empty() { + continue; + } + } + let variant_path = format!("{}.{}", path, option.name); + collect_columns(table_prefix, &variant_path, &option.ty, selections); + } + } + Ty::Array(_) | Ty::Primitive(_) | Ty::ByteArray(_) => { + selections.push(format!("[{table_prefix}].[{path}] as \"{table_prefix}.{path}\"",)); + } + } + } + + const MAX_JOINS: usize = 64; + let schema_chunks = schemas.chunks(MAX_JOINS); + let mut total_count = 0; + let mut all_rows = Vec::new(); + + for chunk in schema_chunks { + let mut selections = Vec::new(); + let mut joins = Vec::new(); + + // Add base table columns + selections.push(format!("{}.id", table_name)); + selections.push(format!("{}.keys", table_name)); + selections.push(format!("group_concat({model_relation_table}.model_id) as model_ids")); + + // Process each model schema in the chunk + for model in chunk { + let model_table = model.name(); + joins.push(format!( + "LEFT JOIN [{model_table}] ON {table_name}.id = \ + [{model_table}].{entity_relation_column}" + )); + collect_columns(&model_table, "", model, &mut selections); + } + + joins.push(format!( + "JOIN {model_relation_table} ON {table_name}.id = {model_relation_table}.entity_id" + )); + + let selections_clause = selections.join(", "); + let joins_clause = joins.join(" "); + + // Build count query + let count_query = format!( + "SELECT COUNT(*) FROM (SELECT {}.id, group_concat({}.model_id) as model_ids FROM [{}] \ + {} {} GROUP BY {}.id {})", + table_name, + model_relation_table, + table_name, + joins_clause, + where_clause.map_or(String::new(), |w| format!(" WHERE {}", w)), + table_name, + having_clause.map_or(String::new(), |h| format!(" HAVING {}", h)) + ); + + // Execute count query + let mut count_stmt = sqlx::query_scalar(&count_query); + for value in &bind_values { + count_stmt = count_stmt.bind(value); + } + let chunk_count: u32 = count_stmt.fetch_one(pool).await?; + total_count += chunk_count; + + if chunk_count > 0 { + // Build main query + let mut query = + format!("SELECT {} FROM [{}] {}", selections_clause, table_name, joins_clause); + + if let Some(where_clause) = where_clause { + query += &format!(" WHERE {}", where_clause); + } + + query += &format!(" GROUP BY {table_name}.id"); + + if let Some(having_clause) = having_clause { + query += &format!(" HAVING {}", having_clause); + } + + if let Some(order_clause) = order_by { + query += &format!(" ORDER BY {}", order_clause); + } else { + query += &format!(" ORDER BY {}.event_id DESC", table_name); + } + + if let Some(limit) = limit { + query += &format!(" LIMIT {}", limit); + } + + if let Some(offset) = offset { + query += &format!(" OFFSET {}", offset); + } + + // Execute main query + let mut stmt = sqlx::query(&query); + for value in &bind_values { + stmt = stmt.bind(value); + } + let chunk_rows = stmt.fetch_all(pool).await?; + all_rows.extend(chunk_rows); + } + } + + Ok((all_rows, total_count)) +} + #[cfg(test)] mod tests { use dojo_types::schema::{Enum, EnumOption, Member, Struct, Ty}; diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index b0e7d4ac45..54cccf0b9e 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -42,7 +42,7 @@ use tonic::transport::Server; use tonic::{Request, Response, Status}; use tonic_web::GrpcWebLayer; use torii_core::error::{Error, ParseError, QueryError}; -use torii_core::model::{build_sql_query, map_row_to_ty}; +use torii_core::model::{fetch_entities, map_row_to_ty}; use torii_core::sql::cache::ModelCache; use torii_core::types::{Token, TokenBalance}; use tower_http::cors::{AllowOrigin, CorsLayer}; @@ -312,13 +312,10 @@ impl DojoWorld { ) -> Result<(Vec, u32), Error> { let where_clause = match &hashed_keys { Some(hashed_keys) => { - let ids = hashed_keys - .hashed_keys - .iter() - .map(|_| Ok("{table}.id = ?")) - .collect::, Error>>()?; + let ids = + hashed_keys.hashed_keys.iter().map(|_| "{table}.id = ?").collect::>(); format!( - "WHERE {} {}", + "{} {}", ids.join(" OR "), if entity_updated_after.is_some() { format!("AND {table}.updated_at >= ?") @@ -329,19 +326,20 @@ impl DojoWorld { } None => { if entity_updated_after.is_some() { - format!("WHERE {table}.updated_at >= ?") + format!("{table}.updated_at >= ?") } else { String::new() } } }; + let mut bind_values = vec![]; if let Some(hashed_keys) = hashed_keys { bind_values = hashed_keys .hashed_keys .iter() .map(|key| format!("{:#x}", Felt::from_bytes_be_slice(key))) - .collect::>() + .collect::>(); } if let Some(entity_updated_after) = entity_updated_after.clone() { bind_values.push(entity_updated_after); @@ -363,49 +361,59 @@ impl DojoWorld { .collect::>() .join(" OR "); - let (query, count_query) = build_sql_query( - &schemas, - table, - model_relation_table, - entity_relation_column, - if where_clause.is_empty() { None } else { Some(&where_clause) }, - if !having_clause.is_empty() { Some(&having_clause) } else { None }, - order_by, - limit, - offset, - )?; - - let mut count_query = sqlx::query_scalar(&count_query); - for value in &bind_values { - count_query = count_query.bind(value); - } - let total_count = count_query.fetch_one(&self.pool).await?; - if total_count == 0 { - return Ok((Vec::new(), 0)); - } - if table == EVENT_MESSAGES_HISTORICAL_TABLE { - let entities = - self.fetch_historical_event_messages(&format!( + let count_query = format!( + r#" + SELECT COUNT(*) FROM {table} + JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id + WHERE {where_clause} + GROUP BY {table}.event_id + "# + ); + let mut total_count = sqlx::query_scalar(&count_query); + for value in &bind_values { + total_count = total_count.bind(value); + } + let total_count = total_count.fetch_one(&self.pool).await?; + if total_count == 0 { + return Ok((Vec::new(), 0)); + } + + let entities = self.fetch_historical_event_messages( + &format!( r#" SELECT {table}.id, {table}.data, {table}.model_id, group_concat({model_relation_table}.model_id) as model_ids FROM {table} JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id - {where_clause} + WHERE {where_clause} GROUP BY {table}.event_id ORDER BY {table}.event_id DESC "# - ), bind_values, limit, offset).await?; + ), + bind_values, + limit, + offset + ).await?; return Ok((entities, total_count)); } - let mut query = sqlx::query(&query); - for value in &bind_values { - query = query.bind(value); - } - let entities = query.fetch_all(&self.pool).await?; - let entities = entities - .iter() + let (rows, total_count) = fetch_entities( + &self.pool, + &schemas, + table, + model_relation_table, + entity_relation_column, + if !where_clause.is_empty() { Some(&where_clause) } else { None }, + if !having_clause.is_empty() { Some(&having_clause) } else { None }, + order_by, + limit, + offset, + bind_values, + ) + .await?; + + let entities = rows + .par_iter() .map(|row| map_row_to_entity(row, &schemas, dont_include_hashed_keys)) .collect::, Error>>()?; @@ -457,28 +465,25 @@ impl DojoWorld { .map(|model| format!("INSTR(model_ids, '{:#x}') > 0", model)) .collect::>() .join(" OR "); - let (query, count_query) = build_sql_query( - &schemas, - table, - model_relation_table, - entity_relation_column, - Some(&where_clause), - if !having_clause.is_empty() { Some(&having_clause) } else { None }, - order_by, - limit, - offset, - )?; - - let mut count_query = sqlx::query_scalar(&count_query); - for value in &bind_values { - count_query = count_query.bind(value); - } - let total_count = count_query.fetch_one(&self.pool).await?; - if total_count == 0 { - return Ok((Vec::new(), 0)); - } if table == EVENT_MESSAGES_HISTORICAL_TABLE { + let count_query = format!( + r#" + SELECT COUNT(*) FROM {table} + JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id + WHERE {where_clause} + GROUP BY {table}.event_id + "# + ); + let mut total_count = sqlx::query_scalar(&count_query); + for value in &bind_values { + total_count = total_count.bind(value); + } + let total_count = total_count.fetch_one(&self.pool).await?; + if total_count == 0 { + return Ok((Vec::new(), 0)); + } + let entities = self.fetch_historical_event_messages( &format!( r#" @@ -497,13 +502,23 @@ impl DojoWorld { return Ok((entities, total_count)); } - let mut query = sqlx::query(&query); - for value in &bind_values { - query = query.bind(value); - } - let entities = query.fetch_all(&self.pool).await?; - let entities = entities - .iter() + let (rows, total_count) = fetch_entities( + &self.pool, + &schemas, + table, + model_relation_table, + entity_relation_column, + Some(&where_clause), + if !having_clause.is_empty() { Some(&having_clause) } else { None }, + order_by, + limit, + offset, + bind_values, + ) + .await?; + + let entities = rows + .par_iter() .map(|row| map_row_to_entity(row, &schemas, dont_include_hashed_keys)) .collect::, Error>>()?; @@ -618,8 +633,13 @@ impl DojoWorld { } }) .collect::>(); - let schemas = - self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); + let schemas = self + .model_cache + .models(&model_ids) + .await? + .into_iter() + .map(|m| m.schema) + .collect::>(); // Use the member name directly as the column name since it's already flattened let mut bind_values = Vec::new(); @@ -633,9 +653,11 @@ impl DojoWorld { ); if entity_updated_after.is_some() { where_clause += &format!(" AND {table}.updated_at >= ?"); + bind_values.push(entity_updated_after.unwrap()); } - let (entity_query, count_query) = build_sql_query( + let (rows, total_count) = fetch_entities( + &self.pool, &schemas, table, model_relation_table, @@ -645,31 +667,16 @@ impl DojoWorld { order_by, limit, offset, - )?; - let mut count_query = sqlx::query_scalar(&count_query); - for value in &bind_values { - count_query = count_query.bind(value); - } - if let Some(entity_updated_after) = entity_updated_after.clone() { - count_query = count_query.bind(entity_updated_after); - } - let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0); - - let mut query = sqlx::query(&entity_query); - for value in &bind_values { - query = query.bind(value); - } - if let Some(entity_updated_after) = entity_updated_after.clone() { - query = query.bind(entity_updated_after); - } - query = query.bind(limit).bind(offset); - let db_entities = query.fetch_all(&self.pool).await?; + bind_values, + ) + .await?; - let entities_collection: Result, Error> = db_entities + let entities = rows .par_iter() .map(|row| map_row_to_entity(row, &schemas, dont_include_hashed_keys)) - .collect(); - Ok((entities_collection?, total_count)) + .collect::, Error>>()?; + + Ok((entities, total_count)) } #[allow(clippy::too_many_arguments)] @@ -691,8 +698,13 @@ impl DojoWorld { let entity_models = entity_models.iter().map(|model| compute_selector_from_tag(model)).collect::>(); - let schemas = - self.model_cache.models(&entity_models).await?.into_iter().map(|m| m.schema).collect(); + let schemas = self + .model_cache + .models(&entity_models) + .await? + .into_iter() + .map(|m| m.schema) + .collect::>(); let having_clause = entity_models .iter() @@ -700,7 +712,8 @@ impl DojoWorld { .collect::>() .join(" OR "); - let (query, count_query) = build_sql_query( + let (rows, total_count) = fetch_entities( + &self.pool, &schemas, table, model_relation_table, @@ -710,29 +723,15 @@ impl DojoWorld { order_by, limit, offset, - )?; - - let mut count_query = sqlx::query_scalar(&count_query); - for value in &bind_values { - count_query = count_query.bind(value); - } - - let total_count = count_query.fetch_one(&self.pool).await?; - if total_count == 0 { - return Ok((Vec::new(), 0)); - } - - println!("query: {}", query); - let mut query = sqlx::query(&query); - for value in &bind_values { - query = query.bind(value); - } - let db_entities = query.fetch_all(&self.pool).await?; + bind_values, + ) + .await?; - let entities = db_entities + let entities = rows .par_iter() .map(|row| map_row_to_entity(row, &schemas, dont_include_hashed_keys)) .collect::, Error>>()?; + Ok((entities, total_count)) }