Skip to content

Commit

Permalink
Create database for each DAG run
Browse files Browse the repository at this point in the history
  • Loading branch information
lwrubel committed Feb 11, 2025
1 parent 81412a2 commit 560a286
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 3 deletions.
1 change: 1 addition & 0 deletions compose.prod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ x-airflow-common:
AIRFLOW_VAR_SUL_PUB_KEY: ${AIRFLOW_VAR_SUL_PUB_KEY}
AIRFLOW_VAR_DATA_DIR: /opt/airflow/data
AIRFLOW_VAR_PUBLISH_DIR: /opt/airflow/data/latest
AIRFLOW_VAR_RIALTO_POSTGRES: "postgresql+psycopg2://${DATABASE_USERNAME}:${DATABASE_PASSWORD}@${DATABASE_HOSTNAME}"
volumes:
- /opt/app/rialto/rialto-airflow/current/rialto_airflow:/opt/airflow/rialto_airflow
- /data:/opt/airflow/data
Expand Down
4 changes: 4 additions & 0 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ x-airflow-common:
AIRFLOW_VAR_DATA_DIR: /opt/airflow/data
AIRFLOW_VAR_PUBLISH_DIR: /opt/airflow/data/latest
AIRFLOW_VAR_OPENALEX_EMAIL: ${AIRFLOW_VAR_OPENALEX_EMAIL}
AIRFLOW_VAR_RIALTO_POSTGRES: "postgresql+psycopg2://airflow:airflow@postgres"
volumes:
- ${AIRFLOW_PROJ_DIR:-.}/rialto_airflow:/opt/airflow/rialto_airflow
- ${AIRFLOW_PROJ_DIR:-.}/logs:/opt/airflow/logs
Expand All @@ -97,6 +98,9 @@ x-airflow-common:
services:
postgres:
image: postgres:16
# make available for local testing outside docker
ports:
- "5432:5432"
environment:
POSTGRES_USER: airflow
POSTGRES_PASSWORD: airflow
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ dependencies = [
"polars>=1.2",
"pyalex",
"more-itertools",
"sqlalchemy>=2.0.38",
"psycopg2>=2.9.10",
]

[tool.pytest.ini_options]
Expand Down
6 changes: 6 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ prompt-toolkit==3.0.50
# via
# dimcli
# ipython
psycopg2==2.9.10
# via rialto-airflow (pyproject.toml)
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
Expand Down Expand Up @@ -105,6 +107,8 @@ sphinxcontrib-qthelp==2.0.0
# via sphinx
sphinxcontrib-serializinghtml==2.0.0
# via sphinx
sqlalchemy==2.0.38
# via rialto-airflow (pyproject.toml)
stack-data==0.6.3
# via ipython
tqdm==4.67.1
Expand All @@ -113,6 +117,8 @@ traitlets==5.14.3
# via
# ipython
# matplotlib-inline
typing-extensions==4.12.2
# via sqlalchemy
tzdata==2025.1
# via pandas
urllib3==2.3.0
Expand Down
7 changes: 6 additions & 1 deletion rialto_airflow/dags/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from rialto_airflow.harvest.doi_sunet import create_doi_sunet_pickle
from rialto_airflow.harvest.sul_pub import sul_pub_csv
from rialto_airflow.harvest.contribs import create_contribs
from rialto_airflow.utils import create_snapshot_dir, rialto_authors_file
from rialto_airflow.utils import (
create_database,
create_snapshot_dir,
rialto_authors_file,
)

data_dir = Variable.get("data_dir")
publish_dir = Variable.get("publish_dir")
Expand All @@ -36,6 +40,7 @@ def setup():
Setup the data directory.
"""
snapshot_dir = create_snapshot_dir(data_dir)
create_database(snapshot_dir)
return snapshot_dir

@task()
Expand Down
23 changes: 23 additions & 0 deletions rialto_airflow/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,31 @@
import csv
import datetime
import logging
import os
from pathlib import Path
import re

from sqlalchemy import create_engine, text


def create_database(snapshot_dir):
"""Create a DAG-specific database for publications and author/orgs data"""
timestamp = Path(snapshot_dir).name
database_name = f"rialto_{timestamp}"

# set up the connection using the default postgres database
# see discussion here: https://stackoverflow.com/questions/6506578/how-to-create-a-new-database-using-sqlalchemy
# and https://docs.sqlalchemy.org/en/20/core/connections.html#understanding-the-dbapi-level-autocommit-isolation-level
postgres_conn = f"{os.environ.get('AIRFLOW_VAR_RIALTO_POSTGRES')}/postgres"
engine = create_engine(postgres_conn)
with engine.connect() as connection:
connection.execution_options(isolation_level="AUTOCOMMIT")
connection.execute(text(f"create database {database_name}"))
connection.close()

logging.info(f"created database {database_name}")
return database_name


def create_snapshot_dir(data_dir):
snapshots_dir = Path(data_dir) / "snapshots"
Expand Down
41 changes: 39 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import csv
from pathlib import Path

import os
import pytest

from sqlalchemy import create_engine, text
from sqlalchemy.pool import NullPool
from rialto_airflow import utils


Expand All @@ -19,11 +22,45 @@ def authors_csv(tmp_path):
return fixture_file


def test_create_snapshot_dir(tmpdir):
snap_dir = Path(utils.create_snapshot_dir(tmpdir))
def test_create_snapshot_dir(tmp_path):
snap_dir = Path(utils.create_snapshot_dir(tmp_path))
assert snap_dir.is_dir()


@pytest.fixture
def mock_rialto_postgres(monkeypatch):
# Set up the environment variable for the PostgreSQL connection
monkeypatch.setenv(
"AIRFLOW_VAR_RIALTO_POSTGRES",
"postgresql+psycopg2://airflow:airflow@localhost:5432",
)


def test_create_database(tmp_path, mock_rialto_postgres, request):
db_name = utils.create_database(tmp_path)
assert db_name == "rialto_" + Path(tmp_path).name

# Verify the database was created
# using NullPool to avoid connections to the database staying open
postgres_conn = f"{os.environ.get('AIRFLOW_VAR_RIALTO_POSTGRES')}/{db_name}"
engine = create_engine(postgres_conn, poolclass=NullPool)
conn = engine.connect()
# assert that a database connection was able to be made
assert conn
conn.close()

def teardown_database():
# Clean up by creating a connection that is not to the database, and then drop the database
teardown_conn = f"{os.environ.get('AIRFLOW_VAR_RIALTO_POSTGRES')}/postgres"
teardown_engine = create_engine(teardown_conn, poolclass=NullPool)
with teardown_engine.connect() as connection:
connection.execution_options(isolation_level="AUTOCOMMIT")
connection.execute(text(f"drop database {db_name}"))
connection.close()

teardown_database()


def test_rialto_authors_orcids(tmp_path, authors_csv):
orcids = utils.rialto_authors_orcids(authors_csv)
assert len(orcids) == 2
Expand Down
61 changes: 61 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 560a286

Please sign in to comment.