Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pool retries to execute calls #3

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 150 additions & 150 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "^3.9"
ydb = "^3.18.3"
ydb = "^3.18.8"

[tool.poetry.group.dev.dependencies]
pre-commit = "^4.0.1"
Expand Down
64 changes: 22 additions & 42 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,17 @@ async def session_pool(
"""
)

await session_pool.execute_with_retries(
f"""
DELETE FROM {name};
INSERT INTO {name} (id, val) VALUES
(0, 0),
(1, 1),
(2, 2),
(3, 3)
"""
)

yield session_pool


Expand All @@ -184,46 +195,15 @@ def session_pool_sync(
"""
)

yield session_pool


@pytest.fixture
async def session(
session_pool: ydb.aio.QuerySessionPool,
) -> AsyncGenerator[ydb.aio.QuerySession]:
for name in ["table", "table1", "table2"]:
await session_pool.execute_with_retries(
f"""
DELETE FROM {name};
INSERT INTO {name} (id, val) VALUES
(0, 0),
(1, 1),
(2, 2),
(3, 3)
"""
)

session = await session_pool.acquire()
yield session
await session_pool.release(session)


@pytest.fixture
def session_sync(
session_pool_sync: ydb.QuerySessionPool,
) -> Generator[ydb.QuerySession]:
for name in ["table", "table1", "table2"]:
session_pool_sync.execute_with_retries(
f"""
DELETE FROM {name};
INSERT INTO {name} (id, val) VALUES
(0, 0),
(1, 1),
(2, 2),
(3, 3)
"""
)
session_pool.execute_with_retries(
f"""
DELETE FROM {name};
INSERT INTO {name} (id, val) VALUES
(0, 0),
(1, 1),
(2, 2),
(3, 3)
"""
)

