Skip to content

Commit

Permalink
Fix on_failure_callback when task receive SIGKILL (apache#15537)
Browse files Browse the repository at this point in the history
This PR fixes a case where a task would not call the on_failure_callback
when there's a case of OOM. The issue was that task pid was being set
at the wrong place and the local task job heartbeat was not checking the
correct pid of the process runner and task.

Now, instead of setting the task pid in check_and_change_state_before_execution,
it's now set correctly at the _run_raw_task method
  • Loading branch information
ephraimbuddy authored May 5, 2021
1 parent 13faa69 commit 817b599
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
9 changes: 6 additions & 3 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
#

import os
import signal
from typing import Optional

Expand Down Expand Up @@ -154,6 +153,10 @@ def handle_task_exit(self, return_code: int) -> None:
# task exited by itself, so we need to check for error file
# incase it failed due to runtime exception/error
error = None
if self.task_instance.state == State.RUNNING:
# This is for a case where the task received a sigkill
# while running
self.task_instance.set_state(State.FAILED)
if self.task_instance.state != State.SUCCESS:
error = self.task_runner.deserialize_run_error()
self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access
Expand Down Expand Up @@ -184,9 +187,9 @@ def heartbeat_callback(self, session=None):
)
raise AirflowException("Hostname of job runner does not match")

current_pid = os.getpid()
current_pid = self.task_runner.process.pid
same_process = ti.pid == current_pid
if not same_process:
if ti.pid is not None and not same_process:
self.log.warning("Recorded pid %s does not match " "the current pid %s", ti.pid, current_pid)
raise AirflowException("PID of job runner does not match")
elif self.task_runner.return_code() is None and hasattr(self.task_runner, 'process'):
Expand Down
5 changes: 3 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,6 @@ def check_and_change_state_before_execution( # pylint: disable=too-many-argumen
if not test_mode:
session.add(Log(State.RUNNING, self))
self.state = State.RUNNING
self.pid = os.getpid()
self.end_date = None
if not test_mode:
session.merge(self)
Expand Down Expand Up @@ -1127,7 +1126,9 @@ def _run_raw_task(
self.refresh_from_db(session=session)
self.job_id = job_id
self.hostname = get_hostname()

self.pid = os.getpid()
session.merge(self)
session.commit()
actual_start_date = timezone.utcnow()
Stats.incr(f'ti.start.{task.dag_id}.{task.task_id}')
try:
Expand Down
33 changes: 22 additions & 11 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from unittest.mock import patch

import pytest
from parameterized import parameterized

from airflow import settings
from airflow.exceptions import AirflowException, AirflowFailException
Expand Down Expand Up @@ -92,8 +93,7 @@ def test_localtaskjob_essential_attr(self):
check_result_2 = [getattr(job1, attr) is not None for attr in essential_attr]
assert all(check_result_2)

@patch('os.getpid')
def test_localtaskjob_heartbeat(self, mock_pid):
def test_localtaskjob_heartbeat(self):
session = settings.Session()
dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

Expand All @@ -114,19 +114,23 @@ def test_localtaskjob_heartbeat(self, mock_pid):
session.commit()

job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
ti.task = op1
ti.refresh_from_task(op1)
job1.task_runner = StandardTaskRunner(job1)
job1.task_runner.process = mock.Mock()
with pytest.raises(AirflowException):
job1.heartbeat_callback() # pylint: disable=no-value-for-parameter

mock_pid.return_value = 1
job1.task_runner.process.pid = 1
ti.state = State.RUNNING
ti.hostname = get_hostname()
ti.pid = 1
session.merge(ti)
session.commit()

assert ti.pid != os.getpid()
job1.heartbeat_callback(session=None)

mock_pid.return_value = 2
job1.task_runner.process.pid = 2
with pytest.raises(AirflowException):
job1.heartbeat_callback() # pylint: disable=no-value-for-parameter

Expand Down Expand Up @@ -496,9 +500,15 @@ def task_function(ti):
assert task_terminated_externally.value == 1
assert not process.is_alive()

def test_process_kill_call_on_failure_callback(self):
@parameterized.expand(
[
(signal.SIGTERM,),
(signal.SIGKILL,),
]
)
def test_process_kill_calls_on_failure_callback(self, signal_type):
"""
Test that ensures that when a task is killed with sigterm
Test that ensures that when a task is killed with sigterm or sigkill
on_failure_callback gets executed
"""
# use shared memory value so we can properly track value change even if
Expand Down Expand Up @@ -547,13 +557,14 @@ def task_function(ti):
process = multiprocessing.Process(target=job1.run)
process.start()

for _ in range(0, 10):
for _ in range(0, 20):
ti.refresh_from_db()
if ti.state == State.RUNNING:
if ti.state == State.RUNNING and ti.pid is not None:
break
time.sleep(0.2)
assert ti.state == State.RUNNING
os.kill(ti.pid, signal.SIGTERM)
assert ti.pid is not None
os.kill(ti.pid, signal_type)
process.join(timeout=10)
assert failure_callback_called.value == 1
assert task_terminated_externally.value == 1
Expand Down Expand Up @@ -584,5 +595,5 @@ def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes)
mock_get_task_runner.return_value.return_code.side_effects = return_codes

job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
with assert_queries_count(13):
with assert_queries_count(15):
job.run()
6 changes: 3 additions & 3 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,8 +2080,8 @@ def tearDown(self) -> None:
@parameterized.expand(
[
# Expected queries, mark_success
(10, False),
(5, True),
(12, False),
(7, True),
]
)
def test_execute_queries_count(self, expected_query_count, mark_success):
Expand Down Expand Up @@ -2117,7 +2117,7 @@ def test_execute_queries_count_store_serialized(self):
session=session,
)

with assert_queries_count(10):
with assert_queries_count(12):
ti._run_raw_task()

def test_operator_field_with_serialization(self):
Expand Down

0 comments on commit 817b599

Please sign in to comment.