Skip to content

Commit

Permalink
Improved ssh retry mechanism (#501)
Browse files Browse the repository at this point in the history
This commit improves the retry mechanism's resilience. Before this commit,
the mechanism involved only opening the connection. Now, it covers both
connection opening and connection communication exceptions. In addition,
now the `SSHConnector` correctly retrieves the free disk size to determine
the available storage space, instead of using the full disk size. Finally, two
logging messages were fixed. A Docker `DEBUG` logging message has a new
line at the end. The step exit status was in lowercase, but in StreamFlow, it is
shown in uppercase.
  • Loading branch information
LanderOtto authored Jan 17, 2025
1 parent 1558b3f commit d3b82ec
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ jobs:
python -m pip install -r docs/requirements.txt
- name: "Build documentation and check for consistency"
env:
CHECKSUM: "6642d8d7533c7cf3ec9de7bdbf1871e3a313dd73fa6551fe3bb10e4e94e7ff08"
CHECKSUM: "b59239241d3529a179df6158271dd00ba7a86e807a37a11ac8e078ad9c377f94"
run: |
cd docs
HASH="$(make checksum | tail -n1)"
Expand Down
2 changes: 1 addition & 1 deletion streamflow/deployment/connector/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ async def get_available_locations(
(json_end := output.rfind("}")) != -1
):
if json_start != 0 and logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Docker Compose log: {output[:json_start]}")
logger.debug(f"Docker Compose log: {output[:json_start].strip()}")
locations = json.loads(output[json_start : json_end + 1])
else:
raise WorkflowExecutionException(
Expand Down
208 changes: 121 additions & 87 deletions streamflow/deployment/connector/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any

import asyncssh
from asyncssh import ChannelOpenError, ConnectionLost
from asyncssh import ChannelOpenError, ConnectionLost, DisconnectError

from streamflow.core import utils
from streamflow.core.data import StreamWrapper, StreamWrapperContextManager
Expand Down Expand Up @@ -42,55 +42,14 @@ def __init__(
streamflow_config_dir: str,
config: SSHConfig,
max_concurrent_sessions: int,
retries: int,
retry_delay: int,
):
self._streamflow_config_dir: str = streamflow_config_dir
self._config: SSHConfig = config
self._max_concurrent_sessions: int = max_concurrent_sessions
self._ssh_connection: asyncssh.SSHClientConnection | None = None
self._connecting = False
self._retries = retries
self._retry_delay = retry_delay
self._connecting: bool = False
self._connect_event: asyncio.Event = asyncio.Event()

async def get_connection(self) -> asyncssh.SSHClientConnection:
if self._ssh_connection is None:
if not self._connecting:
self._connecting = True
for i in range(1, self._retries + 1):
try:
self._ssh_connection = await self._get_connection(self._config)
break
except (ConnectionError, ConnectionLost) as e:
if i == self._retries:
logger.exception(
f"Impossible to connect to {self._config.hostname}: {e}"
)
self._connect_event.set()
self.close()
raise
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {self._config.hostname} failed: {e}. "
f"Waiting {self._retry_delay} seconds for the next attempt."
)
except asyncssh.Error:
self._connect_event.set()
self.close()
raise
await asyncio.sleep(self._retry_delay)
self._connect_event.set()
else:
await self._connect_event.wait()
if self._ssh_connection is None:
raise WorkflowExecutionException(
f"Impossible to connect to {self._config.hostname}"
)
return self._ssh_connection

def get_hostname(self) -> str:
return self._config.hostname
self.connection_attempts: int = 0

async def _get_connection(
self, config: SSHConfig
Expand Down Expand Up @@ -132,19 +91,53 @@ def _get_param_from_file(self, file_path: str):
with open(file_path) as f:
return f.read().strip()

def close(self):
self._connecting = False
async def close(self):
if self._ssh_connection is not None:
if len(self._ssh_connection._channels) > 0:
logger.warning(
f"Channels still open after closing the SSH connection to {self.get_hostname()}. Forcing closing."
)
self._ssh_connection.close()
await self._ssh_connection.wait_closed()
self._ssh_connection = None
if self._connect_event.is_set():
self._connect_event.clear()
self._connecting = False

def full(self) -> bool:
if self._ssh_connection:
return len(self._ssh_connection._channels) >= self._max_concurrent_sessions
else:
return False
return (
self._ssh_connection
and len(self._ssh_connection._channels) >= self._max_concurrent_sessions
)

async def get_connection(self) -> asyncssh.SSHClientConnection:
if self._ssh_connection is None:
if not self._connecting:
self._connecting = True
try:
self._ssh_connection = await self._get_connection(self._config)
except (ConnectionError, asyncssh.Error) as e:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {self._config.hostname} failed: {e}."
)
await self.close()
raise
finally:
self._connect_event.set()
else:
await self._connect_event.wait()
if self._ssh_connection is None:
raise WorkflowExecutionException(
f"Impossible to connect to {self._config.hostname}"
)
return self._ssh_connection

def get_hostname(self) -> str:
return self._config.hostname

async def reset(self):
await self.close()
self.connection_attempts += 1
self._connect_event.clear()


class SSHContextManager:
Expand All @@ -154,6 +147,8 @@ def __init__(
contexts: MutableSequence[SSHContext],
command: str,
environment: MutableMapping[str, str] | None,
retries: int,
retry_delay: int,
stdin: int = asyncio.subprocess.PIPE,
stdout: int = asyncio.subprocess.PIPE,
stderr: int = asyncio.subprocess.PIPE,
Expand All @@ -167,16 +162,40 @@ def __init__(
self.encoding: str | None = encoding
self._condition: asyncio.Condition = condition
self._contexts: MutableSequence[SSHContext] = contexts
self._retries: int = retries
self._retry_delay: int = retry_delay
self._selected_context: SSHContext | None = None
self._proc: asyncssh.SSHClientProcess | None = None

async def __aenter__(self) -> asyncssh.SSHClientProcess:
async with self._condition:
available_contexts = self._contexts
while True:
for context in self._contexts:
if not context.full():
ssh_connection = await context.get_connection()
if (
len(
available_contexts := [
c
for c in available_contexts
if c.connection_attempts < self._retries
]
)
== 0
):
raise WorkflowExecutionException(
f"Hosts {[c.get_hostname() for c in self._contexts]} have no "
f"more available contexts: terminating."
)
elif (
len(
free_contexts := [c for c in available_contexts if not c.full()]
)
== 0
):
await self._condition.wait()
else:
for context in free_contexts:
try:
ssh_connection = await context.get_connection()
self._selected_context = context
self._proc = await ssh_connection.create_process(
self.command,
Expand All @@ -187,13 +206,27 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess:
encoding=self.encoding,
)
await self._proc.__aenter__()
self._selected_context.connection_attempts = 0
return self._proc
except ChannelOpenError as coe:
logger.warning(
f"Error opening SSH session to {context.get_hostname()} "
f"to execute command `{self.command}`: [{coe.code}] {coe.reason}"
)
await self._condition.wait()
except (
ChannelOpenError,
ConnectionError,
ConnectionLost,
DisconnectError,
) as exc:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Error {type(exc).__name__} opening SSH session to {context.get_hostname()} "
f"to execute command `{self.command}`: {str(exc)}"
)
if not isinstance(exc, ChannelOpenError):
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {context.get_hostname()}: attempt {context.connection_attempts}"
)
self._selected_context = None
await context.reset()
await asyncio.sleep(self._retry_delay)

async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self._condition:
Expand All @@ -219,15 +252,14 @@ def __init__(
streamflow_config_dir=streamflow_config_dir,
config=config,
max_concurrent_sessions=max_concurrent_sessions,
retries=retries,
retry_delay=retry_delay,
)
for _ in range(max_connections)
]
self._retries = retries
self._retry_delay = retry_delay

