From 84ad0e27877c7fdc25338e6790d97ad88660ee0e Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Fri, 20 Jan 2023 16:43:49 +0300 Subject: [PATCH] ydb sqlalchemy experiment --- .gitignore | 4 + examples/_sqlalchemy_example/example.py | 229 ++++++++++++++ examples/_sqlalchemy_example/fill_tables.py | 83 +++++ examples/_sqlalchemy_example/models.py | 34 ++ test-requirements.txt | 2 + tests/sqlalchemy/conftest.py | 22 ++ tests/sqlalchemy/test_dbapi.py | 84 +++++ tests/sqlalchemy/test_sqlalchemy.py | 27 ++ tox.ini | 8 +- ydb/_dbapi/__init__.py | 36 +++ ydb/_dbapi/connection.py | 73 +++++ ydb/_dbapi/cursor.py | 172 +++++++++++ ydb/_dbapi/errors.py | 92 ++++++ ydb/_sqlalchemy/__init__.py | 324 ++++++++++++++++++++ ydb/_sqlalchemy/types.py | 28 ++ 15 files changed, 1217 insertions(+), 1 deletion(-) create mode 100644 examples/_sqlalchemy_example/example.py create mode 100644 examples/_sqlalchemy_example/fill_tables.py create mode 100644 examples/_sqlalchemy_example/models.py create mode 100644 tests/sqlalchemy/conftest.py create mode 100644 tests/sqlalchemy/test_dbapi.py create mode 100644 tests/sqlalchemy/test_sqlalchemy.py create mode 100644 ydb/_dbapi/__init__.py create mode 100644 ydb/_dbapi/connection.py create mode 100644 ydb/_dbapi/cursor.py create mode 100644 ydb/_dbapi/errors.py create mode 100644 ydb/_sqlalchemy/__init__.py create mode 100644 ydb/_sqlalchemy/types.py diff --git a/.gitignore b/.gitignore index 45896947..55c4ea54 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,8 @@ ydb.egg-info/ /tox /venv /ydb_certs +/ydb_data /tmp +.coverage +/cov_html +/build diff --git a/examples/_sqlalchemy_example/example.py b/examples/_sqlalchemy_example/example.py new file mode 100644 index 00000000..00cd80d3 --- /dev/null +++ b/examples/_sqlalchemy_example/example.py @@ -0,0 +1,229 @@ +import datetime +import logging +import argparse +import sqlalchemy as sa +from sqlalchemy import orm, exc, sql +from sqlalchemy import Table, Column, Integer, String, Float, TIMESTAMP +from ydb._sqlalchemy import register_dialect + +from fill_tables import fill_all_tables, to_days +from models import Base, Series, Episodes + + +def describe_table(engine, name): + inspect = sa.inspect(engine) + print(f"describe table {name}:") + for col in inspect.get_columns(name): + print(f"\t{col['name']}: {col['type']}") + + +def simple_select(conn): + stm = sa.select(Series).where(Series.series_id == 1) + res = conn.execute(stm) + print(res.one()) + + +def simple_insert(conn): + stm = Episodes.__table__.insert().values( + series_id=3, season_id=6, episode_id=1, title="TBD" + ) + conn.execute(stm) + + +def test_types(conn): + types_tb = Table( + "test_types", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("str", String), + Column("num", Float), + Column("dt", TIMESTAMP), + ) + types_tb.drop(bind=conn.engine, checkfirst=True) + types_tb.create(bind=conn.engine, checkfirst=True) + + stm = types_tb.insert().values( + id=1, + str=b"Hello World!", + num=3.1415, + dt=datetime.datetime.now(), + ) + conn.execute(stm) + + # GROUP BY + stm = sa.select(types_tb.c.str, sa.func.max(types_tb.c.num)).group_by( + types_tb.c.str + ) + rs = conn.execute(stm) + for x in rs: + print(x) + + +def run_example_orm(engine): + Base.metadata.bind = engine + Base.metadata.drop_all() + Base.metadata.create_all() + + session = orm.sessionmaker(bind=engine)() + + rs = session.query(Episodes).all() + for e in rs: + print(f"{e.episode_id}: {e.title}") + + fill_all_tables(session.connection()) + + try: + session.add_all( + [ + Episodes( + series_id=2, + season_id=1, + episode_id=1, + title="Minimum Viable Product", + air_date=to_days("2014-04-06"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=2, + title="The Cap Table", + air_date=to_days("2014-04-13"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=3, + title="Articles of Incorporation", + air_date=to_days("2014-04-20"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=4, + title="Fiduciary Duties", + air_date=to_days("2014-04-27"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=5, + title="Signaling Risk", + air_date=to_days("2014-05-04"), + ), + ] + ) + session.commit() + except exc.DatabaseError: + print("Episodes already added!") + session.rollback() + + rs = session.query(Episodes).all() + for e in rs: + print(f"{e.episode_id}: {e.title}") + + rs = session.query(Episodes).filter(Episodes.title == "abc??").all() + for e in rs: + print(e.title) + + print("Episodes count:", session.query(Episodes).count()) + + max_episode = session.query(sql.expression.func.max(Episodes.episode_id)).scalar() + print("Maximum episodes id:", max_episode) + + session.add( + Episodes( + series_id=2, + season_id=1, + episode_id=max_episode + 1, + title="Signaling Risk", + air_date=to_days("2014-05-04"), + ) + ) + + print("Episodes count:", session.query(Episodes).count()) + + +def run_example_core(engine): + with engine.connect() as conn: + # raw sql + rs = conn.execute("SELECT 1 AS value") + print(rs.fetchone()["value"]) + + fill_all_tables(conn) + + for t in "series seasons episodes".split(): + describe_table(engine, t) + + tb = sa.Table("episodes", sa.MetaData(engine), autoload=True) + stm = ( + sa.select([tb.c.title]) + .where(sa.and_(tb.c.series_id == 1, tb.c.season_id == 3)) + .where(tb.c.title.like("%")) + .order_by(sa.asc(tb.c.title)) + # TODO: limit isn't working now + # .limit(3) + ) + rs = conn.execute(stm) + print(rs.fetchall()) + + simple_select(conn) + + simple_insert(conn) + + # simple join + stm = sa.select( + [Episodes.__table__.join(Series, Episodes.series_id == Series.series_id)] + ).where(sa.and_(Series.series_id == 1, Episodes.season_id == 1)) + rs = conn.execute(stm) + for row in rs: + print(f"{row.series_title}({row.episode_id}): {row.title}") + + rs = conn.execute(sa.select(Episodes).where(Episodes.series_id == 3)) + print(rs.fetchall()) + + # count + cnt = conn.execute(sa.func.count(Episodes.episode_id)).scalar() + print("Episodes cnt:", cnt) + + # simple delete + conn.execute(sa.delete(Episodes).where(Episodes.title == "TBD")) + cnt = conn.execute(sa.func.count(Episodes.episode_id)).scalar() + print("Episodes cnt:", cnt) + + test_types(conn) + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""\033[92mYandex.Database examples sqlalchemy usage.\x1b[0m\n""", + ) + parser.add_argument( + "-d", + "--database", + help="Name of the database to use", + default="/local", + ) + parser.add_argument( + "-e", + "--endpoint", + help="Endpoint url to use", + default="grpc://localhost:2136", + ) + + args = parser.parse_args() + register_dialect() + engine = sa.create_engine( + "yql:///ydb/", + connect_args={"database": args.database, "endpoint": args.endpoint}, + ) + + logging.basicConfig(level=logging.INFO) + logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + + run_example_core(engine) + # run_example_orm(engine) + + +if __name__ == "__main__": + main() diff --git a/examples/_sqlalchemy_example/fill_tables.py b/examples/_sqlalchemy_example/fill_tables.py new file mode 100644 index 00000000..5a9eb954 --- /dev/null +++ b/examples/_sqlalchemy_example/fill_tables.py @@ -0,0 +1,83 @@ +import iso8601 + +import sqlalchemy as sa +from models import Base, Series, Seasons, Episodes + + +def to_days(date): + timedelta = iso8601.parse_date(date) - iso8601.parse_date("1970-1-1") + return timedelta.days + + +def fill_series(conn): + data = [ + ( + 1, + "IT Crowd", + "The IT Crowd is a British sitcom produced by Channel 4, written by Graham Linehan, produced by " + "Ash Atalla and starring Chris O'Dowd, Richard Ayoade, Katherine Parkinson, and Matt Berry.", + to_days("2006-02-03"), + ), + ( + 2, + "Silicon Valley", + "Silicon Valley is an American comedy television series created by Mike Judge, John Altschuler and " + "Dave Krinsky. The series focuses on five young men who founded a startup company in Silicon Valley.", + to_days("2014-04-06"), + ), + ] + conn.execute(sa.insert(Series).values(data)) + + +def fill_seasons(conn): + data = [ + (1, 1, "Season 1", to_days("2006-02-03"), to_days("2006-03-03")), + (1, 2, "Season 2", to_days("2007-08-24"), to_days("2007-09-28")), + (1, 3, "Season 3", to_days("2008-11-21"), to_days("2008-12-26")), + (1, 4, "Season 4", to_days("2010-06-25"), to_days("2010-07-30")), + (2, 1, "Season 1", to_days("2014-04-06"), to_days("2014-06-01")), + (2, 2, "Season 2", to_days("2015-04-12"), to_days("2015-06-14")), + (2, 3, "Season 3", to_days("2016-04-24"), to_days("2016-06-26")), + (2, 4, "Season 4", to_days("2017-04-23"), to_days("2017-06-25")), + (2, 5, "Season 5", to_days("2018-03-25"), to_days("2018-05-13")), + ] + conn.execute(sa.insert(Seasons).values(data)) + + +def fill_episodes(conn): + data = [ + (1, 1, 1, "Yesterday's Jam", to_days("2006-02-03")), + (1, 1, 2, "Calamity Jen", to_days("2006-02-03")), + (1, 1, 3, "Fifty-Fifty", to_days("2006-02-10")), + (1, 1, 4, "The Red Door", to_days("2006-02-17")), + (1, 1, 5, "The Haunting of Bill Crouse", to_days("2006-02-24")), + (1, 1, 6, "Aunt Irma Visits", to_days("2006-03-03")), + (1, 2, 1, "The Work Outing", to_days("2006-08-24")), + (1, 2, 2, "Return of the Golden Child", to_days("2007-08-31")), + (1, 2, 3, "Moss and the German", to_days("2007-09-07")), + (1, 2, 4, "The Dinner Party", to_days("2007-09-14")), + (1, 2, 5, "Smoke and Mirrors", to_days("2007-09-21")), + (1, 2, 6, "Men Without Women", to_days("2007-09-28")), + (1, 3, 1, "From Hell", to_days("2008-11-21")), + (1, 3, 2, "Are We Not Men?", to_days("2008-11-28")), + (1, 3, 3, "Tramps Like Us", to_days("2008-12-05")), + (1, 3, 4, "The Speech", to_days("2008-12-12")), + (1, 3, 5, "Friendface", to_days("2008-12-19")), + (1, 3, 6, "Calendar Geeks", to_days("2008-12-26")), + (1, 4, 1, "Jen The Fredo", to_days("2010-06-25")), + (1, 4, 2, "The Final Countdown", to_days("2010-07-02")), + (1, 4, 3, "Something Happened", to_days("2010-07-09")), + (1, 4, 4, "Italian For Beginners", to_days("2010-07-16")), + (1, 4, 5, "Bad Boys", to_days("2010-07-23")), + (1, 4, 6, "Reynholm vs Reynholm", to_days("2010-07-30")), + ] + conn.execute(sa.insert(Episodes).values(data)) + + +def fill_all_tables(conn): + Base.metadata.drop_all(conn.engine) + Base.metadata.create_all(conn.engine) + + fill_series(conn) + fill_seasons(conn) + fill_episodes(conn) diff --git a/examples/_sqlalchemy_example/models.py b/examples/_sqlalchemy_example/models.py new file mode 100644 index 00000000..a02349a9 --- /dev/null +++ b/examples/_sqlalchemy_example/models.py @@ -0,0 +1,34 @@ +import sqlalchemy.orm as orm +from sqlalchemy import Column, Integer, Unicode + + +Base = orm.declarative_base() + + +class Series(Base): + __tablename__ = "series" + + series_id = Column(Integer, primary_key=True) + title = Column(Unicode) + series_info = Column(Unicode) + release_date = Column(Integer) + + +class Seasons(Base): + __tablename__ = "seasons" + + series_id = Column(Integer, primary_key=True) + season_id = Column(Integer, primary_key=True) + title = Column(Unicode) + first_aired = Column(Integer) + last_aired = Column(Integer) + + +class Episodes(Base): + __tablename__ = "episodes" + + series_id = Column(Integer, primary_key=True) + season_id = Column(Integer, primary_key=True) + episode_id = Column(Integer, primary_key=True) + title = Column(Unicode) + air_date = Column(Integer) diff --git a/test-requirements.txt b/test-requirements.txt index 9f592875..d1ca4276 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -47,3 +47,5 @@ sqlalchemy==1.4.26 pylint-protobuf cython freezegun==1.2.2 +grpcio-tools +pytest-cov diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py new file mode 100644 index 00000000..6ebabac3 --- /dev/null +++ b/tests/sqlalchemy/conftest.py @@ -0,0 +1,22 @@ +import pytest +import sqlalchemy as sa + +from ydb._sqlalchemy import register_dialect + + +@pytest.fixture(scope="module") +def engine(endpoint, database): + register_dialect() + engine = sa.create_engine( + "yql:///ydb/", + connect_args={"database": database, "endpoint": endpoint}, + ) + + yield engine + engine.dispose() + + +@pytest.fixture(scope="module") +def connection(engine): + with engine.connect() as conn: + yield conn diff --git a/tests/sqlalchemy/test_dbapi.py b/tests/sqlalchemy/test_dbapi.py new file mode 100644 index 00000000..407cc9e4 --- /dev/null +++ b/tests/sqlalchemy/test_dbapi.py @@ -0,0 +1,84 @@ +from ydb import _dbapi as dbapi + + +def test_dbapi(endpoint, database): + conn = dbapi.connect(endpoint, database=database) + assert conn + + conn.commit() + conn.rollback() + + cur = conn.cursor() + assert cur + + cur.execute( + "CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", + context={"isddl": True}, + ) + + cur.execute('INSERT INTO test(id, text) VALUES (1, "foo")') + + cur.execute("SELECT id, text FROM test") + assert cur.fetchone() == (1, "foo"), "fetchone is ok" + + cur.execute("SELECT id, text FROM test WHERE id = %(id)s", {"id": 1}) + assert cur.fetchone() == (1, "foo"), "parametrized query is ok" + + cur.execute( + "INSERT INTO test(id, text) VALUES (%(id1)s, %(text1)s), (%(id2)s, %(text2)s)", + {"id1": 2, "text1": "", "id2": 3, "text2": "bar"}, + ) + + cur.execute( + "UPDATE test SET text = %(t)s WHERE id = %(id)s", {"id": 2, "t": "foo2"} + ) + + cur.execute("SELECT id FROM test") + assert cur.fetchall() == [(1,), (2,), (3,)], "fetchall is ok" + + cur.execute("SELECT id FROM test ORDER BY id DESC") + assert cur.fetchmany(2) == [(3,), (2,)], "fetchmany is ok" + assert cur.fetchmany(1) == [(1,)] + + cur.execute("SELECT id FROM test ORDER BY id LIMIT 2") + assert cur.fetchall() == [(1,), (2,)], "limit clause without params is ok" + + # TODO: Failed to convert type: Int64 to Uint64 + # cur.execute("SELECT id FROM test ORDER BY id LIMIT %(limit)s", {"limit": 2}) + # assert cur.fetchall() == [(1,), (2,)], "limit clause with params is ok" + + cur2 = conn.cursor() + cur2.execute( + "INSERT INTO test(id) VALUES (%(id1)s), (%(id2)s)", {"id1": 5, "id2": 6} + ) + + cur.execute("SELECT id FROM test ORDER BY id") + assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,)], "cursor2 commit changes" + + cur.execute("SELECT text FROM test WHERE id > %(min_id)s", {"min_id": 3}) + assert cur.fetchall() == [(None,), (None,)], "NULL returns as None" + + cur.execute("SELECT id, text FROM test WHERE text LIKE %(p)s", {"p": "foo%"}) + assert cur.fetchall() == [(1, "foo"), (2, "foo2")], "like clause works" + + cur.execute( + # DECLARE statement (DECLARE $data AS List>) + # will generate automatically + """INSERT INTO test SELECT id, text FROM AS_TABLE($data);""", + { + "data": [ + {"id": 17, "text": "seventeen"}, + {"id": 21, "text": "twenty one"}, + ] + }, + ) + + cur.execute("SELECT id FROM test ORDER BY id") + assert cur.rowcount == 7, "rowcount ok" + assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,), (17,), (21,)], "ok" + + cur.execute("DROP TABLE test", context={"isddl": True}) + + cur.close() + cur2.close() + conn.close() diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py new file mode 100644 index 00000000..914553ea --- /dev/null +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -0,0 +1,27 @@ +import sqlalchemy as sa +from sqlalchemy import MetaData, Table, Column, Integer, Unicode + +meta = MetaData() + + +def clear_sql(stm): + return stm.replace("\n", " ").replace(" ", " ").strip() + + +def test_sqlalchemy_core(connection): + # raw sql + rs = connection.execute("SELECT 1 AS value") + assert rs.fetchone()["value"] == 1 + + tb_test = Table( + "test", + meta, + Column("id", Integer, primary_key=True), + Column("text", Unicode), + ) + + stm = sa.select(tb_test) + assert clear_sql(str(stm)) == "SELECT test.id, test.text FROM test" + + stm = sa.insert(tb_test).values(id=2, text="foo") + assert clear_sql(str(stm)) == "INSERT INTO test (id, text) VALUES (:id, :text)" diff --git a/tox.ini b/tox.ini index 28181d20..7aca13db 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py,py-proto3,py-tls,py-tls-proto3,style,pylint,black,protoc +envlist = py,py-proto3,py-tls,py-tls-proto3,style,pylint,black,protoc,py-cov minversion = 4.2.6 skipsdist = True ignore_basepython_conflict = true @@ -25,6 +25,12 @@ deps = commands = pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} +[testenv:py-cov] +commands = + pytest -v -m "not tls" \ + --cov-report html:cov_html --cov=ydb \ + --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} + [testenv:py-proto3] commands = pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} diff --git a/ydb/_dbapi/__init__.py b/ydb/_dbapi/__init__.py new file mode 100644 index 00000000..8756b0f2 --- /dev/null +++ b/ydb/_dbapi/__init__.py @@ -0,0 +1,36 @@ +from .connection import Connection +from .errors import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) + +apilevel = "1.0" + +threadsafety = 0 + +paramstyle = "pyformat" + +errors = ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) + + +def connect(*args, **kwargs): + return Connection(*args, **kwargs) diff --git a/ydb/_dbapi/connection.py b/ydb/_dbapi/connection.py new file mode 100644 index 00000000..75bfeb58 --- /dev/null +++ b/ydb/_dbapi/connection.py @@ -0,0 +1,73 @@ +import posixpath + +import ydb +from .cursor import Cursor +from .errors import DatabaseError + + +class Connection: + def __init__(self, endpoint, database=None, **conn_kwargs): + self.endpoint = endpoint + self.database = database + self.driver = self._create_driver(self.endpoint, self.database, **conn_kwargs) + self.pool = ydb.SessionPool(self.driver) + + def cursor(self): + return Cursor(self) + + def describe(self, table_path): + full_path = posixpath.join(self.database, table_path) + try: + res = self.pool.retry_operation_sync( + lambda cli: cli.describe_table(full_path) + ) + return res.columns + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + except Exception: + raise DatabaseError(f"Failed to describe table {table_path}") + + def check_exists(self, table_path): + try: + self.driver.scheme_client.describe_path(table_path) + return True + except ydb.SchemeError: + return False + + def commit(self): + pass + + def rollback(self): + pass + + def close(self): + if self.pool: + self.pool.stop() + if self.driver: + self.driver.stop() + + @staticmethod + def _create_driver(endpoint, database, **conn_kwargs): + # TODO: add cache for initialized drivers/pools? + driver_config = ydb.DriverConfig( + endpoint, + database=database, + table_client_settings=ydb.TableClientSettings() + .with_native_date_in_result_sets(True) + .with_native_datetime_in_result_sets(True) + .with_native_timestamp_in_result_sets(True) + .with_native_interval_in_result_sets(True) + .with_native_json_in_result_sets(True), + **conn_kwargs, + ) + driver = ydb.Driver(driver_config) + try: + driver.wait(timeout=5, fail_fast=True) + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + except Exception: + driver.stop() + raise DatabaseError( + f"Failed to connect to YDB, details {driver.discovery_debug_details()}" + ) + return driver diff --git a/ydb/_dbapi/cursor.py b/ydb/_dbapi/cursor.py new file mode 100644 index 00000000..57659c7a --- /dev/null +++ b/ydb/_dbapi/cursor.py @@ -0,0 +1,172 @@ +import datetime +import itertools +import logging +import uuid +import decimal + +import ydb +from .errors import DatabaseError, ProgrammingError + + +logger = logging.getLogger(__name__) + + +def get_column_type(type_obj): + return str(ydb.convert.type_to_native(type_obj)) + + +def _generate_type_str(value): + tvalue = type(value) + + stype = { + bool: "Bool", + bytes: "String", + str: "Utf8", + int: "Int64", + float: "Double", + decimal.Decimal: "Decimal(22, 9)", + datetime.date: "Date", + datetime.datetime: "Timestamp", + datetime.timedelta: "Interval", + uuid.UUID: "Uuid", + }.get(tvalue) + + if tvalue == dict: + types_lst = ", ".join(f"{k}: {_generate_type_str(v)}" for k, v in value.items()) + stype = f"Struct<{types_lst}>" + + elif tvalue == tuple: + types_lst = ", ".join(_generate_type_str(x) for x in value) + stype = f"Tuple<{types_lst}>" + + elif tvalue == list: + nested_type = _generate_type_str(value[0]) + stype = f"List<{nested_type}>" + + elif tvalue == set: + nested_type = _generate_type_str(next(iter(value))) + stype = f"Set<{nested_type}>" + + if stype is None: + raise ProgrammingError( + "Cannot translate python type to ydb type.", tvalue, value + ) + + return stype + + +def _generate_declare_stms(params: dict) -> str: + return "".join( + f"DECLARE {k} AS {_generate_type_str(t)}; " for k, t in params.items() + ) + + +class Cursor(object): + def __init__(self, connection): + self.connection = connection + self.description = None + self.arraysize = 1 + self.rows = None + self._rows_prefetched = None + + def execute(self, sql, parameters=None, context=None): + self.description = None + sql_params = None + + if parameters: + sql = sql % {k: f"${k}" for k, v in parameters.items()} + sql_params = {f"${k}": v for k, v in parameters.items()} + declare_stms = _generate_declare_stms(sql_params) + sql = f"{declare_stms}{sql}" + + logger.info("execute sql: %s, params: %s", sql, sql_params) + + def _execute_in_pool(cli): + try: + if context and context.get("isddl"): + return cli.execute_scheme(sql) + else: + prepared_query = cli.prepare(sql) + return cli.transaction().execute( + prepared_query, sql_params, commit_tx=True + ) + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + + chunks = self.connection.pool.retry_operation_sync(_execute_in_pool) + rows = self._rows_iterable(chunks) + # Prefetch the description: + try: + first_row = next(rows) + except StopIteration: + pass + else: + rows = itertools.chain((first_row,), rows) + if self.rows is not None: + rows = itertools.chain(self.rows, rows) + + self.rows = rows + + def _rows_iterable(self, chunks_iterable): + try: + for chunk in chunks_iterable: + self.description = [ + ( + col.name, + get_column_type(col.type), + None, + None, + None, + None, + None, + ) + for col in chunk.columns + ] + for row in chunk.rows: + # returns tuple to be compatible with SqlAlchemy and because + # of this PEP to return a sequence: https://www.python.org/dev/peps/pep-0249/#fetchmany + yield row[::] + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + + def _ensure_prefetched(self): + if self.rows is not None and self._rows_prefetched is None: + self._rows_prefetched = list(self.rows) + self.rows = iter(self._rows_prefetched) + return self._rows_prefetched + + def executemany(self, sql, seq_of_parameters): + for parameters in seq_of_parameters: + self.execute(sql, parameters) + + def executescript(self, script): + return self.execute(script) + + def fetchone(self): + if self.rows is None: + return None + return next(self.rows, None) + + def fetchmany(self, size=None): + size = self.arraysize if size is None else size + return list(itertools.islice(self.rows, size)) + + def fetchall(self): + return list(self.rows) + + def nextset(self): + self.fetchall() + + def setinputsizes(self, sizes): + pass + + def setoutputsize(self, column=None): + pass + + def close(self): + self.rows = None + self._rows_prefetched = None + + @property + def rowcount(self): + return len(self._ensure_prefetched()) diff --git a/ydb/_dbapi/errors.py b/ydb/_dbapi/errors.py new file mode 100644 index 00000000..ddb55b4c --- /dev/null +++ b/ydb/_dbapi/errors.py @@ -0,0 +1,92 @@ +class Warning(Exception): + pass + + +class Error(Exception): + def __init__(self, message, issues=None, status=None): + super(Error, self).__init__(message) + + pretty_issues = _pretty_issues(issues) + self.issues = issues + self.message = pretty_issues or message + self.status = status + + +class InterfaceError(Error): + pass + + +class DatabaseError(Error): + pass + + +class DataError(DatabaseError): + pass + + +class OperationalError(DatabaseError): + pass + + +class IntegrityError(DatabaseError): + pass + + +class InternalError(DatabaseError): + pass + + +class ProgrammingError(DatabaseError): + pass + + +class NotSupportedError(DatabaseError): + pass + + +def _pretty_issues(issues): + if issues is None: + return None + + children_messages = [_get_messages(issue, root=True) for issue in issues] + + if None in children_messages: + return None + + return "\n" + "\n".join(children_messages) + + +def _get_messages(issue, max_depth=100, indent=2, depth=0, root=False): + if depth >= max_depth: + return None + + margin_str = " " * depth * indent + pre_message = "" + children = "" + + if issue.issues: + collapsed_messages = [] + while not root and len(issue.issues) == 1: + collapsed_messages.append(issue.message) + issue = issue.issues[0] + + if collapsed_messages: + pre_message = f"{margin_str}{', '.join(collapsed_messages)}\n" + depth += 1 + margin_str = " " * depth * indent + + children_messages = [ + _get_messages(iss, max_depth=max_depth, indent=indent, depth=depth + 1) + for iss in issue.issues + ] + + if None in children_messages: + return None + + children = "\n".join(children_messages) + + return ( + f"{pre_message}{margin_str}{issue.message}\n{margin_str}" + f"severity level: {issue.severity}\n{margin_str}" + f"issue code: {issue.issue_code}\n{children}" + ) diff --git a/ydb/_sqlalchemy/__init__.py b/ydb/_sqlalchemy/__init__.py new file mode 100644 index 00000000..8336a9a8 --- /dev/null +++ b/ydb/_sqlalchemy/__init__.py @@ -0,0 +1,324 @@ +""" +Experimental +Work in progress, breaking changes are possible. +""" +import ydb +import ydb._dbapi as dbapi + +import sqlalchemy as sa +from sqlalchemy import dialects +from sqlalchemy import Table +from sqlalchemy.exc import CompileError +from sqlalchemy.sql import functions, literal_column +from sqlalchemy.sql.compiler import ( + IdentifierPreparer, + GenericTypeCompiler, + SQLCompiler, + DDLCompiler, +) +from sqlalchemy.sql.elements import ClauseList +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.util.compat import inspect_getfullargspec + +from ydb._sqlalchemy.types import UInt32, UInt64 + + +SQLALCHEMY_VERSION = tuple(sa.__version__.split(".")) +SA_14 = SQLALCHEMY_VERSION >= ("1", "4") + + +class YqlIdentifierPreparer(IdentifierPreparer): + def __init__(self, dialect): + super(YqlIdentifierPreparer, self).__init__( + dialect, + initial_quote="`", + final_quote="`", + ) + + def _requires_quotes(self, value): + # Force all identifiers to get quoted unless already quoted. + return not ( + value.startswith(self.initial_quote) and value.endswith(self.final_quote) + ) + + +class YqlTypeCompiler(GenericTypeCompiler): + def visit_VARCHAR(self, type_, **kw): + return "STRING" + + def visit_unicode(self, type_, **kw): + return "UTF8" + + def visit_NVARCHAR(self, type_, **kw): + return "UTF8" + + def visit_TEXT(self, type_, **kw): + return "UTF8" + + def visit_FLOAT(self, type_, **kw): + return "DOUBLE" + + def visit_BOOLEAN(self, type_, **kw): + return "BOOL" + + def visit_uint32(self, type_, **kw): + return "UInt32" + + def visit_uint64(self, type_, **kw): + return "UInt64" + + def visit_uint8(self, type_, **kw): + return "UInt8" + + def visit_INTEGER(self, type_, **kw): + return "Int64" + + def visit_NUMERIC(self, type_, **kw): + return "Int64" + + +class ParametrizedFunction(functions.Function): + __visit_name__ = "parametrized_function" + + def __init__(self, name, params, *args, **kwargs): + super(ParametrizedFunction, self).__init__(name, *args, **kwargs) + self._func_name = name + self._func_params = params + self.params_expr = ClauseList( + operator=functions.operators.comma_op, group_contents=True, *params + ).self_group() + + +class YqlCompiler(SQLCompiler): + def group_by_clause(self, select, **kw): + # Hack to ensure it is possible to define labels in groupby. + kw.update(within_columns_clause=True) + return super(YqlCompiler, self).group_by_clause(select, **kw) + + def visit_lambda(self, lambda_, **kw): + func = lambda_.func + spec = inspect_getfullargspec(func) + + if spec.varargs: + raise CompileError("Lambdas with *args are not supported") + + try: + keywords = spec.keywords + except AttributeError: + keywords = spec.varkw + + if keywords: + raise CompileError("Lambdas with **kwargs are not supported") + + text = "(" + ", ".join("$" + arg for arg in spec.args) + ")" + " -> " + + args = [literal_column("$" + arg) for arg in spec.args] + text += "{ RETURN " + self.process(func(*args), **kw) + " ;}" + + return text + + def visit_parametrized_function(self, func, **kwargs): + name = func.name + name_parts = [] + for name in name.split("::"): + fname = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + or isinstance(name, sa.sql.elements.quoted_name) + else name + ) + + name_parts.append(fname) + + name = "::".join(name_parts) + params = func.params_expr._compiler_dispatch(self, **kwargs) + args = self.function_argspec(func, **kwargs) + return "%(name)s%(params)s%(args)s" % dict(name=name, params=params, args=args) + + def visit_function(self, func, add_to_result_map=None, **kwargs): + # Copypaste of `sa.sql.compiler.SQLCompiler.visit_function` with + # `::` as namespace separator instead of `.` + if add_to_result_map is not None: + add_to_result_map(func.name, func.name, (), func.type) + + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + return disp(func, **kwargs) + else: + name = sa.sql.compiler.FUNCTIONS.get(func.__class__, None) + if name: + if func._has_args: + name += "%(expr)s" + else: + name = func.name + name = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + or isinstance(name, sa.sql.elements.quoted_name) + else name + ) + name = name + "%(expr)s" + return "::".join( + [ + ( + self.preparer.quote(tok) + if self.preparer._requires_quotes_illegal_chars(tok) + or isinstance(name, sa.sql.elements.quoted_name) + else tok + ) + for tok in func.packagenames + ] + + [name] + ) % {"expr": self.function_argspec(func, **kwargs)} + + +class YqlDdlCompiler(DDLCompiler): + pass + + +def upsert(table): + return sa.sql.Insert(table) + + +COLUMN_TYPES = { + ydb.PrimitiveType.Int8: sa.INTEGER, + ydb.PrimitiveType.Int16: sa.INTEGER, + ydb.PrimitiveType.Int32: sa.INTEGER, + ydb.PrimitiveType.Int64: sa.INTEGER, + ydb.PrimitiveType.Uint8: sa.INTEGER, + ydb.PrimitiveType.Uint16: sa.INTEGER, + ydb.PrimitiveType.Uint32: UInt32, + ydb.PrimitiveType.Uint64: UInt64, + ydb.PrimitiveType.Float: sa.FLOAT, + ydb.PrimitiveType.Double: sa.FLOAT, + ydb.PrimitiveType.String: sa.TEXT, + ydb.PrimitiveType.Utf8: sa.TEXT, + ydb.PrimitiveType.Json: sa.JSON, + ydb.PrimitiveType.JsonDocument: sa.JSON, + ydb.DecimalType: sa.DECIMAL, + ydb.PrimitiveType.Yson: sa.TEXT, + ydb.PrimitiveType.Date: sa.DATE, + ydb.PrimitiveType.Datetime: sa.DATETIME, + ydb.PrimitiveType.Timestamp: sa.DATETIME, + ydb.PrimitiveType.Interval: sa.INTEGER, + ydb.PrimitiveType.Bool: sa.BOOLEAN, + ydb.PrimitiveType.DyNumber: sa.TEXT, +} + + +def _get_column_type(t): + if isinstance(t, ydb.OptionalType): + t = t.item + + if isinstance(t, ydb.DecimalType): + return sa.DECIMAL(precision=t.item.precision, scale=t.item.scale) + + return COLUMN_TYPES[t] + + +class YqlDialect(DefaultDialect): + name = "yql" + supports_alter = False + max_identifier_length = 63 + supports_sane_rowcount = False + supports_statement_cache = False + + supports_native_enum = False + supports_native_boolean = True + supports_smallserial = False + + supports_sequences = False + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + + supports_default_values = False + supports_empty_insert = False + supports_multivalues_insert = True + default_paramstyle = "qmark" + + isolation_level = None + + preparer = YqlIdentifierPreparer + statement_compiler = YqlCompiler + ddl_compiler = YqlDdlCompiler + type_compiler = YqlTypeCompiler + + driver = ydb.Driver + + @staticmethod + def dbapi(): + return dbapi + + def _check_unicode_returns(self, *args, **kwargs): + # Normally, this would do 2 SQL queries, which isn't quite necessary. + return "conditional" + + def get_columns(self, connection, table_name, schema=None, **kw): + if schema is not None: + raise dbapi.errors.NotSupportedError("unsupported on non empty schema") + + qt = table_name.name if isinstance(table_name, Table) else table_name + + if SA_14: + raw_conn = connection.connection + else: + raw_conn = connection.raw_connection() + + columns = raw_conn.describe(qt) + as_compatible = [] + for column in columns: + as_compatible.append( + { + "name": column.name, + "type": _get_column_type(column.type), + "nullable": True, + } + ) + + return as_compatible + + def has_table(self, connection, table_name, schema=None, **kwargs): + if schema is not None: + raise dbapi.errors.NotSupportedError("unsupported on non empty schema") + + quote = self.identifier_preparer.quote_identifier + qtable = quote(table_name) + + # TODO: use `get_columns` instead. + statement = "SELECT * FROM " + qtable + try: + connection.execute(statement) + return True + except Exception: + return False + + def get_pk_constraint(self, connection, table_name, schema=None, **kwargs): + # TODO: implement me + return [] + + def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): + # foreign keys unsupported + return [] + + def get_indexes(self, connection, table_name, schema=None, **kwargs): + # TODO: implement me + return [] + + def do_commit(self, dbapi_connection) -> None: + # TODO: needs to implement? + pass + + def do_execute(self, cursor, statement, parameters, context=None) -> None: + c = None + if context is not None and context.isddl: + c = {"isddl": True} + cursor.execute(statement, parameters, c) + + +def register_dialect( + name="yql", + module=__name__, + cls="YqlDialect", +): + return dialects.registry.register(name, module, cls) diff --git a/ydb/_sqlalchemy/types.py b/ydb/_sqlalchemy/types.py new file mode 100644 index 00000000..21748ec1 --- /dev/null +++ b/ydb/_sqlalchemy/types.py @@ -0,0 +1,28 @@ +from sqlalchemy.types import Integer +from sqlalchemy.sql import type_api +from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy import util, exc + + +class UInt32(Integer): + __visit_name__ = "uint32" + + +class UInt64(Integer): + __visit_name__ = "uint64" + + +class UInt8(Integer): + __visit_name__ = "uint8" + + +class Lambda(ColumnElement): + + __visit_name__ = "lambda" + + def __init__(self, func): + if not util.callable(func): + raise exc.ArgumentError("func must be callable") + + self.type = type_api.NULLTYPE + self.func = func