Skip to content

Commit

Permalink
Kjob storage configuration for run command.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbobrovskyi committed Mar 5, 2025
1 parent c7b4958 commit d93b9b0
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 62 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,24 +122,24 @@ jobs:
with:
run-id: '${{needs.set-variables.outputs.run-id}}'
run-unit-tests:
needs: [install-dependencies, set-variables]
uses: ./.github/workflows/unit_tests.yaml
with:
run-id: ${{needs.set-variables.outputs.run-id}}
concurrency: # We support one build or nightly test to run at a time currently.
group: unit-tests-${{needs.set-variables.outputs.run-id}}
cancel-in-progress: true
needs: [linter, set-variables]
run-integration-tests:
needs: [install-dependencies, set-variables]
uses: ./.github/workflows/integration_tests.yaml
with:
run-id: '${{needs.set-variables.outputs.run-id}}'
concurrency: # We support one build or nightly test to run at a time currently.
group: integration-tests-${{needs.set-variables.outputs.run-id}}
cancel-in-progress: true
secrets: inherit
needs: [run-unit-tests, set-variables]
cluster-private:
needs: [run-integration-tests, set-variables]
needs: [linter, run-unit-tests, run-integration-tests, set-variables]
uses: ./.github/workflows/cluster_private.yaml
concurrency: # We support one build or nightly test to run at a time currently.
group: cluster-private-${{needs.set-variables.outputs.run-id}}
Expand All @@ -152,7 +152,7 @@ jobs:
location: '${{needs.set-variables.outputs.location}}'
secrets: inherit
cluster-create:
needs: [run-integration-tests, set-variables]
needs: [linter, run-unit-tests, run-integration-tests, set-variables]
concurrency: # We support one build or nightly test to run at a time currently.
group: cluster-create-${{needs.set-variables.outputs.run-id}}
cancel-in-progress: true
Expand Down
7 changes: 3 additions & 4 deletions src/xpk/commands/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

from argparse import Namespace

from ..core.cluster import setup_k8s_env, create_k8s_service_account
from ..core.cluster import setup_k8s_env, create_xpk_k8s_service_account
from ..core.commands import run_command_for_value
from ..core.config import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE, XPK_SA, DEFAULT_NAMESPACE
from ..core.gcloud_context import add_zone_and_project
from ..core.kueue import LOCAL_QUEUE_NAME
from ..core.storage import get_auto_mount_gcsfuse_storages
from ..core.storage import get_auto_mount_gcsfuse_storages, GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
from ..utils.console import xpk_exit, xpk_print
from .common import set_cluster_command
from ..core.kjob import AppProfileDefaults, prepare_kjob, Kueue_TAS_annotation
Expand Down Expand Up @@ -55,7 +54,7 @@ def batch(args: Namespace) -> None:

def submit_job(args: Namespace) -> None:
k8s_api_client = setup_k8s_env(args)
create_k8s_service_account(XPK_SA, DEFAULT_NAMESPACE)
create_xpk_k8s_service_account()
gcs_fuse_storages = get_auto_mount_gcsfuse_storages(k8s_api_client)