def close(self):
for c in self._contexts:
c.close()
async def close(self):
await asyncio.gather(*(asyncio.create_task(c.close()) for c in self._contexts))

def get(
self,
Expand All @@ -247,6 +279,8 @@ def get(
stdout=stdout,
stderr=stderr,
encoding=encoding,
retries=self._retries,
retry_delay=self._retry_delay,
)


Expand Down Expand Up @@ -429,7 +463,7 @@ async def _get_available_location(self, location: str) -> Hardware:
location=location,
command="nproc && "
"free | grep Mem | awk '{print $2}' && "
"df -aT | tail -n +2 | awk 'NF == 1 {device = $1; getline; $0 = device $0} {print $7, $2, $3}'",
"df -aT | tail -n +2 | awk 'NF == 1 {device = $1; getline; $0 = device $0} {print $7, $2, $5}'",
stderr=asyncio.subprocess.STDOUT,
) as proc:
result = await proc.wait()
Expand Down Expand Up @@ -653,27 +687,27 @@ async def run(
workdir=workdir,
)
command = utils.encode_command(command)
async with self._get_ssh_client_process(
location=location.name,
command=command,
stderr=asyncio.subprocess.STDOUT,
environment=environment,
) as proc:
result = await proc.wait(timeout=timeout)
else:
async with self._get_ssh_client_process(
location=location.name,
command=command,
stderr=asyncio.subprocess.STDOUT,
environment=environment,
) as proc:
result = await proc.wait(timeout=timeout)
return result.stdout.strip(), result.returncode if capture_output else None
async with self._get_ssh_client_process(
location=location.name,
command=command,
stderr=asyncio.subprocess.STDOUT,
environment=environment,
) as proc:
result = await proc.wait(timeout=timeout)
return (result.stdout.strip(), result.returncode) if capture_output else None

