Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sqla2 integration #420

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
16 changes: 10 additions & 6 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
version: 2
updates:
- package-ecosystem: pip
directory: "/"
schedule:
interval: daily
time: "04:00"
open-pull-requests-limit: 100
- package-ecosystem: pip
directory: "/"
schedule:
interval: weekly
open-pull-requests-limit: 100
- package-ecosystem: "github-actions"
open-pull-requests-limit: 99
directory: "/"
schedule:
interval: weekly
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
python setup.py sdist bdist_wheel
- name: Publish a Python distribution to PyPI
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@master
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.pypi_password }}
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ all: clean test dists

.PHONY: test
test:
pytest
pytest -v

dists:
python setup.py sdist
Expand Down
2 changes: 1 addition & 1 deletion dataset/chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,6 @@ def flush(self):
if self.callback is not None:
self.callback(self.queue)
self.queue.sort(key=dict.keys)
for fields, items in itertools.groupby(self.queue, key=dict.keys):
for _, items in itertools.groupby(self.queue, key=dict.keys):
self.table.update_many(list(items), self.keys)
super().flush()
62 changes: 24 additions & 38 deletions dataset/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(

self.lock = threading.RLock()
self.local = threading.local()
self.connections = {}

if len(parsed_url.query):
query = parse_qs(parsed_url.query)
Expand All @@ -54,9 +53,10 @@ def __init__(
schema = schema_qs.pop()

self.schema = schema
self.engine = create_engine(url, **engine_kwargs)
self.is_postgres = self.engine.dialect.name == "postgresql"
self.is_sqlite = self.engine.dialect.name == "sqlite"
self._engine = create_engine(url, **engine_kwargs)
self.is_postgres = self._engine.dialect.name == "postgresql"
self.is_sqlite = self._engine.dialect.name == "sqlite"
self.is_mysql = "mysql" in self._engine.dialect.dbapi.__name__
if on_connect_statements is None:
on_connect_statements = []

Expand All @@ -72,7 +72,7 @@ def _run_on_connect(dbapi_con, con_record):
on_connect_statements.append("PRAGMA journal_mode=WAL")

if len(on_connect_statements):
event.listen(self.engine, "connect", _run_on_connect)
event.listen(self._engine, "connect", _run_on_connect)

self.types = Types(is_postgres=self.is_postgres)
self.url = url
Expand All @@ -81,39 +81,38 @@ def _run_on_connect(dbapi_con, con_record):
self._tables = {}

@property
def executable(self):
def conn(self):
"""Connection against which statements will be executed."""
with self.lock:
tid = threading.get_ident()
if tid not in self.connections:
self.connections[tid] = self.engine.connect()
return self.connections[tid]
try:
return self.local.conn
except AttributeError:
self.local.conn = self._engine.connect()
self.local.conn.begin()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems calling conn.begin() here means that there is only ever one transaction opened, even in presence of multiple table.insert() calls.

IMHO it would be better to have an explicit semantic where db.conn just returns the connection and transaction context is handled by the respective operations (insert, etc.). There should be a distinction between the autocommit case (user did not request a transaction) and the explicit commit case (user started the transaction).

return self.local.conn

@property
def op(self):
"""Get an alembic operations context."""
ctx = MigrationContext.configure(self.executable)
ctx = MigrationContext.configure(self.conn)
return Operations(ctx)

@property
def inspect(self):
"""Get a SQLAlchemy inspector."""
return inspect(self.executable)
return inspect(self.conn)

def has_table(self, name):
return self.inspect.has_table(name, schema=self.schema)

@property
def metadata(self):
"""Return a SQLAlchemy schema cache object."""
return MetaData(schema=self.schema, bind=self.executable)
return MetaData(schema=self.schema)

@property
def in_transaction(self):
"""Check if this database is in a transactional context."""
if not hasattr(self.local, "tx"):
return False
return len(self.local.tx) > 0
return self.conn.in_transaction()

def _flush_tables(self):
"""Clear the table metadata after transaction rollbacks."""
Expand All @@ -125,32 +124,22 @@ def begin(self):

No data will be written until the transaction has been committed.
"""
if not hasattr(self.local, "tx"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't this remove support for nested transactions? as in https://dataset.readthedocs.io/en/latest/quickstart.html#using-transactions

self.local.tx = []
self.local.tx.append(self.executable.begin())
if not self.in_transaction:
self.conn.begin()

def commit(self):
"""Commit the current transaction.

Make all statements executed since the transaction was begun permanent.
"""
if hasattr(self.local, "tx") and self.local.tx:
tx = self.local.tx.pop()
tx.commit()
# Removed in 2020-12, I'm a bit worried this means that some DDL
# operations in transactions won't cause metadata to refresh any
# more:
# self._flush_tables()
self.conn.commit()

def rollback(self):
"""Roll back the current transaction.

Discard all statements executed since the transaction was begun.
"""
if hasattr(self.local, "tx") and self.local.tx:
tx = self.local.tx.pop()
tx.rollback()
self._flush_tables()
self.conn.rollback()

def __enter__(self):
"""Start a transaction."""
Expand All @@ -170,13 +159,10 @@ def __exit__(self, error_type, error_value, traceback):

def close(self):
"""Close database connections. Makes this object unusable."""
with self.lock:
for conn in self.connections.values():
conn.close()
self.connections.clear()
self.engine.dispose()
self.local = threading.local()
self._engine.dispose()
self._tables = {}
self.engine = None
self._engine = None

@property
def tables(self):
Expand Down Expand Up @@ -322,7 +308,7 @@ def query(self, query, *args, **kwargs):
_step = kwargs.pop("_step", QUERY_STEP)
if _step is False or _step == 0:
_step = None
rp = self.executable.execute(query, *args, **kwargs)
rp = self.conn.execute(query, *args, **kwargs)
return ResultIter(rp, row_type=self.row_type, step=_step)

def __repr__(self):
Expand Down
62 changes: 34 additions & 28 deletions dataset/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def insert(self, row, ensure=None, types=None):
Returns the inserted row's primary key.
"""
row = self._sync_columns(row, ensure, types=types)
res = self.db.executable.execute(self.table.insert(row))
res = self.db.conn.execute(self.table.insert(), row)
# if not self.db.in_transaction:
Copy link

@miraculixx miraculixx Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't db.commit() be called here? Otherwise the transaction stays open (started by db.conn on first connect) until the user calls db.commit(). Perhaps instead of calling db.conn.execute() directly, there should be a db.execute() that handles transactions implicitly, i.e. distinct between user created transactions and autocommit.

# self.db.conn.commit()
if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0]
return True
Expand Down Expand Up @@ -181,7 +183,7 @@ def insert_many(self, rows, chunk_size=1000, ensure=None, types=None):
# Insert when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
chunk = pad_chunk_columns(chunk, columns)
self.table.insert().execute(chunk)
self.db.conn.execute(self.table.insert(), chunk)
chunk = []

def update(self, row, keys, ensure=None, types=None, return_count=False):
Expand All @@ -206,8 +208,8 @@ def update(self, row, keys, ensure=None, types=None, return_count=False):
clause = self._args_to_clause(args)
if not len(row):
return self.count(clause)
stmt = self.table.update(whereclause=clause, values=row)
rp = self.db.executable.execute(stmt)
stmt = self.table.update().where(clause).values(row)
rp = self.db.conn.execute(stmt)
if rp.supports_sane_rowcount():
return rp.rowcount
if return_count:
Expand Down Expand Up @@ -241,11 +243,12 @@ def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
# Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys]
stmt = self.table.update(
whereclause=and_(True, *cl),
values={col: bindparam(col, required=False) for col in columns},
stmt = (
self.table.update()
.where(and_(True, *cl))
.values({col: bindparam(col, required=False) for col in columns})
)
self.db.executable.execute(stmt, chunk)
self.db.conn.execute(stmt, chunk)
chunk = []

def upsert(self, row, keys, ensure=None, types=None):
Expand Down Expand Up @@ -293,8 +296,8 @@ def delete(self, *clauses, **filters):
if not self.exists:
return False
clause = self._args_to_clause(filters, clauses=clauses)
stmt = self.table.delete(whereclause=clause)
rp = self.db.executable.execute(stmt)
stmt = self.table.delete().where(clause)
rp = self.db.conn.execute(stmt)
return rp.rowcount > 0

def _reflect_table(self):
Expand All @@ -303,7 +306,10 @@ def _reflect_table(self):
self._columns = None
try:
self._table = SQLATable(
self.name, self.db.metadata, schema=self.db.schema, autoload=True
self.name,
self.db.metadata,
schema=self.db.schema,
autoload_with=self.db.conn,
)
except NoSuchTableError:
self._table = None
Expand Down Expand Up @@ -345,15 +351,15 @@ def _sync_table(self, columns):
for column in columns:
if not column.name == self._primary_id:
self._table.append_column(column)
self._table.create(self.db.executable, checkfirst=True)
self._table.create(self.db.conn, checkfirst=True)
self._columns = None
elif len(columns):
with self.db.lock:
self._reflect_table()
self._threading_warn()
for column in columns:
if not self.has_column(column.name):
self.db.op.add_column(self.name, column, self.db.schema)
self.db.op.add_column(self.name, column, schema=self.db.schema)
self._reflect_table()

def _sync_columns(self, row, ensure, types=None):
Expand All @@ -378,6 +384,7 @@ def _sync_columns(self, row, ensure, types=None):
_type = self.db.types.guess(value)
sync_columns[name] = Column(name, _type)
out[name] = value

self._sync_table(sync_columns.values())
return out

Expand Down Expand Up @@ -500,7 +507,7 @@ def drop_column(self, name):
table.drop_column('created_at')

"""
if self.db.engine.dialect.name == "sqlite":
if self.db.is_sqlite:
raise RuntimeError("SQLite does not support dropping columns.")
name = self._get_column_name(name)
with self.db.lock:
Expand All @@ -509,7 +516,7 @@ def drop_column(self, name):
return

self._threading_warn()
self.db.op.drop_column(self.table.name, name, self.table.schema)
self.db.op.drop_column(self.table.name, name, schema=self.table.schema)
self._reflect_table()

def drop(self):
Expand All @@ -520,7 +527,7 @@ def drop(self):
with self.db.lock:
if self.exists:
self._threading_warn()
self.table.drop(self.db.executable, checkfirst=True)
self.table.drop(self.db.conn, checkfirst=True)
self._table = None
self._columns = None
self.db._tables.pop(self.name, None)
Expand Down Expand Up @@ -581,7 +588,7 @@ def create_index(self, columns, name=None, **kw):
kw["mysql_length"] = mysql_length

idx = Index(name, *columns, **kw)
idx.create(self.db.executable)
idx.create(self.db.conn)

def find(self, *_clauses, **kwargs):
"""Perform a simple search on the table.
Expand Down Expand Up @@ -625,14 +632,13 @@ def find(self, *_clauses, **kwargs):

order_by = self._args_to_order_by(order_by)
args = self._args_to_clause(kwargs, clauses=_clauses)
query = self.table.select(whereclause=args, limit=_limit, offset=_offset)
query = self.table.select().where(args).limit(_limit).offset(_offset)
if len(order_by):
query = query.order_by(*order_by)

conn = self.db.executable
conn = self.db.conn
if _streamed:
conn = self.db.engine.connect()
conn = conn.execution_options(stream_results=True)
conn = self.db._engine.connect().execution_options(stream_results=True)

return ResultIter(conn.execute(query), row_type=self.db.row_type, step=_step)

Expand Down Expand Up @@ -666,9 +672,9 @@ def count(self, *_clauses, **kwargs):
return 0

args = self._args_to_clause(kwargs, clauses=_clauses)
query = select([func.count()], whereclause=args)
query = select(func.count()).where(args)
query = query.select_from(self.table)
rp = self.db.executable.execute(query)
rp = self.db.conn.execute(query)
return rp.fetchone()[0]

def __len__(self):
Expand Down Expand Up @@ -703,11 +709,11 @@ def distinct(self, *args, **_filter):
if not len(columns):
return iter([])

q = expression.select(
columns,
distinct=True,
whereclause=clause,
order_by=[c.asc() for c in columns],
q = (
expression.select(*columns)
.where(clause)
.group_by(*columns)
.order_by(*(c.asc() for c in columns))
)
return self.db.query(q)

Expand Down
Loading