Skip to content

Commit

Permalink
Fix obonet caching (#314)
Browse files Browse the repository at this point in the history
This PR makes sure that the parsed obonet file is actually cached. It
also makes sure the tests don't pollute the active cache directory
  • Loading branch information
cthoyt authored Jan 15, 2025
1 parent 816826e commit 83b7f11
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
33 changes: 26 additions & 7 deletions src/pyobo/getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from .identifier_utils import ParseError, wrap_norm_prefix
from .plugins import has_nomenclature_plugin, run_nomenclature_plugin
from .reader import from_obo_path
from .reader import from_obo_path, from_obonet
from .struct import Obo
from .utils.io import get_writer
from .utils.path import ensure_path, prefix_directory_join
Expand Down Expand Up @@ -110,16 +110,30 @@ def get_ontology(
logger.info("UBERON has so much garbage in it that defaulting to non-strict parsing")
strict = False

if cache:
if not cache:
logger.debug("[%s] caching was turned off, so dont look for an obonet file", prefix)
obonet_json_gz_path = None
else:
obonet_json_gz_path = prefix_directory_join(
prefix, name=f"{prefix}.obonet.json.gz", ensure_exists=False, version=version
)
logger.debug(
"[%s] caching is turned on, so look for an obonet file at %s",
prefix,
obonet_json_gz_path,
)
if obonet_json_gz_path.exists() and not force:
from .reader import from_obonet
from .utils.cache import get_gzipped_graph

logger.debug("[%s] using obonet cache at %s", prefix, obonet_json_gz_path)
return from_obonet(get_gzipped_graph(obonet_json_gz_path))
return from_obonet(
get_gzipped_graph(obonet_json_gz_path),
strict=strict,
version=version,
upgrade=upgrade,
)
else:
logger.debug("[%s] no obonet cache found at %s", prefix, obonet_json_gz_path)

if has_nomenclature_plugin(prefix):
obo = run_nomenclature_plugin(prefix, version=version)
Expand All @@ -128,8 +142,6 @@ def get_ontology(
obo.write_default(force=force_process)
return obo

logger.debug("[%s] no obonet cache found at %s", prefix, obonet_json_gz_path)

ontology_format, path = _ensure_ontology_path(prefix, force=force, version=version)
if path is None:
raise NoBuildError(prefix)
Expand All @@ -144,7 +156,14 @@ def get_ontology(
else:
raise UnhandledFormatError(f"[{prefix}] unhandled ontology file format: {path.suffix}")

obo = from_obo_path(path, prefix=prefix, strict=strict, version=version, upgrade=upgrade)
obo = from_obo_path(
path,
prefix=prefix,
strict=strict,
version=version,
upgrade=upgrade,
_cache_path=obonet_json_gz_path,
)
if cache:
obo.write_default(force=force_process)
return obo
Expand Down
6 changes: 6 additions & 0 deletions src/pyobo/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .struct.struct_utils import Annotation, Stanza
from .struct.typedef import comment as has_comment
from .struct.typedef import default_typedefs, has_ontology_root_term
from .utils.cache import write_gzipped_graph
from .utils.misc import STATIC_VERSION_REWRITES, cleanup_version

__all__ = [
Expand All @@ -60,6 +61,7 @@ def from_obo_path(
version: str | None,
upgrade: bool = True,
ignore_obsolete: bool = False,
_cache_path: Path | None = None,
) -> Obo:
"""Get the OBO graph from a path."""
path = Path(path).expanduser().resolve()
Expand Down Expand Up @@ -87,6 +89,10 @@ def from_obo_path(
# Make sure the graph is named properly
_clean_graph_ontology(graph, prefix)

if _cache_path:
logger.info("[%s] writing obonet cache to %s", prefix, _cache_path)
write_gzipped_graph(path=_cache_path, graph=graph)

# Convert to an Obo instance and return
return from_obonet(graph, strict=strict, version=version, upgrade=upgrade)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_obo_reader/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class TestGet(unittest.TestCase):
def setUp(self) -> None:
"""Set up the test with the mock ChEBI OBO file."""
with chebi_patch:
self.ontology = get_ontology("chebi")
self.ontology = get_ontology("chebi", cache=False)

def test_get_id_alts_mapping(self):
"""Make sure the alternative ids are mapped properly.
Expand Down

0 comments on commit 83b7f11

Please sign in to comment.