Skip to content

Commit

Permalink
test: use MemoryPool in testing actx
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Jul 13, 2024
1 parent fb971ab commit b28d938
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(self, queue, allocator=None,
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)", stacklevel=2)

super().__init__(queue, allocator,
compile_trace_callback=compile_trace_callback)

Expand Down Expand Up @@ -509,11 +510,34 @@ class PytestPyOpenCLArrayContextFactory(
_PytestPyOpenCLArrayContextFactoryWithClass):
actx_class = PyOpenCLArrayContext

def __call__(self):
actx = super().__call__()
if actx.allocator is not None:
return actx

from pyopencl.tools import ImmediateAllocator, MemoryPool
alloc = MemoryPool(ImmediateAllocator(actx.queue))

return self.actx_class(
actx.queue,
allocator=alloc,
force_device_scalars=self.force_device_scalars)


class PytestPytatoPyOpenCLArrayContextFactory(
_PytestPytatoPyOpenCLArrayContextFactory):
actx_class = PytatoPyOpenCLArrayContext

def __call__(self):
actx = super().__call__()
if actx.allocator is not None:
return actx

from pyopencl.tools import ImmediateAllocator, MemoryPool
alloc = MemoryPool(ImmediateAllocator(actx.queue))

return self.actx_class(actx.queue, allocator=alloc)


# deprecated
class PytestPyOpenCLArrayContextFactoryWithHostScalars(
Expand Down

0 comments on commit b28d938

Please sign in to comment.