diff --git a/integrations/python/dataloader/CLAUDE.md b/integrations/python/dataloader/CLAUDE.md index cc3e6b313..d63180c90 100644 --- a/integrations/python/dataloader/CLAUDE.md +++ b/integrations/python/dataloader/CLAUDE.md @@ -75,18 +75,22 @@ Exported in `__init__.py`: - `OpenHouseCatalogError` — Error raised when catalog fails to load a table - `col()` — Column reference for building filter expressions - `always_true()` — Filter that matches all rows +- `to_sql()` — Render a filter as a SQL boolean expression (WHERE-clause predicate) for a target +- `SqlTarget` — Enum of supported SQL flavors (`SPARK`, `TRINO`, `DATA_FUSION`); used by `to_sql()` and `TableTransformer` ### Filter DSL (`filters.py`) Build row filters using `col()` with comparison operators (`==`, `!=`, `>`, `>=`, `<`, `<=`) and predicates (`is_null()`, `is_not_null()`, `is_nan()`, `is_not_nan()`, `is_in()`, `is_not_in()`, `starts_with()`, `not_starts_with()`, `between()`). Combine with `&` (AND), `|` (OR), `~` (NOT). Filters are converted to PyIceberg expressions internally for partition pruning and file-level filtering. +`to_sql(filter, target=SqlTarget.SPARK)` renders a filter as a SQL boolean expression for the given target. Internally, `_filter_to_expr()` builds a single dialect-agnostic sqlglot AST that is rendered per target with `.sql(dialect=target.value)` — the same path the loader uses for its internal DataFusion query via `to_sql(filters, SqlTarget.DATA_FUSION)`. The backing sqlglot dialect string is an internal detail — callers select a `SqlTarget`, never a dialect string. + ### Internal modules (not in `__init__.py`) -- `TableTransformer` — ABC for SQL-based table transforms; subclass must provide a `dialect` (e.g. `"spark"`) and implement `transform()` returning SQL or `None` +- `TableTransformer` — ABC for SQL-based table transforms; subclass must provide a `dialect` (a `SqlTarget`, e.g. `SqlTarget.SPARK`) and implement `transform()` returning SQL or `None` - `UDFRegistry` / `NoOpRegistry` — ABC for registering DataFusion UDFs; `NoOpRegistry` is the default no-op - `TableScanContext` — Frozen dataclass holding table metadata, FileIO, projected schema, row filter, and table ID; pickle-safe for distributed execution - `DataFusion` dialect in `datafusion_sql.py` — Custom SQLGlot dialect for transpiling SQL from other dialects (e.g. Spark) to DataFusion -- `to_datafusion_sql()` — Transpiles a SQL statement from a source dialect to DataFusion using SQLGlot +- `to_datafusion_sql()` — Transpiles a SQL statement from a source `SqlTarget` to DataFusion using SQLGlot ## Key Dependencies diff --git a/integrations/python/dataloader/README.md b/integrations/python/dataloader/README.md index b9b59e661..6cf32c059 100644 --- a/integrations/python/dataloader/README.md +++ b/integrations/python/dataloader/README.md @@ -52,6 +52,19 @@ filters = col("score").between(0.5, 1.0) filters = (col("age") >= 18) & (col("country").is_in(["US", "CA"])) & ~col("email").is_null() ``` +### Rendering a filter as SQL + +Use `to_sql()` to render a filter as a SQL boolean expression (a `WHERE`-clause +predicate) for a given `SqlTarget` (`SPARK`, `TRINO`, or `DATA_FUSION`): + +```python +from openhouse.dataloader import SqlTarget, col, to_sql + +to_sql(col("age") > 21) # `age` > 21 (defaults to Spark) +to_sql((col("country") == "US") & col("email").is_null(), SqlTarget.SPARK) +# `country` = 'US' AND `email` IS NULL +``` + ## Development ```bash diff --git a/integrations/python/dataloader/src/openhouse/dataloader/__init__.py b/integrations/python/dataloader/src/openhouse/dataloader/__init__.py index df266c799..16d8659a9 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/__init__.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/__init__.py @@ -2,7 +2,7 @@ from openhouse.dataloader.catalog import OpenHouseCatalog, OpenHouseCatalogError from openhouse.dataloader.data_loader import DataLoaderContext, JvmConfig, OpenHouseDataLoader -from openhouse.dataloader.filters import always_true, col +from openhouse.dataloader.filters import SqlTarget, always_true, col, to_sql __version__ = version("openhouse.dataloader") __all__ = [ @@ -11,6 +11,8 @@ "JvmConfig", "OpenHouseCatalog", "OpenHouseCatalogError", + "SqlTarget", "always_true", "col", + "to_sql", ] diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index da9a3e943..8ff0d2724 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -25,10 +25,11 @@ from openhouse.dataloader.filters import ( AlwaysTrue, Filter, + SqlTarget, _quote_identifier, - _to_datafusion_sql, _to_pyiceberg, always_true, + to_sql, ) from openhouse.dataloader.metrics import METER_NAME from openhouse.dataloader.scan_optimizer import optimize_scan @@ -302,7 +303,7 @@ def _build_query(self) -> str | None: outer_cols = ", ".join(_quote_identifier(c) for c in self._columns) if self._columns else "*" combined = f"SELECT {outer_cols} FROM ({sql}) AS _t" if self._filters and not isinstance(self._filters, AlwaysTrue): - combined += f" WHERE {_to_datafusion_sql(self._filters)}" + combined += f" WHERE {to_sql(self._filters, SqlTarget.DATA_FUSION)}" return combined def __iter__(self) -> Iterator[DataLoaderSplit]: diff --git a/integrations/python/dataloader/src/openhouse/dataloader/datafusion_sql.py b/integrations/python/dataloader/src/openhouse/dataloader/datafusion_sql.py index ca11fb673..6a9834cd6 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/datafusion_sql.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/datafusion_sql.py @@ -8,6 +8,7 @@ from sqlglot.tokens import Tokenizer as _Tokenizer from sqlglot.tokens import TokenType +from openhouse.dataloader.filters import SqlTarget from openhouse.dataloader.table_identifier import TableIdentifier @@ -100,7 +101,7 @@ class Generator(_Generator): def to_datafusion_sql( sql: str, - source_dialect: str, + source_dialect: SqlTarget, *, table: TableIdentifier | None = None, ) -> str: @@ -110,28 +111,23 @@ def to_datafusion_sql( that table. Args: - sql: SQL statement in the source dialect. - source_dialect: sqlglot dialect name (e.g. "spark", "postgres"). + sql: SQL statement written in *source_dialect*. + source_dialect: The SqlTarget the *sql* is written in (e.g. ``SqlTarget.SPARK``). table: Expected table the SQL must reference. When set the function verifies the SQL contains exactly one table matching this identifier. Raises: - ValueError: If the dialect is unsupported, the SQL is invalid, the - input contains more than one statement, or (when *table* is set) - the table reference does not match. + ValueError: If the SQL is invalid, the input contains more than one + statement, or (when *table* is set) the table reference does not match. """ - if source_dialect not in Dialect.classes: - raise ValueError( - f"Unsupported source dialect '{source_dialect}'. Supported dialects: {', '.join(sorted(Dialect.classes))}" - ) - if source_dialect == DataFusion.DIALECT and table is None: + if source_dialect is SqlTarget.DATA_FUSION and table is None: return sql try: - statements = sqlglot.parse(sql, dialect=source_dialect) + statements = sqlglot.parse(sql, dialect=source_dialect.value) except sqlglot.errors.SqlglotError as e: - raise ValueError(f"Failed to transpile SQL from '{source_dialect}' to DataFusion: {e}") from e + raise ValueError(f"Failed to transpile SQL from '{source_dialect.value}' to DataFusion: {e}") from e if len(statements) != 1 or statements[0] is None: raise ValueError(f"Expected exactly one SQL statement, got {len(statements)}: {statements}") ast = statements[0] diff --git a/integrations/python/dataloader/src/openhouse/dataloader/filters.py b/integrations/python/dataloader/src/openhouse/dataloader/filters.py index a725eaa5e..99a76db79 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/filters.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/filters.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from datetime import date, datetime, time from decimal import Decimal +from enum import Enum from typing import Any from uuid import UUID @@ -325,101 +326,179 @@ def _escape_like(value: str) -> str: return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") -def _literal_to_sql(value: object) -> str: - """Convert a Python literal to a SQL literal string using sqlglot. +def _non_finite_double(value: float | Decimal) -> exp.Cast: + """Build ``CAST('' AS DOUBLE)`` for a non-finite float/decimal. + + Spark and DataFusion both parse the IEEE special-value spellings ``NaN`` / + ``Infinity`` / ``-Infinity`` (the lowercase ``str(float(...))`` forms like + ``'inf'`` are not reliably cast by Spark). + """ + number = float(value) + if math.isnan(number): + spelling = "NaN" + elif number == math.inf: + spelling = "Infinity" + elif number == -math.inf: + spelling = "-Infinity" + else: + raise ValueError(f"_non_finite_double expects a non-finite value, got {value!r}") + return exp.Cast(this=exp.Literal.string(spelling), to=exp.DataType.build("DOUBLE")) + + +def _literal_to_expr(value: object) -> exp.Expression: + """Convert a Python literal to a sqlglot literal expression. Datetime/date/time values are emitted as plain string literals (ISO format). DataFusion implicitly coerces string literals to the column type at execution, and PyIceberg promotes StringLiteral to the matching typed literal during expression binding. """ if isinstance(value, str): - return exp.Literal.string(value).sql() + return exp.Literal.string(value) if isinstance(value, bool): - return exp.Boolean(this=True).sql() if value else exp.Boolean(this=False).sql() + return exp.true() if value else exp.false() if isinstance(value, datetime): - return exp.Literal.string(value.isoformat()).sql() + return exp.Literal.string(value.isoformat()) if isinstance(value, date): - return exp.Literal.string(value.isoformat()).sql() + return exp.Literal.string(value.isoformat()) if isinstance(value, time): if value.tzinfo is not None: raise TypeError( - "DataFusion does not support timezones for time data types. " + "The SQL target does not support timezones for time data types. " "The time should match the timezone used in the dataset." ) - return exp.Literal.string(value.isoformat()).sql() + return exp.Literal.string(value.isoformat()) if isinstance(value, (int, float)): if isinstance(value, float) and not math.isfinite(value): - return exp.Cast(this=exp.Literal.string(str(value)), to=exp.DataType.build("DOUBLE")).sql() - return exp.Literal.number(value).sql() + return _non_finite_double(value) + return exp.Literal.number(value) if isinstance(value, Decimal): if not value.is_finite(): - return exp.Cast(this=exp.Literal.string(str(value)), to=exp.DataType.build("DOUBLE")).sql() - return exp.Literal.number(value).sql() + return _non_finite_double(value) + return exp.Literal.number(value) if isinstance(value, UUID): - return exp.Literal.string(str(value)).sql() + return exp.Literal.string(str(value)) raise TypeError(f"Unsupported literal type: {type(value).__name__}") -def _to_datafusion_sql(expr: Filter) -> str: - """Convert a Filter expression tree to a DataFusion SQL expression string.""" - match expr: +class SqlTarget(Enum): + """A SQL flavor used by this library. + + Used both to render filters (:func:`to_sql`) and to declare the SQL a + ``TableTransformer`` emits. Members name a concrete SQL flavor; the backing + sqlglot dialect string is an internal detail — callers use the member, not + the string. + """ + + SPARK = "spark" + TRINO = "trino" + DATA_FUSION = "datafusion" + + +def _column_expr(name: str) -> exp.Column: + """Build a quoted sqlglot column reference.""" + return exp.column(exp.to_identifier(name, quoted=True)) + + +def _like_prefix(column: str, prefix: str) -> exp.Expression: + r"""Build ``col LIKE 'prefix%' ESCAPE '\'`` with the prefix matched literally.""" + pattern = exp.Literal.string(_escape_like(prefix) + "%") + like = exp.Like(this=_column_expr(column), expression=pattern) + return exp.Escape(this=like, expression=exp.Literal.string("\\")) + + +def _isnan(column: str, target: SqlTarget) -> exp.Anonymous: + """Build the dialect-appropriate NaN-check function call. + + Trino spells it ``is_nan``; Spark and DataFusion use ``isnan``. (sqlglot's + built-in IsNan node renders ``IS_NAN(...)``, which none of these accept, so we + emit an Anonymous function with the correct name for the target.) + """ + name = "is_nan" if target is SqlTarget.TRINO else "isnan" + return exp.Anonymous(this=name, expressions=[_column_expr(column)]) + + +def _filter_to_expr(filter_expr: Filter, target: SqlTarget) -> exp.Expression: + """Build a sqlglot expression tree for a Filter, ready to render for *target*. + + The tree is dialect-agnostic except where SQL genuinely diverges (NaN checks); + sqlglot handles per-dialect identifier quoting, operators, and literal + formatting at render time. See :func:`to_sql`. + """ + match filter_expr: case AlwaysTrue(): - return exp.Boolean(this=True).sql() + return exp.true() # Comparison case EqualTo(column, value): - return f"{_quote_identifier(column)} = {_literal_to_sql(value)}" + return exp.EQ(this=_column_expr(column), expression=_literal_to_expr(value)) case NotEqualTo(column, value): - return f"{_quote_identifier(column)} <> {_literal_to_sql(value)}" + return exp.NEQ(this=_column_expr(column), expression=_literal_to_expr(value)) case GreaterThan(column, value): - return f"{_quote_identifier(column)} > {_literal_to_sql(value)}" + return exp.GT(this=_column_expr(column), expression=_literal_to_expr(value)) case GreaterThanOrEqual(column, value): - return f"{_quote_identifier(column)} >= {_literal_to_sql(value)}" + return exp.GTE(this=_column_expr(column), expression=_literal_to_expr(value)) case LessThan(column, value): - return f"{_quote_identifier(column)} < {_literal_to_sql(value)}" + return exp.LT(this=_column_expr(column), expression=_literal_to_expr(value)) case LessThanOrEqual(column, value): - return f"{_quote_identifier(column)} <= {_literal_to_sql(value)}" + return exp.LTE(this=_column_expr(column), expression=_literal_to_expr(value)) # Null / NaN case IsNull(column): - return f"{_quote_identifier(column)} IS NULL" + return exp.Is(this=_column_expr(column), expression=exp.Null()) case IsNotNull(column): - return f"{_quote_identifier(column)} IS NOT NULL" + return exp.not_(exp.Is(this=_column_expr(column), expression=exp.Null())) case IsNaN(column): - return f"{_quote_identifier(column)} IS NAN" + return _isnan(column, target) case IsNotNaN(column): - return f"{_quote_identifier(column)} IS NOT NAN" + return exp.not_(_isnan(column, target)) # Set membership case In(column, values): - vals = ", ".join(_literal_to_sql(v) for v in values) - return f"{_quote_identifier(column)} IN ({vals})" + return exp.In(this=_column_expr(column), expressions=[_literal_to_expr(v) for v in values]) case NotIn(column, values): - vals = ", ".join(_literal_to_sql(v) for v in values) - return f"{_quote_identifier(column)} NOT IN ({vals})" + return exp.not_(exp.In(this=_column_expr(column), expressions=[_literal_to_expr(v) for v in values])) # String prefix case StartsWith(column, prefix): - escaped = _escape_like(prefix) - return f"{_quote_identifier(column)} LIKE {_literal_to_sql(escaped + '%')} ESCAPE '\\'" + return _like_prefix(column, prefix) case NotStartsWith(column, prefix): - escaped = _escape_like(prefix) - return f"{_quote_identifier(column)} NOT LIKE {_literal_to_sql(escaped + '%')} ESCAPE '\\'" + return exp.not_(_like_prefix(column, prefix)) # Range case Between(column, lower, upper): - return f"{_quote_identifier(column)} BETWEEN {_literal_to_sql(lower)} AND {_literal_to_sql(upper)}" + return exp.Between( + this=_column_expr(column), + low=_literal_to_expr(lower), + high=_literal_to_expr(upper), + ) # Logical combinators case And(left, right): - return f"({_to_datafusion_sql(left)} AND {_to_datafusion_sql(right)})" + return exp.and_(_filter_to_expr(left, target), _filter_to_expr(right, target)) case Or(left, right): - return f"({_to_datafusion_sql(left)} OR {_to_datafusion_sql(right)})" + return exp.or_(_filter_to_expr(left, target), _filter_to_expr(right, target)) case Not(operand): - return f"NOT ({_to_datafusion_sql(operand)})" + return exp.not_(_filter_to_expr(operand, target)) case _: - raise TypeError(f"Unsupported filter type: {type(expr).__name__}") + raise TypeError(f"Unsupported filter type: {type(filter_expr).__name__}") + + +def to_sql(filter_expr: Filter, target: SqlTarget = SqlTarget.SPARK) -> str: + """Render a filter as a SQL boolean expression for the given target. + + The result is a WHERE-clause-ready predicate (without a leading ``WHERE``). + + Example:: + + to_sql(col("age") > 21) # `age` > 21 + to_sql((col("a") == 1) & col("b").is_null()) # `a` = 1 AND `b` IS NULL + + Args: + filter_expr: The filter expression to render. + target: The SQL flavor to render for. Defaults to Spark. + """ + return _filter_to_expr(filter_expr, target).sql(dialect=target.value) def _to_pyiceberg(expr: Filter) -> ice.BooleanExpression: diff --git a/integrations/python/dataloader/src/openhouse/dataloader/table_transformer.py b/integrations/python/dataloader/src/openhouse/dataloader/table_transformer.py index ea6746b9f..8bca405c5 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/table_transformer.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/table_transformer.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping +from openhouse.dataloader.filters import SqlTarget from openhouse.dataloader.table_identifier import TableIdentifier @@ -13,11 +14,11 @@ class TableTransformer(ABC): projections and filters can reference them. Args: - dialect: The SQL dialect used by ``transform()`` (e.g. ``"spark"``). + dialect: The SQL flavor that ``transform()`` emits (e.g. ``SqlTarget.SPARK``). """ - def __init__(self, dialect: str) -> None: - self.dialect: str = dialect + def __init__(self, dialect: SqlTarget) -> None: + self.dialect: SqlTarget = dialect @abstractmethod def transform(self, table: TableIdentifier, context: Mapping[str, str]) -> str | None: diff --git a/integrations/python/dataloader/tests/integration_tests.py b/integrations/python/dataloader/tests/integration_tests.py index bb183a99c..000347712 100644 --- a/integrations/python/dataloader/tests/integration_tests.py +++ b/integrations/python/dataloader/tests/integration_tests.py @@ -21,7 +21,7 @@ from openhouse.dataloader import DataLoaderContext, JvmConfig, OpenHouseDataLoader from openhouse.dataloader.catalog import OpenHouseCatalog from openhouse.dataloader.data_loader_split import to_sql_identifier -from openhouse.dataloader.filters import col +from openhouse.dataloader.filters import SqlTarget, col, to_sql from openhouse.dataloader.table_transformer import TableTransformer BASE_URL = "http://openhouse-tables:8080" @@ -85,6 +85,15 @@ def _wait_for_idle(self) -> None: def execute(self, sql: str) -> None: """Submit a SQL statement and wait for completion. Raises on error.""" + self._run(sql) + + def query(self, sql: str) -> list[list]: + """Submit a SELECT, wait for completion, and return its result rows.""" + output = self._run(sql) + return output["data"]["application/json"]["data"] + + def _run(self, sql: str) -> dict: + """Submit a SQL statement, wait for completion, and return its output. Raises on error.""" print(f" SQL: {sql}") resp = requests.post( f"{self._session_url}/statements", json={"code": sql}, headers=HEADERS, timeout=REQUEST_TIMEOUT @@ -103,7 +112,7 @@ def execute(self, sql: str) -> None: output = resp.json()["output"] if output["status"] == "error": raise RuntimeError(f"SQL failed: {output.get('evalue', output)}") - return + return output if state in ("error", "cancelled"): raise RuntimeError(f"Statement entered state: {state}") time.sleep(1) @@ -284,6 +293,37 @@ def read_token() -> str: assert result.column(COL_SCORE).to_pylist() == [1.1, 2.2, 3.3, 4.4] print(f"PASS: after second insert, read all {result.num_rows} rows") + # 6b. Spark-SQL filter parity: read rows with DataLoader filters, then read the + # same rows from Spark using to_sql(filters, SqlTarget.SPARK), and verify the + # two row sets are identical. diana (score > 2.0 but excluded by the name IN) + # proves the predicate is actually applied on both sides, not ignored. + parity_filter = (col(COL_SCORE) > 2.0) & col(COL_NAME).is_in(["alice", "bob", "charlie"]) + parity_loader = OpenHouseDataLoader( + catalog=catalog, database=DATABASE_ID, table=TABLE_ID, filters=parity_filter + ) + dl_result = _read_all(parity_loader) + dl_rows = list( + zip( + dl_result.column(COL_ID).to_pylist(), + dl_result.column(COL_NAME).to_pylist(), + dl_result.column(COL_SCORE).to_pylist(), + strict=True, + ) + ) + + spark_where = to_sql(parity_filter, SqlTarget.SPARK) + spark_rows = [ + tuple(row) + for row in livy.query( + f"SELECT {COL_ID}, {COL_NAME}, {COL_SCORE} FROM {FQTN} WHERE {spark_where} ORDER BY {COL_ID}" + ) + ] + assert dl_rows, "Expected the parity filter to match at least one row" + assert dl_rows == spark_rows, ( + f"DataLoader rows {dl_rows} != Spark rows {spark_rows} (Spark WHERE: {spark_where})" + ) + print(f"PASS: DataLoader and Spark agree on {len(dl_rows)} rows for WHERE {spark_where}") + # 7. Pin to the old snapshot and verify only the original data is returned loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, snapshot_id=snap1) result = _read_all(loader) @@ -345,10 +385,10 @@ def read_token() -> str: # SQL roundtrip path (filters -> DataFusion SQL -> sqlglot -> scan_optimizer -> # PyIceberg expression). Without a transformer, _build_query() returns None # and the loader skips that path entirely, which would mean a CAST(literal, - # TIMESTAMP) regression in _literal_to_sql / scan_optimizer would go unnoticed. + # TIMESTAMP) regression in _literal_to_expr / scan_optimizer would go unnoticed. class _PartPassthroughTransformer(TableTransformer): def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f'SELECT "id", "ts" FROM {to_sql_identifier(table)}' diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index a66ee29ca..0d56a714e 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -17,7 +17,7 @@ from openhouse.dataloader import DataLoaderContext, JvmConfig, OpenHouseDataLoader, __version__ from openhouse.dataloader.data_loader_split import DataLoaderSplit, to_sql_identifier -from openhouse.dataloader.filters import col +from openhouse.dataloader.filters import SqlTarget, col from openhouse.dataloader.table_transformer import TableTransformer from openhouse.dataloader.udf_registry import UDFRegistry @@ -346,7 +346,7 @@ class _NoneTransformer(TableTransformer): """Transformer that returns None (no transformation).""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return None @@ -356,7 +356,7 @@ class _MaskingTransformer(TableTransformer): """Transformer that masks the name column.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, 'MASKED' as name, value FROM {to_sql_identifier(table)}" @@ -443,7 +443,7 @@ class _SparkMaskingTransformer(TableTransformer): """Transformer using Spark SQL dialect.""" def __init__(self): - super().__init__(dialect="spark") + super().__init__(dialect=SqlTarget.SPARK) def transform(self, table, context): return f"SELECT id, CAST('MASKED' AS STRING) AS name, value FROM {to_sql_identifier(table)}" @@ -465,35 +465,13 @@ def test_iter_with_spark_dialect_transformer_transpiles(tmp_path): assert result.column("name").to_pylist() == ["MASKED", "MASKED", "MASKED"] -def test_iter_with_invalid_dialect_raises(tmp_path): - """Unsupported dialect raises ValueError during iteration.""" - - class _BadDialectTransformer(TableTransformer): - def __init__(self): - super().__init__(dialect="not_a_real_dialect") - - def transform(self, table, context): - return f"SELECT * FROM {to_sql_identifier(table)}" - - catalog = _make_real_catalog(tmp_path) - loader = OpenHouseDataLoader( - catalog=catalog, - database="db", - table="tbl", - context=DataLoaderContext(table_transformer=_BadDialectTransformer()), - ) - - with pytest.raises(ValueError, match="Unsupported source dialect"): - _materialize(loader) - - def test_iter_with_transformer_and_special_char_database(tmp_path): """Transformer works when the database name contains special characters.""" catalog = _make_real_catalog(tmp_path) class _QuotedMaskingTransformer(TableTransformer): def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, 'MASKED' as name, value FROM {to_sql_identifier(table)}" @@ -664,7 +642,7 @@ class _FilteringTransformer(TableTransformer): """Transformer that has a WHERE clause filtering on status.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, name, value, status FROM {to_sql_identifier(table)} WHERE status = 'active'" @@ -707,7 +685,7 @@ class _MaskingFilteringTransformer(TableTransformer): """Transformer that masks name and filters on value.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, 'MASKED' as name, value FROM {to_sql_identifier(table)} WHERE value > 1.5" @@ -784,7 +762,7 @@ class _PassthroughTransformer(TableTransformer): """Transformer that selects all columns unchanged.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, name, value FROM {to_sql_identifier(table)}" @@ -807,7 +785,7 @@ class _MixedCaseTransformer(TableTransformer): """Transformer that selects mixed-case columns.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f'SELECT "purchaseAmount", "itemCount", "discountRate" FROM {to_sql_identifier(table)}' @@ -978,7 +956,7 @@ class _NestedUDFTransformer(TableTransformer): """Nested subquery with UDF at two levels — triggers alias rewrite bug.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): tbl = to_sql_identifier(table) @@ -995,7 +973,7 @@ class _SelectStarUDFTransformer(TableTransformer): """SELECT * with UDF — triggers unquoted column bug after projection pushdown.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): tbl = to_sql_identifier(table) @@ -1007,7 +985,7 @@ class _ProjectionUDFTransformer(TableTransformer): """UDF in projection with inner SELECT * — triggers unquoted column bug in projection.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): tbl = to_sql_identifier(table) diff --git a/integrations/python/dataloader/tests/test_datafusion_sql.py b/integrations/python/dataloader/tests/test_datafusion_sql.py index bbdba3faa..db6e71cfb 100644 --- a/integrations/python/dataloader/tests/test_datafusion_sql.py +++ b/integrations/python/dataloader/tests/test_datafusion_sql.py @@ -5,6 +5,7 @@ import pytest from openhouse.dataloader.datafusion_sql import to_datafusion_sql +from openhouse.dataloader.filters import SqlTarget from openhouse.dataloader.table_identifier import TableIdentifier _DB_TBL = TableIdentifier(database="db", table="tbl") @@ -18,39 +19,38 @@ "sql, dialect, expected", [ # Spark → DataFusion - ("SELECT `col1`, `col2` FROM `my_table`", "spark", 'SELECT "col1", "col2" FROM "my_table"'), - ("SELECT SIZE(arr) FROM t", "spark", "SELECT cardinality(arr) FROM t"), - ("SELECT ARRAY(1, 2, 3)", "spark", "SELECT make_array(1, 2, 3)"), - ("SELECT UPPER(name) FROM t", "spark", "SELECT upper(name) FROM t"), - ("SELECT my_udf(col1, col2) FROM t", "spark", "SELECT my_udf(col1, col2) FROM t"), - ("SELECT IF(x > 0, 'pos', 'neg') FROM t", "spark", "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t"), + ("SELECT `col1`, `col2` FROM `my_table`", SqlTarget.SPARK, 'SELECT "col1", "col2" FROM "my_table"'), + ("SELECT SIZE(arr) FROM t", SqlTarget.SPARK, "SELECT cardinality(arr) FROM t"), + ("SELECT ARRAY(1, 2, 3)", SqlTarget.SPARK, "SELECT make_array(1, 2, 3)"), + ("SELECT UPPER(name) FROM t", SqlTarget.SPARK, "SELECT upper(name) FROM t"), + ("SELECT my_udf(col1, col2) FROM t", SqlTarget.SPARK, "SELECT my_udf(col1, col2) FROM t"), + ( + "SELECT IF(x > 0, 'pos', 'neg') FROM t", + SqlTarget.SPARK, + "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t", + ), ( "SELECT CASE WHEN status = 1 THEN 'active' ELSE 'inactive' END FROM t", - "spark", + SqlTarget.SPARK, "SELECT CASE WHEN status = 1 THEN 'active' ELSE 'inactive' END FROM t", ), ( "SELECT * FROM (SELECT id, name FROM t WHERE id > 10) sub WHERE sub.name IS NOT NULL", - "spark", + SqlTarget.SPARK, "SELECT * FROM (SELECT id, name FROM t WHERE id > 10) AS sub WHERE NOT sub.name IS NULL", ), - ("SELECT 'hello world' AS greeting", "spark", "SELECT 'hello world' AS greeting"), - ("SELECT CURRENT_TIMESTAMP()", "spark", "SELECT now()"), - ("SELECT CAST(x AS BINARY)", "spark", "SELECT TRY_CAST(x AS BYTEA)"), - # MySQL → DataFusion - ("SELECT CAST(x AS CHAR)", "mysql", "SELECT CAST(x AS VARCHAR)"), - ("SELECT CAST(x AS DATETIME)", "mysql", "SELECT CAST(x AS TIMESTAMP)"), - # Postgres → DataFusion - ("SELECT CAST(x AS TEXT)", "postgres", "SELECT CAST(x AS VARCHAR)"), + ("SELECT 'hello world' AS greeting", SqlTarget.SPARK, "SELECT 'hello world' AS greeting"), + ("SELECT CURRENT_TIMESTAMP()", SqlTarget.SPARK, "SELECT now()"), + ("SELECT CAST(x AS BINARY)", SqlTarget.SPARK, "SELECT TRY_CAST(x AS BYTEA)"), # DataFusion → DataFusion (noop) ( "SELECT cardinality(arr) FROM t WHERE x > 10 ORDER BY x LIMIT 5", - "datafusion", + SqlTarget.DATA_FUSION, "SELECT cardinality(arr) FROM t WHERE x > 10 ORDER BY x LIMIT 5", ), ], ) -def test_transpilation(sql: str, dialect: str, expected: str) -> None: +def test_transpilation(sql: str, dialect: SqlTarget, expected: str) -> None: assert to_datafusion_sql(sql, dialect) == expected @@ -62,19 +62,15 @@ def test_transpilation(sql: str, dialect: str, expected: str) -> None: class TestTranslatorEdgeCases: def test_multi_statement_raises(self) -> None: with pytest.raises(ValueError, match="Expected exactly one"): - to_datafusion_sql("SELECT 1; SELECT 2", "spark") - - def test_unsupported_dialect_raises(self) -> None: - with pytest.raises(ValueError, match="Unsupported source dialect 'nosuchdialect'"): - to_datafusion_sql("SELECT 1", "nosuchdialect") + to_datafusion_sql("SELECT 1; SELECT 2", SqlTarget.SPARK) def test_syntax_error_raises(self) -> None: with pytest.raises(ValueError, match="Failed to transpile SQL from 'spark' to DataFusion"): - to_datafusion_sql("SELECT * FROM", "spark") + to_datafusion_sql("SELECT * FROM", SqlTarget.SPARK) def test_datafusion_dialect_is_noop(self) -> None: sql = "SELECT make_array(1, 2, 3)" - assert to_datafusion_sql(sql, "datafusion") is sql + assert to_datafusion_sql(sql, SqlTarget.DATA_FUSION) is sql # --------------------------------------------------------------------------- @@ -84,42 +80,42 @@ def test_datafusion_dialect_is_noop(self) -> None: class TestTableValidation: def test_validates_single_table(self) -> None: - result = to_datafusion_sql('SELECT id FROM "db"."tbl"', "datafusion", table=_DB_TBL) + result = to_datafusion_sql('SELECT id FROM "db"."tbl"', SqlTarget.DATA_FUSION, table=_DB_TBL) assert result == 'SELECT id FROM "db"."tbl"' def test_wrong_table_name_raises(self) -> None: with pytest.raises(ValueError, match="references db.other, expected db.tbl"): - to_datafusion_sql('SELECT id FROM "db"."other"', "datafusion", table=_DB_TBL) + to_datafusion_sql('SELECT id FROM "db"."other"', SqlTarget.DATA_FUSION, table=_DB_TBL) def test_wrong_database_raises(self) -> None: with pytest.raises(ValueError, match="references other.tbl, expected db.tbl"): - to_datafusion_sql('SELECT id FROM "other"."tbl"', "datafusion", table=_DB_TBL) + to_datafusion_sql('SELECT id FROM "other"."tbl"', SqlTarget.DATA_FUSION, table=_DB_TBL) def test_multiple_tables_raises(self) -> None: with pytest.raises(ValueError, match="exactly 1 table, found 2"): to_datafusion_sql( 'SELECT * FROM "db"."tbl" JOIN "db"."tbl" AS t2 ON tbl.id = t2.id', - "datafusion", + SqlTarget.DATA_FUSION, table=_DB_TBL, ) def test_no_table_raises(self) -> None: with pytest.raises(ValueError, match="exactly 1 table, found 0"): - to_datafusion_sql("SELECT 1 AS x", "datafusion", table=_DB_TBL) + to_datafusion_sql("SELECT 1 AS x", SqlTarget.DATA_FUSION, table=_DB_TBL) def test_case_insensitive_table_match(self) -> None: - result = to_datafusion_sql('SELECT id FROM "DB"."TBL"', "datafusion", table=_DB_TBL) + result = to_datafusion_sql('SELECT id FROM "DB"."TBL"', SqlTarget.DATA_FUSION, table=_DB_TBL) assert result == 'SELECT id FROM "DB"."TBL"' def test_spark_table_validated_after_transpilation(self) -> None: - result = to_datafusion_sql("SELECT id FROM `db`.`tbl`", "spark", table=_DB_TBL) + result = to_datafusion_sql("SELECT id FROM `db`.`tbl`", SqlTarget.SPARK, table=_DB_TBL) assert result == 'SELECT id FROM "db"."tbl"' class TestNoOp: def test_no_filter_no_table_is_noop(self) -> None: sql = 'SELECT id FROM "db"."tbl"' - assert to_datafusion_sql(sql, "datafusion") is sql + assert to_datafusion_sql(sql, SqlTarget.DATA_FUSION) is sql # --------------------------------------------------------------------------- @@ -129,14 +125,14 @@ def test_no_filter_no_table_is_noop(self) -> None: def test_datafusion_execution() -> None: ctx = datafusion.SessionContext() - translated = to_datafusion_sql("SELECT SIZE(ARRAY(1, 2, 3))", "spark") + translated = to_datafusion_sql("SELECT SIZE(ARRAY(1, 2, 3))", SqlTarget.SPARK) batch = ctx.sql(translated).collect()[0] assert batch.column(0)[0].as_py() == 3 def test_datafusion_execution_median() -> None: ctx = datafusion.SessionContext() - translated = to_datafusion_sql("SELECT MEDIAN(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", "spark") + translated = to_datafusion_sql("SELECT MEDIAN(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", SqlTarget.SPARK) assert translated == "SELECT median(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)" batch = ctx.sql(translated).collect()[0] assert batch.column(0)[0].as_py() == 3 @@ -146,7 +142,7 @@ def test_datafusion_execution_percentile_cont() -> None: ctx = datafusion.SessionContext() translated = to_datafusion_sql( "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", - "spark", + SqlTarget.SPARK, ) expected = ( "SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY x NULLS FIRST)" @@ -161,7 +157,7 @@ def test_datafusion_execution_approx_percentile_cont() -> None: ctx = datafusion.SessionContext() translated = to_datafusion_sql( "SELECT PERCENTILE_APPROX(x, 0.5) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", - "spark", + SqlTarget.SPARK, ) assert translated == "SELECT approx_percentile_cont(x, 0.5) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)" batch = ctx.sql(translated).collect()[0] @@ -176,7 +172,7 @@ def double_it(arr: pa.Array) -> pa.Array: ctx.register_udf(datafusion.udf(double_it, [pa.int64()], pa.int64(), "stable", name="double_it")) - translated = to_datafusion_sql("SELECT double_it(x) FROM (VALUES (5)) AS t(x)", "spark") + translated = to_datafusion_sql("SELECT double_it(x) FROM (VALUES (5)) AS t(x)", SqlTarget.SPARK) assert translated == "SELECT double_it(x) FROM (VALUES (5)) AS t(x)" batch = ctx.sql(translated).collect()[0] assert batch.column(0)[0].as_py() == 10 diff --git a/integrations/python/dataloader/tests/test_filters.py b/integrations/python/dataloader/tests/test_filters.py index 2bc74e612..2f194d2c5 100644 --- a/integrations/python/dataloader/tests/test_filters.py +++ b/integrations/python/dataloader/tests/test_filters.py @@ -4,6 +4,7 @@ from uuid import UUID import pytest +import sqlglot from pyiceberg import expressions as ice from openhouse.dataloader import col @@ -28,10 +29,11 @@ NotIn, NotStartsWith, Or, + SqlTarget, StartsWith, - _to_datafusion_sql, _to_pyiceberg, always_true, + to_sql, ) @@ -355,81 +357,81 @@ def test_f_col_builds_filter(self): class TestDataFusionLiteralConversion: def test_datetime_greater_than_or_equal(self): dt = datetime(2026, 4, 27, tzinfo=UTC) - result = _to_datafusion_sql(col("datepartition") >= dt) + result = to_sql(col("datepartition") >= dt, SqlTarget.DATA_FUSION) assert result == "\"datepartition\" >= '2026-04-27T00:00:00+00:00'" def test_datetime_equal(self): dt = datetime(2026, 4, 27, 12, 30, 45, tzinfo=UTC) - result = _to_datafusion_sql(col("ts") == dt) + result = to_sql(col("ts") == dt, SqlTarget.DATA_FUSION) assert result == "\"ts\" = '2026-04-27T12:30:45+00:00'" def test_datetime_with_microseconds(self): dt = datetime(2026, 4, 27, 12, 30, 45, 123456, tzinfo=UTC) - result = _to_datafusion_sql(col("ts") == dt) + result = to_sql(col("ts") == dt, SqlTarget.DATA_FUSION) assert result == "\"ts\" = '2026-04-27T12:30:45.123456+00:00'" def test_datetime_non_utc_timezone_preserved(self): dt = datetime(2026, 4, 27, 12, 0, 0, tzinfo=timezone(timedelta(hours=5))) - result = _to_datafusion_sql(col("ts") >= dt) + result = to_sql(col("ts") >= dt, SqlTarget.DATA_FUSION) assert result == "\"ts\" >= '2026-04-27T12:00:00+05:00'" def test_datetime_naive_no_offset(self): dt = datetime(2026, 4, 27, 12, 0, 0) - result = _to_datafusion_sql(col("ts") >= dt) + result = to_sql(col("ts") >= dt, SqlTarget.DATA_FUSION) assert result == "\"ts\" >= '2026-04-27T12:00:00'" def test_date_greater_than_or_equal(self): d = date(2026, 4, 27) - result = _to_datafusion_sql(col("datepartition") >= d) + result = to_sql(col("datepartition") >= d, SqlTarget.DATA_FUSION) assert result == "\"datepartition\" >= '2026-04-27'" def test_datetime_between(self): dt1 = datetime(2026, 4, 27, tzinfo=UTC) dt2 = datetime(2026, 5, 1, tzinfo=UTC) - result = _to_datafusion_sql(col("ts").between(dt1, dt2)) + result = to_sql(col("ts").between(dt1, dt2), SqlTarget.DATA_FUSION) assert result == "\"ts\" BETWEEN '2026-04-27T00:00:00+00:00' AND '2026-05-01T00:00:00+00:00'" def test_datetime_in_compound_filter(self): dt = datetime(2026, 4, 27, tzinfo=UTC) f = (col("datepartition") >= dt) & (col("status") == "active") - result = _to_datafusion_sql(f) + result = to_sql(f, SqlTarget.DATA_FUSION) assert "'2026-04-27T00:00:00+00:00'" in result assert "\"status\" = 'active'" in result def test_time_equal(self): t = time(14, 30, 0) - result = _to_datafusion_sql(col("event_time") == t) + result = to_sql(col("event_time") == t, SqlTarget.DATA_FUSION) assert result == "\"event_time\" = '14:30:00'" def test_time_with_microseconds(self): t = time(14, 30, 0, 500000) - result = _to_datafusion_sql(col("event_time") == t) + result = to_sql(col("event_time") == t, SqlTarget.DATA_FUSION) assert result == "\"event_time\" = '14:30:00.500000'" def test_time_with_timezone_rejected(self): t = time(14, 30, 0, tzinfo=timezone(timedelta(hours=5))) with pytest.raises(TypeError, match="does not support timezones for time"): - _to_datafusion_sql(col("event_time") == t) + to_sql(col("event_time") == t, SqlTarget.DATA_FUSION) def test_decimal_greater_than(self): d = Decimal("99.95") - result = _to_datafusion_sql(col("price") > d) + result = to_sql(col("price") > d, SqlTarget.DATA_FUSION) assert result == '"price" > 99.95' def test_decimal_between(self): - result = _to_datafusion_sql(col("price").between(Decimal("10.00"), Decimal("50.00"))) + result = to_sql(col("price").between(Decimal("10.00"), Decimal("50.00")), SqlTarget.DATA_FUSION) assert result == '"price" BETWEEN 10.00 AND 50.00' @pytest.mark.parametrize( ("value", "expected"), [ - (float("nan"), "CAST('nan' AS DOUBLE)"), - (float("inf"), "CAST('inf' AS DOUBLE)"), - (float("-inf"), "CAST('-inf' AS DOUBLE)"), + (float("nan"), "CAST('NaN' AS DOUBLE)"), + (float("inf"), "CAST('Infinity' AS DOUBLE)"), + (float("-inf"), "CAST('-Infinity' AS DOUBLE)"), ], ) def test_non_finite_float(self, value, expected): - result = _to_datafusion_sql(col("x") == value) + result = to_sql(col("x") == value, SqlTarget.DATA_FUSION) assert result == f'"x" = {expected}' @pytest.mark.parametrize( @@ -441,12 +443,12 @@ def test_non_finite_float(self, value, expected): ], ) def test_non_finite_decimal(self, value, expected): - result = _to_datafusion_sql(col("x") == value) + result = to_sql(col("x") == value, SqlTarget.DATA_FUSION) assert result == f'"x" = {expected}' def test_uuid_equal(self): u = UUID("12345678-1234-5678-1234-567812345678") - result = _to_datafusion_sql(col("id") == u) + result = to_sql(col("id") == u, SqlTarget.DATA_FUSION) assert result == "\"id\" = '12345678-1234-5678-1234-567812345678'" @@ -491,14 +493,14 @@ def _query(self, ctx, where: str): return pa.Table.from_batches(batches) def test_datetime_filter(self, ctx): - where = _to_datafusion_sql(col("ts") >= datetime(2026, 4, 27, tzinfo=UTC)) + where = to_sql(col("ts") >= datetime(2026, 4, 27, tzinfo=UTC), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 2 ts_values = [v.as_py() for v in table.column("ts")] assert ts_values == [datetime(2026, 4, 27, tzinfo=UTC), datetime(2026, 4, 29, tzinfo=UTC)] def test_datetime_less_than(self, ctx): - where = _to_datafusion_sql(col("ts") < datetime(2026, 4, 27, tzinfo=UTC)) + where = to_sql(col("ts") < datetime(2026, 4, 27, tzinfo=UTC), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("ts")[0].as_py() == datetime(2026, 4, 25, tzinfo=UTC) @@ -506,27 +508,27 @@ def test_datetime_less_than(self, ctx): def test_datetime_non_utc_timezone(self, ctx): # 2026-04-27 05:00:00+05:00 == 2026-04-27 00:00:00 UTC, same as row 2 dt = datetime(2026, 4, 27, 5, 0, 0, tzinfo=timezone(timedelta(hours=5))) - where = _to_datafusion_sql(col("ts") >= dt) + where = to_sql(col("ts") >= dt, SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 2 ts_values = [v.as_py() for v in table.column("ts")] assert ts_values == [datetime(2026, 4, 27, tzinfo=UTC), datetime(2026, 4, 29, tzinfo=UTC)] def test_date_filter(self, ctx): - where = _to_datafusion_sql(col("dt") >= date(2026, 4, 27)) + where = to_sql(col("dt") >= date(2026, 4, 27), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 2 dt_values = [v.as_py() for v in table.column("dt")] assert dt_values == [date(2026, 4, 27), date(2026, 4, 29)] def test_time_filter(self, ctx): - where = _to_datafusion_sql(col("t") == time(14, 30, 0)) + where = to_sql(col("t") == time(14, 30, 0), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("t")[0].as_py() == time(14, 30, 0) def test_decimal_filter(self, ctx): - where = _to_datafusion_sql(col("price") > Decimal("10.00")) + where = to_sql(col("price") > Decimal("10.00"), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 2 price_values = [v.as_py() for v in table.column("price")] @@ -551,27 +553,28 @@ def test_non_finite_filter(self, value, expected_count, check): batch = pa.record_batch({"x": pa.array([1.0, float("nan"), float("inf"), float("-inf"), 5.0])}) ctx.register_record_batches("t", [[batch]]) - where = _to_datafusion_sql(col("x") == value) + where = to_sql(col("x") == value, SqlTarget.DATA_FUSION) batches = ctx.sql(f'SELECT * FROM "t" WHERE {where}').collect() table = pa.Table.from_batches(batches) assert table.num_rows == expected_count assert check(table.column("x")[0].as_py()) def test_high_precision_decimal_filter(self, ctx): - where = _to_datafusion_sql(col("price") > Decimal("49.9899999999999999")) + where = to_sql(col("price") > Decimal("49.9899999999999999"), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("price")[0].as_py() == Decimal("99.99") def test_uuid_filter(self, ctx): - where = _to_datafusion_sql(col("id") == UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")) + where = to_sql(col("id") == UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("id")[0].as_py() == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" def test_datetime_between_filter(self, ctx): - where = _to_datafusion_sql( - col("ts").between(datetime(2026, 4, 26, tzinfo=UTC), datetime(2026, 4, 28, tzinfo=UTC)) + where = to_sql( + col("ts").between(datetime(2026, 4, 26, tzinfo=UTC), datetime(2026, 4, 28, tzinfo=UTC)), + SqlTarget.DATA_FUSION, ) table = self._query(ctx, where) assert table.num_rows == 1 @@ -579,7 +582,7 @@ def test_datetime_between_filter(self, ctx): def test_compound_filter(self, ctx): f = (col("ts") >= datetime(2026, 4, 27, tzinfo=UTC)) & (col("price") > Decimal("50.00")) - where = _to_datafusion_sql(f) + where = to_sql(f, SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("ts")[0].as_py() == datetime(2026, 4, 29, tzinfo=UTC) @@ -598,16 +601,216 @@ def test_datetime_microseconds(self, ctx): } ) ctx2.register_record_batches("t2", [[batch]]) - where = _to_datafusion_sql(col("ts") == datetime(2026, 4, 27, 12, 0, 0, 500000, tzinfo=UTC)) + where = to_sql(col("ts") == datetime(2026, 4, 27, 12, 0, 0, 500000, tzinfo=UTC), SqlTarget.DATA_FUSION) table = ctx2.sql(f'SELECT * FROM "t2" WHERE {where}').collect() assert len(table[0]) == 1 assert table[0].column("ts")[0].as_py() == datetime(2026, 4, 27, 12, 0, 0, 500000, tzinfo=UTC) - where_no_match = _to_datafusion_sql(col("ts") == datetime(2026, 4, 27, 12, 0, 0, tzinfo=UTC)) + where_no_match = to_sql(col("ts") == datetime(2026, 4, 27, 12, 0, 0, tzinfo=UTC), SqlTarget.DATA_FUSION) table2 = ctx2.sql(f'SELECT * FROM "t2" WHERE {where_no_match}').collect() assert sum(len(b) for b in table2) == 0 +class TestSparkSqlConversion: + """Render filters to Spark SQL via the public ``to_sql`` entry point.""" + + def test_default_target_is_spark(self): + # to_sql defaults to the Spark target (backtick-quoted identifiers). + assert to_sql(col("x") == 5) == "`x` = 5" + + def test_always_true(self): + assert to_sql(always_true(), SqlTarget.SPARK) == "TRUE" + + # --- Comparison --- + + def test_equal(self): + assert to_sql(col("x") == 5, SqlTarget.SPARK) == "`x` = 5" + + def test_equal_string(self): + assert to_sql(col("name") == "alice", SqlTarget.SPARK) == "`name` = 'alice'" + + def test_equal_bool(self): + assert to_sql(col("flag") == True, SqlTarget.SPARK) == "`flag` = TRUE" # noqa: E712 + + def test_not_equal(self): + assert to_sql(col("x") != 5, SqlTarget.SPARK) == "`x` <> 5" + + def test_greater_than(self): + assert to_sql(col("x") > 5, SqlTarget.SPARK) == "`x` > 5" + + def test_greater_than_or_equal(self): + assert to_sql(col("x") >= 5, SqlTarget.SPARK) == "`x` >= 5" + + def test_less_than(self): + assert to_sql(col("x") < 5, SqlTarget.SPARK) == "`x` < 5" + + def test_less_than_or_equal(self): + assert to_sql(col("x") <= 5, SqlTarget.SPARK) == "`x` <= 5" + + # --- Null / NaN --- + + def test_is_null(self): + assert to_sql(col("x").is_null(), SqlTarget.SPARK) == "`x` IS NULL" + + def test_is_not_null(self): + assert to_sql(col("x").is_not_null(), SqlTarget.SPARK) == "NOT `x` IS NULL" + + def test_is_nan(self): + # NaN uses Spark's isnan() function, not the (unsupported) `IS NAN` syntax. + assert to_sql(col("x").is_nan(), SqlTarget.SPARK) == "ISNAN(`x`)" + + def test_is_not_nan(self): + assert to_sql(col("x").is_not_nan(), SqlTarget.SPARK) == "NOT ISNAN(`x`)" + + # --- Set membership --- + + def test_in(self): + assert to_sql(col("x").is_in([1, 2, 3]), SqlTarget.SPARK) == "`x` IN (1, 2, 3)" + + def test_in_strings(self): + assert to_sql(col("x").is_in(["a", "b"]), SqlTarget.SPARK) == "`x` IN ('a', 'b')" + + def test_not_in(self): + assert to_sql(col("x").is_not_in([1, 2, 3]), SqlTarget.SPARK) == "NOT `x` IN (1, 2, 3)" + + # --- String prefix (LIKE with escaped wildcards) --- + + def test_starts_with(self): + assert to_sql(col("name").starts_with("John"), SqlTarget.SPARK) == r"`name` LIKE 'John%' ESCAPE '\\'" + + def test_starts_with_escapes_wildcards(self): + # %, _ and \ in the prefix are escaped so they match literally. + assert to_sql(col("name").starts_with("a%b_c"), SqlTarget.SPARK) == r"`name` LIKE 'a\\%b\\_c%' ESCAPE '\\'" + + def test_not_starts_with(self): + assert to_sql(col("name").not_starts_with("John"), SqlTarget.SPARK) == r"NOT `name` LIKE 'John%' ESCAPE '\\'" + + # --- Range --- + + def test_between(self): + assert to_sql(col("x").between(1, 10), SqlTarget.SPARK) == "`x` BETWEEN 1 AND 10" + + # --- Logical combinators --- + + def test_and(self): + assert to_sql((col("x") > 5) & (col("y") == "a"), SqlTarget.SPARK) == "`x` > 5 AND `y` = 'a'" + + def test_or(self): + assert to_sql((col("x") > 5) | (col("y") == "a"), SqlTarget.SPARK) == "`x` > 5 OR `y` = 'a'" + + def test_not(self): + assert to_sql(~col("z").is_null(), SqlTarget.SPARK) == "NOT `z` IS NULL" + + def test_complex_composition(self): + expr = (col("x") > 5) & (col("y") == "a") | ~col("z").is_null() + assert to_sql(expr, SqlTarget.SPARK) == "(`x` > 5 AND `y` = 'a') OR NOT `z` IS NULL" + + # --- Literals --- + + def test_datetime(self): + dt = datetime(2026, 4, 27, tzinfo=UTC) + assert to_sql(col("ts") >= dt, SqlTarget.SPARK) == "`ts` >= '2026-04-27T00:00:00+00:00'" + + def test_date(self): + assert to_sql(col("d") >= date(2026, 4, 27), SqlTarget.SPARK) == "`d` >= '2026-04-27'" + + def test_decimal(self): + assert to_sql(col("price") > Decimal("99.95"), SqlTarget.SPARK) == "`price` > 99.95" + + def test_non_finite_float_uses_canonical_cast(self): + # Canonical NaN/Infinity spelling (Spark does not parse lowercase 'nan'/'inf'). + assert to_sql(col("x") == float("nan"), SqlTarget.SPARK) == "`x` = CAST('NaN' AS DOUBLE)" + assert to_sql(col("x") == float("inf"), SqlTarget.SPARK) == "`x` = CAST('Infinity' AS DOUBLE)" + + def test_uuid(self): + u = UUID("12345678-1234-5678-1234-567812345678") + assert to_sql(col("id") == u, SqlTarget.SPARK) == "`id` = '12345678-1234-5678-1234-567812345678'" + + # --- Cross-cutting invariants --- + + @pytest.mark.parametrize( + "expr", + [ + col("x") == 5, + col("name") == "alice", + col("x").is_null(), + col("x").is_not_null(), + col("x").is_nan(), + col("x").is_not_nan(), + col("x").is_in([1, 2, 3]), + col("x").is_not_in(["a", "b"]), + col("name").starts_with("a%b_c"), + col("name").not_starts_with("John"), + col("x").between(1, 10), + (col("x") > 5) & (col("y") == "a") | ~col("z").is_null(), + col("ts") >= datetime(2026, 4, 27, tzinfo=UTC), + col("x") == float("nan"), + ], + ) + def test_output_is_valid_spark_sql(self, expr): + # Every rendering must parse back as valid Spark SQL. + sqlglot.parse_one(to_sql(expr, SqlTarget.SPARK), dialect="spark") + + @pytest.mark.parametrize( + "expr", + [ + col("x") == 5, + col("x") >= 5, + col("x").is_null(), + col("x").is_not_null(), + col("x").is_in([1, 2, 3]), + (col("x") > 5) & col("y").is_null(), + ], + ) + def test_spark_and_datafusion_differ_only_in_quoting(self, expr): + # For filters without string-literal escaping, Spark vs DataFusion output + # differs only in identifier quoting (backtick vs double-quote). + spark = to_sql(expr, SqlTarget.SPARK) + assert spark.replace("`", '"') == to_sql(expr, SqlTarget.DATA_FUSION) + + def test_raises_on_unknown_filter(self): + class CustomFilter(Filter): + def __repr__(self) -> str: + return "custom" + + with pytest.raises(TypeError, match="Unsupported filter type"): + to_sql(CustomFilter(), SqlTarget.SPARK) + + +class TestTrinoSqlConversion: + """Render filters to Trino SQL (double-quoted identifiers, is_nan()).""" + + def test_uses_double_quoted_identifiers(self): + assert to_sql(col("x") == 5, SqlTarget.TRINO) == '"x" = 5' + + def test_is_nan_uses_trino_function(self): + # Trino's NaN function is is_nan (underscore), unlike Spark/DataFusion's isnan. + assert to_sql(col("x").is_nan(), SqlTarget.TRINO) == 'IS_NAN("x")' + assert to_sql(col("x").is_not_nan(), SqlTarget.TRINO) == 'NOT IS_NAN("x")' + + def test_in_and_between(self): + assert to_sql(col("x").is_in([1, 2]), SqlTarget.TRINO) == '"x" IN (1, 2)' + assert to_sql(col("x").between(1, 10), SqlTarget.TRINO) == '"x" BETWEEN 1 AND 10' + + @pytest.mark.parametrize( + "expr", + [ + col("x") == 5, + col("name") == "alice", + col("x").is_null(), + col("x").is_not_null(), + col("x").is_nan(), + col("x").is_not_nan(), + col("x").is_in([1, 2, 3]), + col("name").starts_with("a%b_c"), + col("x").between(1, 10), + (col("x") > 5) & (col("y") == "a") | ~col("z").is_null(), + ], + ) + def test_output_is_valid_trino_sql(self, expr): + sqlglot.parse_one(to_sql(expr, SqlTarget.TRINO), dialect="trino") + + class TestPyIcebergUnsupportedType: def test_raises_on_unknown_filter(self): class CustomFilter(Filter): @@ -623,4 +826,4 @@ def __repr__(self) -> str: return "custom" with pytest.raises(TypeError, match="Unsupported filter type"): - _to_datafusion_sql(CustomFilter()) + to_sql(CustomFilter(), SqlTarget.DATA_FUSION) diff --git a/integrations/python/dataloader/tests/test_scan_optimizer.py b/integrations/python/dataloader/tests/test_scan_optimizer.py index a1047bd12..57282d197 100644 --- a/integrations/python/dataloader/tests/test_scan_optimizer.py +++ b/integrations/python/dataloader/tests/test_scan_optimizer.py @@ -17,7 +17,8 @@ LessThanOrEqual, NotEqualTo, Or, - _to_datafusion_sql, + SqlTarget, + to_sql, ) from openhouse.dataloader.scan_optimizer import optimize_scan as _optimize_scan @@ -170,7 +171,7 @@ def test_comparison_types(): def test_datetime_string_literals_pushed_as_strings(): - """`filters._literal_to_sql()` emits plain string literals for datetime/date/time + """`filters._literal_to_expr()` emits plain string literals for datetime/date/time (see PR #569 + follow-up). The scan optimizer treats them as ordinary string literals; PyIceberg promotes them to typed literals during expression binding against the table schema, restoring partition pruning. @@ -206,7 +207,7 @@ def test_non_convertible_predicates_not_pushed(): def test_filter_dsl_to_sql_round_trip(): - """Each Filter type survives _to_datafusion_sql → optimize_scan round trip.""" + """Each Filter type survives to_sql(..., DATA_FUSION) → optimize_scan round trip.""" cases = [ EqualTo("x", 1), NotEqualTo("x", 1), @@ -223,7 +224,7 @@ def test_filter_dsl_to_sql_round_trip(): Or(EqualTo("x", 1), EqualTo("x", 2)), ] for filter_dsl in cases: - sql = f'SELECT "a" FROM "db"."tbl" WHERE {_to_datafusion_sql(filter_dsl)}' + sql = f'SELECT "a" FROM "db"."tbl" WHERE {to_sql(filter_dsl, SqlTarget.DATA_FUSION)}' plan = optimize_scan(sql) assert plan.row_filter == filter_dsl, f"Round trip failed for {filter_dsl!r}: got {plan.row_filter!r}"