Skip to content

Commit

Permalink
Refactor tests to use per-test imports
Browse files Browse the repository at this point in the history
  • Loading branch information
tcjennings committed Feb 6, 2025
1 parent e86c1dc commit b9cdf53
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 14 deletions.
18 changes: 18 additions & 0 deletions tests/db/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import importlib
import sys
from collections.abc import Generator

import pytest


@pytest.fixture(autouse=True, scope="function")
def import_deps() -> Generator:
_ = importlib.import_module("lsst.cmservice.handlers.interface")
_ = importlib.import_module("lsst.cmservice.handlers.jobs")
_ = importlib.import_module("lsst.cmservice.handlers.functions")
_ = importlib.import_module("lsst.cmservice.handlers.script_handler")
yield
del sys.modules["lsst.cmservice.handlers.interface"]
del sys.modules["lsst.cmservice.handlers.jobs"]
del sys.modules["lsst.cmservice.handlers.functions"]
del sys.modules["lsst.cmservice.handlers.script_handler"]
4 changes: 2 additions & 2 deletions tests/db/test_daemon.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import uuid
from asyncio import sleep
from datetime import datetime

Expand All @@ -16,7 +17,6 @@


@pytest.mark.asyncio()
@pytest.mark.skip(reason="Test passes when called directly, fails in general run.")
async def test_daemon_db(engine: AsyncEngine) -> None:
"""Test creating a job, add it to the work queue, and start processing."""

Expand All @@ -29,7 +29,7 @@ async def test_daemon_db(engine: AsyncEngine) -> None:
campaign = await interface.load_and_create_campaign(
session,
"tests/fixtures/seeds/example_trivial.yaml",
"trivial_panda",
f"trivial_panda_{uuid.uuid1().int}",
"test_daemon",
"trivial_panda#campaign",
)
Expand Down
12 changes: 10 additions & 2 deletions tests/db/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import importlib
import os
import sys
from pathlib import Path
from typing import Any

Expand All @@ -10,7 +12,6 @@

from lsst.cmservice import db
from lsst.cmservice.common.enums import LevelEnum, StatusEnum
from lsst.cmservice.handlers import interface, jobs

from .util_functions import cleanup, create_tree

