Skip to content

Commit

Permalink
add detailed exception classes for each error condition
Browse files Browse the repository at this point in the history
  • Loading branch information
mki-c2c committed Feb 4, 2025
1 parent 9bffdbd commit 5ccc5c0
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 38 deletions.
22 changes: 20 additions & 2 deletions geonetwork/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,32 @@
from typing import Dict, Any
from requests import Response


class GnException(Exception):
pass
def __init__(self, code: int, details: Dict[str, Any]):
super().__init__()
self.code = code
self.details = details


class AuthException(GnException):
pass


class APIVersionException(GnException):
pass
def __init__(self, details: Dict[str, Any]):
super().__init__(501, details)


class ParameterException(GnException):
pass


class TimeoutException(GnException):
def __init__(self, details: Dict[str, Any]):
super().__init__(504, details)


def raise_for_status(response: Response):
if 400 <= response.status_code < 600:
raise GnException(response.status_code, {"response": response})
32 changes: 17 additions & 15 deletions geonetwork/gn_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from io import BytesIO
from typing import Union, Literal, IO, Any
from typing import Union, Literal, IO, Any, Dict
from .gn_session import GnSession, Credentials, logger
from .exceptions import APIVersionException, ParameterException
from .exceptions import APIVersionException, ParameterException, raise_for_status


GN_VERSION_RANGE = ["4.2.8", "4.4.5"]
Expand Down Expand Up @@ -41,17 +41,16 @@ def _init_xsrf_token(self):
def _get_version(self):
version_url = self.api_url + "/site"
resp = self.session.get(version_url)
resp.raise_for_status()
raise_for_status(resp)
version = resp.json().get("system/platform/version")
if (
(version is None)
or (version < GN_VERSION_RANGE[0])
or (version > GN_VERSION_RANGE[1])
):
raise APIVersionException(
{
"code": 501,
"msg": f"Version {version} not in allowed range {GN_VERSION_RANGE}",
details={
"message": f"Version {version} not in allowed range {GN_VERSION_RANGE}",
}
)
logger.info("GN API Session started with geonetwork server version %s", version)
Expand All @@ -68,8 +67,8 @@ def get_record_zip(self, uuid: str) -> IO[bytes]:
headers={"accept": "application/zip"},
)
if resp.status_code == 404:
raise ParameterException({"code": 404, "msg": f"UUID {uuid} not found"})
resp.raise_for_status()
raise ParameterException(code=404, details={"message": f"UUID {uuid} not found"})
raise_for_status(resp)
return BytesIO(resp.content)

def put_record_zip(self, zipdata: IO[bytes], overwrite: bool = True) -> Any:
Expand All @@ -87,7 +86,7 @@ def put_record_zip(self, zipdata: IO[bytes], overwrite: bool = True) -> Any:
"uuidProcessing": "OVERWRITE" if overwrite else "GENERATEUUID",
},
)
resp.raise_for_status()
raise_for_status(resp)
results = resp.json()
if results["errors"]:
clean_error_stack = [
Expand All @@ -98,7 +97,10 @@ def put_record_zip(self, zipdata: IO[bytes], overwrite: bool = True) -> Any:
for err in results["errors"]
]

raise ParameterException({"code": 404, "details": clean_error_stack})
raise ParameterException(code=400, details={
"message": f"POST {self.api_url}/records failed",
"stack": clean_error_stack
})

# take first id of results ids
serial_id = next(iter(results["metadataInfos"].values()))["uuid"]
Expand All @@ -121,7 +123,7 @@ def get_metadataxml(self, uuid):
url,
headers=headers,
)
resp.raise_for_status()
raise_for_status(resp)
return resp.content

UuidProcs = Literal["NOTHING", "OVERWRITE", "GENERATEUUID", "REMOVE_AND_REPLACE"]
Expand All @@ -143,7 +145,7 @@ def upload_metadata(self, metadata, groupid='100', uuidprocessing: UuidProcs = "
params=params,
files={"file": metadata},
)
response.raise_for_status()
raise_for_status(response)
return response

def get_thesaurus_dict(self):
Expand Down Expand Up @@ -203,10 +205,10 @@ def delete_thesaurus_dict(self, name):
"""
url = self.api_url + "/registries/vocabularies/" + name
response = self.session.delete(url)
response.raise_for_status()
raise_for_status(response)
return response.json()

def search(self, query: dict[str, Any]) -> dict[str, Any]:
def search(self, query: Dict[str, Any]) -> Dict[str, Any]:
"""
Use geonetwork API to search metadata
:param query: query generated by frontend app like datahub of geonetwork
Expand All @@ -218,7 +220,7 @@ def search(self, query: dict[str, Any]) -> dict[str, Any]:
url,
json=query
)
resp.raise_for_status()
raise_for_status(resp)
return resp.json()

