Skip to content

Commit

Permalink
Add a module level flag indicating whether os.fork needs a patch
Browse files Browse the repository at this point in the history
Drop use of monkeypatch fixture for os.fork isolation. Add some
docs and comments to the two new fixtures and rename the test
module.
  • Loading branch information
sgillies committed Dec 11, 2023
1 parent 9b62394 commit 82687d2
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 119 deletions.
42 changes: 41 additions & 1 deletion tiledb/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_ctx_var = ContextVar("ctx")

already_warned = False
_needs_fork_wrapper = sys.platform != "win32" and sys.version_info < (3, 12)


class Config(lt.Config):
Expand Down Expand Up @@ -345,6 +346,16 @@ def __init__(self, config: Config = None):

self._set_default_tags()

# The core tiledb library uses threads and it's easy
# to experience deadlocks when forking a process that is using
# tiledb. The project doesn't have a solution for this at the
# moment other than to avoid using fork(), which is the same
# recommendation that Python makes. Python 3.12 warns if you
# fork() when multiple threads are detected and Python 3.14 will
# make it so you never accidentally fork(): multiprocessing will
# default to "spawn" on Linux.
_ensure_os_fork_wrap()

def __repr__(self):
return "tiledb.Ctx() [see Ctx.config() for configuration]"

Expand Down Expand Up @@ -439,7 +450,6 @@ def check_ipykernel_warn_once():
global already_warned
if not already_warned:
try:
import sys
import warnings

if "ipykernel" in sys.modules and tuple(
Expand Down Expand Up @@ -521,7 +531,37 @@ def default_ctx(config: Union["Config", dict] = None) -> "Ctx":
ctx = _ctx_var.get()
if config is not None:
raise tiledb.TileDBError("Global context already initialized!")
_ensure_os_fork_wrap()
except LookupError:
ctx = tiledb.Ctx(config)
_ctx_var.set(ctx)
return ctx


def _ensure_os_fork_wrap():
global _needs_fork_wrapper
if _needs_fork_wrapper:
import os
import warnings
from functools import wraps

def warning_wrapper(func):
@wraps(func)
def wrapper():
warnings.warn(
"TileDB is a multithreading library and deadlocks "
"are likely if fork() is called after a TileDB "
"context has been created (such as for array "
"access). To safely use TileDB with "
"multiprocessing or concurrent.futures, choose "
"'spawn' as the start method for child processes. "
"For example: "
"multiprocessing.set_start_method('spawn').",
UserWarning,
)
return func()

return wrapper

os.fork = warning_wrapper(os.fork)
_needs_fork_wrapper = False
26 changes: 16 additions & 10 deletions tiledb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,19 @@ def __init__(self, params=None):


@pytest.fixture(scope="function", autouse=True)
def isolate_os_fork(monkeypatch):
# Use monkeypatch to set an attribute to itself, what?
# This makes sure that before any test is run, we save the original
# value of os.fork, i.e. <built-in function fork>, and then
# restore it at the end of every test. Calling Ctx() may patch
# os.fork at runtime.
if sys.platform == "win32":
pass
else:
monkeypatch.setattr(os, "fork", os.fork)
def isolate_os_fork(original_os_fork):
"""Guarantee that tests start and finish with no os.fork patch."""
# Python 3.12 warns about fork() and threads. Tiledb only patches
# os.fork for Pythons 3.8-3.11.
if sys.platform != "win32" and sys.version_info < (3, 12):
tiledb.ctx._needs_fork_wrapper = True
os.fork = original_os_fork
yield
tiledb.ctx._needs_fork_wrapper = True
os.fork = original_os_fork


@pytest.fixture(scope="session")
def original_os_fork():
"""Provides the original unpatched os.fork."""
return os.fork
96 changes: 96 additions & 0 deletions tiledb/tests/test_fork_ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Tests combining fork with tiledb context threads.
Background: the core tiledb library uses threads and it's easy to
experience deadlocks when forking a process that is using tiledb. The
project doesn't have a solution for this at the moment other than to
avoid using fork(), which is the same recommendation that Python makes.
Python 3.12 warns if you fork() when multiple threads are detected and
Python 3.14 will make it so you never accidentally fork():
multiprocessing will default to "spawn" on Linux.
"""

import multiprocessing
import os
import sys
import warnings

import pytest

import tiledb


def test_no_warning_fork_without_ctx():
"""Get no warning if no tiledb context exists."""
with warnings.catch_warnings():
warnings.simplefilter("error")
pid = os.fork()
if pid == 0:
os._exit(0)
else:
os.wait()


@pytest.mark.skipif(
sys.platform == "win32", reason="fork() is not available on Windows"
)
def test_warning_fork_with_ctx():
"""Get a warning if we fork after creating a tiledb context."""
_ = tiledb.Ctx()
with pytest.warns(UserWarning, match="TileDB is a multithreading library"):
pid = os.fork()
if pid == 0:
os._exit(0)
else:
os.wait()


@pytest.mark.skipif(
sys.platform == "win32", reason="fork() is not available on Windows"
)
def test_warning_fork_with_default_ctx():
"""Get a warning if we fork after creating a default context."""
_ = tiledb.default_ctx()
with pytest.warns(UserWarning, match="TileDB is a multithreading library"):
pid = os.fork()
if pid == 0:
os._exit(0)
else:
os.wait()

pass


def test_no_warning_multiprocessing_without_ctx():
"""Get no warning if no tiledb context exists."""
with warnings.catch_warnings():
warnings.simplefilter("error")
mp = multiprocessing.get_context("fork")
p = mp.Process()
p.start()
p.join()


@pytest.mark.skipif(
sys.platform == "win32", reason="fork() is not available on Windows"
)
def test_warning_multiprocessing_with_ctx():
"""Get a warning if we fork after creating a tiledb context."""
_ = tiledb.Ctx()
mp = multiprocessing.get_context("fork")
p = mp.Process()
with pytest.warns(UserWarning, match="TileDB is a multithreading library"):
p.start()
p.join()


@pytest.mark.skipif(
sys.platform == "win32", reason="fork() is not available on Windows"
)
def test_warning_multiprocessing_with_default_ctx():
"""Get a warning if we fork after creating a default context."""
_ = tiledb.default_ctx()
mp = multiprocessing.get_context("fork")
p = mp.Process()
with pytest.warns(UserWarning, match="TileDB is a multithreading library"):
p.start()
p.join()
108 changes: 0 additions & 108 deletions tiledb/tests/test_multiprocessing.py

This file was deleted.

0 comments on commit 82687d2

Please sign in to comment.