diff --git a/src/psij/executors/batch/batch_scheduler_executor.py b/src/psij/executors/batch/batch_scheduler_executor.py index 3a7cc327..20ba5764 100644 --- a/src/psij/executors/batch/batch_scheduler_executor.py +++ b/src/psij/executors/batch/batch_scheduler_executor.py @@ -3,6 +3,7 @@ import subprocess import time import traceback +import weakref from abc import abstractmethod from datetime import timedelta from pathlib import Path @@ -639,7 +640,8 @@ def __init__(self, name: str, config: BatchSchedulerExecutorConfig, self.name = name self.daemon = True self.config = config - self.executor = executor + self.done = False + self.executor = weakref.ref(executor, self._stop) # native_id -> job self._jobs: Dict[str, List[Job]] = {} # counts consecutive errors while invoking qstat or equivalent @@ -647,26 +649,32 @@ def __init__(self, name: str, config: BatchSchedulerExecutorConfig, self._jobs_lock = RLock() def run(self) -> None: - logger.debug('Executor %s: queue poll thread started', self.executor) + logger.debug('Executor %s: queue poll thread started', self.executor()) time.sleep(self.config.initial_queue_polling_delay) - while True: + while not self.done: self._poll() time.sleep(self.config.queue_polling_interval) + def _stop(self, exec: object) -> None: + self.done = True + def _poll(self) -> None: + exec = self.executor() + if exec is None: + return with self._jobs_lock: if len(self._jobs) == 0: return jobs_copy = dict(self._jobs) logger.info('Polling for %s jobs', len(jobs_copy)) try: - out = self.executor._run_command(self.executor.get_status_command(jobs_copy.keys())) + if exec: + out = exec._run_command(exec.get_status_command(jobs_copy.keys())) except subprocess.CalledProcessError as ex: out = ex.output exit_code = ex.returncode except Exception as ex: - self._handle_poll_error(True, - ex, + self._handle_poll_error(exec, True, ex, f'Failed to poll for job status: {traceback.format_exc()}') return else: @@ -674,10 +682,9 @@ def _poll(self) -> None: self._poll_error_count = 0 logger.debug('Output from status command: %s', out) try: - status_map = self.executor.parse_status_output(exit_code, out) + status_map = exec.parse_status_output(exit_code, out) except Exception as ex: - self._handle_poll_error(False, - ex, + self._handle_poll_error(exec, False, ex, f'Failed to poll for job status: {traceback.format_exc()}') return try: @@ -689,13 +696,13 @@ def _poll(self) -> None: message='Failed to update job status: %s' % traceback.format_exc()) for job in job_list: - self.executor._set_job_status(job, status) + exec._set_job_status(job, status) if status.state.final: with self._jobs_lock: del self._jobs[native_id] except Exception as ex: msg = traceback.format_exc() - self._handle_poll_error(True, ex, 'Error updating job statuses {}'.format(msg)) + self._handle_poll_error(exec, True, ex, 'Error updating job statuses {}'.format(msg)) def _get_job_status(self, native_id: str, status_map: Dict[str, JobStatus]) -> JobStatus: if native_id in status_map: @@ -703,7 +710,8 @@ def _get_job_status(self, native_id: str, status_map: Dict[str, JobStatus]) -> J else: return JobStatus(JobState.COMPLETED) - def _handle_poll_error(self, immediate: bool, ex: Exception, msg: str) -> None: + def _handle_poll_error(self, exec: BatchSchedulerExecutor, immediate: bool, ex: Exception, + msg: str) -> None: logger.warning('Polling error: %s', msg) self._poll_error_count += 1 if immediate or (self._poll_error_count > self.config.queue_polling_error_threshold): @@ -720,7 +728,7 @@ def _handle_poll_error(self, immediate: bool, ex: Exception, msg: str) -> None: self._jobs.clear() for job_list in jobs_copy.values(): for job in job_list: - self.executor._set_job_status(job, JobStatus(JobState.FAILED, message=msg)) + exec._set_job_status(job, JobStatus(JobState.FAILED, message=msg)) def register_job(self, job: Job) -> None: assert job.native_id