diff --git a/src/questdb_connect/__init__.py b/src/questdb_connect/__init__.py index 82042ef..305254b 100644 --- a/src/questdb_connect/__init__.py +++ b/src/questdb_connect/__init__.py @@ -11,6 +11,7 @@ create_engine, create_superset_engine, ) +from questdb_connect.dml import QDBSelect, select from questdb_connect.identifier_preparer import QDBIdentifierPreparer from questdb_connect.inspector import QDBInspector from questdb_connect.keywords_functions import get_functions_list, get_keywords_list @@ -51,6 +52,11 @@ threadsafety = 2 paramstyle = "pyformat" +__all__ = ( + "select", + "QDBSelect", +) + class Error(Exception): pass diff --git a/src/questdb_connect/compilers.py b/src/questdb_connect/compilers.py index e45430e..7d5de56 100644 --- a/src/questdb_connect/compilers.py +++ b/src/questdb_connect/compilers.py @@ -1,6 +1,7 @@ import abc import sqlalchemy +from sqlalchemy.sql.base import elements from .common import quote_identifier, remove_public_schema from .types import QDBTypeMixin @@ -33,6 +34,72 @@ class QDBSQLCompiler(sqlalchemy.sql.compiler.SQLCompiler, abc.ABC): # Maximum value for 64-bit signed integer (2^63 - 1) BIGINT_MAX = 9223372036854775807 + def visit_sample_by(self, sample_by, **kw): + """Compile a SAMPLE BY clause.""" + text = "" + + # Basic SAMPLE BY + if sample_by.unit: + text = f"SAMPLE BY {sample_by.value}{sample_by.unit}" + else: + text = f"SAMPLE BY {sample_by.value}" + + if sample_by.from_timestamp: + # Format datetime to ISO format that QuestDB expects + text += f" FROM '{sample_by.from_timestamp.isoformat()}'" + if sample_by.to_timestamp: + text += f" TO '{sample_by.to_timestamp.isoformat()}'" + + # Add FILL if specified + if sample_by.fill is not None: + if isinstance(sample_by.fill, str): + text += f" FILL({sample_by.fill})" + else: + text += f" FILL({sample_by.fill:g})" + + # Add ALIGN TO clause + text += f" ALIGN TO {sample_by.align_to}" + + # Add TIME ZONE if specified + if sample_by.timezone: + text += f" TIME ZONE '{sample_by.timezone}'" + + # Add WITH OFFSET if specified + if sample_by.offset: + text += f" WITH OFFSET '{sample_by.offset}'" + + return text + + def group_by_clause(self, select, **kw): + """Customize GROUP BY to also render SAMPLE BY.""" + text = "" + + # Add SAMPLE BY first if present + if _has_sample_by(select): + text += " " + self.process(select._sample_by_clause, **kw) + + # Use parent's GROUP BY implementation + group_by_text = super().group_by_clause(select, **kw) + if group_by_text: + text += group_by_text + + return text + + def visit_select(self, select, **kw): + """Add SAMPLE BY support to the standard SELECT compilation.""" + + # If we have SAMPLE BY but no GROUP BY, + # add a dummy GROUP BY clause to trigger the rendering + if ( + _has_sample_by(select) + and not select._group_by_clauses + ): + select = select._clone() + select._group_by_clauses = [elements.TextClause("")] + + text = super().visit_select(select, **kw) + return text + def _is_safe_for_fast_insert_values_helper(self): return True @@ -69,3 +136,6 @@ def limit_clause(self, select, **kw): text += f"{self.process(offset, **kw)},{self.BIGINT_MAX}" return text + +def _has_sample_by(select): + return hasattr(select, '_sample_by_clause') and select._sample_by_clause is not None \ No newline at end of file diff --git a/src/questdb_connect/dml.py b/src/questdb_connect/dml.py new file mode 100644 index 0000000..506bb00 --- /dev/null +++ b/src/questdb_connect/dml.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, Sequence, Union + +from sqlalchemy import select as sa_select +from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql import Select as StandardSelect + +if TYPE_CHECKING: + from datetime import date, datetime + + from sqlalchemy.sql.visitors import Visitable + + +class SampleByClause(ClauseElement): + """Represents the QuestDB SAMPLE BY clause.""" + + __visit_name__ = "sample_by" + stringify_dialect = "questdb" + + def __init__( + self, + value: Union[int, float], + unit: Optional[str] = None, + fill: Optional[Union[str, float]] = None, + align_to: str = "CALENDAR", # default per docs + timezone: Optional[str] = None, + offset: Optional[str] = None, + from_timestamp: Optional[Union[datetime, date]] = None, + to_timestamp: Optional[Union[datetime, date]] = None + ): + self.value = value + self.unit = unit.lower() if unit else None + self.fill = fill + self.align_to = align_to.upper() + self.timezone = timezone + self.offset = offset + self.from_timestamp = from_timestamp + self.to_timestamp = to_timestamp + + def __str__(self) -> str: + if self.unit: + return f"SAMPLE BY {self.value}{self.unit}" + return f"SAMPLE BY {self.value}" + + def get_children(self, **kwargs: Any) -> Sequence[Visitable]: + return [] + + +class QDBSelect(StandardSelect): + """QuestDB-specific implementation of SELECT. + + Adds methods for QuestDB-specific syntaxes such as SAMPLE BY. + + The :class:`_questdb.QDBSelect` object is created using the + :func:`sqlalchemy.dialects.questdb.select` function. + """ + + stringify_dialect = "questdb" + _sample_by_clause: Optional[SampleByClause] = None + + def get_children(self, **kwargs: Any) -> Sequence[Visitable]: + children = super().get_children(**kwargs) + if self._sample_by_clause is not None: + children = [*children, self._sample_by_clause] + return children + + def sample_by( + self, + value: Union[int, float], + unit: Optional[str] = None, + fill: Optional[Union[str, float]] = None, + align_to: str = "CALENDAR", + timezone: Optional[str] = None, + offset: Optional[str] = None, + from_timestamp: Optional[Union[datetime, date]] = None, + to_timestamp: Optional[Union[datetime, date]] = None, + ) -> QDBSelect: + """Add a SAMPLE BY clause. + + :param value: time interval value + :param unit: 's' for seconds, 'm' for minutes, 'h' for hours, etc. + :param fill: fill strategy - NONE, NULL, PREV, LINEAR, or constant value + :param align_to: CALENDAR or FIRST OBSERVATION + :param timezone: Optional timezone for calendar alignment + :param offset: Optional offset in format '+/-HH:mm' + :param from_timestamp: Optional start timestamp for the sample + :param to_timestamp: Optional end timestamp for the sample + """ + + # Create a copy of our object with _generative + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + + # Set the sample by clause + s._sample_by_clause = SampleByClause( + value, unit, fill, align_to, timezone, offset, from_timestamp, to_timestamp + ) + return s + + +def select(*entities: Any, **kwargs: Any) -> QDBSelect: + """Construct a QuestDB-specific variant :class:`_questdb.Select` construct. + + .. container:: inherited_member + + The :func:`sqlalchemy.dialects.questdb.select` function creates + a :class:`sqlalchemy.dialects.questdb.Select`. This class is based + on the dialect-agnostic :class:`_sql.Select` construct which may + be constructed using the :func:`_sql.select` function in + SQLAlchemy Core. + + The :class:`_questdb.Select` construct includes additional method + :meth:`_questdb.Select.sample_by` for QuestDB's SAMPLE BY clause. + """ + stmt = sa_select(*entities, **kwargs) + # Convert the SQLAlchemy Select into our QDBSelect + qdbs = QDBSelect.__new__(QDBSelect) + qdbs.__dict__ = stmt.__dict__.copy() + return qdbs \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 222925f..b12c2d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,6 +126,32 @@ def collect_select_all(session, expected_rows) -> str: if rs.rowcount == expected_rows: return '\n'.join(str(row) for row in rs) +def wait_until_table_is_ready(test_engine, table_name, expected_rows, timeout=10): + """ + Wait until a table has the expected number of rows, with timeout. + Args: + test_engine: SQLAlchemy engine + table_name: Name of the table to check + expected_rows: Expected number of rows + timeout: Maximum time to wait in seconds (default: 10 seconds) + Returns: + bool: True if table is ready, False if timeout occurred + Raises: + sqlalchemy.exc.SQLAlchemyError: If there's a database error + """ + start_time = time.time() + + while time.time() - start_time < timeout: + with test_engine.connect() as conn: + result = conn.execute(text(f'SELECT count(*) FROM {table_name}')) + row = result.fetchone() + if row and row[0] == expected_rows: + return True + + print(f'Waiting for table {table_name} to have {expected_rows} rows, current: {row[0] if row else 0}') + time.sleep(0.01) # Wait 10ms between checks + return False + def wait_until_table_is_ready(test_engine, table_name, expected_rows, timeout=10): """ diff --git a/tests/test_dialect.py b/tests/test_dialect.py index 26aa19f..d1df8a4 100644 --- a/tests/test_dialect.py +++ b/tests/test_dialect.py @@ -1,5 +1,6 @@ import datetime +import questdb_connect import questdb_connect as qdbc import sqlalchemy as sqla from sqlalchemy.orm import Session @@ -13,6 +14,173 @@ ) +def test_sample_by_in_subquery(test_engine, test_model): + """Test SAMPLE BY usage within subqueries.""" + base_ts = datetime.datetime(2023, 4, 12, 0, 0, 0) + session = Session(test_engine) + try: + # Insert test data - one row every minute for 2 hours + num_rows = 120 # 2 hours * 60 minutes + models = [ + test_model( + col_boolean=True, + col_byte=8, + col_short=12, + col_int=idx, + col_long=14, + col_float=15.234, + col_double=16.88993244, + col_symbol='coconut', + col_string='banana', + col_char='C', + col_uuid='6d5eb038-63d1-4971-8484-30c16e13de5b', + col_date=base_ts.date(), + col_ts=base_ts + datetime.timedelta(minutes=idx), + col_geohash='dfvgsj2vptwu', + col_long256='0xa3b400fcf6ed707d710d5d4e672305203ed3cc6254d1cefe313e4a465861f42a', + col_varchar='pineapple' + ) for idx in range(num_rows) + ] + session.bulk_save_objects(models) + session.commit() + + metadata = sqla.MetaData() + table = sqla.Table(ALL_TYPES_TABLE_NAME, metadata, autoload_with=test_engine) + wait_until_table_is_ready(test_engine, ALL_TYPES_TABLE_NAME, num_rows) + + with test_engine.connect() as conn: + # Subquery with SAMPLE BY + subq = ( + questdb_connect.select( + table.c.col_ts, + sqla.func.avg(table.c.col_int).label('avg_int') + ) + .sample_by(30, 'm') # 30 minute samples in subquery + .subquery() + ) + + # Main query selecting from subquery with extra conditions + query = ( + questdb_connect.select( + subq.c.col_ts, + subq.c.avg_int + ) + .where(subq.c.avg_int > 30) + .order_by(subq.c.col_ts) + ) + + result = conn.execute(query) + rows = result.fetchall() + + # Should only get samples from second half of the data + # where averages are > 30 + assert len(rows) == 3 # expecting 2 30-min samples > 30 + assert all(row.avg_int > 30 for row in rows) + assert rows[0].avg_int < rows[1].avg_int # ordered by timestamp + + # Test nested aggregation + outer_query = ( + questdb_connect.select( + sqla.func.sum(subq.c.avg_int).label('total_avg') + ) + .select_from(subq) + ) + + result = conn.execute(outer_query) + row = result.fetchone() + # Sum of all 30-min sample averages + assert row.total_avg == 238; + + finally: + if session: + session.close() + +def test_sample_by_clause(test_engine, test_model): + """Test SAMPLE BY clause functionality.""" + base_ts = datetime.datetime(2023, 4, 12, 0, 0, 0) + session = Session(test_engine) + try: + # Insert test data - one row every minute for 2 hours + num_rows = 120 # 2 hours * 60 minutes + models = [ + test_model( + col_boolean=True, + col_byte=8, + col_short=12, + col_int=idx, + col_long=14, + col_float=15.234, + col_double=16.88993244, + col_symbol='coconut', + col_string='banana', + col_char='C', + col_uuid='6d5eb038-63d1-4971-8484-30c16e13de5b', + col_date=base_ts.date(), + # Add idx minutes to base timestamp + col_ts=base_ts + datetime.timedelta(minutes=idx), + col_geohash='dfvgsj2vptwu', + col_long256='0xa3b400fcf6ed707d710d5d4e672305203ed3cc6254d1cefe313e4a465861f42a', + col_varchar='pineapple' + ) for idx in range(num_rows) + ] + session.bulk_save_objects(models) + session.commit() + + metadata = sqla.MetaData() + table = sqla.Table(ALL_TYPES_TABLE_NAME, metadata, autoload_with=test_engine) + wait_until_table_is_ready(test_engine, ALL_TYPES_TABLE_NAME, num_rows) + + with test_engine.connect() as conn: + # Simple SAMPLE BY + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .sample_by(30, 'm') # 30 minute samples + ) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 4 # 2 hours should give us 4 30-minute samples + + # Verify sample averages + # First 30 min should average 0-29, second 30-59, etc. + expected_averages = [14.5, 44.5, 74.5, 104.5] # (min+max)/2 for each 30-min period + for row, expected_avg in zip(rows, expected_averages): + assert abs(row.avg_int - expected_avg) < 0.1 + + # SAMPLE BY with ORDER BY + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .sample_by(1, 'h') # 1 hour samples + .order_by(sqla.desc('avg_int')) + ) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 2 # 2 one-hour samples + assert rows[0].avg_int > rows[1].avg_int # Descending order + + # SAMPLE BY with WHERE clause + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .where(table.c.col_int > 30) + .sample_by(1, 'h') + ) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 2 + assert all(row.avg_int > 30 for row in rows) + + # SAMPLE BY with LIMIT + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .sample_by(15, 'm') # 15 minute samples + .limit(3) + ) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 3 # Should limit to first 3 samples + finally: + if session: + session.close() + def test_insert(test_engine, test_model): with test_engine.connect() as conn: assert test_engine.dialect.has_table(conn, ALL_TYPES_TABLE_NAME) @@ -222,6 +390,233 @@ def test_bulk_insert(test_engine, test_model): assert collect_select_all_raw_connection(test_engine, expected_rows=num_rows) == expected +def test_sample_by_from_to(test_engine, test_model): + """Test SAMPLE BY with FROM-TO extension.""" + base_ts = datetime.datetime(2023, 4, 12, 0, 0, 0) + day_before = base_ts - datetime.timedelta(days=1) + day_after = base_ts + datetime.timedelta(days=1) + session = Session(test_engine) + try: + num_rows = 6 # 6 hours only + models = [ + test_model( + col_int=idx, + col_ts=base_ts + datetime.timedelta(hours=idx), + ) for idx in range(num_rows) + ] + + session.bulk_save_objects(models) + session.commit() + + metadata = sqla.MetaData() + table = sqla.Table(ALL_TYPES_TABLE_NAME, metadata, autoload_with=test_engine) + wait_until_table_is_ready(test_engine, ALL_TYPES_TABLE_NAME, len(models)) + + with test_engine.connect() as conn: + # Test FROM-TO with FILL + query = ( + questdb_connect.select( + table.c.col_ts, + sqla.func.avg(table.c.col_int).label('avg_int') + ) + .sample_by( + 1, 'h', + fill="NULL", + from_timestamp=day_before, # day before data starts + to_timestamp=day_after # day after data ends + ) + ) + result = conn.execute(query) + rows = result.fetchall() + + assert len(rows) == 48 # 48 hours in total + + # First rows should be NULL (before our data starts) + assert rows[0].avg_int is None + assert rows[1].avg_int is None + assert rows[2].avg_int is None + assert rows[3].avg_int is None + + # Middle rows should have data + assert any(row.avg_int is not None for row in rows[4:-4]) + + # Last rows should be NULL (after our data ends) + assert rows[-4].avg_int is None + assert rows[-3].avg_int is None + assert rows[-2].avg_int is None + assert rows[-1].avg_int is None + + # Test FROM only + query = ( + questdb_connect.select( + table.c.col_ts, + sqla.func.avg(table.c.col_int).label('avg_int') + ) + .sample_by( + 1, 'h', + fill="NULL", + from_timestamp=day_before # day before data starts + ) + ) + result = conn.execute(query) + rows = result.fetchall() + + # First rows should be NULL + assert rows[0].avg_int is None + assert rows[1].avg_int is None + assert rows[2].avg_int is None + assert rows[3].avg_int is None + + # Test TO only + query = ( + questdb_connect.select( + table.c.col_ts, + sqla.func.avg(table.c.col_int).label('avg_int') + ) + .sample_by( + 1, 'h', + fill="NULL", + to_timestamp=day_after # day after data ends + ) + ) + result = conn.execute(query) + rows = result.fetchall() + + # Last rows should be NULL + assert rows[-4].avg_int is None + assert rows[-3].avg_int is None + assert rows[-2].avg_int is None + assert rows[-1].avg_int is None + + finally: + if session: + session.close() + +def test_plain_select_core_api(test_engine, test_model): + """ + Test plain select with core API. Plain select means select implementation from sqlalchemy.sql.selectable, + not from questdb_connect. + """ + + session = Session(test_engine) + try: + num_rows = 3 + models = [ + test_model( + col_int=idx, + col_ts=datetime.datetime(2023, 4, 12, 0, 0, 0) + datetime.timedelta(hours=idx), + ) for idx in range(num_rows) + ] + session.bulk_save_objects(models) + session.commit() + + metadata = sqla.MetaData() + table = sqla.Table(ALL_TYPES_TABLE_NAME, metadata, autoload_with=test_engine) + wait_until_table_is_ready(test_engine, ALL_TYPES_TABLE_NAME, len(models)) + + with test_engine.connect() as conn: + query = ( + # important: use sqla.select, not questdb_connect.select! + sqla.select(table.c.col_ts, table.c.col_int) + ) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 3 + finally: + if session: + session.close() + +def test_sample_by_options(test_engine, test_model): + """Test SAMPLE BY with ALIGN TO and FILL options.""" + base_ts = datetime.datetime(2023, 4, 12, 0, 0, 0) + session = Session(test_engine) + try: + # Insert test data - one row every hour for a day + num_rows = 24 + models = [ + test_model( + col_int=idx, + col_ts=base_ts + datetime.timedelta(hours=idx), + ) for idx in range(num_rows) + ] + # Add some gaps by removing every 3rd record + models = [m for i, m in enumerate(models) if i % 3 != 0] + + session.bulk_save_objects(models) + session.commit() + + metadata = sqla.MetaData() + table = sqla.Table(ALL_TYPES_TABLE_NAME, metadata, autoload_with=test_engine) + wait_until_table_is_ready(test_engine, ALL_TYPES_TABLE_NAME, len(models)) + + with test_engine.connect() as conn: + # Test FILL(NULL) + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .sample_by(15, 'm', fill="NULL") + ) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 89 + # Should have NULLs for missing data points + assert any(row.avg_int is None for row in rows) + + + # Test FILL(PREV) + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .sample_by(15, 'm', fill="PREV") + ) + result = conn.execute(query) + rows = result.fetchall() + assert all(row.avg_int is not None for row in rows) + + # Test FILL with constant + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .sample_by(15, 'm', fill=999.99) + .limit(10) + ) + result = conn.execute(query) + rows = result.fetchall() + assert any(row.avg_int == 999.99 for row in rows) + + # Test ALIGN TO FIRST OBSERVATION + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .sample_by(15, 'm', align_to="FIRST OBSERVATION") + .limit(10) + ) + result = conn.execute(query) + first_row = result.fetchone() + # First timestamp should match our first data point + assert first_row.col_ts == models[0].col_ts + + # Test with timezone + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .sample_by(1, 'd', align_to="CALENDAR", timezone="Europe/Prague") + ) + result = conn.execute(query) + rows = result.fetchall() + # First row should be at midnight Prague time, that is 22:00 UTC the previous day + assert rows[0].col_ts.hour == 22 + assert rows[1].col_ts.hour == 22 + + # Test with offset + query = ( + questdb_connect.select(table.c.col_ts, sqla.func.avg(table.c.col_int).label('avg_int')) + .sample_by(1, 'd', align_to="CALENDAR", offset="02:00") + ) + result = conn.execute(query) + rows = result.fetchall() + # First row should start at 02:00 + assert rows[0].col_ts.hour == 2 + + finally: + if session: + session.close() + def test_dialect_get_schema_names(test_engine): dialect = qdbc.QuestDBDialect() with test_engine.connect() as conn: