diff --git a/pytential/symbolic/execution.py b/pytential/symbolic/execution.py index e1a31accd..afbd42d8b 100644 --- a/pytential/symbolic/execution.py +++ b/pytential/symbolic/execution.py @@ -381,36 +381,37 @@ def exec_assign(self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate): def exec_compute_potential_insn( self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate): from pytential.qbx.distributed import DistributedQBXLayerPotentialSource - return_timing_data = self.timing_data is not None - is_distributed_fmm = None + mpi_rank = self.comm.Get_rank() use_target_specific_qbx = None fmm_backend = None qbx_order = None fmm_level_to_order = None expansion_factory = None - if self.comm.Get_rank() == 0: - source = bound_expr.places.get_geometry(insn.source.geometry) - is_distributed_fmm = isinstance( - source, DistributedQBXLayerPotentialSource) - if is_distributed_fmm: - use_target_specific_qbx = source._use_target_specific_qbx - fmm_backend = source.fmm_backend - qbx_order = source.qbx_order - fmm_level_to_order = source.fmm_level_to_order - expansion_factory = source.expansion_factory - - is_distributed_fmm = self.comm.bcast(is_distributed_fmm, root=0) - if is_distributed_fmm: - use_target_specific_qbx = self.comm.bcast( - use_target_specific_qbx, root=0) - fmm_backend = self.comm.bcast(fmm_backend, root=0) - qbx_order = self.comm.bcast(qbx_order, root=0) - fmm_level_to_order = self.comm.bcast(fmm_level_to_order, root=0) - expansion_factory = self.comm.bcast(expansion_factory, root=0) - - if is_distributed_fmm and self.comm.Get_rank() != 0: + if mpi_rank == 0: + source: DistributedQBXLayerPotentialSource = \ + bound_expr.places.get_geometry(insn.source.geometry) + if not isinstance(source, DistributedQBXLayerPotentialSource): + raise TypeError("Distributed execution mapper can only process" + "distributed layer potential source") + + use_target_specific_qbx = source._use_target_specific_qbx + fmm_backend = source.fmm_backend + qbx_order = source.qbx_order + fmm_level_to_order = source.fmm_level_to_order + expansion_factory = source.expansion_factory + + use_target_specific_qbx = self.comm.bcast( + use_target_specific_qbx, root=0) + fmm_backend = self.comm.bcast(fmm_backend, root=0) + qbx_order = self.comm.bcast(qbx_order, root=0) + fmm_level_to_order = self.comm.bcast(fmm_level_to_order, root=0) + expansion_factory = self.comm.bcast(expansion_factory, root=0) + + assert isinstance(fmm_backend, str) + + if mpi_rank != 0: source = DistributedQBXLayerPotentialSource( self.comm, actx.context, @@ -420,18 +421,18 @@ def exec_compute_potential_insn( fmm_backend=fmm_backend, expansion_factory=expansion_factory) - if self.comm.Get_rank() == 0 or is_distributed_fmm: - result, timing_data = ( - source.exec_compute_potential_insn( - actx, insn, bound_expr, evaluate, return_timing_data)) + return_timing_data = self.timing_data is not None + result, timing_data = ( + source.exec_compute_potential_insn( + actx, insn, bound_expr, evaluate, return_timing_data)) - if return_timing_data: - # The compiler ensures this. - assert insn not in self.timing_data + if return_timing_data: + # The compiler ensures this. + assert insn not in self.timing_data - self.timing_data[insn] = timing_data + self.timing_data[insn] = timing_data - return result + return result def __call__(self, expr, *args, **kwargs): if self.comm.Get_rank() == 0: