-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Breakout connection from athena DB init #328
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,19 +32,22 @@ def __init__(self, region: str, work_group: str, profile: str, schema_name: str) | |
self.work_group = work_group | ||
self.profile = profile | ||
self.schema_name = schema_name | ||
self.connection = None | ||
|
||
def connect(self): | ||
# the profile may not be required, provided the above three AWS env vars | ||
# are set. If both are present, the env vars take precedence | ||
connect_kwargs = {} | ||
if self.profile is not None: | ||
connect_kwargs["profile_name"] = self.profile | ||
|
||
for aws_env_name in [ | ||
"AWS_ACCESS_KEY_ID", | ||
"AWS_SECRET_ACCESS_KEY", | ||
"AWS_SESSION_TOKEN", | ||
]: | ||
if aws_env_val := os.environ.get(aws_env_name): | ||
connect_kwargs[aws_env_name.lower()] = aws_env_val | ||
|
||
self.connection = pyathena.connect( | ||
region_name=self.region, | ||
work_group=self.work_group, | ||
|
@@ -102,8 +105,11 @@ def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list: | |
output.append((column[0], pyarrow.int64())) | ||
case "double": | ||
output.append((column[0], pyarrow.float64())) | ||
# This is future proofing - we don't see this type currently. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought we did see it in one edge case - I don't think a normal production flow, but it might have been when dealing with flat tables? Or I'm making it up. |
||
case "decimal": | ||
output.append((column[0], pyarrow.decimal128(column[4], column[5]))) | ||
output.append( # pragma: no cover | ||
(column[0], pyarrow.decimal128(column[4], column[5])) | ||
) | ||
case "boolean": | ||
output.append((column[0], pyarrow.bool_())) | ||
case "date": | ||
|
@@ -168,7 +174,8 @@ def create_schema(self, schema_name) -> None: | |
glue_client.create_database(DatabaseInput={"Name": schema_name}) | ||
|
||
def close(self) -> None: | ||
return self.connection.close() # pragma: no cover | ||
if self.connection is not None: # pragma: no cover | ||
self.connection.close() | ||
|
||
|
||
class AthenaParser(base.DatabaseParser): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,6 +120,10 @@ def __init__(self, schema_name: str): | |
# technology | ||
self.db_type = None | ||
|
||
@abc.abstractmethod | ||
def connect(self): | ||
"""Initiates connection configuration of the database""" | ||
Comment on lines
+123
to
+125
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we wanted to be Extremely Python ™️, we would probably write this abstract class as a context manager, so you could do:
To automatically get connect/close support (and make sure that callers don't forget one of the calls, even when an exception happens). But not necessary right now and can often be a little annoying to reorganize code to support that style (but does have some benefits - mainly the exception handling). |
||
|
||
@abc.abstractmethod | ||
def cursor(self) -> DatabaseCursor: | ||
"""Returns a connection to the backing database""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,14 +78,12 @@ def create_db_backend(args: dict[str, str]) -> (base.DatabaseBackend, str): | |
# TODO: reevaluate as DuckDB's local schema support evolves. | ||
# https://duckdb.org/docs/sql/statements/set.html#syntax | ||
if not (args.get("schema_name") is None or args["schema_name"] == "main"): | ||
print( | ||
print( # pragma: no cover | ||
"Warning - local schema names are not yet supported by duckDB's " | ||
"python library - using 'main' instead" | ||
) | ||
schema_name = "main" | ||
backend = duckdb.DuckDatabaseBackend(args["database"]) | ||
if load_ndjson_dir: | ||
backend.insert_tables(read_ndjson_dir(load_ndjson_dir)) | ||
elif db_config.db_type == "athena": | ||
if ( | ||
args.get("schema_name") is not None | ||
|
@@ -110,5 +108,10 @@ def create_db_backend(args: dict[str, str]) -> (base.DatabaseBackend, str): | |
sys.exit("Loading an ndjson dir is not supported with --db-type=athena.") | ||
else: | ||
raise errors.CumulusLibraryError(f"'{db_config.db_type}' is not a supported database.") | ||
|
||
if "prepare" not in args.keys(): | ||
backend.connect() | ||
elif not args["prepare"]: | ||
backend.connect() | ||
Comment on lines
+111
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
if backend.connection is not None and db_config.db_type == "duckdb" and load_ndjson_dir: | ||
backend.insert_tables(read_ndjson_dir(load_ndjson_dir)) | ||
return (backend, schema_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is lightly out of date now, since it refers to "above three AWS env vars" and it got moved around