From 6d27661ce975f508ca8134238f1ff8e94e0c01b0 Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Mon, 13 Jan 2025 00:05:31 +0800 Subject: [PATCH 1/9] feat: impl reqpool --- reqpool/Cargo.toml | 33 +++++ reqpool/src/config.rs | 10 ++ reqpool/src/lib.rs | 17 +++ reqpool/src/macros.rs | 44 ++++++ reqpool/src/memory_pool.rs | 115 +++++++++++++++ reqpool/src/redis_pool.rs | 191 +++++++++++++++++++++++++ reqpool/src/request.rs | 286 +++++++++++++++++++++++++++++++++++++ reqpool/src/traits.rs | 45 ++++++ reqpool/src/utils.rs | 27 ++++ 9 files changed, 768 insertions(+) create mode 100644 reqpool/Cargo.toml create mode 100644 reqpool/src/config.rs create mode 100644 reqpool/src/lib.rs create mode 100644 reqpool/src/macros.rs create mode 100644 reqpool/src/memory_pool.rs create mode 100644 reqpool/src/redis_pool.rs create mode 100644 reqpool/src/request.rs create mode 100644 reqpool/src/traits.rs create mode 100644 reqpool/src/utils.rs diff --git a/reqpool/Cargo.toml b/reqpool/Cargo.toml new file mode 100644 index 000000000..5f41576a9 --- /dev/null +++ b/reqpool/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "raiko-reqpool" +version = "0.1.0" +authors = ["Taiko Labs"] +edition = "2021" + +[dependencies] +raiko-lib = { workspace = true } +raiko-core = { workspace = true } +raiko-redis-derive = { workspace = true } +num_enum = { workspace = true } +chrono = { workspace = true, features = ["serde"] } +thiserror = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +serde_with = { workspace = true } +hex = { workspace = true } +tracing = { workspace = true } +anyhow = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } +redis = { workspace = true } +backoff = { workspace = true } +derive-getters = { workspace = true } +proc-macro2 = { workspace = true } +quote = { workspace = true } +syn = { workspace = true } +alloy-primitives = { workspace = true } + +[dev-dependencies] +rand = "0.9.0-alpha.1" # This is an alpha version, that has rng.gen_iter::() +rand_chacha = "0.9.0-alpha.1" +tempfile = "3.10.1" diff --git a/reqpool/src/config.rs b/reqpool/src/config.rs new file mode 100644 index 000000000..0050daa3c --- /dev/null +++ b/reqpool/src/config.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// The configuration for the redis-backend request pool +pub struct RedisPoolConfig { + /// The URL of the Redis database, e.g. "redis://localhost:6379" + pub redis_url: String, + /// The TTL of the Redis database + pub redis_ttl: u64, +} diff --git a/reqpool/src/lib.rs b/reqpool/src/lib.rs new file mode 100644 index 000000000..6aa88c3b9 --- /dev/null +++ b/reqpool/src/lib.rs @@ -0,0 +1,17 @@ +mod config; +mod macros; +mod memory_pool; +mod redis_pool; +mod request; +mod traits; +mod utils; + +// Re-export +pub use config::RedisPoolConfig; +pub use redis_pool::RedisPool; +pub use request::{ + AggregationRequestEntity, AggregationRequestKey, RequestEntity, RequestKey, + SingleProofRequestEntity, SingleProofRequestKey, Status, StatusWithContext, +}; +pub use traits::{Pool, PoolResult, PoolWithTrace}; +pub use utils::proof_key_to_hack_request_key; diff --git a/reqpool/src/macros.rs b/reqpool/src/macros.rs new file mode 100644 index 000000000..fb36b349c --- /dev/null +++ b/reqpool/src/macros.rs @@ -0,0 +1,44 @@ +/// This macro implements the Display trait for a type by using serde_json's pretty printing. +/// If the type cannot be serialized to JSON, it falls back to using Debug formatting. +/// +/// # Example +/// +/// ```rust +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Debug, Serialize, Deserialize)] +/// struct Person { +/// name: String, +/// age: u32 +/// } +/// +/// impl_display_using_json_pretty!(Person); +/// +/// let person = Person { +/// name: "John".to_string(), +/// age: 30 +/// }; +/// +/// // Will print: +/// // { +/// // "name": "John", +/// // "age": 30 +/// // } +/// println!("{}", person); +/// ``` +/// +/// The type must implement serde's Serialize trait for JSON serialization to work. +/// If serialization fails, it will fall back to using the Debug implementation. +#[macro_export] +macro_rules! impl_display_using_json_pretty { + ($type:ty) => { + impl std::fmt::Display for $type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match serde_json::to_string_pretty(self) { + Ok(s) => write!(f, "{}", s), + Err(_) => write!(f, "{:?}", self), + } + } + } + }; +} diff --git a/reqpool/src/memory_pool.rs b/reqpool/src/memory_pool.rs new file mode 100644 index 000000000..410da5d12 --- /dev/null +++ b/reqpool/src/memory_pool.rs @@ -0,0 +1,115 @@ +// use std::collections::HashMap; + +// use chrono::Utc; + +// use crate::{ +// request::{RequestEntity, RequestKey, Status, StatusWithContext}, +// traits::{Pool, PoolWithTrace}, +// }; + +// #[derive(Debug, Clone)] +// pub struct MemoryPool { +// /// The live requests in the pool +// pending: HashMap, +// /// The trace of requests +// trace: Vec<(RequestKey, RequestEntity, StatusWithContext)>, +// } + +// impl Pool for MemoryPool { +// type Config = (); + +// fn new(_config: Self::Config) -> Self { +// Self { +// lives: HashMap::new(), +// trace: Vec::new(), +// } +// } + +// fn add(&mut self, request_key: RequestKey, request_entity: RequestEntity) { +// let status = StatusWithContext::new(Status::Registered, Utc::now()); + +// let old = self.lives.insert( +// request_key.clone(), +// (request_entity.clone(), status.clone()), +// ); + +// if let Some((_, old_status)) = old { +// tracing::error!( +// "MemoryPool.add: request key already exists, {request_key:?}, old status: {old_status:?}" +// ); +// } else { +// tracing::info!("MemoryPool.add, {request_key:?}, status: {status:?}"); +// } + +// self.trace.push((request_key, request_entity, status)); +// } + +// fn remove(&mut self, request_key: &RequestKey) { +// match self.lives.remove(request_key) { +// Some((_, status)) => { +// tracing::info!("MemoryPool.remove, {request_key:?}, status: {status:?}"); +// } +// None => { +// tracing::error!("MemoryPool.remove: request key not found, {request_key:?}"); +// } +// } +// } + +// fn get(&self, request_key: &RequestKey) -> Option<(RequestEntity, StatusWithContext)> { +// self.lives.get(request_key).cloned() +// } + +// fn get_status(&self, request_key: &RequestKey) -> Option { +// self.lives +// .get(request_key) +// .map(|(_, status)| status.clone()) +// } + +// fn update_status(&mut self, request_key: &RequestKey, status: StatusWithContext) { +// match self.lives.remove(request_key) { +// Some((entity, old_status)) => { +// tracing::info!( +// "MemoryPool.update_status, {request_key:?}, old status: {old_status:?}, new status: {status:?}" +// ); +// self.lives +// .insert(request_key.clone(), (entity.clone(), status.clone())); +// self.trace.push((request_key.clone(), entity, status)); +// } +// None => { +// tracing::error!( +// "MemoryPool.update_status: request key not found, discard it, {request_key:?}" +// ); +// } +// } +// } +// } + +// impl PoolWithTrace for MemoryPool { +// fn get_all_live(&self) -> Vec<(RequestKey, RequestEntity, StatusWithContext)> { +// self.lives +// .iter() +// .map(|(k, v)| (k.clone(), v.0.clone(), v.1.clone())) +// .collect() +// } + +// fn get_all_trace(&self) -> Vec<(RequestKey, RequestEntity, StatusWithContext)> { +// self.trace.clone() +// } + +// fn trace( +// &self, +// request_key: &RequestKey, +// ) -> ( +// Option<(RequestEntity, StatusWithContext)>, +// Vec<(RequestKey, RequestEntity, StatusWithContext)>, +// ) { +// let live = self.lives.get(request_key).cloned(); +// let traces = self +// .trace +// .iter() +// .filter(|(k, _, _)| k == request_key) +// .cloned() +// .collect(); +// (live, traces) +// } +// } diff --git a/reqpool/src/redis_pool.rs b/reqpool/src/redis_pool.rs new file mode 100644 index 000000000..f52ba218d --- /dev/null +++ b/reqpool/src/redis_pool.rs @@ -0,0 +1,191 @@ +use crate::{ + impl_display_using_json_pretty, proof_key_to_hack_request_key, Pool, PoolResult, + RedisPoolConfig, RequestEntity, RequestKey, StatusWithContext, +}; +use backoff::{exponential::ExponentialBackoff, SystemClock}; +use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; +use raiko_redis_derive::RedisValue; +use redis::{Client, Commands, RedisResult}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Debug, Clone)] +pub struct RedisPool { + client: Client, + config: RedisPoolConfig, +} + +impl Pool for RedisPool { + fn add( + &mut self, + request_key: RequestKey, + request_entity: RequestEntity, + status: StatusWithContext, + ) -> PoolResult<()> { + tracing::info!("RedisPool.add: {request_key}, {status}"); + let request_entity_and_status = RequestEntityAndStatus { + entity: request_entity, + status, + }; + self.conn() + .map_err(|e| e.to_string())? + .set_ex( + request_key, + request_entity_and_status, + self.config.redis_ttl, + ) + .map_err(|e| e.to_string())?; + Ok(()) + } + + fn remove(&mut self, request_key: &RequestKey) -> PoolResult { + tracing::info!("RedisPool.remove: {request_key}"); + let result: usize = self + .conn() + .map_err(|e| e.to_string())? + .del(request_key) + .map_err(|e| e.to_string())?; + Ok(result) + } + + fn get( + &mut self, + request_key: &RequestKey, + ) -> PoolResult> { + let result: RedisResult = + self.conn().map_err(|e| e.to_string())?.get(request_key); + match result { + Ok(value) => Ok(Some(value.into())), + Err(e) if e.kind() == redis::ErrorKind::TypeError => Ok(None), + Err(e) => Err(e.to_string()), + } + } + + fn get_status(&mut self, request_key: &RequestKey) -> PoolResult> { + self.get(request_key).map(|v| v.map(|v| v.1)) + } + + fn update_status( + &mut self, + request_key: RequestKey, + status: StatusWithContext, + ) -> PoolResult { + tracing::info!("RedisPool.update_status: {request_key}, {status}"); + match self.get(&request_key)? { + Some((entity, old_status)) => { + self.add(request_key, entity, status)?; + Ok(old_status) + } + None => Err("Request not found".to_string()), + } + } +} + +#[async_trait::async_trait] +impl IdStore for RedisPool { + async fn read_id(&mut self, proof_key: ProofKey) -> ProverResult { + let hack_request_key = proof_key_to_hack_request_key(proof_key); + + tracing::info!("RedisPool.read_id: {hack_request_key}"); + + let result: RedisResult = self + .conn() + .map_err(|e| e.to_string())? + .get(hack_request_key); + match result { + Ok(value) => Ok(value.into()), + Err(e) => Err(ProverError::StoreError(e.to_string())), + } + } +} + +#[async_trait::async_trait] +impl IdWrite for RedisPool { + async fn store_id(&mut self, proof_key: ProofKey, id: String) -> ProverResult<()> { + let hack_request_key = proof_key_to_hack_request_key(proof_key); + + tracing::info!("RedisPool.store_id: {hack_request_key}, {id}"); + + self.conn() + .map_err(|e| e.to_string())? + .set_ex(hack_request_key, id, self.config.redis_ttl) + .map_err(|e| ProverError::StoreError(e.to_string()))?; + Ok(()) + } + + async fn remove_id(&mut self, proof_key: ProofKey) -> ProverResult<()> { + let hack_request_key = proof_key_to_hack_request_key(proof_key); + + tracing::info!("RedisPool.remove_id: {hack_request_key}"); + + self.conn() + .map_err(|e| e.to_string())? + .del(hack_request_key) + .map_err(|e| ProverError::StoreError(e.to_string()))?; + Ok(()) + } +} + +impl RedisPool { + pub fn open(config: RedisPoolConfig) -> Result { + tracing::info!("RedisPool.open: connecting to redis: {}", config.redis_url); + + let client = Client::open(config.redis_url.clone())?; + Ok(Self { client, config }) + } + + fn conn(&mut self) -> Result { + let backoff: ExponentialBackoff = ExponentialBackoff { + initial_interval: Duration::from_secs(10), + max_interval: Duration::from_secs(60), + max_elapsed_time: Some(Duration::from_secs(300)), + ..Default::default() + }; + + backoff::retry(backoff, || match self.client.get_connection() { + Ok(conn) => Ok(conn), + Err(e) => { + tracing::error!( + "RedisPool.get_connection: failed to connect to redis: {e:?}, retrying..." + ); + + self.client = redis::Client::open(self.config.redis_url.clone())?; + Err(backoff::Error::Transient { + err: e, + retry_after: None, + }) + } + }) + .map_err(|e| match e { + backoff::Error::Transient { + err, + retry_after: _, + } + | backoff::Error::Permanent(err) => err, + }) + } +} + +/// A internal wrapper for request entity and status, used for redis serialization +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue)] +struct RequestEntityAndStatus { + entity: RequestEntity, + status: StatusWithContext, +} + +impl From<(RequestEntity, StatusWithContext)> for RequestEntityAndStatus { + fn from(value: (RequestEntity, StatusWithContext)) -> Self { + Self { + entity: value.0, + status: value.1, + } + } +} + +impl From for (RequestEntity, StatusWithContext) { + fn from(value: RequestEntityAndStatus) -> Self { + (value.entity, value.status) + } +} + +impl_display_using_json_pretty!(RequestEntityAndStatus); diff --git a/reqpool/src/request.rs b/reqpool/src/request.rs new file mode 100644 index 000000000..5a71aeb5b --- /dev/null +++ b/reqpool/src/request.rs @@ -0,0 +1,286 @@ +use crate::impl_display_using_json_pretty; +use alloy_primitives::Address; +use chrono::{DateTime, Utc}; +use derive_getters::Getters; +use raiko_core::interfaces::ProverSpecificOpts; +use raiko_lib::{ + input::BlobProofType, + primitives::{ChainId, B256}, + proof_type::ProofType, + prover::Proof, +}; +use raiko_redis_derive::RedisValue; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, DisplayFromStr}; +use std::collections::HashMap; + +#[derive(RedisValue, PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord)] +#[serde(rename_all = "snake_case")] +/// The status of a request +pub enum Status { + // === Normal status === + /// The request is registered but not yet started + Registered, + + /// The request is in progress + WorkInProgress, + + // /// The request is in progress of proving + // WorkInProgressProving { + // /// The proof ID + // /// For SP1 and RISC0 proof type, it is the proof ID returned by the network prover, + // /// otherwise, it should be empty. + // proof_id: String, + // }, + /// The request is successful + Success { + /// The proof of the request + proof: Proof, + }, + + // === Cancelled status === + /// The request is cancelled + Cancelled, + + // === Error status === + /// The request is failed with an error + Failed { + /// The error message + error: String, + }, +} + +impl Status { + pub fn is_success(&self) -> bool { + matches!(self, Status::Success { .. }) + } +} + +#[derive( + PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, RedisValue, Getters, +)] +/// The status of a request with context +pub struct StatusWithContext { + /// The status of the request + status: Status, + /// The timestamp of the status + timestamp: DateTime, +} + +impl StatusWithContext { + pub fn new(status: Status, timestamp: DateTime) -> Self { + Self { status, timestamp } + } + + pub fn new_registered() -> Self { + Self::new(Status::Registered, chrono::Utc::now()) + } + + pub fn new_cancelled() -> Self { + Self::new(Status::Cancelled, chrono::Utc::now()) + } + + pub fn into_status(self) -> Status { + self.status + } +} + +impl From for StatusWithContext { + fn from(status: Status) -> Self { + Self::new(status, chrono::Utc::now()) + } +} + +/// The key to identify a request in the pool +#[derive( + PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, Hash, RedisValue, +)] +pub enum RequestKey { + SingleProof(SingleProofRequestKey), + Aggregation(AggregationRequestKey), +} + +impl RequestKey { + pub fn proof_type(&self) -> &ProofType { + match self { + RequestKey::SingleProof(key) => &key.proof_type, + RequestKey::Aggregation(key) => &key.proof_type, + } + } +} + +/// The key to identify a request in the pool +#[derive( + PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, Hash, RedisValue, Getters, +)] +pub struct SingleProofRequestKey { + /// The chain ID of the request + chain_id: ChainId, + /// The block number of the request + block_number: u64, + /// The block hash of the request + block_hash: B256, + /// The proof type of the request + proof_type: ProofType, + /// The prover of the request + prover_address: String, +} + +impl SingleProofRequestKey { + pub fn new( + chain_id: ChainId, + block_number: u64, + block_hash: B256, + proof_type: ProofType, + prover_address: String, + ) -> Self { + Self { + chain_id, + block_number, + block_hash, + proof_type, + prover_address, + } + } +} + +#[derive( + PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, Hash, RedisValue, Getters, +)] +/// The key to identify an aggregation request in the pool +pub struct AggregationRequestKey { + // TODO add chain_id + proof_type: ProofType, + block_numbers: Vec, +} + +impl AggregationRequestKey { + pub fn new(proof_type: ProofType, block_numbers: Vec) -> Self { + Self { + proof_type, + block_numbers, + } + } +} + +impl From for RequestKey { + fn from(key: SingleProofRequestKey) -> Self { + RequestKey::SingleProof(key) + } +} + +impl From for RequestKey { + fn from(key: AggregationRequestKey) -> Self { + RequestKey::Aggregation(key) + } +} + +#[serde_as] +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue, Getters)] +pub struct SingleProofRequestEntity { + /// The block number for the block to generate a proof for. + block_number: u64, + /// The l1 block number of the l2 block be proposed. + l1_inclusion_block_number: u64, + /// The network to generate the proof for. + network: String, + /// The L1 network to generate the proof for. + l1_network: String, + /// Graffiti. + graffiti: B256, + /// The protocol instance data. + #[serde_as(as = "DisplayFromStr")] + prover: Address, + /// The proof type. + proof_type: ProofType, + /// Blob proof type. + blob_proof_type: BlobProofType, + #[serde(flatten)] + /// Additional prover params. + prover_args: HashMap, +} + +impl SingleProofRequestEntity { + pub fn new( + block_number: u64, + l1_inclusion_block_number: u64, + network: String, + l1_network: String, + graffiti: B256, + prover: Address, + proof_type: ProofType, + blob_proof_type: BlobProofType, + prover_args: HashMap, + ) -> Self { + Self { + block_number, + l1_inclusion_block_number, + network, + l1_network, + graffiti, + prover, + proof_type, + blob_proof_type, + prover_args, + } + } +} + +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue, Getters)] +pub struct AggregationRequestEntity { + /// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for. + aggregation_ids: Vec, + /// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for. + proofs: Vec, + /// The proof type. + proof_type: ProofType, + #[serde(flatten)] + /// Any additional prover params in JSON format. + prover_args: ProverSpecificOpts, +} + +impl AggregationRequestEntity { + pub fn new( + aggregation_ids: Vec, + proofs: Vec, + proof_type: ProofType, + prover_args: ProverSpecificOpts, + ) -> Self { + Self { + aggregation_ids, + proofs, + proof_type, + prover_args, + } + } +} + +/// The entity of a request +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue)] +pub enum RequestEntity { + SingleProof(SingleProofRequestEntity), + Aggregation(AggregationRequestEntity), +} + +impl From for RequestEntity { + fn from(entity: SingleProofRequestEntity) -> Self { + RequestEntity::SingleProof(entity) + } +} + +impl From for RequestEntity { + fn from(entity: AggregationRequestEntity) -> Self { + RequestEntity::Aggregation(entity) + } +} + +// === impl Display using json_pretty === + +impl_display_using_json_pretty!(Status); +impl_display_using_json_pretty!(StatusWithContext); +impl_display_using_json_pretty!(RequestKey); +impl_display_using_json_pretty!(SingleProofRequestKey); +impl_display_using_json_pretty!(AggregationRequestKey); +impl_display_using_json_pretty!(RequestEntity); +impl_display_using_json_pretty!(SingleProofRequestEntity); +impl_display_using_json_pretty!(AggregationRequestEntity); diff --git a/reqpool/src/traits.rs b/reqpool/src/traits.rs new file mode 100644 index 000000000..243ecfb37 --- /dev/null +++ b/reqpool/src/traits.rs @@ -0,0 +1,45 @@ +use crate::request::{RequestEntity, RequestKey, StatusWithContext}; + +pub type PoolResult = Result; + +/// Pool maintains the requests and their statuses +pub trait Pool: Send + Sync + Clone { + /// Add a new request to the pool + fn add( + &mut self, + request_key: RequestKey, + request_entity: RequestEntity, + status: StatusWithContext, + ) -> PoolResult<()>; + + /// Remove a request from the pool, return the number of requests removed + fn remove(&mut self, request_key: &RequestKey) -> PoolResult; + + /// Get a request and status from the pool + fn get( + &mut self, + request_key: &RequestKey, + ) -> PoolResult>; + + /// Get the status of a request + fn get_status(&mut self, request_key: &RequestKey) -> PoolResult>; + + /// Update the status of a request, return the old status + fn update_status( + &mut self, + request_key: RequestKey, + status: StatusWithContext, + ) -> PoolResult; +} + +/// A pool extension that supports tracing +pub trait PoolWithTrace: Pool { + /// Get all trace of requests, with the given max depth. + fn trace_all(&self, max_depth: usize) -> Vec<(RequestKey, RequestEntity, StatusWithContext)>; + + /// Get the live entity and trace of a request + fn trace( + &self, + request_key: &RequestKey, + ) -> Vec<(RequestKey, RequestEntity, StatusWithContext)>; +} diff --git a/reqpool/src/utils.rs b/reqpool/src/utils.rs new file mode 100644 index 000000000..ba9d5c3d7 --- /dev/null +++ b/reqpool/src/utils.rs @@ -0,0 +1,27 @@ +use raiko_lib::{proof_type::ProofType, prover::ProofKey}; + +use crate::{RequestKey, SingleProofRequestKey}; + +/// Returns the proof key corresponding to the request key. +/// +/// During proving, the prover will store the network proof id into pool, which is identified by **proof key**. This +/// function is used to generate a unique proof key corresponding to the request key, so that we can store the +/// proof key into the pool. +/// +/// Note that this is a hack, and it should be removed in the future. +pub fn proof_key_to_hack_request_key(proof_key: ProofKey) -> RequestKey { + let (chain_id, block_number, block_hash, proof_type) = proof_key; + + // HACK: Use a special prover address as a mask, to distinguish from real + // RequestKeys + let hack_prover_address = String::from("0x1231231231231231231231231231231231231231"); + + SingleProofRequestKey::new( + chain_id, + block_number, + block_hash, + ProofType::try_from(proof_type).expect("unsupported proof type, it should not happen at proof_key_to_hack_request_key, please issue a bug report"), + hack_prover_address, + ) + .into() +} From 4900f2f7254e4100db0632a2a7840902de7eb60d Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Tue, 14 Jan 2025 20:29:00 +0800 Subject: [PATCH 2/9] feat(reqpool): mock pool --- reqpool/Cargo.toml | 1 + reqpool/src/lib.rs | 2 + reqpool/src/mock.rs | 95 +++++++++++++++++++++++++++++++++++++++ reqpool/src/redis_pool.rs | 17 +++++-- 4 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 reqpool/src/mock.rs diff --git a/reqpool/Cargo.toml b/reqpool/Cargo.toml index 5f41576a9..d5a1fc7b4 100644 --- a/reqpool/Cargo.toml +++ b/reqpool/Cargo.toml @@ -31,3 +31,4 @@ alloy-primitives = { workspace = true } rand = "0.9.0-alpha.1" # This is an alpha version, that has rng.gen_iter::() rand_chacha = "0.9.0-alpha.1" tempfile = "3.10.1" +lazy_static = { workspace = true } diff --git a/reqpool/src/lib.rs b/reqpool/src/lib.rs index 6aa88c3b9..6b2056c21 100644 --- a/reqpool/src/lib.rs +++ b/reqpool/src/lib.rs @@ -1,6 +1,8 @@ mod config; mod macros; mod memory_pool; +#[cfg(test)] +mod mock; mod redis_pool; mod request; mod traits; diff --git a/reqpool/src/mock.rs b/reqpool/src/mock.rs new file mode 100644 index 000000000..fc88eb457 --- /dev/null +++ b/reqpool/src/mock.rs @@ -0,0 +1,95 @@ +use lazy_static::lazy_static; +use redis::{RedisError, RedisResult}; +use serde::Serialize; +use serde_json::{json, Value}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +type SingleStorage = Arc>>; +type GlobalStorage = Mutex>; + +lazy_static! { + // #{redis_url => single_storage} + static ref GLOBAL_STORAGE: GlobalStorage = Mutex::new(HashMap::new()); +} + +pub struct MockRedisConnection { + storage: SingleStorage, +} + +impl MockRedisConnection { + pub(crate) fn new(redis_url: String) -> Self { + let mut global = GLOBAL_STORAGE.lock().unwrap(); + Self { + storage: global + .entry(redis_url) + .or_insert_with(|| Arc::new(Mutex::new(HashMap::new()))) + .clone(), + } + } + + pub fn set_ex( + &mut self, + key: K, + val: V, + _ttl: u64, + ) -> RedisResult<()> { + let mut lock = self.storage.lock().unwrap(); + lock.insert(json!(key), json!(val)); + Ok(()) + } + + pub fn get(&mut self, key: &K) -> RedisResult { + let lock = self.storage.lock().unwrap(); + match lock.get(&json!(key)) { + None => Err(RedisError::from((redis::ErrorKind::TypeError, "not found"))), + Some(v) => serde_json::from_value(v.clone()).map_err(|e| { + RedisError::from(( + redis::ErrorKind::TypeError, + "deserialization error", + e.to_string(), + )) + }), + } + } + + pub fn del(&mut self, key: K) -> RedisResult { + let mut lock = self.storage.lock().unwrap(); + if lock.remove(&json!(key)).is_none() { + Ok(0) + } else { + Ok(1) + } + } +} + +#[cfg(test)] +mod tests { + use redis::RedisResult; + + use crate::{RedisPool, RedisPoolConfig}; + + #[test] + fn test_mock_redis_pool() { + let config = RedisPoolConfig { + redis_ttl: 111, + redis_url: "redis://localhost:6379".to_string(), + }; + let mut pool = RedisPool::open(config).unwrap(); + let mut conn = pool.conn().expect("mock conn"); + + let key = "hello".to_string(); + let val = "world".to_string(); + conn.set_ex(key.clone(), val.clone(), 111) + .expect("mock set_ex"); + + let actual: RedisResult = conn.get(&key); + assert_eq!(actual, Ok(val)); + + let _ = conn.del(&key); + let actual: RedisResult = conn.get(&key); + assert!(actual.is_err()); + } +} diff --git a/reqpool/src/redis_pool.rs b/reqpool/src/redis_pool.rs index f52ba218d..6111acb12 100644 --- a/reqpool/src/redis_pool.rs +++ b/reqpool/src/redis_pool.rs @@ -2,12 +2,11 @@ use crate::{ impl_display_using_json_pretty, proof_key_to_hack_request_key, Pool, PoolResult, RedisPoolConfig, RequestEntity, RequestKey, StatusWithContext, }; -use backoff::{exponential::ExponentialBackoff, SystemClock}; use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; use raiko_redis_derive::RedisValue; +#[allow(unused_imports)] use redis::{Client, Commands, RedisResult}; use serde::{Deserialize, Serialize}; -use std::time::Duration; #[derive(Debug, Clone)] pub struct RedisPool { @@ -91,7 +90,7 @@ impl IdStore for RedisPool { let result: RedisResult = self .conn() .map_err(|e| e.to_string())? - .get(hack_request_key); + .get(&hack_request_key); match result { Ok(value) => Ok(value.into()), Err(e) => Err(ProverError::StoreError(e.to_string())), @@ -134,7 +133,19 @@ impl RedisPool { Ok(Self { client, config }) } + #[cfg(test)] + pub(crate) fn conn(&mut self) -> Result { + let _ = self.client; + Ok(crate::mock::MockRedisConnection::new( + self.config.redis_url.clone(), + )) + } + + #[cfg(not(test))] fn conn(&mut self) -> Result { + use backoff::{exponential::ExponentialBackoff, SystemClock}; + use std::time::Duration; + let backoff: ExponentialBackoff = ExponentialBackoff { initial_interval: Duration::from_secs(10), max_interval: Duration::from_secs(60), From c4f57874702695264c3d6ec537505e8e5518ab6e Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Tue, 14 Jan 2025 20:01:50 +0800 Subject: [PATCH 3/9] feat(reqpool): remove Pool trait --- reqpool/src/lib.rs | 4 +--- reqpool/src/mock.rs | 4 ++-- reqpool/src/redis_pool.rs | 33 +++++++++++++++------------- reqpool/src/traits.rs | 45 --------------------------------------- 4 files changed, 21 insertions(+), 65 deletions(-) delete mode 100644 reqpool/src/traits.rs diff --git a/reqpool/src/lib.rs b/reqpool/src/lib.rs index 6b2056c21..58140e237 100644 --- a/reqpool/src/lib.rs +++ b/reqpool/src/lib.rs @@ -5,15 +5,13 @@ mod memory_pool; mod mock; mod redis_pool; mod request; -mod traits; mod utils; // Re-export pub use config::RedisPoolConfig; -pub use redis_pool::RedisPool; +pub use redis_pool::Pool; pub use request::{ AggregationRequestEntity, AggregationRequestKey, RequestEntity, RequestKey, SingleProofRequestEntity, SingleProofRequestKey, Status, StatusWithContext, }; -pub use traits::{Pool, PoolResult, PoolWithTrace}; pub use utils::proof_key_to_hack_request_key; diff --git a/reqpool/src/mock.rs b/reqpool/src/mock.rs index fc88eb457..f819cd4b7 100644 --- a/reqpool/src/mock.rs +++ b/reqpool/src/mock.rs @@ -69,7 +69,7 @@ impl MockRedisConnection { mod tests { use redis::RedisResult; - use crate::{RedisPool, RedisPoolConfig}; + use crate::{Pool, RedisPoolConfig}; #[test] fn test_mock_redis_pool() { @@ -77,7 +77,7 @@ mod tests { redis_ttl: 111, redis_url: "redis://localhost:6379".to_string(), }; - let mut pool = RedisPool::open(config).unwrap(); + let mut pool = Pool::open(config).unwrap(); let mut conn = pool.conn().expect("mock conn"); let key = "hello".to_string(); diff --git a/reqpool/src/redis_pool.rs b/reqpool/src/redis_pool.rs index 6111acb12..e8c3c9b8f 100644 --- a/reqpool/src/redis_pool.rs +++ b/reqpool/src/redis_pool.rs @@ -1,6 +1,6 @@ use crate::{ - impl_display_using_json_pretty, proof_key_to_hack_request_key, Pool, PoolResult, - RedisPoolConfig, RequestEntity, RequestKey, StatusWithContext, + impl_display_using_json_pretty, proof_key_to_hack_request_key, RedisPoolConfig, RequestEntity, + RequestKey, StatusWithContext, }; use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; use raiko_redis_derive::RedisValue; @@ -9,18 +9,18 @@ use redis::{Client, Commands, RedisResult}; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone)] -pub struct RedisPool { +pub struct Pool { client: Client, config: RedisPoolConfig, } -impl Pool for RedisPool { - fn add( +impl Pool { + pub fn add( &mut self, request_key: RequestKey, request_entity: RequestEntity, status: StatusWithContext, - ) -> PoolResult<()> { + ) -> Result<(), String> { tracing::info!("RedisPool.add: {request_key}, {status}"); let request_entity_and_status = RequestEntityAndStatus { entity: request_entity, @@ -37,7 +37,7 @@ impl Pool for RedisPool { Ok(()) } - fn remove(&mut self, request_key: &RequestKey) -> PoolResult { + pub fn remove(&mut self, request_key: &RequestKey) -> Result { tracing::info!("RedisPool.remove: {request_key}"); let result: usize = self .conn() @@ -47,10 +47,10 @@ impl Pool for RedisPool { Ok(result) } - fn get( + pub fn get( &mut self, request_key: &RequestKey, - ) -> PoolResult> { + ) -> Result, String> { let result: RedisResult = self.conn().map_err(|e| e.to_string())?.get(request_key); match result { @@ -60,15 +60,18 @@ impl Pool for RedisPool { } } - fn get_status(&mut self, request_key: &RequestKey) -> PoolResult> { + pub fn get_status( + &mut self, + request_key: &RequestKey, + ) -> Result, String> { self.get(request_key).map(|v| v.map(|v| v.1)) } - fn update_status( + pub fn update_status( &mut self, request_key: RequestKey, status: StatusWithContext, - ) -> PoolResult { + ) -> Result { tracing::info!("RedisPool.update_status: {request_key}, {status}"); match self.get(&request_key)? { Some((entity, old_status)) => { @@ -81,7 +84,7 @@ impl Pool for RedisPool { } #[async_trait::async_trait] -impl IdStore for RedisPool { +impl IdStore for Pool { async fn read_id(&mut self, proof_key: ProofKey) -> ProverResult { let hack_request_key = proof_key_to_hack_request_key(proof_key); @@ -99,7 +102,7 @@ impl IdStore for RedisPool { } #[async_trait::async_trait] -impl IdWrite for RedisPool { +impl IdWrite for Pool { async fn store_id(&mut self, proof_key: ProofKey, id: String) -> ProverResult<()> { let hack_request_key = proof_key_to_hack_request_key(proof_key); @@ -125,7 +128,7 @@ impl IdWrite for RedisPool { } } -impl RedisPool { +impl Pool { pub fn open(config: RedisPoolConfig) -> Result { tracing::info!("RedisPool.open: connecting to redis: {}", config.redis_url); diff --git a/reqpool/src/traits.rs b/reqpool/src/traits.rs deleted file mode 100644 index 243ecfb37..000000000 --- a/reqpool/src/traits.rs +++ /dev/null @@ -1,45 +0,0 @@ -use crate::request::{RequestEntity, RequestKey, StatusWithContext}; - -pub type PoolResult = Result; - -/// Pool maintains the requests and their statuses -pub trait Pool: Send + Sync + Clone { - /// Add a new request to the pool - fn add( - &mut self, - request_key: RequestKey, - request_entity: RequestEntity, - status: StatusWithContext, - ) -> PoolResult<()>; - - /// Remove a request from the pool, return the number of requests removed - fn remove(&mut self, request_key: &RequestKey) -> PoolResult; - - /// Get a request and status from the pool - fn get( - &mut self, - request_key: &RequestKey, - ) -> PoolResult>; - - /// Get the status of a request - fn get_status(&mut self, request_key: &RequestKey) -> PoolResult>; - - /// Update the status of a request, return the old status - fn update_status( - &mut self, - request_key: RequestKey, - status: StatusWithContext, - ) -> PoolResult; -} - -/// A pool extension that supports tracing -pub trait PoolWithTrace: Pool { - /// Get all trace of requests, with the given max depth. - fn trace_all(&self, max_depth: usize) -> Vec<(RequestKey, RequestEntity, StatusWithContext)>; - - /// Get the live entity and trace of a request - fn trace( - &self, - request_key: &RequestKey, - ) -> Vec<(RequestKey, RequestEntity, StatusWithContext)>; -} From 65c5d7e8091a09559ad8150ddd5ab2c23969b454 Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Tue, 14 Jan 2025 20:15:02 +0800 Subject: [PATCH 4/9] feat(reqpool): remove memory pool --- reqpool/src/lib.rs | 3 +- reqpool/src/memory_pool.rs | 115 ------------------------------------- reqpool/src/mock.rs | 3 + reqpool/src/redis_pool.rs | 14 +++-- 4 files changed, 12 insertions(+), 123 deletions(-) delete mode 100644 reqpool/src/memory_pool.rs diff --git a/reqpool/src/lib.rs b/reqpool/src/lib.rs index 58140e237..e25021163 100644 --- a/reqpool/src/lib.rs +++ b/reqpool/src/lib.rs @@ -1,7 +1,6 @@ mod config; mod macros; -mod memory_pool; -#[cfg(test)] +#[cfg(any(test, feature = "enable-mock"))] mod mock; mod redis_pool; mod request; diff --git a/reqpool/src/memory_pool.rs b/reqpool/src/memory_pool.rs deleted file mode 100644 index 410da5d12..000000000 --- a/reqpool/src/memory_pool.rs +++ /dev/null @@ -1,115 +0,0 @@ -// use std::collections::HashMap; - -// use chrono::Utc; - -// use crate::{ -// request::{RequestEntity, RequestKey, Status, StatusWithContext}, -// traits::{Pool, PoolWithTrace}, -// }; - -// #[derive(Debug, Clone)] -// pub struct MemoryPool { -// /// The live requests in the pool -// pending: HashMap, -// /// The trace of requests -// trace: Vec<(RequestKey, RequestEntity, StatusWithContext)>, -// } - -// impl Pool for MemoryPool { -// type Config = (); - -// fn new(_config: Self::Config) -> Self { -// Self { -// lives: HashMap::new(), -// trace: Vec::new(), -// } -// } - -// fn add(&mut self, request_key: RequestKey, request_entity: RequestEntity) { -// let status = StatusWithContext::new(Status::Registered, Utc::now()); - -// let old = self.lives.insert( -// request_key.clone(), -// (request_entity.clone(), status.clone()), -// ); - -// if let Some((_, old_status)) = old { -// tracing::error!( -// "MemoryPool.add: request key already exists, {request_key:?}, old status: {old_status:?}" -// ); -// } else { -// tracing::info!("MemoryPool.add, {request_key:?}, status: {status:?}"); -// } - -// self.trace.push((request_key, request_entity, status)); -// } - -// fn remove(&mut self, request_key: &RequestKey) { -// match self.lives.remove(request_key) { -// Some((_, status)) => { -// tracing::info!("MemoryPool.remove, {request_key:?}, status: {status:?}"); -// } -// None => { -// tracing::error!("MemoryPool.remove: request key not found, {request_key:?}"); -// } -// } -// } - -// fn get(&self, request_key: &RequestKey) -> Option<(RequestEntity, StatusWithContext)> { -// self.lives.get(request_key).cloned() -// } - -// fn get_status(&self, request_key: &RequestKey) -> Option { -// self.lives -// .get(request_key) -// .map(|(_, status)| status.clone()) -// } - -// fn update_status(&mut self, request_key: &RequestKey, status: StatusWithContext) { -// match self.lives.remove(request_key) { -// Some((entity, old_status)) => { -// tracing::info!( -// "MemoryPool.update_status, {request_key:?}, old status: {old_status:?}, new status: {status:?}" -// ); -// self.lives -// .insert(request_key.clone(), (entity.clone(), status.clone())); -// self.trace.push((request_key.clone(), entity, status)); -// } -// None => { -// tracing::error!( -// "MemoryPool.update_status: request key not found, discard it, {request_key:?}" -// ); -// } -// } -// } -// } - -// impl PoolWithTrace for MemoryPool { -// fn get_all_live(&self) -> Vec<(RequestKey, RequestEntity, StatusWithContext)> { -// self.lives -// .iter() -// .map(|(k, v)| (k.clone(), v.0.clone(), v.1.clone())) -// .collect() -// } - -// fn get_all_trace(&self) -> Vec<(RequestKey, RequestEntity, StatusWithContext)> { -// self.trace.clone() -// } - -// fn trace( -// &self, -// request_key: &RequestKey, -// ) -> ( -// Option<(RequestEntity, StatusWithContext)>, -// Vec<(RequestKey, RequestEntity, StatusWithContext)>, -// ) { -// let live = self.lives.get(request_key).cloned(); -// let traces = self -// .trace -// .iter() -// .filter(|(k, _, _)| k == request_key) -// .cloned() -// .collect(); -// (live, traces) -// } -// } diff --git a/reqpool/src/mock.rs b/reqpool/src/mock.rs index f819cd4b7..d1c8692ae 100644 --- a/reqpool/src/mock.rs +++ b/reqpool/src/mock.rs @@ -12,6 +12,9 @@ type GlobalStorage = Mutex>; lazy_static! { // #{redis_url => single_storage} + // + // We use redis_url to distinguish different redis database for tests, to prevent + // data race problem when running multiple tests. static ref GLOBAL_STORAGE: GlobalStorage = Mutex::new(HashMap::new()); } diff --git a/reqpool/src/redis_pool.rs b/reqpool/src/redis_pool.rs index e8c3c9b8f..5f75ec738 100644 --- a/reqpool/src/redis_pool.rs +++ b/reqpool/src/redis_pool.rs @@ -2,11 +2,12 @@ use crate::{ impl_display_using_json_pretty, proof_key_to_hack_request_key, RedisPoolConfig, RequestEntity, RequestKey, StatusWithContext, }; +use backoff::{exponential::ExponentialBackoff, SystemClock}; use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; use raiko_redis_derive::RedisValue; -#[allow(unused_imports)] use redis::{Client, Commands, RedisResult}; use serde::{Deserialize, Serialize}; +use std::time::Duration; #[derive(Debug, Clone)] pub struct Pool { @@ -136,19 +137,20 @@ impl Pool { Ok(Self { client, config }) } - #[cfg(test)] + #[cfg(any(test, feature = "enable-mock"))] pub(crate) fn conn(&mut self) -> Result { - let _ = self.client; Ok(crate::mock::MockRedisConnection::new( self.config.redis_url.clone(), )) } - #[cfg(not(test))] + #[cfg(not(any(test, feature = "enable-mock")))] fn conn(&mut self) -> Result { - use backoff::{exponential::ExponentialBackoff, SystemClock}; - use std::time::Duration; + self.redis_conn() + } + #[allow(dead_code)] + fn redis_conn(&mut self) -> Result { let backoff: ExponentialBackoff = ExponentialBackoff { initial_interval: Duration::from_secs(10), max_interval: Duration::from_secs(60), From bbb61ce04a318d45ae89651639e002333728474d Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Wed, 15 Jan 2025 11:04:44 +0800 Subject: [PATCH 5/9] test(reqpool): add case for multiple redis pools --- reqpool/src/mock.rs | 47 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/reqpool/src/mock.rs b/reqpool/src/mock.rs index d1c8692ae..27194e388 100644 --- a/reqpool/src/mock.rs +++ b/reqpool/src/mock.rs @@ -95,4 +95,51 @@ mod tests { let actual: RedisResult = conn.get(&key); assert!(actual.is_err()); } + + #[test] + fn test_mock_multiple_redis_pool() { + let mut pool1 = Pool::open(RedisPoolConfig { + redis_ttl: 111, + redis_url: "redis://localhost:6379".to_string(), + }) + .unwrap(); + let mut pool2 = Pool::open(RedisPoolConfig { + redis_ttl: 111, + redis_url: "redis://localhost:6380".to_string(), + }) + .unwrap(); + + let mut conn1 = pool1.conn().expect("mock conn"); + let mut conn2 = pool2.conn().expect("mock conn"); + + let key = "hello".to_string(); + let world = "world".to_string(); + + { + conn1 + .set_ex(key.clone(), world.clone(), 111) + .expect("mock set_ex"); + let actual: RedisResult = conn1.get(&key); + assert_eq!(actual, Ok(world.clone())); + } + + { + let actual: RedisResult = conn2.get(&key); + assert!(actual.is_err()); + } + + { + let meme = "meme".to_string(); + conn2 + .set_ex(key.clone(), meme.clone(), 111) + .expect("mock set_ex"); + let actual: RedisResult = conn2.get(&key); + assert_eq!(actual, Ok(meme)); + } + + { + let actual: RedisResult = conn1.get(&key); + assert_eq!(actual, Ok(world)); + } + } } From 41b19b4eb7ed9e141442fd88358735164b121543 Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Tue, 14 Jan 2025 20:55:22 +0800 Subject: [PATCH 6/9] chore: allow unused_imports --- reqpool/src/redis_pool.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/reqpool/src/redis_pool.rs b/reqpool/src/redis_pool.rs index 5f75ec738..62f29f016 100644 --- a/reqpool/src/redis_pool.rs +++ b/reqpool/src/redis_pool.rs @@ -5,6 +5,7 @@ use crate::{ use backoff::{exponential::ExponentialBackoff, SystemClock}; use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; use raiko_redis_derive::RedisValue; +#[allow(unused_imports)] use redis::{Client, Commands, RedisResult}; use serde::{Deserialize, Serialize}; use std::time::Duration; From 6c4b4b1bf03763ba71c0d4552d5b80e6ae2587cf Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Wed, 15 Jan 2025 16:46:08 +0800 Subject: [PATCH 7/9] feat(reqpool): update Cargo.toml --- reqpool/Cargo.toml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/reqpool/Cargo.toml b/reqpool/Cargo.toml index d5a1fc7b4..4c417367b 100644 --- a/reqpool/Cargo.toml +++ b/reqpool/Cargo.toml @@ -8,15 +8,11 @@ edition = "2021" raiko-lib = { workspace = true } raiko-core = { workspace = true } raiko-redis-derive = { workspace = true } -num_enum = { workspace = true } chrono = { workspace = true, features = ["serde"] } -thiserror = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_with = { workspace = true } -hex = { workspace = true } tracing = { workspace = true } -anyhow = { workspace = true } tokio = { workspace = true } async-trait = { workspace = true } redis = { workspace = true } @@ -28,7 +24,4 @@ syn = { workspace = true } alloy-primitives = { workspace = true } [dev-dependencies] -rand = "0.9.0-alpha.1" # This is an alpha version, that has rng.gen_iter::() -rand_chacha = "0.9.0-alpha.1" -tempfile = "3.10.1" lazy_static = { workspace = true } From 7c9f523cfade610b0bcff2f3914232ca9f946f18 Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Wed, 15 Jan 2025 18:05:30 +0800 Subject: [PATCH 8/9] chore(reqpool): impl Display for Status --- reqpool/src/request.rs | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/reqpool/src/request.rs b/reqpool/src/request.rs index 5a71aeb5b..f02cd3e28 100644 --- a/reqpool/src/request.rs +++ b/reqpool/src/request.rs @@ -276,11 +276,29 @@ impl From for RequestEntity { // === impl Display using json_pretty === -impl_display_using_json_pretty!(Status); -impl_display_using_json_pretty!(StatusWithContext); impl_display_using_json_pretty!(RequestKey); impl_display_using_json_pretty!(SingleProofRequestKey); impl_display_using_json_pretty!(AggregationRequestKey); impl_display_using_json_pretty!(RequestEntity); impl_display_using_json_pretty!(SingleProofRequestEntity); impl_display_using_json_pretty!(AggregationRequestEntity); + +// === impl Display for Status === + +impl std::fmt::Display for Status { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Status::Registered => write!(f, "Registered"), + Status::WorkInProgress => write!(f, "WorkInProgress"), + Status::Success { .. } => write!(f, "Success"), + Status::Cancelled => write!(f, "Cancelled"), + Status::Failed { error } => write!(f, "Failed({})", error), + } + } +} + +impl std::fmt::Display for StatusWithContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.status()) + } +} From d8f4c6d28b041395ed1df086494a5671eccd2796 Mon Sep 17 00:00:00 2001 From: "keroroxx520@gmail.com" Date: Tue, 21 Jan 2025 20:01:54 +0800 Subject: [PATCH 9/9] feat(reqpool): remove feature "enable-mock" --- reqpool/Cargo.toml | 5 +++-- reqpool/src/lib.rs | 2 +- reqpool/src/mock.rs | 46 +++++++++++++++++++++++---------------- reqpool/src/redis_pool.rs | 20 ++++++++++++++--- 4 files changed, 48 insertions(+), 25 deletions(-) diff --git a/reqpool/Cargo.toml b/reqpool/Cargo.toml index 4c417367b..bc2721969 100644 --- a/reqpool/Cargo.toml +++ b/reqpool/Cargo.toml @@ -22,6 +22,7 @@ proc-macro2 = { workspace = true } quote = { workspace = true } syn = { workspace = true } alloy-primitives = { workspace = true } - -[dev-dependencies] lazy_static = { workspace = true } + +[features] +test-utils = [] diff --git a/reqpool/src/lib.rs b/reqpool/src/lib.rs index e25021163..8fbf3db42 100644 --- a/reqpool/src/lib.rs +++ b/reqpool/src/lib.rs @@ -1,6 +1,5 @@ mod config; mod macros; -#[cfg(any(test, feature = "enable-mock"))] mod mock; mod redis_pool; mod request; @@ -8,6 +7,7 @@ mod utils; // Re-export pub use config::RedisPoolConfig; +pub use mock::{mock_redis_pool, MockRedisConnection}; pub use redis_pool::Pool; pub use request::{ AggregationRequestEntity, AggregationRequestKey, RequestEntity, RequestKey, diff --git a/reqpool/src/mock.rs b/reqpool/src/mock.rs index 27194e388..ac408c7c5 100644 --- a/reqpool/src/mock.rs +++ b/reqpool/src/mock.rs @@ -1,3 +1,4 @@ +use crate::{Pool, RedisPoolConfig}; use lazy_static::lazy_static; use redis::{RedisError, RedisResult}; use serde::Serialize; @@ -23,7 +24,7 @@ pub struct MockRedisConnection { } impl MockRedisConnection { - pub(crate) fn new(redis_url: String) -> Self { + pub fn new(redis_url: String) -> Self { let mut global = GLOBAL_STORAGE.lock().unwrap(); Self { storage: global @@ -66,21 +67,37 @@ impl MockRedisConnection { Ok(1) } } + + pub fn keys(&mut self, key: &str) -> RedisResult> { + assert_eq!(key, "*", "mock redis only supports '*'"); + + let lock = self.storage.lock().unwrap(); + Ok(lock + .keys() + .map(|k| serde_json::from_value(k.clone()).unwrap()) + .collect()) + } +} + +/// Return the mock redis pool with the given id. +/// +/// This is used for testing. Please use the test case name as the id to prevent data race. +pub fn mock_redis_pool(id: S) -> Pool { + let config = RedisPoolConfig { + redis_ttl: 111, + redis_url: format!("redis://{}:6379", id.to_string()), + }; + Pool::open(config).unwrap() } #[cfg(test)] mod tests { + use super::*; use redis::RedisResult; - use crate::{Pool, RedisPoolConfig}; - #[test] fn test_mock_redis_pool() { - let config = RedisPoolConfig { - redis_ttl: 111, - redis_url: "redis://localhost:6379".to_string(), - }; - let mut pool = Pool::open(config).unwrap(); + let mut pool = mock_redis_pool("test_mock_redis_pool"); let mut conn = pool.conn().expect("mock conn"); let key = "hello".to_string(); @@ -98,17 +115,8 @@ mod tests { #[test] fn test_mock_multiple_redis_pool() { - let mut pool1 = Pool::open(RedisPoolConfig { - redis_ttl: 111, - redis_url: "redis://localhost:6379".to_string(), - }) - .unwrap(); - let mut pool2 = Pool::open(RedisPoolConfig { - redis_ttl: 111, - redis_url: "redis://localhost:6380".to_string(), - }) - .unwrap(); - + let mut pool1 = mock_redis_pool("test_mock_multiple_redis_pool_1"); + let mut pool2 = mock_redis_pool("test_mock_multiple_redis_pool_2"); let mut conn1 = pool1.conn().expect("mock conn"); let mut conn2 = pool2.conn().expect("mock conn"); diff --git a/reqpool/src/redis_pool.rs b/reqpool/src/redis_pool.rs index 62f29f016..c869917f1 100644 --- a/reqpool/src/redis_pool.rs +++ b/reqpool/src/redis_pool.rs @@ -8,7 +8,7 @@ use raiko_redis_derive::RedisValue; #[allow(unused_imports)] use redis::{Client, Commands, RedisResult}; use serde::{Deserialize, Serialize}; -use std::time::Duration; +use std::{collections::HashMap, time::Duration}; #[derive(Debug, Clone)] pub struct Pool { @@ -83,6 +83,20 @@ impl Pool { None => Err("Request not found".to_string()), } } + + pub fn list(&mut self) -> Result, String> { + let mut conn = self.conn().map_err(|e| e.to_string())?; + let keys: Vec = conn.keys("*").map_err(|e| e.to_string())?; + + let mut result = HashMap::new(); + for key in keys { + if let Ok(Some((_, status))) = self.get(&key) { + result.insert(key, status); + } + } + + Ok(result) + } } #[async_trait::async_trait] @@ -138,14 +152,14 @@ impl Pool { Ok(Self { client, config }) } - #[cfg(any(test, feature = "enable-mock"))] + #[cfg(any(test, feature = "test-utils"))] pub(crate) fn conn(&mut self) -> Result { Ok(crate::mock::MockRedisConnection::new( self.config.redis_url.clone(), )) } - #[cfg(not(any(test, feature = "enable-mock")))] + #[cfg(not(any(test, feature = "test-utils")))] fn conn(&mut self) -> Result { self.redis_conn() }