cmd = (
Expand Down
12 changes: 12 additions & 0 deletions src/xpk/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

from argparse import Namespace

from ..core.cluster import create_xpk_k8s_service_account, setup_k8s_env
from ..core.commands import run_command_with_full_controls
from ..core.gcloud_context import add_zone_and_project
from ..core.kueue import LOCAL_QUEUE_NAME
from ..core.storage import get_auto_mount_gcsfuse_storages, GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
from ..utils.console import xpk_exit, xpk_print
from .common import set_cluster_command
from ..core.kjob import AppProfileDefaults, prepare_kjob, Kueue_TAS_annotation
Expand Down Expand Up @@ -50,6 +52,10 @@ def run(args: Namespace) -> None:


def submit_job(args: Namespace) -> None:
k8s_api_client = setup_k8s_env(args)
create_xpk_k8s_service_account()
gcs_fuse_storages = get_auto_mount_gcsfuse_storages(k8s_api_client)

cmd = (
'kubectl kjob create slurm'
f' --profile {AppProfileDefaults.NAME.value}'
Expand All @@ -59,6 +65,12 @@ def submit_job(args: Namespace) -> None:
' --rm'
)

if len(gcs_fuse_storages) > 0:
cmd += (
' --pod-template-annotation'
f' {GCS_FUSE_ANNOTATION_KEY}={GCS_FUSE_ANNOTATION_VALUE}'
)

if args.ignore_unknown_flags:
cmd += ' --ignore-unknown-flags'

Expand Down
7 changes: 3 additions & 4 deletions src/xpk/commands/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
"""

from ..core.commands import run_command_with_full_controls, run_command_for_value, run_command_with_updates
from ..core.cluster import get_cluster_credentials, add_zone_and_project, setup_k8s_env, create_k8s_service_account
from ..core.config import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE, XPK_SA, DEFAULT_NAMESPACE
from ..core.storage import get_auto_mount_gcsfuse_storages
from ..core.cluster import get_cluster_credentials, add_zone_and_project, setup_k8s_env, create_xpk_k8s_service_account
from ..core.storage import get_auto_mount_gcsfuse_storages, GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
from ..utils.console import xpk_exit, xpk_print
from argparse import Namespace

Expand Down Expand Up @@ -82,7 +81,7 @@ def connect_to_new_interactive_shell(args: Namespace) -> int:
xpk_exit(err_code)

k8s_api_client = setup_k8s_env(args)
create_k8s_service_account(XPK_SA, DEFAULT_NAMESPACE)
create_xpk_k8s_service_account()
gcs_fuse_storages = get_auto_mount_gcsfuse_storages(k8s_api_client)

cmd = (
Expand Down
2 changes: 1 addition & 1 deletion src/xpk/commands/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
update_cluster_with_gcpfilestore_driver_if_necessary,
add_zone_and_project,
get_cluster_network,
DEFAULT_NAMESPACE,
)
from ..core.config import DEFAULT_NAMESPACE
from ..core.kjob import (
KJOB_API_GROUP_NAME,
KJOB_API_GROUP_VERSION,
Expand Down
17 changes: 6 additions & 11 deletions src/xpk/commands/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,13 @@
"""

from ..core.cluster import (
create_k8s_service_account,
create_xpk_k8s_service_account,
get_cluster_credentials,
setup_k8s_env,
)
from ..core.commands import run_command_with_updates, run_commands
from ..core.config import (
GCS_FUSE_ANNOTATION_KEY,
GCS_FUSE_ANNOTATION_VALUE,
VERTEX_TENSORBOARD_FEATURE_FLAG,
XPK_CURRENT_VERSION,
parse_env_config,
XPK_SA,
DEFAULT_NAMESPACE,
)
from ..core.commands import run_command_with_updates, run_commands
from ..core.config import VERTEX_TENSORBOARD_FEATURE_FLAG, XPK_CURRENT_VERSION, parse_env_config
from ..core.docker_container import (
get_main_container_docker_image,
get_user_workload_container,
Expand Down Expand Up @@ -68,6 +61,8 @@
get_storages_to_mount,
get_storage_volume_mounts_yaml_for_gpu,
get_storage_volumes_yaml_for_gpu,
GCS_FUSE_ANNOTATION_KEY,
GCS_FUSE_ANNOTATION_VALUE,
)
from ..core.system_characteristics import (
AcceleratorType,
Expand Down Expand Up @@ -491,7 +486,7 @@ def workload_create(args) -> None:
0 if successful and 1 otherwise.
"""
k8s_api_client = setup_k8s_env(args)
create_k8s_service_account(XPK_SA, DEFAULT_NAMESPACE)
create_xpk_k8s_service_account()

workload_exists = check_if_workload_exists(args)

Expand Down
17 changes: 12 additions & 5 deletions src/xpk/core/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
INSTALLER_NCC_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-tcpx-installer.yaml'
INSTALLER_NCC_TCPXO = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpxo/nccl-tcpxo-installer.yaml'

DEFAULT_NAMESPACE = 'default'
XPK_SA = 'xpk-sa'


# TODO(vbarr): Remove this function when jobsets gets enabled by default on
# GKE clusters.
Expand Down Expand Up @@ -232,18 +235,22 @@ def setup_k8s_env(args) -> k8s_client.ApiClient:
return k8s_client.ApiClient() # pytype: disable=bad-return-type


def create_k8s_service_account(name: str, namespace: str) -> None:
def create_xpk_k8s_service_account() -> None:
k8s_core_client = k8s_client.CoreV1Api()
sa = k8s_client.V1ServiceAccount(metadata=k8s_client.V1ObjectMeta(name=name))
sa = k8s_client.V1ServiceAccount(
metadata=k8s_client.V1ObjectMeta(name=XPK_SA)
)

xpk_print(f'Creating a new service account: {name}')
xpk_print(f'Creating a new service account: {XPK_SA}')
try:
k8s_core_client.create_namespaced_service_account(
namespace, sa, pretty=True
DEFAULT_NAMESPACE, sa, pretty=True
)
xpk_print(f'Created a new service account: {sa} successfully')
except ApiException:
xpk_print(f'Service account: {name} already exists. Skipping its creation')
xpk_print(
f'Service account: {XPK_SA} already exists. Skipping its creation'
)


def update_gke_cluster_with_clouddns(args) -> int:
Expand Down
4 changes: 0 additions & 4 deletions src/xpk/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
XPK_CONFIG_FILE = os.path.expanduser('~/.config/xpk/config.yaml')

CONFIGS_KEY = 'configs'
DEFAULT_NAMESPACE = 'default'
XPK_SA = 'xpk-sa'
CFG_BUCKET_KEY = 'cluster-state-gcs-bucket'
CLUSTER_NAME_KEY = 'cluster-name'
PROJECT_KEY = 'project-id'
Expand All @@ -58,8 +56,6 @@
KJOB_SHELL_WORKING_DIRECTORY,
]
VERTEX_TENSORBOARD_FEATURE_FLAG = XPK_CURRENT_VERSION >= '0.4.0'
GCS_FUSE_ANNOTATION_KEY = 'gke-gcsfuse/volumes'
GCS_FUSE_ANNOTATION_VALUE = 'true'


yaml = ruamel.yaml.YAML()
Expand Down
26 changes: 7 additions & 19 deletions src/xpk/core/kjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from kubernetes import client as k8s_client
from kubernetes.client import ApiClient
from kubernetes.client.rest import ApiException
from .cluster import setup_k8s_env
from .config import XPK_SA, DEFAULT_NAMESPACE
from .storage import Storage, get_auto_mount_storages, GCS_FUSE_TYPE, GCP_FILESTORE_TYPE
from .cluster import setup_k8s_env, XPK_SA, DEFAULT_NAMESPACE
from .storage import get_auto_mount_storages
from ..utils.console import xpk_print, xpk_exit
from .commands import run_command_for_value, run_kubectl_apply, run_command_with_updates
from .config import XpkConfig, KJOB_SHELL_IMAGE, KJOB_SHELL_INTERACTIVE_COMMAND, KJOB_SHELL_WORKING_DIRECTORY, KJOB_BATCH_IMAGE, KJOB_BATCH_WORKING_DIRECTORY
Expand All @@ -32,11 +31,7 @@

KJOB_API_GROUP_NAME = "kjobctl.x-k8s.io"
KJOB_API_GROUP_VERSION = "v1alpha1"
KJOB_API_VOLUME_BUNDLE_KIND = "VolumeBundle"
KJOB_API_VOLUME_BUNDLE_PLURAL = KJOB_API_VOLUME_BUNDLE_KIND.lower() + "s"
KJOB_API_VOLUME_BUNDLE_CRD_NAME = (
f"{KJOB_API_VOLUME_BUNDLE_PLURAL}.{KJOB_API_GROUP_NAME}"
)
KJOB_API_VOLUME_BUNDLE_PLURAL = "volumebundles"
VOLUME_BUNDLE_TEMPLATE_PATH = "/../templates/volume_bundle.yaml"


Expand Down Expand Up @@ -288,16 +283,10 @@ def prepare_kjob(args: Namespace) -> int:
system = get_cluster_system_characteristics(args)

k8s_api_client = setup_k8s_env(args)
storages: list[Storage] = get_auto_mount_storages(k8s_api_client)
gcs_fuse_storages = list(
filter(lambda storage: storage.type == GCS_FUSE_TYPE, storages)
)
gcp_filestore_storages = list(
filter(lambda storage: storage.type == GCP_FILESTORE_TYPE, storages)
)
storages = get_auto_mount_storages(k8s_api_client)

service_account = ""
if len(gcs_fuse_storages) > 0 or len(gcp_filestore_storages) > 0:
if len(storages) > 0:
service_account = XPK_SA

job_err_code = create_job_template_instance(args, system, service_account)
Expand All @@ -308,8 +297,7 @@ def prepare_kjob(args: Namespace) -> int:
if pod_err_code > 0:
return pod_err_code

all_storages = gcs_fuse_storages + gcp_filestore_storages
volume_bundles = [item.name for item in all_storages]
volume_bundles = [item.name for item in storages]

return create_app_profile_instance(args, volume_bundles)

Expand Down Expand Up @@ -387,7 +375,7 @@ def create_volume_bundle_instance(
body=data,
)
xpk_print(
f"Created {KJOB_API_VOLUME_BUNDLE_CRD_NAME} object:"
f"Created {KJOB_API_VOLUME_BUNDLE_PLURAL}.{KJOB_API_GROUP_NAME} object:"
f" {data['metadata']['name']}"
)
except ApiException as e:
Expand Down
9 changes: 3 additions & 6 deletions src/xpk/core/pathways.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
limitations under the License.
"""

from .cluster import XPK_SA
from ..core.docker_container import get_user_workload_container
from ..core.gcloud_context import zone_to_region
from ..core.nodepool import get_all_nodepools_programmatic
from ..utils.console import xpk_exit, xpk_print
from .config import (
GCS_FUSE_ANNOTATION_KEY,
GCS_FUSE_ANNOTATION_VALUE,
AcceleratorType,
)
from .storage import XPK_SA, Storage, get_storage_volumes_yaml
from .config import AcceleratorType
from .storage import Storage, get_storage_volumes_yaml, GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
from .system_characteristics import SystemCharacteristics

PathwaysExpectedInstancesMap = {
Expand Down
6 changes: 4 additions & 2 deletions src/xpk/core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,20 @@
from kubernetes.utils import FailToCreateError
from tabulate import tabulate

from .config import XPK_SA
from .cluster import XPK_SA
from ..utils.console import xpk_exit, xpk_print

STORAGE_CRD_PATH = "/../api/storage_crd.yaml"
STORAGE_TEMPLATE_PATH = "/../templates/storage.yaml"
XPK_API_GROUP_NAME = "xpk.x-k8s.io"
XPK_API_GROUP_VERSION = "v1"
STORAGE_CRD_KIND = "Storage"
STORAGE_CRD_PLURAL = STORAGE_CRD_KIND.lower() + "s"
STORAGE_CRD_PLURAL = "storages"
STORAGE_CRD_NAME = f"{XPK_API_GROUP_NAME}.{STORAGE_CRD_PLURAL}"
GCS_FUSE_TYPE = "gcsfuse"
GCP_FILESTORE_TYPE = "gcpfilestore"
GCS_FUSE_ANNOTATION_KEY = "gke-gcsfuse/volumes"
GCS_FUSE_ANNOTATION_VALUE = "true"


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions src/xpk/core/workload_decorators/storage_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

import yaml

from ..config import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
from ...core.storage import GCS_FUSE_TYPE, get_storage_volumes_yaml_dict
from ...core.storage import GCS_FUSE_TYPE, get_storage_volumes_yaml_dict, GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE


def decorate_jobset(jobset_manifest_str, storages) -> str:
Expand Down

0 comments on commit d93b9b0

Please sign in to comment.