Expand All @@ -23,6 +24,7 @@ async def check_run_script(
spec_block_name: str,
**kwargs: Any,
) -> db.Script:
interface = sys.modules["lsst.cmservice.handlers"].interface
script = await db.Script.create_row(
session,
parent_name=parent.fullname,
Expand Down Expand Up @@ -51,6 +53,7 @@ async def check_script(
spec_block_name: str,
**kwargs: Any,
) -> db.Script:
interface = sys.modules["lsst.cmservice.handlers"].interface
script = await db.Script.create_row(
session,
parent_name=parent.fullname,
Expand Down Expand Up @@ -90,6 +93,7 @@ async def check_script(
async def test_handlers_campaign_level_db(
engine: AsyncEngine,
tmp_path: Path,
import_deps: Any,
) -> None:
"""Test to run the write and purge methods of various scripts"""
temp_dir = str(tmp_path / "archive")
Expand Down Expand Up @@ -171,6 +175,7 @@ async def test_handlers_campaign_level_db(
async def test_handlers_step_level_db(
engine: AsyncEngine,
tmp_path: Path,
import_deps: Any,
monkeypatch: MonkeyPatch,
) -> None:
"""Test to run the write and purge methods of various scripts"""
Expand Down Expand Up @@ -211,6 +216,7 @@ async def test_handlers_step_level_db(
async def test_handlers_group_level_db(
engine: AsyncEngine,
tmp_path: Path,
import_deps: Any,
monkeypatch: MonkeyPatch,
) -> None:
"""Test to run the write and purge methods of various scripts"""
Expand Down Expand Up @@ -268,6 +274,8 @@ async def test_handlers_job_level_db(
tmp_path: Path,
) -> None:
"""Test to run the write and purge methods of various scripts"""
interface = importlib.import_module("lsst.cmservice.handlers.interface")
jobs = importlib.import_module("lsst.cmservice.handlers.jobs")
temp_dir = str(tmp_path / "archive")

logger = structlog.get_logger(__name__)
Expand Down Expand Up @@ -363,7 +371,7 @@ async def test_handlers_job_level_db(
)
assert status == StatusEnum.waiting

assert jobs.PandaScriptHandler.get_job_id({"Run Id": 322}) == 322 # type: ignore
assert jobs.PandaScriptHandler.get_job_id({"Run Id": 322}) == 322
assert jobs.HTCondorScriptHandler.get_job_id({"Submit dir": "dummy"}) == "dummy"

await cleanup(session, check_cascade=True)
4 changes: 2 additions & 2 deletions tests/db/test_job.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
import uuid

Expand All @@ -10,7 +11,6 @@
from lsst.cmservice import db
from lsst.cmservice.common import errors
from lsst.cmservice.common.enums import LevelEnum, StatusEnum
from lsst.cmservice.handlers import interface

from .util_functions import (
check_get_methods,
Expand All @@ -25,7 +25,7 @@
@pytest.mark.asyncio()
async def test_job_db(engine: AsyncEngine) -> None:
"""Test `job` db table."""

interface = importlib.import_module("lsst.cmservice.handlers.interface")
# generate a uuid to avoid collisions
uuid_int = uuid.uuid1().int
logger = structlog.get_logger(__name__)
Expand Down
5 changes: 3 additions & 2 deletions tests/db/test_micro.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
from pathlib import Path

Expand All @@ -9,8 +10,6 @@

from lsst.cmservice.common import errors
from lsst.cmservice.common.enums import ScriptMethodEnum, StatusEnum
from lsst.cmservice.handlers import interface
from lsst.cmservice.handlers.script_handler import ScriptHandler

from .util_functions import cleanup

Expand All @@ -26,6 +25,8 @@ async def test_micro_db(
monkeypatch: MonkeyPatch,
) -> None:
"""Test fake end to end run using example/example_micro.yaml"""
ScriptHandler = importlib.import_module("lsst.cmservice.handlers.script_handler").ScriptHandler
interface = importlib.import_module("lsst.cmservice.handlers.interface")
monkeypatch.setattr("lsst.cmservice.config.config.butler.mock", True)

orig_method = ScriptHandler.default_method
Expand Down
5 changes: 4 additions & 1 deletion tests/db/test_pipetask_error_type.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os

import pytest
Expand All @@ -7,7 +8,6 @@
from sqlalchemy.ext.asyncio import AsyncEngine

from lsst.cmservice import db
from lsst.cmservice.handlers import functions, interface

from .util_functions import cleanup, delete_all_rows

Expand All @@ -20,6 +20,8 @@ async def test_error_match_db(engine: AsyncEngine) -> None:
fake error which is not in the database.
"""

interface = importlib.import_module("lsst.cmservice.handlers.interface")
functions = importlib.import_module("lsst.cmservice.handlers.functions")
logger = structlog.get_logger(__name__)
async with engine.begin():
session = await create_async_session(engine, logger)
Expand Down Expand Up @@ -111,6 +113,7 @@ async def test_error_match_db(engine: AsyncEngine) -> None:
async def test_error_type_db(engine: AsyncEngine) -> None:
"""Test `error_type` db table."""

interface = importlib.import_module("lsst.cmservice.handlers.interface")
logger = structlog.get_logger(__name__)
async with engine.begin():
session = await create_async_session(engine, logger)
Expand Down
4 changes: 3 additions & 1 deletion tests/db/test_reports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
import uuid

Expand All @@ -8,7 +9,6 @@

from lsst.cmservice import db
from lsst.cmservice.common.enums import LevelEnum, StatusEnum
from lsst.cmservice.handlers import functions, interface

from .util_functions import (
cleanup,
Expand All @@ -19,6 +19,8 @@
@pytest.mark.asyncio()
async def test_reports_db(engine: AsyncEngine) -> None:
"""Test `job` db table."""
interface = importlib.import_module("lsst.cmservice.handlers.interface")
functions = importlib.import_module("lsst.cmservice.handlers.functions")

# generate a uuid to avoid collisions
uuid_int = uuid.uuid1().int
Expand Down
11 changes: 7 additions & 4 deletions tests/db/util_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
from typing import TypeAlias

import pytest
Expand All @@ -6,7 +7,6 @@
from lsst.cmservice import db
from lsst.cmservice.common import errors
from lsst.cmservice.common.enums import LevelEnum, StatusEnum, TableEnum
from lsst.cmservice.handlers import interface


async def add_scripts(
Expand Down Expand Up @@ -43,6 +43,7 @@ async def create_tree(
level: LevelEnum,
uuid_int: int,
) -> None:
interface = importlib.import_module("lsst.cmservice.handlers.interface")
specification = await interface.load_specification(session, "examples/empty_config.yaml")
_ = await specification.get_block(session, "campaign")

Expand Down Expand Up @@ -107,7 +108,7 @@ async def create_tree(
if level.value <= LevelEnum.group.value:
return

jobs = [
jobs_ = [
await db.Job.create_row(
session,
name=f"job_{uuid_int}",
Expand All @@ -117,8 +118,8 @@ async def create_tree(
for group_ in groups
]

for job_ in jobs:
await add_scripts(session, job_)
for job in jobs_:
await add_scripts(session, job)

return

Expand Down Expand Up @@ -291,6 +292,7 @@ async def check_scripts(
session: async_scoped_session,
entry: db.ElementMixin,
) -> None:
interface = importlib.import_module("lsst.cmservice.handlers.interface")
scripts = await entry.get_scripts(session)
assert len(scripts) == 2, f"Expected exactly two scripts for {entry.fullname} got {len(scripts)}"

Expand Down Expand Up @@ -376,6 +378,7 @@ async def check_get_methods(
entry_class: TypeAlias = db.ElementMixin,
parent_class: TypeAlias = db.ElementMixin,
) -> None:
interface = importlib.import_module("lsst.cmservice.handlers.interface")
check_getall_nonefound = await entry_class.get_rows(
session,
parent_name="bad",
Expand Down

0 comments on commit b9cdf53

Please sign in to comment.