Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a custom SELECT with SAMPLE BY support #27

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ max-branches = 20
max-args = 10

[tool.ruff.per-file-ignores]
'tests/test_dialect.py' = ['S101']
'tests/test_dialect.py' = ['S101', 'PLR2004']
'tests/test_types.py' = ['S101']
'tests/test_superset.py' = ['S101']
'tests/conftest.py' = ['S608']
Expand Down
6 changes: 6 additions & 0 deletions src/questdb_connect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
create_engine,
create_superset_engine,
)
from questdb_connect.dml import QDBSelect, select
from questdb_connect.identifier_preparer import QDBIdentifierPreparer
from questdb_connect.inspector import QDBInspector
from questdb_connect.keywords_functions import get_functions_list, get_keywords_list
Expand Down Expand Up @@ -51,6 +52,11 @@
threadsafety = 2
paramstyle = "pyformat"

__all__ = (
"select",
"QDBSelect",
)


class Error(Exception):
pass
Expand Down
70 changes: 70 additions & 0 deletions src/questdb_connect/compilers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc

import sqlalchemy
from sqlalchemy.sql.base import elements

from .common import quote_identifier, remove_public_schema
from .types import QDBTypeMixin
Expand Down Expand Up @@ -30,9 +31,78 @@ def get_column_specification(self, column: sqlalchemy.Column, **_):


class QDBSQLCompiler(sqlalchemy.sql.compiler.SQLCompiler, abc.ABC):
def visit_sample_by(self, sample_by, **kw):
"""Compile a SAMPLE BY clause."""
text = ""

# Basic SAMPLE BY
if sample_by.unit:
text = f"SAMPLE BY {sample_by.value}{sample_by.unit}"
else:
text = f"SAMPLE BY {sample_by.value}"

if sample_by.from_timestamp:
# Format datetime to ISO format that QuestDB expects
text += f" FROM '{sample_by.from_timestamp.isoformat()}'"
if sample_by.to_timestamp:
text += f" TO '{sample_by.to_timestamp.isoformat()}'"

# Add FILL if specified
if sample_by.fill is not None:
if isinstance(sample_by.fill, str):
text += f" FILL({sample_by.fill})"
else:
text += f" FILL({sample_by.fill:g})"

# Add ALIGN TO clause
text += f" ALIGN TO {sample_by.align_to}"

# Add TIME ZONE if specified
if sample_by.timezone:
text += f" TIME ZONE '{sample_by.timezone}'"

# Add WITH OFFSET if specified
if sample_by.offset:
text += f" WITH OFFSET '{sample_by.offset}'"

return text

def group_by_clause(self, select, **kw):
"""Customize GROUP BY to also render SAMPLE BY."""
text = ""

# Add SAMPLE BY first if present
if _has_sample_by(select):
text += " " + self.process(select._sample_by_clause, **kw)

# Use parent's GROUP BY implementation
group_by_text = super().group_by_clause(select, **kw)
if group_by_text:
text += group_by_text

return text

def visit_select(self, select, **kw):
"""Add SAMPLE BY support to the standard SELECT compilation."""

# If we have SAMPLE BY but no GROUP BY,
# add a dummy GROUP BY clause to trigger the rendering
if (
_has_sample_by(select)
and not select._group_by_clauses
):
select = select._clone()
select._group_by_clauses = [elements.TextClause("")]

text = super().visit_select(select, **kw)
return text

def _is_safe_for_fast_insert_values_helper(self):
return True

def visit_textclause(self, textclause, add_to_result_map=None, **kw):
textclause.text = remove_public_schema(textclause.text)
return super().visit_textclause(textclause, add_to_result_map, **kw)

def _has_sample_by(select):
return hasattr(select, '_sample_by_clause') and select._sample_by_clause is not None
120 changes: 120 additions & 0 deletions src/questdb_connect/dml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Sequence, Union

from sqlalchemy import select as sa_select
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql import Select as StandardSelect

if TYPE_CHECKING:
from datetime import date, datetime

from sqlalchemy.sql.visitors import Visitable


class SampleByClause(ClauseElement):
"""Represents the QuestDB SAMPLE BY clause."""

__visit_name__ = "sample_by"
stringify_dialect = "questdb"

