Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow point FMMs #199

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytential/qbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def exec_compute_potential_insn_direct(self, actx, insn, bound_expr, evaluate,
from meshmode.discretization import Discretization

if return_timing_data:
from pytential.source import UnableToCollectTimingData
from sumpy.fmm import UnableToCollectTimingData
from warnings import warn
warn(
"Timing data collection not supported.",
Expand Down
228 changes: 192 additions & 36 deletions pytential/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,19 @@

import numpy as np

from pytools import memoize_in
from pytools import memoize_in, memoize_method
from arraycontext import flatten, unflatten
from meshmode.dof_array import DOFArray
from meshmode.discretization import Discretization
from arraycontext import ArrayContext

from sumpy.fmm import UnableToCollectTimingData
from sumpy.fmm import (SumpyTimingFuture,
SumpyTreeIndependentDataForWrangler, SumpyExpansionWrangler)
from sumpy.expansion import DefaultExpansionFactory

from functools import partial
from collections import defaultdict
from typing import Optional, Mapping, Union, Callable


__doc__ = """
Expand Down Expand Up @@ -96,9 +104,12 @@ def evaluate_kernel_arguments(actx, evaluate, kernel_arguments, flat=True):
return kernel_args


default_expansion_factory = DefaultExpansionFactory()


class PointPotentialSource(_SumpyP2PMixin, PotentialSource):
"""
.. attribute:: nodes
.. method:: nodes

An :class:`pyopencl.array.Array` of shape ``[ambient_dim, ndofs]``.

Expand All @@ -108,7 +119,48 @@ class PointPotentialSource(_SumpyP2PMixin, PotentialSource):
.. automethod:: exec_compute_potential_insn
"""

def __init__(self, nodes):
def __init__(self, nodes, *,
fmm_order: Optional[int] = False,
fmm_level_to_order: Optional[Union[bool, Callable[..., int]]] = None,
expansion_factory: Optional[DefaultExpansionFactory]
= default_expansion_factory,
tree_build_kwargs: Optional[Mapping] = None,
trav_build_kwargs: Optional[Mapping] = None,
setup_actx: Optional[ArrayContext] = None):
"""
:arg nodes: The point potential source given as a
:class:`pyopencl.array.Array`
:arg fmm_order: The order of the FMM for all levels if *fmm_order* is not
*False*. Mutually exclusive with argument *fmm_level_to_order*.
If both arguments are not given a direct point-to-point calculation
is used.
:arg fmm_level_to_order: An optional callable that returns the FMM order
to use for a given level. Mutually exclusive with *fmm_order*
argument.
:arg expansion_factory: An expansion factory to get the expansion objects
when an FMM is used.
:arg tree_build_kwargs: Keyword arguments to be passed when building the
tree for an FMM.
:arg trav_build_kwargs: Keyword arguments to be passed when building a
traversal for an FMM.
:arg setup_actx: An array context to be used when building a tree
for an FMM.
"""

if fmm_order is not False and fmm_level_to_order is not None:
raise TypeError("may not specify both fmm_order and fmm_level_to_order")

if fmm_level_to_order is None:
if fmm_order is not False:
def fmm_level_to_order(kernel, kernel_args, tree, level): # noqa pylint:disable=function-redefined
return fmm_order
else:
fmm_level_to_order = False
self.fmm_level_to_order = fmm_level_to_order
self.expansion_factory = expansion_factory
self.tree_build_kwargs = tree_build_kwargs if tree_build_kwargs else {}
self.trav_build_kwargs = trav_build_kwargs if trav_build_kwargs else {}
self._setup_actx = setup_actx
self._nodes = nodes

@property
Expand All @@ -131,6 +183,32 @@ def ndofs(self):
for coord_ary in self._nodes:
return coord_ary.shape[0]

def copy(self, *, nodes=None, fmm_order=None, fmm_level_to_order=None,
expansion_factory=None, tree_build_kwargs=None, trav_build_kwargs=None,
setup_actx=None):
if nodes is None:
nodes = self._nodes
if setup_actx is None:
setup_actx = self._setup_actx
if fmm_level_to_order is None and fmm_order is None:
fmm_level_to_order = self.fmm_level_to_order
if expansion_factory is None:
expansion_factory = self.expansion_factory
if tree_build_kwargs is None:
tree_build_kwargs = self.tree_build_kwargs
if trav_build_kwargs is None:
trav_build_kwargs = self.trav_build_kwargs

return type(self)(
nodes=nodes,
fmm_order=fmm_order,
fmm_level_to_order=fmm_level_to_order,
expansion_factory=expansion_factory,
tree_build_kwargs=tree_build_kwargs,
trav_build_kwargs=trav_build_kwargs,
setup_actx=setup_actx,
)

@property
def complex_dtype(self):
return {
Expand Down Expand Up @@ -161,46 +239,124 @@ def cost_model_compute_potential_insn(self, actx, insn, bound_expr,
evaluate, costs):
raise NotImplementedError

def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate,
return_timing_data):
if return_timing_data:
from warnings import warn
warn(
"Timing data collection not supported.",
category=UnableToCollectTimingData)
@memoize_method
def _get_tree(self, target_discr):
"""Builds a tree for targets given by *target_discr* and caches the
result. Needed only when an FMM is used.
"""
from boxtree import TreeBuilder
from boxtree.traversal import FMMTraversalBuilder

actx = self._setup_actx
sources = self._nodes
targets = flatten(target_discr.nodes(), actx, leaf_class=DOFArray)
tree_build = TreeBuilder(actx.context)
trav_build = FMMTraversalBuilder(actx.context,
**self.trav_build_kwargs)
tree, _ = tree_build(actx.queue, sources, targets=targets,
**self.tree_build_kwargs)
trav, _ = trav_build(actx.queue, tree)
return tree, trav

@memoize_method
def _get_exec_insn_func(self, source_kernels, target_kernels, target_discr):
if self.fmm_level_to_order is False:
def exec_insn(actx, strengths, kernel_args, dtype, return_timing_data):
sources = self._nodes
targets = flatten(target_discr.nodes(), actx, leaf_class=DOFArray)
p2p = self.get_p2p(actx, source_kernels=source_kernels,
target_kernels=target_kernels)

evt, output = p2p(actx.queue,
targets=targets,
sources=sources,
strength=strengths, **kernel_args)

if return_timing_data:
timing_data = {"eval_direct":
SumpyTimingFuture(actx.queue, [evt]).result()}
else:
timing_data = None
return timing_data, output
else:
from boxtree.fmm import drive_fmm

kernel = target_kernels[0].get_base_kernel()
local_expansion_factory = \
self.expansion_factory.get_local_expansion_class(kernel)
local_expansion_factory = partial(local_expansion_factory, kernel)
mpole_expansion_factory = \
self.expansion_factory.get_multipole_expansion_class(kernel)
mpole_expansion_factory = partial(mpole_expansion_factory, kernel)

tree, trav = self._get_tree(target_discr)

def exec_insn(actx, strengths, kernel_args, dtype, return_timing_data):
tree_indep = SumpyTreeIndependentDataForWrangler(
actx.context,
mpole_expansion_factory,
local_expansion_factory,
target_kernels=target_kernels,
source_kernels=source_kernels,
)
wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
fmm_level_to_order=self.fmm_level_to_order,
kernel_extra_kwargs=kernel_args)
timing_data = {} if return_timing_data else None
output = drive_fmm(wrangler, strengths, timing_data=timing_data)
return timing_data, output

p2p = None
return exec_insn

def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate,
return_timing_data):
kernel_args = evaluate_kernel_arguments(
actx, evaluate, insn.kernel_arguments, flat=False)
strengths = [evaluate(density) for density in insn.densities]

# FIXME: Do this all at once
results = []
for o in insn.outputs:
target_discr = bound_expr.places.get_discretization(
o.target_name.geometry, o.target_name.discr_stage)

# no on-disk kernel caching
if p2p is None:
p2p = self.get_p2p(actx, source_kernels=insn.source_kernels,
target_kernels=insn.target_kernels)

evt, output_for_each_kernel = p2p(actx.queue,
targets=flatten(target_discr.nodes(), actx, leaf_class=DOFArray),
sources=self._nodes,
strength=strengths, **kernel_args)

from meshmode.discretization import Discretization
result = output_for_each_kernel[o.target_kernel_index]
if isinstance(target_discr, Discretization):
template_ary = actx.thaw(target_discr.nodes()[0])
result = unflatten(template_ary, result, actx, strict=False)
if any(knl.is_complex_valued for knl in insn.target_kernels) or \
any(_entry_dtype(actx, strength).kind == "c" for
strength in strengths):
dtype = self.complex_dtype
else:
dtype = self.real_dtype

results.append((o.name, result))
outputs_grouped_by_target = defaultdict(list)
for o in insn.outputs:
outputs_grouped_by_target[o.target_name].append(o)

timing_data = {}
return results, timing_data
results = []
timing_data_arr = []
for target_name, output_group in outputs_grouped_by_target.items():
target_discr = bound_expr.places.get_discretization(
target_name.geometry, target_name.discr_stage)

exec_insn = self._get_exec_insn_func(
source_kernels=insn.source_kernels,
target_kernels=insn.target_kernels,
target_discr=target_discr,
)

timing_data, output_for_each_kernel = \
exec_insn(actx, strengths, kernel_args,
dtype, return_timing_data)
timing_data_arr.append(timing_data)

for o in output_group:
result = output_for_each_kernel[o.target_kernel_index]
if isinstance(target_discr, Discretization):
template_ary = actx.thaw(target_discr.nodes()[0])
result = unflatten(template_ary, result, actx, strict=False)

results.append((o.name, result))

timing_data = defaultdict(list)
if return_timing_data and timing_data_arr:
for timing_data in timing_data_arr:
for description, result in timing_data.items():
timing_data[description].merge(result)

return results, dict(timing_data)

# }}}

Expand Down
2 changes: 1 addition & 1 deletion pytential/unregularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def exec_compute_potential_insn(self, actx: PyOpenCLArrayContext,
insn, bound_expr, evaluate, return_timing_data):
if return_timing_data:
from warnings import warn
from pytential.source import UnableToCollectTimingData
from sumpy.fmm import UnableToCollectTimingData
warn(
"Timing data collection not supported.",
category=UnableToCollectTimingData)
Expand Down