diff --git a/crates/torii/libp2p/src/typed_data.rs b/crates/torii/libp2p/src/typed_data.rs index 6c6903723c..584737fd27 100644 --- a/crates/torii/libp2p/src/typed_data.rs +++ b/crates/torii/libp2p/src/typed_data.rs @@ -1,12 +1,14 @@ -use core::slice::SlicePattern; use std::collections::HashMap; -use crypto_bigint::U256; +use crypto_bigint::{Encoding, U256}; use dojo_types::schema::Ty; +use serde::{Deserialize, Serialize}; use starknet_core::utils::{cairo_short_string_to_felt, starknet_keccak}; use starknet_crypto::{pedersen_hash, poseidon_hash, poseidon_hash_many}; use starknet_ff::FieldElement; +use crate::errors::Error; + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SimpleField { pub name: String, @@ -27,6 +29,39 @@ pub struct TokenAmount { pub amount: U256, } +impl TokenAmount { + pub fn encode( + &self, + name: String, + types: HashMap>, + ) -> Result { + let mut hashes = Vec::new(); + + let type_hash = encode_type( + name, + vec![ + Field::SimpleType(SimpleField { + name: "token_address".to_string(), + r#type: "ContractAddress".to_string(), + }), + Field::SimpleType(SimpleField { + name: "amount".to_string(), + r#type: "u256".to_string(), + }), + ], + types, + ); + hashes.push(type_hash); + + hashes.push(self.token_address); + + let amount_hash = PrimitiveType::U256(self.amount).encode("amount".to_string(), types)?; + hashes.push(amount_hash); + + Ok(poseidon_hash_many(hashes.as_slice())) + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct NftId { // ContractAddress @@ -34,6 +69,39 @@ pub struct NftId { pub token_id: U256, } +impl NftId { + pub fn encode( + &self, + name: String, + types: HashMap>, + ) -> Result { + let mut hashes = Vec::new(); + + let type_hash = encode_type( + name, + vec![ + Field::SimpleType(SimpleField { + name: "collection_address".to_string(), + r#type: "ContractAddress".to_string(), + }), + Field::SimpleType(SimpleField { + name: "token_id".to_string(), + r#type: "u256".to_string(), + }), + ], + types, + ); + hashes.push(type_hash); + + hashes.push(self.collection_address); + let token_id = PrimitiveType::U256(self.token_id).encode("token_id".to_string(), types)?; + + hashes.push(token_id); + + Ok(poseidon_hash_many(hashes.as_slice())) + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub enum Field { SimpleType(SimpleField), @@ -56,151 +124,149 @@ pub enum PrimitiveType { U256(U256), // Maximum of 31 ascii characters ShortString(String), - // Enum(HashMap>), + Enum(HashMap>), NftId(NftId), TokenAmount(TokenAmount), } -impl PrimitiveType { - fn encode_type(&self, name: String, fields: Vec, types: HashMap>) -> FieldElement { - let mut type_hash = String::new(); +pub fn encode_type( + name: String, + fields: Vec, + types: HashMap>, +) -> FieldElement { + let mut type_hash = String::new(); - type_hash += &format!("\"{}\":", name); + type_hash += &format!("\"{}\":", name); - // add fields - type_hash += "("; - for field in types[name] { - match field { - Field::SimpleType(simple_field) => { - type_hash += &format!("\"{}\":\"{}\"", simple_field.name, simple_field.r#type); - } - Field::ParentType(parent_field) => { - // type_hash += &format!("\"{}\":\"{}\",", parent_field.name, parent_field.r#type); - // type_hash += &format!("\"{}\":\"{}\",", parent_field.name, parent_field.contains); - } + // add fields + type_hash += "("; + for (idx, field) in fields.iter().enumerate() { + match field { + Field::SimpleType(simple_field) => { + type_hash += &format!("\"{}\":\"{}\"", simple_field.name, simple_field.r#type); } - - if field != fields.last().unwrap() { - type_hash += ","; + Field::ParentType(parent_field) => { + return encode_type( + parent_field.contains, + types[&parent_field.contains].clone(), + types, + ); } } - type_hash += ")"; - - starknet_keccak(type_hash.as_bytes()) + if idx < fields.len() - 1 { + type_hash += ","; + } } - pub fn encode(&self, name: String, types: HashMap>) -> FieldElement { + type_hash += ")"; + + starknet_keccak(type_hash.as_bytes()) +} + +impl PrimitiveType { + pub fn encode( + &self, + name: String, + types: HashMap>, + ) -> Result { match self { PrimitiveType::Object(obj) => { let mut hashes = Vec::new(); - - let type_hash = self.encode_type(name, types[name], types); + + let type_hash = encode_type(name, types[&name], types); hashes.push(type_hash); for (field_name, value) in obj { - let field_hash = value.encode(field_name, types); + let field_hash = value.encode(*field_name, types)?; hashes.push(field_hash); } - poseidon_hash_many(hashes.as_slice()) + Ok(poseidon_hash_many(hashes.as_slice())) } - PrimitiveType::Array(array) => { - poseidon_hash_many(array.iter().map(|x| x.encode(name.clone(), types)).collect()) + PrimitiveType::Array(array) => Ok(poseidon_hash_many( + array + .iter() + .map(|x| x.encode(name.clone(), types)) + .collect::, _>>()? + .as_slice(), + )), + PrimitiveType::Enum(enum_map) => { + let mut hashes = Vec::new(); + + let type_hash = encode_type(name, types[&name], types); + hashes.push(type_hash); + + for (field_name, value) in enum_map { + let field_hash = poseidon_hash_many( + value + .iter() + .map(|x| x.encode(field_name.clone(), types)) + .collect::, _>>()? + .as_slice(), + ); + hashes.push(field_hash); + } + + Ok(poseidon_hash_many(hashes.as_slice())) } - PrimitiveType::FieldElement(field_element) => *field_element, + PrimitiveType::FieldElement(field_element) => Ok(*field_element), PrimitiveType::Bool(boolean) => { if *boolean { - FieldElement::from(1) + Ok(FieldElement::from(1 as u32)) } else { - FieldElement::from(0) + Ok(FieldElement::from(0 as u32)) } } - PrimitiveType::String(string) => poseidon_hash_many( + PrimitiveType::String(string) => Ok(poseidon_hash_many( string .as_bytes() .iter() .map(|x| FieldElement::from(*x as u128)) - .collect(), - ), - PrimitiveType::Selector(selector) => { - starknet_keccak(selector.as_bytes()) - } - PrimitiveType::U128(u128) => FieldElement::from(*u128), - PrimitiveType::I128(i128) => FieldElement::from(*i128 as u128), - PrimitiveType::ContractAddress(contract_address) => *contract_address, - PrimitiveType::ClassHash(class_hash) => *class_hash, - PrimitiveType::Timestamp(timestamp) => FieldElement::from(*timestamp), + .collect::>() + .as_slice(), + )), + PrimitiveType::Selector(selector) => Ok(starknet_keccak(selector.as_bytes())), + PrimitiveType::U128(u128) => Ok(FieldElement::from(*u128)), + PrimitiveType::I128(i128) => Ok(FieldElement::from(*i128 as u128)), + PrimitiveType::ContractAddress(contract_address) => Ok(*contract_address), + PrimitiveType::ClassHash(class_hash) => Ok(*class_hash), + PrimitiveType::Timestamp(timestamp) => Ok(FieldElement::from(*timestamp)), PrimitiveType::U256(u256) => { let mut hashes = Vec::new(); - - let type_hash = self.encode_type(name, vec![ - Field::SimpleType(SimpleField { - name: "low".to_string(), - r#type: "u128".to_string(), - }), - Field::SimpleType(SimpleField { - name: "high".to_string(), - r#type: "u128".to_string(), - }), - ], types); + + let type_hash = encode_type( + name, + vec![ + Field::SimpleType(SimpleField { + name: "low".to_string(), + r#type: "u128".to_string(), + }), + Field::SimpleType(SimpleField { + name: "high".to_string(), + r#type: "u128".to_string(), + }), + ], + types, + ); hashes.push(type_hash); - let low_hash = u256.low.encode("low".to_string(), types); + // lower half + let bytes = u256.to_be_bytes(); + let low_hash = + FieldElement::from(u128::from_be_bytes(bytes[0..16].try_into().unwrap())); hashes.push(low_hash); - let high_hash = u256.high.encode("high".to_string(), types); + let high_hash = + FieldElement::from(u128::from_be_bytes(bytes[16..32].try_into().unwrap())); hashes.push(high_hash); - poseidon_hash_many(hashes.as_slice()) - } - PrimitiveType::ShortString(short_string) => cairo_short_string_to_felt(&short_string), - PrimitiveType::NftId(nft_id) => { - let mut hashes = Vec::new(); - - let type_hash = self.encode_type(name, vec![ - Field::SimpleType(SimpleField { - name: "collection_address".to_string(), - r#type: "FieldElement".to_string(), - }), - Field::SimpleType(SimpleField { - name: "token_id".to_string(), - r#type: "U256".to_string(), - }), - ], types); - hashes.push(type_hash); - - let collection_address_hash = nft_id.collection_address.encode("collection_address".to_string(), types); - hashes.push(collection_address_hash); - - let token_id_hash = nft_id.token_id.encode("token_id".to_string(), types); - hashes.push(token_id_hash); - - poseidon_hash_many(hashes.as_slice()) - } - PrimitiveType::TokenAmount(token_amount) => { - let mut hashes = Vec::new(); - - let type_hash = self.encode_type(name, vec![ - Field::SimpleType(SimpleField { - name: "token_address".to_string(), - r#type: "FieldElement".to_string(), - }), - Field::SimpleType(SimpleField { - name: "amount".to_string(), - r#type: "U256".to_string(), - }), - ], types); - hashes.push(type_hash); - - let token_address_hash = token_amount.token_address.encode("token_address".to_string(), types); - hashes.push(token_address_hash); - - let amount_hash = token_amount.amount.encode("amount".to_string(), types); - hashes.push(amount_hash); - - poseidon_hash_many(hashes.as_slice()) + Ok(poseidon_hash_many(hashes.as_slice())) } + PrimitiveType::ShortString(short_string) => cairo_short_string_to_felt(&short_string) + .map_err(|_| Error::MessageValidationError("Invalid short string".to_string())), + PrimitiveType::NftId(nft_id) => nft_id.encode(name, types), + PrimitiveType::TokenAmount(token_amount) => token_amount.encode(name, types), } } } @@ -209,10 +275,23 @@ impl PrimitiveType { pub struct Domain { pub name: String, pub version: String, - pub chain_id: FieldElement, + pub chain_id: String, pub revision: String, } +impl Domain { + pub fn encode(&self, types: HashMap>) -> Result { + let mut object = HashMap::new(); + + object.insert("name".to_string(), PrimitiveType::ShortString(self.name.clone())); + object.insert("version".to_string(), PrimitiveType::ShortString(self.version.clone())); + object.insert("chain_id".to_string(), PrimitiveType::ShortString(self.chain_id.clone())); + object.insert("revision".to_string(), PrimitiveType::ShortString(self.revision.clone())); + + PrimitiveType::Object(object).encode("StarknetDomain".to_string(), types) + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TypedData { pub types: HashMap>, @@ -222,14 +301,21 @@ pub struct TypedData { } impl TypedData { - pub fn encode(&self) -> Result, Error> { - if self.domain.revision == 0 { - return Err(Error::InvalidMessage("Invalid revision".to_string())); + pub fn encode(&self, account: FieldElement) -> Result { + if self.domain.revision == "0" { + return Err(Error::MessageValidationError("Invalid revision".to_string())); } let prefix_message = starknet_keccak("StarkNet Message".as_bytes()); // encode domain separator - types["StarknetDomain"] + let domain_hash = self.domain.encode(self.types.clone())?; + + // encode message + let message_hash = PrimitiveType::Object(self.message) + .encode(self.primary_type.clone(), self.types.clone())?; + + // return full hash + Ok(poseidon_hash_many(vec![prefix_message, domain_hash, account, message_hash].as_slice())) } }