Skip to content

Commit

Permalink
refactor: member clauses handle struct fields & operators & recursive
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Sep 12, 2024
1 parent bf4ea9b commit 28baa66
Showing 1 changed file with 172 additions and 122 deletions.
294 changes: 172 additions & 122 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use self::subscriptions::entity::EntityManager;
use self::subscriptions::event_message::EventMessageManager;
use self::subscriptions::model_diff::{ModelDiffRequest, StateDiffManager};
use crate::proto::types::clause::ClauseType;
use crate::proto::types::LogicalOperator;
use crate::proto::world::world_server::WorldServer;
use crate::proto::world::{
SubscribeEntitiesRequest, SubscribeEntityResponse, SubscribeEventsResponse,
Expand Down Expand Up @@ -497,67 +498,32 @@ impl DojoWorld {
limit: Option<u32>,
offset: Option<u32>,
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize)
.expect("invalid comparison operator");
let (where_clause, join_clause, having_clause, comparison_value, model_id) =
build_member_clause(table).await?;

let primitive: Primitive =
member_clause.value.ok_or(QueryError::MissingParam("value".into()))?.try_into()?;

let comparison_value = primitive.to_sql_value()?;

let (namespace, model) = member_clause
.model
.split_once('-')
.ok_or(QueryError::InvalidNamespacedModel(member_clause.model.clone()))?;

let models_query = format!(
r#"
SELECT group_concat({model_relation_table}.model_id) as model_ids
FROM {table}
JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id
GROUP BY {table}.id
HAVING INSTR(model_ids, '{:#x}') > 0
LIMIT 1
"#,
compute_selector_from_names(namespace, model)
);
let (models_str,): (String,) = sqlx::query_as(&models_query).fetch_one(&self.pool).await?;

let model_ids = models_str
.split(',')
.map(Felt::from_str)
.collect::<Result<Vec<_>, _>>()
.map_err(ParseError::FromStr)?;
let schemas =
self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect();

let table_name = member_clause.model;
let column_name = format!("external_{}", member_clause.member);
let schemas = self.fetch_schemas(table, model_relation_table, model_id).await?;
let (entity_query, arrays_queries, count_query) = build_sql_query(
&schemas,
table,
entity_relation_column,
Some(&format!("[{table_name}].{column_name} {comparison_operator} ?")),
Some(&where_clause),
None,
limit,
offset,
)?;

let total_count = sqlx::query_scalar(&count_query)
.bind(comparison_value.clone())
.fetch_one(&self.pool)
.await?;
let total_count =
sqlx::query_scalar(&count_query).bind(&comparison_value).fetch_one(&self.pool).await?;

let db_entities = sqlx::query(&entity_query)
.bind(comparison_value.clone())
.bind(&comparison_value)
.bind(limit)
.bind(offset)
.fetch_all(&self.pool)
.await?;
let mut arrays_rows = HashMap::new();
for (name, query) in arrays_queries {
let rows =
sqlx::query(&query).bind(comparison_value.clone()).fetch_all(&self.pool).await?;
let rows = sqlx::query(&query).bind(&comparison_value).fetch_all(&self.pool).await?;
arrays_rows.insert(name, rows);
}

Expand All @@ -568,7 +534,7 @@ impl DojoWorld {
Ok((entities_collection, total_count))
}

async fn query_by_composite(
pub(crate) async fn query_by_composite(
&self,
table: &str,
model_relation_table: &str,
Expand All @@ -577,92 +543,29 @@ impl DojoWorld {
limit: Option<u32>,
offset: Option<u32>,
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
// different types of clauses
let mut where_clauses = Vec::new();
let mut model_clauses: HashMap<String, Vec<(String, ComparisonOperator, String)>> =
HashMap::new();
let mut having_clauses = Vec::new();

// bind valeus for prepared statement
let mut bind_values = Vec::new();

for clause in composite.clauses {
match clause.clause_type.unwrap() {
ClauseType::HashedKeys(hashed_keys) => {
let ids = hashed_keys
.hashed_keys
.iter()
.map(|id| {
Ok(format!("{table}.id = '{:#x}'", Felt::from_bytes_be_slice(id)))
})
.collect::<Result<Vec<_>, Error>>()?;
where_clauses.push(format!("({})", ids.join(" OR ")));
}
ClauseType::Keys(keys) => {
let keys_pattern = build_keys_pattern(&keys)?;
where_clauses.push(format!("{table}.keys REGEXP '{keys_pattern}'"));
}
ClauseType::Member(member) => {
let comparison_operator =
ComparisonOperator::from_repr(member.operator as usize)
.expect("invalid comparison operator");
let value: Primitive = member.value.unwrap().try_into()?;
let comparison_value = value.to_sql_value()?;

let column_name = format!("external_{}", member.member);

model_clauses.entry(member.model.clone()).or_default().push((
column_name,
comparison_operator,
comparison_value,
));

let (namespace, model) = member
.model
.split_once('-')
.ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?;
let model_id: Felt = compute_selector_from_names(namespace, model);
having_clauses.push(format!("INSTR(model_ids, '{:#x}') > 0", model_id));
}
_ => return Err(QueryError::UnsupportedQuery.into()),
}
}
let (where_clause, having_clause, join_clause, bind_values, model_ids) =
self.build_composite_clause(table, model_relation_table, &composite)?;

let mut join_clauses = Vec::new();
for (model, clauses) in model_clauses {
let model_conditions = clauses
.into_iter()
.map(|(column, op, value)| {
bind_values.push(value);
format!("[{}].{} {} ?", model, column, op)
})
.collect::<Vec<_>>()
.join(" AND ");

join_clauses.push(format!(
"JOIN [{}] ON [{}].id = [{}].entity_id AND ({})",
model, table, model, model_conditions
));
}
let schemas = self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect();

let join_clause = join_clauses.join(" ");
let where_clause = if !where_clauses.is_empty() {
format!("WHERE {}", where_clauses.join(" AND "))
} else {
String::new()
};
let having_clause = if !having_clauses.is_empty() {
format!("HAVING {}", having_clauses.join(" AND "))
} else {
String::new()
};
let (entity_query, arrays_queries, count_query) = build_sql_query(
&schemas,
table,
entity_relation_column,
Some(&where_clause),
Some(&having_clause),
limit,
offset,
)?;

let count_query = format!(
r#"
SELECT COUNT(DISTINCT [{table}].id)
FROM [{table}]
JOIN {model_relation_table} ON [{table}].id = {model_relation_table}.entity_id
{join_clause}
{where_clause}
{having_clause}
"#
);

Expand Down Expand Up @@ -692,7 +595,7 @@ impl DojoWorld {
);

let mut db_query = sqlx::query_as(&query);
for value in bind_values {
for value in &bind_values {
db_query = db_query.bind(value);
}
db_query = db_query.bind(limit.unwrap_or(u32::MAX)).bind(offset.unwrap_or(0));
Expand Down Expand Up @@ -955,6 +858,39 @@ impl DojoWorld {
) -> Result<Receiver<Result<proto::world::SubscribeEventsResponse, tonic::Status>>, Error> {
self.event_manager.add_subscriber(clause.into()).await
}

async fn fetch_schemas(
&self,
table: &str,
model_relation_table: &str,
model_id: Felt,
) -> Result<Vec<Ty>, Error> {
let models_query = format!(
r#"
SELECT group_concat({model_relation_table}.model_id) as model_ids
FROM {table}
JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id
GROUP BY {table}.id
HAVING INSTR(model_ids, '{:#x}') > 0
LIMIT 1
"#,
model_id
);
let (models_str,): (String,) =
sqlx::query_as(&models_query).fetch_optional(&self.pool).await?;
if models_str.is_none() {
return Ok(vec![]);
}

let models_str = models_str.unwrap();
let model_ids = models_str
.split(',')
.map(Felt::from_str)
.collect::<Result<Vec<_>, _>>()
.map_err(ParseError::FromStr)?;

Ok(self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect())
}
}

fn process_event_field(data: &str) -> Result<Vec<Vec<u8>>, Error> {
Expand Down Expand Up @@ -1013,6 +949,120 @@ fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result<String, Error
Ok(keys_pattern)
}

fn build_member_clause(
member_clause: &proto::types::MemberClause,
) -> Result<(String, String, String, String, Felt), Error> {
let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize)
.expect("invalid comparison operator");

let primitive: Primitive = member_clause
.value
.as_ref()
.ok_or(QueryError::MissingParam("value".into()))?
.clone()
.try_into()?;

let comparison_value = primitive.to_sql_value()?;

let (namespace, model) = member_clause
.model
.split_once('-')
.ok_or(QueryError::InvalidNamespacedModel(member_clause.model.clone()))?;

let model_id = compute_selector_from_names(namespace, model);

let table_name = &member_clause.model;
let parts: Vec<&str> = member_clause.member.split('.').collect();
let (join_table_name, column_name) = if parts.len() > 1 {
let nested_table = parts[..parts.len() - 1].join("$");
(format!("[{table_name}${nested_table}]"), format!("external_{}", parts.last().unwrap()))
} else {
(format!("[{table_name}]"), format!("external_{}", member_clause.member))
};

let where_clause = format!("{join_table_name}.{column_name} {comparison_operator} ?");
let join_clause =
format!("LEFT JOIN {join_table_name} ON [{{table}}].id = {join_table_name}.entity_id");
let having_clause =
format!("INSTR(group_concat({{model_relation_table}}.model_id), '{:#x}') > 0", model_id);

Ok((where_clause, join_clause, having_clause, comparison_value, model_id))
}

fn build_composite_clause(
&self,
composite: &proto::types::CompositeClause,
) -> Result<(String, String, String, Vec<String>, Vec<Felt>), Error> {
let is_or = composite.operator == LogicalOperator::Or as i32;
let mut where_clauses = Vec::new();
let mut join_clauses = Vec::new();
let mut having_clauses = Vec::new();
let mut bind_values = Vec::new();
let mut model_ids = Vec::new();

for clause in &composite.clauses {
match clause.clause_type.as_ref().unwrap() {
ClauseType::HashedKeys(hashed_keys) => {
let ids = hashed_keys
.hashed_keys
.iter()
.map(|id| {
bind_values.push(Felt::from_bytes_be_slice(id).to_string());
"?".to_string()
})
.collect::<Vec<_>>()
.join(", ");
where_clauses.push(format!("{{table}}.id IN ({})", ids));
}
ClauseType::Keys(keys) => {
let keys_pattern = build_keys_pattern(keys)?;
bind_values.push(keys_pattern);
where_clauses.push(format!("{{table}}.keys REGEXP ?"));
}
ClauseType::Member(member) => {
let (member_where, member_join, member_having, member_value, member_model_id) =
self.build_member_clause(member)?;
where_clauses.push(member_where);
join_clauses.push(member_join);
having_clauses.push(member_having);
bind_values.push(member_value);
model_ids.push(member_model_id);
}
ClauseType::Composite(nested_composite) => {
let (nested_where, nested_having, nested_join, nested_values, nested_model_ids) =
self.build_composite_clause(nested_composite)?;
where_clauses.push(format!("({})", nested_where.trim_start_matches("WHERE ")));
if !nested_having.is_empty() {
having_clauses.push(nested_having.trim_start_matches("HAVING ").to_string());
}
join_clauses.extend(
nested_join
.split_whitespace()
.filter(|&s| s.starts_with("LEFT"))
.map(String::from),
);
bind_values.extend(nested_values);
model_ids.extend(nested_model_ids);
}
_ => return Err(QueryError::UnsupportedQuery.into()),
}
}

let join_clause = join_clauses.join(" ");
let where_clause = if !where_clauses.is_empty() {
format!("WHERE {}", where_clauses.join(if is_or { " OR " } else { " AND " }))
} else {
String::new()
};
let having_clause = if !having_clauses.is_empty() {
format!("HAVING {}", having_clauses.join(if is_or { " OR " } else { " AND " }))
} else {
String::new()
};

Ok((where_clause, having_clause, join_clause, bind_values, model_ids))
}

type ServiceResult<T> = Result<Response<T>, Status>;
type SubscribeModelsResponseStream =
Pin<Box<dyn Stream<Item = Result<SubscribeModelsResponse, Status>> + Send>>;
Expand Down

0 comments on commit 28baa66

Please sign in to comment.