From 468936d8182b6b31dae7bb09b624b9f998d74831 Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Tue, 9 Jun 2026 20:48:47 +0000 Subject: [PATCH 1/6] feat(dataloader): render filters as Spark SQL via to_sql() Add a public to_sql(filter, target=SqlTarget.SPARK) that renders a Filter DSL object as a SQL boolean expression for a target. The underlying sqlglot dialect is hidden behind the SqlTarget enum so callers never pass a dialect string. Internally, _filter_to_expr() builds a single dialect-agnostic sqlglot AST from the Filter tree; to_sql() renders it with .sql(dialect=target.value) and _to_datafusion_sql() renders the same AST with the default dialect (signature unchanged). This replaces the hand-rolled f-string DataFusion builder. Side effect: NaN checks now render isnan(col) instead of the unparseable 'col IS NAN' (DataFusion 53 rejects IS NAN), and non-finite float literals use canonical NaN/Infinity spellings accepted by both DataFusion and Spark. Tests: TestSparkSqlConversion asserts per-node Spark output, validates each rendering parses as Spark via sqlglot, and checks Spark vs DataFusion differ only in identifier quoting. make verify passes (ruff, mypy, 332 tests). --- integrations/python/dataloader/CLAUDE.md | 4 + integrations/python/dataloader/README.md | 13 ++ .../src/openhouse/dataloader/__init__.py | 4 +- .../src/openhouse/dataloader/filters.py | 152 +++++++++++---- .../python/dataloader/tests/test_filters.py | 175 +++++++++++++++++- 5 files changed, 303 insertions(+), 45 deletions(-) diff --git a/integrations/python/dataloader/CLAUDE.md b/integrations/python/dataloader/CLAUDE.md index cc3e6b313..deb66d267 100644 --- a/integrations/python/dataloader/CLAUDE.md +++ b/integrations/python/dataloader/CLAUDE.md @@ -75,11 +75,15 @@ Exported in `__init__.py`: - `OpenHouseCatalogError` — Error raised when catalog fails to load a table - `col()` — Column reference for building filter expressions - `always_true()` — Filter that matches all rows +- `to_sql()` — Render a filter as a SQL boolean expression (WHERE-clause predicate) for a target +- `SqlTarget` — Enum naming the SQL target for `to_sql()` (currently `SPARK`) ### Filter DSL (`filters.py`) Build row filters using `col()` with comparison operators (`==`, `!=`, `>`, `>=`, `<`, `<=`) and predicates (`is_null()`, `is_not_null()`, `is_nan()`, `is_not_nan()`, `is_in()`, `is_not_in()`, `starts_with()`, `not_starts_with()`, `between()`). Combine with `&` (AND), `|` (OR), `~` (NOT). Filters are converted to PyIceberg expressions internally for partition pruning and file-level filtering. +`to_sql(filter, target=SqlTarget.SPARK)` renders a filter as a SQL boolean expression for the given target (e.g. Spark SQL). Internally, `_filter_to_expr()` builds a single dialect-agnostic sqlglot AST that is rendered per target with `.sql(dialect=...)`; `_to_datafusion_sql()` uses the same AST with the default dialect. The underlying sqlglot dialect is an internal detail — callers select a `SqlTarget`, never a dialect string. + ### Internal modules (not in `__init__.py`) - `TableTransformer` — ABC for SQL-based table transforms; subclass must provide a `dialect` (e.g. `"spark"`) and implement `transform()` returning SQL or `None` diff --git a/integrations/python/dataloader/README.md b/integrations/python/dataloader/README.md index b9b59e661..a33f64252 100644 --- a/integrations/python/dataloader/README.md +++ b/integrations/python/dataloader/README.md @@ -52,6 +52,19 @@ filters = col("score").between(0.5, 1.0) filters = (col("age") >= 18) & (col("country").is_in(["US", "CA"])) & ~col("email").is_null() ``` +### Rendering a filter as SQL + +Use `to_sql()` to render a filter as a SQL boolean expression (a `WHERE`-clause +predicate) for a given target. Spark is currently the supported target: + +```python +from openhouse.dataloader import SqlTarget, col, to_sql + +to_sql(col("age") > 21) # `age` > 21 (defaults to Spark) +to_sql((col("country") == "US") & col("email").is_null(), SqlTarget.SPARK) +# `country` = 'US' AND `email` IS NULL +``` + ## Development ```bash diff --git a/integrations/python/dataloader/src/openhouse/dataloader/__init__.py b/integrations/python/dataloader/src/openhouse/dataloader/__init__.py index df266c799..16d8659a9 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/__init__.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/__init__.py @@ -2,7 +2,7 @@ from openhouse.dataloader.catalog import OpenHouseCatalog, OpenHouseCatalogError from openhouse.dataloader.data_loader import DataLoaderContext, JvmConfig, OpenHouseDataLoader -from openhouse.dataloader.filters import always_true, col +from openhouse.dataloader.filters import SqlTarget, always_true, col, to_sql __version__ = version("openhouse.dataloader") __all__ = [ @@ -11,6 +11,8 @@ "JvmConfig", "OpenHouseCatalog", "OpenHouseCatalogError", + "SqlTarget", "always_true", "col", + "to_sql", ] diff --git a/integrations/python/dataloader/src/openhouse/dataloader/filters.py b/integrations/python/dataloader/src/openhouse/dataloader/filters.py index a725eaa5e..7ef5701b2 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/filters.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/filters.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from datetime import date, datetime, time from decimal import Decimal +from enum import Enum from typing import Any from uuid import UUID @@ -325,101 +326,170 @@ def _escape_like(value: str) -> str: return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") -def _literal_to_sql(value: object) -> str: - """Convert a Python literal to a SQL literal string using sqlglot. +def _non_finite_double(value: float | Decimal) -> exp.Cast: + """Build ``CAST('' AS DOUBLE)`` for a non-finite float/decimal. + + Uses the canonical ``NaN``/``Infinity``/``-Infinity`` spellings, which both + DataFusion and Spark parse. The lowercase forms from ``str(float(...))`` + (e.g. ``'inf'``) are not reliably cast by Spark. + """ + if isinstance(value, Decimal): + text = str(value) # 'NaN' / 'Infinity' / '-Infinity' + elif math.isnan(value): + text = "NaN" + elif value > 0: + text = "Infinity" + else: + text = "-Infinity" + return exp.Cast(this=exp.Literal.string(text), to=exp.DataType.build("DOUBLE")) + + +def _literal_to_expr(value: object) -> exp.Expression: + """Convert a Python literal to a sqlglot literal expression. Datetime/date/time values are emitted as plain string literals (ISO format). DataFusion implicitly coerces string literals to the column type at execution, and PyIceberg promotes StringLiteral to the matching typed literal during expression binding. """ if isinstance(value, str): - return exp.Literal.string(value).sql() + return exp.Literal.string(value) if isinstance(value, bool): - return exp.Boolean(this=True).sql() if value else exp.Boolean(this=False).sql() + return exp.true() if value else exp.false() if isinstance(value, datetime): - return exp.Literal.string(value.isoformat()).sql() + return exp.Literal.string(value.isoformat()) if isinstance(value, date): - return exp.Literal.string(value.isoformat()).sql() + return exp.Literal.string(value.isoformat()) if isinstance(value, time): if value.tzinfo is not None: raise TypeError( - "DataFusion does not support timezones for time data types. " + "The SQL target does not support timezones for time data types. " "The time should match the timezone used in the dataset." ) - return exp.Literal.string(value.isoformat()).sql() + return exp.Literal.string(value.isoformat()) if isinstance(value, (int, float)): if isinstance(value, float) and not math.isfinite(value): - return exp.Cast(this=exp.Literal.string(str(value)), to=exp.DataType.build("DOUBLE")).sql() - return exp.Literal.number(value).sql() + return _non_finite_double(value) + return exp.Literal.number(value) if isinstance(value, Decimal): if not value.is_finite(): - return exp.Cast(this=exp.Literal.string(str(value)), to=exp.DataType.build("DOUBLE")).sql() - return exp.Literal.number(value).sql() + return _non_finite_double(value) + return exp.Literal.number(value) if isinstance(value, UUID): - return exp.Literal.string(str(value)).sql() + return exp.Literal.string(str(value)) raise TypeError(f"Unsupported literal type: {type(value).__name__}") -def _to_datafusion_sql(expr: Filter) -> str: - """Convert a Filter expression tree to a DataFusion SQL expression string.""" - match expr: +def _column_expr(name: str) -> exp.Column: + """Build a quoted sqlglot column reference.""" + return exp.column(exp.to_identifier(name, quoted=True)) + + +def _like_prefix(column: str, prefix: str) -> exp.Expression: + r"""Build ``col LIKE 'prefix%' ESCAPE '\'`` with the prefix matched literally.""" + pattern = exp.Literal.string(_escape_like(prefix) + "%") + like = exp.Like(this=_column_expr(column), expression=pattern) + return exp.Escape(this=like, expression=exp.Literal.string("\\")) + + +def _filter_to_expr(filter_expr: Filter) -> exp.Expression: + """Build a dialect-agnostic sqlglot expression tree for a Filter. + + Render it for a target by calling ``.sql(dialect=...)``; see :func:`to_sql` + and :func:`_to_datafusion_sql`. The tree is built once and sqlglot handles + per-dialect identifier quoting, operators, and literal formatting. + """ + match filter_expr: case AlwaysTrue(): - return exp.Boolean(this=True).sql() + return exp.true() # Comparison case EqualTo(column, value): - return f"{_quote_identifier(column)} = {_literal_to_sql(value)}" + return exp.EQ(this=_column_expr(column), expression=_literal_to_expr(value)) case NotEqualTo(column, value): - return f"{_quote_identifier(column)} <> {_literal_to_sql(value)}" + return exp.NEQ(this=_column_expr(column), expression=_literal_to_expr(value)) case GreaterThan(column, value): - return f"{_quote_identifier(column)} > {_literal_to_sql(value)}" + return exp.GT(this=_column_expr(column), expression=_literal_to_expr(value)) case GreaterThanOrEqual(column, value): - return f"{_quote_identifier(column)} >= {_literal_to_sql(value)}" + return exp.GTE(this=_column_expr(column), expression=_literal_to_expr(value)) case LessThan(column, value): - return f"{_quote_identifier(column)} < {_literal_to_sql(value)}" + return exp.LT(this=_column_expr(column), expression=_literal_to_expr(value)) case LessThanOrEqual(column, value): - return f"{_quote_identifier(column)} <= {_literal_to_sql(value)}" + return exp.LTE(this=_column_expr(column), expression=_literal_to_expr(value)) - # Null / NaN + # Null / NaN. NaN uses the isnan() function (an Anonymous node, not sqlglot's + # built-in IsNan which renders the unsupported `IS_NAN(...)`); isnan() is + # accepted by both DataFusion and Spark. case IsNull(column): - return f"{_quote_identifier(column)} IS NULL" + return exp.Is(this=_column_expr(column), expression=exp.Null()) case IsNotNull(column): - return f"{_quote_identifier(column)} IS NOT NULL" + return exp.not_(exp.Is(this=_column_expr(column), expression=exp.Null())) case IsNaN(column): - return f"{_quote_identifier(column)} IS NAN" + return exp.Anonymous(this="isnan", expressions=[_column_expr(column)]) case IsNotNaN(column): - return f"{_quote_identifier(column)} IS NOT NAN" + return exp.not_(exp.Anonymous(this="isnan", expressions=[_column_expr(column)])) # Set membership case In(column, values): - vals = ", ".join(_literal_to_sql(v) for v in values) - return f"{_quote_identifier(column)} IN ({vals})" + return exp.In(this=_column_expr(column), expressions=[_literal_to_expr(v) for v in values]) case NotIn(column, values): - vals = ", ".join(_literal_to_sql(v) for v in values) - return f"{_quote_identifier(column)} NOT IN ({vals})" + return exp.not_(exp.In(this=_column_expr(column), expressions=[_literal_to_expr(v) for v in values])) # String prefix case StartsWith(column, prefix): - escaped = _escape_like(prefix) - return f"{_quote_identifier(column)} LIKE {_literal_to_sql(escaped + '%')} ESCAPE '\\'" + return _like_prefix(column, prefix) case NotStartsWith(column, prefix): - escaped = _escape_like(prefix) - return f"{_quote_identifier(column)} NOT LIKE {_literal_to_sql(escaped + '%')} ESCAPE '\\'" + return exp.not_(_like_prefix(column, prefix)) # Range case Between(column, lower, upper): - return f"{_quote_identifier(column)} BETWEEN {_literal_to_sql(lower)} AND {_literal_to_sql(upper)}" + return exp.Between( + this=_column_expr(column), + low=_literal_to_expr(lower), + high=_literal_to_expr(upper), + ) # Logical combinators case And(left, right): - return f"({_to_datafusion_sql(left)} AND {_to_datafusion_sql(right)})" + return exp.and_(_filter_to_expr(left), _filter_to_expr(right)) case Or(left, right): - return f"({_to_datafusion_sql(left)} OR {_to_datafusion_sql(right)})" + return exp.or_(_filter_to_expr(left), _filter_to_expr(right)) case Not(operand): - return f"NOT ({_to_datafusion_sql(operand)})" + return exp.not_(_filter_to_expr(operand)) case _: - raise TypeError(f"Unsupported filter type: {type(expr).__name__}") + raise TypeError(f"Unsupported filter type: {type(filter_expr).__name__}") + + +class SqlTarget(Enum): + """Target for :func:`to_sql`. + + Members name a concrete SQL flavor; the underlying SQL dialect is an internal + implementation detail and is not part of the public contract. + """ + + SPARK = "spark" + + +def to_sql(filter_expr: Filter, target: SqlTarget = SqlTarget.SPARK) -> str: + """Render a filter as a SQL boolean expression for the given target. + + The result is a WHERE-clause-ready predicate (without a leading ``WHERE``). + + Example:: + + to_sql(col("age") > 21) # `age` > 21 + to_sql((col("a") == 1) & col("b").is_null()) # `a` = 1 AND `b` IS NULL + + Args: + filter_expr: The filter expression to render. + target: The SQL flavor to render for. Defaults to Spark. + """ + return _filter_to_expr(filter_expr).sql(dialect=target.value) + + +def _to_datafusion_sql(expr: Filter) -> str: + """Render a Filter as a DataFusion SQL boolean expression string (default dialect).""" + return _filter_to_expr(expr).sql() def _to_pyiceberg(expr: Filter) -> ice.BooleanExpression: diff --git a/integrations/python/dataloader/tests/test_filters.py b/integrations/python/dataloader/tests/test_filters.py index 2bc74e612..4a90a7d88 100644 --- a/integrations/python/dataloader/tests/test_filters.py +++ b/integrations/python/dataloader/tests/test_filters.py @@ -4,6 +4,7 @@ from uuid import UUID import pytest +import sqlglot from pyiceberg import expressions as ice from openhouse.dataloader import col @@ -28,10 +29,12 @@ NotIn, NotStartsWith, Or, + SqlTarget, StartsWith, _to_datafusion_sql, _to_pyiceberg, always_true, + to_sql, ) @@ -423,9 +426,9 @@ def test_decimal_between(self): @pytest.mark.parametrize( ("value", "expected"), [ - (float("nan"), "CAST('nan' AS DOUBLE)"), - (float("inf"), "CAST('inf' AS DOUBLE)"), - (float("-inf"), "CAST('-inf' AS DOUBLE)"), + (float("nan"), "CAST('NaN' AS DOUBLE)"), + (float("inf"), "CAST('Infinity' AS DOUBLE)"), + (float("-inf"), "CAST('-Infinity' AS DOUBLE)"), ], ) def test_non_finite_float(self, value, expected): @@ -608,6 +611,172 @@ def test_datetime_microseconds(self, ctx): assert sum(len(b) for b in table2) == 0 +class TestSparkSqlConversion: + """Render filters to Spark SQL via the public ``to_sql`` entry point.""" + + def test_default_target_is_spark(self): + # to_sql defaults to the Spark target (backtick-quoted identifiers). + assert to_sql(col("x") == 5) == "`x` = 5" + + def test_always_true(self): + assert to_sql(always_true(), SqlTarget.SPARK) == "TRUE" + + # --- Comparison --- + + def test_equal(self): + assert to_sql(col("x") == 5, SqlTarget.SPARK) == "`x` = 5" + + def test_equal_string(self): + assert to_sql(col("name") == "alice", SqlTarget.SPARK) == "`name` = 'alice'" + + def test_equal_bool(self): + assert to_sql(col("flag") == True, SqlTarget.SPARK) == "`flag` = TRUE" # noqa: E712 + + def test_not_equal(self): + assert to_sql(col("x") != 5, SqlTarget.SPARK) == "`x` <> 5" + + def test_greater_than(self): + assert to_sql(col("x") > 5, SqlTarget.SPARK) == "`x` > 5" + + def test_greater_than_or_equal(self): + assert to_sql(col("x") >= 5, SqlTarget.SPARK) == "`x` >= 5" + + def test_less_than(self): + assert to_sql(col("x") < 5, SqlTarget.SPARK) == "`x` < 5" + + def test_less_than_or_equal(self): + assert to_sql(col("x") <= 5, SqlTarget.SPARK) == "`x` <= 5" + + # --- Null / NaN --- + + def test_is_null(self): + assert to_sql(col("x").is_null(), SqlTarget.SPARK) == "`x` IS NULL" + + def test_is_not_null(self): + assert to_sql(col("x").is_not_null(), SqlTarget.SPARK) == "NOT `x` IS NULL" + + def test_is_nan(self): + # NaN uses Spark's isnan() function, not the (unsupported) `IS NAN` syntax. + assert to_sql(col("x").is_nan(), SqlTarget.SPARK) == "ISNAN(`x`)" + + def test_is_not_nan(self): + assert to_sql(col("x").is_not_nan(), SqlTarget.SPARK) == "NOT ISNAN(`x`)" + + # --- Set membership --- + + def test_in(self): + assert to_sql(col("x").is_in([1, 2, 3]), SqlTarget.SPARK) == "`x` IN (1, 2, 3)" + + def test_in_strings(self): + assert to_sql(col("x").is_in(["a", "b"]), SqlTarget.SPARK) == "`x` IN ('a', 'b')" + + def test_not_in(self): + assert to_sql(col("x").is_not_in([1, 2, 3]), SqlTarget.SPARK) == "NOT `x` IN (1, 2, 3)" + + # --- String prefix (LIKE with escaped wildcards) --- + + def test_starts_with(self): + assert to_sql(col("name").starts_with("John"), SqlTarget.SPARK) == r"`name` LIKE 'John%' ESCAPE '\\'" + + def test_starts_with_escapes_wildcards(self): + # %, _ and \ in the prefix are escaped so they match literally. + assert to_sql(col("name").starts_with("a%b_c"), SqlTarget.SPARK) == r"`name` LIKE 'a\\%b\\_c%' ESCAPE '\\'" + + def test_not_starts_with(self): + assert to_sql(col("name").not_starts_with("John"), SqlTarget.SPARK) == r"NOT `name` LIKE 'John%' ESCAPE '\\'" + + # --- Range --- + + def test_between(self): + assert to_sql(col("x").between(1, 10), SqlTarget.SPARK) == "`x` BETWEEN 1 AND 10" + + # --- Logical combinators --- + + def test_and(self): + assert to_sql((col("x") > 5) & (col("y") == "a"), SqlTarget.SPARK) == "`x` > 5 AND `y` = 'a'" + + def test_or(self): + assert to_sql((col("x") > 5) | (col("y") == "a"), SqlTarget.SPARK) == "`x` > 5 OR `y` = 'a'" + + def test_not(self): + assert to_sql(~col("z").is_null(), SqlTarget.SPARK) == "NOT `z` IS NULL" + + def test_complex_composition(self): + expr = (col("x") > 5) & (col("y") == "a") | ~col("z").is_null() + assert to_sql(expr, SqlTarget.SPARK) == "(`x` > 5 AND `y` = 'a') OR NOT `z` IS NULL" + + # --- Literals --- + + def test_datetime(self): + dt = datetime(2026, 4, 27, tzinfo=UTC) + assert to_sql(col("ts") >= dt, SqlTarget.SPARK) == "`ts` >= '2026-04-27T00:00:00+00:00'" + + def test_date(self): + assert to_sql(col("d") >= date(2026, 4, 27), SqlTarget.SPARK) == "`d` >= '2026-04-27'" + + def test_decimal(self): + assert to_sql(col("price") > Decimal("99.95"), SqlTarget.SPARK) == "`price` > 99.95" + + def test_non_finite_float_uses_canonical_cast(self): + # Canonical NaN/Infinity spelling (Spark does not parse lowercase 'nan'/'inf'). + assert to_sql(col("x") == float("nan"), SqlTarget.SPARK) == "`x` = CAST('NaN' AS DOUBLE)" + assert to_sql(col("x") == float("inf"), SqlTarget.SPARK) == "`x` = CAST('Infinity' AS DOUBLE)" + + def test_uuid(self): + u = UUID("12345678-1234-5678-1234-567812345678") + assert to_sql(col("id") == u, SqlTarget.SPARK) == "`id` = '12345678-1234-5678-1234-567812345678'" + + # --- Cross-cutting invariants --- + + @pytest.mark.parametrize( + "expr", + [ + col("x") == 5, + col("name") == "alice", + col("x").is_null(), + col("x").is_not_null(), + col("x").is_nan(), + col("x").is_not_nan(), + col("x").is_in([1, 2, 3]), + col("x").is_not_in(["a", "b"]), + col("name").starts_with("a%b_c"), + col("name").not_starts_with("John"), + col("x").between(1, 10), + (col("x") > 5) & (col("y") == "a") | ~col("z").is_null(), + col("ts") >= datetime(2026, 4, 27, tzinfo=UTC), + col("x") == float("nan"), + ], + ) + def test_output_is_valid_spark_sql(self, expr): + # Every rendering must parse back as valid Spark SQL. + sqlglot.parse_one(to_sql(expr, SqlTarget.SPARK), dialect="spark") + + @pytest.mark.parametrize( + "expr", + [ + col("x") == 5, + col("x") >= 5, + col("x").is_null(), + col("x").is_not_null(), + col("x").is_in([1, 2, 3]), + (col("x") > 5) & col("y").is_null(), + ], + ) + def test_spark_and_datafusion_differ_only_in_quoting(self, expr): + # For filters without string-literal escaping, Spark vs DataFusion output + # differs only in identifier quoting (backtick vs double-quote). + spark = to_sql(expr, SqlTarget.SPARK) + assert spark.replace("`", '"') == _to_datafusion_sql(expr) + + def test_raises_on_unknown_filter(self): + class CustomFilter(Filter): + def __repr__(self) -> str: + return "custom" + + with pytest.raises(TypeError, match="Unsupported filter type"): + to_sql(CustomFilter(), SqlTarget.SPARK) + + class TestPyIcebergUnsupportedType: def test_raises_on_unknown_filter(self): class CustomFilter(Filter): From fc0435ed7690d75bd3037746eb7b5c57597d927a Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Mon, 15 Jun 2026 16:26:58 +0000 Subject: [PATCH 2/6] feat(dataloader): make SqlTarget the unified SQL-flavor vocabulary Use SqlTarget across the library instead of raw dialect strings: - Add SqlTarget.TRINO and SqlTarget.DATA_FUSION (now {SPARK, TRINO, DATA_FUSION}). - Type TableTransformer.dialect as SqlTarget (was str); data_loader bridges to the string-based to_datafusion_sql() transpiler via transformer.dialect.value. - Drop the standalone _to_datafusion_sql(); the loader's internal DataFusion query now uses to_sql(filters, SqlTarget.DATA_FUSION). Everything funnels through to_sql() -> _filter_to_expr(). NaN is dialect-specific: Trino spells it is_nan() while Spark/DataFusion use isnan(). Thread the target through _filter_to_expr and pick the right function name per target. Breaking change for TableTransformer subclasses (an internal API): pass a SqlTarget, e.g. SqlTarget.SPARK. Updated all subclasses + tests; removed the now-impossible bad-dialect transformer test (string validation is still covered by test_datafusion_sql). make verify passes (ruff, mypy, 344 tests). --- integrations/python/dataloader/CLAUDE.md | 6 +- integrations/python/dataloader/README.md | 2 +- .../src/openhouse/dataloader/data_loader.py | 7 +- .../src/openhouse/dataloader/filters.py | 66 +++++++++++-------- .../openhouse/dataloader/table_transformer.py | 7 +- .../dataloader/tests/integration_tests.py | 6 +- .../dataloader/tests/test_data_loader.py | 46 ++++--------- .../python/dataloader/tests/test_filters.py | 40 ++++++++++- .../dataloader/tests/test_scan_optimizer.py | 9 +-- 9 files changed, 108 insertions(+), 81 deletions(-) diff --git a/integrations/python/dataloader/CLAUDE.md b/integrations/python/dataloader/CLAUDE.md index deb66d267..a1be62086 100644 --- a/integrations/python/dataloader/CLAUDE.md +++ b/integrations/python/dataloader/CLAUDE.md @@ -76,17 +76,17 @@ Exported in `__init__.py`: - `col()` — Column reference for building filter expressions - `always_true()` — Filter that matches all rows - `to_sql()` — Render a filter as a SQL boolean expression (WHERE-clause predicate) for a target -- `SqlTarget` — Enum naming the SQL target for `to_sql()` (currently `SPARK`) +- `SqlTarget` — Enum of supported SQL flavors (`SPARK`, `TRINO`, `DATA_FUSION`); used by `to_sql()` and `TableTransformer` ### Filter DSL (`filters.py`) Build row filters using `col()` with comparison operators (`==`, `!=`, `>`, `>=`, `<`, `<=`) and predicates (`is_null()`, `is_not_null()`, `is_nan()`, `is_not_nan()`, `is_in()`, `is_not_in()`, `starts_with()`, `not_starts_with()`, `between()`). Combine with `&` (AND), `|` (OR), `~` (NOT). Filters are converted to PyIceberg expressions internally for partition pruning and file-level filtering. -`to_sql(filter, target=SqlTarget.SPARK)` renders a filter as a SQL boolean expression for the given target (e.g. Spark SQL). Internally, `_filter_to_expr()` builds a single dialect-agnostic sqlglot AST that is rendered per target with `.sql(dialect=...)`; `_to_datafusion_sql()` uses the same AST with the default dialect. The underlying sqlglot dialect is an internal detail — callers select a `SqlTarget`, never a dialect string. +`to_sql(filter, target=SqlTarget.SPARK)` renders a filter as a SQL boolean expression for the given target. Internally, `_filter_to_expr()` builds a single dialect-agnostic sqlglot AST that is rendered per target with `.sql(dialect=target.value)` — the same path the loader uses for its internal DataFusion query via `to_sql(filters, SqlTarget.DATA_FUSION)`. The backing sqlglot dialect string is an internal detail — callers select a `SqlTarget`, never a dialect string. ### Internal modules (not in `__init__.py`) -- `TableTransformer` — ABC for SQL-based table transforms; subclass must provide a `dialect` (e.g. `"spark"`) and implement `transform()` returning SQL or `None` +- `TableTransformer` — ABC for SQL-based table transforms; subclass must provide a `dialect` (a `SqlTarget`, e.g. `SqlTarget.SPARK`) and implement `transform()` returning SQL or `None` - `UDFRegistry` / `NoOpRegistry` — ABC for registering DataFusion UDFs; `NoOpRegistry` is the default no-op - `TableScanContext` — Frozen dataclass holding table metadata, FileIO, projected schema, row filter, and table ID; pickle-safe for distributed execution - `DataFusion` dialect in `datafusion_sql.py` — Custom SQLGlot dialect for transpiling SQL from other dialects (e.g. Spark) to DataFusion diff --git a/integrations/python/dataloader/README.md b/integrations/python/dataloader/README.md index a33f64252..6cf32c059 100644 --- a/integrations/python/dataloader/README.md +++ b/integrations/python/dataloader/README.md @@ -55,7 +55,7 @@ filters = (col("age") >= 18) & (col("country").is_in(["US", "CA"])) & ~col("emai ### Rendering a filter as SQL Use `to_sql()` to render a filter as a SQL boolean expression (a `WHERE`-clause -predicate) for a given target. Spark is currently the supported target: +predicate) for a given `SqlTarget` (`SPARK`, `TRINO`, or `DATA_FUSION`): ```python from openhouse.dataloader import SqlTarget, col, to_sql diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index da9a3e943..19300f261 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -25,10 +25,11 @@ from openhouse.dataloader.filters import ( AlwaysTrue, Filter, + SqlTarget, _quote_identifier, - _to_datafusion_sql, _to_pyiceberg, always_true, + to_sql, ) from openhouse.dataloader.metrics import METER_NAME from openhouse.dataloader.scan_optimizer import optimize_scan @@ -298,11 +299,11 @@ def _build_query(self) -> str | None: sql = transformer.transform(self._table_id, execution_context) if sql is None: return None - sql = to_datafusion_sql(sql, transformer.dialect, table=self._table_id) + sql = to_datafusion_sql(sql, transformer.dialect.value, table=self._table_id) outer_cols = ", ".join(_quote_identifier(c) for c in self._columns) if self._columns else "*" combined = f"SELECT {outer_cols} FROM ({sql}) AS _t" if self._filters and not isinstance(self._filters, AlwaysTrue): - combined += f" WHERE {_to_datafusion_sql(self._filters)}" + combined += f" WHERE {to_sql(self._filters, SqlTarget.DATA_FUSION)}" return combined def __iter__(self) -> Iterator[DataLoaderSplit]: diff --git a/integrations/python/dataloader/src/openhouse/dataloader/filters.py b/integrations/python/dataloader/src/openhouse/dataloader/filters.py index 7ef5701b2..9ea662407 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/filters.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/filters.py @@ -379,6 +379,20 @@ def _literal_to_expr(value: object) -> exp.Expression: raise TypeError(f"Unsupported literal type: {type(value).__name__}") +class SqlTarget(Enum): + """A SQL flavor used by this library. + + Used both to render filters (:func:`to_sql`) and to declare the SQL a + ``TableTransformer`` emits. Members name a concrete SQL flavor; the backing + sqlglot dialect string is an internal detail — callers use the member, not + the string. + """ + + SPARK = "spark" + TRINO = "trino" + DATA_FUSION = "datafusion" + + def _column_expr(name: str) -> exp.Column: """Build a quoted sqlglot column reference.""" return exp.column(exp.to_identifier(name, quoted=True)) @@ -391,12 +405,23 @@ def _like_prefix(column: str, prefix: str) -> exp.Expression: return exp.Escape(this=like, expression=exp.Literal.string("\\")) -def _filter_to_expr(filter_expr: Filter) -> exp.Expression: - """Build a dialect-agnostic sqlglot expression tree for a Filter. +def _isnan(column: str, target: SqlTarget) -> exp.Anonymous: + """Build the dialect-appropriate NaN-check function call. + + Trino spells it ``is_nan``; Spark and DataFusion use ``isnan``. (sqlglot's + built-in IsNan node renders ``IS_NAN(...)``, which none of these accept, so we + emit an Anonymous function with the correct name for the target.) + """ + name = "is_nan" if target is SqlTarget.TRINO else "isnan" + return exp.Anonymous(this=name, expressions=[_column_expr(column)]) + + +def _filter_to_expr(filter_expr: Filter, target: SqlTarget) -> exp.Expression: + """Build a sqlglot expression tree for a Filter, ready to render for *target*. - Render it for a target by calling ``.sql(dialect=...)``; see :func:`to_sql` - and :func:`_to_datafusion_sql`. The tree is built once and sqlglot handles - per-dialect identifier quoting, operators, and literal formatting. + The tree is dialect-agnostic except where SQL genuinely diverges (NaN checks); + sqlglot handles per-dialect identifier quoting, operators, and literal + formatting at render time. See :func:`to_sql`. """ match filter_expr: case AlwaysTrue(): @@ -416,17 +441,15 @@ def _filter_to_expr(filter_expr: Filter) -> exp.Expression: case LessThanOrEqual(column, value): return exp.LTE(this=_column_expr(column), expression=_literal_to_expr(value)) - # Null / NaN. NaN uses the isnan() function (an Anonymous node, not sqlglot's - # built-in IsNan which renders the unsupported `IS_NAN(...)`); isnan() is - # accepted by both DataFusion and Spark. + # Null / NaN case IsNull(column): return exp.Is(this=_column_expr(column), expression=exp.Null()) case IsNotNull(column): return exp.not_(exp.Is(this=_column_expr(column), expression=exp.Null())) case IsNaN(column): - return exp.Anonymous(this="isnan", expressions=[_column_expr(column)]) + return _isnan(column, target) case IsNotNaN(column): - return exp.not_(exp.Anonymous(this="isnan", expressions=[_column_expr(column)])) + return exp.not_(_isnan(column, target)) # Set membership case In(column, values): @@ -450,26 +473,16 @@ def _filter_to_expr(filter_expr: Filter) -> exp.Expression: # Logical combinators case And(left, right): - return exp.and_(_filter_to_expr(left), _filter_to_expr(right)) + return exp.and_(_filter_to_expr(left, target), _filter_to_expr(right, target)) case Or(left, right): - return exp.or_(_filter_to_expr(left), _filter_to_expr(right)) + return exp.or_(_filter_to_expr(left, target), _filter_to_expr(right, target)) case Not(operand): - return exp.not_(_filter_to_expr(operand)) + return exp.not_(_filter_to_expr(operand, target)) case _: raise TypeError(f"Unsupported filter type: {type(filter_expr).__name__}") -class SqlTarget(Enum): - """Target for :func:`to_sql`. - - Members name a concrete SQL flavor; the underlying SQL dialect is an internal - implementation detail and is not part of the public contract. - """ - - SPARK = "spark" - - def to_sql(filter_expr: Filter, target: SqlTarget = SqlTarget.SPARK) -> str: """Render a filter as a SQL boolean expression for the given target. @@ -484,12 +497,7 @@ def to_sql(filter_expr: Filter, target: SqlTarget = SqlTarget.SPARK) -> str: filter_expr: The filter expression to render. target: The SQL flavor to render for. Defaults to Spark. """ - return _filter_to_expr(filter_expr).sql(dialect=target.value) - - -def _to_datafusion_sql(expr: Filter) -> str: - """Render a Filter as a DataFusion SQL boolean expression string (default dialect).""" - return _filter_to_expr(expr).sql() + return _filter_to_expr(filter_expr, target).sql(dialect=target.value) def _to_pyiceberg(expr: Filter) -> ice.BooleanExpression: diff --git a/integrations/python/dataloader/src/openhouse/dataloader/table_transformer.py b/integrations/python/dataloader/src/openhouse/dataloader/table_transformer.py index ea6746b9f..8bca405c5 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/table_transformer.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/table_transformer.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping +from openhouse.dataloader.filters import SqlTarget from openhouse.dataloader.table_identifier import TableIdentifier @@ -13,11 +14,11 @@ class TableTransformer(ABC): projections and filters can reference them. Args: - dialect: The SQL dialect used by ``transform()`` (e.g. ``"spark"``). + dialect: The SQL flavor that ``transform()`` emits (e.g. ``SqlTarget.SPARK``). """ - def __init__(self, dialect: str) -> None: - self.dialect: str = dialect + def __init__(self, dialect: SqlTarget) -> None: + self.dialect: SqlTarget = dialect @abstractmethod def transform(self, table: TableIdentifier, context: Mapping[str, str]) -> str | None: diff --git a/integrations/python/dataloader/tests/integration_tests.py b/integrations/python/dataloader/tests/integration_tests.py index bb183a99c..c99a5110b 100644 --- a/integrations/python/dataloader/tests/integration_tests.py +++ b/integrations/python/dataloader/tests/integration_tests.py @@ -21,7 +21,7 @@ from openhouse.dataloader import DataLoaderContext, JvmConfig, OpenHouseDataLoader from openhouse.dataloader.catalog import OpenHouseCatalog from openhouse.dataloader.data_loader_split import to_sql_identifier -from openhouse.dataloader.filters import col +from openhouse.dataloader.filters import SqlTarget, col from openhouse.dataloader.table_transformer import TableTransformer BASE_URL = "http://openhouse-tables:8080" @@ -345,10 +345,10 @@ def read_token() -> str: # SQL roundtrip path (filters -> DataFusion SQL -> sqlglot -> scan_optimizer -> # PyIceberg expression). Without a transformer, _build_query() returns None # and the loader skips that path entirely, which would mean a CAST(literal, - # TIMESTAMP) regression in _literal_to_sql / scan_optimizer would go unnoticed. + # TIMESTAMP) regression in _literal_to_expr / scan_optimizer would go unnoticed. class _PartPassthroughTransformer(TableTransformer): def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f'SELECT "id", "ts" FROM {to_sql_identifier(table)}' diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index a66ee29ca..0d56a714e 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -17,7 +17,7 @@ from openhouse.dataloader import DataLoaderContext, JvmConfig, OpenHouseDataLoader, __version__ from openhouse.dataloader.data_loader_split import DataLoaderSplit, to_sql_identifier -from openhouse.dataloader.filters import col +from openhouse.dataloader.filters import SqlTarget, col from openhouse.dataloader.table_transformer import TableTransformer from openhouse.dataloader.udf_registry import UDFRegistry @@ -346,7 +346,7 @@ class _NoneTransformer(TableTransformer): """Transformer that returns None (no transformation).""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return None @@ -356,7 +356,7 @@ class _MaskingTransformer(TableTransformer): """Transformer that masks the name column.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, 'MASKED' as name, value FROM {to_sql_identifier(table)}" @@ -443,7 +443,7 @@ class _SparkMaskingTransformer(TableTransformer): """Transformer using Spark SQL dialect.""" def __init__(self): - super().__init__(dialect="spark") + super().__init__(dialect=SqlTarget.SPARK) def transform(self, table, context): return f"SELECT id, CAST('MASKED' AS STRING) AS name, value FROM {to_sql_identifier(table)}" @@ -465,35 +465,13 @@ def test_iter_with_spark_dialect_transformer_transpiles(tmp_path): assert result.column("name").to_pylist() == ["MASKED", "MASKED", "MASKED"] -def test_iter_with_invalid_dialect_raises(tmp_path): - """Unsupported dialect raises ValueError during iteration.""" - - class _BadDialectTransformer(TableTransformer): - def __init__(self): - super().__init__(dialect="not_a_real_dialect") - - def transform(self, table, context): - return f"SELECT * FROM {to_sql_identifier(table)}" - - catalog = _make_real_catalog(tmp_path) - loader = OpenHouseDataLoader( - catalog=catalog, - database="db", - table="tbl", - context=DataLoaderContext(table_transformer=_BadDialectTransformer()), - ) - - with pytest.raises(ValueError, match="Unsupported source dialect"): - _materialize(loader) - - def test_iter_with_transformer_and_special_char_database(tmp_path): """Transformer works when the database name contains special characters.""" catalog = _make_real_catalog(tmp_path) class _QuotedMaskingTransformer(TableTransformer): def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, 'MASKED' as name, value FROM {to_sql_identifier(table)}" @@ -664,7 +642,7 @@ class _FilteringTransformer(TableTransformer): """Transformer that has a WHERE clause filtering on status.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, name, value, status FROM {to_sql_identifier(table)} WHERE status = 'active'" @@ -707,7 +685,7 @@ class _MaskingFilteringTransformer(TableTransformer): """Transformer that masks name and filters on value.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, 'MASKED' as name, value FROM {to_sql_identifier(table)} WHERE value > 1.5" @@ -784,7 +762,7 @@ class _PassthroughTransformer(TableTransformer): """Transformer that selects all columns unchanged.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f"SELECT id, name, value FROM {to_sql_identifier(table)}" @@ -807,7 +785,7 @@ class _MixedCaseTransformer(TableTransformer): """Transformer that selects mixed-case columns.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): return f'SELECT "purchaseAmount", "itemCount", "discountRate" FROM {to_sql_identifier(table)}' @@ -978,7 +956,7 @@ class _NestedUDFTransformer(TableTransformer): """Nested subquery with UDF at two levels — triggers alias rewrite bug.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): tbl = to_sql_identifier(table) @@ -995,7 +973,7 @@ class _SelectStarUDFTransformer(TableTransformer): """SELECT * with UDF — triggers unquoted column bug after projection pushdown.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): tbl = to_sql_identifier(table) @@ -1007,7 +985,7 @@ class _ProjectionUDFTransformer(TableTransformer): """UDF in projection with inner SELECT * — triggers unquoted column bug in projection.""" def __init__(self): - super().__init__(dialect="datafusion") + super().__init__(dialect=SqlTarget.DATA_FUSION) def transform(self, table, context): tbl = to_sql_identifier(table) diff --git a/integrations/python/dataloader/tests/test_filters.py b/integrations/python/dataloader/tests/test_filters.py index 4a90a7d88..7f599b551 100644 --- a/integrations/python/dataloader/tests/test_filters.py +++ b/integrations/python/dataloader/tests/test_filters.py @@ -31,13 +31,17 @@ Or, SqlTarget, StartsWith, - _to_datafusion_sql, _to_pyiceberg, always_true, to_sql, ) +def _to_datafusion_sql(expr: Filter) -> str: + """Test helper: render a filter to DataFusion SQL through the public ``to_sql``.""" + return to_sql(expr, SqlTarget.DATA_FUSION) + + class TestColumnCreation: def test_col_returns_column(self): c = col("x") @@ -777,6 +781,40 @@ def __repr__(self) -> str: to_sql(CustomFilter(), SqlTarget.SPARK) +class TestTrinoSqlConversion: + """Render filters to Trino SQL (double-quoted identifiers, is_nan()).""" + + def test_uses_double_quoted_identifiers(self): + assert to_sql(col("x") == 5, SqlTarget.TRINO) == '"x" = 5' + + def test_is_nan_uses_trino_function(self): + # Trino's NaN function is is_nan (underscore), unlike Spark/DataFusion's isnan. + assert to_sql(col("x").is_nan(), SqlTarget.TRINO) == 'IS_NAN("x")' + assert to_sql(col("x").is_not_nan(), SqlTarget.TRINO) == 'NOT IS_NAN("x")' + + def test_in_and_between(self): + assert to_sql(col("x").is_in([1, 2]), SqlTarget.TRINO) == '"x" IN (1, 2)' + assert to_sql(col("x").between(1, 10), SqlTarget.TRINO) == '"x" BETWEEN 1 AND 10' + + @pytest.mark.parametrize( + "expr", + [ + col("x") == 5, + col("name") == "alice", + col("x").is_null(), + col("x").is_not_null(), + col("x").is_nan(), + col("x").is_not_nan(), + col("x").is_in([1, 2, 3]), + col("name").starts_with("a%b_c"), + col("x").between(1, 10), + (col("x") > 5) & (col("y") == "a") | ~col("z").is_null(), + ], + ) + def test_output_is_valid_trino_sql(self, expr): + sqlglot.parse_one(to_sql(expr, SqlTarget.TRINO), dialect="trino") + + class TestPyIcebergUnsupportedType: def test_raises_on_unknown_filter(self): class CustomFilter(Filter): diff --git a/integrations/python/dataloader/tests/test_scan_optimizer.py b/integrations/python/dataloader/tests/test_scan_optimizer.py index a1047bd12..57282d197 100644 --- a/integrations/python/dataloader/tests/test_scan_optimizer.py +++ b/integrations/python/dataloader/tests/test_scan_optimizer.py @@ -17,7 +17,8 @@ LessThanOrEqual, NotEqualTo, Or, - _to_datafusion_sql, + SqlTarget, + to_sql, ) from openhouse.dataloader.scan_optimizer import optimize_scan as _optimize_scan @@ -170,7 +171,7 @@ def test_comparison_types(): def test_datetime_string_literals_pushed_as_strings(): - """`filters._literal_to_sql()` emits plain string literals for datetime/date/time + """`filters._literal_to_expr()` emits plain string literals for datetime/date/time (see PR #569 + follow-up). The scan optimizer treats them as ordinary string literals; PyIceberg promotes them to typed literals during expression binding against the table schema, restoring partition pruning. @@ -206,7 +207,7 @@ def test_non_convertible_predicates_not_pushed(): def test_filter_dsl_to_sql_round_trip(): - """Each Filter type survives _to_datafusion_sql → optimize_scan round trip.""" + """Each Filter type survives to_sql(..., DATA_FUSION) → optimize_scan round trip.""" cases = [ EqualTo("x", 1), NotEqualTo("x", 1), @@ -223,7 +224,7 @@ def test_filter_dsl_to_sql_round_trip(): Or(EqualTo("x", 1), EqualTo("x", 2)), ] for filter_dsl in cases: - sql = f'SELECT "a" FROM "db"."tbl" WHERE {_to_datafusion_sql(filter_dsl)}' + sql = f'SELECT "a" FROM "db"."tbl" WHERE {to_sql(filter_dsl, SqlTarget.DATA_FUSION)}' plan = optimize_scan(sql) assert plan.row_filter == filter_dsl, f"Round trip failed for {filter_dsl!r}: got {plan.row_filter!r}" From 636682d17a0d1abdeff989b4bd409adb9ee90ae7 Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Mon, 15 Jun 2026 18:20:32 +0000 Subject: [PATCH 3/6] test(dataloader): drop _to_datafusion_sql test helper, call to_sql directly Inline to_sql(expr, SqlTarget.DATA_FUSION) at the 31 call sites that used the leftover _to_datafusion_sql test helper, and remove the helper. The DataFusion filter rendering is now exercised through the public to_sql() path with no indirection. No production change; the to_datafusion_sql statement transpiler is untouched. make verify passes (344 tests). --- .../python/dataloader/tests/test_filters.py | 70 +++++++++---------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/integrations/python/dataloader/tests/test_filters.py b/integrations/python/dataloader/tests/test_filters.py index 7f599b551..2f194d2c5 100644 --- a/integrations/python/dataloader/tests/test_filters.py +++ b/integrations/python/dataloader/tests/test_filters.py @@ -37,11 +37,6 @@ ) -def _to_datafusion_sql(expr: Filter) -> str: - """Test helper: render a filter to DataFusion SQL through the public ``to_sql``.""" - return to_sql(expr, SqlTarget.DATA_FUSION) - - class TestColumnCreation: def test_col_returns_column(self): c = col("x") @@ -362,69 +357,69 @@ def test_f_col_builds_filter(self): class TestDataFusionLiteralConversion: def test_datetime_greater_than_or_equal(self): dt = datetime(2026, 4, 27, tzinfo=UTC) - result = _to_datafusion_sql(col("datepartition") >= dt) + result = to_sql(col("datepartition") >= dt, SqlTarget.DATA_FUSION) assert result == "\"datepartition\" >= '2026-04-27T00:00:00+00:00'" def test_datetime_equal(self): dt = datetime(2026, 4, 27, 12, 30, 45, tzinfo=UTC) - result = _to_datafusion_sql(col("ts") == dt) + result = to_sql(col("ts") == dt, SqlTarget.DATA_FUSION) assert result == "\"ts\" = '2026-04-27T12:30:45+00:00'" def test_datetime_with_microseconds(self): dt = datetime(2026, 4, 27, 12, 30, 45, 123456, tzinfo=UTC) - result = _to_datafusion_sql(col("ts") == dt) + result = to_sql(col("ts") == dt, SqlTarget.DATA_FUSION) assert result == "\"ts\" = '2026-04-27T12:30:45.123456+00:00'" def test_datetime_non_utc_timezone_preserved(self): dt = datetime(2026, 4, 27, 12, 0, 0, tzinfo=timezone(timedelta(hours=5))) - result = _to_datafusion_sql(col("ts") >= dt) + result = to_sql(col("ts") >= dt, SqlTarget.DATA_FUSION) assert result == "\"ts\" >= '2026-04-27T12:00:00+05:00'" def test_datetime_naive_no_offset(self): dt = datetime(2026, 4, 27, 12, 0, 0) - result = _to_datafusion_sql(col("ts") >= dt) + result = to_sql(col("ts") >= dt, SqlTarget.DATA_FUSION) assert result == "\"ts\" >= '2026-04-27T12:00:00'" def test_date_greater_than_or_equal(self): d = date(2026, 4, 27) - result = _to_datafusion_sql(col("datepartition") >= d) + result = to_sql(col("datepartition") >= d, SqlTarget.DATA_FUSION) assert result == "\"datepartition\" >= '2026-04-27'" def test_datetime_between(self): dt1 = datetime(2026, 4, 27, tzinfo=UTC) dt2 = datetime(2026, 5, 1, tzinfo=UTC) - result = _to_datafusion_sql(col("ts").between(dt1, dt2)) + result = to_sql(col("ts").between(dt1, dt2), SqlTarget.DATA_FUSION) assert result == "\"ts\" BETWEEN '2026-04-27T00:00:00+00:00' AND '2026-05-01T00:00:00+00:00'" def test_datetime_in_compound_filter(self): dt = datetime(2026, 4, 27, tzinfo=UTC) f = (col("datepartition") >= dt) & (col("status") == "active") - result = _to_datafusion_sql(f) + result = to_sql(f, SqlTarget.DATA_FUSION) assert "'2026-04-27T00:00:00+00:00'" in result assert "\"status\" = 'active'" in result def test_time_equal(self): t = time(14, 30, 0) - result = _to_datafusion_sql(col("event_time") == t) + result = to_sql(col("event_time") == t, SqlTarget.DATA_FUSION) assert result == "\"event_time\" = '14:30:00'" def test_time_with_microseconds(self): t = time(14, 30, 0, 500000) - result = _to_datafusion_sql(col("event_time") == t) + result = to_sql(col("event_time") == t, SqlTarget.DATA_FUSION) assert result == "\"event_time\" = '14:30:00.500000'" def test_time_with_timezone_rejected(self): t = time(14, 30, 0, tzinfo=timezone(timedelta(hours=5))) with pytest.raises(TypeError, match="does not support timezones for time"): - _to_datafusion_sql(col("event_time") == t) + to_sql(col("event_time") == t, SqlTarget.DATA_FUSION) def test_decimal_greater_than(self): d = Decimal("99.95") - result = _to_datafusion_sql(col("price") > d) + result = to_sql(col("price") > d, SqlTarget.DATA_FUSION) assert result == '"price" > 99.95' def test_decimal_between(self): - result = _to_datafusion_sql(col("price").between(Decimal("10.00"), Decimal("50.00"))) + result = to_sql(col("price").between(Decimal("10.00"), Decimal("50.00")), SqlTarget.DATA_FUSION) assert result == '"price" BETWEEN 10.00 AND 50.00' @pytest.mark.parametrize( @@ -436,7 +431,7 @@ def test_decimal_between(self): ], ) def test_non_finite_float(self, value, expected): - result = _to_datafusion_sql(col("x") == value) + result = to_sql(col("x") == value, SqlTarget.DATA_FUSION) assert result == f'"x" = {expected}' @pytest.mark.parametrize( @@ -448,12 +443,12 @@ def test_non_finite_float(self, value, expected): ], ) def test_non_finite_decimal(self, value, expected): - result = _to_datafusion_sql(col("x") == value) + result = to_sql(col("x") == value, SqlTarget.DATA_FUSION) assert result == f'"x" = {expected}' def test_uuid_equal(self): u = UUID("12345678-1234-5678-1234-567812345678") - result = _to_datafusion_sql(col("id") == u) + result = to_sql(col("id") == u, SqlTarget.DATA_FUSION) assert result == "\"id\" = '12345678-1234-5678-1234-567812345678'" @@ -498,14 +493,14 @@ def _query(self, ctx, where: str): return pa.Table.from_batches(batches) def test_datetime_filter(self, ctx): - where = _to_datafusion_sql(col("ts") >= datetime(2026, 4, 27, tzinfo=UTC)) + where = to_sql(col("ts") >= datetime(2026, 4, 27, tzinfo=UTC), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 2 ts_values = [v.as_py() for v in table.column("ts")] assert ts_values == [datetime(2026, 4, 27, tzinfo=UTC), datetime(2026, 4, 29, tzinfo=UTC)] def test_datetime_less_than(self, ctx): - where = _to_datafusion_sql(col("ts") < datetime(2026, 4, 27, tzinfo=UTC)) + where = to_sql(col("ts") < datetime(2026, 4, 27, tzinfo=UTC), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("ts")[0].as_py() == datetime(2026, 4, 25, tzinfo=UTC) @@ -513,27 +508,27 @@ def test_datetime_less_than(self, ctx): def test_datetime_non_utc_timezone(self, ctx): # 2026-04-27 05:00:00+05:00 == 2026-04-27 00:00:00 UTC, same as row 2 dt = datetime(2026, 4, 27, 5, 0, 0, tzinfo=timezone(timedelta(hours=5))) - where = _to_datafusion_sql(col("ts") >= dt) + where = to_sql(col("ts") >= dt, SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 2 ts_values = [v.as_py() for v in table.column("ts")] assert ts_values == [datetime(2026, 4, 27, tzinfo=UTC), datetime(2026, 4, 29, tzinfo=UTC)] def test_date_filter(self, ctx): - where = _to_datafusion_sql(col("dt") >= date(2026, 4, 27)) + where = to_sql(col("dt") >= date(2026, 4, 27), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 2 dt_values = [v.as_py() for v in table.column("dt")] assert dt_values == [date(2026, 4, 27), date(2026, 4, 29)] def test_time_filter(self, ctx): - where = _to_datafusion_sql(col("t") == time(14, 30, 0)) + where = to_sql(col("t") == time(14, 30, 0), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("t")[0].as_py() == time(14, 30, 0) def test_decimal_filter(self, ctx): - where = _to_datafusion_sql(col("price") > Decimal("10.00")) + where = to_sql(col("price") > Decimal("10.00"), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 2 price_values = [v.as_py() for v in table.column("price")] @@ -558,27 +553,28 @@ def test_non_finite_filter(self, value, expected_count, check): batch = pa.record_batch({"x": pa.array([1.0, float("nan"), float("inf"), float("-inf"), 5.0])}) ctx.register_record_batches("t", [[batch]]) - where = _to_datafusion_sql(col("x") == value) + where = to_sql(col("x") == value, SqlTarget.DATA_FUSION) batches = ctx.sql(f'SELECT * FROM "t" WHERE {where}').collect() table = pa.Table.from_batches(batches) assert table.num_rows == expected_count assert check(table.column("x")[0].as_py()) def test_high_precision_decimal_filter(self, ctx): - where = _to_datafusion_sql(col("price") > Decimal("49.9899999999999999")) + where = to_sql(col("price") > Decimal("49.9899999999999999"), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("price")[0].as_py() == Decimal("99.99") def test_uuid_filter(self, ctx): - where = _to_datafusion_sql(col("id") == UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")) + where = to_sql(col("id") == UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"), SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("id")[0].as_py() == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" def test_datetime_between_filter(self, ctx): - where = _to_datafusion_sql( - col("ts").between(datetime(2026, 4, 26, tzinfo=UTC), datetime(2026, 4, 28, tzinfo=UTC)) + where = to_sql( + col("ts").between(datetime(2026, 4, 26, tzinfo=UTC), datetime(2026, 4, 28, tzinfo=UTC)), + SqlTarget.DATA_FUSION, ) table = self._query(ctx, where) assert table.num_rows == 1 @@ -586,7 +582,7 @@ def test_datetime_between_filter(self, ctx): def test_compound_filter(self, ctx): f = (col("ts") >= datetime(2026, 4, 27, tzinfo=UTC)) & (col("price") > Decimal("50.00")) - where = _to_datafusion_sql(f) + where = to_sql(f, SqlTarget.DATA_FUSION) table = self._query(ctx, where) assert table.num_rows == 1 assert table.column("ts")[0].as_py() == datetime(2026, 4, 29, tzinfo=UTC) @@ -605,12 +601,12 @@ def test_datetime_microseconds(self, ctx): } ) ctx2.register_record_batches("t2", [[batch]]) - where = _to_datafusion_sql(col("ts") == datetime(2026, 4, 27, 12, 0, 0, 500000, tzinfo=UTC)) + where = to_sql(col("ts") == datetime(2026, 4, 27, 12, 0, 0, 500000, tzinfo=UTC), SqlTarget.DATA_FUSION) table = ctx2.sql(f'SELECT * FROM "t2" WHERE {where}').collect() assert len(table[0]) == 1 assert table[0].column("ts")[0].as_py() == datetime(2026, 4, 27, 12, 0, 0, 500000, tzinfo=UTC) - where_no_match = _to_datafusion_sql(col("ts") == datetime(2026, 4, 27, 12, 0, 0, tzinfo=UTC)) + where_no_match = to_sql(col("ts") == datetime(2026, 4, 27, 12, 0, 0, tzinfo=UTC), SqlTarget.DATA_FUSION) table2 = ctx2.sql(f'SELECT * FROM "t2" WHERE {where_no_match}').collect() assert sum(len(b) for b in table2) == 0 @@ -770,7 +766,7 @@ def test_spark_and_datafusion_differ_only_in_quoting(self, expr): # For filters without string-literal escaping, Spark vs DataFusion output # differs only in identifier quoting (backtick vs double-quote). spark = to_sql(expr, SqlTarget.SPARK) - assert spark.replace("`", '"') == _to_datafusion_sql(expr) + assert spark.replace("`", '"') == to_sql(expr, SqlTarget.DATA_FUSION) def test_raises_on_unknown_filter(self): class CustomFilter(Filter): @@ -830,4 +826,4 @@ def __repr__(self) -> str: return "custom" with pytest.raises(TypeError, match="Unsupported filter type"): - _to_datafusion_sql(CustomFilter()) + to_sql(CustomFilter(), SqlTarget.DATA_FUSION) From 7876bfb83de339829dc4de5f6624b5647766dbdf Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Mon, 15 Jun 2026 21:18:13 +0000 Subject: [PATCH 4/6] refactor(dataloader): type to_datafusion_sql source_dialect as SqlTarget Make to_datafusion_sql accept a SqlTarget instead of a raw dialect string, so the transformer transpilation path is SqlTarget-typed end to end (data_loader now passes transformer.dialect directly, no .value). - Drop the 'unsupported source dialect' validation: every SqlTarget value is a valid sqlglot dialect, so the check is dead. Compare the no-op case with 'is SqlTarget.DATA_FUSION'; use .value for the sqlglot parse + error message. - Tests: convert spark/datafusion cases to SqlTarget; remove the mysql/postgres transpilation cases and the unsupported-dialect test (not expressible/reachable once the param is a SqlTarget). make verify passes (ruff, mypy, 340 tests). --- integrations/python/dataloader/CLAUDE.md | 2 +- .../src/openhouse/dataloader/data_loader.py | 2 +- .../openhouse/dataloader/datafusion_sql.py | 22 +++--- .../dataloader/tests/test_datafusion_sql.py | 72 +++++++++---------- 4 files changed, 45 insertions(+), 53 deletions(-) diff --git a/integrations/python/dataloader/CLAUDE.md b/integrations/python/dataloader/CLAUDE.md index a1be62086..d63180c90 100644 --- a/integrations/python/dataloader/CLAUDE.md +++ b/integrations/python/dataloader/CLAUDE.md @@ -90,7 +90,7 @@ Build row filters using `col()` with comparison operators (`==`, `!=`, `>`, `>=` - `UDFRegistry` / `NoOpRegistry` — ABC for registering DataFusion UDFs; `NoOpRegistry` is the default no-op - `TableScanContext` — Frozen dataclass holding table metadata, FileIO, projected schema, row filter, and table ID; pickle-safe for distributed execution - `DataFusion` dialect in `datafusion_sql.py` — Custom SQLGlot dialect for transpiling SQL from other dialects (e.g. Spark) to DataFusion -- `to_datafusion_sql()` — Transpiles a SQL statement from a source dialect to DataFusion using SQLGlot +- `to_datafusion_sql()` — Transpiles a SQL statement from a source `SqlTarget` to DataFusion using SQLGlot ## Key Dependencies diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index 19300f261..8ff0d2724 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -299,7 +299,7 @@ def _build_query(self) -> str | None: sql = transformer.transform(self._table_id, execution_context) if sql is None: return None - sql = to_datafusion_sql(sql, transformer.dialect.value, table=self._table_id) + sql = to_datafusion_sql(sql, transformer.dialect, table=self._table_id) outer_cols = ", ".join(_quote_identifier(c) for c in self._columns) if self._columns else "*" combined = f"SELECT {outer_cols} FROM ({sql}) AS _t" if self._filters and not isinstance(self._filters, AlwaysTrue): diff --git a/integrations/python/dataloader/src/openhouse/dataloader/datafusion_sql.py b/integrations/python/dataloader/src/openhouse/dataloader/datafusion_sql.py index ca11fb673..6a9834cd6 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/datafusion_sql.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/datafusion_sql.py @@ -8,6 +8,7 @@ from sqlglot.tokens import Tokenizer as _Tokenizer from sqlglot.tokens import TokenType +from openhouse.dataloader.filters import SqlTarget from openhouse.dataloader.table_identifier import TableIdentifier @@ -100,7 +101,7 @@ class Generator(_Generator): def to_datafusion_sql( sql: str, - source_dialect: str, + source_dialect: SqlTarget, *, table: TableIdentifier | None = None, ) -> str: @@ -110,28 +111,23 @@ def to_datafusion_sql( that table. Args: - sql: SQL statement in the source dialect. - source_dialect: sqlglot dialect name (e.g. "spark", "postgres"). + sql: SQL statement written in *source_dialect*. + source_dialect: The SqlTarget the *sql* is written in (e.g. ``SqlTarget.SPARK``). table: Expected table the SQL must reference. When set the function verifies the SQL contains exactly one table matching this identifier. Raises: - ValueError: If the dialect is unsupported, the SQL is invalid, the - input contains more than one statement, or (when *table* is set) - the table reference does not match. + ValueError: If the SQL is invalid, the input contains more than one + statement, or (when *table* is set) the table reference does not match. """ - if source_dialect not in Dialect.classes: - raise ValueError( - f"Unsupported source dialect '{source_dialect}'. Supported dialects: {', '.join(sorted(Dialect.classes))}" - ) - if source_dialect == DataFusion.DIALECT and table is None: + if source_dialect is SqlTarget.DATA_FUSION and table is None: return sql try: - statements = sqlglot.parse(sql, dialect=source_dialect) + statements = sqlglot.parse(sql, dialect=source_dialect.value) except sqlglot.errors.SqlglotError as e: - raise ValueError(f"Failed to transpile SQL from '{source_dialect}' to DataFusion: {e}") from e + raise ValueError(f"Failed to transpile SQL from '{source_dialect.value}' to DataFusion: {e}") from e if len(statements) != 1 or statements[0] is None: raise ValueError(f"Expected exactly one SQL statement, got {len(statements)}: {statements}") ast = statements[0] diff --git a/integrations/python/dataloader/tests/test_datafusion_sql.py b/integrations/python/dataloader/tests/test_datafusion_sql.py index bbdba3faa..db6e71cfb 100644 --- a/integrations/python/dataloader/tests/test_datafusion_sql.py +++ b/integrations/python/dataloader/tests/test_datafusion_sql.py @@ -5,6 +5,7 @@ import pytest from openhouse.dataloader.datafusion_sql import to_datafusion_sql +from openhouse.dataloader.filters import SqlTarget from openhouse.dataloader.table_identifier import TableIdentifier _DB_TBL = TableIdentifier(database="db", table="tbl") @@ -18,39 +19,38 @@ "sql, dialect, expected", [ # Spark → DataFusion - ("SELECT `col1`, `col2` FROM `my_table`", "spark", 'SELECT "col1", "col2" FROM "my_table"'), - ("SELECT SIZE(arr) FROM t", "spark", "SELECT cardinality(arr) FROM t"), - ("SELECT ARRAY(1, 2, 3)", "spark", "SELECT make_array(1, 2, 3)"), - ("SELECT UPPER(name) FROM t", "spark", "SELECT upper(name) FROM t"), - ("SELECT my_udf(col1, col2) FROM t", "spark", "SELECT my_udf(col1, col2) FROM t"), - ("SELECT IF(x > 0, 'pos', 'neg') FROM t", "spark", "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t"), + ("SELECT `col1`, `col2` FROM `my_table`", SqlTarget.SPARK, 'SELECT "col1", "col2" FROM "my_table"'), + ("SELECT SIZE(arr) FROM t", SqlTarget.SPARK, "SELECT cardinality(arr) FROM t"), + ("SELECT ARRAY(1, 2, 3)", SqlTarget.SPARK, "SELECT make_array(1, 2, 3)"), + ("SELECT UPPER(name) FROM t", SqlTarget.SPARK, "SELECT upper(name) FROM t"), + ("SELECT my_udf(col1, col2) FROM t", SqlTarget.SPARK, "SELECT my_udf(col1, col2) FROM t"), + ( + "SELECT IF(x > 0, 'pos', 'neg') FROM t", + SqlTarget.SPARK, + "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t", + ), ( "SELECT CASE WHEN status = 1 THEN 'active' ELSE 'inactive' END FROM t", - "spark", + SqlTarget.SPARK, "SELECT CASE WHEN status = 1 THEN 'active' ELSE 'inactive' END FROM t", ), ( "SELECT * FROM (SELECT id, name FROM t WHERE id > 10) sub WHERE sub.name IS NOT NULL", - "spark", + SqlTarget.SPARK, "SELECT * FROM (SELECT id, name FROM t WHERE id > 10) AS sub WHERE NOT sub.name IS NULL", ), - ("SELECT 'hello world' AS greeting", "spark", "SELECT 'hello world' AS greeting"), - ("SELECT CURRENT_TIMESTAMP()", "spark", "SELECT now()"), - ("SELECT CAST(x AS BINARY)", "spark", "SELECT TRY_CAST(x AS BYTEA)"), - # MySQL → DataFusion - ("SELECT CAST(x AS CHAR)", "mysql", "SELECT CAST(x AS VARCHAR)"), - ("SELECT CAST(x AS DATETIME)", "mysql", "SELECT CAST(x AS TIMESTAMP)"), - # Postgres → DataFusion - ("SELECT CAST(x AS TEXT)", "postgres", "SELECT CAST(x AS VARCHAR)"), + ("SELECT 'hello world' AS greeting", SqlTarget.SPARK, "SELECT 'hello world' AS greeting"), + ("SELECT CURRENT_TIMESTAMP()", SqlTarget.SPARK, "SELECT now()"), + ("SELECT CAST(x AS BINARY)", SqlTarget.SPARK, "SELECT TRY_CAST(x AS BYTEA)"), # DataFusion → DataFusion (noop) ( "SELECT cardinality(arr) FROM t WHERE x > 10 ORDER BY x LIMIT 5", - "datafusion", + SqlTarget.DATA_FUSION, "SELECT cardinality(arr) FROM t WHERE x > 10 ORDER BY x LIMIT 5", ), ], ) -def test_transpilation(sql: str, dialect: str, expected: str) -> None: +def test_transpilation(sql: str, dialect: SqlTarget, expected: str) -> None: assert to_datafusion_sql(sql, dialect) == expected @@ -62,19 +62,15 @@ def test_transpilation(sql: str, dialect: str, expected: str) -> None: class TestTranslatorEdgeCases: def test_multi_statement_raises(self) -> None: with pytest.raises(ValueError, match="Expected exactly one"): - to_datafusion_sql("SELECT 1; SELECT 2", "spark") - - def test_unsupported_dialect_raises(self) -> None: - with pytest.raises(ValueError, match="Unsupported source dialect 'nosuchdialect'"): - to_datafusion_sql("SELECT 1", "nosuchdialect") + to_datafusion_sql("SELECT 1; SELECT 2", SqlTarget.SPARK) def test_syntax_error_raises(self) -> None: with pytest.raises(ValueError, match="Failed to transpile SQL from 'spark' to DataFusion"): - to_datafusion_sql("SELECT * FROM", "spark") + to_datafusion_sql("SELECT * FROM", SqlTarget.SPARK) def test_datafusion_dialect_is_noop(self) -> None: sql = "SELECT make_array(1, 2, 3)" - assert to_datafusion_sql(sql, "datafusion") is sql + assert to_datafusion_sql(sql, SqlTarget.DATA_FUSION) is sql # --------------------------------------------------------------------------- @@ -84,42 +80,42 @@ def test_datafusion_dialect_is_noop(self) -> None: class TestTableValidation: def test_validates_single_table(self) -> None: - result = to_datafusion_sql('SELECT id FROM "db"."tbl"', "datafusion", table=_DB_TBL) + result = to_datafusion_sql('SELECT id FROM "db"."tbl"', SqlTarget.DATA_FUSION, table=_DB_TBL) assert result == 'SELECT id FROM "db"."tbl"' def test_wrong_table_name_raises(self) -> None: with pytest.raises(ValueError, match="references db.other, expected db.tbl"): - to_datafusion_sql('SELECT id FROM "db"."other"', "datafusion", table=_DB_TBL) + to_datafusion_sql('SELECT id FROM "db"."other"', SqlTarget.DATA_FUSION, table=_DB_TBL) def test_wrong_database_raises(self) -> None: with pytest.raises(ValueError, match="references other.tbl, expected db.tbl"): - to_datafusion_sql('SELECT id FROM "other"."tbl"', "datafusion", table=_DB_TBL) + to_datafusion_sql('SELECT id FROM "other"."tbl"', SqlTarget.DATA_FUSION, table=_DB_TBL) def test_multiple_tables_raises(self) -> None: with pytest.raises(ValueError, match="exactly 1 table, found 2"): to_datafusion_sql( 'SELECT * FROM "db"."tbl" JOIN "db"."tbl" AS t2 ON tbl.id = t2.id', - "datafusion", + SqlTarget.DATA_FUSION, table=_DB_TBL, ) def test_no_table_raises(self) -> None: with pytest.raises(ValueError, match="exactly 1 table, found 0"): - to_datafusion_sql("SELECT 1 AS x", "datafusion", table=_DB_TBL) + to_datafusion_sql("SELECT 1 AS x", SqlTarget.DATA_FUSION, table=_DB_TBL) def test_case_insensitive_table_match(self) -> None: - result = to_datafusion_sql('SELECT id FROM "DB"."TBL"', "datafusion", table=_DB_TBL) + result = to_datafusion_sql('SELECT id FROM "DB"."TBL"', SqlTarget.DATA_FUSION, table=_DB_TBL) assert result == 'SELECT id FROM "DB"."TBL"' def test_spark_table_validated_after_transpilation(self) -> None: - result = to_datafusion_sql("SELECT id FROM `db`.`tbl`", "spark", table=_DB_TBL) + result = to_datafusion_sql("SELECT id FROM `db`.`tbl`", SqlTarget.SPARK, table=_DB_TBL) assert result == 'SELECT id FROM "db"."tbl"' class TestNoOp: def test_no_filter_no_table_is_noop(self) -> None: sql = 'SELECT id FROM "db"."tbl"' - assert to_datafusion_sql(sql, "datafusion") is sql + assert to_datafusion_sql(sql, SqlTarget.DATA_FUSION) is sql # --------------------------------------------------------------------------- @@ -129,14 +125,14 @@ def test_no_filter_no_table_is_noop(self) -> None: def test_datafusion_execution() -> None: ctx = datafusion.SessionContext() - translated = to_datafusion_sql("SELECT SIZE(ARRAY(1, 2, 3))", "spark") + translated = to_datafusion_sql("SELECT SIZE(ARRAY(1, 2, 3))", SqlTarget.SPARK) batch = ctx.sql(translated).collect()[0] assert batch.column(0)[0].as_py() == 3 def test_datafusion_execution_median() -> None: ctx = datafusion.SessionContext() - translated = to_datafusion_sql("SELECT MEDIAN(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", "spark") + translated = to_datafusion_sql("SELECT MEDIAN(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", SqlTarget.SPARK) assert translated == "SELECT median(x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)" batch = ctx.sql(translated).collect()[0] assert batch.column(0)[0].as_py() == 3 @@ -146,7 +142,7 @@ def test_datafusion_execution_percentile_cont() -> None: ctx = datafusion.SessionContext() translated = to_datafusion_sql( "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", - "spark", + SqlTarget.SPARK, ) expected = ( "SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY x NULLS FIRST)" @@ -161,7 +157,7 @@ def test_datafusion_execution_approx_percentile_cont() -> None: ctx = datafusion.SessionContext() translated = to_datafusion_sql( "SELECT PERCENTILE_APPROX(x, 0.5) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)", - "spark", + SqlTarget.SPARK, ) assert translated == "SELECT approx_percentile_cont(x, 0.5) FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)" batch = ctx.sql(translated).collect()[0] @@ -176,7 +172,7 @@ def double_it(arr: pa.Array) -> pa.Array: ctx.register_udf(datafusion.udf(double_it, [pa.int64()], pa.int64(), "stable", name="double_it")) - translated = to_datafusion_sql("SELECT double_it(x) FROM (VALUES (5)) AS t(x)", "spark") + translated = to_datafusion_sql("SELECT double_it(x) FROM (VALUES (5)) AS t(x)", SqlTarget.SPARK) assert translated == "SELECT double_it(x) FROM (VALUES (5)) AS t(x)" batch = ctx.sql(translated).collect()[0] assert batch.column(0)[0].as_py() == 10 From 5a10946a5efbf6e5fefe3e8733ed57873764f56b Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Tue, 16 Jun 2026 15:18:08 +0000 Subject: [PATCH 5/6] test(dataloader): integration parity check between DataLoader and Spark filters Add a case to the Livy-driven integration test that reads rows via OpenHouseDataLoader with a filter, converts the same filter to Spark SQL via to_sql(filters, SqlTarget.SPARK), reads from Spark with that WHERE clause, and asserts the two row sets match. The filter excludes one matching-by-score row via a name IN, so an ignored predicate on either side fails the assertion. LivySession gains a query() that returns result rows (shared submit/poll loop factored into _run); execute() is unchanged behaviorally. Static checks pass (ruff, py_compile, unit suite); the Dockerized integration suite must be run via make integration-tests. --- .../dataloader/tests/integration_tests.py | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/integrations/python/dataloader/tests/integration_tests.py b/integrations/python/dataloader/tests/integration_tests.py index c99a5110b..000347712 100644 --- a/integrations/python/dataloader/tests/integration_tests.py +++ b/integrations/python/dataloader/tests/integration_tests.py @@ -21,7 +21,7 @@ from openhouse.dataloader import DataLoaderContext, JvmConfig, OpenHouseDataLoader from openhouse.dataloader.catalog import OpenHouseCatalog from openhouse.dataloader.data_loader_split import to_sql_identifier -from openhouse.dataloader.filters import SqlTarget, col +from openhouse.dataloader.filters import SqlTarget, col, to_sql from openhouse.dataloader.table_transformer import TableTransformer BASE_URL = "http://openhouse-tables:8080" @@ -85,6 +85,15 @@ def _wait_for_idle(self) -> None: def execute(self, sql: str) -> None: """Submit a SQL statement and wait for completion. Raises on error.""" + self._run(sql) + + def query(self, sql: str) -> list[list]: + """Submit a SELECT, wait for completion, and return its result rows.""" + output = self._run(sql) + return output["data"]["application/json"]["data"] + + def _run(self, sql: str) -> dict: + """Submit a SQL statement, wait for completion, and return its output. Raises on error.""" print(f" SQL: {sql}") resp = requests.post( f"{self._session_url}/statements", json={"code": sql}, headers=HEADERS, timeout=REQUEST_TIMEOUT @@ -103,7 +112,7 @@ def execute(self, sql: str) -> None: output = resp.json()["output"] if output["status"] == "error": raise RuntimeError(f"SQL failed: {output.get('evalue', output)}") - return + return output if state in ("error", "cancelled"): raise RuntimeError(f"Statement entered state: {state}") time.sleep(1) @@ -284,6 +293,37 @@ def read_token() -> str: assert result.column(COL_SCORE).to_pylist() == [1.1, 2.2, 3.3, 4.4] print(f"PASS: after second insert, read all {result.num_rows} rows") + # 6b. Spark-SQL filter parity: read rows with DataLoader filters, then read the + # same rows from Spark using to_sql(filters, SqlTarget.SPARK), and verify the + # two row sets are identical. diana (score > 2.0 but excluded by the name IN) + # proves the predicate is actually applied on both sides, not ignored. + parity_filter = (col(COL_SCORE) > 2.0) & col(COL_NAME).is_in(["alice", "bob", "charlie"]) + parity_loader = OpenHouseDataLoader( + catalog=catalog, database=DATABASE_ID, table=TABLE_ID, filters=parity_filter + ) + dl_result = _read_all(parity_loader) + dl_rows = list( + zip( + dl_result.column(COL_ID).to_pylist(), + dl_result.column(COL_NAME).to_pylist(), + dl_result.column(COL_SCORE).to_pylist(), + strict=True, + ) + ) + + spark_where = to_sql(parity_filter, SqlTarget.SPARK) + spark_rows = [ + tuple(row) + for row in livy.query( + f"SELECT {COL_ID}, {COL_NAME}, {COL_SCORE} FROM {FQTN} WHERE {spark_where} ORDER BY {COL_ID}" + ) + ] + assert dl_rows, "Expected the parity filter to match at least one row" + assert dl_rows == spark_rows, ( + f"DataLoader rows {dl_rows} != Spark rows {spark_rows} (Spark WHERE: {spark_where})" + ) + print(f"PASS: DataLoader and Spark agree on {len(dl_rows)} rows for WHERE {spark_where}") + # 7. Pin to the old snapshot and verify only the original data is returned loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, snapshot_id=snap1) result = _read_all(loader) From 2aa65ed3157664117e86c26a818ca7033952d41e Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Wed, 17 Jun 2026 17:24:53 +0000 Subject: [PATCH 6/6] refactor(dataloader): make _non_finite_double check each special value explicitly Normalize the float/Decimal input to a float once, then identify NaN / +Infinity / -Infinity with explicit positive checks instead of inferring -Infinity from a fallthrough. The impossible case now raises (a guard) rather than silently producing '-Infinity', and there's a single spelling source (no separate str(Decimal) branch). Behavior is unchanged; 340 tests pass. --- .../src/openhouse/dataloader/filters.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/integrations/python/dataloader/src/openhouse/dataloader/filters.py b/integrations/python/dataloader/src/openhouse/dataloader/filters.py index 9ea662407..99a76db79 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/filters.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/filters.py @@ -327,21 +327,22 @@ def _escape_like(value: str) -> str: def _non_finite_double(value: float | Decimal) -> exp.Cast: - """Build ``CAST('' AS DOUBLE)`` for a non-finite float/decimal. + """Build ``CAST('' AS DOUBLE)`` for a non-finite float/decimal. - Uses the canonical ``NaN``/``Infinity``/``-Infinity`` spellings, which both - DataFusion and Spark parse. The lowercase forms from ``str(float(...))`` - (e.g. ``'inf'``) are not reliably cast by Spark. + Spark and DataFusion both parse the IEEE special-value spellings ``NaN`` / + ``Infinity`` / ``-Infinity`` (the lowercase ``str(float(...))`` forms like + ``'inf'`` are not reliably cast by Spark). """ - if isinstance(value, Decimal): - text = str(value) # 'NaN' / 'Infinity' / '-Infinity' - elif math.isnan(value): - text = "NaN" - elif value > 0: - text = "Infinity" + number = float(value) + if math.isnan(number): + spelling = "NaN" + elif number == math.inf: + spelling = "Infinity" + elif number == -math.inf: + spelling = "-Infinity" else: - text = "-Infinity" - return exp.Cast(this=exp.Literal.string(text), to=exp.DataType.build("DOUBLE")) + raise ValueError(f"_non_finite_double expects a non-finite value, got {value!r}") + return exp.Cast(this=exp.Literal.string(spelling), to=exp.DataType.build("DOUBLE")) def _literal_to_expr(value: object) -> exp.Expression: