Skip to content

Commit

Permalink
Adding signal processing to the task monitor condition (#137)
Browse files Browse the repository at this point in the history
*Description of changes:*

The task monitor needs to process the following 4 MWAA signals for the
graceful update project:
1. Termination signal: Graceful termination of the workers when the
environment is going through a graceful update
2. Resume signal: Reverting the state of graceful termination and resume
work when the environment is going through a rollback after attempting a
graceful update
3. Kill signal: Shutting down the worker without waiting for the current
Airflow tasks to finish when the environment is going through a forced
update
4. Activation signal: Starting consumption of work from the queue after
termination protection has enabled on the corresponding Fargate task

The processing is gated behind certain environment variables which are
either absent or marked as false for an environment which does not have
graceful updates enabled.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Signed-off-by: ashishgo-aws <[email protected]>
  • Loading branch information
ashishgo-aws authored Sep 20, 2024
1 parent b4af67c commit 155f0e0
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 46 deletions.
3 changes: 3 additions & 0 deletions images/airflow/2.9.2/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ x-airflow-common: &airflow-common
MWAA__LOGGING__AIRFLOW_WORKER_LOGS_ENABLED: ${MWAA__LOGGING__AIRFLOW_WORKER_LOGS_ENABLED}
MWAA__LOGGING__AIRFLOW_WORKER_LOG_GROUP_ARN: ${MWAA__LOGGING__AIRFLOW_WORKER_LOG_GROUP_ARN}
MWAA__LOGGING__AIRFLOW_WORKER_LOG_LEVEL: ${MWAA__LOGGING__AIRFLOW_WORKER_LOG_LEVEL}
MWAA__CORE__TASK_MONITORING_ENABLED: ${MWAA__CORE__TASK_MONITORING_ENABLED}
MWAA__CORE__TERMINATE_IF_IDLE: ${MWAA__CORE__TERMINATE_IF_IDLE}
MWAA__CORE__MWAA_SIGNAL_HANDLING_ENABLED: ${MWAA__CORE__MWAA_SIGNAL_HANDLING_ENABLED}

volumes:
- ./dags:/usr/local/airflow/dags
Expand Down
4 changes: 2 additions & 2 deletions images/airflow/2.9.2/python/mwaa/celery/sqs_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ def __init__(self, *args, **kwargs):

self.hub = kwargs.get("hub") or get_event_loop()

# Dynamic workers have the MWAA__CORE__TASK_MONITORING_ENABLED set to 'true'.
# This will be used to determine if idle worker checks are to be enabled.
# MWAA__CORE__TASK_MONITORING_ENABLED is set to 'true' for workers where we want to monitor count of tasks currently getting
# executed on the worker. This will be used to determine if idle worker checks are to be enabled.
self.idle_worker_monitoring_enabled = (
os.environ.get("MWAA__CORE__TASK_MONITORING_ENABLED", "false") == "true"
)
Expand Down
220 changes: 199 additions & 21 deletions images/airflow/2.9.2/python/mwaa/celery/task_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from enum import Enum
from multiprocessing import shared_memory
from builtins import memoryview
import math
import time
from typing import Any, Dict, List

# 3rd-party imports
Expand Down Expand Up @@ -55,21 +57,34 @@
)
DEFAULT_QUEUE_ENV_KEY = "AIRFLOW__CELERY__DEFAULT_QUEUE"

# Allow at least 3 minutes for the worker to warm up before checking for idleness or
# cleaning abandoned resources.
# A worker maybe be busy loading the required libraries or polling messages from the environment SQS queue, so we allow the worker to
# warm up before checking for idleness or cleaning abandoned resources.
IDLENESS_CHECK_WARMUP_WAIT_PERIOD = timedelta(minutes=3)
CLEANUP_ABANDONED_RESOURCES_WARMUP_WAIT_PERIOD = timedelta(minutes=3)
# Allow at least 1 second between any two consecutive idleness checks.
# If the idleness check is made too aggressively, then we will be reusing the result of the previous check till
# the check delay threshold is reached.
IDLENESS_CHECK_DELAY_PERIOD = timedelta(seconds=1)
# The worker should be idle for at least 2 consecutive idleness check before being
# declared idle.
# The worker should be idle for some consecutive checks before being declared idle.
CONSECUTIVE_IDLENESS_CHECK_THRESHOLD = 2
# Allow at least 1 minute between any two consecutive abandoned resource cleanups.
# The monitor is also responsible for performing cleanups/corrections of in-memory state in case of issues such as inability to
# terminate the Airflow task process once it has finished or deleting the message from the environment SQS queue. This process requires
# scanning of multiple shared memory blocks and possibly executing SQS operations, so we do this only after a time threshold since
# the last cleanup is breached.
CLEANUP_ABANDONED_RESOURCES_DELAY_PERIOD = timedelta(minutes=1)
# If in the improbable case of the worker picking up new tasks after having paused its
# consumption, we reset the worker to non-idle state and backoff from checking for
# idleness for a minute.
# If in the improbable case of the worker picking up new tasks after having paused its consumption, we reset the worker to non-idle
# state and backoff from checking for idleness for a certain threshold.
IDLENESS_RESET_BACKOFF_PERIOD = timedelta(minutes=1)
# In case of issues, signals can arrive late and out of order. So, we scan all unprocessed signals within a time range.
# Only a handful of signals are expected for each worker, so this repeated processing should be very light.
SIGNAL_SEARCH_TIME_RANGE = timedelta(hours=1)
# If a worker activation signal has not been received in a certain threshold, then we give up on waiting anymore for the signal
# and exit the worker. The assumption here is that the signal is somehow lost and will not arrive at all. So, exiting this non-active
# worker will allow the worker to be replaced by a new one.
ACTIVATION_WAIT_TIME_LIMIT = timedelta(minutes=10)
# Worker will be allowed a specific time range for a graceful shutdown starting from the moment of processing a termination signal
# before they are forcibly killed.
TERMINATION_TIME_LIMIT = timedelta(hours=12)

