Skip to content

Commit

Permalink
Merge pull request #109 from nihohit/check-all-connections
Browse files Browse the repository at this point in the history
Add `check_node_connections` function.
  • Loading branch information
shachlanAmazon authored Jan 28, 2024
2 parents 2e31b36 + d8083d3 commit 728dac2
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 47 deletions.
4 changes: 2 additions & 2 deletions redis/src/cluster_async/connections_container.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::cluster_topology::TopologyHash;
type IdentifierType = ArcStr;

#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) struct ClusterNode<Connection> {
pub struct ClusterNode<Connection> {
pub user_connection: Connection,
pub management_connection: Option<Connection>,
pub ip: Option<IpAddr>,
Expand All @@ -21,7 +21,7 @@ impl<Connection> ClusterNode<Connection>
where
Connection: Clone,
{
pub(crate) fn new(
pub fn new(
user_connection: Connection,
management_connection: Option<Connection>,
ip: Option<IpAddr>,
Expand Down
87 changes: 81 additions & 6 deletions redis/src/cluster_async/connections_logic.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
use std::{
iter::Iterator,
net::{IpAddr, SocketAddr},
};

use super::{AsyncClusterNode, Connect};
use crate::{
aio::{get_socket_addrs, ConnectionLike},
cluster::get_connection_info,
cluster_client::ClusterParams,
RedisResult,
};

use futures_time::future::FutureExt;
use std::{
iter::Iterator,
net::{IpAddr, SocketAddr},
};
use futures_util::join;
use tracing::warn;

#[doc(hidden)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum RefreshConnectionType {
// Refresh only user connections
OnlyUserConnection,
// Refresh only management connections
OnlyManagementConnection,
// Refresh all connections: both management and user connections.
AllConnections,
}

/// Return true if a DNS change is detected, otherwise return false.
/// This function takes a node's address, examines if its host has encountered a DNS change, where the node's endpoint now leads to a different IP address.
/// If no socket addresses are discovered for the node's host address, or if it's a non-DNS address, it returns false.
/// In case the node's host address resolves to socket addresses and none of them match the current connection's IP,
/// a DNS change is detected, so the current connection isn't valid anymore and a new connection should be made.
async fn is_dns_changed(addr: &str, curr_ip: &IpAddr) -> bool {
async fn has_dns_changed(addr: &str, curr_ip: &IpAddr) -> bool {
let (host, port) = match get_host_and_port_from_addr(addr) {
Some((host, port)) => (host, port),
None => return false,
Expand All @@ -40,7 +55,7 @@ where
if let Some(node) = node {
let mut conn = node.user_connection.await;
if let Some(ref ip) = node.ip {
if is_dns_changed(addr, ip).await {
if has_dns_changed(addr, ip).await {
return connect_and_check(addr, params.clone(), None).await;
}
};
Expand Down Expand Up @@ -77,6 +92,66 @@ where
Ok((conn, ip))
}

/// The function returns None if the checked connection/s are healthy. Otherwise, it returns the type of the unhealthy connection/s.
#[allow(dead_code)]
#[doc(hidden)]
pub async fn check_node_connections<C>(
node: &AsyncClusterNode<C>,
params: &ClusterParams,
conn_type: RefreshConnectionType,
address: &str,
) -> Option<RefreshConnectionType>
where
C: ConnectionLike + Send + 'static + Clone,
{
let timeout = params.connection_timeout.into();
let (check_mgmt_connection, check_user_connection) = match conn_type {
RefreshConnectionType::OnlyUserConnection => (false, true),
RefreshConnectionType::OnlyManagementConnection => (true, false),
RefreshConnectionType::AllConnections => (true, true),
};
let check = |conn, timeout, conn_type| async move {
match check_connection(&mut conn.await, timeout).await {
Ok(_) => false,
Err(err) => {
warn!(
"The {} connection for node {} is unhealthy. Error: {:?}",
conn_type, address, err
);
true
}
}
};
let (mgmt_failed, user_failed) = join!(
async {
if !check_mgmt_connection {
return false;
}
match node.management_connection.clone() {
Some(conn) => check(conn, timeout, "management").await,
None => {
warn!("The management connection for node {} isn't set", address);
true
}
}
},
async {
if !check_user_connection {
return false;
}
let conn = node.user_connection.clone();
check(conn, timeout, "user").await
},
);

match (mgmt_failed, user_failed) {
(true, true) => Some(RefreshConnectionType::AllConnections),
(true, false) => Some(RefreshConnectionType::OnlyManagementConnection),
(false, true) => Some(RefreshConnectionType::OnlyUserConnection),
(false, false) => None,
}
}

async fn check_connection<C>(conn: &mut C, timeout: futures_time::time::Duration) -> RedisResult<()>
where
C: ConnectionLike + Send + 'static,
Expand Down
3 changes: 2 additions & 1 deletion redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ where
}

type ConnectionFuture<C> = future::Shared<BoxFuture<'static, C>>;
type AsyncClusterNode<C> = ClusterNode<ConnectionFuture<C>>;
/// Cluster node for async connections
pub type AsyncClusterNode<C> = ClusterNode<ConnectionFuture<C>>;
type ConnectionMap<C> = connections_container::ConnectionsMap<ConnectionFuture<C>>;
type ConnectionsContainer<C> =
self::connections_container::ConnectionsContainer<ConnectionFuture<C>>;
Expand Down
45 changes: 31 additions & 14 deletions redis/tests/support/mock_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub struct MockConnectionBehavior {
}

impl MockConnectionBehavior {
pub fn new(id: &str, handler: Handler) -> Self {
fn new(id: &str, handler: Handler) -> Self {
Self {
id: id.to_string(),
handler,
Expand All @@ -49,23 +49,20 @@ impl MockConnectionBehavior {
}
}

#[must_use]
pub fn register_new(id: &str, handler: Handler) -> RemoveHandler {
get_behaviors().insert(id.to_string(), Self::new(id, handler));
RemoveHandler(vec![id.to_string()])
}

fn get_handler(&self) -> Handler {
self.handler.clone()
}
}

pub fn add_new_mock_connection_behavior(name: &str, handler: Handler) {
MOCK_CONN_BEHAVIORS
.write()
.unwrap()
.insert(name.to_string(), MockConnectionBehavior::new(name, handler));
}

pub fn modify_mock_connection_behavior(name: &str, func: impl FnOnce(&mut MockConnectionBehavior)) {
func(
MOCK_CONN_BEHAVIORS
.write()
.unwrap()
get_behaviors()
.get_mut(name)
.expect("Handler `{name}` was not installed"),
);
Expand All @@ -80,9 +77,26 @@ pub fn get_mock_connection_handler(name: &str) -> Handler {
.get_handler()
}

pub fn get_mock_connection(name: &str, id: usize) -> MockConnection {
get_mock_connection_with_port(name, id, 6379)
}

pub fn get_mock_connection_with_port(name: &str, id: usize, port: u16) -> MockConnection {
MockConnection {
id,
handler: get_mock_connection_handler(name),
port,
}
}

static MOCK_CONN_BEHAVIORS: Lazy<RwLock<HashMap<String, MockConnectionBehavior>>> =
Lazy::new(Default::default);

fn get_behaviors() -> std::sync::RwLockWriteGuard<'static, HashMap<String, MockConnectionBehavior>>
{
MOCK_CONN_BEHAVIORS.write().unwrap()
}

#[derive(Default)]
pub enum ConnectionIPReturnType {
/// New connections' IP will be returned as None
Expand Down Expand Up @@ -410,7 +424,7 @@ pub struct RemoveHandler(Vec<String>);
impl Drop for RemoveHandler {
fn drop(&mut self) {
for id in &self.0 {
MOCK_CONN_BEHAVIORS.write().unwrap().remove(id);
get_behaviors().remove(id);
}
}
}
Expand Down Expand Up @@ -440,7 +454,10 @@ impl MockEnv {
.unwrap();

let id = id.to_string();
add_new_mock_connection_behavior(&id, Arc::new(move |cmd, port| handler(cmd, port)));
let handler = MockConnectionBehavior::register_new(
&id,
Arc::new(move |cmd, port| handler(cmd, port)),
);
let client = client_builder.build().unwrap();
let connection = client.get_generic_connection().unwrap();
#[cfg(feature = "cluster-async")]
Expand All @@ -454,7 +471,7 @@ impl MockEnv {
connection,
#[cfg(feature = "cluster-async")]
async_connection,
handler: RemoveHandler(vec![id]),
handler,
}
}
}
Loading

0 comments on commit 728dac2

Please sign in to comment.