Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GitHub refactor #137

Merged
merged 1 commit into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 41 additions & 145 deletions src/discord-cluster-manager/cogs/github_cog.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import asyncio
import json
import tempfile
import zipfile
from datetime import datetime, timedelta, timezone
from typing import Optional

import discord
import requests
from consts import GPUType
from discord import app_commands
from discord.ext import commands
from env import GITHUB_REPO, GITHUB_TOKEN
from github import Github
from github_runner import GitHubRun
from leaderboard_eval import amd_requirements, nvidia_requirements
from report import generate_report
from run_eval import CompileResult, FullResult, RunResult
from utils import build_task_config, get_github_branch_name, send_discord_message, setup_logging
from utils import build_task_config, send_discord_message, setup_logging

logger = setup_logging()

Expand Down Expand Up @@ -68,36 +62,23 @@ async def run_github(
else:
reference_content = None

run_id = await self.trigger_github_run(
artifacts = await self.execute_github_run(
lang=lang,
gpu_type=selected_gpu,
script_content=script_content,
reference_content=reference_content,
thread=thread,
)

if run_id:
await thread.send(
f"GitHub Action triggered! Run ID: {run_id}\nMonitoring progress..."
)
status, result, url = await self.check_workflow_status(run_id, thread, gpu_type)

await thread.send(f"Training completed with status: {status}")

if isinstance(result, FullResult):
await generate_report(thread, result)
else:
if len(result) > 1900:
await self.bot.send_chunked_message(thread, result, code_block=True)
else:
await thread.send(f"```\nLogs:\n{result}\n```")

if url:
await thread.send(f"View the full run at: <{url}>")
logs = artifacts["run-result"]["result.json"].decode("utf-8")
data = json.loads(logs)
if "compile" in data and data["compile"] is not None:
comp = CompileResult(**data["compile"])
else:
await thread.send(
"Failed to trigger GitHub Action. Please check the configuration."
)

comp = None
run = RunResult(**data["run"])
result = FullResult(success=True, error="", compile=comp, run=run)
await generate_report(thread, result)
return thread, result

except Exception as e:
Expand All @@ -106,9 +87,14 @@ async def run_github(
await thread.send(f"Error processing request: {str(e)}")
raise

async def trigger_github_run(
self, lang: str, gpu_type: GPUType, script_content: str, reference_content: Optional[str]
):
async def execute_github_run(
self,
lang: str,
gpu_type: GPUType,
script_content: str,
reference_content: Optional[str],
thread: discord.Thread,
) -> dict:
if lang == "cu" and gpu_type == GPUType.AMD:
# TODO implement HIP
raise ValueError("Cannot use CUDA runs with AMD GPUs")
Expand All @@ -123,14 +109,11 @@ async def trigger_github_run(
)

logger.info(f"Attempting to trigger GitHub action for {lang_name} on {gpu_type.name}")
gh = Github(GITHUB_TOKEN)
repo = gh.get_repo(GITHUB_REPO)

try:
trigger_time = datetime.now(timezone.utc)
workflow_file = gpu_type.value
workflow = repo.get_workflow(workflow_file)
workflow_file = gpu_type.value
run = GitHubRun(workflow_file)

try:
payload = json.dumps(config)

inputs = {"payload": payload}
Expand All @@ -140,116 +123,29 @@ async def trigger_github_run(
else:
inputs["requirements"] = amd_requirements

success = workflow.create_dispatch(get_github_branch_name(), inputs=inputs)
if success:
await asyncio.sleep(2)
runs = list(workflow.get_runs())

for run in runs:
if run.created_at.replace(tzinfo=timezone.utc) > trigger_time:
return run.id
return None

except Exception as e:
logger.error(f"Error in trigger_github_action: {str(e)}", exc_info=True)
return None

async def check_workflow_status(self, run_id, thread, gpu_type):
logger.info(f"Starting to monitor workflow status for run {run_id}")
gh = Github(GITHUB_TOKEN)
repo = gh.get_repo(GITHUB_REPO)
start_time = datetime.now(timezone.utc)
timeout_minutes = 5
timeout = timedelta(minutes=timeout_minutes)

while True:
try:
run = repo.get_workflow_run(run_id)
elapsed_time = datetime.now(timezone.utc) - start_time

if elapsed_time > timeout:
try:
run.cancel()
# Wait briefly to ensure cancellation is processed
# And Verify the run was actually cancelled
await asyncio.sleep(5)
run = repo.get_workflow_run(run_id)
if run.status != "completed":
logger.warning(f"Failed to cancel workflow run {run_id}")
except Exception as e:
logger.error(f"Error cancelling workflow: {str(e)}")

await thread.send(
f"Workflow cancelled - exceeded {timeout_minutes} minute timeout"
)
return (
"cancelled",
f"Workflow exceeded {timeout_minutes} minute timeout",
run.html_url,
)

if run.status == "completed":
result = await self.download_results(run_id)
return run.conclusion, result, run.html_url

if not await run.trigger(inputs):
await thread.send(
f"Workflow: {run.status} running for "
f"{elapsed_time.total_seconds():.2f} seconds\n"
f"Live view: <{run.html_url}>"
"Failed to trigger GitHub Action. Please check the configuration."
)
await asyncio.sleep(20)
except Exception as e:
logger.error("Error", exc_info=e)
return "error", str(e), None
return {}

async def download_results(self, run_id) -> FullResult:
try:
data = await self.download_artifact(run_id, name="run-result")
logs = data["result.json"].decode("utf-8")
data = json.loads(logs)
if "compile" in data and data["compile"] is not None:
comp = CompileResult(**data["compile"])
else:
comp = None
run = RunResult(**data["run"])
return FullResult(success=True, error="", compile=comp, run=run)
except Exception as e:
logger.error("Error downloading artifacts", exc_info=e)
return FullResult(
success=False,
error=f"Error downloading artifacts: {repr(e)}",
compile=None,
run=None,
status_msg = await thread.send(
"**Running on GitHub...**\n" "> ⏳ Waiting for workflow to start..."
)
await run.wait_for_completion(lambda x: self.wait_callback(x, thread, status_msg))
await thread.send(f"Running completed with status: {run.status}")

async def download_artifact(self, run_id, name: str):
logger.info(f"Attempting to download artifact {name} for run {run_id}")
gh = Github(GITHUB_TOKEN)
repo = gh.get_repo(GITHUB_REPO)

run = repo.get_workflow_run(run_id)
artifacts = run.get_artifacts()
return await run.download_artifacts()

for artifact in artifacts:
if artifact.name == name:
url = artifact.archive_download_url
headers = {"Authorization": f"token {GITHUB_TOKEN}"}
response = requests.get(url, headers=headers)

if response.status_code == 200:
with tempfile.NamedTemporaryFile("w+b") as temp:
temp.write(response.content)
temp.flush()
except Exception as e:
logger.error(f"Error in trigger_github_action: {str(e)}", exc_info=True)
raise

with zipfile.ZipFile(temp.name) as z:
artifact_dict = {}
for file in z.namelist():
with z.open(file) as f:
artifact_dict[file] = f.read()
async def wait_callback(self, run: GitHubRun, thread: discord.Thread, msg: discord.Message):
message = (
f"**Running on GitHub...**\n"
f"> Workflow [{run.run_id}]({run.html_url}): {run.status}\n"
f"> ⏳ {run.elapsed_time.total_seconds():.1f} seconds\n"
)

return artifact_dict
else:
raise RuntimeError(
f"Failed to download artifact. Status code: {response.status_code}"
)
raise RuntimeError(f"Could not find artifact {name}")
await msg.edit(content=message)
8 changes: 1 addition & 7 deletions src/discord-cluster-manager/cogs/verify_run_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,7 @@ async def verify_github_run(

message_contents = [msg.content async for msg in github_thread.history(limit=None)]

required_patterns = [
"Processing `.*` with",
"GitHub Action triggered! Run ID:",
"Training completed with status: success",
"'check': 'pass'",
"View the full run at:",
]
required_patterns = ["Processing `.*` with", "Running on GitHub...", "'check': 'pass'"]

all_patterns_found = all(
any(re.search(pattern, content, re.DOTALL) is not None for content in message_contents)
Expand Down
Loading
Loading