Skip to content

Commit

Permalink
log fetching behavior updates
Browse files Browse the repository at this point in the history
- keep log fetching at trigger
- default fetch_logs to true
- remove heuristic to sleep
  • Loading branch information
marwan116 committed Jun 25, 2024
1 parent 55e495b commit 003a533
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
18 changes: 5 additions & 13 deletions anyscale_provider/operators/anyscale.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import time
from datetime import timedelta
from typing import Any

Expand Down Expand Up @@ -82,7 +81,7 @@ def __init__(
cloud: str | None = None,
project: str | None = None,
max_retries: int = 1,
fetch_logs: bool = False,
fetch_logs: bool = True,
wait_for_completion: bool = True,
job_timeout_seconds: float = 3600,
poll_interval: float = 60,
Expand Down Expand Up @@ -158,7 +157,10 @@ def execute(self, context: Context) -> None:
elif current_state in (JobState.STARTING, JobState.RUNNING):
self.defer(
trigger=AnyscaleJobTrigger(
conn_id=self.conn_id, job_id=self.job_id, poll_interval=self.poll_interval
conn_id=self.conn_id,
job_id=self.job_id,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.job_timeout_seconds),
Expand All @@ -169,16 +171,6 @@ def execute(self, context: Context) -> None:
def execute_complete(self, context: Context, event: Any) -> None:
current_job_id = event["job_id"]

if self.fetch_logs:
job_status = self.hook.get_job_status(current_job_id)

# Heuristic to wait for the job logs to be complete
time.sleep(30)

logs = self.hook.get_job_logs(current_job_id, run=job_status.runs[-1].name)
for log in logs.split("\n"):
print(log)

if event["state"] == JobState.FAILED:
self.log.info(f"Anyscale job {current_job_id} ended with state: {event['state']}")
raise AirflowException(f"Job {current_job_id} failed with error {event['message']}")
Expand Down
14 changes: 13 additions & 1 deletion anyscale_provider/triggers/anyscale.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
from functools import partial
from typing import Any, AsyncIterator

from airflow.compat.functools import cached_property
Expand Down Expand Up @@ -28,11 +29,12 @@ class AnyscaleJobTrigger(BaseTrigger):
:param poll_interval: Optional. Interval in seconds between status checks. Defaults to 60 seconds.
"""

def __init__(self, conn_id: str, job_id: str, poll_interval: float = 60):
def __init__(self, conn_id: str, job_id: str, poll_interval: float = 60, fetch_logs: bool = True):
super().__init__() # type: ignore[no-untyped-call]
self.conn_id = conn_id
self.job_id = job_id
self.poll_interval = poll_interval
self.fetch_logs = fetch_logs

@cached_property
def hook(self) -> AnyscaleHook:
Expand All @@ -57,6 +59,16 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
while not self._is_terminal_state(self.job_id):
await asyncio.sleep(self.poll_interval)

if self.fetch_logs:
job_status = self.hook.get_job_status(self.job_id)
loop = asyncio.get_event_loop()
logs = await loop.run_in_executor(
None, partial(self.hook.get_job_logs, job_id=self.job_id, run=job_status.runs[-1].name)
)

for log in logs.split("\n"):
print(log)

# Once out of the loop, the job has reached a terminal status
job_status = self.hook.get_job_status(self.job_id)
job_state = str(job_status.state)
Expand Down

0 comments on commit 003a533

Please sign in to comment.