diff --git a/anyscale_provider/operators/anyscale.py b/anyscale_provider/operators/anyscale.py index 9ebe45c..e0ffde3 100644 --- a/anyscale_provider/operators/anyscale.py +++ b/anyscale_provider/operators/anyscale.py @@ -1,6 +1,5 @@ from __future__ import annotations -import time from datetime import timedelta from typing import Any @@ -82,7 +81,7 @@ def __init__( cloud: str | None = None, project: str | None = None, max_retries: int = 1, - fetch_logs: bool = False, + fetch_logs: bool = True, wait_for_completion: bool = True, job_timeout_seconds: float = 3600, poll_interval: float = 60, @@ -158,7 +157,10 @@ def execute(self, context: Context) -> None: elif current_state in (JobState.STARTING, JobState.RUNNING): self.defer( trigger=AnyscaleJobTrigger( - conn_id=self.conn_id, job_id=self.job_id, poll_interval=self.poll_interval + conn_id=self.conn_id, + job_id=self.job_id, + poll_interval=self.poll_interval, + fetch_logs=self.fetch_logs, ), method_name="execute_complete", timeout=timedelta(seconds=self.job_timeout_seconds), @@ -169,16 +171,6 @@ def execute(self, context: Context) -> None: def execute_complete(self, context: Context, event: Any) -> None: current_job_id = event["job_id"] - if self.fetch_logs: - job_status = self.hook.get_job_status(current_job_id) - - # Heuristic to wait for the job logs to be complete - time.sleep(30) - - logs = self.hook.get_job_logs(current_job_id, run=job_status.runs[-1].name) - for log in logs.split("\n"): - print(log) - if event["state"] == JobState.FAILED: self.log.info(f"Anyscale job {current_job_id} ended with state: {event['state']}") raise AirflowException(f"Job {current_job_id} failed with error {event['message']}") diff --git a/anyscale_provider/triggers/anyscale.py b/anyscale_provider/triggers/anyscale.py index eea950b..551d302 100644 --- a/anyscale_provider/triggers/anyscale.py +++ b/anyscale_provider/triggers/anyscale.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from functools import partial from typing import Any, AsyncIterator from airflow.compat.functools import cached_property @@ -28,11 +29,12 @@ class AnyscaleJobTrigger(BaseTrigger): :param poll_interval: Optional. Interval in seconds between status checks. Defaults to 60 seconds. """ - def __init__(self, conn_id: str, job_id: str, poll_interval: float = 60): + def __init__(self, conn_id: str, job_id: str, poll_interval: float = 60, fetch_logs: bool = True): super().__init__() # type: ignore[no-untyped-call] self.conn_id = conn_id self.job_id = job_id self.poll_interval = poll_interval + self.fetch_logs = fetch_logs @cached_property def hook(self) -> AnyscaleHook: @@ -57,6 +59,16 @@ async def run(self) -> AsyncIterator[TriggerEvent]: while not self._is_terminal_state(self.job_id): await asyncio.sleep(self.poll_interval) + if self.fetch_logs: + job_status = self.hook.get_job_status(self.job_id) + loop = asyncio.get_event_loop() + logs = await loop.run_in_executor( + None, partial(self.hook.get_job_logs, job_id=self.job_id, run=job_status.runs[-1].name) + ) + + for log in logs.split("\n"): + print(log) + # Once out of the loop, the job has reached a terminal status job_status = self.hook.get_job_status(self.job_id) job_state = str(job_status.state)