diff --git a/src/discord-cluster-manager/cogs/github_cog.py b/src/discord-cluster-manager/cogs/github_cog.py index a2ae394..8ae34da 100644 --- a/src/discord-cluster-manager/cogs/github_cog.py +++ b/src/discord-cluster-manager/cogs/github_cog.py @@ -1,21 +1,15 @@ -import asyncio import json -import tempfile -import zipfile -from datetime import datetime, timedelta, timezone from typing import Optional import discord -import requests from consts import GPUType from discord import app_commands from discord.ext import commands -from env import GITHUB_REPO, GITHUB_TOKEN -from github import Github +from github_runner import GitHubRun from leaderboard_eval import amd_requirements, nvidia_requirements from report import generate_report from run_eval import CompileResult, FullResult, RunResult -from utils import build_task_config, get_github_branch_name, send_discord_message, setup_logging +from utils import build_task_config, send_discord_message, setup_logging logger = setup_logging() @@ -68,36 +62,23 @@ async def run_github( else: reference_content = None - run_id = await self.trigger_github_run( + artifacts = await self.execute_github_run( lang=lang, gpu_type=selected_gpu, script_content=script_content, reference_content=reference_content, + thread=thread, ) - if run_id: - await thread.send( - f"GitHub Action triggered! Run ID: {run_id}\nMonitoring progress..." - ) - status, result, url = await self.check_workflow_status(run_id, thread, gpu_type) - - await thread.send(f"Training completed with status: {status}") - - if isinstance(result, FullResult): - await generate_report(thread, result) - else: - if len(result) > 1900: - await self.bot.send_chunked_message(thread, result, code_block=True) - else: - await thread.send(f"```\nLogs:\n{result}\n```") - - if url: - await thread.send(f"View the full run at: <{url}>") + logs = artifacts["run-result"]["result.json"].decode("utf-8") + data = json.loads(logs) + if "compile" in data and data["compile"] is not None: + comp = CompileResult(**data["compile"]) else: - await thread.send( - "Failed to trigger GitHub Action. Please check the configuration." - ) - + comp = None + run = RunResult(**data["run"]) + result = FullResult(success=True, error="", compile=comp, run=run) + await generate_report(thread, result) return thread, result except Exception as e: @@ -106,9 +87,14 @@ async def run_github( await thread.send(f"Error processing request: {str(e)}") raise - async def trigger_github_run( - self, lang: str, gpu_type: GPUType, script_content: str, reference_content: Optional[str] - ): + async def execute_github_run( + self, + lang: str, + gpu_type: GPUType, + script_content: str, + reference_content: Optional[str], + thread: discord.Thread, + ) -> dict: if lang == "cu" and gpu_type == GPUType.AMD: # TODO implement HIP raise ValueError("Cannot use CUDA runs with AMD GPUs") @@ -123,14 +109,11 @@ async def trigger_github_run( ) logger.info(f"Attempting to trigger GitHub action for {lang_name} on {gpu_type.name}") - gh = Github(GITHUB_TOKEN) - repo = gh.get_repo(GITHUB_REPO) - try: - trigger_time = datetime.now(timezone.utc) - workflow_file = gpu_type.value - workflow = repo.get_workflow(workflow_file) + workflow_file = gpu_type.value + run = GitHubRun(workflow_file) + try: payload = json.dumps(config) inputs = {"payload": payload} @@ -140,116 +123,29 @@ async def trigger_github_run( else: inputs["requirements"] = amd_requirements - success = workflow.create_dispatch(get_github_branch_name(), inputs=inputs) - if success: - await asyncio.sleep(2) - runs = list(workflow.get_runs()) - - for run in runs: - if run.created_at.replace(tzinfo=timezone.utc) > trigger_time: - return run.id - return None - - except Exception as e: - logger.error(f"Error in trigger_github_action: {str(e)}", exc_info=True) - return None - - async def check_workflow_status(self, run_id, thread, gpu_type): - logger.info(f"Starting to monitor workflow status for run {run_id}") - gh = Github(GITHUB_TOKEN) - repo = gh.get_repo(GITHUB_REPO) - start_time = datetime.now(timezone.utc) - timeout_minutes = 5 - timeout = timedelta(minutes=timeout_minutes) - - while True: - try: - run = repo.get_workflow_run(run_id) - elapsed_time = datetime.now(timezone.utc) - start_time - - if elapsed_time > timeout: - try: - run.cancel() - # Wait briefly to ensure cancellation is processed - # And Verify the run was actually cancelled - await asyncio.sleep(5) - run = repo.get_workflow_run(run_id) - if run.status != "completed": - logger.warning(f"Failed to cancel workflow run {run_id}") - except Exception as e: - logger.error(f"Error cancelling workflow: {str(e)}") - - await thread.send( - f"Workflow cancelled - exceeded {timeout_minutes} minute timeout" - ) - return ( - "cancelled", - f"Workflow exceeded {timeout_minutes} minute timeout", - run.html_url, - ) - - if run.status == "completed": - result = await self.download_results(run_id) - return run.conclusion, result, run.html_url - + if not await run.trigger(inputs): await thread.send( - f"Workflow: {run.status} running for " - f"{elapsed_time.total_seconds():.2f} seconds\n" - f"Live view: <{run.html_url}>" + "Failed to trigger GitHub Action. Please check the configuration." ) - await asyncio.sleep(20) - except Exception as e: - logger.error("Error", exc_info=e) - return "error", str(e), None + return {} - async def download_results(self, run_id) -> FullResult: - try: - data = await self.download_artifact(run_id, name="run-result") - logs = data["result.json"].decode("utf-8") - data = json.loads(logs) - if "compile" in data and data["compile"] is not None: - comp = CompileResult(**data["compile"]) - else: - comp = None - run = RunResult(**data["run"]) - return FullResult(success=True, error="", compile=comp, run=run) - except Exception as e: - logger.error("Error downloading artifacts", exc_info=e) - return FullResult( - success=False, - error=f"Error downloading artifacts: {repr(e)}", - compile=None, - run=None, + status_msg = await thread.send( + "**Running on GitHub...**\n" "> ⏳ Waiting for workflow to start..." ) + await run.wait_for_completion(lambda x: self.wait_callback(x, thread, status_msg)) + await thread.send(f"Running completed with status: {run.status}") - async def download_artifact(self, run_id, name: str): - logger.info(f"Attempting to download artifact {name} for run {run_id}") - gh = Github(GITHUB_TOKEN) - repo = gh.get_repo(GITHUB_REPO) - - run = repo.get_workflow_run(run_id) - artifacts = run.get_artifacts() + return await run.download_artifacts() - for artifact in artifacts: - if artifact.name == name: - url = artifact.archive_download_url - headers = {"Authorization": f"token {GITHUB_TOKEN}"} - response = requests.get(url, headers=headers) - - if response.status_code == 200: - with tempfile.NamedTemporaryFile("w+b") as temp: - temp.write(response.content) - temp.flush() + except Exception as e: + logger.error(f"Error in trigger_github_action: {str(e)}", exc_info=True) + raise - with zipfile.ZipFile(temp.name) as z: - artifact_dict = {} - for file in z.namelist(): - with z.open(file) as f: - artifact_dict[file] = f.read() + async def wait_callback(self, run: GitHubRun, thread: discord.Thread, msg: discord.Message): + message = ( + f"**Running on GitHub...**\n" + f"> Workflow [{run.run_id}]({run.html_url}): {run.status}\n" + f"> ⏳ {run.elapsed_time.total_seconds():.1f} seconds\n" + ) - return artifact_dict - else: - raise RuntimeError( - f"Failed to download artifact. Status code: {response.status_code}" - ) - raise RuntimeError(f"Could not find artifact {name}") + await msg.edit(content=message) diff --git a/src/discord-cluster-manager/cogs/verify_run_cog.py b/src/discord-cluster-manager/cogs/verify_run_cog.py index 1d593ba..66d1d70 100644 --- a/src/discord-cluster-manager/cogs/verify_run_cog.py +++ b/src/discord-cluster-manager/cogs/verify_run_cog.py @@ -61,13 +61,7 @@ async def verify_github_run( message_contents = [msg.content async for msg in github_thread.history(limit=None)] - required_patterns = [ - "Processing `.*` with", - "GitHub Action triggered! Run ID:", - "Training completed with status: success", - "'check': 'pass'", - "View the full run at:", - ] + required_patterns = ["Processing `.*` with", "Running on GitHub...", "'check': 'pass'"] all_patterns_found = all( any(re.search(pattern, content, re.DOTALL) is not None for content in message_contents) diff --git a/src/discord-cluster-manager/github_runner.py b/src/discord-cluster-manager/github_runner.py new file mode 100644 index 0000000..12d13ed --- /dev/null +++ b/src/discord-cluster-manager/github_runner.py @@ -0,0 +1,156 @@ +import asyncio +import pprint +import tempfile +import zipfile +from datetime import datetime, timedelta, timezone +from typing import Awaitable, Callable, Optional + +import requests +from env import GITHUB_REPO, GITHUB_TOKEN +from github import Github, UnknownObjectException, WorkflowRun +from utils import get_github_branch_name, setup_logging + +logger = setup_logging() + + +class GitHubRun: + def __init__(self, workflow_file: str): + gh = Github(GITHUB_TOKEN) + self.repo = gh.get_repo(GITHUB_REPO) + self.workflow_file = workflow_file + self.run: Optional[WorkflowRun.WorkflowRun] = None + self.start_time = None + + @property + def run_id(self): + if self.run is None: + return None + return self.run.id + + @property + def html_url(self): + if self.run is None: + return None + return self.run.html_url + + @property + def status(self): + if self.run is None: + return None + return self.run.status + + @property + def elapsed_time(self): + if self.start_time is None: + return None + return datetime.now(timezone.utc) - self.start_time + + async def trigger(self, inputs: dict) -> bool: + """ + Trigger this run with the provided inputs. + Sets `self.run` to the new WorkflowRun on success. + + Returns: Whether the run was successfully triggered, + """ + trigger_time = datetime.now(timezone.utc) + try: + workflow = self.repo.get_workflow(self.workflow_file) + except UnknownObjectException as e: + logger.error(f"Could not find workflow {self.workflow_file}", exc_info=e) + raise ValueError(f"Could not find workflow {self.workflow_file}") from e + + logger.debug( + "Dispatching workflow %s on branch %s with inputs %s", + self.workflow_file, + get_github_branch_name(), + pprint.pformat(inputs), + ) + success = workflow.create_dispatch(get_github_branch_name(), inputs=inputs) + if success: + await asyncio.sleep(2) + runs = list(workflow.get_runs()) + + for run in runs: + if run.created_at.replace(tzinfo=timezone.utc) > trigger_time: + self.run = run + return True + return False + + async def wait_for_completion( + self, callback: Callable[["GitHubRun"], Awaitable[None]], timeout_minutes: int = 5 + ): + if self.run is None: + raise ValueError("Run needs to be triggered before a status check!") + + self.start_time = datetime.now(timezone.utc) + timeout = timedelta(minutes=timeout_minutes) + + while True: + try: + # update run status + self.run = run = self.repo.get_workflow_run(self.run_id) + + if self.elapsed_time > timeout: + try: + self.run.cancel() + # Wait briefly to ensure cancellation is processed + # And Verify the run was actually cancelled + await asyncio.sleep(5) + run = self.repo.get_workflow_run(self.run_id) + if run.status != "completed": + logger.warning(f"Failed to cancel workflow run {self.run_id}") + except Exception as e: + logger.error(f"Error cancelling workflow: {str(e)}", exc_info=e) + raise + + logger.warning( + f"Workflow {self.run_id} cancelled - " + f"exceeded {timeout_minutes} minute timeout" + ) + raise TimeoutError( + f"Workflow {self.run_id} cancelled - " + f"exceeded {timeout_minutes} minute timeout" + ) + + if run.status == "completed": + return + + await callback(self) + await asyncio.sleep(20) + except TimeoutError: + raise + except Exception as e: + logger.error(f"Error waiting for GitHub run {self.run_id}: {e}", exc_info=e) + raise + + async def download_artifacts(self) -> dict: + logger.info("Attempting to download artifacts for run %s", self.run_id) + artifacts = self.run.get_artifacts() + + extracted = {} + + for artifact in artifacts: + url = artifact.archive_download_url + headers = {"Authorization": f"token {GITHUB_TOKEN}"} + response = requests.get(url, headers=headers) + + if response.status_code == 200: + with tempfile.NamedTemporaryFile("w+b") as temp: + temp.write(response.content) + temp.flush() + + with zipfile.ZipFile(temp.name) as z: + artifact_dict = {} + for file in z.namelist(): + with z.open(file) as f: + artifact_dict[file] = f.read() + + extracted[artifact.name] = artifact_dict + else: + raise RuntimeError( + f"Failed to download artifact {artifact.name}. " + f"Status code: {response.status_code}" + ) + + logger.info("Download artifacts for run %s: %s", self.run_id, list(extracted.keys())) + return extracted