Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🪛 Moving to SQLAlchemy 2.0 #540

Merged
merged 10 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:

strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]

services:
mysql:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ values = [
]
await database.execute_many(query=query, values=values)

# Run a database query.
# Run a database query.
query = "SELECT * FROM HighScores"
rows = await database.fetch_all(query=query)
print('High Scores:', rows)
Expand Down Expand Up @@ -115,4 +115,4 @@ for examples of how to start using databases together with SQLAlchemy core expre
[quart]: https://gitlab.com/pgjones/quart
[aiohttp]: https://github.com/aio-libs/aiohttp
[tornado]: https://github.com/tornadoweb/tornado
[fastapi]: https://github.com/tiangolo/fastapi
[fastapi]: https://github.com/tiangolo/fastapi
69 changes: 44 additions & 25 deletions databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
import uuid

import aiopg
from aiopg.sa.engine import APGCompiler_psycopg2
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.row import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.core import DatabaseURL
from databases.backends.common.records import Record, Row, create_column_maps
from databases.backends.compilers.psycopg import PGCompiler_psycopg
from databases.backends.dialects.psycopg import PGDialect_psycopg
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record,
Record as RecordInterface,
TransactionBackend,
)

Expand All @@ -34,10 +35,10 @@ def __init__(
self._pool: typing.Union[aiopg.Pool, None] = None

def _get_dialect(self) -> Dialect:
dialect = PGDialect_psycopg2(
dialect = PGDialect_psycopg(
json_serializer=json.dumps, json_deserializer=lambda x: x
)
dialect.statement_compiler = APGCompiler_psycopg2
dialect.statement_compiler = PGCompiler_psycopg
dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
Expand Down Expand Up @@ -117,50 +118,55 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect

cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
rows = await cursor.fetchall()
metadata = CursorResultMetaData(context, cursor.description)
return [
rows = [
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]
return [Record(row, result_columns, dialect, column_maps) for row in rows]
finally:
cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
row = await cursor.fetchone()
if row is None:
return None
metadata = CursorResultMetaData(context, cursor.description)
return Row(
row = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
return Record(row, result_columns, dialect, column_maps)
finally:
cursor.close()

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, _, _ = self._compile(query)
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
Expand All @@ -173,7 +179,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
cursor = await self._connection.cursor()
try:
for single_query in queries:
single_query, args, context = self._compile(single_query)
single_query, args, _, _ = self._compile(single_query)
await cursor.execute(single_query, args)
finally:
cursor.close()
Expand All @@ -182,36 +188,37 @@ async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield Row(
record = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
yield Record(record, result_columns, dialect, column_maps)
finally:
cursor.close()

def transaction(self) -> TransactionBackend:
return AiopgTransaction(self)

def _compile(
self, query: ClauseElement
) -> typing.Tuple[str, dict, CompilationContext]:
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
)

execution_context = self._dialect.execution_ctx_cls()
execution_context.dialect = self._dialect

if not isinstance(query, DDLElement):
compiled_params = sorted(compiled.params.items())

args = compiled.construct_params()
for key, val in args.items():
if key in compiled._bind_processors:
Expand All @@ -224,11 +231,23 @@ def _compile(
compiled._ad_hoc_textual,
compiled._loose_column_name_matching,
)

mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping
result_map = compiled._result_columns

else:
args = {}
result_map = None
compiled_query = compiled.string

logger.debug("Query: %s\nArgs: %s", compiled.string, args)
return compiled.string, args, CompilationContext(execution_context)
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
logger.debug(
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
)
return compiled.string, args, result_map, CompilationContext(execution_context)

@property
def raw_connection(self) -> aiopg.connection.Connection:
Expand Down
63 changes: 41 additions & 22 deletions databases/backends/asyncmy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from sqlalchemy.dialects.mysql import pymysql
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.row import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, Row, create_column_maps
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record,
Record as RecordInterface,
TransactionBackend,
)

Expand Down Expand Up @@ -108,50 +108,57 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect

async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
rows = await cursor.fetchall()
metadata = CursorResultMetaData(context, cursor.description)
return [
rows = [
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]
return [
Record(row, result_columns, dialect, column_maps) for row in rows
]
finally:
await cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
row = await cursor.fetchone()
if row is None:
return None
metadata = CursorResultMetaData(context, cursor.description)
return Row(
row = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
return Record(row, result_columns, dialect, column_maps)
finally:
await cursor.close()

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, _, _ = self._compile(query)
async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
Expand All @@ -166,7 +173,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
async with self._connection.cursor() as cursor:
try:
for single_query in queries:
single_query, args, context = self._compile(single_query)
single_query, args, _, _ = self._compile(single_query)
await cursor.execute(single_query, args)
finally:
await cursor.close()
Expand All @@ -175,36 +182,37 @@ async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
dialect = self._dialect
async with self._connection.cursor() as cursor:
try:
await cursor.execute(query_str, args)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield Row(
record = Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
yield Record(record, result_columns, dialect, column_maps)
finally:
await cursor.close()

def transaction(self) -> TransactionBackend:
return AsyncMyTransaction(self)

def _compile(
self, query: ClauseElement
) -> typing.Tuple[str, dict, CompilationContext]:
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
)

execution_context = self._dialect.execution_ctx_cls()
execution_context.dialect = self._dialect

if not isinstance(query, DDLElement):
compiled_params = sorted(compiled.params.items())

args = compiled.construct_params()
for key, val in args.items():
if key in compiled._bind_processors:
Expand All @@ -217,12 +225,23 @@ def _compile(
compiled._ad_hoc_textual,
compiled._loose_column_name_matching,
)

mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping
result_map = compiled._result_columns

else:
args = {}
result_map = None
compiled_query = compiled.string

query_message = compiled.string.replace(" \n", " ").replace("\n", " ")
logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA)
return compiled.string, args, CompilationContext(execution_context)
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
logger.debug(
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
)
return compiled.string, args, result_map, CompilationContext(execution_context)

@property
def raw_connection(self) -> asyncmy.connection.Connection:
Expand Down
Empty file.
Loading
Loading