Skip to content

Commit

Permalink
ci(athena): use a different database for each python version to avoid…
Browse files Browse the repository at this point in the history
… clobbering data (#10930)

Ensure that Athena test runs do not clobber each other by creating a database per user+python-version pair.
  • Loading branch information
cpcloud authored Mar 3, 2025
1 parent a220b47 commit 1829169
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 27 deletions.
40 changes: 26 additions & 14 deletions ibis/backends/athena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,17 @@ def do_connect(
s3_staging_dir: str,
cursor_class: type[Cursor] = ArrowCursor,
memtable_volume: str | None = None,
schema_name: str = "default",
catalog_name: str = "awsdatacatalog",
**config: Any,
) -> None:
"""Create an Ibis client connected to an Amazon Athena instance."""
self.con = pyathena.connect(
s3_staging_dir=s3_staging_dir, cursor_class=cursor_class, **config
s3_staging_dir=s3_staging_dir,
cursor_class=cursor_class,
schema_name=schema_name,
catalog_name=catalog_name,
**config,
)

if memtable_volume is None:
Expand Down Expand Up @@ -441,23 +447,29 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
pass

def _finalize_memtable(self, name: str) -> None:
path = f"{self._memtable_volume_path}/{name}"
sql = sge.Drop(
kind="TABLE",
this=sg.to_identifier(name, quoted=self.compiler.quoted),
exists=True,
)

with self._safe_raw_sql(sql, unload=False):
pass

self._fs.rm(path, recursive=True)
self.drop_table(name, force=True)
self._fs.rm(f"{self._memtable_volume_path}/{name}", recursive=True)

def create_database(
self, name: str, /, *, catalog: str | None = None, force: bool = False
self,
name: str,
/,
*,
location: str | None = None,
catalog: str | None = None,
force: bool = False,
) -> None:
name = sg.table(name, catalog=catalog, quoted=self.compiler.quoted)
sql = sge.Create(this=name, kind="SCHEMA", exists=force)
sql = sge.Create(
this=name,
kind="SCHEMA",
exists=force,
properties=None
if location is None
else sge.Properties(
expressions=[sge.LocationProperty(this=sge.convert(location))]
),
)
with self._safe_raw_sql(sql, unload=False):
pass

Expand Down
41 changes: 28 additions & 13 deletions ibis/backends/athena/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import concurrent.futures
import getpass
import sys
import uuid
from os import environ as env
from typing import TYPE_CHECKING, Any

Expand All @@ -30,6 +29,9 @@
IBIS_ATHENA_S3_STAGING_DIR = env.get(
"IBIS_ATHENA_S3_STAGING_DIR", "s3://aws-athena-query-results-ibis-testing"
)
IBIS_ATHENA_TEST_DATABASE = (
f"{getpass.getuser()}_{''.join(map(str, sys.version_info[:3]))}"
)
AWS_REGION = env.get("AWS_REGION", "us-east-2")
AWS_PROFILE = env.get("AWS_PROFILE")
CONNECT_ARGS = dict(
Expand All @@ -49,7 +51,10 @@ def create_table(connection, *, fs: s3fs.S3FileSystem, file: Path, folder: str)

ddl = sge.Create(
kind="TABLE",
this=sge.Schema(this=sg.table(name), expressions=sg_schema),
this=sge.Schema(
this=sg.table(name, db=IBIS_ATHENA_TEST_DATABASE, quoted=True),
expressions=sg_schema,
),
properties=sge.Properties(
expressions=[
sge.ExternalProperty(),
Expand All @@ -61,16 +66,19 @@ def create_table(connection, *, fs: s3fs.S3FileSystem, file: Path, folder: str)

fs.put(str(file), f"{folder.removeprefix('s3://')}/{name}/{file.name}")

drop_query = sge.Drop(kind="TABLE", this=sg.to_identifier(name, quoted=True)).sql(
Athena
)
drop_query = sge.Drop(
kind="TABLE", this=sg.table(name, db=IBIS_ATHENA_TEST_DATABASE), exists=True
).sql(Athena)

create_query = ddl.sql(Athena)

with connection.con.cursor() as cur:
cur.execute(drop_query)
cur.execute(create_query)

assert connection.table(name).count().execute() > 0
assert (
connection.table(name, database=IBIS_ATHENA_TEST_DATABASE).count().execute() > 0
)


class TestConf(BackendTest):
Expand All @@ -82,35 +90,42 @@ class TestConf(BackendTest):

deps = ("pyathena", "fsspec")

@staticmethod
def format_table(name: str) -> str:
return sg.table(name, db=IBIS_ATHENA_TEST_DATABASE, quoted=True).sql(Athena)

def _load_data(self, **_: Any) -> None:
import fsspec

files = self.data_dir.joinpath("parquet").glob("*.parquet")

user = getpass.getuser()
python_version = "".join(map(str, sys.version_info[:3]))
folder = f"{user}_{python_version}_{uuid.uuid4()}"

fs = fsspec.filesystem("s3")

connection = self.connection
folder = f"{IBIS_ATHENA_S3_STAGING_DIR}/{folder}"
db_dir = f"{IBIS_ATHENA_S3_STAGING_DIR}/{IBIS_ATHENA_TEST_DATABASE}"

connection.create_database(
IBIS_ATHENA_TEST_DATABASE, location=db_dir, force=True
)

with concurrent.futures.ThreadPoolExecutor() as executor:
for future in concurrent.futures.as_completed(
executor.submit(
create_table, connection, fs=fs, file=file, folder=folder
create_table, connection, fs=fs, file=file, folder=db_dir
)
for file in files
):
future.result()

def postload(self, **kw):
self.connection = self.connect(schema_name=IBIS_ATHENA_TEST_DATABASE, **kw)

@staticmethod
def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
return ibis.athena.connect(**CONNECT_ARGS, **kw)

def _remap_column_names(self, table_name: str) -> dict[str, str]:
table = self.connection.table(table_name)
table = self.connection.table(table_name, database=IBIS_ATHENA_TEST_DATABASE)
return table.rename(
dict(zip(TEST_TABLES[table_name].names, table.schema().names))
)
Expand Down

0 comments on commit 1829169

Please sign in to comment.