Skip to content

Commit

Permalink
Extract MPI code from execute_task()
Browse files Browse the repository at this point in the history
The `execute_task()` function is used by multiple executors, but the MPI
code is specific to HTEX.
  • Loading branch information
rjmello committed Nov 19, 2024
1 parent 9fb5269 commit 6290750
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,28 +590,28 @@ def update_resource_spec_env_vars(mpi_launcher: str, resource_spec: Dict, node_i
os.environ[key] = prefix_table[key]


def execute_task(bufs, mpi_launcher: Optional[str] = None):
"""Deserialize the buffer and execute the task.
def _init_mpi_env(mpi_launcher: str, resource_spec: Dict):
node_list = resource_spec.get("MPI_NODELIST")
if node_list is None:
return
nodes_for_task = node_list.split(',')
logger.info(f"Launching task on provisioned nodes: {nodes_for_task}")
update_resource_spec_env_vars(mpi_launcher=mpi_launcher, resource_spec=resource_spec, node_info=nodes_for_task)


def execute_task(bufs: bytes):
"""Deserialize the buffer and execute the task.
Returns the result or throws exception.
"""
user_ns = locals()
user_ns.update({'__builtins__': __builtins__})

f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, user_ns, copy=False)
f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, copy=False)

for varname in resource_spec:
envname = "PARSL_" + str(varname).upper()
os.environ[envname] = str(resource_spec[varname])

if resource_spec.get("MPI_NODELIST"):
worker_id = os.environ['PARSL_WORKER_RANK']
nodes_for_task = resource_spec["MPI_NODELIST"].split(',')
logger.info(f"Launching task on provisioned nodes: {nodes_for_task}")
assert mpi_launcher
update_resource_spec_env_vars(mpi_launcher,
resource_spec=resource_spec,
node_info=nodes_for_task)
# We might need to look into callability of the function from itself
# since we change it's name in the new namespace
prefix = "parsl_"
Expand Down Expand Up @@ -786,8 +786,10 @@ def manager_is_alive():
ready_worker_count.value -= 1
worker_enqueued = False

_init_mpi_env(mpi_launcher=mpi_launcher, resource_spec=req["resource_spec"])

try:
result = execute_task(req['buffer'], mpi_launcher=mpi_launcher)
result = execute_task(req['buffer'])
serialized_result = serialize(result, buffer_threshold=1000000)
except Exception as e:
logger.info('Caught an exception: {}'.format(e))
Expand Down

0 comments on commit 6290750

Please sign in to comment.