session = session_pool_sync.acquire()
yield session
session_pool_sync.release(session)
yield session_pool
2 changes: 2 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def _test_isolation_level_read_only(

connection.set_isolation_level(isolation_level)
cursor = connection.cursor()
maybe_await(connection.begin())

query = "UPSERT INTO foo(id) VALUES (1)"
if read_only:
with pytest.raises(dbapi.DatabaseError):
Expand Down
10 changes: 6 additions & 4 deletions tests/test_cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ def _test_cursor_fetch_all_multiple_result_sets(

class TestCursor(BaseCursorTestSuit):
@pytest.fixture
def sync_cursor(self, session_sync: ydb.QuerySession) -> Generator[Cursor]:
cursor = Cursor(session_sync, ydb.QuerySerializableReadWrite())
def sync_cursor(
self, session_pool_sync: ydb.QuerySessionPool
) -> Generator[Cursor]:
cursor = Cursor(session_pool_sync, ydb.QuerySerializableReadWrite())
yield cursor
cursor.close()

Expand Down Expand Up @@ -173,9 +175,9 @@ def test_cursor_fetch_all_multiple_result_sets(
class TestAsyncCursor(BaseCursorTestSuit):
@pytest.fixture
async def async_cursor(
self, session: ydb.aio.QuerySession
self, session_pool: ydb.aio.QuerySessionPool
) -> AsyncGenerator[Cursor]:
cursor = AsyncCursor(session, ydb.QuerySerializableReadWrite())
cursor = AsyncCursor(session_pool, ydb.QuerySerializableReadWrite())
yield cursor
await greenlet_spawn(cursor.close)

Expand Down
69 changes: 38 additions & 31 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,15 @@ def __init__(
if ydb_session_pool is not None:
self._shared_session_pool = True
self._session_pool = ydb_session_pool
settings = self._get_client_settings()
self._session_pool._query_client_settings = settings
self._driver = self._session_pool._driver
else:
driver_config = ydb.DriverConfig(
endpoint=self.endpoint,
database=self.database,
credentials=self.credentials,
query_client_settings=self._get_client_settings(),
)
self._driver = self._driver_cls(driver_config)
self._session_pool = self._pool_cls(self._driver, size=5)
Expand Down Expand Up @@ -126,11 +129,15 @@ def get_isolation_level(self) -> str:
msg = f"{self._tx_mode.name} is not supported"
raise NotSupportedError(msg)

def _maybe_init_tx(
self, session: ydb.QuerySession | ydb.aio.QuerySession
) -> None:
if self._tx_context is None and self.interactive_transaction:
self._tx_context = session.transaction(self._tx_mode)
def _get_client_settings(self) -> ydb.QueryClientSettings:
return (
ydb.QueryClientSettings()
.with_native_date_in_result_sets(True)
.with_native_datetime_in_result_sets(True)
.with_native_timestamp_in_result_sets(True)
.with_native_interval_in_result_sets(True)
.with_native_json_in_result_sets(False)
)


class Connection(BaseConnection):
Expand Down Expand Up @@ -160,17 +167,11 @@ def __init__(
self._current_cursor: Cursor | None = None

def cursor(self) -> Cursor:
if self._session is None:
raise RuntimeError("Connection is not ready, use wait_ready.")

self._maybe_init_tx(self._session)

self._current_cursor = self._cursor_cls(
session=self._session,
return self._cursor_cls(
session_pool=self._session_pool,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
)
return self._current_cursor

def wait_ready(self, timeout: int = 10) -> None:
try:
Expand All @@ -185,27 +186,33 @@ def wait_ready(self, timeout: int = 10) -> None:
)
raise InterfaceError(msg) from e

self._session = self._session_pool.acquire()
@handle_ydb_errors
def begin(self) -> None:
self._tx_context = None
if self.interactive_transaction:
self._session = self._session_pool.acquire()
self._tx_context = self._session.transaction(self._tx_mode)

@handle_ydb_errors
def commit(self) -> None:
if self._tx_context and self._tx_context.tx_id:
self._tx_context.commit()
self._session_pool.release(self._session)
self._tx_context = None
self._session = None

@handle_ydb_errors
def rollback(self) -> None:
if self._tx_context and self._tx_context.tx_id:
self._tx_context.rollback()
self._session_pool.release(self._session)
self._tx_context = None
self._session = None

@handle_ydb_errors
def close(self) -> None:
self.rollback()

if self._current_cursor:
self._current_cursor.close()

if self._session:
self._session_pool.release(self._session)

Expand Down Expand Up @@ -287,17 +294,11 @@ def __init__(
self._current_cursor: AsyncCursor | None = None

def cursor(self) -> AsyncCursor:
if self._session is None:
raise RuntimeError("Connection is not ready, use wait_ready.")

self._maybe_init_tx(self._session)

self._current_cursor = self._cursor_cls(
session=self._session,
return self._cursor_cls(
session_pool=self._session_pool,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
)
return self._current_cursor

async def wait_ready(self, timeout: int = 10) -> None:
try:
Expand All @@ -312,27 +313,33 @@ async def wait_ready(self, timeout: int = 10) -> None:
)
raise InterfaceError(msg) from e

self._session = await self._session_pool.acquire()
@handle_ydb_errors
async def begin(self) -> None:
self._tx_context = None
if self.interactive_transaction:
self._session = await self._session_pool.acquire()
self._tx_context = self._session.transaction(self._tx_mode)

@handle_ydb_errors
async def commit(self) -> None:
if self._tx_context and self._tx_context.tx_id:
if self._session and self._tx_context and self._tx_context.tx_id:
await self._tx_context.commit()
await self._session_pool.release(self._session)
self._session = None
self._tx_context = None

@handle_ydb_errors
async def rollback(self) -> None:
if self._tx_context and self._tx_context.tx_id:
if self._session and self._tx_context and self._tx_context.tx_id:
await self._tx_context.rollback()
await self._session_pool.release(self._session)
self._session = None
self._tx_context = None

@handle_ydb_errors
async def close(self) -> None:
await self.rollback()

if self._current_cursor:
await self._current_cursor.close()

if self._session:
await self._session_pool.release(self._session)

Expand Down
58 changes: 42 additions & 16 deletions ydb_dbapi/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def _fetchall_from_buffer(self) -> list:
class Cursor(BufferedCursor):
def __init__(
self,
session: ydb.QuerySession,
session_pool: ydb.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
tx_context: ydb.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._session = session
self._session_pool = session_pool
self._tx_mode = tx_mode
self._tx_context = tx_context
self._table_path_prefix = table_path_prefix
Expand All @@ -165,19 +165,32 @@ def fetchall(self) -> list:
def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> Iterator[ydb.convert.ResultSet]:
return self._session.execute(query=query, parameters=parameters)
def callee(
session: ydb.QuerySession,
) -> Iterator[ydb.convert.ResultSet]:
return session.execute(
query=query,
parameters=parameters,
)

return self._session_pool.retry_operation_sync(callee)

@handle_ydb_errors
def _execute_session_query(
self,
query: str,
parameters: ParametersType | None = None,
) -> Iterator[ydb.convert.ResultSet]:
return self._session.transaction(self._tx_mode).execute(
query=query,
parameters=parameters,
commit_tx=True,
)
def callee(
session: ydb.QuerySession,
) -> Iterator[ydb.convert.ResultSet]:
return session.transaction(self._tx_mode).execute(
query=query,
parameters=parameters,
commit_tx=True,
)

return self._session_pool.retry_operation_sync(callee)

@handle_ydb_errors
def _execute_transactional_query(
Expand Down Expand Up @@ -276,13 +289,13 @@ def __exit__(
class AsyncCursor(BufferedCursor):
def __init__(
self,
session: ydb.aio.QuerySession,
session_pool: ydb.aio.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
tx_context: ydb.aio.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._session = session
self._session_pool = session_pool
self._tx_mode = tx_mode
self._tx_context = tx_context
self._table_path_prefix = table_path_prefix
Expand All @@ -303,19 +316,32 @@ async def fetchall(self) -> list:
async def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> AsyncIterator[ydb.convert.ResultSet]:
return await self._session.execute(query=query, parameters=parameters)
async def callee(
session: ydb.aio.QuerySession,
) -> AsyncIterator[ydb.convert.ResultSet]:
return await session.execute(
query=query,
parameters=parameters,
)

return await self._session_pool.retry_operation_async(callee)

@handle_ydb_errors
async def _execute_session_query(
self,
query: str,
parameters: ParametersType | None = None,
) -> AsyncIterator[ydb.convert.ResultSet]:
return await self._session.transaction(self._tx_mode).execute(
query=query,
parameters=parameters,
commit_tx=True,
)
async def callee(
session: ydb.aio.QuerySession,
) -> AsyncIterator[ydb.convert.ResultSet]:
return await session.transaction(self._tx_mode).execute(
query=query,
parameters=parameters,
commit_tx=True,
)

return await self._session_pool.retry_operation_async(callee)

@handle_ydb_errors
async def _execute_transactional_query(
Expand Down
Loading