Skip to content

Commit

Permalink
Merge pull request #1226 from gboeing/fix
Browse files Browse the repository at this point in the history
improve file reading and context management
  • Loading branch information
gboeing authored Oct 21, 2024
2 parents 46944f8 + b87d555 commit 6140263
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 16 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ repos:
args: [--disable=MD013]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.6.9"
rev: "v0.7.0"
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.11.2"
rev: "v1.12.1"
hooks:
- id: mypy
additional_dependencies:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Read the v2 [migration guide](https://github.com/gboeing/osmnx/issues/1123)
- improve docstrings throughout package (#1116)
- improve logging and warnings throughout package (#1125)
- improve error messages throughout package (#1131)
- improve internal file handling context management (#1226)
- refactor features module for speed improvement and memory efficiency (#1157 #1205)
- refactor save_graph_xml function and \_osm_xml module for speed improvement and bug fixes (#1135)
- make save_graph_xml function accept only an unsimplified MultiDiGraph as its input data (#1135)
Expand Down
13 changes: 10 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
https://www.sphinx-doc.org/en/master/usage/configuration.html
"""

import re
import sys
from pathlib import Path

Expand All @@ -18,9 +19,15 @@
pkg_root_path = str(Path.cwd().parent.parent)
sys.path.insert(0, pkg_root_path)

# dynamically load version from /osmnx/_version.py
with Path.open(Path("../../osmnx/_version.py")) as f:
version = release = f.read().split(" = ")[1].replace('"', "")
# dynamically load version from __version__ variable in /osmnx/_version.py
file_text = Path("../../osmnx/_version.py").read_text(encoding="utf-8")
regex = re.compile("^__version__ = ['\"]([^'\"]*)['\"]", re.MULTILINE)
match = re.search(regex, file_text)
if match:
version = release = match.group(1)
else:
msg = "Unable to find version string in file."
raise ValueError(msg)

# mock import all required + optional dependency packages because readthedocs
# does not have them installed
Expand Down
22 changes: 14 additions & 8 deletions osmnx/_osm_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import bz2
import logging as lg
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
Expand All @@ -32,6 +33,7 @@
from ._version import __version__ as osmnx_version

if TYPE_CHECKING:
from collections.abc import Iterator
from xml.sax.xmlreader import AttributesImpl

import geopandas as gpd
Expand Down Expand Up @@ -124,16 +126,20 @@ def _overpass_json_from_xml(filepath: str | Path, encoding: str) -> dict[str, An
A parsed JSON response from the Overpass API.
"""

# open the XML file, handling bz2 or regular XML
def _opener(filepath: Path, encoding: str) -> TextIO:
# open the XML file, handling bz2 or plain XML. use a wrapper context
# manager to yield the file handle, to ensure file will always get closed
@contextmanager
def _opener(filepath: Path, encoding: str) -> Iterator[TextIO]:
if filepath.suffix == ".bz2":
return bz2.open(filepath, mode="rt", encoding=encoding)

# otherwise just open it if it's not bz2
return filepath.open(encoding=encoding)
with bz2.open(filepath, mode="rt", encoding=encoding) as f:
yield f
else:
with Path(filepath).open(mode="rt", encoding=encoding) as f:
yield f

# warn if this XML file was generated by OSMnx itself
with _opener(Path(filepath), encoding) as f:
filepath = Path(filepath)
with _opener(filepath, encoding) as f:
root_attrs = etree_parse(f).getroot().attrib # noqa: S314
if "generator" in root_attrs and "OSMnx" in root_attrs["generator"]:
msg = (
Expand All @@ -146,7 +152,7 @@ def _opener(filepath: Path, encoding: str) -> TextIO:
warn(msg, category=UserWarning, stacklevel=2)

# parse the XML to Overpass-like JSON
with _opener(Path(filepath), encoding) as f:
with _opener(filepath, encoding) as f:
handler = _OSMContentHandler()
sax_parse(f, handler) # noqa: S317
return handler.object
Expand Down
5 changes: 2 additions & 3 deletions tests/test_osmnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,9 +581,8 @@ def test_save_load() -> None: # noqa: PLR0915
G2 = ox.load_graphml(fp, node_dtypes=nd, edge_dtypes=ed)

# test loading graphml from a file stream
file_bytes = Path.open(Path("tests/input_data/short.graphml"), "rb").read()
data = str(file_bytes.decode())
G = ox.load_graphml(graphml_str=data, node_dtypes=nd, edge_dtypes=ed)
graphml = Path("tests/input_data/short.graphml").read_text(encoding="utf-8")
G = ox.load_graphml(graphml_str=graphml, node_dtypes=nd, edge_dtypes=ed)


def test_graph_from() -> None:
Expand Down

0 comments on commit 6140263

Please sign in to comment.