Skip to content

Commit

Permalink
refactor: member clauses handle struct fields & operators & recursive (
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo authored Sep 16, 2024
1 parent 435c2d1 commit 09203e8
Showing 1 changed file with 124 additions and 103 deletions.
227 changes: 124 additions & 103 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use self::subscriptions::event_message::EventMessageManager;
use self::subscriptions::model_diff::{ModelDiffRequest, StateDiffManager};
use crate::proto::types::clause::ClauseType;
use crate::proto::types::member_value::ValueType;
use crate::proto::types::LogicalOperator;
use crate::proto::world::world_server::WorldServer;
use crate::proto::world::{
SubscribeEntitiesRequest, SubscribeEntityResponse, SubscribeEventsResponse,
Expand Down Expand Up @@ -260,7 +261,6 @@ impl DojoWorld {
// total count of rows without limit and offset
let total_count: u32 =
sqlx::query_scalar(&count_query).fetch_optional(&self.pool).await?.unwrap_or(0);

if total_count == 0 {
return Ok((Vec::new(), 0));
}
Expand Down Expand Up @@ -382,7 +382,6 @@ impl DojoWorld {
.fetch_optional(&self.pool)
.await?
.unwrap_or(0);

if total_count == 0 {
return Ok((Vec::new(), 0));
}
Expand Down Expand Up @@ -531,15 +530,13 @@ impl DojoWorld {
"#,
compute_selector_from_names(namespace, model)
);

let models_result: Option<(String,)> =
sqlx::query_as(&models_query).fetch_optional(&self.pool).await?;
// we return an empty array of entities if the table is empty
if models_result.is_none() {
let models_str: Option<String> =
sqlx::query_scalar(&models_query).fetch_optional(&self.pool).await?;
if models_str.is_none() {
return Ok((Vec::new(), 0));
}

let (models_str,) = models_result.unwrap();
let models_str = models_str.unwrap();

let model_ids = models_str
.split(',')
Expand All @@ -549,8 +546,14 @@ impl DojoWorld {
let schemas =
self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect();

let table_name = member_clause.model;
let column_name = format!("external_{}", member_clause.member);
let model = member_clause.model.clone();
let parts: Vec<&str> = member_clause.member.split('.').collect();
let (table_name, column_name) = if parts.len() > 1 {
let nested_table = parts[..parts.len() - 1].join("$");
(format!("{model}${nested_table}"), format!("external_{}", parts.last().unwrap()))
} else {
(model, format!("external_{}", member_clause.member))
};
let (entity_query, arrays_queries, count_query) = build_sql_query(
&schemas,
table,
Expand All @@ -566,7 +569,6 @@ impl DojoWorld {
.fetch_optional(&self.pool)
.await?
.unwrap_or(0);

let db_entities = sqlx::query(&entity_query)
.bind(comparison_value.clone())
.bind(limit)
Expand All @@ -587,7 +589,7 @@ impl DojoWorld {
Ok((entities_collection, total_count))
}

async fn query_by_composite(
pub(crate) async fn query_by_composite(
&self,
table: &str,
model_relation_table: &str,
Expand All @@ -596,102 +598,17 @@ impl DojoWorld {
limit: Option<u32>,
offset: Option<u32>,
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
// different types of clauses
let mut where_clauses = Vec::new();
let mut model_clauses: HashMap<String, Vec<(String, ComparisonOperator, String)>> =
HashMap::new();
let mut having_clauses = Vec::new();

// bind valeus for prepared statement
let mut bind_values = Vec::new();

for clause in composite.clauses {
match clause.clause_type.unwrap() {
ClauseType::HashedKeys(hashed_keys) => {
let ids = hashed_keys
.hashed_keys
.iter()
.map(|id| {
Ok(format!("{table}.id = '{:#x}'", Felt::from_bytes_be_slice(id)))
})
.collect::<Result<Vec<_>, Error>>()?;
where_clauses.push(format!("({})", ids.join(" OR ")));
}
ClauseType::Keys(keys) => {
let keys_pattern = build_keys_pattern(&keys)?;
where_clauses.push(format!("{table}.keys REGEXP '{keys_pattern}'"));
}
ClauseType::Member(member) => {
let comparison_operator =
ComparisonOperator::from_repr(member.operator as usize)
.expect("invalid comparison operator");
let comparison_value = match member
.value
.ok_or(QueryError::MissingParam("value".into()))?
.value_type
{
Some(ValueType::String(value)) => value,
Some(ValueType::Primitive(value)) => {
let primitive: Primitive = value.try_into()?;
primitive.to_sql_value()?
}
None => return Err(QueryError::MissingParam("value_type".into()).into()),
};

let column_name = format!("external_{}", member.member);

model_clauses.entry(member.model.clone()).or_default().push((
column_name,
comparison_operator,
comparison_value,
));

let (namespace, model) = member
.model
.split_once('-')
.ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?;
let model_id: Felt = compute_selector_from_names(namespace, model);
having_clauses.push(format!("INSTR(model_ids, '{:#x}') > 0", model_id));
}
_ => return Err(QueryError::UnsupportedQuery.into()),
}
}

let mut join_clauses = Vec::new();
for (model, clauses) in model_clauses {
let model_conditions = clauses
.into_iter()
.map(|(column, op, value)| {
bind_values.push(value);
format!("[{}].{} {} ?", model, column, op)
})
.collect::<Vec<_>>()
.join(" AND ");

join_clauses.push(format!(
"JOIN [{}] ON [{}].id = [{}].entity_id AND ({})",
model, table, model, model_conditions
));
}

let join_clause = join_clauses.join(" ");
let where_clause = if !where_clauses.is_empty() {
format!("WHERE {}", where_clauses.join(" AND "))
} else {
String::new()
};
let having_clause = if !having_clauses.is_empty() {
format!("HAVING {}", having_clauses.join(" AND "))
} else {
String::new()
};
let (where_clause, having_clause, join_clause, bind_values) =
build_composite_clause(table, model_relation_table, &composite)?;

let count_query = format!(
r#"
SELECT COUNT(DISTINCT [{table}].id)
FROM [{table}]
JOIN {model_relation_table} ON [{table}].id = {model_relation_table}.entity_id
{join_clause}
{where_clause}
{having_clause}
"#
);

Expand All @@ -701,7 +618,6 @@ impl DojoWorld {
}

let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0);

if total_count == 0 {
return Ok((Vec::new(), 0));
}
Expand All @@ -721,7 +637,7 @@ impl DojoWorld {
);

let mut db_query = sqlx::query_as(&query);
for value in bind_values {
for value in &bind_values {
db_query = db_query.bind(value);
}
db_query = db_query.bind(limit.unwrap_or(u32::MAX)).bind(offset.unwrap_or(0));
Expand Down Expand Up @@ -1042,6 +958,111 @@ fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result<String, Error
Ok(keys_pattern)
}

// builds a composite clause for a query
fn build_composite_clause(
table: &str,
model_relation_table: &str,
composite: &proto::types::CompositeClause,
) -> Result<(String, String, String, Vec<String>), Error> {
let is_or = composite.operator == LogicalOperator::Or as i32;
let mut where_clauses = Vec::new();
let mut join_clauses = Vec::new();
let mut having_clauses = Vec::new();
let mut bind_values = Vec::new();

for clause in &composite.clauses {
match clause.clause_type.as_ref().unwrap() {
ClauseType::HashedKeys(hashed_keys) => {
let ids = hashed_keys
.hashed_keys
.iter()
.map(|id| {
bind_values.push(Felt::from_bytes_be_slice(id).to_string());
"?".to_string()
})
.collect::<Vec<_>>()
.join(", ");
where_clauses.push(format!("{table}.id IN ({})", ids));
}
ClauseType::Keys(keys) => {
let keys_pattern = build_keys_pattern(keys)?;
bind_values.push(keys_pattern);
where_clauses.push(format!("{table}.keys REGEXP ?"));
}
ClauseType::Member(member) => {
let comparison_operator = ComparisonOperator::from_repr(member.operator as usize)
.expect("invalid comparison operator");
let value = member.value.clone();
let comparison_value =
match value.ok_or(QueryError::MissingParam("value".into()))?.value_type {
Some(ValueType::String(value)) => value,
Some(ValueType::Primitive(value)) => {
let primitive: Primitive = value.try_into()?;
primitive.to_sql_value()?
}
None => return Err(QueryError::MissingParam("value_type".into()).into()),
};
bind_values.push(comparison_value);

let model = member.model.clone();
let parts: Vec<&str> = member.member.split('.').collect();
let (table_name, column_name) = if parts.len() > 1 {
let nested_table = parts[..parts.len() - 1].join("$");
(
format!("[{model}${nested_table}]"),
format!("external_{}", parts.last().unwrap()),
)
} else {
(format!("[{model}]"), format!("external_{}", member.member))
};

let (namespace, model) = member
.model
.split_once('-')
.ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?;
let model_id = compute_selector_from_names(namespace, model);
join_clauses.push(format!(
"LEFT JOIN {table_name} ON [{table}].id = {table_name}.entity_id"
));
where_clauses.push(format!("{table_name}.{column_name} {comparison_operator} ?"));
having_clauses.push(format!(
"INSTR(group_concat({model_relation_table}.model_id), '{:#x}') > 0",
model_id
));
}
ClauseType::Composite(nested_composite) => {
let (nested_where, nested_having, nested_join, nested_values) =
build_composite_clause(table, model_relation_table, nested_composite)?;
where_clauses.push(format!("({})", nested_where.trim_start_matches("WHERE ")));
if !nested_having.is_empty() {
having_clauses.push(nested_having.trim_start_matches("HAVING ").to_string());
}
join_clauses.extend(
nested_join
.split_whitespace()
.filter(|&s| s.starts_with("LEFT"))
.map(String::from),
);
bind_values.extend(nested_values);
}
}
}

let join_clause = join_clauses.join(" ");
let where_clause = if !where_clauses.is_empty() {
format!("WHERE {}", where_clauses.join(if is_or { " OR " } else { " AND " }))
} else {
String::new()
};
let having_clause = if !having_clauses.is_empty() {
format!("HAVING {}", having_clauses.join(if is_or { " OR " } else { " AND " }))
} else {
String::new()
};

Ok((where_clause, having_clause, join_clause, bind_values))
}

type ServiceResult<T> = Result<Response<T>, Status>;
type SubscribeModelsResponseStream =
Pin<Box<dyn Stream<Item = Result<SubscribeModelsResponse, Status>> + Send>>;
Expand Down

0 comments on commit 09203e8

Please sign in to comment.