Skip to content

Commit

Permalink
add optionality to avoid waiting on job completion
Browse files Browse the repository at this point in the history
  • Loading branch information
marwan116 committed Jun 25, 2024
1 parent a4c32eb commit 50d403b
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions anyscale_provider/operators/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class SubmitAnyscaleJob(BaseOperator):
"project",
"max_retries",
"fetch_logs",
"wait_for_completion",
"job_timeout_seconds",
"poll_interval",
)
Expand All @@ -82,6 +83,7 @@ def __init__(
project: str | None = None,
max_retries: int = 1,
fetch_logs: bool = False,
wait_for_completion: bool = True,
job_timeout_seconds: float = 3600,
poll_interval: float = 60,
*args: Any,
Expand All @@ -103,6 +105,7 @@ def __init__(
self.project = project
self.max_retries = max_retries
self.fetch_logs = fetch_logs
self.wait_for_completion = wait_for_completion
self.job_timeout_seconds = job_timeout_seconds
self.poll_interval = poll_interval

Expand Down Expand Up @@ -144,21 +147,24 @@ def execute(self, context: Context) -> None:

self.log.info(f"Submitted Anyscale job with ID: {self.job_id}")

current_state = str(self.hook.get_job_status(self.job_id).state)
self.log.info(f"Current job state for {self.job_id} is: {current_state}")

if current_state == JobState.SUCCEEDED:
self.log.info(f"Job {self.job_id} completed successfully.")
elif current_state == JobState.FAILED:
raise AirflowException(f"Job {self.job_id} failed.")
elif current_state in (JobState.STARTING, JobState.RUNNING):
self.defer(
trigger=AnyscaleJobTrigger(conn_id=self.conn_id, job_id=self.job_id, poll_interval=self.poll_interval),
method_name="execute_complete",
timeout=timedelta(seconds=self.job_timeout_seconds),
)
else:
raise Exception(f"Unexpected state `{current_state}` for job_id `{self.job_id}`.")
if self.wait_for_completion:
current_state = str(self.hook.get_job_status(self.job_id).state)
self.log.info(f"Current job state for {self.job_id} is: {current_state}")

if current_state == JobState.SUCCEEDED:
self.log.info(f"Job {self.job_id} completed successfully.")
elif current_state == JobState.FAILED:
raise AirflowException(f"Job {self.job_id} failed.")
elif current_state in (JobState.STARTING, JobState.RUNNING):
self.defer(
trigger=AnyscaleJobTrigger(
conn_id=self.conn_id, job_id=self.job_id, poll_interval=self.poll_interval
),
method_name="execute_complete",
timeout=timedelta(seconds=self.job_timeout_seconds),
)
else:
raise Exception(f"Unexpected state `{current_state}` for job_id `{self.job_id}`.")

def execute_complete(self, context: Context, event: Any) -> None:
current_job_id = event["job_id"]
Expand Down

0 comments on commit 50d403b

Please sign in to comment.