Skip to content

Commit

Permalink
make code containers functions instead of methods
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Apr 19, 2022
1 parent 4e576a0 commit 2dd9746
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 98 deletions.
8 changes: 4 additions & 4 deletions pytential/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ def _get_qbx_discretization(self, geometry, discr_stage):
except KeyError:
dofdesc = sym.DOFDescriptor(geometry, discr_stage)

from pytential.qbx.refinement import refiner_code_container
wrangler = refiner_code_container(lpot_source._setup_actx).get_wrangler()

from pytential.qbx.refinement import _refine_for_global_qbx
# NOTE: this adds the required discretizations to the cache
_refine_for_global_qbx(self, dofdesc,
lpot_source.refiner_code_container.get_wrangler(),
_copy_collection=False)

_refine_for_global_qbx(self, dofdesc, wrangler, _copy_collection=False)
discr = self._get_discr_from_cache(geometry, discr_stage)

return discr
Expand Down
16 changes: 9 additions & 7 deletions pytential/linalg/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def partition_by_nodes(
discr = places.get_discretization(dofdesc.geometry, dofdesc.discr_stage)

if tree_kind is not None:
from pytential.qbx.utils import tree_code_container
tcc = tree_code_container(lpot_source._setup_actx)

from arraycontext import thaw
builder = lpot_source.tree_code_container.build_tree()
tree, _ = builder(actx.queue,
tree, _ = tcc.build_tree()(actx.queue,
particles=flatten(
thaw(discr.nodes(), actx), actx, leaf_class=DOFArray
),
Expand Down Expand Up @@ -594,12 +596,12 @@ def prg():

# {{{ perform area query

builder = lpot_source.tree_code_container.build_tree()
tree, _ = builder(actx.queue, sources,
max_particles_in_box=max_particles_in_box)
from pytential.qbx.utils import tree_code_container
tcc = tree_code_container(lpot_source._setup_actx)

builder = lpot_source.tree_code_container.build_area_query()
query, _ = builder(actx.queue, tree, pxy.centers, pxy.radii)
tree, _ = tcc.build_tree()(actx.queue, sources,
max_particles_in_box=max_particles_in_box)
query, _ = tcc.build_area_query()(actx.queue, tree, pxy.centers, pxy.radii)

tree = tree.get(actx.queue)
query = query.get(actx.queue)
Expand Down
67 changes: 9 additions & 58 deletions pytential/qbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,60 +341,6 @@ def copy(

# }}}

# {{{ code containers

@property
def tree_code_container(self):
@memoize_in(self._setup_actx, (
QBXLayerPotentialSource, "tree_code_container"))
def make_container():
from pytential.qbx.utils import TreeCodeContainer
return TreeCodeContainer(self._setup_actx)

return make_container()

@property
def refiner_code_container(self):
@memoize_in(self._setup_actx, (
QBXLayerPotentialSource, "refiner_code_container"))
def make_container():
from pytential.qbx.refinement import RefinerCodeContainer
return RefinerCodeContainer(
self._setup_actx, self.tree_code_container)

return make_container()

@property
def target_association_code_container(self):
@memoize_in(self._setup_actx, (
QBXLayerPotentialSource, "target_association_code_container"))
def make_container():
from pytential.qbx.target_assoc import TargetAssociationCodeContainer
return TargetAssociationCodeContainer(
self._setup_actx, self.tree_code_container)

return make_container()

@property
def qbx_fmm_geometry_data_code_container(self):
@memoize_in(self._setup_actx, (
QBXLayerPotentialSource, "qbx_fmm_geometry_data_code_container"))
def make_container(
debug, ambient_dim, well_sep_is_n_away,
from_sep_smaller_crit):
from pytential.qbx.geometry import QBXFMMGeometryDataCodeContainer
return QBXFMMGeometryDataCodeContainer(
self._setup_actx,
ambient_dim, self.tree_code_container, debug,
_well_sep_is_n_away=well_sep_is_n_away,
_from_sep_smaller_crit=from_sep_smaller_crit)

return make_container(
self.debug, self.ambient_dim,
self._well_sep_is_n_away, self._from_sep_smaller_crit)

# }}}

# {{{ internal API

@memoize_method
Expand All @@ -409,11 +355,16 @@ def qbx_fmm_geometry_data(self, places, name,
:class:`pytential.target.TargetBase`
instance
"""
from pytential.qbx.geometry import QBXFMMGeometryData
from pytential.qbx.geometry import qbx_fmm_geometry_data_code_container
code_container = qbx_fmm_geometry_data_code_container(
self._setup_actx, self.ambient_dim,
debug=self.debug,
well_sep_is_n_away=self._well_sep_is_n_away,
from_sep_smaller_crit=self._from_sep_smaller_crit)

return QBXFMMGeometryData(places, name,
self.qbx_fmm_geometry_data_code_container,
target_discrs_and_qbx_sides,
from pytential.qbx.geometry import QBXFMMGeometryData
return QBXFMMGeometryData(
places, name, code_container, target_discrs_and_qbx_sides,
target_association_tolerance=self.target_association_tolerance,
tree_kind=self._tree_kind,
debug=self.debug)
Expand Down
39 changes: 30 additions & 9 deletions pytential/qbx/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np

from pytools import memoize_method, log_process
from pytools import memoize_method, memoize_in, log_process
from arraycontext import PyOpenCLArrayContext, flatten, freeze
from meshmode.dof_array import DOFArray

Expand Down Expand Up @@ -107,16 +107,17 @@ class target_state(Enum): # noqa


class QBXFMMGeometryDataCodeContainer(TreeCodeContainerMixin):
def __init__(self, actx: PyOpenCLArrayContext, ambient_dim,
tree_code_container, debug,
_well_sep_is_n_away, _from_sep_smaller_crit):
def __init__(self,
actx: PyOpenCLArrayContext, ambient_dim: int, debug: bool,
_well_sep_is_n_away: int, _from_sep_smaller_crit: str) -> None:
self._setup_actx = actx
self.ambient_dim = ambient_dim
self.tree_code_container = tree_code_container
self.debug = debug
self._well_sep_is_n_away = _well_sep_is_n_away
self._from_sep_smaller_crit = _from_sep_smaller_crit

self._setup_actx = actx.clone()
self.debug = debug
from pytential.qbx.utils import tree_code_container
self.tree_code_container = tree_code_container(actx)

@memoize_method
def copy_targets_kernel(self):
Expand Down Expand Up @@ -260,6 +261,26 @@ def rotation_classes_builder(self):
from boxtree.rotation_classes import RotationClassesBuilder
return RotationClassesBuilder(self._setup_actx.context)


def qbx_fmm_geometry_data_code_container(
actx: PyOpenCLArrayContext, ambient_dim: int, *,
debug: bool,
well_sep_is_n_away: int,
from_sep_smaller_crit: str) -> QBXFMMGeometryDataCodeContainer:
@memoize_in(actx, (
QBXFMMGeometryDataCodeContainer, qbx_fmm_geometry_data_code_container))
def make_container(
_ambient_dim, _debug,
_well_sep_is_n_away, _from_sep_smaller_crit):
return QBXFMMGeometryDataCodeContainer(
actx, _ambient_dim, _debug,
_well_sep_is_n_away=_well_sep_is_n_away,
_from_sep_smaller_crit=_from_sep_smaller_crit)

return make_container(
ambient_dim, debug,
well_sep_is_n_away, from_sep_smaller_crit)

# }}}


Expand Down Expand Up @@ -759,9 +780,9 @@ def user_target_to_center(self):
PointsTarget(target_info.targets[:, self.ncenters:]),
target_side_prefs.astype(np.int32))]

from pytential.qbx.target_assoc import target_association_code_container
target_association_wrangler = (
self.lpot_source.target_association_code_container
.get_wrangler(actx))
target_association_code_container(actx).get_wrangler(actx))

tgt_assoc_result = associate_targets_to_qbx_centers(
self.places,
Expand Down
20 changes: 15 additions & 5 deletions pytential/qbx/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from arraycontext import PyOpenCLArrayContext, flatten
from meshmode.dof_array import DOFArray

from pytools import memoize_method
from pytools import memoize_method, memoize_in
from boxtree.area_query import AreaQueryElementwiseTemplate
from boxtree.tools import InlineBinarySearch
from pytential.qbx.utils import (
Expand Down Expand Up @@ -219,9 +219,11 @@

class RefinerCodeContainer(TreeCodeContainerMixin):

def __init__(self, actx: PyOpenCLArrayContext, tree_code_container):
def __init__(self, actx: PyOpenCLArrayContext):
self.array_context = actx
self.tree_code_container = tree_code_container

from pytential.qbx.utils import tree_code_container
self.tree_code_container = tree_code_container(actx)

@memoize_method
def expansion_disk_undisturbed_by_sources_checker(
Expand Down Expand Up @@ -271,6 +273,14 @@ def element_prop_threshold_checker(self):
def get_wrangler(self):
return RefinerWrangler(self.array_context, self)


def refiner_code_container(actx: PyOpenCLArrayContext) -> RefinerCodeContainer:
@memoize_in(actx, (RefinerCodeContainer, refiner_code_container))
def make_container():
return RefinerCodeContainer(actx)

return make_container()

# }}}


Expand Down Expand Up @@ -964,8 +974,8 @@ def refine_geometry_collection(places,
if not isinstance(lpot_source, QBXLayerPotentialSource):
continue

_refine_for_global_qbx(places, dofdesc,
lpot_source.refiner_code_container.get_wrangler(),
wrangler = refiner_code_container(lpot_source._setup_actx).get_wrangler()
_refine_for_global_qbx(places, dofdesc, wrangler,
group_factory=group_factory,
kernel_length_scale=kernel_length_scale,
scaled_max_curvature_threshold=scaled_max_curvature_threshold,
Expand Down
18 changes: 15 additions & 3 deletions pytential/qbx/target_assoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import numpy as np

from pytools import memoize_method
from pytools import memoize_method, memoize_in
from boxtree.tools import DeviceDataRecord
from boxtree.area_query import AreaQueryElementwiseTemplate
from boxtree.tools import InlineBinarySearch
Expand Down Expand Up @@ -440,9 +440,11 @@ class QBXTargetAssociation(DeviceDataRecord):

class TargetAssociationCodeContainer(TreeCodeContainerMixin):

def __init__(self, actx: PyOpenCLArrayContext, tree_code_container):
def __init__(self, actx: PyOpenCLArrayContext):
self.array_context = actx
self.tree_code_container = tree_code_container

from pytential.qbx.utils import tree_code_container
self.tree_code_container = tree_code_container(actx)

@property
def cl_context(self):
Expand Down Expand Up @@ -493,6 +495,16 @@ def get_wrangler(self, actx: PyOpenCLArrayContext):
return TargetAssociationWrangler(actx, code_container=self)


def target_association_code_container(
actx: PyOpenCLArrayContext) -> TargetAssociationCodeContainer:
@memoize_in(actx, (
TargetAssociationCodeContainer, target_association_code_container))
def make_container():
return TargetAssociationCodeContainer(actx)

return make_container()


class TargetAssociationWrangler(TreeWranglerBase):

@log_process(logger)
Expand Down
11 changes: 10 additions & 1 deletion pytential/qbx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import numpy as np

from pytools import memoize_method, log_process
from pytools import memoize_method, memoize_in, log_process
from arraycontext import PyOpenCLArrayContext
from meshmode.dof_array import DOFArray

Expand Down Expand Up @@ -91,6 +91,14 @@ def build_area_query(self):
from boxtree.area_query import AreaQueryBuilder
return AreaQueryBuilder(self.array_context.context)


def tree_code_container(actx: PyOpenCLArrayContext) -> TreeCodeContainer:
@memoize_in(actx, (TreeCodeContainer, tree_code_container))
def make_container():
return TreeCodeContainer(actx)

return make_container()

# }}}


Expand All @@ -110,6 +118,7 @@ def peer_list_finder(self):
def particle_list_filter(self):
return self.tree_code_container.particle_list_filter()


# }}}


Expand Down
15 changes: 4 additions & 11 deletions test/test_global_qbx.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,8 @@ def targets_from_sources(sign, dist, dim=2):
# {{{ run target associator and check

from pytential.qbx.target_assoc import (
TargetAssociationCodeContainer, associate_targets_to_qbx_centers)

from pytential.qbx.utils import TreeCodeContainer
code_container = TargetAssociationCodeContainer(
actx, TreeCodeContainer(actx))
target_association_code_container, associate_targets_to_qbx_centers)
code_container = target_association_code_container(actx)

target_assoc = (
associate_targets_to_qbx_centers(
Expand Down Expand Up @@ -543,13 +540,9 @@ def test_target_association_failure(actx_factory):
)

from pytential.qbx.target_assoc import (
TargetAssociationCodeContainer, associate_targets_to_qbx_centers,
target_association_code_container, associate_targets_to_qbx_centers,
QBXTargetAssociationFailedException)

from pytential.qbx.utils import TreeCodeContainer

code_container = TargetAssociationCodeContainer(
actx, TreeCodeContainer(actx))
code_container = target_association_code_container(actx)

with pytest.raises(QBXTargetAssociationFailedException):
associate_targets_to_qbx_centers(
Expand Down

0 comments on commit 2dd9746

Please sign in to comment.