Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] shuffle testing #3492

Merged
merged 8 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ setup_commands:
- uv v
- echo "source $HOME/.venv/bin/activate" >> $HOME/.bashrc
- source .venv/bin/activate
- uv pip install pip ray[default] py-spy \{{DAFT_INSTALL}}
- uv pip install pip ray[default] py-spy \{{DAFT_INSTALL}} \{{OTHER_INSTALLS}}
25 changes: 25 additions & 0 deletions .github/ci-scripts/format_env_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import argparse
import json


def parse_env_var_str(env_var_str: str) -> dict:
iter = map(
lambda s: s.strip().split("="),
filter(lambda s: s, env_var_str.split(",")),
)
return {k: v for k, v in iter}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
raunakab marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--env-vars", required=True)
args = parser.parse_args()

env_vars = parse_env_var_str(args.env_vars)
ray_env_vars = {
"env_vars": {
"DAFT_ENABLE_RAY_TRACING": "1",
**env_vars,
},
}
raunakab marked this conversation as resolved.
Show resolved Hide resolved
print(json.dumps(ray_env_vars))
25 changes: 25 additions & 0 deletions .github/ci-scripts/read_inline_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# /// script
# requires-python = ">=3.12"
# dependencies = []
# ///

raunakab marked this conversation as resolved.
Show resolved Hide resolved
import re

import tomllib

REGEX = r"(?m)^# /// (?P<type>[a-zA-Z0-9-]+)$\s(?P<content>(^#(| .*)$\s)+)^# ///$"


def read(script: str) -> dict | None:
name = "script"
matches = list(filter(lambda m: m.group("type") == name, re.finditer(REGEX, script)))
if len(matches) > 1:
raise ValueError(f"Multiple {name} blocks found")
elif len(matches) == 1:
content = "".join(
line[2:] if line.startswith("# ") else line[1:]
for line in matches[0].group("content").splitlines(keepends=True)
)
return tomllib.loads(content)
else:
return None
68 changes: 42 additions & 26 deletions .github/ci-scripts/templatize_ray_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
# /// script
# requires-python = ">=3.12"
# dependencies = ['pydantic']
# ///

import sys
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import read_inline_metadata
from pydantic import BaseModel, Field

CLUSTER_NAME_PLACEHOLDER = "\\{{CLUSTER_NAME}}"
DAFT_INSTALL_PLACEHOLDER = "\\{{DAFT_INSTALL}}"
OTHER_INSTALL_PLACEHOLDER = "\{{OTHER_INSTALLS}}"
raunakab marked this conversation as resolved.
Show resolved Hide resolved
PYTHON_VERSION_PLACEHOLDER = "\\{{PYTHON_VERSION}}"
CLUSTER_PROFILE__NODE_COUNT = "\\{{CLUSTER_PROFILE/node_count}}"
CLUSTER_PROFILE__INSTANCE_TYPE = "\\{{CLUSTER_PROFILE/instance_type}}"
CLUSTER_PROFILE__IMAGE_ID = "\\{{CLUSTER_PROFILE/image_id}}"
CLUSTER_PROFILE__SSH_USER = "\\{{CLUSTER_PROFILE/ssh_user}}"
CLUSTER_PROFILE__VOLUME_MOUNT = "\\{{CLUSTER_PROFILE/volume_mount}}"

NOOP_STEP = "echo 'noop step; skipping'"


@dataclass
class Profile:
Expand All @@ -22,6 +34,11 @@ class Profile:
volume_mount: Optional[str] = None


class Metadata(BaseModel, extra="allow"):
dependencies: list[str] = Field(default_factory=list)
env: dict[str, str] = Field(default_factory=dict)


