Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breakout connection from athena DB init #328

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cumulus_library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from cumulus_library.study_manifest import StudyManifest

__all__ = ["BaseTableBuilder", "CountsBuilder", "StudyConfig", "StudyManifest"]
__version__ = "4.1.2"
__version__ = "4.1.3"
13 changes: 10 additions & 3 deletions cumulus_library/databases/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,22 @@ def __init__(self, region: str, work_group: str, profile: str, schema_name: str)
self.work_group = work_group
self.profile = profile
self.schema_name = schema_name
self.connection = None

def connect(self):
# the profile may not be required, provided the above three AWS env vars
# are set. If both are present, the env vars take precedence
Comment on lines 38 to 39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is lightly out of date now, since it refers to "above three AWS env vars" and it got moved around

connect_kwargs = {}
if self.profile is not None:
connect_kwargs["profile_name"] = self.profile

for aws_env_name in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
]:
if aws_env_val := os.environ.get(aws_env_name):
connect_kwargs[aws_env_name.lower()] = aws_env_val

self.connection = pyathena.connect(
region_name=self.region,
work_group=self.work_group,
Expand Down Expand Up @@ -102,8 +105,11 @@ def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list:
output.append((column[0], pyarrow.int64()))
case "double":
output.append((column[0], pyarrow.float64()))
# This is future proofing - we don't see this type currently.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we did see it in one edge case - I don't think a normal production flow, but it might have been when dealing with flat tables? Or I'm making it up.

case "decimal":
output.append((column[0], pyarrow.decimal128(column[4], column[5])))
output.append( # pragma: no cover
(column[0], pyarrow.decimal128(column[4], column[5]))
)
case "boolean":
output.append((column[0], pyarrow.bool_()))
case "date":
Expand Down Expand Up @@ -168,7 +174,8 @@ def create_schema(self, schema_name) -> None:
glue_client.create_database(DatabaseInput={"Name": schema_name})

def close(self) -> None:
return self.connection.close() # pragma: no cover
if self.connection is not None: # pragma: no cover
self.connection.close()


class AthenaParser(base.DatabaseParser):
Expand Down
4 changes: 4 additions & 0 deletions cumulus_library/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def __init__(self, schema_name: str):
# technology
self.db_type = None

@abc.abstractmethod
def connect(self):
"""Initiates connection configuration of the database"""
Comment on lines +123 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we wanted to be Extremely Python ™️, we would probably write this abstract class as a context manager, so you could do:

with AthenaDatabase(...) as db:
    # use db

To automatically get connect/close support (and make sure that callers don't forget one of the calls, even when an exception happens).

But not necessary right now and can often be a little annoying to reorganize code to support that style (but does have some benefits - mainly the exception handling).


@abc.abstractmethod
def cursor(self) -> DatabaseCursor:
"""Returns a connection to the backing database"""
Expand Down
11 changes: 8 additions & 3 deletions cumulus_library/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ class DuckDatabaseBackend(base.DatabaseBackend):
def __init__(self, db_file: str, schema_name: str | None = None):
super().__init__("main")
self.db_type = "duckdb"
self.db_file = db_file
self.connection = None

def connect(self):
"""Connects to the local duckdb database"""
# As of the 1.0 duckdb release, local scopes, where schema names can be provided
# as configuration to duckdb.connect, are not supported.
# https://duckdb.org/docs/sql/statements/set.html#syntax

# This is where the connection config would be supplied when it is supported
self.connection = duckdb.connect(db_file)
self.connection = duckdb.connect(self.db_file)
# Aliasing Athena's as_pandas to duckDB's df cast
duckdb.DuckDBPyConnection.as_pandas = duckdb.DuckDBPyConnection.df

Expand Down Expand Up @@ -208,7 +212,8 @@ def create_schema(self, schema_name):
self.connection.sql(f"CREATE SCHEMA {schema_name}")

def close(self) -> None:
self.connection.close()
if self.connection is not None:
self.connection.close()


class DuckDbParser(base.DatabaseParser):
Expand Down
11 changes: 7 additions & 4 deletions cumulus_library/databases/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,12 @@ def create_db_backend(args: dict[str, str]) -> (base.DatabaseBackend, str):
# TODO: reevaluate as DuckDB's local schema support evolves.
# https://duckdb.org/docs/sql/statements/set.html#syntax
if not (args.get("schema_name") is None or args["schema_name"] == "main"):
print(
print( # pragma: no cover
"Warning - local schema names are not yet supported by duckDB's "
"python library - using 'main' instead"
)
schema_name = "main"
backend = duckdb.DuckDatabaseBackend(args["database"])
if load_ndjson_dir:
backend.insert_tables(read_ndjson_dir(load_ndjson_dir))
elif db_config.db_type == "athena":
if (
args.get("schema_name") is not None
Expand All @@ -110,5 +108,10 @@ def create_db_backend(args: dict[str, str]) -> (base.DatabaseBackend, str):
sys.exit("Loading an ndjson dir is not supported with --db-type=athena.")
else:
raise errors.CumulusLibraryError(f"'{db_config.db_type}' is not a supported database.")

if "prepare" not in args.keys():
backend.connect()
elif not args["prepare"]:
backend.connect()
Comment on lines +111 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if not args.get("prepare"):

if backend.connection is not None and db_config.db_type == "duckdb" and load_ndjson_dir:
backend.insert_tables(read_ndjson_dir(load_ndjson_dir))
return (backend, schema_name)
1 change: 1 addition & 0 deletions tests/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_upload_parquet_response_handling(mock_session):
profile="profile",
schema_name="db_schema",
)
db.connect()
client = mock.MagicMock()
with open(path / "test_data/aws/boto3.client.athena.get_work_group.json") as f:
client.get_work_group.return_value = json.load(f)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_cli_path_mapping(mock_load_json, monkeypatch, tmp_path, args, raises, e
args = duckdb_args(args, tmp_path)
cli.main(cli_args=args)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
assert (expected,) in db.cursor().execute("show tables").fetchall()


Expand All @@ -140,6 +141,7 @@ def test_count_builder_mapping(tmp_path):
)
cli.main(cli_args=args)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
assert [
("study_python_counts_valid__lib_transactions",),
("study_python_counts_valid__table1",),
Expand Down Expand Up @@ -271,6 +273,7 @@ def test_clean(tmp_path, args, expected, raises):
with mock.patch.object(builtins, "input", lambda _: "y"):
cli.main(cli_args=duckdb_args(args, tmp_path))
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
for table in db.cursor().execute("show tables").fetchall():
assert expected not in table

Expand Down Expand Up @@ -455,6 +458,7 @@ def test_cli_executes_queries(tmp_path, build_args, export_args, expected_tables
cli.main(cli_args=export_args)

db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
found_tables = (
db.cursor()
.execute("SELECT table_schema,table_name FROM information_schema.tables")
Expand Down Expand Up @@ -541,6 +545,7 @@ def test_cli_transactions(tmp_path, study, finishes, raises):
]
cli.main(cli_args=args)
db = databases.DuckDatabaseBackend(f"{tmp_path}/{study}_duck.db")
db.connect()
query = db.cursor().execute(f"SELECT * from {study}__lib_transactions").fetchall()
assert query[0][2] == "started"
if finishes:
Expand Down Expand Up @@ -579,6 +584,7 @@ def test_cli_stats_rebuild(tmp_path):
cli.main(cli_args=[*arg_list, f"{tmp_path}/export"])
cli.main(cli_args=[*arg_list, f"{tmp_path}/export", "--statistics"])
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
expected = (
db.cursor()
.execute(
Expand Down Expand Up @@ -690,6 +696,7 @@ def test_cli_umls_parsing(mock_config, mode, tmp_path):
def test_cli_single_builder(tmp_path):
cli.main(cli_args=duckdb_args(["build", "--builder=patient", "--target=core"], tmp_path))
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
tables = {x[0] for x in db.cursor().execute("show tables").fetchall()}
assert {
"core__patient",
Expand All @@ -708,6 +715,7 @@ def test_cli_finds_study_from_manifest_prefix(tmp_path):
)
)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
tables = {x[0] for x in db.cursor().execute("show tables").fetchall()}
assert "study_different_name__table" in tables

Expand Down Expand Up @@ -806,6 +814,7 @@ def test_dedicated_schema(tmp_path):
cli.main(cli_args=core_build_args)
cli.main(cli_args=build_args)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
tables = (
db.cursor()
.execute("SELECT table_schema,table_name FROM information_schema.tables")
Expand Down
18 changes: 13 additions & 5 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,22 +223,25 @@ def test_pyarrow_types_from_sql(db, data, expected, raises):
does_not_raise(),
),
(
{**{"db_type": "athena"}, **ATHENA_KWARGS},
{**{"db_type": "athena", "prepare": False}, **ATHENA_KWARGS},
databases.AthenaDatabaseBackend,
does_not_raise(),
),
(
{**{"db_type": "athena", "database": "test"}, **ATHENA_KWARGS},
{**{"db_type": "athena", "database": "test", "prepare": False}, **ATHENA_KWARGS},
databases.AthenaDatabaseBackend,
does_not_raise(),
),
(
{**{"db_type": "athena", "database": "testtwo"}, **ATHENA_KWARGS},
{**{"db_type": "athena", "database": "testtwo", "prepare": False}, **ATHENA_KWARGS},
databases.AthenaDatabaseBackend,
pytest.raises(SystemExit),
),
(
{**{"db_type": "athena", "load_ndjson_dir": "file.json"}, **ATHENA_KWARGS},
{
**{"db_type": "athena", "load_ndjson_dir": "file.json", "prepare": False},
**ATHENA_KWARGS,
},
databases.AthenaDatabaseBackend,
pytest.raises(SystemExit),
),
Expand All @@ -253,6 +256,7 @@ def test_pyarrow_types_from_sql(db, data, expected, raises):
def test_create_db_backend(args, expected_type, raises):
with raises:
db, schema = databases.create_db_backend(args)
db.connect()
assert isinstance(db, expected_type)
if args.get("schema_name"):
assert args["schema_name"] == schema
Expand Down Expand Up @@ -347,6 +351,7 @@ def test_upload_file_athena(mock_botocore, args, sse, keycount, expected, raises
mock_clientobj.get_work_group.return_value = mock_data
mock_clientobj.list_objects_v2.return_value = {"KeyCount": keycount}
db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db.connect()
with raises:
location = db.upload_file(**args)
assert location == expected
Expand Down Expand Up @@ -383,6 +388,7 @@ def test_athena_pandas_cursor(mock_pyathena):
(None, "B", None, None, None),
)
db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db.connect()
res, desc = db.execute_as_pandas("ignored query")
assert res.equals(
pandas.DataFrame(
Expand All @@ -398,6 +404,7 @@ def test_athena_pandas_cursor(mock_pyathena):
@mock.patch("pyathena.connect")
def test_athena_parser(mock_pyathena):
db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db.connect()
parser = db.parser()
assert isinstance(parser, databases.AthenaParser)

Expand All @@ -411,7 +418,8 @@ def test_athena_parser(mock_pyathena):
@mock.patch("pyathena.connect")
def test_athena_env_var_priority(mock_pyathena):
os.environ["AWS_ACCESS_KEY_ID"] = "secret"
databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db = databases.AthenaDatabaseBackend(**ATHENA_KWARGS)
db.connect()
assert mock_pyathena.call_args[1]["aws_access_key_id"] == "secret"


Expand Down
1 change: 1 addition & 0 deletions tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_discovery(tmp_path):
)
)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
cursor = db.cursor()
table_rows, cols = conftest.get_sorted_table_data(cursor, "discovery__code_sources")
table_rows = [tuple(x or "" for x in row) for row in table_rows]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_duckdb_core_build_and_export(tmp_path):
)
def test_duckdb_from_iso8601_timestamp(timestamp, expected):
db = databases.DuckDatabaseBackend(":memory:")
db.connect()
parsed = db.cursor().execute(f"select from_iso8601_timestamp('{timestamp}')").fetchone()[0]
assert parsed == expected

Expand Down Expand Up @@ -95,6 +96,7 @@ def test_duckdb_load_ndjson_dir(tmp_path):
def test_duckdb_table_schema():
"""Verify we can detect schemas correctly, even for nested camel case fields"""
db = databases.DuckDatabaseBackend(":memory:")
db.connect()

with tempfile.TemporaryDirectory() as tmpdir:
os.mkdir(f"{tmpdir}/observation")
Expand Down
1 change: 1 addition & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def test_migrate_transactions_athena(mock_pyathena):
profile="test",
schema_name="test",
)
db.connect()
manifest = study_manifest.StudyManifest("./tests/test_data/study_valid/")
config = base_utils.StudyConfig(schema="test", db=db)
log_utils.log_transaction(
Expand Down
1 change: 1 addition & 0 deletions tests/test_static_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_static_file(tmp_path):
)
)
db = databases.DuckDatabaseBackend(f"{tmp_path}/duck.db")
db.connect()
cursor = db.cursor()
table_rows, cols = conftest.get_sorted_table_data(cursor, "study_static_file__table")
expected_cols = {"CUI", "TTY", "CODE", "SAB", "STR"}
Expand Down
1 change: 1 addition & 0 deletions tests/testbed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def build(self, study="core") -> duckdb.DuckDBPyConnection:
"db_type": "duckdb",
"database": db_file,
"load_ndjson_dir": str(self.path),
"prepare": False,
}
)
config = base_utils.StudyConfig(
Expand Down
Loading