From 50d403b2286d285f290e1c7c71ea3089695e3206 Mon Sep 17 00:00:00 2001 From: Marwan Sarieddine Date: Tue, 25 Jun 2024 15:27:16 +0300 Subject: [PATCH] add optionality to avoid waiting on job completion --- anyscale_provider/operators/anyscale.py | 36 ++++++++++++++----------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/anyscale_provider/operators/anyscale.py b/anyscale_provider/operators/anyscale.py index ad2d0ea..9ebe45c 100644 --- a/anyscale_provider/operators/anyscale.py +++ b/anyscale_provider/operators/anyscale.py @@ -61,6 +61,7 @@ class SubmitAnyscaleJob(BaseOperator): "project", "max_retries", "fetch_logs", + "wait_for_completion", "job_timeout_seconds", "poll_interval", ) @@ -82,6 +83,7 @@ def __init__( project: str | None = None, max_retries: int = 1, fetch_logs: bool = False, + wait_for_completion: bool = True, job_timeout_seconds: float = 3600, poll_interval: float = 60, *args: Any, @@ -103,6 +105,7 @@ def __init__( self.project = project self.max_retries = max_retries self.fetch_logs = fetch_logs + self.wait_for_completion = wait_for_completion self.job_timeout_seconds = job_timeout_seconds self.poll_interval = poll_interval @@ -144,21 +147,24 @@ def execute(self, context: Context) -> None: self.log.info(f"Submitted Anyscale job with ID: {self.job_id}") - current_state = str(self.hook.get_job_status(self.job_id).state) - self.log.info(f"Current job state for {self.job_id} is: {current_state}") - - if current_state == JobState.SUCCEEDED: - self.log.info(f"Job {self.job_id} completed successfully.") - elif current_state == JobState.FAILED: - raise AirflowException(f"Job {self.job_id} failed.") - elif current_state in (JobState.STARTING, JobState.RUNNING): - self.defer( - trigger=AnyscaleJobTrigger(conn_id=self.conn_id, job_id=self.job_id, poll_interval=self.poll_interval), - method_name="execute_complete", - timeout=timedelta(seconds=self.job_timeout_seconds), - ) - else: - raise Exception(f"Unexpected state `{current_state}` for job_id `{self.job_id}`.") + if self.wait_for_completion: + current_state = str(self.hook.get_job_status(self.job_id).state) + self.log.info(f"Current job state for {self.job_id} is: {current_state}") + + if current_state == JobState.SUCCEEDED: + self.log.info(f"Job {self.job_id} completed successfully.") + elif current_state == JobState.FAILED: + raise AirflowException(f"Job {self.job_id} failed.") + elif current_state in (JobState.STARTING, JobState.RUNNING): + self.defer( + trigger=AnyscaleJobTrigger( + conn_id=self.conn_id, job_id=self.job_id, poll_interval=self.poll_interval + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.job_timeout_seconds), + ) + else: + raise Exception(f"Unexpected state `{current_state}` for job_id `{self.job_id}`.") def execute_complete(self, context: Context, event: Any) -> None: current_job_id = event["job_id"]