Skip to content

Commit

Permalink
Merge pull request #502 from ExaWorks/clean_queue_poll_threads
Browse files Browse the repository at this point in the history
Fix thread "leak"
  • Loading branch information
hategan authored Feb 12, 2025
2 parents 7e7dcbe + 65a5c83 commit 7088769
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions src/psij/executors/batch/batch_scheduler_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess
import time
import traceback
import weakref
from abc import abstractmethod
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -639,45 +640,51 @@ def __init__(self, name: str, config: BatchSchedulerExecutorConfig,
self.name = name
self.daemon = True
self.config = config
self.executor = executor
self.done = False
self.executor = weakref.ref(executor, self._stop)
# native_id -> job
self._jobs: Dict[str, List[Job]] = {}
# counts consecutive errors while invoking qstat or equivalent
self._poll_error_count = 0
self._jobs_lock = RLock()

def run(self) -> None:
logger.debug('Executor %s: queue poll thread started', self.executor)
logger.debug('Executor %s: queue poll thread started', self.executor())
time.sleep(self.config.initial_queue_polling_delay)
while True:
while not self.done:
self._poll()
time.sleep(self.config.queue_polling_interval)

def _stop(self, exec: object) -> None:
self.done = True

def _poll(self) -> None:
exec = self.executor()
if exec is None:
return
with self._jobs_lock:
if len(self._jobs) == 0:
return
jobs_copy = dict(self._jobs)
logger.info('Polling for %s jobs', len(jobs_copy))
try:
out = self.executor._run_command(self.executor.get_status_command(jobs_copy.keys()))
if exec:
out = exec._run_command(exec.get_status_command(jobs_copy.keys()))
except subprocess.CalledProcessError as ex:
out = ex.output
exit_code = ex.returncode
except Exception as ex:
self._handle_poll_error(True,
ex,
self._handle_poll_error(exec, True, ex,
f'Failed to poll for job status: {traceback.format_exc()}')
return
else:
exit_code = 0
self._poll_error_count = 0
logger.debug('Output from status command: %s', out)
try:
status_map = self.executor.parse_status_output(exit_code, out)
status_map = exec.parse_status_output(exit_code, out)
except Exception as ex:
self._handle_poll_error(False,
ex,
self._handle_poll_error(exec, False, ex,
f'Failed to poll for job status: {traceback.format_exc()}')
return
try:
Expand All @@ -689,21 +696,22 @@ def _poll(self) -> None:
message='Failed to update job status: %s' %
traceback.format_exc())
for job in job_list:
self.executor._set_job_status(job, status)
exec._set_job_status(job, status)
if status.state.final:
with self._jobs_lock:
del self._jobs[native_id]
except Exception as ex:
msg = traceback.format_exc()
self._handle_poll_error(True, ex, 'Error updating job statuses {}'.format(msg))
self._handle_poll_error(exec, True, ex, 'Error updating job statuses {}'.format(msg))

def _get_job_status(self, native_id: str, status_map: Dict[str, JobStatus]) -> JobStatus:
if native_id in status_map:
return status_map[native_id]
else:
return JobStatus(JobState.COMPLETED)

def _handle_poll_error(self, immediate: bool, ex: Exception, msg: str) -> None:
def _handle_poll_error(self, exec: BatchSchedulerExecutor, immediate: bool, ex: Exception,
msg: str) -> None:
logger.warning('Polling error: %s', msg)
self._poll_error_count += 1
if immediate or (self._poll_error_count > self.config.queue_polling_error_threshold):
Expand All @@ -720,7 +728,7 @@ def _handle_poll_error(self, immediate: bool, ex: Exception, msg: str) -> None:
self._jobs.clear()
for job_list in jobs_copy.values():
for job in job_list:
self.executor._set_job_status(job, JobStatus(JobState.FAILED, message=msg))
exec._set_job_status(job, JobStatus(JobState.FAILED, message=msg))

def register_job(self, job: Job) -> None:
assert job.native_id
Expand Down

0 comments on commit 7088769

Please sign in to comment.