diff --git a/CHANGES.md b/CHANGES.md index 57904cd..5b051d1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,10 @@ * Adding an end date to `CMIP6_UofT`'s temporal extent for better rendering in STAC Browser * Updates to datacube extension helper routines for `CMIP6_UofT`. * Make pyessv-archive a requirement for *only* the cmip6 implementation instead of for the whole CLI +* Fix bug where logger setup failed +* Simplify CLI argument constructor code (for cleaner and more testable code) +* Add tests for CLI and implementations when invoked through the CLI +* Refactored code dealing with requests and authentication to the `requests.py` file ## [0.6.0](https://github.com/crim-ca/stac-populator/tree/0.6.0) (2024-02-22) diff --git a/STACpopulator/cli.py b/STACpopulator/cli.py index cc4a9fe..44a295a 100644 --- a/STACpopulator/cli.py +++ b/STACpopulator/cli.py @@ -1,108 +1,19 @@ import argparse -import glob +import functools import importlib import logging -import os import sys +from types import ModuleType import warnings -from datetime import datetime -from http import cookiejar -from typing import Callable, Optional +from datetime import datetime, timezone +from typing import Callable -import requests -from requests.auth import AuthBase, HTTPBasicAuth, HTTPDigestAuth, HTTPProxyAuth -from requests.sessions import Session - -from STACpopulator import __version__ +from STACpopulator import __version__, implementations from STACpopulator.exceptions import STACPopulatorError from STACpopulator.logging import setup_logging -POPULATORS = {} - - -class HTTPBearerTokenAuth(AuthBase): - def __init__(self, token: str) -> None: - self._token = token - - def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: - r.headers["Authorization"] = f"Bearer {self._token}" - return r - - -class HTTPCookieAuth(cookiejar.MozillaCookieJar): - """ - Employ a cookie-jar file for authorization. - - Examples of useful command: - - .. code-block:: shell - - curl --cookie-jar /path/to/cookie-jar.txt [authorization-provider-arguments] - - curl \ - -k \ - -X POST \ - --cookie-jar /tmp/magpie-cookie.txt \ - -d '{"user_name":"...","password":"..."}' \ - -H 'Accept:application/json' \ - -H 'Content-Type:application/json' \ - 'https://{hostname}/magpie/signin' - - .. note:: - Due to implementation details with :mod:`requests`, this must be passed directly to the ``cookies`` - attribute rather than ``auth`` as in the case for other authorization handlers. - """ - - -def add_request_options(parser: argparse.ArgumentParser) -> None: - """ - Adds arguments to a parser to allow update of a request session definition used across a populator procedure. - """ - parser.add_argument( - "--no-verify", - "--no-ssl", - "--no-ssl-verify", - dest="verify", - action="store_false", - help="Disable SSL verification (not recommended unless for development/test servers).", - ) - parser.add_argument("--cert", type=argparse.FileType(), required=False, help="Path to a certificate file to use.") - parser.add_argument( - "--auth-handler", - choices=["basic", "digest", "bearer", "proxy", "cookie"], - required=False, - help="Authentication strategy to employ for the requests session.", - ) - parser.add_argument( - "--auth-identity", - required=False, - help="Bearer token, cookie-jar file or proxy/digest/basic username:password for selected authorization handler.", - ) - -def apply_request_options(session: Session, namespace: argparse.Namespace) -> None: - """ - Applies the relevant request session options from parsed input arguments. - """ - session.verify = namespace.verify - session.cert = namespace.cert - if namespace.auth_handler in ["basic", "digest", "proxy"]: - usr, pwd = namespace.auth_identity.split(":", 1) - if namespace.auth_handler == "basic": - session.auth = HTTPBasicAuth(usr, pwd) - elif namespace.auth_handler == "digest": - session.auth = HTTPDigestAuth(usr, pwd) - else: - session.auth = HTTPProxyAuth(usr, pwd) - elif namespace.auth_handler == "bearer": - session.auth = HTTPBearerTokenAuth(namespace.auth_identity) - elif namespace.auth_handler == "cookie": - session.cookies = HTTPCookieAuth(namespace.auth_identity) - session.cookies.load(namespace.auth_identity) - - -def make_main_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(prog="stac-populator", description="STACpopulator operations.") +def add_parser_args(parser: argparse.ArgumentParser) -> dict[str, Callable]: parser.add_argument( "--version", "-V", @@ -110,108 +21,47 @@ def make_main_parser() -> argparse.ArgumentParser: version=f"%(prog)s {__version__}", help="prints the version of the library and exits", ) - commands = parser.add_subparsers(title="command", dest="command", description="STAC populator command to execute.") - - run_cmd_parser = make_run_command_parser(parser.prog) - commands.add_parser( - "run", - prog=f"{parser.prog} {run_cmd_parser.prog}", - parents=[run_cmd_parser], - formatter_class=run_cmd_parser.formatter_class, - usage=run_cmd_parser.usage, - add_help=False, - help=run_cmd_parser.description, - description=run_cmd_parser.description, + parser.add_argument("--debug", action="store_const", const=logging.DEBUG, help="set logger level to debug") + parser.add_argument( + "--log_file", help="file to write log output to. By default logs will be written to the current directory." ) + commands_subparser = parser.add_subparsers( + title="command", dest="command", description="STAC populator command to execute.", required=True + ) + run_parser = commands_subparser.add_parser("run", description="Run a STACpopulator implementation") + populators_subparser = run_parser.add_subparsers( + title="populator", dest="populator", description="Implementation to run." + ) + for implementation_module_name, module in implementation_modules().items(): + implementation_parser = populators_subparser.add_parser(implementation_module_name) + module.add_parser_args(implementation_parser) - # add more commands as needed... - parser.add_argument("--debug", action="store_true", help="Set logger level to debug") - - return parser - - -def make_run_command_parser(parent) -> argparse.ArgumentParser: - """ - Groups all sub-populator CLI listed in :py:mod:`STACpopulator.implementations` as a common ``stac-populator`` CLI. - - Dispatches the provided arguments to the appropriate sub-populator CLI as requested. Each sub-populator CLI must - implement functions ``make_parser`` and ``main`` to generate the arguments and dispatch them to the corresponding - caller. The ``main`` function should accept a sequence of string arguments, which can be passed to the parser - obtained from ``make_parser``. - An optional ``runner`` can also be defined in each populator module. If provided, the namespace arguments that have - already been parsed to resolve the populator to run will be used directly, avoiding parsing arguments twice. - """ - parser = argparse.ArgumentParser(prog="run", description="STACpopulator implementation runner.") - subparsers = parser.add_subparsers(title="populator", dest="populator", description="Implementation to run.") - populators_impl = "implementations" - populators_dir = os.path.join(os.path.dirname(__file__), populators_impl) - populator_mods = glob.glob(f"{populators_dir}/**/[!__init__]*.py", recursive=True) # potential candidate scripts - for populator_path in sorted(populator_mods): - populator_script = populator_path.split(populators_dir, 1)[1][1:] - populator_py_mod = os.path.splitext(populator_script)[0].replace(os.sep, ".") - populator_name, pop_mod_file = populator_py_mod.rsplit(".", 1) - populator_root = f"STACpopulator.{populators_impl}.{populator_name}" - pop_mod_file_loc = f"{populator_root}.{pop_mod_file}" +@functools.cache +def implementation_modules() -> dict[str, ModuleType]: + modules = {} + for implementation_module_name in implementations.__all__: try: - populator_module = importlib.import_module(pop_mod_file_loc, populator_root) - except STACPopulatorError as e: - warnings.warn(f"Could not load extension {populator_name} because of error {e}") - continue - parser_maker: Callable[[], argparse.ArgumentParser] = getattr(populator_module, "make_parser", None) - populator_runner = getattr(populator_module, "runner", None) # optional, call main directly if not available - populator_caller = getattr(populator_module, "main", None) - if callable(parser_maker) and callable(populator_caller): - populator_parser = parser_maker() - populator_prog = f"{parent} {parser.prog} {populator_name}" - subparsers.add_parser( - populator_name, - prog=populator_prog, - parents=[populator_parser], - formatter_class=populator_parser.formatter_class, - add_help=False, # add help disabled otherwise conflicts with this main populator help - help=populator_parser.description, - description=populator_parser.description, - usage=populator_parser.usage, + modules[implementation_module_name] = importlib.import_module( + f".{implementation_module_name}", implementations.__package__ ) - POPULATORS[populator_name] = { - "name": populator_name, - "caller": populator_caller, - "parser": populator_parser, - "runner": populator_runner, - } - return parser + except STACPopulatorError as e: + warnings.warn(f"Could not load extension {implementation_module_name} because of error {e}") + return modules -def main(*args: str) -> Optional[int]: - parser = make_main_parser() - args = args or sys.argv[1:] # same as was parse args does, but we must provide them to subparser - ns = parser.parse_args(args=args) # if 'command' or 'populator' unknown, auto prints the help message with exit(2) - params = vars(ns) - populator_cmd = params.pop("command") - if not populator_cmd: - parser.print_help() - return 0 - result = None - if populator_cmd == "run": - populator_name = params.pop("populator") +def run(ns: argparse.Namespace) -> int: + if ns.command == "run": + logfile_name = ns.log_file or f"{ns.populator}_log_{datetime.now(timezone.utc).isoformat() + 'Z'}.jsonl" + setup_logging(logfile_name, ns.debug or logging.INFO) + return implementation_modules()[ns.populator].runner(ns) or 0 - # Setup the application logger: - fname = f"{populator_name}_log_{datetime.utcnow().isoformat() + 'Z'}.jsonl" - log_level = logging.DEBUG if ns.debug else logging.INFO - setup_logging(fname, log_level) - if not populator_name: - parser.print_help() - return 0 - populator_args = args[2:] # skip [command] [populator] - populator_caller = POPULATORS[populator_name]["caller"] - populator_runner = POPULATORS[populator_name]["runner"] - if populator_runner: - result = populator_runner(ns) - else: - result = populator_caller(*populator_args) - return 0 if result is None else result +def main(*args: str) -> int: + parser = argparse.ArgumentParser() + add_parser_args(parser) + ns = parser.parse_args(args or None) + return run(ns) if __name__ == "__main__": diff --git a/STACpopulator/implementations/CMIP6_UofT/__init__.py b/STACpopulator/implementations/CMIP6_UofT/__init__.py index e69de29..c623e47 100644 --- a/STACpopulator/implementations/CMIP6_UofT/__init__.py +++ b/STACpopulator/implementations/CMIP6_UofT/__init__.py @@ -0,0 +1,3 @@ +from .add_CMIP6 import add_parser_args, runner + +__all__ = ["add_parser_args", "runner"] diff --git a/STACpopulator/implementations/CMIP6_UofT/add_CMIP6.py b/STACpopulator/implementations/CMIP6_UofT/add_CMIP6.py index 4bfd172..5ed3fd8 100644 --- a/STACpopulator/implementations/CMIP6_UofT/add_CMIP6.py +++ b/STACpopulator/implementations/CMIP6_UofT/add_CMIP6.py @@ -2,13 +2,14 @@ import json import logging import os -from typing import Any, MutableMapping, NoReturn, Optional, Union +import sys +from typing import Any, MutableMapping, Optional, Union from pystac import STACValidationError from pystac.extensions.datacube import DatacubeExtension from requests.sessions import Session -from STACpopulator.cli import add_request_options, apply_request_options +from STACpopulator.requests import add_request_options, apply_request_options from STACpopulator.extensions.cmip6 import CMIP6Helper, CMIP6Properties from STACpopulator.extensions.datacube import DataCubeHelper from STACpopulator.extensions.thredds import THREDDSExtension, THREDDSHelper @@ -78,17 +79,17 @@ def create_stac_item( try: item.validate() - except STACValidationError: + except STACValidationError as e: raise Exception("Failed to validate STAC item") from e # print(json.dumps(item.to_dict())) return json.loads(json.dumps(item.to_dict())) -def make_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="CMIP6 STAC populator from a THREDDS catalog or NCML XML.") - parser.add_argument("stac_host", type=str, help="STAC API address") - parser.add_argument("href", type=str, help="URL to a THREDDS catalog or a NCML XML with CMIP6 metadata.") +def add_parser_args(parser: argparse.ArgumentParser) -> None: + parser.description="CMIP6 STAC populator from a THREDDS catalog or NCML XML." + parser.add_argument("stac_host", help="STAC API URL") + parser.add_argument("href", help="URL to a THREDDS catalog or a NCML XML with CMIP6 metadata.") parser.add_argument("--update", action="store_true", help="Update collection and its items") parser.add_argument( "--mode", @@ -105,10 +106,9 @@ def make_parser() -> argparse.ArgumentParser: ), ) add_request_options(parser) - return parser -def runner(ns: argparse.Namespace) -> Optional[int] | NoReturn: +def runner(ns: argparse.Namespace) -> int: LOGGER.info(f"Arguments to call: {vars(ns)}") with Session() as session: @@ -123,13 +123,14 @@ def runner(ns: argparse.Namespace) -> Optional[int] | NoReturn: ns.stac_host, data_loader, update=ns.update, session=session, config_file=ns.config, log_debug=ns.debug ) c.ingest() + return 0 -def main(*args: str) -> Optional[int]: - parser = make_parser() +def main(*args: str) -> int: + parser = argparse.ArgumentParser() ns = parser.parse_args(args or None) return runner(ns) if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/STACpopulator/implementations/DirectoryLoader/__init__.py b/STACpopulator/implementations/DirectoryLoader/__init__.py index e69de29..ee80e45 100644 --- a/STACpopulator/implementations/DirectoryLoader/__init__.py +++ b/STACpopulator/implementations/DirectoryLoader/__init__.py @@ -0,0 +1,3 @@ +from .crawl_directory import add_parser_args, runner + +__all__ = ["add_parser_args", "runner"] diff --git a/STACpopulator/implementations/DirectoryLoader/crawl_directory.py b/STACpopulator/implementations/DirectoryLoader/crawl_directory.py index bae4c61..9b4cf01 100644 --- a/STACpopulator/implementations/DirectoryLoader/crawl_directory.py +++ b/STACpopulator/implementations/DirectoryLoader/crawl_directory.py @@ -1,11 +1,12 @@ import argparse import logging import os.path -from typing import Any, MutableMapping, NoReturn, Optional +import sys +from typing import Any, MutableMapping, Optional from requests.sessions import Session -from STACpopulator.cli import add_request_options, apply_request_options +from STACpopulator.requests import add_request_options, apply_request_options from STACpopulator.input import STACDirectoryLoader from STACpopulator.models import GeoJSONPolygon from STACpopulator.populator_base import STACpopulatorBase @@ -39,8 +40,8 @@ def create_stac_item(self, item_name: str, item_data: MutableMapping[str, Any]) return item_data -def make_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Directory STAC populator") +def add_parser_args(parser: argparse.ArgumentParser) -> None: + parser.description="Directory STAC populator" parser.add_argument("stac_host", type=str, help="STAC API URL.") parser.add_argument("directory", type=str, help="Path to a directory structure with STAC Collections and Items.") parser.add_argument("--update", action="store_true", help="Update collection and its items.") @@ -50,10 +51,9 @@ def make_parser() -> argparse.ArgumentParser: help="Limit search of STAC Collections only to first top-most matches in the crawled directory structure.", ) add_request_options(parser) - return parser -def runner(ns: argparse.Namespace) -> Optional[int] | NoReturn: +def runner(ns: argparse.Namespace) -> int: LOGGER.info(f"Arguments to call: {vars(ns)}") with Session() as session: @@ -63,13 +63,15 @@ def runner(ns: argparse.Namespace) -> Optional[int] | NoReturn: loader = STACDirectoryLoader(collection_dir, "item", prune=ns.prune) populator = DirectoryPopulator(ns.stac_host, loader, ns.update, collection_json, session=session) populator.ingest() + return 0 -def main(*args: str) -> Optional[int]: - parser = make_parser() +def main(*args: str) -> int: + parser = argparse.ArgumentParser() + add_parser_args(parser) ns = parser.parse_args(args or None) return runner(ns) if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/STACpopulator/implementations/NEX_GDDP_UofT/__init__.py b/STACpopulator/implementations/NEX_GDDP_UofT/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/STACpopulator/implementations/NEX_GDDP_UofT/add_NEX-GDDP.py b/STACpopulator/implementations/NEX_GDDP_UofT/add_NEX-GDDP.py deleted file mode 100644 index e69de29..0000000 diff --git a/STACpopulator/implementations/__init__.py b/STACpopulator/implementations/__init__.py index e69de29..80c732b 100644 --- a/STACpopulator/implementations/__init__.py +++ b/STACpopulator/implementations/__init__.py @@ -0,0 +1,8 @@ +# By adding modules to __all__, they are discoverable by the cli.implementation_modules method and +# become available to be invoked through the CLI. +# All modules in this list must contain two functions: +# - add_parser_args(parser: argparse.ArgumentParser) -> None +# - adds additional arguments to the given parser needed to run this implementation +# - def runner(ns: argparse.Namespace) -> int: +# - runs the implementation given a namespace constructed from the parser arguments supplied +__all__ = ["CMIP6_UofT", "DirectoryLoader"] diff --git a/STACpopulator/logging.py b/STACpopulator/logging.py index 0d7caad..e278360 100644 --- a/STACpopulator/logging.py +++ b/STACpopulator/logging.py @@ -1,6 +1,6 @@ import datetime as dt import json -import logging +import logging.config LOG_RECORD_BUILTIN_ATTRS = { "args", diff --git a/STACpopulator/requests.py b/STACpopulator/requests.py new file mode 100644 index 0000000..c193400 --- /dev/null +++ b/STACpopulator/requests.py @@ -0,0 +1,60 @@ +import argparse +from http import cookiejar + +import requests +from requests.auth import AuthBase, HTTPBasicAuth, HTTPDigestAuth, HTTPProxyAuth +from requests.sessions import Session + + +class HTTPBearerTokenAuth(AuthBase): + def __init__(self, token: str) -> None: + self._token = token + + def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + r.headers["Authorization"] = f"Bearer {self._token}" + return r + + +def add_request_options(parser: argparse.ArgumentParser) -> None: + """ + Adds arguments to a parser to allow update of a request session definition used across a populator procedure. + """ + parser.add_argument( + "--no-verify", + "--no-ssl", + "--no-ssl-verify", + dest="verify", + action="store_false", + help="Disable SSL verification (not recommended unless for development/test servers).", + ) + parser.add_argument("--cert", type=argparse.FileType(), help="Path to a certificate file to use.") + parser.add_argument( + "--auth-handler", + choices=["basic", "digest", "bearer", "proxy", "cookie"], + help="Authentication strategy to employ for the requests session.", + ) + parser.add_argument( + "--auth-identity", + help="Bearer token, cookie-jar file or proxy/digest/basic username:password for selected authorization handler.", + ) + + +def apply_request_options(session: Session, namespace: argparse.Namespace) -> None: + """ + Applies the relevant request session options from parsed input arguments. + """ + session.verify = namespace.verify + session.cert = namespace.cert + if namespace.auth_handler in ["basic", "digest", "proxy"]: + usr, pwd = namespace.auth_identity.split(":", 1) + if namespace.auth_handler == "basic": + session.auth = HTTPBasicAuth(usr, pwd) + elif namespace.auth_handler == "digest": + session.auth = HTTPDigestAuth(usr, pwd) + else: + session.auth = HTTPProxyAuth(usr, pwd) + elif namespace.auth_handler == "bearer": + session.auth = HTTPBearerTokenAuth(namespace.auth_identity) + elif namespace.auth_handler == "cookie": + session.cookies = cookiejar.MozillaCookieJar(namespace.auth_identity) + session.cookies.load(namespace.auth_identity) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..add5bf3 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,48 @@ +import os +import re +import subprocess +import tempfile +from typing import Mapping + +import pytest + +from STACpopulator import implementations + + +def run_cli(*args: str, **kwargs: Mapping) -> subprocess.CompletedProcess: + return subprocess.run(args, stderr=subprocess.PIPE, stdout=subprocess.PIPE, universal_newlines=True, **kwargs) + + +@pytest.fixture(scope="session") +def populator_help_pattern(): + return re.compile(f"\{{({',?|'.join([imp.replace('.', '\\.') for imp in implementations.__all__])},?)+\}}") + + +def test_help(): + """Test that there are no errors when running a very basic command""" + proc = run_cli("stac-populator", "--help") + proc.check_returncode() + + +def test_run_implementation(populator_help_pattern): + """ + Test that all implementations can be loaded from the command line + + This test assumes that the pyessv-archive is installed in the default location. + Run `make setup-pyessv-archive` prior to running this test. + """ + proc = run_cli("stac-populator", "run", "--help") + proc.check_returncode() + populators = re.search(populator_help_pattern, proc.stdout) + assert set(implementations.__all__) == set(populators.group(0).strip("{}").split(",")) + + +def test_missing_implementation(populator_help_pattern): + """Test that implementations that can't load are missing from the options""" + with tempfile.TemporaryDirectory() as dirname: + pass # this allows us to get a dirname that does not exist + proc = run_cli("stac-populator", "run", "--help", env={**os.environ, "PYESSV_ARCHIVE_HOME": dirname}) + proc.check_returncode() + populators = re.search(populator_help_pattern, proc.stdout) + assert "CMIP6_UofT" in implementations.__all__ # sanity check + assert "CMIP6_UofT" not in set(populators.group(0).strip("{}").split(",")) diff --git a/tests/test_directory_loader.py b/tests/test_directory_loader.py index 88cf869..5a9c17f 100644 --- a/tests/test_directory_loader.py +++ b/tests/test_directory_loader.py @@ -1,30 +1,21 @@ +import abc import argparse +import functools import json import os +from typing import Any, Callable, Generator import pytest import responses from STACpopulator.implementations.DirectoryLoader import crawl_directory +from STACpopulator.cli import add_parser_args, main as cli_main -CUR_DIR = os.path.dirname(__file__) - - -@pytest.mark.parametrize( - "prune_option", - [True, False] -) -def test_directory_loader_populator_runner(prune_option: bool): - ns = argparse.Namespace() - stac_host = "http://test-host.com/stac/" - setattr(ns, "verify", False) - setattr(ns, "cert", None) - setattr(ns, "auth_handler", None) - setattr(ns, "stac_host", stac_host) - setattr(ns, "directory", os.path.join(CUR_DIR, "data/test_directory")) - setattr(ns, "prune", prune_option) - setattr(ns, "update", True) - - file_id_map = { +type RequestContext = Generator[responses.RequestsMock, None, None] + + +@pytest.fixture(scope="session") +def file_id_map() -> dict[str, str]: + return { "collection.json": "EuroSAT-subset-train", "item-0.json": "EuroSAT-subset-train-sample-0-class-AnnualCrop", "item-1.json": "EuroSAT-subset-train-sample-1-class-AnnualCrop", @@ -32,37 +23,68 @@ def test_directory_loader_populator_runner(prune_option: bool): "nested/item-0.json": "EuroSAT-subset-test-sample-0-class-AnnualCrop", "nested/item-1.json": "EuroSAT-subset-test-sample-1-class-AnnualCrop", } - file_contents = {} + + +@pytest.fixture(scope="package") +def file_contents(file_id_map: dict[str, str], request: pytest.FixtureRequest) -> dict[str, bytes]: + contents = {} for file_name in file_id_map: - ref_file = os.path.join(CUR_DIR, "data/test_directory", file_name) + ref_file = os.path.join(request.fspath.dirname, "data/test_directory", file_name) with open(ref_file, mode="r", encoding="utf-8") as f: json_data = json.load(f) - file_contents[file_name] = json.dumps(json_data, indent=None).encode() + contents[file_name] = json.dumps(json_data, indent=None).encode() + return contents + - with responses.RequestsMock(assert_all_requests_are_fired=False) as request_mock: - request_mock.add("GET", stac_host, json={"stac_version": "1.0.0", "type": "Catalog"}) - request_mock.add( +@pytest.fixture(autouse=True) +def request_mock(namespace: argparse.Namespace, file_id_map: dict[str, str]) -> RequestContext: + with responses.RequestsMock(assert_all_requests_are_fired=False) as mock_context: + mock_context.add("GET", namespace.stac_host, json={"stac_version": "1.0.0", "type": "Catalog"}) + mock_context.add( "POST", - f"{stac_host}collections", + f"{namespace.stac_host}collections", headers={"Content-Type": "application/json"}, ) - request_mock.add( + mock_context.add( "POST", - f"{stac_host}collections/{file_id_map['collection.json']}/items", + f"{namespace.stac_host}collections/{file_id_map['collection.json']}/items", headers={"Content-Type": "application/json"}, ) - request_mock.add( + mock_context.add( "POST", - f"{stac_host}collections/{file_id_map['nested/collection.json']}/items", + f"{namespace.stac_host}collections/{file_id_map['nested/collection.json']}/items", headers={"Content-Type": "application/json"}, ) - - crawl_directory.runner(ns) + yield mock_context + + +@pytest.mark.parametrize("prune_option", [True, False]) +class _TestDirectoryLoader(abc.ABC): + @abc.abstractmethod + @pytest.fixture + def namespace(self, *args: Any) -> argparse.Namespace: + raise NotImplementedError + + @abc.abstractmethod + @pytest.fixture + def runner(self, *args: Any) -> Callable: + raise NotImplementedError + + def test_runner( + self, + prune_option: bool, + namespace: argparse.Namespace, + file_id_map: dict[str, str], + file_contents: dict[str, bytes], + request_mock: RequestContext, + runner: Callable, + ): + runner() assert len(request_mock.calls) == (4 if prune_option else 8) - assert request_mock.calls[0].request.url == stac_host + assert request_mock.calls[0].request.url == namespace.stac_host - base_col = file_id_map['collection.json'] + base_col = file_id_map["collection.json"] assert request_mock.calls[1].request.path_url == "/stac/collections" assert request_mock.calls[1].request.body == file_contents["collection.json"] @@ -80,7 +102,7 @@ def test_directory_loader_populator_runner(prune_option: bool): assert request_mock.calls[item1_idx].request.body == file_contents["item-1.json"] if not prune_option: - assert request_mock.calls[4].request.url == stac_host + assert request_mock.calls[4].request.url == namespace.stac_host nested_col = file_id_map["nested/collection.json"] assert request_mock.calls[5].request.path_url == "/stac/collections" @@ -98,3 +120,47 @@ def test_directory_loader_populator_runner(prune_option: bool): assert request_mock.calls[item1_idx].request.path_url == f"/stac/collections/{nested_col}/items" assert request_mock.calls[item1_idx].request.body == file_contents["nested/item-1.json"] + + +class TestModule(_TestDirectoryLoader): + @pytest.fixture + def runner(self, namespace: argparse.Namespace) -> Callable: + return functools.partial(crawl_directory.runner, namespace) + + @pytest.fixture + def namespace(self, request: pytest.FixtureRequest, prune_option: bool) -> argparse.Namespace: + return argparse.Namespace( + verify=False, + cert=None, + auth_handler=None, + stac_host="http://example.com/stac/", + directory=os.path.join(request.fspath.dirname, "data/test_directory"), + prune=prune_option, + update=True, + ) + + +class TestFromCLI(_TestDirectoryLoader): + @pytest.fixture + def args(self, request: pytest.FixtureRequest, prune_option: bool) -> list[str]: + cmd_args = [ + "run", + "DirectoryLoader", + "http://example.com/stac/", + os.path.join(request.fspath.dirname, "data/test_directory"), + "--no-verify", + "--update", + ] + if prune_option: + cmd_args.append("--prune") + return cmd_args + + @pytest.fixture + def runner(self, args: list[str]) -> int: + return functools.partial(cli_main, *args) + + @pytest.fixture + def namespace(self, args: tuple[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser() + add_parser_args(parser) + return parser.parse_args(args)