def __init__(
self,
value: Union[int, float],
unit: Optional[str] = None,
fill: Optional[Union[str, float]] = None,
align_to: str = "CALENDAR", # default per docs
timezone: Optional[str] = None,
offset: Optional[str] = None,
from_timestamp: Optional[Union[datetime, date]] = None,
to_timestamp: Optional[Union[datetime, date]] = None
):
self.value = value
self.unit = unit.lower() if unit else None
self.fill = fill
self.align_to = align_to.upper()
self.timezone = timezone
self.offset = offset
self.from_timestamp = from_timestamp
self.to_timestamp = to_timestamp

def __str__(self) -> str:
if self.unit:
return f"SAMPLE BY {self.value}{self.unit}"
return f"SAMPLE BY {self.value}"

def get_children(self, **kwargs: Any) -> Sequence[Visitable]:
return []


class QDBSelect(StandardSelect):
"""QuestDB-specific implementation of SELECT.

Adds methods for QuestDB-specific syntaxes such as SAMPLE BY.

The :class:`_questdb.QDBSelect` object is created using the
:func:`sqlalchemy.dialects.questdb.select` function.
"""

stringify_dialect = "questdb"
_sample_by_clause: Optional[SampleByClause] = None

def get_children(self, **kwargs: Any) -> Sequence[Visitable]:
children = super().get_children(**kwargs)
if self._sample_by_clause is not None:
children = [*children, self._sample_by_clause]
return children

def sample_by(
self,
value: Union[int, float],
unit: Optional[str] = None,
fill: Optional[Union[str, float]] = None,
align_to: str = "CALENDAR",
timezone: Optional[str] = None,
offset: Optional[str] = None,
from_timestamp: Optional[Union[datetime, date]] = None,
to_timestamp: Optional[Union[datetime, date]] = None,
) -> QDBSelect:
"""Add a SAMPLE BY clause.

:param value: time interval value
:param unit: 's' for seconds, 'm' for minutes, 'h' for hours, etc.
:param fill: fill strategy - NONE, NULL, PREV, LINEAR, or constant value
:param align_to: CALENDAR or FIRST OBSERVATION
:param timezone: Optional timezone for calendar alignment
:param offset: Optional offset in format '+/-HH:mm'
:param from_timestamp: Optional start timestamp for the sample
:param to_timestamp: Optional end timestamp for the sample
"""

# Create a copy of our object with _generative
s = self.__class__.__new__(self.__class__)
s.__dict__ = self.__dict__.copy()

# Set the sample by clause
s._sample_by_clause = SampleByClause(
value, unit, fill, align_to, timezone, offset, from_timestamp, to_timestamp
)
return s


def select(*entities: Any, **kwargs: Any) -> QDBSelect:
"""Construct a QuestDB-specific variant :class:`_questdb.Select` construct.

.. container:: inherited_member

The :func:`sqlalchemy.dialects.questdb.select` function creates
a :class:`sqlalchemy.dialects.questdb.Select`. This class is based
on the dialect-agnostic :class:`_sql.Select` construct which may
be constructed using the :func:`_sql.select` function in
SQLAlchemy Core.

The :class:`_questdb.Select` construct includes additional method
:meth:`_questdb.Select.sample_by` for QuestDB's SAMPLE BY clause.
"""
stmt = sa_select(*entities, **kwargs)
# Convert the SQLAlchemy Select into our QDBSelect
qdbs = QDBSelect.__new__(QDBSelect)
qdbs.__dict__ = stmt.__dict__.copy()
return qdbs
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
from typing import NamedTuple

import pytest
Expand Down Expand Up @@ -125,6 +126,32 @@ def collect_select_all(session, expected_rows) -> str:
if rs.rowcount == expected_rows:
return '\n'.join(str(row) for row in rs)

def wait_until_table_is_ready(test_engine, table_name, expected_rows, timeout=10):
"""
Wait until a table has the expected number of rows, with timeout.
Args:
test_engine: SQLAlchemy engine
table_name: Name of the table to check
expected_rows: Expected number of rows
timeout: Maximum time to wait in seconds (default: 10 seconds)
Returns:
bool: True if table is ready, False if timeout occurred
Raises:
sqlalchemy.exc.SQLAlchemyError: If there's a database error
"""
start_time = time.time()

while time.time() - start_time < timeout:
with test_engine.connect() as conn:
result = conn.execute(text(f'SELECT count(*) FROM {table_name}'))
row = result.fetchone()
if row and row[0] == expected_rows:
return True

print(f'Waiting for table {table_name} to have {expected_rows} rows, current: {row[0] if row else 0}')
time.sleep(0.01) # Wait 10ms between checks
return False


def collect_select_all_raw_connection(test_engine, expected_rows) -> str:
conn = test_engine.raw_connection()
Expand Down
Loading
Loading