def close_session(self):
Expand Down
29 changes: 19 additions & 10 deletions geonetwork/gn_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import requests
from requests.exceptions import ConnectTimeout, ConnectionError, HTTPError
from typing import Union, Dict, Any
from collections import namedtuple
from .exceptions import AuthException, TimeoutException


Credentials = namedtuple("Credentials", ["login", "password"])
Expand Down Expand Up @@ -46,14 +48,21 @@ def request(self, *args: Any, **kwargs: Any) -> Any:
url = args[1] if len(args) >= 2 else kwargs.get("url")
request_headers = kwargs.get("headers", {})
consolidated_headers = {**self.base_headers, **request_headers}
r = super().request(
*args, **{
**kwargs,
"auth": self.credentials,
"headers": consolidated_headers,
"verify": self.verifytls,
}
)
logger.debug("Queried [%s] %s, got status %s", method, url, r.status_code)
logger.debug("Header: %s", consolidated_headers)
try:
r = super().request(
*args, **{
**kwargs,
"auth": self.credentials,
"headers": consolidated_headers,
"verify": self.verifytls,
}
)
except (ConnectTimeout, ConnectionError) as err:
logger.debug("[%s] %s: %s", method, url, err.__class__.__name__)
raise TimeoutException({"message": f"connection failed to {url}", "error": err})
logger.debug("[%s] %s, status %s", method, url, r.status_code)
logger.debug("Headers: %s", consolidated_headers)
if r.status_code in [401, 403]:
logger.debug("Authentication failed at [%s] %s", method, url)
raise AuthException(r.status_code, {"message": f"auth failed at {url}", "response": r})
return r
15 changes: 7 additions & 8 deletions tests/test_gn_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
from io import BytesIO
from requests.exceptions import HTTPError
import requests_mock
from geonetwork import GnApi
from geonetwork.exceptions import APIVersionException, ParameterException
from geonetwork.exceptions import APIVersionException, ParameterException, AuthException


@pytest.fixture
Expand Down Expand Up @@ -34,9 +33,9 @@ def site_callback(request, context):
context.status_code = 401
return {"system/platform/version": "4.3.2"}
m.get('http://geonetwork/api/site', json=site_callback, cookies=cookies)
with pytest.raises(HTTPError) as err:
with pytest.raises(AuthException) as err:
GnApi("http://geonetwork/api")
assert "401" in str(err.value)
assert err.value.code == 401


def test_unsupported_version():
Expand All @@ -49,7 +48,7 @@ def site_callback(request, context):
m.get('http://geonetwork/api/site', json=site_callback, cookies=cookies)
with pytest.raises(APIVersionException) as err:
GnApi("http://geonetwork/api")
assert err.value.args[0]["code"] == 501
assert err.value.code == 501


def test_record_zip(init_gn):
Expand All @@ -70,7 +69,7 @@ def test_record_zip_unknown_uuid(init_gn):
m.get('http://geonetwork/api/records/1232', status_code=404)
with pytest.raises(ParameterException) as err:
init_gn.get_record_zip("1232")
assert err.value.args[0]["code"] == 404
assert err.value.code == 404


def test_upload_zip(init_gn):
Expand Down Expand Up @@ -122,8 +121,8 @@ def record_callback(request, context):
zipdata = BytesIO(b"dummy_zip")
with pytest.raises(ParameterException) as err:
init_gn.put_record_zip(zipdata)
assert err.value.args[0]["code"] == 404
assert err.value.args[0]["details"] == [
assert err.value.code == 400
assert err.value.details["stack"] == [
{"message": "err1", "stack": ["line1", "line2"]},
{"message": "err2", "stack": ["e2/line1", " at e2/line2"]},
]
Expand Down
22 changes: 19 additions & 3 deletions tests/test_gn_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from geonetwork import GnSession
import requests_mock
from requests.exceptions import ConnectTimeout
import pytest
from geonetwork import GnSession
from geonetwork.exceptions import AuthException, TimeoutException


def test_anonymous():
Expand Down Expand Up @@ -62,5 +65,18 @@ def text_callback(request, context):
context.status_code = 401
return "test"
m.get("http://mock_server", text=text_callback)
resp = gns.get("http://mock_server")
assert resp.status_code == 401
with pytest.raises(AuthException) as err:
gns.get("http://mock_server")
assert err.value.code == 401


def test_timeout():
gns = GnSession(("test", "test"))
with requests_mock.Mocker() as m:

def timeout_callback(request, context):
raise ConnectTimeout
m.get("http://mock_server", text=timeout_callback)
with pytest.raises(TimeoutException) as err:
gns.get("http://mock_server")
assert err.value.code == 504

0 comments on commit 5ccc5c0

Please sign in to comment.