profiles: dict[str, Optional[Profile]] = {
"debug_xs-x86": Profile(
instance_type="t3.large",
Expand Down Expand Up @@ -50,15 +67,16 @@ class Profile:
content = sys.stdin.read()

parser = ArgumentParser()
parser.add_argument("--cluster-name")
parser.add_argument("--cluster-name", required=True)
parser.add_argument("--daft-wheel-url")
parser.add_argument("--daft-version")
parser.add_argument("--python-version")
parser.add_argument("--cluster-profile")
parser.add_argument("--python-version", required=True)
parser.add_argument("--cluster-profile", required=True, choices=["debug_xs-x86", "medium-x86"])
parser.add_argument("--working-dir", required=True)
parser.add_argument("--entrypoint-script", required=True)
args = parser.parse_args()

if args.cluster_name:
content = content.replace(CLUSTER_NAME_PLACEHOLDER, args.cluster_name)
content = content.replace(CLUSTER_NAME_PLACEHOLDER, args.cluster_name)

if args.daft_wheel_url and args.daft_version:
raise ValueError(
Expand All @@ -72,26 +90,24 @@ class Profile:
daft_install = "getdaft"
content = content.replace(DAFT_INSTALL_PLACEHOLDER, daft_install)

if args.python_version:
content = content.replace(PYTHON_VERSION_PLACEHOLDER, args.python_version)

if cluster_profile := args.cluster_profile:
cluster_profile: str
if cluster_profile not in profiles:
raise Exception(f'Cluster profile "{cluster_profile}" not found')

profile = profiles[cluster_profile]
if profile is None:
raise Exception(f'Cluster profile "{cluster_profile}" not yet implemented')

assert profile is not None
content = content.replace(CLUSTER_PROFILE__NODE_COUNT, str(profile.node_count))
content = content.replace(CLUSTER_PROFILE__INSTANCE_TYPE, profile.instance_type)
content = content.replace(CLUSTER_PROFILE__IMAGE_ID, profile.image_id)
content = content.replace(CLUSTER_PROFILE__SSH_USER, profile.ssh_user)
if profile.volume_mount:
content = content.replace(CLUSTER_PROFILE__VOLUME_MOUNT, profile.volume_mount)
else:
content = content.replace(CLUSTER_PROFILE__VOLUME_MOUNT, "echo 'Nothing to mount; skipping'")
content = content.replace(PYTHON_VERSION_PLACEHOLDER, args.python_version)

profile = profiles[args.cluster_profile]
content = content.replace(CLUSTER_PROFILE__NODE_COUNT, str(profile.node_count))
content = content.replace(CLUSTER_PROFILE__INSTANCE_TYPE, profile.instance_type)
content = content.replace(CLUSTER_PROFILE__IMAGE_ID, profile.image_id)
content = content.replace(CLUSTER_PROFILE__SSH_USER, profile.ssh_user)
content = content.replace(
CLUSTER_PROFILE__VOLUME_MOUNT, profile.volume_mount if profile.volume_mount else NOOP_STEP
)

working_dir = Path(args.working_dir)
assert working_dir.exists() and working_dir.is_dir()
entrypoint_script_fullpath: Path = working_dir / args.entrypoint_script
assert entrypoint_script_fullpath.exists() and entrypoint_script_fullpath.is_file()
with open(entrypoint_script_fullpath) as f:
metadata = Metadata(**read_inline_metadata.read(f.read()))

content = content.replace(OTHER_INSTALL_PLACEHOLDER, " ".join(metadata.dependencies))

print(content)
60 changes: 41 additions & 19 deletions .github/workflows/run-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,45 @@ on:
workflow_dispatch:
inputs:
daft_wheel_url:
description: Daft python-wheel URL
type: string
description: A public https url pointing directly to a daft python-wheel to install
required: false
daft_version:
description: Daft version (errors if both this and "Daft python-wheel URL" are provided)
type: string
description: A released version of daft on PyPi to install (errors if both this and `daft_wheel_url` are provided)
required: false
python_version:
description: Python version
type: string
description: The version of python to use
required: false
default: "3.9"
cluster_profile:
description: Cluster profile
type: choice
options:
- medium-x86
- debug_xs-x86
description: The profile to use for the cluster
required: false
default: medium-x86
command:
type: string
description: The command to run on the cluster
required: true
working_dir:
description: Working directory
type: string
description: The working directory to submit to the cluster
required: false
default: .github/working-dir
entrypoint_script:
description: Entry-point python script (must be inside of the working directory)
type: string
required: true
entrypoint_args:
description: Entry-point arguments
type: string
required: false
default: ""
env_vars:
description: Environment variables
type: string
required: false
default: ""

jobs:
run-command:
Expand All @@ -42,6 +52,8 @@ jobs:
id-token: write
contents: read
steps:
- name: Log workflow inputs
run: echo "${{ toJson(github.event.inputs) }}"
- name: Checkout repo
uses: actions/checkout@v4
with:
Expand All @@ -63,15 +75,25 @@ jobs:
- name: Dynamically update ray config file
run: |
source .venv/bin/activate
(cat .github/assets/.template.yaml \
| python .github/ci-scripts/templatize_ray_config.py \
--cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
--daft-wheel-url '${{ inputs.daft_wheel_url }}' \
--daft-version '${{ inputs.daft_version }}' \
--python-version '${{ inputs.python_version }}' \
--cluster-profile '${{ inputs.cluster_profile }}'
(cat .github/assets/template.yaml | \
uv run \
--python 3.12 \
.github/ci-scripts/templatize_ray_config.py \
--cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
--daft-wheel-url '${{ inputs.daft_wheel_url }}' \
--daft-version '${{ inputs.daft_version }}' \
--python-version '${{ inputs.python_version }}' \
--cluster-profile '${{ inputs.cluster_profile }}' \
--working-dir '${{ inputs.working_dir }}' \
--entrypoint-script '${{ inputs.entrypoint_script }}'
) >> .github/assets/ray.yaml
cat .github/assets/ray.yaml
- name: Setup ray env vars
run: |
source .venv/bin/activate
ray_env_var=$(python .github/ci-scripts/format_env_vars.py --env-vars '${{ inputs.env_vars }}')
echo $ray_env_var
echo "ray_env_var=$ray_env_var" >> $GITHUB_ENV
- name: Download private ssh key
run: |
KEY=$(aws secretsmanager get-secret-value --secret-id ci-github-actions-ray-cluster-key-3 --query SecretString --output text)
Expand All @@ -88,15 +110,15 @@ jobs:
- name: Submit job to ray cluster
run: |
source .venv/bin/activate
if [[ -z '${{ inputs.command }}' ]]; then
if [[ -z '${{ inputs.entrypoint_script }}' ]]; then
echo 'Invalid command submitted; command cannot be empty'
exit 1
fi
ray job submit \
--working-dir ${{ inputs.working_dir }} \
--address http://localhost:8265 \
--runtime-env-json '{"env_vars": {"DAFT_ENABLE_RAY_TRACING": "1"}}' \
-- ${{ inputs.command }}
--runtime-env-json "$ray_env_var" \
-- python ${{ inputs.entrypoint_script }} ${{ inputs.entrypoint_args }}
- name: Download log files from ray cluster
run: |
source .venv/bin/activate
Expand Down
Loading
Loading