From 105333b2e10c06ed5bd1963aaa6295c5c53662d0 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 18 Nov 2024 10:18:59 -0800 Subject: [PATCH] Support AMD Runners --- .github/workflows/amd_workflow.yml | 58 +++++++++++++++++ ...train_workflow.yml => nvidia_workflow.yml} | 0 discord-bot.py | 62 +++++++++++-------- 3 files changed, 95 insertions(+), 25 deletions(-) create mode 100644 .github/workflows/amd_workflow.yml rename .github/workflows/{train_workflow.yml => nvidia_workflow.yml} (100%) diff --git a/.github/workflows/amd_workflow.yml b/.github/workflows/amd_workflow.yml new file mode 100644 index 0000000..5683138 --- /dev/null +++ b/.github/workflows/amd_workflow.yml @@ -0,0 +1,58 @@ +name: AMD PyTorch Job + +on: + workflow_dispatch: + inputs: + script_content: + description: 'Content of Python script' + required: true + type: string + filename: + description: 'Name of Python script' + required: true + type: string + +jobs: + train: + runs-on: [amdgpu-mi250-x86-64] + steps: + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Create script + shell: python + run: | + with open('${{ github.event.inputs.filename }}', 'w') as f: + f.write('''${{ github.event.inputs.script_content }}''') + + - name: Install dependencies + run: | + if grep -rE "(import numpy|from numpy)" "${{ github.event.inputs.filename }}"; then + echo "Numpy detected, installing numpy" + pip install numpy + fi + # Check if 'import torch' is in any Python file + if grep -rE "(import torch|from torch)" "${{ github.event.inputs.filename }}"; then + echo "PyTorch detected, installing PyTorch for ROCm" + pip install torch --index-url https://download.pytorch.org/whl/rocm6.2 + fi + # Check if 'import triton' is in any Python file + if grep -rE "(import triton|from triton)" "${{ github.event.inputs.filename }}"; then + echo "Triton detected, installing triton" + pip install triton + fi + + - name: Run script + run: | + python "${{ github.event.inputs.filename }}" > training.log 2>&1 + + - name: Upload artifacts + uses: actions/upload-artifact@v3 + if: always() + with: + name: training-artifacts + path: | + training.log + ${{ github.event.inputs.filename }} diff --git a/.github/workflows/train_workflow.yml b/.github/workflows/nvidia_workflow.yml similarity index 100% rename from .github/workflows/train_workflow.yml rename to .github/workflows/nvidia_workflow.yml diff --git a/discord-bot.py b/discord-bot.py index cdb5c56..52f40ac 100644 --- a/discord-bot.py +++ b/discord-bot.py @@ -10,6 +10,7 @@ import zipfile import subprocess import argparse +from enum import Enum # Set up logging logging.basicConfig( @@ -23,6 +24,18 @@ load_dotenv() logger.info("Environment variables loaded") +class GPUType(Enum): + NVIDIA = "nvidia_workflow.yml" + AMD = "amd_workflow.yml" + +def get_gpu_type(message_content): + """ + Determine GPU type based on message content + """ + if "AMD" in message_content.upper(): + return GPUType.AMD + return GPUType.NVIDIA # Default to NVIDIA if not specified + def get_github_branch_name(): """ Runs a git command to determine the remote branch name, to be used in the GitHub Workflow @@ -41,15 +54,11 @@ def get_github_branch_name(): return 'main' # Validate environment variables -if not os.getenv('DISCORD_TOKEN'): - logger.error("DISCORD_TOKEN not found in environment variables") - raise ValueError("DISCORD_TOKEN not found") -if not os.getenv('GITHUB_TOKEN'): - logger.error("GITHUB_TOKEN not found in environment variables") - raise ValueError("GITHUB_TOKEN not found") -if not os.getenv('GITHUB_REPO'): - logger.error("GITHUB_REPO not found in environment variables") - raise ValueError("GITHUB_REPO not found") +required_env_vars = ['DISCORD_TOKEN', 'GITHUB_TOKEN', 'GITHUB_REPO'] +for var in required_env_vars: + if not os.getenv(var): + logger.error(f"{var} not found in environment variables") + raise ValueError(f"{var} not found") logger.info(f"Using GitHub repo: {os.getenv('GITHUB_REPO')}") @@ -58,21 +67,21 @@ def get_github_branch_name(): intents.message_content = True client = discord.Client(intents=intents) - -async def trigger_github_action(script_content, filename): +async def trigger_github_action(script_content, filename, gpu_type): """ Triggers the GitHub action with custom script contents and filename """ - logger.info("Attempting to trigger GitHub action") + logger.info(f"Attempting to trigger GitHub action for {gpu_type.name} GPU") gh = Github(os.getenv('GITHUB_TOKEN')) repo = gh.get_repo(os.getenv('GITHUB_REPO')) try: trigger_time = datetime.now(timezone.utc) - logger.info(f"Looking for workflow 'train_workflow.yml' in repo {os.getenv('GITHUB_REPO')}") + workflow_file = gpu_type.value + logger.info(f"Looking for workflow '{workflow_file}' in repo {os.getenv('GITHUB_REPO')}") - workflow = repo.get_workflow("train_workflow.yml") - logger.info("Found workflow, attempting to dispatch") + workflow = repo.get_workflow(workflow_file) + logger.info(f"Found workflow, attempting to dispatch for {gpu_type.name}") success = workflow.create_dispatch(get_github_branch_name(), { 'script_content': script_content, @@ -178,7 +187,7 @@ async def on_ready(): logger.info(f'Logged in as {client.user}') for guild in client.guilds: try: - if globals().get('args') and args.debug: # TODO: Fix Do this properly, maybe subclass `discord.Client` for better argument passing + if globals().get('args') and args.debug: await guild.me.edit(nick="Cluster Bot (Staging)") else: await guild.me.edit(nick="Cluster Bot") @@ -199,14 +208,18 @@ async def on_message(message): for attachment in message.attachments: logger.info(f"Processing attachment: {attachment.filename}") if attachment.filename.endswith('.py'): + # Determine GPU type from message + gpu_type = get_gpu_type(message.content) + logger.info(f"Selected {gpu_type.name} GPU for processing") + # Create a thread directly from the original message thread = await message.create_thread( - name=f"Training Job - {datetime.now().strftime('%Y-%m-%d %H:%M')}", + name=f"{gpu_type.name} Training Job - {datetime.now().strftime('%Y-%m-%d %H:%M')}", auto_archive_duration=1440 # Archive after 24 hours of inactivity ) # Send initial message in the thread - await thread.send(f"Found {attachment.filename}! Starting training process...") + await thread.send(f"Found {attachment.filename}! Starting training process on {gpu_type.name} GPU...") try: # Download the file content @@ -216,14 +229,13 @@ async def on_message(message): logger.info(f"Successfully read {attachment.filename} content") # Trigger GitHub Action - run_id = await trigger_github_action(script_content, attachment.filename) + run_id = await trigger_github_action(script_content, attachment.filename, gpu_type) - # TODO: This is is very hacky await asyncio.sleep(10) if run_id: - logger.info(f"Successfully triggered workflow with run ID: {run_id}") - await thread.send(f"GitHub Action triggered successfully! Run ID: {run_id}\nMonitoring progress...") + logger.info(f"Successfully triggered {gpu_type.name} workflow with run ID: {run_id}") + await thread.send(f"GitHub Action triggered successfully on {gpu_type.name}! Run ID: {run_id}\nMonitoring progress...") # Monitor the workflow status, logs, url = await check_workflow_status(run_id, thread) @@ -242,8 +254,8 @@ async def on_message(message): if url: await thread.send(f"View the full run at: {url}") else: - logger.error("Missing run_id. Failed to trigger GitHub Action") - await thread.send("Failed to trigger GitHub Action. Please check the configuration.") + logger.error(f"Missing run_id. Failed to trigger GitHub Action for {gpu_type.name}") + await thread.send(f"Failed to trigger GitHub Action for {gpu_type.name}. Please check the configuration.") except Exception as e: logger.error(f"Error processing request: {str(e)}", exc_info=True) @@ -252,7 +264,7 @@ async def on_message(message): break if not any(att.filename.endswith('.py') for att in message.attachments): - await message.reply("Please attach a file named 'train.py' to your message.") + await message.reply("Please attach a Python file to your message. Include 'AMD' in your message to use AMD GPU, otherwise NVIDIA will be used.") # Run the bot if __name__ == "__main__":