Skip to content

Commit

Permalink
Multinode Cleanup (#1360)
Browse files Browse the repository at this point in the history
* nodecount cleanup

* adds node count to dict marshalling

* updates tests

* updates test

* remove unnecessary things

* pr review
  • Loading branch information
rcano-baseten authored Feb 6, 2025
1 parent f7be3e0 commit 27dced4
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 18 deletions.
31 changes: 17 additions & 14 deletions truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from truss.base.validation import (
validate_cpu_spec,
validate_memory_spec,
validate_node_count,
validate_python_executable_path,
validate_secret_name,
validate_secret_to_path_mapping,
Expand All @@ -43,7 +44,6 @@
DEFAULT_CPU = "1"
DEFAULT_MEMORY = "2Gi"
DEFAULT_USE_GPU = False
DEFAULT_NODE_COUNT = 1

DEFAULT_BLOB_BACKEND = HTTP_PUBLIC_BLOB_BACKEND

Expand Down Expand Up @@ -261,7 +261,7 @@ class Resources:
memory: str = DEFAULT_MEMORY
use_gpu: bool = DEFAULT_USE_GPU
accelerator: AcceleratorSpec = field(default_factory=AcceleratorSpec)
node_count: int = DEFAULT_NODE_COUNT
node_count: Optional[int] = None

@staticmethod
def from_dict(d):
Expand All @@ -273,24 +273,27 @@ def from_dict(d):
use_gpu = d.get("use_gpu", DEFAULT_USE_GPU)
if accelerator.accelerator is not None:
use_gpu = True
# TODO[rcano]: add validation for node count
node_count = d.get("node_count", DEFAULT_NODE_COUNT)

return Resources(
cpu=cpu,
memory=memory,
use_gpu=use_gpu,
accelerator=accelerator,
node_count=node_count,
)

r = Resources(cpu=cpu, memory=memory, use_gpu=use_gpu, accelerator=accelerator)

# only add node_count if not None. This helps keep
# config generated by truss init concise.
node_count = d.get("node_count")
validate_node_count(node_count)
r.node_count = node_count

return r

def to_dict(self):
return {
d = {
"cpu": self.cpu,
"memory": self.memory,
"use_gpu": self.use_gpu,
"accelerator": self.accelerator.to_str(),
}
if self.node_count is not None:
d["node_count"] = self.node_count
return d


@dataclass
Expand Down Expand Up @@ -775,7 +778,7 @@ def _handle_env_vars(env_vars: Dict[str, Any]) -> Dict[str, str]:


DATACLASS_TO_REQ_KEYS_MAP = {
Resources: {"accelerator", "cpu", "memory", "use_gpu", "node_count"},
Resources: {"accelerator", "cpu", "memory", "use_gpu"},
Runtime: {"predict_concurrency"},
Build: {"model_server"},
TrussConfig: {
Expand Down
16 changes: 15 additions & 1 deletion truss/base/validation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import re
from pathlib import PurePosixPath
from typing import Dict, Pattern
from typing import Any, Dict, Pattern

from truss.base.constants import REGISTRY_BUILD_SECRET_PREFIX
from truss.base.errors import ValidationError
Expand Down Expand Up @@ -122,3 +122,17 @@ def validate_python_executable_path(path: str) -> None:
raise ValidationError(
f"Invalid relative python executable path {path}. Provide an absolute path"
)


def validate_node_count(node_count: Any) -> None:
fieldpath = "resources.node_count"
if node_count is None:
return None
if not isinstance(node_count, int):
raise ValidationError(
f"{fieldpath} must be a postiive integer. Got {node_count} of type '{type(node_count)}'"
)
if node_count < 1:
raise ValidationError(
f"{fieldpath} must be a positive integer. Got {node_count}."
)
1 change: 0 additions & 1 deletion truss/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,6 @@ def default_config() -> Dict[str, Any]:
"cpu": "1",
"memory": "2Gi",
"use_gpu": False,
"node_count": 1,
},
"secrets": {},
"system_packages": [],
Expand Down
12 changes: 11 additions & 1 deletion truss/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@
"accelerator": "A10G:4",
},
),
(
{"node_count": 2},
Resources(node_count=2),
{
"cpu": DEFAULT_CPU,
"memory": DEFAULT_MEMORY,
"use_gpu": False,
"accelerator": None,
"node_count": 2,
},
),
],
)
def test_parse_resources(input_dict, expect_resources, output_dict):
Expand Down Expand Up @@ -170,7 +181,6 @@ def test_default_config_not_crowded_end_to_end():
accelerator: null
cpu: '1'
memory: 2Gi
node_count: 1
use_gpu: false
secrets: {}
system_packages: []
Expand Down
1 change: 0 additions & 1 deletion truss/tests/test_truss_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,6 @@ def generate_default_config():
"cpu": "1",
"memory": "2Gi",
"use_gpu": False,
"node_count": 1,
},
"secrets": {},
"system_packages": [],
Expand Down

0 comments on commit 27dced4

Please sign in to comment.