Skip to content

Commit

Permalink
beeflow:useContainer support for SquashFS
Browse files Browse the repository at this point in the history
  • Loading branch information
arhall0 committed Jan 17, 2025
1 parent 8123777 commit 68c2a51
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 9 deletions.
28 changes: 19 additions & 9 deletions beeflow/common/crt/charliecloud_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,14 @@ def run_text(self, task): # noqa
container_path = '/'.join([self.container_archive, task_container_name]) + '.tar.gz'

# If use_container is specified, no copying is done, the file path is used
squashfs = False
if use_container:
task_container_name = self.get_ccname(use_container)
container_path = os.path.expanduser(use_container)
_, ext = os.path.splitext(container_path)
# infer if using SquashFS, similar to:
# https://hpc.github.io/charliecloud/ch-convert.html#format-inference
squashfs = ext in ['.sqfs', '.squash', '.squashfs']
else:
container_path = '/'.join([self.container_archive, task_container_name]) + '.tar.gz'

Expand All @@ -119,12 +124,20 @@ def run_text(self, task): # noqa
mpi_opt = ''
command = ' '.join(task.command)
env_code = '\n'.join([self.cc_setup if self.cc_setup else '', task_workdir_env])
deployed_path = deployed_image_root + '/' + task_container_name
pre_commands = [
Command(f'mkdir -p {deployed_image_root}\n'.split(), CommandType.ONE_PER_NODE),
Command(f'ch-convert -i tar -o dir {container_path} {deployed_path}\n'.split(),
CommandType.ONE_PER_NODE),
]
pre_commands = []
post_commands = []
if squashfs:
deployed_path = container_path
else:
deployed_path = deployed_image_root + '/' + task_container_name
pre_commands = [
Command(f'mkdir -p {deployed_image_root}\n'.split(), CommandType.ONE_PER_NODE),
Command(f'ch-convert -i tar -o dir {container_path} {deployed_path}\n'.split(),
CommandType.ONE_PER_NODE),
]
post_commands = [
Command(f'rm -rf {deployed_path}\n'.split(), type_=CommandType.ONE_PER_NODE),
]
# Need to convert the path from inside to outside base on the bind mounts
extra_opts = ''
if task.workdir is not None:
Expand All @@ -140,9 +153,6 @@ def run_text(self, task): # noqa
main_command = (f'ch-run {mpi_opt} {deployed_path} {self.chrun_opts} '
f'{extra_opts} {bind_mount_opts} -- {command}\n').split()
main_command = Command(main_command)
post_commands = [
Command(f'rm -rf {deployed_path}\n'.split(), type_=CommandType.ONE_PER_NODE),
]
return ContainerRuntimeResult(env_code, pre_commands, main_command, post_commands)

def build_text(self, userconfig, task):
Expand Down
70 changes: 70 additions & 0 deletions beeflow/tests/test_charliecloud_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Charliecloud driver tests."""
import pytest
from beeflow.common.crt.charliecloud_driver import CharliecloudDriver as crt_driver
from beeflow.common.wf_data import Task, Requirement


@pytest.mark.parametrize(
"use_container, pre_commands_exp, main_command_exp, post_commands_exp",
[
("cont.sqfs", "", "ch-run cont.sqfs env --cd -b : -- default", ""),
(
"cont.tar.gz",
"mkdir -p env one-per-node ch-convert -i tar -o dir cont.tar.gz env/cont one-per-node",
"ch-run env/cont env --cd -b : -- default",
"rm -rf env/cont one-per-node",
),
],
)
def test_run_text_use_container(
mocker, tmpdir, use_container, pre_commands_exp, main_command_exp, post_commands_exp
):
"""Test run_text with different useContainer DockerRequirements."""
tmpdir_str = str(tmpdir)
mocker.patch("beeflow.common.config_driver.BeeConfig.get", return_value="env")
mocker.patch(
"beeflow.common.config_driver.BeeConfig.resolve_path", return_value=tmpdir_str
)
mocker.patch("os.getenv", return_value=str(tmpdir))
requirements = [
Requirement(
"DockerRequirement",
{
"beeflow:containerName": None,
"beeflow:bindMounts": None,
"beeflow:copyContainer": None,
"dockerPull": None,
"beeflow:useContainer": use_container,
},
)
]
task = Task(
name="",
base_command="",
hints=[],
requirements=requirements,
inputs=[],
outputs=[],
stdout="",
stderr="",
workflow_id="",
workdir=tmpdir,
)
driver = crt_driver()
res = driver.run_text(task)
assert res.env_code.count(tmpdir_str) == 1
# simplify results for easier comparison
env_code = res.env_code.replace(tmpdir_str, "").replace("\n", " ")
pre_commands = " ".join(
[f'{" ".join(com.args)} {com.type}' for com in res.pre_commands]
).replace(tmpdir_str, "")
main_command = f'{" ".join(res.main_command.args)} {res.main_command.type}'
assert main_command.count(tmpdir_str) == 3
main_command = main_command.replace(tmpdir_str, "")
post_commands = " ".join(
[f'{" ".join(com.args)} {com.type}' for com in res.post_commands]
)
assert env_code == "env cd "
assert pre_commands == pre_commands_exp
assert main_command == main_command_exp
assert post_commands == post_commands_exp

0 comments on commit 68c2a51

Please sign in to comment.