Skip to content

Commit

Permalink
feat: add new test db for unit testing and some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
leon-schmid committed Nov 25, 2024
1 parent bd8f26d commit 38a91e4
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 176 deletions.
26 changes: 17 additions & 9 deletions admyral/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,29 @@ class SecretsManagerType(str, Enum):


ENV_ADMYRAL_DATABASE_URL = "ADMYRAL_DATABASE_URL"
ENV_ADMYRAL_TEST_DATABASE_URL = "ADMYRAL_TEST_DATABASE_URL"
ENV_ADMYRAL_SECRETS_MANAGER_TYPE = "ADMYRAL_SECRETS_MANAGER"


ADMYRAL_DATABASE_URL = os.getenv(
ENV_ADMYRAL_DATABASE_URL,
"postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/admyral",
)
if ADMYRAL_DATABASE_URL.startswith("postgresql"):
if ADMYRAL_DATABASE_URL.startswith("postgresql://"):
ADMYRAL_DATABASE_URL = ADMYRAL_DATABASE_URL.replace(
"postgresql://", "postgresql+asyncpg://"
)
ADMYRAL_DATABASE_TYPE = DatabaseType.POSTGRES
else:
raise NotImplementedError(f"Unsupported database type: {ADMYRAL_DATABASE_URL}")

ADMYRAL_TEST_DATABASE_URL = os.getenv(
ENV_ADMYRAL_TEST_DATABASE_URL,
"postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5433/admyral",
)

if ADMYRAL_TEST_DATABASE_URL.startswith("postgresql://"):
ADMYRAL_TEST_DATABASE_URL = ADMYRAL_TEST_DATABASE_URL.replace(
"postgresql://", "postgresql+asyncpg://"
)

if ADMYRAL_DATABASE_URL.startswith("postgresql://"):
ADMYRAL_DATABASE_URL = ADMYRAL_DATABASE_URL.replace(
"postgresql://", "postgresql+asyncpg://"
)

ADMYRAL_SECRETS_MANAGER_TYPE = SecretsManagerType(
os.getenv(ENV_ADMYRAL_SECRETS_MANAGER_TYPE, SecretsManagerType.SQL)
Expand All @@ -151,8 +159,8 @@ class GlobalConfig(BaseModel):
default_user_email: str = "[email protected]"
telemetry_disabled: bool = ADMYRAL_DISABLE_TELEMETRY
storage_directory: str = get_local_storage_path()
database_type: DatabaseType = ADMYRAL_DATABASE_TYPE
database_url: str = ADMYRAL_DATABASE_URL
test_database_url: str = ADMYRAL_TEST_DATABASE_URL
temporal_host: str = ADMYRAL_TEMPORAL_HOST
secrets_manager_type: SecretsManagerType = ADMYRAL_SECRETS_MANAGER_TYPE
posthog_api_key: str = ADMYRAL_POSTHOG_API_KEY
Expand Down
16 changes: 9 additions & 7 deletions admyral/db/admyral_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
WorkflowControlResultsSchema,
)
from admyral.db.alembic.database_manager import DatabaseManager
from admyral.config.config import GlobalConfig, CONFIG
from admyral.config.config import CONFIG
from admyral.logger import get_logger
from admyral.utils.time import utc_now
from admyral.utils.crypto import generate_hs256
Expand Down Expand Up @@ -74,11 +74,11 @@ async def commit(self) -> None:


class AdmyralStore(StoreInterface):
def __init__(self, config: GlobalConfig) -> None:
self.config = config
def __init__(self, database_url: str) -> None:
self.database_url = database_url

