Skip to content

Commit

Permalink
feat: v1.0.14
Browse files Browse the repository at this point in the history
  • Loading branch information
ddc committed Dec 7, 2024
1 parent 142dbb8 commit dfb2905
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 122 deletions.
75 changes: 22 additions & 53 deletions ddcDatabases/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from typing import AsyncGenerator, Generator, Optional

import sqlalchemy as sa
from sqlalchemy import RowMapping
from sqlalchemy.engine import create_engine, Engine, URL
Expand All @@ -12,6 +13,7 @@
create_async_engine,
)
from sqlalchemy.orm import Session, sessionmaker

from .exceptions import (
DBDeleteAllDataException,
DBExecuteException,
Expand All @@ -23,30 +25,22 @@


class BaseConn:

def __init__(
self,
host,
port,
user,
database,
autoflush,
expire_on_commit,
connection_url,
engine_args,
autoflush,
expire_on_commit,
sync_driver,
async_driver,
):
self.host = host
self.port = port
self.user = user
self.database = database
self.autoflush = autoflush
self.expire_on_commit = expire_on_commit
self.connection_url = connection_url
self.engine_args = engine_args
self.autoflush = autoflush
self.expire_on_commit = expire_on_commit
self.sync_driver = sync_driver
self.async_driver = async_driver

self.temp_engine: Optional[Engine | AsyncEngine] = None
self.session: Optional[Session | AsyncSession] = None

Expand Down Expand Up @@ -88,10 +82,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

@contextmanager
def engine(self) -> Generator:
_connection_url = URL.create(
**self.connection_url,
drivername=self.sync_driver,
)
_connection_url = URL.create(**self.connection_url, drivername=self.sync_driver)
_engine_args = {
"url": _connection_url,
**self.engine_args,
Expand All @@ -102,10 +93,7 @@ def engine(self) -> Generator:

@asynccontextmanager
async def async_engine(self) -> AsyncGenerator:
_connection_url = URL.create(
**self.connection_url,
drivername=self.async_driver,
)
_connection_url = URL.create(**self.connection_url, drivername=self.async_driver)
_engine_args = {
"url": _connection_url,
**self.engine_args,
Expand All @@ -115,33 +103,20 @@ async def async_engine(self) -> AsyncGenerator:
await _engine.dispose()

def _test_connection_sync(self, session: Session) -> None:
host_url = URL.create(
drivername=self.sync_driver,
username=self.user,
host=self.host,
port=self.port,
database=self.database,
)
test_connection = TestConnections(
sync_session=session, host_url=host_url,
)
del self.connection_url["password"]
_connection_url = URL.create(**self.connection_url, drivername=self.sync_driver)
test_connection = TestConnections(sync_session=session, host_url=_connection_url)
test_connection.test_connection_sync()

async def _test_connection_async(self, session: AsyncSession) -> None:
host_url = URL.create(
drivername=self.async_driver,
username=self.user,
host=self.host,
port=self.port,
database=self.database,
)
test_connection = TestConnections(
async_session=session, host_url=host_url,
)
del self.connection_url["password"]
_connection_url = URL.create(**self.connection_url, drivername=self.async_driver)
test_connection = TestConnections(async_session=session, host_url=_connection_url)
await test_connection.test_connection_async()


class TestConnections:

def __init__(
self,
sync_session: Session = None,
Expand All @@ -156,31 +131,24 @@ def __init__(
def test_connection_sync(self) -> None:
try:
self.sync_session.execute(sa.text("SELECT 1"))
sys.stdout.write(
f"[{self.dt}]:[INFO]:Connection to database successful | {self.host_url}\n"
)
sys.stdout.write(f"[{self.dt}]:[INFO]:Connection to database successful | {self.host_url}\n")
except Exception as e:
self.sync_session.close()
sys.stderr.write(
f"[{self.dt}]:[ERROR]:Connection to datatabse failed | {self.host_url} | {repr(e)}\n"
)
sys.stderr.write(f"[{self.dt}]:[ERROR]:Connection to datatabse failed | {self.host_url} | {repr(e)}\n")
raise

async def test_connection_async(self) -> None:
try:
await self.async_session.execute(sa.text("SELECT 1"))
sys.stdout.write(
f"[{self.dt}]:[INFO]:Connection to database successful | {self.host_url}\n"
)
sys.stdout.write(f"[{self.dt}]:[INFO]:Connection to database successful | {self.host_url}\n")
except Exception as e:
await self.async_session.close()
sys.stderr.write(
f"[{self.dt}]:[ERROR]:Connection to datatabse failed | {self.host_url} | {repr(e)}\n"
)
sys.stderr.write(f"[{self.dt}]:[ERROR]:Connection to datatabse failed | {self.host_url} | {repr(e)}\n")
raise


class DBUtils:

def __init__(self, session):
self.session = session

Expand Down Expand Up @@ -245,6 +213,7 @@ def execute(self, stmt) -> None:


class DBUtilsAsync:

def __init__(self, session):
self.session = session

Expand Down
59 changes: 22 additions & 37 deletions ddcDatabases/mssql.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# -*- coding: utf-8 -*-
from typing import Optional
from sqlalchemy.engine import Engine, URL
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
)
from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from .db_utils import BaseConn, TestConnections
from .settings import MSSQLSettings
Expand All @@ -30,27 +27,24 @@ def __init__(
expire_on_commit: Optional[bool] = None,
):
_settings = MSSQLSettings()
self.host = host or _settings.host
self.user = user or _settings.user
self.password = password or _settings.password
self.port = port or int(_settings.port)
self.database = database or _settings.database
if not _settings.user or not _settings.password:
raise RuntimeError("Missing username or password")

self.schema = schema or _settings.db_schema
self.echo = echo or _settings.echo
self.pool_size = pool_size or int(_settings.pool_size)
self.max_overflow = max_overflow or int(_settings.max_overflow)

self.autoflush = autoflush
self.expire_on_commit = expire_on_commit
self.async_driver = _settings.async_driver
self.sync_driver = _settings.sync_driver
self.odbcdriver_version = int(_settings.odbcdriver_version)
self.connection_url = {
"username": self.user,
"password": self.password,
"host": self.host,
"port": self.port,
"database": self.database,
"host": host or _settings.host,
"port": port or int(_settings.port),
"database": database or _settings.database,
"username": user or _settings.user,
"password": password or _settings.password,
"query": {
"driver": f"ODBC Driver {self.odbcdriver_version} for SQL Server",
"TrustServerCertificate": "yes",
Expand All @@ -62,42 +56,33 @@ def __init__(
"echo": self.echo,
}

if not self.user or not self.password:
raise RuntimeError("Missing username or password")

super().__init__(
host=self.host,
port=self.port,
user=self.user,
database=self.database,
autoflush=self.autoflush,
expire_on_commit=self.expire_on_commit,
connection_url=self.connection_url,
engine_args=self.engine_args,
autoflush=self.autoflush,
expire_on_commit=self.expire_on_commit,
sync_driver=self.sync_driver,
async_driver=self.async_driver,
)

def _test_connection_sync(self, session: Session) -> None:
host_url = URL.create(
del self.connection_url["password"]
del self.connection_url["query"]
_connection_url = URL.create(
**self.connection_url,
drivername=self.sync_driver,
username=self.user,
host=self.host,
port=self.port,
database=self.database,
query={"schema": self.schema},
)
test_connection = TestConnections(sync_session=session, host_url=host_url)
test_connection = TestConnections(sync_session=session, host_url=_connection_url)
test_connection.test_connection_sync()

async def _test_connection_async(self, session: AsyncSession) -> None:
host_url = URL.create(
del self.connection_url["password"]
del self.connection_url["query"]
_connection_url = URL.create(
**self.connection_url,
drivername=self.async_driver,
username=self.user,
host=self.host,
port=self.port,
database=self.database,
query={"schema": self.schema},
)
test_connection = TestConnections(async_session=session, host_url=host_url)
test_connection = TestConnections(async_session=session, host_url=_connection_url)
await test_connection.test_connection_async()
36 changes: 10 additions & 26 deletions ddcDatabases/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
# -*- encoding: utf-8 -*-
from typing import Optional
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
)
from sqlalchemy.orm import Session
from .db_utils import BaseConn
from .settings import PostgreSQLSettings

Expand All @@ -27,40 +21,30 @@ def __init__(
expire_on_commit: Optional[bool] = None,
):
_settings = PostgreSQLSettings()
self.host = host or _settings.host
self.user = user or _settings.user
self.password = password or _settings.password
self.port = port or int(_settings.port)
self.database = database or _settings.database
self.echo = echo or _settings.echo
if not _settings.user or not _settings.password:
raise RuntimeError("Missing username or password")

self.echo = echo or _settings.echo
self.autoflush = autoflush
self.expire_on_commit = expire_on_commit
self.async_driver = _settings.async_driver
self.sync_driver = _settings.sync_driver
self.connection_url = {
"username": self.user,
"password": self.password,
"host": self.host,
"port": self.port,
"database": self.database,
"host": host or _settings.host,
"port": port or int(_settings.port),
"database": database or _settings.database,
"username": user or _settings.user,
"password": password or _settings.password,
}
self.engine_args = {
"echo": self.echo,
}

if not self.user or not self.password:
raise RuntimeError("Missing username or password")

super().__init__(
host=self.host,
port=self.port,
user=self.user,
database=self.database,
autoflush=self.autoflush,
expire_on_commit=self.expire_on_commit,
connection_url=self.connection_url,
engine_args=self.engine_args,
autoflush=self.autoflush,
expire_on_commit=self.expire_on_commit,
sync_driver=self.sync_driver,
async_driver=self.async_driver,
)
6 changes: 1 addition & 5 deletions tests/dal/sqlite_dal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sqlalchemy as sa
from ddcDatabases import DBUtils
from ddcDatabases.exceptions import DBFetchAllException
from tests.models.sqlite_model import ModelTest
from tests.models.test_model import ModelTest


class SqliteDal:
Expand All @@ -14,10 +14,6 @@ def update_name(self, name: str, test_id: int):
stmt = sa.update(ModelTest).where(ModelTest.id == test_id).values(name=name)
self.db_utils.execute(stmt)

def update_enable(self, status: bool, test_id: int):
stmt = sa.update(ModelTest).where(ModelTest.id == test_id).values(enable=status)
self.db_utils.execute(stmt)

def get(self, test_id: int):
try:
stmt = sa.select(*self.columns).where(ModelTest.id == test_id)
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/unit/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ddcDatabases import Sqlite
from tests.dal.sqlite_dal import SqliteDal
from tests.data.base_data import db_filename
from tests.models.sqlite_model import ModelTest
from tests.models.test_model import ModelTest


class TestSQLite:
Expand Down

0 comments on commit dfb2905

Please sign in to comment.