diff --git a/pytential/qbx/__init__.py b/pytential/qbx/__init__.py index 7c6cb6cf9..f0b1f95ee 100644 --- a/pytential/qbx/__init__.py +++ b/pytential/qbx/__init__.py @@ -765,7 +765,7 @@ def exec_compute_potential_insn_direct(self, actx, insn, bound_expr, evaluate, from meshmode.discretization import Discretization if return_timing_data: - from pytential.source import UnableToCollectTimingData + from sumpy.fmm import UnableToCollectTimingData from warnings import warn warn( "Timing data collection not supported.", diff --git a/pytential/source.py b/pytential/source.py index fa442d547..83f48ca8d 100644 --- a/pytential/source.py +++ b/pytential/source.py @@ -22,11 +22,19 @@ import numpy as np -from pytools import memoize_in +from pytools import memoize_in, memoize_method from arraycontext import flatten, unflatten from meshmode.dof_array import DOFArray +from meshmode.discretization import Discretization +from arraycontext import ArrayContext -from sumpy.fmm import UnableToCollectTimingData +from sumpy.fmm import (SumpyTimingFuture, + SumpyTreeIndependentDataForWrangler, SumpyExpansionWrangler) +from sumpy.expansion import DefaultExpansionFactory + +from functools import partial +from collections import defaultdict +from typing import Optional, Mapping, Union, Callable __doc__ = """ @@ -96,9 +104,12 @@ def evaluate_kernel_arguments(actx, evaluate, kernel_arguments, flat=True): return kernel_args +default_expansion_factory = DefaultExpansionFactory() + + class PointPotentialSource(_SumpyP2PMixin, PotentialSource): """ - .. attribute:: nodes + .. method:: nodes An :class:`pyopencl.array.Array` of shape ``[ambient_dim, ndofs]``. @@ -108,7 +119,48 @@ class PointPotentialSource(_SumpyP2PMixin, PotentialSource): .. automethod:: exec_compute_potential_insn """ - def __init__(self, nodes): + def __init__(self, nodes, *, + fmm_order: Optional[int] = False, + fmm_level_to_order: Optional[Union[bool, Callable[..., int]]] = None, + expansion_factory: Optional[DefaultExpansionFactory] + = default_expansion_factory, + tree_build_kwargs: Optional[Mapping] = None, + trav_build_kwargs: Optional[Mapping] = None, + setup_actx: Optional[ArrayContext] = None): + """ + :arg nodes: The point potential source given as a + :class:`pyopencl.array.Array` + :arg fmm_order: The order of the FMM for all levels if *fmm_order* is not + *False*. Mutually exclusive with argument *fmm_level_to_order*. + If both arguments are not given a direct point-to-point calculation + is used. + :arg fmm_level_to_order: An optional callable that returns the FMM order + to use for a given level. Mutually exclusive with *fmm_order* + argument. + :arg expansion_factory: An expansion factory to get the expansion objects + when an FMM is used. + :arg tree_build_kwargs: Keyword arguments to be passed when building the + tree for an FMM. + :arg trav_build_kwargs: Keyword arguments to be passed when building a + traversal for an FMM. + :arg setup_actx: An array context to be used when building a tree + for an FMM. + """ + + if fmm_order is not False and fmm_level_to_order is not None: + raise TypeError("may not specify both fmm_order and fmm_level_to_order") + + if fmm_level_to_order is None: + if fmm_order is not False: + def fmm_level_to_order(kernel, kernel_args, tree, level): # noqa pylint:disable=function-redefined + return fmm_order + else: + fmm_level_to_order = False + self.fmm_level_to_order = fmm_level_to_order + self.expansion_factory = expansion_factory + self.tree_build_kwargs = tree_build_kwargs if tree_build_kwargs else {} + self.trav_build_kwargs = trav_build_kwargs if trav_build_kwargs else {} + self._setup_actx = setup_actx self._nodes = nodes @property @@ -131,6 +183,32 @@ def ndofs(self): for coord_ary in self._nodes: return coord_ary.shape[0] + def copy(self, *, nodes=None, fmm_order=None, fmm_level_to_order=None, + expansion_factory=None, tree_build_kwargs=None, trav_build_kwargs=None, + setup_actx=None): + if nodes is None: + nodes = self._nodes + if setup_actx is None: + setup_actx = self._setup_actx + if fmm_level_to_order is None and fmm_order is None: + fmm_level_to_order = self.fmm_level_to_order + if expansion_factory is None: + expansion_factory = self.expansion_factory + if tree_build_kwargs is None: + tree_build_kwargs = self.tree_build_kwargs + if trav_build_kwargs is None: + trav_build_kwargs = self.trav_build_kwargs + + return type(self)( + nodes=nodes, + fmm_order=fmm_order, + fmm_level_to_order=fmm_level_to_order, + expansion_factory=expansion_factory, + tree_build_kwargs=tree_build_kwargs, + trav_build_kwargs=trav_build_kwargs, + setup_actx=setup_actx, + ) + @property def complex_dtype(self): return { @@ -161,46 +239,124 @@ def cost_model_compute_potential_insn(self, actx, insn, bound_expr, evaluate, costs): raise NotImplementedError - def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate, - return_timing_data): - if return_timing_data: - from warnings import warn - warn( - "Timing data collection not supported.", - category=UnableToCollectTimingData) + @memoize_method + def _get_tree(self, target_discr): + """Builds a tree for targets given by *target_discr* and caches the + result. Needed only when an FMM is used. + """ + from boxtree import TreeBuilder + from boxtree.traversal import FMMTraversalBuilder + + actx = self._setup_actx + sources = self._nodes + targets = flatten(target_discr.nodes(), actx, leaf_class=DOFArray) + tree_build = TreeBuilder(actx.context) + trav_build = FMMTraversalBuilder(actx.context, + **self.trav_build_kwargs) + tree, _ = tree_build(actx.queue, sources, targets=targets, + **self.tree_build_kwargs) + trav, _ = trav_build(actx.queue, tree) + return tree, trav + + @memoize_method + def _get_exec_insn_func(self, source_kernels, target_kernels, target_discr): + if self.fmm_level_to_order is False: + def exec_insn(actx, strengths, kernel_args, dtype, return_timing_data): + sources = self._nodes + targets = flatten(target_discr.nodes(), actx, leaf_class=DOFArray) + p2p = self.get_p2p(actx, source_kernels=source_kernels, + target_kernels=target_kernels) + + evt, output = p2p(actx.queue, + targets=targets, + sources=sources, + strength=strengths, **kernel_args) + + if return_timing_data: + timing_data = {"eval_direct": + SumpyTimingFuture(actx.queue, [evt]).result()} + else: + timing_data = None + return timing_data, output + else: + from boxtree.fmm import drive_fmm + + kernel = target_kernels[0].get_base_kernel() + local_expansion_factory = \ + self.expansion_factory.get_local_expansion_class(kernel) + local_expansion_factory = partial(local_expansion_factory, kernel) + mpole_expansion_factory = \ + self.expansion_factory.get_multipole_expansion_class(kernel) + mpole_expansion_factory = partial(mpole_expansion_factory, kernel) + + tree, trav = self._get_tree(target_discr) + + def exec_insn(actx, strengths, kernel_args, dtype, return_timing_data): + tree_indep = SumpyTreeIndependentDataForWrangler( + actx.context, + mpole_expansion_factory, + local_expansion_factory, + target_kernels=target_kernels, + source_kernels=source_kernels, + ) + wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, + fmm_level_to_order=self.fmm_level_to_order, + kernel_extra_kwargs=kernel_args) + timing_data = {} if return_timing_data else None + output = drive_fmm(wrangler, strengths, timing_data=timing_data) + return timing_data, output - p2p = None + return exec_insn + def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate, + return_timing_data): kernel_args = evaluate_kernel_arguments( actx, evaluate, insn.kernel_arguments, flat=False) strengths = [evaluate(density) for density in insn.densities] - # FIXME: Do this all at once - results = [] - for o in insn.outputs: - target_discr = bound_expr.places.get_discretization( - o.target_name.geometry, o.target_name.discr_stage) - - # no on-disk kernel caching - if p2p is None: - p2p = self.get_p2p(actx, source_kernels=insn.source_kernels, - target_kernels=insn.target_kernels) - - evt, output_for_each_kernel = p2p(actx.queue, - targets=flatten(target_discr.nodes(), actx, leaf_class=DOFArray), - sources=self._nodes, - strength=strengths, **kernel_args) - - from meshmode.discretization import Discretization - result = output_for_each_kernel[o.target_kernel_index] - if isinstance(target_discr, Discretization): - template_ary = actx.thaw(target_discr.nodes()[0]) - result = unflatten(template_ary, result, actx, strict=False) + if any(knl.is_complex_valued for knl in insn.target_kernels) or \ + any(_entry_dtype(actx, strength).kind == "c" for + strength in strengths): + dtype = self.complex_dtype + else: + dtype = self.real_dtype - results.append((o.name, result)) + outputs_grouped_by_target = defaultdict(list) + for o in insn.outputs: + outputs_grouped_by_target[o.target_name].append(o) - timing_data = {} - return results, timing_data + results = [] + timing_data_arr = [] + for target_name, output_group in outputs_grouped_by_target.items(): + target_discr = bound_expr.places.get_discretization( + target_name.geometry, target_name.discr_stage) + + exec_insn = self._get_exec_insn_func( + source_kernels=insn.source_kernels, + target_kernels=insn.target_kernels, + target_discr=target_discr, + ) + + timing_data, output_for_each_kernel = \ + exec_insn(actx, strengths, kernel_args, + dtype, return_timing_data) + timing_data_arr.append(timing_data) + + for o in output_group: + result = output_for_each_kernel[o.target_kernel_index] + if isinstance(target_discr, Discretization): + template_ary = actx.thaw(target_discr.nodes()[0]) + result = unflatten(template_ary, result, actx, strict=False) + + results.append((o.name, result)) + + timing_data = defaultdict(list) + if return_timing_data and timing_data_arr: + for timing_data in timing_data_arr: + for description, result in timing_data.items(): + timing_data[description].merge(result) + + return results, dict(timing_data) # }}} diff --git a/pytential/unregularized.py b/pytential/unregularized.py index b4a5d5dee..c2d7a63eb 100644 --- a/pytential/unregularized.py +++ b/pytential/unregularized.py @@ -99,7 +99,7 @@ def exec_compute_potential_insn(self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate, return_timing_data): if return_timing_data: from warnings import warn - from pytential.source import UnableToCollectTimingData + from sumpy.fmm import UnableToCollectTimingData warn( "Timing data collection not supported.", category=UnableToCollectTimingData)