Skip to content

Commit

Permalink
Feat(torii-grpc): Add total_count on RetrieveEntitiesResponse (#1545)
Browse files Browse the repository at this point in the history
* Feat: Add total_count on RetrieveEntitiesResponse

* Feat: add count_rows

* fmt

* Fix sql count_query
  • Loading branch information
gianalarcon authored Feb 24, 2024
1 parent 640e94d commit 8c5eecd
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 12 deletions.
3 changes: 2 additions & 1 deletion crates/torii/client/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ impl Client {
/// type of entites matching keys and/or models.
pub async fn entities(&self, query: Query) -> Result<Vec<Entity>, Error> {
let mut grpc_client = self.inner.write().await;
let RetrieveEntitiesResponse { entities } = grpc_client.retrieve_entities(query).await?;
let RetrieveEntitiesResponse { entities, total_count: _ } =
grpc_client.retrieve_entities(query).await?;
Ok(entities.into_iter().map(TryInto::try_into).collect::<Result<Vec<Entity>, _>>()?)
}

Expand Down
1 change: 1 addition & 0 deletions crates/torii/grpc/proto/world.proto
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ message RetrieveEntitiesRequest {

message RetrieveEntitiesResponse {
repeated types.Entity entities = 1;
uint32 total_count = 2;
}
60 changes: 49 additions & 11 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl DojoWorld {
&self,
limit: u32,
offset: u32,
) -> Result<Vec<proto::types::Entity>, Error> {
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
self.entities_by_hashed_keys(None, limit, offset).await
}

Expand All @@ -131,7 +131,7 @@ impl DojoWorld {
hashed_keys: Option<proto::types::HashedKeysClause>,
limit: u32,
offset: u32,
) -> Result<Vec<proto::types::Entity>, Error> {
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
// TODO: use prepared statement for where clause
let filter_ids = match hashed_keys {
Some(hashed_keys) => {
Expand All @@ -150,6 +150,18 @@ impl DojoWorld {
None => String::new(),
};

// count query that matches filter_ids
let count_query = format!(
r#"
SELECT count(*)
FROM entities
{filter_ids}
"#
);
// total count of rows without limit and offset
let total_count: u32 = sqlx::query_scalar(&count_query).fetch_one(&self.pool).await?;

// query to filter with limit and offset
let query = format!(
r#"
SELECT entities.id, group_concat(entity_model.model_id) as model_names
Expand Down Expand Up @@ -190,15 +202,15 @@ impl DojoWorld {
})
}

Ok(entities)
Ok((entities, total_count))
}

async fn entities_by_keys(
&self,
keys_clause: proto::types::KeysClause,
limit: u32,
offset: u32,
) -> Result<Vec<proto::types::Entity>, Error> {
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
let keys = keys_clause
.keys
.iter()
Expand All @@ -213,6 +225,20 @@ impl DojoWorld {
.collect::<Result<Vec<_>, Error>>()?;
let keys_pattern = keys.join("/") + "/%";

let count_query = format!(
r#"
SELECT count(*)
FROM entities
JOIN entity_model ON entities.id = entity_model.entity_id
WHERE entity_model.model_id = '{}' and entities.keys LIKE ?
"#,
keys_clause.model
);

// total count of rows that matches keys_pattern without limit and offset
let total_count =
sqlx::query_scalar(&count_query).bind(&keys_pattern).fetch_one(&self.pool).await?;

let models_query = format!(
r#"
SELECT group_concat(entity_model.model_id) as model_names
Expand All @@ -231,6 +257,7 @@ impl DojoWorld {
let model_names = models_str.split(',').collect::<Vec<&str>>();
let schemas = self.model_cache.schemas(model_names).await?;

// query to filter with limit and offset
let entities_query = format!(
"{} WHERE entities.keys LIKE ? ORDER BY entities.event_id DESC LIMIT ? OFFSET ?",
build_sql_query(&schemas)?
Expand All @@ -242,15 +269,21 @@ impl DojoWorld {
.fetch_all(&self.pool)
.await?;

db_entities.iter().map(|row| Self::map_row_to_entity(row, &schemas)).collect()
Ok((
db_entities
.iter()
.map(|row| Self::map_row_to_entity(row, &schemas))
.collect::<Result<Vec<_>, Error>>()?,
total_count,
))
}

async fn entities_by_member(
&self,
member_clause: proto::types::MemberClause,
_limit: u32,
_offset: u32,
) -> Result<Vec<proto::types::Entity>, Error> {
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize)
.expect("invalid comparison operator");

Expand Down Expand Up @@ -299,16 +332,21 @@ impl DojoWorld {

let db_entities =
sqlx::query(&member_query).bind(comparison_value).fetch_all(&self.pool).await?;

db_entities.iter().map(|row| Self::map_row_to_entity(row, &schemas)).collect()
let entities_collection = db_entities
.iter()
.map(|row| Self::map_row_to_entity(row, &schemas))
.collect::<Result<Vec<_>, Error>>()?;
// Since there is not limit and offset, total_count is same as number of entities
let total_count = entities_collection.len() as u32;
Ok((entities_collection, total_count))
}

async fn entities_by_composite(
&self,
_composite: proto::types::CompositeClause,
_limit: u32,
_offset: u32,
) -> Result<Vec<proto::types::Entity>, Error> {
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
// TODO: Implement
Err(QueryError::UnsupportedQuery.into())
}
Expand Down Expand Up @@ -378,7 +416,7 @@ impl DojoWorld {
&self,
query: proto::types::Query,
) -> Result<proto::world::RetrieveEntitiesResponse, Error> {
let entities = match query.clause {
let (entities, total_count) = match query.clause {
None => self.entities_all(query.limit, query.offset).await?,
Some(clause) => {
let clause_type =
Expand Down Expand Up @@ -414,7 +452,7 @@ impl DojoWorld {
}
};

Ok(RetrieveEntitiesResponse { entities })
Ok(RetrieveEntitiesResponse { entities, total_count })
}

fn map_row_to_entity(row: &SqliteRow, schemas: &[Ty]) -> Result<proto::types::Entity, Error> {
Expand Down

0 comments on commit 8c5eecd

Please sign in to comment.