Skip to content

Commit

Permalink
Making shared_memory configurable (#969)
Browse files Browse the repository at this point in the history
* Making shared_memory configurable

* fix eol space
  • Loading branch information
RsEnts authored Feb 7, 2025
1 parent 323faa3 commit ceab4f4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
9 changes: 6 additions & 3 deletions axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ class GCSFuseMount(VolumeMount):
cpu: Defaults to 250m. Increase if higher throughput needed.
memory: Defaults to 256Mi. Set proportionally to number of files processed (not filesize).
ephemeral_gb: Defaults to 5Gi. Used for staging temp files before uploading to GCS.
shared_memory: Default to 1Gi. Used for e.g. Grain-related jobs which store prefetch
elements in shared_memory. Setting it to 0 means unlimited shared_memory.
read_only: Whether the mount should be read-only.
"""

Expand All @@ -329,6 +331,7 @@ class GCSFuseMount(VolumeMount):
cpu: str = "250m"
memory: str = "256Mi"
ephemeral_gb: str = "5Gi"
shared_memory: str = "1Gi"


@dataclass(kw_only=True)
Expand Down Expand Up @@ -625,12 +628,12 @@ def _build_uploader_container(self) -> Nested[Any]:
volumeMounts=volume_mounts,
)

def _build_shared_memory_volumes(self):
def _build_shared_memory_volumes(self, shared_memory: str):
volume = {
"name": "shared-memory",
"emptyDir": {
"medium": "Memory",
"sizeLimit": "1Gi",
"sizeLimit": shared_memory,
},
}
return volume
Expand All @@ -651,7 +654,7 @@ def _build_pod(self) -> Nested[Any]:
if cfg.gcsfuse_mount:
# Increases the shared memory volumes when enabled gcsfuse. This is useful when grain
# prefetch is enabled.
volumes.append(self._build_shared_memory_volumes())
volumes.append(self._build_shared_memory_volumes(cfg.gcsfuse_mount.shared_memory))
# Mount a GCS bucket as a volume.
annotations.update(
{
Expand Down
10 changes: 9 additions & 1 deletion axlearn/cloud/gcp/job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,10 @@ class Config(Bundler.Config):
location_hint=["test-location-hint", None],
enable_tpu_smart_repair=[True, False],
host_mount_spec=[["name=host-mount,host_path=/tmp,mount_path=/host-tmp"], None],
gcsfuse_mount_spec=[["mount_path=/tmp/gcsfuse", "gcs_path=/tmp/gcs_path"], None],
gcsfuse_mount_spec=[
["mount_path=/tmp/gcsfuse", "gcs_path=/tmp/gcs_path", "shared_memory=5Gi"],
None,
],
priority_class=[None, "such-high-priority"],
)
def test_build_pod(
Expand Down Expand Up @@ -458,6 +461,11 @@ def test_build_pod(
for v in pod_spec["volumes"]:
if v["name"] == "shared-memory":
self.assertIn("sizeLimit", v["emptyDir"])
size_limit_request = [x for x in gcsfuse_mount_spec if "shared_memory" in x]
self.assertLessEqual(len(size_limit_request), 1)
if size_limit_request:
size_limit_request = size_limit_request[0].split("=")[1]
self.assertEqual(v["emptyDir"]["sizeLimit"], size_limit_request)
else:
self.assertNotIn("shared-memory", [v["name"] for v in pod_spec["volumes"]])

Expand Down

0 comments on commit ceab4f4

Please sign in to comment.