From 0f47b6d1c905a6b429a9abf6bd2cb415fe622032 Mon Sep 17 00:00:00 2001 From: Michael Clifford Date: Wed, 26 Oct 2022 14:18:23 -0400 Subject: [PATCH] update appwapper generator (#13) --- src/codeflare_sdk/cluster/config.py | 9 +++++---- .../templates/{new-template.yml => new-template.yaml} | 2 +- src/codeflare_sdk/utils/generate_yaml.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) rename src/codeflare_sdk/templates/{new-template.yml => new-template.yaml} (99%) diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index 921ffb16b..3e825e95f 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -1,10 +1,10 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass class ClusterConfiguration: name: str - head_info: list = [] - machine_types: list = [] #["m4.xlarge", "g4dn.xlarge"] + head_info: list = field(default_factory=list) + machine_types: list = field(default_factory=list) #["m4.xlarge", "g4dn.xlarge"] min_cpus: int = 1 max_cpus: int = 1 min_worker: int = 1 @@ -14,5 +14,6 @@ class ClusterConfiguration: gpu: int = 0 template: str = "src/codeflare_sdk/templates/new-template.yaml" instascale: bool = False - envs: dict = {} + envs: dict = field(default_factory=dict) image: str = "ghcr.io/ibm-ai-foundation/base:ray1.13.0-py38-gpu-pytorch1.12.0cu116-20220826-202124" + diff --git a/src/codeflare_sdk/templates/new-template.yml b/src/codeflare_sdk/templates/new-template.yaml similarity index 99% rename from src/codeflare_sdk/templates/new-template.yml rename to src/codeflare_sdk/templates/new-template.yaml index b66bc0541..24daf3a17 100644 --- a/src/codeflare_sdk/templates/new-template.yml +++ b/src/codeflare_sdk/templates/new-template.yaml @@ -98,7 +98,7 @@ spec: # The value of `resources` is a string-integer mapping. # Currently, `resources` must be provided in the specific format demonstrated below: # resources: '"{\"Custom1\": 1, \"Custom2\": 5}"' - num-gpus: 0 + num-gpus: '0' #pod template template: spec: diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index 66225a709..aaf47be57 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -108,7 +108,7 @@ def update_nodes(item, appwrapper_name, min_cpu, max_cpu, min_memory, max_memory worker["replicas"] = workers worker["minReplicas"] = workers worker["maxReplicas"] = workers - worker["rayStartParams"]["num-gpus"] = int(gpu) + worker["rayStartParams"]["num-gpus"] = str(int(gpu)) for comp in [head, worker]: spec = comp.get("template").get("spec")