Skip to content

Commit

Permalink
added docstring, broke apart super long test into many smaller ones
Browse files Browse the repository at this point in the history
  • Loading branch information
srivarra committed Feb 25, 2025
1 parent f673c2b commit 9f6dc2a
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 59 deletions.
25 changes: 25 additions & 0 deletions src/anndata/_core/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ def __get__(self, instance: NS | None, cls: type[NS]) -> NS | type[NS]:


def _check_namespace_signature(ns_class: type) -> None:
"""Validate the signature of a namespace class for AnnData extensions.
This function ensures that any class used to extend AnnData functionality
has an `__init__` method that accepts an AnnData instance as its second
parameter (after `self`), properly named 'adata' and with the correct
type annotation.
Parameters
----------
ns_class : type
The namespace class to validate.
Raises
------
TypeError
If the `__init__` method has fewer than 2 parameters (missing the AnnData parameter).
AttributeError
If the second parameter of `__init__` lacks a type annotation.
TypeError
If the second parameter of `__init__` is not named 'adata'.
TypeError
If the second parameter of `__init__` is not annotated as the 'AnnData' class.
TypeError
If both the name and type annotation of the second parameter are incorrect.
"""
sig = inspect.signature(ns_class.__init__)
params = list(sig.parameters.values())

Expand Down
180 changes: 121 additions & 59 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@


def test_find_stacklevel():
"""Test that find_stacklevel returns a positive integer.
This function helps determine the correct stacklevel for warnings, so
we just need to verify it returns a sensible value.
"""
level = extensions.find_stacklevel()
assert isinstance(level, int)
# It should be at least 1, otherwise something is wrong.
Expand Down Expand Up @@ -54,16 +59,13 @@ class Dummy:
assert getattr(dummy_obj, "dummy") is ns_instance


def test_register_namespace(monkeypatch):
"""Test the behavior of the register_anndata_namespace decorator.
def test_register_namespace_basic():
"""Test the basic behavior of the register_anndata_namespace decorator.
This test verifies that:
- A new namespace can be registered successfully.
- The accessor is available on AnnData instances.
- The accessor is cached on the AnnData instance.
- An warning is raised if the namespace name is overridden.
"""

original_dummy = getattr(ad.AnnData, "dummy", None)

# Register a new namespace called 'dummy'.
Expand All @@ -84,9 +86,58 @@ def greet(self) -> str:
assert ns_instance._adata is adata
assert ns_instance.greet() == "hello"

# Clean up
if original_dummy is not None:
setattr(ad.AnnData, "dummy", original_dummy)
else:
if hasattr(ad.AnnData, "dummy"):
delattr(ad.AnnData, "dummy")


def test_register_namespace_caching():
"""Test that the namespace accessor is cached on the AnnData instance."""
original_dummy = getattr(ad.AnnData, "dummy", None)

# Register a new namespace
@ad.register_anndata_namespace("dummy")
class DummyNamespace:
def __init__(self, adata: ad.AnnData):
self._adata = adata

def greet(self) -> str:
return "hello"

# Create an AnnData instance
rng = np.random.default_rng(42)
adata = ad.AnnData(X=rng.poisson(1, size=(10, 10)))

# Access the namespace to trigger caching
ns_instance = adata.dummy

# Verify caching behavior on the AnnData instance.
assert adata.dummy is ns_instance

# Clean up
if original_dummy is not None:
setattr(ad.AnnData, "dummy", original_dummy)
else:
if hasattr(ad.AnnData, "dummy"):
delattr(ad.AnnData, "dummy")


def test_register_namespace_override():
"""Test that a warning is raised when overriding an existing namespace."""
original_dummy = getattr(ad.AnnData, "dummy", None)

# Register a namespace first
@ad.register_anndata_namespace("dummy")
class DummyNamespace:
def __init__(self, adata: ad.AnnData):
self._adata = adata

def greet(self) -> str:
return "hello"

# Now, override the same namespace and check that a warning is emitted.
with pytest.warns(
UserWarning, match="Overriding existing custom namespace 'dummy'"
Expand All @@ -102,109 +153,120 @@ def greet(self) -> str:
return "world"

# A new AnnData instance should now use the overridden accessor.
adata2 = ad.AnnData(X=rng.poisson(1, size=(10, 10)))
assert adata2.dummy.greet() == "world"
rng = np.random.default_rng(42)
adata = ad.AnnData(X=rng.poisson(1, size=(10, 10)))
assert adata.dummy.greet() == "world"

# Clean up by restoring any previously existing attribute.
# Clean up
if original_dummy is not None:
setattr(ad.AnnData, "dummy", original_dummy)
else:
if hasattr(ad.AnnData, "dummy"):
delattr(ad.AnnData, "dummy")


def test_register_existing_attributes(monkeypatch):
def test_register_existing_attributes():
"""
Test that registering an accessor with a name that is an attribute of AnnData raises an attribute error.
Test that registering an accessor with a name that is a reserved attribute of AnnData raises an attribute error.
i.e. we do not want users to override say, `AnnData.X` or `AnnData.obs_names`, etc...
We only test a representative sample of important attributes rather than all of them.
"""
reserved_names = dir(ad.AnnData)

for reserved_name in reserved_names:
# Test a representative sample of key AnnData attributes
key_attributes = [
"X",
"obs",
"var",
"uns",
"obsm",
"varm",
"layers",
"copy",
"write",
]

for attr in key_attributes:
with pytest.raises(
AttributeError,
match=f"cannot override reserved attribute {reserved_name!r}",
match=f"cannot override reserved attribute {attr!r}",
):

@ad.register_anndata_namespace(reserved_name)
@ad.register_anndata_namespace(attr)
class DummyNamespace:
def __init__(self, adata: ad.AnnData) -> None:
self._adata = adata


def test_check_namespace_signature_comprehensive():
"""Comprehensive test for _check_namespace_signature covering all edge cases.
def test_check_namespace_signature_valid():
"""Test that a namespace with valid signature is accepted."""

