diff --git a/benches/bench/main.rs b/benches/bench/main.rs index 2f457c4..ea66c99 100644 --- a/benches/bench/main.rs +++ b/benches/bench/main.rs @@ -217,7 +217,6 @@ fn config() -> Config { format!("ws://{}", SERVER_TWO_ENDPOINT), ], shuffle_endpoints: false, - health_check: None, }), server: Some(ServerConfig { listen_address: SUBWAY_SERVER_ADDR.to_string(), diff --git a/configs/config.yml b/configs/config.yml index c0f6511..2a062b4 100644 --- a/configs/config.yml +++ b/configs/config.yml @@ -3,14 +3,6 @@ extensions: endpoints: - wss://acala-rpc.dwellir.com - wss://acala-rpc-0.aca-api.network - health_check: - interval_sec: 10 # check interval, default is 10s - healthy_response_time_ms: 500 # max response time to be considered healthy, default is 500ms - health_method: system_health - response: # response contains { isSyncing: false } - !contains - - - isSyncing - - !eq false event_bus: substrate_api: stale_timeout_seconds: 180 # rotate endpoint if no new blocks for 3 minutes diff --git a/configs/eth_config.yml b/configs/eth_config.yml index 85d6955..3363d77 100644 --- a/configs/eth_config.yml +++ b/configs/eth_config.yml @@ -2,18 +2,6 @@ extensions: client: endpoints: - wss://eth-rpc-karura-testnet.aca-staging.network - health_check: - interval_sec: 10 # check interval, default is 10s - healthy_response_time_ms: 500 # max response time to be considered healthy, default is 500ms - health_method: net_health # eth-rpc-adapter bodhijs - response: # response contains { isHealthy: true, isRPCOK: true } - !contains - - - isHealthy - - !eq true - - - isRPCOK - - !eq true -# health_method: eth_syncing # eth node -# response: !eq false event_bus: eth_api: stale_timeout_seconds: 180 # rotate endpoint if no new blocks for 3 minutes diff --git a/src/config/mod.rs b/src/config/mod.rs index 1d30aec..5424157 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -204,6 +204,8 @@ fn render_template(templated_config_str: &str) -> Result } pub async fn validate(config: &Config) -> Result<(), anyhow::Error> { + tracing::debug!("Validating config"); + // validate use garde::Validate config.validate(&())?; // since endpoints connection test is async @@ -214,6 +216,9 @@ pub async fn validate(config: &Config) -> Result<(), anyhow::Error> { anyhow::bail!("Unable to connect to all endpoints"); } } + + tracing::debug!("Validation completed"); + Ok(()) } diff --git a/src/extensions/api/tests.rs b/src/extensions/api/tests.rs index a82b730..1c0ff0c 100644 --- a/src/extensions/api/tests.rs +++ b/src/extensions/api/tests.rs @@ -1,6 +1,6 @@ use jsonrpsee::server::ServerHandle; use serde_json::json; -use std::{net::SocketAddr, sync::Arc, time::Duration}; +use std::{net::SocketAddr, sync::Arc}; use tokio::sync::mpsc; use super::eth::EthApi; @@ -61,14 +61,7 @@ async fn create_client() -> ( ) { let (addr, server, head_rx, finalized_head_rx, block_hash_rx) = create_server().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); (client, server, head_rx, finalized_head_rx, block_hash_rx) } @@ -175,14 +168,7 @@ async fn rotate_endpoint_on_stale() { let (addr, server, mut head_rx, _, mut block_rx) = create_server().await; let (addr2, server2, mut head_rx2, _, mut block_rx2) = create_server().await; - let client = Client::new( - [format!("ws://{addr}"), format!("ws://{addr2}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}"), format!("ws://{addr2}")]).unwrap(); let api = SubstrateApi::new(Arc::new(client), std::time::Duration::from_millis(100)); let head = api.get_head(); @@ -245,14 +231,7 @@ async fn rotate_endpoint_on_head_mismatch() { let (addr1, server1, mut head_rx1, mut finalized_head_rx1, mut block_rx1) = create_server().await; let (addr2, server2, mut head_rx2, mut finalized_head_rx2, mut block_rx2) = create_server().await; - let client = Client::new( - [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr1}"), format!("ws://{addr2}")]).unwrap(); let client = Arc::new(client); let api = SubstrateApi::new(client.clone(), std::time::Duration::from_millis(100)); @@ -353,16 +332,7 @@ async fn rotate_endpoint_on_head_mismatch() { #[tokio::test] async fn substrate_background_tasks_abort_on_drop() { let (addr, _server, mut head_rx, mut finalized_head_rx, _) = create_server().await; - let client = Arc::new( - Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(), - ); + let client = Arc::new(Client::with_endpoints([format!("ws://{addr}")]).unwrap()); let api = SubstrateApi::new(client, std::time::Duration::from_millis(100)); // background tasks started @@ -382,16 +352,7 @@ async fn substrate_background_tasks_abort_on_drop() { #[tokio::test] async fn eth_background_tasks_abort_on_drop() { let (addr, _server, mut subscription_rx, mut block_rx) = create_eth_server().await; - let client = Arc::new( - Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(), - ); + let client = Arc::new(Client::with_endpoints([format!("ws://{addr}")]).unwrap()); let api = EthApi::new(client, std::time::Duration::from_millis(100)); diff --git a/src/extensions/client/endpoint.rs b/src/extensions/client/endpoint.rs deleted file mode 100644 index 5d6118b..0000000 --- a/src/extensions/client/endpoint.rs +++ /dev/null @@ -1,448 +0,0 @@ -use super::health::{self, Event, Health}; -use crate::extensions::client::{get_backoff_time, HealthCheckConfig}; -use jsonrpsee::{ - async_client::Client, - core::client::{ClientT, Subscription, SubscriptionClientT}, - core::JsonValue, - ws_client::WsClientBuilder, -}; -use std::{ - fmt::{Debug, Formatter}, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, - time::Duration, -}; - -enum Message { - Request { - method: String, - params: Vec, - response: tokio::sync::oneshot::Sender>, - timeout: Duration, - }, - Subscribe { - subscribe: String, - params: Vec, - unsubscribe: String, - response: tokio::sync::oneshot::Sender, jsonrpsee::core::client::Error>>, - timeout: Duration, - }, - Reconnect, -} - -impl Debug for Message { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Message::Request { - method, - params, - response: _, - timeout, - } => write!(f, "Request({method}, {params:?}, _, {timeout:?})"), - Message::Subscribe { - subscribe, - params, - unsubscribe, - response: _, - timeout, - } => write!(f, "Subscribe({subscribe}, {params:?}, {unsubscribe}, _, {timeout:?})"), - Message::Reconnect => write!(f, "Reconnect"), - } - } -} - -enum State { - Initial, - OnError(health::Event), - Connect(Option), - HandleMessage(Arc, Message), - WaitForMessage(Arc), -} - -impl Debug for State { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - State::Initial => write!(f, "Initial"), - State::OnError(e) => write!(f, "OnError({e:?})"), - State::Connect(m) => write!(f, "Connect({m:?})"), - State::HandleMessage(_c, m) => write!(f, "HandleMessage(_, {m:?})"), - State::WaitForMessage(_c) => write!(f, "WaitForMessage(_)"), - } - } -} - -pub struct Endpoint { - url: String, - health: Arc, - message_tx: tokio::sync::mpsc::Sender, - background_tasks: Vec>, - connect_counter: Arc, -} - -impl Drop for Endpoint { - fn drop(&mut self) { - self.background_tasks.drain(..).for_each(|handle| handle.abort()); - } -} - -impl Endpoint { - pub fn new( - url: String, - request_timeout: Duration, - connection_timeout: Duration, - health_config: Option, - ) -> Self { - tracing::info!("New endpoint: {url}"); - - let health = Arc::new(Health::new(url.clone())); - let connect_counter = Arc::new(AtomicU32::new(0)); - let (message_tx, message_rx) = tokio::sync::mpsc::channel::(4096); - - let mut endpoint = Self { - url: url.clone(), - health: health.clone(), - message_tx, - background_tasks: vec![], - connect_counter: connect_counter.clone(), - }; - - endpoint.start_background_task( - url, - request_timeout, - connection_timeout, - connect_counter, - message_rx, - health, - ); - if let Some(config) = health_config { - endpoint.start_health_monitor_task(config); - } - - endpoint - } - - fn start_background_task( - &mut self, - url: String, - request_timeout: Duration, - connection_timeout: Duration, - connect_counter: Arc, - mut message_rx: tokio::sync::mpsc::Receiver, - health: Arc, - ) { - let handler = tokio::spawn(async move { - let connect_backoff_counter = Arc::new(AtomicU32::new(0)); - - let mut state = State::Initial; - - loop { - tracing::trace!("{url} {state:?}"); - - let new_state = match state { - State::Initial => { - connect_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - // wait for messages before connecting - let msg = match message_rx.recv().await { - Some(Message::Reconnect) => None, - Some(msg @ Message::Request { .. } | msg @ Message::Subscribe { .. }) => Some(msg), - None => { - let url = url.clone(); - // channel is closed? exit - tracing::debug!("Endpoint {url} channel closed"); - return; - } - }; - State::Connect(msg) - } - State::OnError(evt) => { - health.update(evt); - tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; - State::Initial - } - State::Connect(msg) => { - // TODO: make the params configurable - let client = WsClientBuilder::default() - .request_timeout(request_timeout) - .connection_timeout(connection_timeout) - .max_buffer_capacity_per_subscription(2048) - .max_concurrent_requests(2048) - .max_response_size(20 * 1024 * 1024) - .build(url.clone()) - .await; - - match client { - Ok(client) => { - connect_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - health.update(Event::ConnectionSuccessful); - if let Some(msg) = msg { - State::HandleMessage(Arc::new(client), msg) - } else { - State::WaitForMessage(Arc::new(client)) - } - } - Err(err) => { - tracing::debug!("Endpoint {url} connection error: {err}"); - State::OnError(health::Event::ConnectionClosed) - } - } - } - State::HandleMessage(client, msg) => match msg { - Message::Request { - method, - params, - response, - timeout, - } => { - // don't block on making the request - let url = url.clone(); - let health = health.clone(); - let client2 = client.clone(); - tokio::spawn(async move { - let resp = match tokio::time::timeout( - timeout, - client2.request::>(&method, params), - ) - .await - { - Ok(resp) => resp, - Err(_) => { - tracing::warn!("Endpoint {url} request timeout: {method} timeout: {timeout:?}"); - health.update(Event::RequestTimeout); - Err(jsonrpsee::core::client::Error::RequestTimeout) - } - }; - if let Err(err) = &resp { - health.on_error(err); - } - - if response.send(resp).is_err() { - tracing::error!("Unable to send response to message channel"); - } - }); - - State::WaitForMessage(client) - } - Message::Subscribe { - subscribe, - params, - unsubscribe, - response, - timeout, - } => { - // don't block on making the request - let url = url.clone(); - let health = health.clone(); - let client2 = client.clone(); - tokio::spawn(async move { - let resp = match tokio::time::timeout( - timeout, - client2.subscribe::>( - &subscribe, - params, - &unsubscribe, - ), - ) - .await - { - Ok(resp) => resp, - Err(_) => { - tracing::warn!("Endpoint {url} subscription timeout: {subscribe}"); - health.update(Event::RequestTimeout); - Err(jsonrpsee::core::client::Error::RequestTimeout) - } - }; - if let Err(err) = &resp { - health.on_error(err); - } - - if response.send(resp).is_err() { - tracing::error!("Unable to send response to message channel"); - } - }); - - State::WaitForMessage(client) - } - Message::Reconnect => State::Initial, - }, - State::WaitForMessage(client) => { - tokio::select! { - msg = message_rx.recv() => { - match msg { - Some(msg) => State::HandleMessage(client, msg), - None => { - // channel is closed? exit - tracing::debug!("Endpoint {url} channel closed"); - return - } - } - - }, - () = client.on_disconnect() => { - tracing::debug!("Endpoint {url} disconnected"); - State::OnError(health::Event::ConnectionClosed) - } - } - } - }; - - state = new_state; - } - }); - - self.background_tasks.push(handler); - } - - fn start_health_monitor_task(&mut self, config: HealthCheckConfig) { - let message_tx = self.message_tx.clone(); - let health = self.health.clone(); - let url = self.url.clone(); - - let handler = tokio::spawn(async move { - let health_response = config.response.clone(); - let interval = Duration::from_secs(config.interval_sec); - let healthy_response_time = Duration::from_millis(config.healthy_response_time_ms); - let max_response_time: Duration = Duration::from_millis(config.healthy_response_time_ms * 2); - - loop { - // Wait for the next interval - tokio::time::sleep(interval).await; - - let request_start = std::time::Instant::now(); - - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let res = message_tx - .send(Message::Request { - method: config.health_method.clone(), - params: vec![], - response: response_tx, - timeout: max_response_time, - }) - .await; - - if let Err(err) = res { - tracing::error!("{url} Unexpected error in message channel: {err}"); - } - - let res = match response_rx.await { - Ok(resp) => resp, - Err(err) => { - tracing::error!("{url} Unexpected error in response channel: {err}"); - Err(jsonrpsee::core::client::Error::Custom("Internal server error".into())) - } - }; - - match res { - Ok(response) => { - let duration = request_start.elapsed(); - - // Check response - if let Some(ref health_response) = health_response { - if !health_response.validate(&response) { - health.update(Event::Unhealthy); - continue; - } - } - - // Check response time - if duration > healthy_response_time { - tracing::warn!("{url} response time is too long: {duration:?}"); - health.update(Event::SlowResponse); - continue; - } - - health.update(Event::ResponseOk); - } - Err(err) => { - health.on_error(&err); - } - } - } - }); - - self.background_tasks.push(handler); - } - - pub fn url(&self) -> &str { - &self.url - } - - pub fn health(&self) -> &Health { - self.health.as_ref() - } - - pub fn connect_counter(&self) -> u32 { - self.connect_counter.load(Ordering::Relaxed) - } - - pub async fn request( - &self, - method: &str, - params: Vec, - timeout: Duration, - ) -> Result { - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let res = self - .message_tx - .send(Message::Request { - method: method.into(), - params, - response: response_tx, - timeout, - }) - .await; - - if let Err(err) = res { - tracing::error!("Unexpected error in message channel: {err}"); - } - - match response_rx.await { - Ok(resp) => resp, - Err(err) => { - tracing::error!("Unexpected error in response channel: {err}"); - Err(jsonrpsee::core::client::Error::Custom("Internal server error".into())) - } - } - } - - pub async fn subscribe( - &self, - subscribe_method: &str, - params: Vec, - unsubscribe_method: &str, - timeout: Duration, - ) -> Result, jsonrpsee::core::client::Error> { - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let res = self - .message_tx - .send(Message::Subscribe { - subscribe: subscribe_method.into(), - params, - unsubscribe: unsubscribe_method.into(), - response: response_tx, - timeout, - }) - .await; - - if let Err(err) = res { - tracing::error!("Unexpected error in message channel: {err}"); - } - - match response_rx.await { - Ok(resp) => resp, - Err(err) => { - tracing::error!("Unexpected error in response channel: {err}"); - Err(jsonrpsee::core::client::Error::Custom("Internal server error".into())) - } - } - } - - pub async fn reconnect(&self) { - let res = self.message_tx.send(Message::Reconnect).await; - if let Err(err) = res { - tracing::error!("Unexpected error in message channel: {err}"); - } - } -} diff --git a/src/extensions/client/health.rs b/src/extensions/client/health.rs deleted file mode 100644 index c69adc1..0000000 --- a/src/extensions/client/health.rs +++ /dev/null @@ -1,91 +0,0 @@ -use std::sync::atomic::{AtomicU32, Ordering}; - -const MAX_SCORE: u32 = 100; -const THRESHOLD: u32 = 50; - -#[derive(Debug)] -pub enum Event { - ResponseOk, - ConnectionSuccessful, - SlowResponse, - RequestTimeout, - ServerError, - Unhealthy, - ConnectionClosed, -} - -impl Event { - pub fn update_score(&self, current: u32) -> u32 { - u32::min( - match self { - Event::ConnectionSuccessful => current.saturating_add(60), - Event::ResponseOk => current.saturating_add(2), - Event::SlowResponse => current.saturating_sub(20), - Event::RequestTimeout => current.saturating_sub(40), - Event::ConnectionClosed => current.saturating_sub(30), - Event::ServerError | Event::Unhealthy => 0, - }, - MAX_SCORE, - ) - } -} - -#[derive(Debug, Default)] -pub struct Health { - url: String, - score: AtomicU32, - unhealthy: tokio::sync::Notify, -} - -impl Health { - pub fn new(url: String) -> Self { - Self { - url, - score: AtomicU32::new(0), - unhealthy: tokio::sync::Notify::new(), - } - } - - pub fn score(&self) -> u32 { - self.score.load(Ordering::Relaxed) - } - - pub fn update(&self, event: Event) { - let current_score = self.score.load(Ordering::Relaxed); - let new_score = event.update_score(current_score); - if new_score == current_score { - return; - } - self.score.store(new_score, Ordering::Relaxed); - tracing::trace!( - "{:?} score updated from: {current_score} to: {new_score} because {event:?}", - self.url - ); - - // Notify waiters if the score has dropped below the threshold - if current_score >= THRESHOLD && new_score < THRESHOLD { - tracing::warn!("{:?} became unhealthy", self.url); - self.unhealthy.notify_waiters(); - } - } - - pub fn on_error(&self, err: &jsonrpsee::core::client::Error) { - match err { - jsonrpsee::core::client::Error::Call(_) => { - // NOT SERVER ERROR - } - jsonrpsee::core::client::Error::RequestTimeout => { - tracing::warn!("{:?} request timeout", self.url); - self.update(Event::RequestTimeout); - } - _ => { - tracing::warn!("{:?} responded with error: {err:?}", self.url); - self.update(Event::ServerError); - } - }; - } - - pub async fn unhealthy(&self) { - self.unhealthy.notified().await; - } -} diff --git a/src/extensions/client/mock.rs b/src/extensions/client/mock.rs index d999093..fbe0864 100644 --- a/src/extensions/client/mock.rs +++ b/src/extensions/client/mock.rs @@ -153,16 +153,6 @@ pub async fn dummy_server() -> ( (addr, handle, rx, sub_rx) } -pub async fn dummy_server_extend(extend: Box) -> (SocketAddr, ServerHandle) { - let mut builder = TestServerBuilder::new(); - - extend(&mut builder); - - let (addr, handle) = builder.build().await; - - (addr, handle) -} - pub enum SinkTask { Sleep(u64), Send(JsonValue), diff --git a/src/extensions/client/mod.rs b/src/extensions/client/mod.rs index 865be01..c931d5e 100644 --- a/src/extensions/client/mod.rs +++ b/src/extensions/client/mod.rs @@ -1,21 +1,25 @@ use std::{ - fmt::{Debug, Formatter}, - sync::{atomic::AtomicU32, Arc}, + sync::{ + atomic::{AtomicU32, AtomicUsize}, + Arc, + }, time::Duration, }; use anyhow::anyhow; use async_trait::async_trait; -use futures::FutureExt as _; +use futures::TryFutureExt; use garde::Validate; -use jsonrpsee::core::{ - client::{Error, Subscription}, - JsonValue, +use jsonrpsee::{ + core::{ + client::{ClientT, Error, Subscription, SubscriptionClientT}, + JsonValue, + }, + ws_client::{WsClient, WsClientBuilder}, }; -use jsonrpsee::ws_client::WsClientBuilder; use opentelemetry::trace::FutureExt; use rand::{seq::SliceRandom, thread_rng}; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use tokio::sync::Notify; use super::ExtensionRegistry; @@ -25,10 +29,6 @@ use crate::{ utils::{self, errors}, }; -mod endpoint; -mod health; -use endpoint::Endpoint; - #[cfg(test)] pub mod mock; #[cfg(test)] @@ -37,7 +37,7 @@ mod tests; const TRACER: utils::telemetry::Tracer = utils::telemetry::Tracer::new("client"); pub struct Client { - endpoints: Vec>, + endpoints: Vec, sender: tokio::sync::mpsc::Sender, rotation_notify: Arc, retries: u32, @@ -57,7 +57,6 @@ pub struct ClientConfig { pub endpoints: Vec, #[serde(default = "bool_true")] pub shuffle_endpoints: bool, - pub health_check: Option, } fn validate_endpoint(endpoint: &str, _context: &()) -> garde::Result { @@ -112,54 +111,7 @@ pub fn bool_true() -> bool { true } -#[derive(Deserialize, Debug, Clone)] -pub struct HealthCheckConfig { - #[serde(default = "interval_sec")] - pub interval_sec: u64, - #[serde(default = "healthy_response_time_ms")] - pub healthy_response_time_ms: u64, - pub health_method: String, - pub response: Option, -} - -pub fn interval_sec() -> u64 { - 300 -} - -pub fn healthy_response_time_ms() -> u64 { - 500 -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum HealthResponse { - Eq(JsonValue), - NotEq(JsonValue), - Contains(Vec<(String, Box)>), -} - -impl HealthResponse { - pub fn validate(&self, response: &JsonValue) -> bool { - match self { - HealthResponse::Eq(value) => value.eq(response), - HealthResponse::NotEq(value) => !value.eq(response), - HealthResponse::Contains(items) => { - for (key, expected) in items { - if let Some(response) = response.get(key) { - if !expected.validate(response) { - return false; - } - } else { - // key missing - return false; - } - } - true - } - } - } -} - +#[derive(Debug)] enum Message { Request { method: String, @@ -177,59 +129,27 @@ enum Message { RotateEndpoint, } -impl Debug for Message { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Message::Request { - method, - params, - response: _, - retries, - } => write!(f, "Request({method}, {params:?}, _, {retries})"), - Message::Subscribe { - subscribe, - params, - unsubscribe, - response: _, - retries, - } => write!(f, "Subscribe({subscribe}, {params:?}, {unsubscribe}, _, {retries})"), - Message::RotateEndpoint => write!(f, "RotateEndpoint"), - } - } -} - #[async_trait] impl Extension for Client { type Config = ClientConfig; async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { - let health_check = config.health_check.clone(); - let endpoints = if config.shuffle_endpoints { + if config.shuffle_endpoints { let mut endpoints = config.endpoints.clone(); endpoints.shuffle(&mut thread_rng()); - endpoints + Ok(Self::new(endpoints, None, None, None)?) } else { - config.endpoints.clone() - }; - - // TODO: make the params configurable - Ok(Self::new( - endpoints, - Duration::from_secs(30), - Duration::from_secs(30), - None, - health_check, - )?) + Ok(Self::new(config.endpoints.clone(), None, None, None)?) + } } } impl Client { pub fn new( endpoints: impl IntoIterator>, - request_timeout: Duration, - connection_timeout: Duration, + request_timeout: Option, + connection_timeout: Option, retries: Option, - health_config: Option, ) -> Result { let endpoints: Vec<_> = endpoints.into_iter().map(|e| e.as_ref().to_string()).collect(); @@ -237,23 +157,7 @@ impl Client { return Err(anyhow!("No endpoints provided")); } - if let Some(0) = retries { - return Err(anyhow!("Retries need to be at least 1")); - } - - tracing::debug!("New client with endpoints: {endpoints:?}"); - - let endpoints = endpoints - .into_iter() - .map(|e| { - Arc::new(Endpoint::new( - e, - request_timeout, - connection_timeout, - health_config.clone(), - )) - }) - .collect::>(); + tracing::debug!("New client with endpoints: {:?}", endpoints); let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::(100); @@ -261,48 +165,61 @@ impl Client { let rotation_notify = Arc::new(Notify::new()); let rotation_notify_bg = rotation_notify.clone(); - let endpoints2 = endpoints.clone(); - let has_health_method = health_config.is_some(); - - let mut current_endpoint_idx = 0; - let mut selected_endpoint = endpoints[0].clone(); + let endpoints_ = endpoints.clone(); let background_task = tokio::spawn(async move { + let connect_backoff_counter = Arc::new(AtomicU32::new(0)); let request_backoff_counter = Arc::new(AtomicU32::new(0)); - // Select next endpoint with the highest health score, excluding the current one if possible - let select_healtiest = |endpoints: Vec>, current_idx: usize| async move { - if endpoints.len() == 1 { - let selected_endpoint = endpoints[0].clone(); - return (selected_endpoint, 0); - } - - let (idx, endpoint) = endpoints - .iter() - .enumerate() - .filter(|(idx, _)| *idx != current_idx) - .max_by_key(|(_, endpoint)| endpoint.health().score()) - .expect("No endpoints"); - (endpoint.clone(), idx) - }; - - let select_next = |endpoints: Vec>, current_idx: usize| async move { - let idx = (current_idx + 1) % endpoints.len(); - (endpoints[idx].clone(), idx) - }; + let current_endpoint = AtomicUsize::new(0); + + let connect_backoff_counter2 = connect_backoff_counter.clone(); + let build_ws = || async { + let build = || { + let current_endpoint = current_endpoint.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let url = &endpoints[current_endpoint % endpoints.len()]; + + tracing::info!("Connecting to endpoint: {}", url); + + // TODO: make those configurable + WsClientBuilder::default() + .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) + .connection_timeout(connection_timeout.unwrap_or(Duration::from_secs(30))) + .max_buffer_capacity_per_subscription(2048) + .max_concurrent_requests(2048) + .max_response_size(20 * 1024 * 1024) + .build(url) + .map_err(|e| (e, url.to_string())) + }; - let next_endpoint = |current_idx| { - if has_health_method { - select_healtiest(endpoints2.clone(), current_idx).boxed() - } else { - select_next(endpoints2.clone(), current_idx).boxed() + loop { + match build().await { + Ok(ws) => { + let ws = Arc::new(ws); + tracing::info!("Endpoint connected"); + connect_backoff_counter2.store(0, std::sync::atomic::Ordering::Relaxed); + break ws; + } + Err((e, url)) => { + tracing::warn!("Unable to connect to endpoint: '{url}' error: {e}"); + tokio::time::sleep(get_backoff_time(&connect_backoff_counter2)).await; + } + } } }; - let handle_message = |message: Message, endpoint: Arc| { + let mut ws = build_ws().await; + + let handle_message = |message: Message, ws: Arc| { let tx = message_tx_bg.clone(); let request_backoff_counter = request_backoff_counter.clone(); + // total timeout for a request + let task_timeout = request_timeout + .unwrap_or(Duration::from_secs(30)) + // buffer 5 seconds for the request to be processed + .saturating_add(Duration::from_secs(5)); + tokio::spawn(async move { match message { Message::Request { @@ -318,54 +235,71 @@ impl Client { return; } - match endpoint.request(&method, params.clone(), request_timeout).await { - result @ Ok(_) => { - request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - // make sure it's still connected - if response.is_closed() { - return; + if let Ok(result) = + tokio::time::timeout(task_timeout, ws.request(&method, params.clone())).await + { + match result { + result @ Ok(_) => { + request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(result); } - let _ = response.send(result); - } - Err(err) => { - tracing::debug!("Request failed: {err:?}"); - match err { - Error::RequestTimeout - | Error::Transport(_) - | Error::RestartNeeded(_) - | Error::MaxSlotsExceeded => { - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; - - // make sure it's still connected - if response.is_closed() { - return; + Err(err) => { + tracing::debug!("Request failed: {:?}", err); + match err { + Error::RequestTimeout + | Error::Transport(_) + | Error::RestartNeeded(_) + | Error::MaxSlotsExceeded => { + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; + + // make sure it's still connected + if response.is_closed() { + return; + } + + // make sure we still have retries left + if retries == 0 { + let _ = response.send(Err(Error::RequestTimeout)); + return; + } + + if matches!(err, Error::RequestTimeout) { + tx.send(Message::RotateEndpoint) + .await + .expect("Failed to send rotate message"); + } + + tx.send(Message::Request { + method, + params, + response, + retries, + }) + .await + .expect("Failed to send request message"); } - - // make sure we still have retries left - if retries == 0 { - let _ = response.send(Err(Error::RequestTimeout)); - return; + err => { + // make sure it's still connected + if response.is_closed() { + return; + } + // not something we can handle, send it back to the caller + let _ = response.send(Err(err)); } - - tx.send(Message::Request { - method, - params, - response, - retries, - }) - .await - .expect("Failed to send request message"); - } - err => { - // make sure it's still connected - if response.is_closed() { - return; - } - // not something we can handle, send it back to the caller - let _ = response.send(Err(err)); } } } + } else { + tracing::error!("request timed out method: {} params: {:?}", method, params); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(Err(Error::RequestTimeout)); } } Message::Subscribe { @@ -377,58 +311,75 @@ impl Client { } => { retries = retries.saturating_sub(1); - match endpoint - .subscribe(&subscribe, params.clone(), &unsubscribe, request_timeout) - .await + if let Ok(result) = tokio::time::timeout( + task_timeout, + ws.subscribe(&subscribe, params.clone(), &unsubscribe), + ) + .await { - result @ Ok(_) => { - request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - // make sure it's still connected - if response.is_closed() { - return; + match result { + result @ Ok(_) => { + request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(result); } - let _ = response.send(result); - } - Err(err) => { - tracing::debug!("Subscribe failed: {err:?}"); - match err { - Error::RequestTimeout - | Error::Transport(_) - | Error::RestartNeeded(_) - | Error::MaxSlotsExceeded => { - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; - - // make sure it's still connected - if response.is_closed() { - return; + Err(err) => { + tracing::debug!("Subscribe failed: {:?}", err); + match err { + Error::RequestTimeout + | Error::Transport(_) + | Error::RestartNeeded(_) + | Error::MaxSlotsExceeded => { + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; + + // make sure it's still connected + if response.is_closed() { + return; + } + + // make sure we still have retries left + if retries == 0 { + let _ = response.send(Err(Error::RequestTimeout)); + return; + } + + if matches!(err, Error::RequestTimeout) { + tx.send(Message::RotateEndpoint) + .await + .expect("Failed to send rotate message"); + } + + tx.send(Message::Subscribe { + subscribe, + params, + unsubscribe, + response, + retries, + }) + .await + .expect("Failed to send subscribe message") } - - // make sure we still have retries left - if retries == 0 { - let _ = response.send(Err(Error::RequestTimeout)); - return; + err => { + // make sure it's still connected + if response.is_closed() { + return; + } + // not something we can handle, send it back to the caller + let _ = response.send(Err(err)); } - - tx.send(Message::Subscribe { - subscribe, - params, - unsubscribe, - response, - retries, - }) - .await - .expect("Failed to send subscribe message") - } - err => { - // make sure it's still connected - if response.is_closed() { - return; - } - // not something we can handle, send it back to the caller - let _ = response.send(Err(err)); } } } + } else { + tracing::error!("subscribe timed out subscribe: {} params: {:?}", subscribe, params); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(Err(Error::RequestTimeout)); } } Message::RotateEndpoint => { @@ -440,30 +391,20 @@ impl Client { loop { tokio::select! { - _ = selected_endpoint.health().unhealthy() => { - // Current selected endpoint is unhealthy, try to rotate to another one. - // In case of all endpoints are unhealthy, we don't want to keep rotating but stick with the healthiest one. - - // The ws client maybe in a state that requires a reconnect - selected_endpoint.reconnect().await; - - let (new_selected_endpoint, new_current_endpoint_idx) = next_endpoint(current_endpoint_idx).await; - if new_current_endpoint_idx != current_endpoint_idx { - tracing::warn!("Switch to endpoint: {new_url}", new_url=new_selected_endpoint.url()); - selected_endpoint = new_selected_endpoint; - current_endpoint_idx = new_current_endpoint_idx; - } - rotation_notify_bg.notify_waiters(); + _ = ws.on_disconnect() => { + tracing::info!("Endpoint disconnected"); + tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; + ws = build_ws().await; } message = message_rx.recv() => { tracing::trace!("Received message {message:?}"); match message { Some(Message::RotateEndpoint) => { - tracing::info!("Rotating endpoint ..."); - (selected_endpoint, current_endpoint_idx) = next_endpoint(current_endpoint_idx).await; rotation_notify_bg.notify_waiters(); + tracing::info!("Rotate endpoint"); + ws = build_ws().await; } - Some(message) => handle_message(message, selected_endpoint.clone()), + Some(message) => handle_message(message, ws.clone()), None => { tracing::debug!("Client dropped"); break; @@ -474,8 +415,12 @@ impl Client { } }); + if let Some(0) = retries { + return Err(anyhow!("Retries need to be at least 1")); + } + Ok(Self { - endpoints, + endpoints: endpoints_, sender: message_tx, rotation_notify, retries: retries.unwrap_or(3), @@ -483,8 +428,12 @@ impl Client { }) } - pub fn endpoints(&self) -> &Vec> { - self.endpoints.as_ref() + pub fn with_endpoints(endpoints: impl IntoIterator>) -> Result { + Self::new(endpoints, None, None, None) + } + + pub fn endpoints(&self) -> &Vec { + &self.endpoints } pub async fn request(&self, method: &str, params: Vec) -> CallResult { @@ -544,7 +493,7 @@ impl Client { } } -pub fn get_backoff_time(counter: &Arc) -> Duration { +fn get_backoff_time(counter: &Arc) -> Duration { let min_time = 100u64; let step = 100u64; let max_count = 10u32; @@ -574,112 +523,3 @@ fn test_get_backoff_time() { vec![100, 200, 500, 1000, 1700, 2600, 3700, 5000, 6500, 8200, 10100, 10100] ); } - -#[test] -fn health_response_serialize_deserialize_works() { - let response = HealthResponse::Contains(vec![( - "isSyncing".to_string(), - Box::new(HealthResponse::Eq(false.into())), - )]); - - let expected = serde_yaml::from_str::( - r" - !contains - - - isSyncing - - !eq false - ", - ) - .unwrap(); - - assert_eq!(response, expected); -} - -#[test] -fn health_response_validation_works() { - use serde_json::json; - - let expected = serde_yaml::from_str::( - r" - !eq true - ", - ) - .unwrap(); - assert!(expected.validate(&json!(true))); - assert!(!expected.validate(&json!(false))); - - let expected = serde_yaml::from_str::( - r" - !contains - - - isSyncing - - !eq false - ", - ) - .unwrap(); - let cases = [ - (json!({ "isSyncing": false }), true), - (json!({ "isSyncing": true }), false), - (json!({ "isSyncing": false, "peers": 2 }), true), - (json!({ "isSyncing": true, "peers": 2 }), false), - (json!({}), false), - (json!(true), false), - ]; - for (input, output) in cases { - assert_eq!(expected.validate(&input), output); - } - - // multiple items - let expected = serde_yaml::from_str::( - r" - !contains - - - isSyncing - - !eq false - - - peers - - !eq 3 - ", - ) - .unwrap(); - let cases = [ - (json!({ "isSyncing": false, "peers": 3 }), true), - (json!({ "isSyncing": false, "peers": 2 }), false), - (json!({ "isSyncing": true, "peers": 3 }), false), - ]; - for (input, output) in cases { - assert_eq!(expected.validate(&input), output); - } - - // works with strings - let expected = serde_yaml::from_str::( - r" - !contains - - - foo - - !eq bar - ", - ) - .unwrap(); - assert!(expected.validate(&json!({ "foo": "bar" }))); - assert!(!expected.validate(&json!({ "foo": "bar bar" }))); - - // multiple nested items - let expected = serde_yaml::from_str::( - r" - !contains - - - foo - - !contains - - - one - - !eq subway - - - two - - !not_eq subway - ", - ) - .unwrap(); - let cases = [ - (json!({ "foo": { "one": "subway", "two": "not_subway" } }), true), - (json!({ "foo": { "one": "subway", "two": "subway" } }), false), - (json!({ "foo": { "subway": "one" } }), false), - (json!({ "bar" : { "foo": { "subway": "one", "two": "subway" } }}), false), - (json!({ "foo": "subway" }), false), - ]; - for (input, output) in cases { - assert_eq!(expected.validate(&input), output); - } -} diff --git a/src/extensions/client/tests.rs b/src/extensions/client/tests.rs index cf229a8..c8c9c7b 100644 --- a/src/extensions/client/tests.rs +++ b/src/extensions/client/tests.rs @@ -11,14 +11,7 @@ use tokio::sync::mpsc; async fn basic_request() { let (addr, handle, mut rx, _) = dummy_server().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let task = tokio::spawn(async move { let req = rx.recv().await.unwrap(); @@ -38,14 +31,7 @@ async fn basic_request() { async fn basic_subscription() { let (addr, handle, _, mut rx) = dummy_server().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let task = tokio::spawn(async move { let sub = rx.recv().await.unwrap(); @@ -75,22 +61,11 @@ async fn multiple_endpoints() { let (addr2, handle2, rx2, _) = dummy_server().await; let (addr3, handle3, rx3, _) = dummy_server().await; - let client = Client::new( - [ - format!("ws://{addr1}"), - format!("ws://{addr2}"), - format!("ws://{addr3}"), - ], - Duration::from_secs(1), - Duration::from_secs(1), - None, - Some(HealthCheckConfig { - interval_sec: 1, - healthy_response_time_ms: 250, - health_method: "mock_rpc".into(), - response: None, - }), - ) + let client = Client::with_endpoints([ + format!("ws://{addr1}"), + format!("ws://{addr2}"), + format!("ws://{addr3}"), + ]) .unwrap(); let handle_requests = |mut rx: mpsc::Receiver, n: u32| { @@ -113,7 +88,7 @@ async fn multiple_endpoints() { let result = client.request("mock_rpc", vec![22.into()]).await.unwrap(); - assert_eq!(result.to_string(), "3"); + assert_eq!(result.to_string(), "2"); client.rotate_endpoint().await; @@ -121,7 +96,7 @@ async fn multiple_endpoints() { let result = client.request("mock_rpc", vec![33.into()]).await.unwrap(); - assert_eq!(result.to_string(), "2"); + assert_eq!(result.to_string(), "3"); handle3.stop().unwrap(); @@ -141,37 +116,27 @@ async fn multiple_endpoints() { async fn concurrent_requests() { let (addr, handle, mut rx, _) = dummy_server().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let task = tokio::spawn(async move { let req1 = rx.recv().await.unwrap(); let req2 = rx.recv().await.unwrap(); let req3 = rx.recv().await.unwrap(); - let p1 = req1.params.clone(); - let p2 = req2.params.clone(); - let p3 = req3.params.clone(); - req1.respond(p1); - req2.respond(p2); - req3.respond(p3); + req1.respond(JsonValue::from_str("1").unwrap()); + req2.respond(JsonValue::from_str("2").unwrap()); + req3.respond(JsonValue::from_str("3").unwrap()); }); - let res1 = client.request("mock_rpc", vec![json!(1)]); - let res2 = client.request("mock_rpc", vec![json!(2)]); - let res3 = client.request("mock_rpc", vec![json!(3)]); + let res1 = client.request("mock_rpc", vec![]); + let res2 = client.request("mock_rpc", vec![]); + let res3 = client.request("mock_rpc", vec![]); let res = tokio::join!(res1, res2, res3); - assert_eq!(res.0.unwrap(), json!([1])); - assert_eq!(res.1.unwrap(), json!([2])); - assert_eq!(res.2.unwrap(), json!([3])); + assert_eq!(res.0.unwrap().to_string(), "1"); + assert_eq!(res.1.unwrap().to_string(), "2"); + assert_eq!(res.2.unwrap().to_string(), "3"); handle.stop().unwrap(); task.await.unwrap(); @@ -184,10 +149,9 @@ async fn retry_requests_successful() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_millis(100), - Duration::from_millis(100), - Some(2), + Some(Duration::from_millis(100)), None, + Some(2), ) .unwrap(); @@ -222,10 +186,9 @@ async fn retry_requests_out_of_retries() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_millis(100), - Duration::from_millis(100), - Some(2), + Some(Duration::from_millis(100)), None, + Some(2), ) .unwrap(); @@ -253,118 +216,3 @@ async fn retry_requests_out_of_retries() { handle1.stop().unwrap(); handle2.stop().unwrap(); } - -#[tokio::test] -async fn health_check_works() { - let (addr1, handle1) = dummy_server_extend(Box::new(|builder| { - let mut system_health = builder.register_method("system_health"); - tokio::spawn(async move { - loop { - tokio::select! { - Some(req) = system_health.recv() => { - req.respond(json!({ "isSyncing": true, "peers": 1, "shouldHavePeers": true })); - } - } - } - }); - })) - .await; - - let (addr2, handle2) = dummy_server_extend(Box::new(|builder| { - let mut system_health = builder.register_method("system_health"); - tokio::spawn(async move { - loop { - tokio::select! { - Some(req) = system_health.recv() => { - req.respond(json!({ "isSyncing": false, "peers": 1, "shouldHavePeers": true })); - } - } - } - }); - })) - .await; - - let client = Client::new( - [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - Some(HealthCheckConfig { - interval_sec: 1, - healthy_response_time_ms: 250, - health_method: "system_health".into(), - response: Some(HealthResponse::Contains(vec![( - "isSyncing".to_string(), - Box::new(HealthResponse::Eq(false.into())), - )])), - }), - ) - .unwrap(); - - // first endpoint is stale - let res = client.request("system_health", vec![]).await; - assert_eq!( - res.unwrap(), - json!({ "isSyncing": true, "peers": 1, "shouldHavePeers": true }) - ); - - // wait for the health check to run - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(1_050)).await; - }) - .await - .unwrap(); - - // second endpoint is healthy - let res = client.request("system_health", vec![]).await; - assert_eq!( - res.unwrap(), - json!({ "isSyncing": false, "peers": 1, "shouldHavePeers": true }) - ); - - handle1.stop().unwrap(); - handle2.stop().unwrap(); -} - -#[tokio::test] -async fn reconnect_on_disconnect() { - let (addr1, handle1, mut rx1, _) = dummy_server().await; - let (addr2, handle2, mut rx2, _) = dummy_server().await; - - let client = Client::new( - [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_millis(100), - Duration::from_millis(100), - Some(2), - None, - ) - .unwrap(); - - let h1 = tokio::spawn(async move { - let _req = rx1.recv().await.unwrap(); - // no response, let it timeout - tokio::time::sleep(Duration::from_millis(200)).await; - }); - - let h2 = tokio::spawn(async move { - let req = rx2.recv().await.unwrap(); - req.respond(json!(1)); - }); - - let h3 = tokio::spawn(async move { - let res = client.request("mock_rpc", vec![]).await; - assert_eq!(res.unwrap(), json!(1)); - - tokio::time::sleep(Duration::from_millis(2000)).await; - - assert_eq!(client.endpoints()[0].connect_counter(), 2); - assert_eq!(client.endpoints()[1].connect_counter(), 1); - }); - - h3.await.unwrap(); - h1.await.unwrap(); - h2.await.unwrap(); - - handle1.stop().unwrap(); - handle2.stop().unwrap(); -} diff --git a/src/extensions/validator/mod.rs b/src/extensions/validator/mod.rs index 00ee5e8..742ae93 100644 --- a/src/extensions/validator/mod.rs +++ b/src/extensions/validator/mod.rs @@ -1,15 +1,14 @@ use crate::extensions::client::Client; use crate::middlewares::{CallRequest, CallResult}; -use crate::utils::errors; use async_trait::async_trait; use serde::Deserialize; use std::sync::Arc; use super::{Extension, ExtensionRegistry}; -#[derive(Default)] pub struct Validator { - pub config: ValidateConfig, + config: ValidateConfig, + clients: Vec>, } #[derive(Deserialize, Default, Debug, Clone)] @@ -21,32 +20,38 @@ pub struct ValidateConfig { impl Extension for Validator { type Config = ValidateConfig; - async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { - Ok(Self::new(config.clone())) + async fn from_config(config: &Self::Config, registry: &ExtensionRegistry) -> Result { + let client = registry.get::().await.expect("Client extension not found"); + + let clients = client + .endpoints() + .iter() + .map(|e| Arc::new(Client::with_endpoints([e]).expect("Unable to create client"))) + .collect(); + + Ok(Self::new(config.clone(), clients)) } } impl Validator { - pub fn new(config: ValidateConfig) -> Self { - Self { config } + pub fn new(config: ValidateConfig, clients: Vec>) -> Self { + Self { config, clients } } pub fn ignore(&self, method: &String) -> bool { self.config.ignore_methods.contains(method) } - pub fn validate(&self, client: Arc, request: CallRequest, response: CallResult) { + pub fn validate(&self, request: CallRequest, response: CallResult) { + let clients = self.clients.clone(); tokio::spawn(async move { - let healthy_endpoints = client.endpoints().iter().filter(|x| x.health().score() > 0); - futures::future::join_all(healthy_endpoints.map(|endpoint| async { - let expected = endpoint + futures::future::join_all(clients.iter().map(|client| async { + let expected = client .request( &request.method, request.params.clone(), - std::time::Duration::from_secs(30), ) - .await - .map_err(errors::map_error); + .await; if response != expected { let request = serde_json::to_string_pretty(&request).unwrap_or_default(); @@ -58,7 +63,7 @@ impl Validator { Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_default(), Err(e) => e.to_string() }; - let endpoint_url = endpoint.url(); + let endpoint_url = client.endpoints()[0].clone(); tracing::error!("Response mismatch for request:\n{request}\nSubway response:\n{actual}\nEndpoint {endpoint_url} response:\n{expected}"); } })).await; diff --git a/src/middlewares/methods/block_tag.rs b/src/middlewares/methods/block_tag.rs index 00aecd7..1e0c6cf 100644 --- a/src/middlewares/methods/block_tag.rs +++ b/src/middlewares/methods/block_tag.rs @@ -165,14 +165,7 @@ mod tests { let (addr, _server) = builder.build().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let api = EthApi::new(Arc::new(client), Duration::from_secs(100)); ( diff --git a/src/middlewares/methods/inject_params.rs b/src/middlewares/methods/inject_params.rs index bbfc0f1..bcbca9f 100644 --- a/src/middlewares/methods/inject_params.rs +++ b/src/middlewares/methods/inject_params.rs @@ -211,14 +211,7 @@ mod tests { let (addr, _server) = builder.build().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let api = SubstrateApi::new(Arc::new(client), Duration::from_secs(100)); ExecutionContext { diff --git a/src/middlewares/methods/validate.rs b/src/middlewares/methods/validate.rs index 24e1971..d4edf10 100644 --- a/src/middlewares/methods/validate.rs +++ b/src/middlewares/methods/validate.rs @@ -2,19 +2,18 @@ use async_trait::async_trait; use std::sync::Arc; use crate::{ - extensions::{client::Client, validator::Validator}, + extensions::validator::Validator, middlewares::{CallRequest, CallResult, Middleware, MiddlewareBuilder, NextFn, RpcMethod}, utils::{TypeRegistry, TypeRegistryRef}, }; pub struct ValidateMiddleware { validator: Arc, - client: Arc, } impl ValidateMiddleware { - pub fn new(validator: Arc, client: Arc) -> Self { - Self { validator, client } + pub fn new(validator: Arc) -> Self { + Self { validator } } } @@ -24,14 +23,13 @@ impl MiddlewareBuilder for ValidateMiddlewar _method: &RpcMethod, extensions: &TypeRegistryRef, ) -> Option>> { - let validate = extensions.read().await.get::().unwrap_or_default(); - - let client = extensions + let validate = extensions .read() .await - .get::() - .expect("Client extension not found"); - Some(Box::new(ValidateMiddleware::new(validate, client))) + .get::() + .expect("Validator extension not found"); + + Some(Box::new(ValidateMiddleware::new(validate))) } } @@ -45,7 +43,7 @@ impl Middleware for ValidateMiddleware { ) -> CallResult { let result = next(request.clone(), context).await; if !self.validator.ignore(&request.method) { - self.validator.validate(self.client.clone(), request, result.clone()); + self.validator.validate(request, result.clone()); } result } diff --git a/src/server.rs b/src/server.rs index 4776b26..0404396 100644 --- a/src/server.rs +++ b/src/server.rs @@ -250,7 +250,6 @@ mod tests { client: Some(ClientConfig { endpoints: vec![endpoint], shuffle_endpoints: false, - health_check: None, }), server: Some(ServerConfig { listen_address: "127.0.0.1".to_string(), diff --git a/src/tests/merge_subscription.rs b/src/tests/merge_subscription.rs index 89054b0..a41bb4b 100644 --- a/src/tests/merge_subscription.rs +++ b/src/tests/merge_subscription.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use serde_json::json; use crate::{ @@ -51,7 +49,6 @@ async fn merge_subscription_works() { client: Some(ClientConfig { endpoints: vec![format!("ws://{addr}")], shuffle_endpoints: false, - health_check: None, }), server: Some(ServerConfig { listen_address: "0.0.0.0".to_string(), @@ -99,14 +96,7 @@ async fn merge_subscription_works() { let subway_server = server::build(config).await.unwrap(); let addr = subway_server.addr; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let mut first_sub = client .subscribe(subscribe_mock, vec![], unsubscribe_mock) .await diff --git a/src/tests/upstream.rs b/src/tests/upstream.rs index ab63169..84c38f5 100644 --- a/src/tests/upstream.rs +++ b/src/tests/upstream.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use crate::{ config::{Config, MergeStrategy, MiddlewaresConfig, RpcDefinitions, RpcSubscription}, extensions::{ @@ -33,7 +31,6 @@ async fn upstream_error_propagate() { client: Some(ClientConfig { endpoints: vec![format!("ws://{addr}")], shuffle_endpoints: false, - health_check: None, }), server: Some(ServerConfig { listen_address: "0.0.0.0".to_string(), @@ -75,14 +72,7 @@ async fn upstream_error_propagate() { let subway_server = server::build(config).await.unwrap(); let addr = subway_server.addr; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let result = client.subscribe(subscribe_mock, vec![], unsubscribe_mock).await; assert!(result