self.engine = create_async_engine(
self.config.database_url, echo=True, future=True, pool_pre_ping=True
database_url, echo=True, future=True, pool_pre_ping=True
)
self.async_session_maker = sessionmaker(
self.engine, class_=AsyncSession, expire_on_commit=False
Expand All @@ -90,8 +90,10 @@ def __init__(self, config: GlobalConfig) -> None:

# TODO: pass down config
@classmethod
async def create_store(cls, skip_setup: bool = False) -> "AdmyralStore":
store = cls(CONFIG)
async def create_store(
cls, skip_setup: bool = False, database_url: str | None = None
) -> "AdmyralStore":
store = cls(database_url or CONFIG.database_url)
if not skip_setup:
await store.setup()
store.performed_setup = True
Expand All @@ -105,7 +107,7 @@ async def create_store(cls, skip_setup: bool = False) -> "AdmyralStore":
async def setup(self):
logger.info("Setting up Admyral store.")

database_manager = DatabaseManager(self.engine, self.config)
database_manager = DatabaseManager(self.engine, self.database_url)

does_db_exist = await database_manager.database_exists()
if not does_db_exist:
Expand Down
50 changes: 19 additions & 31 deletions admyral/db/alembic/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from sqlalchemy.engine import Connection
from functools import partial

from admyral.config.config import GlobalConfig, DatabaseType


# TODO: why are we filtering out the alembic_version table?
def include_object(object, name, type_, reflected, compare_to):
Expand All @@ -25,9 +23,9 @@ def get_admyral_dir() -> str:


class DatabaseManager:
def __init__(self, engine: AsyncEngine, config: GlobalConfig) -> None:
def __init__(self, engine: AsyncEngine, database_url: str) -> None:
self.engine = engine
self.config = config
self.database_url = database_url

self.target_metadata = SQLModel.metadata

Expand All @@ -42,39 +40,29 @@ def __init__(self, engine: AsyncEngine, config: GlobalConfig) -> None:

def _get_postgres_setup_engine(self) -> str:
# https://stackoverflow.com/questions/6506578/how-to-create-a-new-database-using-sqlalchemy/8977109#8977109
db_name = self.config.database_url.split("/")[-1]
db_url = self.config.database_url[: -len(db_name)] + "postgres"
db_name = self.database_url.split("/")[-1]
db_url = self.database_url[: -len(db_name)] + "postgres"
return create_async_engine(db_url, echo=True, future=True, pool_pre_ping=True)

async def database_exists(self) -> bool:
if self.config.database_type == DatabaseType.POSTGRES:
engine = self._get_postgres_setup_engine()
try:
async with engine.connect() as conn:
result = await conn.execute(
text(
"select exists (select 1 from pg_database where datname = 'admyral')"
)
engine = self._get_postgres_setup_engine()
try:
async with engine.connect() as conn:
result = await conn.execute(
text(
"select exists (select 1 from pg_database where datname = 'admyral')"
)
return result.scalar()
except Exception:
return False

raise NotImplementedError(
f"Unimplemented database type in database_exists: {self.database_type}"
)
)
return result.scalar()
except Exception:
return False

async def create_database(self) -> None:
if self.config.database_type == DatabaseType.POSTGRES:
engine = self._get_postgres_setup_engine()
async with engine.connect() as conn:
await conn.execute(text("commit"))
await conn.execute(text("create database admyral"))
return

raise NotImplementedError(
f"Unimplemented database type in create_database: {self.database_type}"
)
engine = self._get_postgres_setup_engine()
async with engine.connect() as conn:
await conn.execute(text("commit"))
await conn.execute(text("create database admyral"))
return

async def drop_database(self) -> None:
# TODO:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""add controls and controls_workflows tables
Revision ID: af128e403421
Revision ID: 8fbcc7f42a0b
Revises: 7985f1c159a3
Create Date: 2024-11-20 15:15:35.528059
Create Date: 2024-11-25 17:42:24.701175
"""

Expand All @@ -14,7 +14,7 @@


# revision identifiers, used by Alembic.
revision: str = "af128e403421"
revision: str = "8fbcc7f42a0b"
down_revision: Union[str, None] = "7985f1c159a3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
Expand Down
4 changes: 2 additions & 2 deletions admyral/db/schemas/control_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ControlsWorkflowsMappingSchema(BaseSchema, table=True):
)

# primary keys
control_id: int = Field(primary_key=True)
control_id: int = Field(primary_key=True, default=None)
user_id: str = Field(sa_type=TEXT(), primary_key=True)
workflow_id: str = Field(sa_type=TEXT(), primary_key=True)

Expand Down Expand Up @@ -70,7 +70,7 @@ class ControlSchema(BaseSchema, table=True):
)

# primary keys
control_id: int = Field(primary_key=True)
control_id: int = Field(primary_key=True, default=None)
user_id: str = Field(sa_type=TEXT(), primary_key=True)

# other fields
Expand Down
22 changes: 22 additions & 0 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@ services:
volumes:
- ${POSTGRES_VOLUME_PATH:-.admyral/postgres}:/var/lib/postgresql/data

postgresql-unit-tests:
container_name: postgresql-unit-tests
environment:
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-your-super-secret-and-long-postgres-password}
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_DB: admyral
image: postgres:${POSTGRES_VERSION:-16.4-bookworm}
healthcheck:
test: ["CMD-SHELL", "pg_isready", "-d", "admyral"]
interval: 30s
timeout: 60s
retries: 5
start_period: 80s
networks:
- admyral-network
ports:
- 5433:5432
expose:
- 5432
volumes:
- ${POSTGRES_VOLUME_PATH:-.admyral/postgres-test}:/var/lib/postgresql/data

temporal:
container_name: temporal
environment:
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import asyncio
from admyral.config.config import CONFIG

from admyral.db.admyral_store import AdmyralStore
from admyral.config.config import TEST_USER_ID
Expand All @@ -14,6 +15,6 @@ def event_loop():

@pytest.fixture(scope="session", autouse=True)
async def store(event_loop):
store = await AdmyralStore.create_store()
store = await AdmyralStore.create_store(CONFIG.test_database_url)
yield store
await store.clean_up_workflow_data_of(TEST_USER_ID)
Loading

0 comments on commit 38a91e4

Please sign in to comment.