Skip to content

Commit

Permalink
make fetching of job logs optional
Browse files Browse the repository at this point in the history
  • Loading branch information
marwan116 committed Jun 25, 2024
1 parent 4e72068 commit 03be79b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
17 changes: 14 additions & 3 deletions anyscale_provider/operators/anyscale.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import time
from datetime import timedelta
from typing import Any

Expand Down Expand Up @@ -59,6 +60,7 @@ class SubmitAnyscaleJob(BaseOperator):
"cloud",
"project",
"max_retries",
"fetch_logs",
"job_timeout_seconds",
"poll_interval",
)
Expand All @@ -79,6 +81,7 @@ def __init__(
cloud: str | None = None,
project: str | None = None,
max_retries: int = 1,
fetch_logs: bool = False,
job_timeout_seconds: float = 3600,
poll_interval: float = 60,
*args: Any,
Expand All @@ -99,6 +102,7 @@ def __init__(
self.cloud = cloud
self.project = project
self.max_retries = max_retries
self.fetch_logs = fetch_logs
self.job_timeout_seconds = job_timeout_seconds
self.poll_interval = poll_interval

Expand Down Expand Up @@ -161,12 +165,21 @@ def execute(self, context: Context) -> str | None:
def execute_complete(self, context: Context, event: Any) -> None:
current_job_id = event["job_id"]

if self.fetch_logs:
job_status = self.hook.get_job_status(current_job_id)

# Heuristic to wait for the job logs to be complete
time.sleep(30)

logs = self.hook.get_job_logs(current_job_id, run=job_status.runs[-1].name)
for log in logs.split("\n"):
print(log)

if event["state"] == JobState.FAILED:
self.log.info(f"Anyscale job {current_job_id} ended with state: {event['state']}")
raise AirflowException(f"Job {current_job_id} failed with error {event['message']}")
else:
self.log.info(f"Anyscale job {current_job_id} completed with state: {event['state']}")
return None


class RolloutAnyscaleService(BaseOperator):
Expand Down Expand Up @@ -357,5 +370,3 @@ def execute_complete(self, context: Context, event: Any) -> None:
raise AirflowException(error_msg)
else:
self.log.info(f"Anyscale service deployment {service_name} completed successfully")

return None
11 changes: 1 addition & 10 deletions anyscale_provider/triggers/anyscale.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
from functools import partial
from typing import Any, AsyncIterator

from airflow.compat.functools import cached_property
Expand Down Expand Up @@ -58,16 +57,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
while not self._is_terminal_status(self.job_id):
await asyncio.sleep(self.poll_interval)

# Fetch and print logs
job_status = self.hook.get_job_status(self.job_id)
loop = asyncio.get_running_loop()
logs = await loop.run_in_executor(
None, partial(self.hook.get_job_logs, job_id=self.job_id, run=job_status.runs[-1].name)
)
for log in logs.split("\n"):
self.log.info(log)

# Once out of the loop, the job has reached a terminal status
job_status = self.hook.get_job_status(self.job_id)
job_state = str(job_status.state)
self.log.info(f"Current job status for {self.job_id} is: {job_state}")
yield TriggerEvent(
Expand Down

0 comments on commit 03be79b

Please sign in to comment.