async def undeploy(self, external: bool) -> None:
for ssh_context in self.ssh_context_factories.values():
ssh_context.close()
await asyncio.gather(
*(
asyncio.create_task(ssh_context.close())
for ssh_context in self.ssh_context_factories.values()
)
)
self.ssh_context_factories = {}
for ssh_context in self.data_transfer_context_factories.values():
ssh_context.close()
await asyncio.gather(
*(
asyncio.create_task(ssh_context.close())
for ssh_context in self.data_transfer_context_factories.values()
)
)
self.data_transfer_context_factories = {}
2 changes: 1 addition & 1 deletion streamflow/workflow/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def _get_inputs(self, input_ports: MutableMapping[str, Port]):
if logger.isEnabledFor(logging.DEBUG):
if check_termination(inputs.values()):
logger.debug(
f"Step {self.name} received termination token with Status {_reduce_statuses([t.value for t in inputs.values()]).name.lower()}"
f"Step {self.name} received termination token with Status {_reduce_statuses([t.value for t in inputs.values()]).name}"
)
else:
logger.debug(
Expand Down
11 changes: 7 additions & 4 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import asyncio
import os
import re
from collections.abc import Callable, MutableSequence
from typing import Any

import asyncssh
import pytest
import pytest_asyncio

Expand Down Expand Up @@ -145,7 +145,10 @@ async def test_ssh_connector_multiple_request_fail(context: StreamFlowContext) -
*(connector.get_available_locations() for _ in range(3)),
return_exceptions=True,
):
assert isinstance(result, (ConnectionError, asyncssh.Error)) or (
isinstance(result, WorkflowExecutionException)
and result.args[0] == "Impossible to connect to .*"
assert (
re.match(
r"Hosts \[.*] have no more available contexts: terminating.",
result.args[0],
)
is not None
)
2 changes: 2 additions & 0 deletions tests/utils/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ async def get_ssh_deployment_config(_context: StreamFlowContext):
"hostname": "127.0.0.1:2222",
"sshKey": f.name,
"username": "linuxserver.io",
"retries": 2,
"retryDelay": 5,
}
],
"maxConcurrentSessions": 10,
Expand Down

0 comments on commit d3b82ec

Please sign in to comment.