BOTO_RETRY_CONFIGURATION = botocore.config.Config( # type: ignore
retries={
# The standard retry mode provides exponential backoff with a base of 2 and max
Expand All @@ -79,6 +94,7 @@
"mode": "standard",
}
)
MWAA_SIGNALS_DIRECTORY = "/usr/local/mwaa/signals"

CeleryTask = Dict[str, Any]

Expand Down Expand Up @@ -136,14 +152,16 @@ def _create_shared_mem_celery_state():
# will use to signal the toggle of a flag which tells the Celery SQS channel to
# pause/unpause further consumption of available SQS messages. It is maintained by the
# worker monitor.
def _create_shared_mem_work_consumption_block():
# When MWAA signal handling is enabled, the Airflow Task consumption will be turned off by default and
# it will be enabled only when the activation signal has been received by the worker.
def _create_shared_mem_work_consumption_block(mwaa_signal_handling_enabled: bool):
celery_work_consumption_block_name = (
f'celery_work_consumption_{os.environ.get("AIRFLOW_ENV_ID", "")}'
)
celery_work_consumption_flag_block = shared_memory.SharedMemory(
create=True, size=1, name=celery_work_consumption_block_name
)
celery_work_consumption_flag_block.buf[0] = 0
celery_work_consumption_flag_block.buf[0] = 1 if mwaa_signal_handling_enabled else 0
return celery_work_consumption_flag_block


Expand Down Expand Up @@ -273,25 +291,62 @@ def _cleanup_undead_process(process_id: int):
)


def _get_next_unprocessed_signal() -> (str, dict):
signal_search_start_timestamp = math.ceil((datetime.now(tz=tz.tzutc()) - SIGNAL_SEARCH_TIME_RANGE).timestamp())
signal_filenames = os.listdir(MWAA_SIGNALS_DIRECTORY) if os.path.exists(MWAA_SIGNALS_DIRECTORY) else []
sorted_filenames = sorted(signal_filenames)
for signal_filename in sorted_filenames:
# In case of issues, signals can arrive late and out of order. So, we scan all unprocessed signals in a search time range.
# Only a handful of signals are expected for each worker, so this repeated processing should be very light.
signal_file_path = os.path.join(MWAA_SIGNALS_DIRECTORY, signal_filename)
file_timestamp = os.path.getctime(signal_file_path)
if file_timestamp > signal_search_start_timestamp:
with open(signal_file_path, "r") as file_data:
try:
signal_data = json.load(file_data)
except json.JSONDecodeError as e:
logger.info(f"Error decoding file {signal_file_path}, signal will be ignored: {e}")
signal_data = None
if signal_data and not signal_data["processed"]:
return signal_file_path, signal_data
return None, None


def _marked_signal_as_processed(signal_filepath, signal_data):
signal_data["processed"] = True
with open(signal_filepath, "w") as file_pointer:
json.dump(signal_data, file_pointer)
logger.info(f"Successfully processed signal {signal_data['executionId']}")


class WorkerTaskMonitor:
"""
Monitor for the task count associated with the worker.
:param mwaa_signal_handling_enabled: Whether the monitor should expect certain signals to be sent from MWAA.
These signals will represent MWAA service side events such as start of an environment update.
"""

def __init__(self):
def __init__(
self,
mwaa_signal_handling_enabled: bool,
):
"""
Initialize a WorkerTaskMonitor instance.
"""
# Allow at least 3 minutes for the worker to warm up before checking for
# idleness or cleaning abandoned resources.
self.mwaa_signal_handling_enabled = mwaa_signal_handling_enabled

# A worker maybe be busy loading the required libraries or polling messages from the environment SQS queue,
# so we allow the worker to warm up before checking for idleness or cleaning abandoned resources.
self.idleness_check_warmup_timestamp = (
datetime.now(tz=tz.tzutc()) + IDLENESS_CHECK_WARMUP_WAIT_PERIOD
)
self.cleanup_check_warmup_timestamp = (
datetime.now(tz=tz.tzutc()) + CLEANUP_ABANDONED_RESOURCES_WARMUP_WAIT_PERIOD
)

# Allow at least 1 second to elapse between any two checks for idleness.
# If the idleness check is made too aggressively, then we will be reusing the result of the previous check till
# the check delay threshold is reached.
self.idleness_check_delay_timestamp = (
datetime.now(tz=tz.tzutc()) + IDLENESS_CHECK_DELAY_PERIOD
)
Expand All @@ -300,8 +355,31 @@ def __init__(self):
# CONSECUTIVE_IDLENESS_CHECK_THRESHOLD to declare the worker as idle.
self.consecutive_idleness_count = 0

# If MWAA Signal handling is enabled, then monitor will wait for activation signal before starting consumption of work.
# Activation signal will be sent when service side changes have been made to ensure that it is safe for worker to start working.
self.waiting_for_activation = True if self.mwaa_signal_handling_enabled else False
# The monitor keeps track of the start of the period for which it has been waiting for activation. This is used to check
# if a time limit has expired and if the monitor should give up.
self.activation_wait_start = datetime.now(tz=tz.tzutc())
# If MWAA Signal handling is enabled, then monitor will periodically check if a kill signal has been sent by MWAA for the worker.
# If the signal is found, the monitor will kill the worker without waiting for the current Airflow tasks to be completed.
self.marked_for_kill = False
# If MWAA Signal handling is enabled, then monitor will periodically check if a termination signal has been sent by MWAA for the
# worker. If the signal is found, the monitor will terminate the worker after waiting for the current Airflow tasks to be completed.
self.marked_for_termination = False

# If resume and termination signals are received out of order, then processing them out of order can lead to undesired results.
# So, we will maintain timestamp of last processed termination or resume signal creation time to check if the latest observed
# signal should be processed or not.
self.last_termination_or_resume_signal_timestamp = None
# When a termination signal is received by a worker, then it is provided TERMINATION_TIME_LIMIT amount of time to graceful
# shutdown by finishing up the current Airflow tasks. But if termination signals are received late due to an issue, then
# we need to allow the TERMINATION_TIME_LIMIT to start from the point in time of processing the signal and not the time when the
# signal was sent from MWAA.
self.last_termination_processing_time = None

self.celery_state = _create_shared_mem_celery_state()
self.celery_work_consumption_block = _create_shared_mem_work_consumption_block()
self.celery_work_consumption_block = _create_shared_mem_work_consumption_block(self.mwaa_signal_handling_enabled)
self.cleanup_celery_state = _create_shared_mem_cleanup_celery_state()
self.abandoned_celery_tasks_from_last_check: List[CeleryTask] = []
self.undead_process_ids_from_last_check = []
Expand Down Expand Up @@ -346,6 +424,26 @@ def is_worker_idle(self):
)
return self.last_idleness_check_result

def is_marked_for_kill(self):
"""
Checks if the worker has been marked for kill or not. If MWAA Signal handling is enabled, then monitor will periodically check
if a kill signal has been sent by MWAA for the worker. If the signal is found, the monitor will kill the worker without waiting
for the current Airflow tasks to be completed.
:return: True if the worker has been marked for kill. False otherwise.
"""
return self.marked_for_kill

def is_marked_for_termination(self):
"""
Checks if the worker has been marked for termination or not. If MWAA Signal handling is enabled, then monitor will periodically
check if a termination signal has been sent by MWAA for the worker. If the signal is found, the monitor will terminate the worker
after waiting for the current Airflow tasks to be completed.
:return: True if the worker has been marked for termination. False otherwise.
"""
return self.marked_for_termination

