Skip to content

Commit

Permalink
Add flag for turning off caching (#313)
Browse files Browse the repository at this point in the history
This is useful for testing purposes
  • Loading branch information
cthoyt authored Jan 15, 2025
1 parent f4572a6 commit 247a772
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ dependencies = [
"humanize",
"tabulate",
"cachier",
"pystow>=0.6.0",
"pystow>=0.7.0",
"bioversions>=0.7.0",
"bioregistry>=0.11.33",
"bioontologies>=0.5.2",
Expand Down
7 changes: 5 additions & 2 deletions src/pyobo/api/alts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import Unpack

from .utils import get_version_from_kwargs
from ..constants import GetOntologyKwargs, check_should_force
from ..constants import GetOntologyKwargs, check_should_cache, check_should_force
from ..getters import get_ontology
from ..identifier_utils import wrap_norm_prefix
from ..struct.reference import Reference
Expand Down Expand Up @@ -40,7 +40,10 @@ def get_id_to_alts(prefix: str, **kwargs: Unpack[GetOntologyKwargs]) -> Mapping[
path = prefix_cache_join(prefix, name="alt_ids.tsv", version=version)

@cached_multidict(
path=path, header=[f"{prefix}_id", "alt_id"], force=check_should_force(kwargs)
path=path,
header=[f"{prefix}_id", "alt_id"],
cache=check_should_cache(kwargs),
force=check_should_force(kwargs),
)
def _get_mapping() -> Mapping[str, list[str]]:
ontology = get_ontology(prefix, **kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/pyobo/api/hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_hierarchy(
force_process: bool = False,
version: str | None = None,
strict: bool = True,
cache: bool = True,
) -> nx.DiGraph:
"""Get hierarchy of parents as a directed graph.
Expand Down Expand Up @@ -77,6 +78,7 @@ def get_hierarchy(
force_process=force_process,
version=version,
strict=strict,
cache=cache,
)


Expand All @@ -94,6 +96,7 @@ def _get_hierarchy_helper(
force_process: bool = False,
version: str | None = None,
strict: bool = True,
cache: bool = True,
) -> nx.DiGraph:
rv = nx.DiGraph()

Expand Down
34 changes: 26 additions & 8 deletions src/pyobo/api/names.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .alts import get_primary_identifier
from .utils import get_version, get_version_from_kwargs
from ..constants import GetOntologyKwargs, check_should_force
from ..constants import GetOntologyKwargs, check_should_cache, check_should_force
from ..getters import NoBuildError, get_ontology
from ..identifier_utils import wrap_norm_prefix
from ..utils.cache import cached_collection, cached_mapping, cached_multidict
Expand Down Expand Up @@ -110,7 +110,11 @@ def get_ids(prefix: str, **kwargs: Unpack[GetOntologyKwargs]) -> set[str]:
version = get_version_from_kwargs(prefix, kwargs)
path = prefix_cache_join(prefix, name="ids.tsv", version=version)

@cached_collection(path=path, force=check_should_force(kwargs))
@cached_collection(
path=path,
force=check_should_force(kwargs),
cache=check_should_cache(kwargs),
)
def _get_ids() -> list[str]:
ontology = get_ontology(prefix, **kwargs)
return sorted(ontology.get_ids())
Expand All @@ -136,7 +140,12 @@ def get_id_name_mapping(
version = get_version_from_kwargs(prefix, kwargs)
path = prefix_cache_join(prefix, name="names.tsv", version=version)

@cached_mapping(path=path, header=[f"{prefix}_id", "name"], force=check_should_force(kwargs))
@cached_mapping(
path=path,
header=[f"{prefix}_id", "name"],
force=check_should_force(kwargs),
cache=check_should_cache(kwargs),
)
def _get_id_name_mapping() -> Mapping[str, str]:
ontology = get_ontology(prefix, **kwargs)
return ontology.get_id_name_mapping()
Expand Down Expand Up @@ -175,15 +184,17 @@ def get_definition(


def get_id_definition_mapping(
prefix: str,
**kwargs: Unpack[GetOntologyKwargs],
prefix: str, **kwargs: Unpack[GetOntologyKwargs]
) -> Mapping[str, str]:
"""Get a mapping of descriptions."""
version = get_version_from_kwargs(prefix, kwargs)
path = prefix_cache_join(prefix, name="definitions.tsv", version=version)

@cached_mapping(
path=path, header=[f"{prefix}_id", "definition"], force=check_should_force(kwargs)
path=path,
header=[f"{prefix}_id", "definition"],
force=check_should_force(kwargs),
cache=check_should_cache(kwargs),
)
def _get_mapping() -> Mapping[str, str]:
logger.info(
Expand All @@ -200,7 +211,11 @@ def get_obsolete(prefix: str, **kwargs: Unpack[GetOntologyKwargs]) -> set[str]:
version = get_version_from_kwargs(prefix, kwargs)
path = prefix_cache_join(prefix, name="obsolete.tsv", version=version)

@cached_collection(path=path, force=check_should_force(kwargs))
@cached_collection(
path=path,
force=check_should_force(kwargs),
cache=check_should_cache(kwargs),
)
def _get_obsolete() -> list[str]:
ontology = get_ontology(prefix, **kwargs)
return sorted(ontology.get_obsolete())
Expand All @@ -225,7 +240,10 @@ def get_id_synonyms_mapping(
path = prefix_cache_join(prefix, name="synonyms.tsv", version=version)

@cached_multidict(
path=path, header=[f"{prefix}_id", "synonym"], force=check_should_force(kwargs)
path=path,
header=[f"{prefix}_id", "synonym"],
force=check_should_force(kwargs),
cache=check_should_cache(kwargs),
)
def _get_multidict() -> Mapping[str, list[str]]:
logger.info("[%s v%s] no cached synonyms found. getting from OBO loader", prefix, version)
Expand Down
6 changes: 6 additions & 0 deletions src/pyobo/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class GetOntologyKwargs(SlimGetOntologyKwargs):
"""

version: str | None
cache: bool


def check_should_force(data: GetOntologyKwargs) -> bool:
Expand All @@ -152,6 +153,11 @@ def check_should_force(data: GetOntologyKwargs) -> bool:
return data.get("force", False) or data.get("force_process", False)


def check_should_cache(data: GetOntologyKwargs) -> bool:
"""Determine whether caching should be done based on generic keyword arguments."""
return data.get("cache", True)


class LookupKwargs(GetOntologyKwargs):
"""Represents all arguments passed to :func:`pyobo.get_ontology`.
Expand Down
28 changes: 17 additions & 11 deletions src/pyobo/getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_ontology(
version: str | None = None,
robot_check: bool = True,
upgrade: bool = True,
cache: bool = True,
) -> Obo:
"""Get the OBO for a given graph.
Expand All @@ -85,6 +86,8 @@ def get_ontology(
:param upgrade:
If set to true, will automatically upgrade relationships, such as
``obo:chebi#part_of`` to ``BFO:0000051``
:param cache:
Should cached objects be written? defaults to True
:returns: An OBO object
:raises OnlyOWLError: If the OBO foundry only has an OWL document for this resource.
Expand All @@ -107,20 +110,22 @@ def get_ontology(
logger.info("UBERON has so much garbage in it that defaulting to non-strict parsing")
strict = False

obonet_json_gz_path = prefix_directory_join(
prefix, name=f"{prefix}.obonet.json.gz", ensure_exists=False, version=version
)
if obonet_json_gz_path.exists() and not force:
from .reader import from_obonet
from .utils.cache import get_gzipped_graph
if cache:
obonet_json_gz_path = prefix_directory_join(
prefix, name=f"{prefix}.obonet.json.gz", ensure_exists=False, version=version
)
if obonet_json_gz_path.exists() and not force:
from .reader import from_obonet
from .utils.cache import get_gzipped_graph

logger.debug("[%s] using obonet cache at %s", prefix, obonet_json_gz_path)
return from_obonet(get_gzipped_graph(obonet_json_gz_path))
logger.debug("[%s] using obonet cache at %s", prefix, obonet_json_gz_path)
return from_obonet(get_gzipped_graph(obonet_json_gz_path))

if has_nomenclature_plugin(prefix):
obo = run_nomenclature_plugin(prefix, version=version)
logger.debug("[%s] caching nomenclature plugin", prefix)
obo.write_default(force=force_process)
if cache:
logger.debug("[%s] caching nomenclature plugin", prefix)
obo.write_default(force=force_process)
return obo

logger.debug("[%s] no obonet cache found at %s", prefix, obonet_json_gz_path)
Expand All @@ -140,7 +145,8 @@ def get_ontology(
raise UnhandledFormatError(f"[{prefix}] unhandled ontology file format: {path.suffix}")

obo = from_obo_path(path, prefix=prefix, strict=strict, version=version, upgrade=upgrade)
obo.write_default(force=force_process)
if cache:
obo.write_default(force=force_process)
return obo


Expand Down
3 changes: 2 additions & 1 deletion src/pyobo/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def __init__(
*,
use_tqdm: bool = False,
force: bool = False,
cache: bool = True,
):
"""Initialize the mapping cache."""
super().__init__(path=path, force=force)
super().__init__(path=path, cache=cache, force=force)
self.header = header
self.use_tqdm = use_tqdm

Expand Down
59 changes: 59 additions & 0 deletions tests/test_alt_ids.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Tests for alternative identifiers."""

import unittest
from contextlib import ExitStack
from unittest import mock

import bioregistry
from curies import Reference, ReferenceTuple

import pyobo
from pyobo import Reference as PyOBOReference
from pyobo import get_name, get_name_by_curie, get_primary_curie, get_primary_identifier
from pyobo.mocks import get_mock_id_alts_mapping, get_mock_id_name_mapping
from pyobo.struct.struct import Obo, Term, make_ad_hoc_ontology

mock_id_alts_mapping = get_mock_id_alts_mapping(
{
Expand All @@ -26,6 +32,24 @@
}
)

TEST_P1 = "test"

bioregistry.manager.synonyms[TEST_P1] = TEST_P1
bioregistry.manager.registry[TEST_P1] = bioregistry.Resource(
prefix=TEST_P1,
name="Test Semantic Space",
pattern="^\\d+$",
)


def patch_ontologies(ontology: Obo, targets: list[str]) -> ExitStack:
"""Patch multiple ontologies."""
stack = ExitStack()
for target in targets:
patch = mock.patch(target, return_value=ontology)
stack.enter_context(patch)
return stack


class TestAltIds(unittest.TestCase):
"""Tests for alternative identifiers."""
Expand Down Expand Up @@ -106,3 +130,38 @@ def test_no_alts(self, _, __):
primary_id = get_primary_identifier("ncbitaxon", "52818")
self.assertEqual("52818", primary_id)
self.assertEqual("Allamanda cathartica", get_name("ncbitaxon", "52818"))

def test_api(self) -> None:
"""Test getting the hierarchy."""
r1 = PyOBOReference(prefix=TEST_P1, identifier="1", name="test name")
r2 = PyOBOReference(prefix=TEST_P1, identifier="2")
t1 = Term(reference=r1).append_alt(r2)
t2 = Term(reference=r2)
ontology = make_ad_hoc_ontology(TEST_P1, terms=[t1, t2])

with patch_ontologies(
ontology, ["pyobo.api.names.get_ontology", "pyobo.api.alts.get_ontology"]
):
ids_alts = pyobo.get_id_to_alts(TEST_P1, cache=False)
self.assertEqual({"1": ["2"]}, ids_alts)

alts_ids = pyobo.get_alts_to_id(TEST_P1, cache=False)
self.assertEqual({"2": "1"}, alts_ids)

self.assertEqual("1", pyobo.get_primary_identifier(r1, cache=False))
self.assertEqual("1", pyobo.get_primary_identifier(r2, cache=False))

self.assertEqual("test:1", pyobo.get_primary_curie(r1.curie, cache=False))
self.assertEqual("test:1", pyobo.get_primary_curie(r2.curie, cache=False))

ids = pyobo.get_ids(TEST_P1, cache=False)
self.assertEqual({"1", "2"}, ids)

id_name = pyobo.get_id_name_mapping(TEST_P1, cache=False)
self.assertEqual({t1.identifier: t1.name}, id_name)

name_id = pyobo.get_name_id_mapping(TEST_P1, cache=False)
self.assertEqual({t1.name: t1.identifier}, name_id)

self.assertEqual(t1.name, pyobo.get_name(r1, cache=False))
self.assertEqual(t1.name, pyobo.get_name(r2, cache=False))

0 comments on commit 247a772

Please sign in to comment.