We test the following cases:
1. Valid namespace: correct signature.
2. Missing the second parameter (i.e. the class only has self as a parameter).
3. Wrong parameter name: second parameter not named 'adata'.
4. Wrong annotation: second parameter annotated as wrong type.
5. Both wrong: wrong name and wrong annotation.
6. Missing annotation: no type annotation provided on the second parameter.
"""

# 1. Valid namespace: correct signature.
# Valid namespace: correct signature.
# should not raise any error.
@ad.register_anndata_namespace("valid")
class ValidNamespace:
def __init__(self, adata: ad.AnnData) -> None:
self.adata = adata

# Should not raise any error.
extensions._check_namespace_signature(ValidNamespace)

# 2. Missing the second parameter.
class MissingParamNamespace:
def __init__(self) -> None:
pass

def test_check_namespace_signature_missing_param():
"""Test that a namespace missing the second parameter is rejected."""
with pytest.raises(
TypeError,
match="Namespace initializer must accept an AnnData instance as the second parameter.",
):
extensions._check_namespace_signature(MissingParamNamespace)

# 3. Wrong parameter name: second parameter not named 'adata'.
class WrongNameNamespace:
def __init__(self, notadata: ad.AnnData) -> None:
self.notadata = notadata
@ad.register_anndata_namespace("missing_param")
class MissingParamNamespace:
def __init__(self) -> None:
pass


def test_check_namespace_signature_wrong_name():
"""Test that a namespace with wrong parameter name is rejected."""
with pytest.raises(
TypeError,
match="Namespace initializer's second parameter must be named 'adata', got 'notadata'.",
):
extensions._check_namespace_signature(WrongNameNamespace)

# 4. Wrong annotation: second parameter annotated as wrong type.
class WrongAnnotationNamespace:
def __init__(self, adata: int) -> None:
self.adata = adata
@ad.register_anndata_namespace("wrong_name")
class WrongNameNamespace:
def __init__(self, notadata: ad.AnnData) -> None:
self.notadata = notadata


def test_check_namespace_signature_wrong_annotation():
"""Test that a namespace with wrong parameter annotation is rejected."""
with pytest.raises(
TypeError,
match="Namespace initializer's second parameter must be annotated as the 'AnnData' class, got 'int'.",
):
extensions._check_namespace_signature(WrongAnnotationNamespace)

# 5. Both wrong: wrong name and wrong annotation.
class BothWrongNamespace:
def __init__(self, info: str) -> None:
self.info = info
@ad.register_anndata_namespace("wrong_annotation")
class WrongAnnotationNamespace:
def __init__(self, adata: int) -> None:
self.adata = adata


def test_check_namespace_signature_missing_annotation():
"""Test that a namespace with missing parameter annotation is rejected."""
with pytest.raises(AttributeError):

@ad.register_anndata_namespace("missing_annotation")
class MissingAnnotationNamespace:
def __init__(self, adata) -> None:
self.adata = adata


def test_check_namespace_signature_both_wrong():
"""Test that a namespace with both wrong name and annotation is rejected."""
with pytest.raises(
TypeError,
match=(
r"Namespace initializer's second parameter must be named 'adata', got 'info'\. "
r"And must be annotated as 'AnnData', got 'str'\."
),
):
extensions._check_namespace_signature(BothWrongNamespace)

# 6. Missing annotation: no type annotation provided on the second parameter.
class MissingAnnotationNamespace:
def __init__(self, adata) -> None:
self.adata = adata

with pytest.raises(AttributeError):
extensions._check_namespace_signature(MissingAnnotationNamespace)
@ad.register_anndata_namespace("both_wrong")
class BothWrongNamespace:
def __init__(self, info: str) -> None:
self.info = info

0 comments on commit 9f6dc2a

Please sign in to comment.