From a78aafef05b77a8046c9d0d657196fa8378201d1 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 22 Aug 2024 18:30:50 +0100 Subject: [PATCH] WIP: may have reintroduced deadlocks --- pyop2/compilation.py | 39 ++++++++++++++++++++++++++++++--------- pyop2/global_kernel.py | 17 +++++++++++------ test/unit/test_caching.py | 2 +- 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index c4ebe9306..6105a445a 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -55,6 +55,7 @@ from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError +import pyop2.global_kernel from petsc4py import PETSc @@ -420,6 +421,7 @@ def load_hashkey(*args, **kwargs): # JBTODO: This should not be memory cached +# ...benchmarking disagrees with my assessment @mpi.collective @memory_cache(hashkey=load_hashkey, broadcast=False) def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), @@ -440,8 +442,6 @@ def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), :kwarg comm: Optional communicator to compile the code on (only rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD). """ - from pyop2.global_kernel import GlobalKernel - if isinstance(jitmodule, str): class StrCode(object): def __init__(self, code, argtypes): @@ -451,7 +451,7 @@ def __init__(self, code, argtypes): # cache key self.argtypes = argtypes code = StrCode(jitmodule, argtypes) - elif isinstance(jitmodule, GlobalKernel): + elif isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): code = jitmodule else: raise ValueError("Don't know how to compile code of type %r" % type(jitmodule)) @@ -477,7 +477,7 @@ def __init__(self, code, argtypes): # This call is cached in memory by the OS dll = ctypes.CDLL(so_name) - if isinstance(jitmodule, GlobalKernel): + if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) fn = getattr(dll, fn_name) @@ -511,6 +511,13 @@ def read(self, filename): raise FileNotFoundError("File not on disk, cache miss") return filename + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return self[key] + def _make_so_hashkey(compiler, jitmodule, extension, comm): if extension == "cpp": @@ -546,7 +553,7 @@ def check_source_hashes(compiler, jitmodule, extension, comm): @mpi.collective @parallel_cache( hashkey=_make_so_hashkey, - cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so"), + cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so") ) def make_so(compiler, jitmodule, extension, comm, filename=None): """Build a shared library and load it @@ -560,12 +567,15 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): Returns a :class:`ctypes.CDLL` object of the resulting shared library.""" if filename is None: - tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") - tempdir.mkdir(exist_ok=True) - filename = tempdir.joinpath(f"foo{next(FILE_CYCLER)}.c") + # JBTODO: Remove this directory at some point? + pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") + tempdir = pyop2_tempdir.joinpath(f"{os.getpid()}") + # ~ tempdir = Path(mkdtemp(dir=pyop2_tempdir.joinpath(f"{os.getpid()}"))) + # This path + filename should be unique + filename = tempdir.joinpath("foo.c") else: + pyop2_tempdir = None filename = Path(filename).absolute() - filename.parent.mkdir(exist_ok=True) # Compilation communicators are reference counted on the PyOP2 comm icomm = mpi.internal_comm(comm, compiler) @@ -590,6 +600,11 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): # Compile on compilation communicator (ccomm) rank 0 if comm.rank == 0: + if pyop2_tempdir is None: + filename.parent.mkdir(exist_ok=True) + else: + pyop2_tempdir.mkdir(exist_ok=True) + tempdir.mkdir(exist_ok=True) logfile = path.joinpath(f"{base}_p{pid}.log") errfile = path.joinpath(f"{base}_p{pid}.err") with progress(INFO, 'Compiling wrapper'): @@ -612,6 +627,12 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): return soname +# JBTODO: Probably don't want to do this if we fail to compile... +# ~ @atexit +# ~ def _cleanup_tempdir(): + # ~ pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") + + def _run(cc, logfile, errfile, step="Compilation", filemode="w"): debug(f"{step} command: {' '.join(cc)}") try: diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index d9108119f..7e313a5e8 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -10,7 +10,8 @@ import pytools from petsc4py import PETSc -from pyop2 import compilation, mpi +from pyop2 import mpi +from pyop2.compilation import load from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -397,11 +398,15 @@ def compile(self, comm): + tuple(self.local_kernel.ldargs) ) - return compilation.load(self, extension, self.name, - cppargs=cppargs, - ldargs=ldargs, - restype=ctypes.c_int, - comm=comm) + return load( + self, + extension, + self.name, + cppargs=cppargs, + ldargs=ldargs, + restype=ctypes.c_int, + comm=comm + ) @cached_property def argtypes(self): diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 6ab909b29..e335ec680 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -290,7 +290,7 @@ def cache(self): _cache_collection = int_comm.Get_attr(mpi.comm_cache_keyval) if _cache_collection is None: _cache_collection = {default_cache_name: DEFAULT_CACHE()} - mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache_collection) + int_comm.Set_attr(mpi.comm_cache_keyval, _cache_collection) return _cache_collection[default_cache_name] @pytest.fixture