Skip to content

Commit

Permalink
Add a pytential specific DefaultExpansionFactory to get a QBX local e…
Browse files Browse the repository at this point in the history
…xpansion
  • Loading branch information
isuruf authored and inducer committed Apr 18, 2022
1 parent 30e81b4 commit 4e576a0
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions pytential/qbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytools import memoize_method, memoize_in, single_valued
from pytential.qbx.target_assoc import QBXTargetAssociationFailedException
from pytential.source import LayerPotentialSourceBase
from sumpy.expansion import DefaultExpansionFactory as DefaultExpansionFactoryBase

import logging
logger = logging.getLogger(__name__)
Expand All @@ -37,11 +38,20 @@
.. autoclass:: QBXLayerPotentialSource
.. autoclass:: QBXTargetAssociationFailedException
.. autoclass:: DefaultExpansionFactory
"""


# {{{ QBX layer potential source

class DefaultExpansionFactory(DefaultExpansionFactoryBase):
"""A expansion factory to create QBX local, local and multipole expansions
"""
def get_qbx_local_expansion_class(self, kernel):
return self.get_local_expansion_class(kernel)


class _not_provided: # noqa: N801
pass

Expand Down Expand Up @@ -188,7 +198,6 @@ def fmm_level_to_order(kernel, kernel_args, tree, level): # noqa pylint:disable
self.fmm_backend = fmm_backend

if expansion_factory is None:
from sumpy.expansion import DefaultExpansionFactory
expansion_factory = DefaultExpansionFactory()
self.expansion_factory = expansion_factory

Expand Down Expand Up @@ -533,9 +542,15 @@ def _tree_indep_data_for_wrangler(self, source_kernels, target_kernels):
local_expn_class = \
self.expansion_factory.get_local_expansion_class(base_kernel)

try:
qbx_local_expn_class = \
self.expansion_factory.get_qbx_local_expansion_class(base_kernel)
except AttributeError:
qbx_local_expn_class = local_expn_class

fmm_mpole_factory = partial(mpole_expn_class, base_kernel)
fmm_local_factory = partial(local_expn_class, base_kernel)
qbx_local_factory = partial(local_expn_class, base_kernel)
qbx_local_factory = partial(qbx_local_expn_class, base_kernel)

if self.fmm_backend == "sumpy":
from pytential.qbx.fmm import \
Expand Down

0 comments on commit 4e576a0

Please sign in to comment.