Skip to content

Commit

Permalink
Refactor distributed eval mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
gaohao95 committed Jan 8, 2024
1 parent d7425ab commit c270674
Showing 1 changed file with 33 additions and 32 deletions.
65 changes: 33 additions & 32 deletions pytential/symbolic/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,36 +381,37 @@ def exec_assign(self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate):
def exec_compute_potential_insn(
self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate):
from pytential.qbx.distributed import DistributedQBXLayerPotentialSource
return_timing_data = self.timing_data is not None

is_distributed_fmm = None
mpi_rank = self.comm.Get_rank()
use_target_specific_qbx = None
fmm_backend = None
qbx_order = None
fmm_level_to_order = None
expansion_factory = None

if self.comm.Get_rank() == 0:
source = bound_expr.places.get_geometry(insn.source.geometry)
is_distributed_fmm = isinstance(
source, DistributedQBXLayerPotentialSource)
if is_distributed_fmm:
use_target_specific_qbx = source._use_target_specific_qbx
fmm_backend = source.fmm_backend
qbx_order = source.qbx_order
fmm_level_to_order = source.fmm_level_to_order
expansion_factory = source.expansion_factory

is_distributed_fmm = self.comm.bcast(is_distributed_fmm, root=0)
if is_distributed_fmm:
use_target_specific_qbx = self.comm.bcast(
use_target_specific_qbx, root=0)
fmm_backend = self.comm.bcast(fmm_backend, root=0)
qbx_order = self.comm.bcast(qbx_order, root=0)
fmm_level_to_order = self.comm.bcast(fmm_level_to_order, root=0)
expansion_factory = self.comm.bcast(expansion_factory, root=0)

if is_distributed_fmm and self.comm.Get_rank() != 0:
if mpi_rank == 0:
source: DistributedQBXLayerPotentialSource = \
bound_expr.places.get_geometry(insn.source.geometry)
if not isinstance(source, DistributedQBXLayerPotentialSource):
raise TypeError("Distributed execution mapper can only process"
"distributed layer potential source")

use_target_specific_qbx = source._use_target_specific_qbx
fmm_backend = source.fmm_backend
qbx_order = source.qbx_order
fmm_level_to_order = source.fmm_level_to_order
expansion_factory = source.expansion_factory

use_target_specific_qbx = self.comm.bcast(
use_target_specific_qbx, root=0)
fmm_backend = self.comm.bcast(fmm_backend, root=0)
qbx_order = self.comm.bcast(qbx_order, root=0)
fmm_level_to_order = self.comm.bcast(fmm_level_to_order, root=0)
expansion_factory = self.comm.bcast(expansion_factory, root=0)

assert isinstance(fmm_backend, str)

if mpi_rank != 0:
source = DistributedQBXLayerPotentialSource(
self.comm,
actx.context,
Expand All @@ -420,18 +421,18 @@ def exec_compute_potential_insn(
fmm_backend=fmm_backend,
expansion_factory=expansion_factory)

if self.comm.Get_rank() == 0 or is_distributed_fmm:
result, timing_data = (
source.exec_compute_potential_insn(
actx, insn, bound_expr, evaluate, return_timing_data))
return_timing_data = self.timing_data is not None
result, timing_data = (
source.exec_compute_potential_insn(
actx, insn, bound_expr, evaluate, return_timing_data))

if return_timing_data:
# The compiler ensures this.
assert insn not in self.timing_data
if return_timing_data:
# The compiler ensures this.
assert insn not in self.timing_data

self.timing_data[insn] = timing_data
self.timing_data[insn] = timing_data

return result
return result

def __call__(self, expr, *args, **kwargs):
if self.comm.Get_rank() == 0:
Expand Down

0 comments on commit c270674

Please sign in to comment.