def _get_current_task_count(self):
"""
Get count of tasks currently getting executed on the worker. Any task present in
Expand All @@ -362,6 +460,65 @@ def _get_current_task_count(self):
current_task_count += 1
return current_task_count

def process_next_signal(self):
"""
This method is used to process any signals sent by MWAA. This method processes the first signal it finds
in the chronological order of the available unprocessed signals.
"""
if not self.mwaa_signal_handling_enabled:
logger.info("Signal handling is not enabled for this worker.")
return
if self.closed:
logger.warning(
"Using process_next_signal() of a task monitor "
"after it has been closed."
)
return
signal_filepath, signal_data = _get_next_unprocessed_signal()
if not signal_data:
logger.info("No new signal found.")
return
signal_id = signal_data["executionId"]
signal_type = signal_data["signalType"]
signal_timestamp = signal_data["createdAt"]
logger.info(f"Processing signal {signal_id} of type {signal_type} created at {signal_timestamp}")
if signal_type == "activation":
self.waiting_for_activation = False
elif signal_type == "kill":
self.marked_for_kill = True
elif signal_type == "termination":
if (self.last_termination_or_resume_signal_timestamp is None or
self.last_termination_or_resume_signal_timestamp < signal_timestamp):
self.marked_for_termination = True
self.last_termination_or_resume_signal_timestamp = signal_timestamp
self.last_termination_processing_time = datetime.now(tz=tz.tzutc())
elif signal_type == "resume":
if (self.last_termination_or_resume_signal_timestamp is None or
self.last_termination_or_resume_signal_timestamp < signal_timestamp):
self.marked_for_termination = False
self.last_termination_or_resume_signal_timestamp = signal_timestamp
self.last_termination_processing_time = None
else:
logger.warning(f"Unknown signal type {signal_type}, ignoring.")
should_consume_work = not (self.waiting_for_activation or self.marked_for_kill or self.marked_for_termination)
self.resume_task_consumption() if should_consume_work else self.pause_task_consumption()
_marked_signal_as_processed(signal_filepath, signal_data)

def is_activation_wait_time_limit_breached(self):
"""
This method checks if the time limit for waiting for activation has been breached or not.
:return: True, if the time limit for waiting for activation has been breached.
"""
return self.waiting_for_activation and datetime.now(tz=tz.tzutc()) > self.activation_wait_start + ACTIVATION_WAIT_TIME_LIMIT

def is_termination_time_limit_breached(self):
"""
This method checks if the termination time limit has been breached or not.
:return: True, if the worker has been marked for termination and the allowed time limit for termination has been breached.
"""
return (self.marked_for_termination and self.last_termination_processing_time and
datetime.now(tz=tz.tzutc()) > self.last_termination_processing_time + TERMINATION_TIME_LIMIT)

def pause_task_consumption(self):
"""
celery_work_consumption_block represents the toggle switch for accepting any
Expand All @@ -376,8 +533,17 @@ def pause_task_consumption(self):
)
return

logger.info("Pausing task consumption.")
was_consumption_unpaused = self.celery_work_consumption_block.buf[0] == 0
self.celery_work_consumption_block.buf[0] = 1
if was_consumption_unpaused:
# When we toggle the Airflow Task consumption to paused state, we wait a few seconds in order
# for any in-flight messages in the SQS broker layer to be processed and
# corresponding Airflow task instance to be created. Once that is done, we can
# start gracefully shutting down the worker. Without this, the SQS broker may
# consume messages from the queue, terminate before creating the corresponding
# Airflow task instance and abandon SQS messages in-flight.
logger.info("Pausing task consumption.")
time.sleep(5)

def resume_task_consumption(self):
"""
Expand All @@ -392,8 +558,14 @@ def resume_task_consumption(self):
"after it has been closed."
)
return
logger.info("Unpausing task consumption.")
was_consumption_paused = self.celery_work_consumption_block.buf[0] == 1
self.celery_work_consumption_block.buf[0] = 0
if was_consumption_paused:
# When we toggle the Airflow Task consumption to unpaused state, we wait a few seconds in order
# for any in-flight messages in the SQS queue to start getting consumed by
# the broker layer before checking for worker idleness.
logger.info("Unpausing task consumption.")
time.sleep(5)

def reset_monitor_state(self):
"""
Expand Down Expand Up @@ -451,12 +623,18 @@ def close(self):

logger.info("Closing task monitor...")

# Report a metric about the number of current task, and a warning in case this
# is greater than zero.
# Report a metric about the number of current task, and a warning in case this is greater than zero. If the worker was
# marked for killing or was marked for termination and the allowed time limit for termination has been breached, then we do
# not report this metric because this task interruption is expected and should not be used for alarming.
task_count = self._get_current_task_count()
if task_count > 0:
logger.warning("There are non-zero ongoing tasks.")
self.stats.incr(f"mwaa.task_monitor.interrupted_tasks_at_shutdown", task_count) # type: ignore
if self.marked_for_kill or self.is_termination_time_limit_breached():
if task_count > 0:
logger.warning("Worker is being forcibly shutdown via expected methods, "
"interrupted_tasks_at_shutdown metric will not be emitted.")
else:
self.stats.incr(f"mwaa.task_monitor.interrupted_tasks_at_shutdown", task_count) # type: ignore

# Close shared memory objects.
self.celery_state.close()
Expand Down
Loading

0 comments on commit 155f0e0

Please sign in to comment.