Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions integrations/python/dataloader/CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,22 @@ Exported in `__init__.py`:
- `OpenHouseCatalogError` — Error raised when catalog fails to load a table
- `col()` — Column reference for building filter expressions
- `always_true()` — Filter that matches all rows
- `to_sql()` — Render a filter as a SQL boolean expression (WHERE-clause predicate) for a target
- `SqlTarget` — Enum of supported SQL flavors (`SPARK`, `TRINO`, `DATA_FUSION`); used by `to_sql()` and `TableTransformer`

### Filter DSL (`filters.py`)

Build row filters using `col()` with comparison operators (`==`, `!=`, `>`, `>=`, `<`, `<=`) and predicates (`is_null()`, `is_not_null()`, `is_nan()`, `is_not_nan()`, `is_in()`, `is_not_in()`, `starts_with()`, `not_starts_with()`, `between()`). Combine with `&` (AND), `|` (OR), `~` (NOT). Filters are converted to PyIceberg expressions internally for partition pruning and file-level filtering.

`to_sql(filter, target=SqlTarget.SPARK)` renders a filter as a SQL boolean expression for the given target. Internally, `_filter_to_expr()` builds a single dialect-agnostic sqlglot AST that is rendered per target with `.sql(dialect=target.value)` — the same path the loader uses for its internal DataFusion query via `to_sql(filters, SqlTarget.DATA_FUSION)`. The backing sqlglot dialect string is an internal detail — callers select a `SqlTarget`, never a dialect string.

### Internal modules (not in `__init__.py`)

- `TableTransformer` — ABC for SQL-based table transforms; subclass must provide a `dialect` (e.g. `"spark"`) and implement `transform()` returning SQL or `None`
- `TableTransformer` — ABC for SQL-based table transforms; subclass must provide a `dialect` (a `SqlTarget`, e.g. `SqlTarget.SPARK`) and implement `transform()` returning SQL or `None`
- `UDFRegistry` / `NoOpRegistry` — ABC for registering DataFusion UDFs; `NoOpRegistry` is the default no-op
- `TableScanContext` — Frozen dataclass holding table metadata, FileIO, projected schema, row filter, and table ID; pickle-safe for distributed execution
- `DataFusion` dialect in `datafusion_sql.py` — Custom SQLGlot dialect for transpiling SQL from other dialects (e.g. Spark) to DataFusion
- `to_datafusion_sql()` — Transpiles a SQL statement from a source dialect to DataFusion using SQLGlot
- `to_datafusion_sql()` — Transpiles a SQL statement from a source `SqlTarget` to DataFusion using SQLGlot

## Key Dependencies

Expand Down
13 changes: 13 additions & 0 deletions integrations/python/dataloader/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ filters = col("score").between(0.5, 1.0)
filters = (col("age") >= 18) & (col("country").is_in(["US", "CA"])) & ~col("email").is_null()
```

### Rendering a filter as SQL

Use `to_sql()` to render a filter as a SQL boolean expression (a `WHERE`-clause
predicate) for a given `SqlTarget` (`SPARK`, `TRINO`, or `DATA_FUSION`):

```python
from openhouse.dataloader import SqlTarget, col, to_sql

to_sql(col("age") > 21) # `age` > 21 (defaults to Spark)
to_sql((col("country") == "US") & col("email").is_null(), SqlTarget.SPARK)
# `country` = 'US' AND `email` IS NULL
```

## Development

```bash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from openhouse.dataloader.catalog import OpenHouseCatalog, OpenHouseCatalogError
from openhouse.dataloader.data_loader import DataLoaderContext, JvmConfig, OpenHouseDataLoader
from openhouse.dataloader.filters import always_true, col
from openhouse.dataloader.filters import SqlTarget, always_true, col, to_sql

__version__ = version("openhouse.dataloader")
__all__ = [
Expand All @@ -11,6 +11,8 @@
"JvmConfig",
"OpenHouseCatalog",
"OpenHouseCatalogError",
"SqlTarget",
"always_true",
"col",
"to_sql",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
from openhouse.dataloader.filters import (
AlwaysTrue,
Filter,
SqlTarget,
_quote_identifier,
_to_datafusion_sql,
_to_pyiceberg,
always_true,
to_sql,
)
from openhouse.dataloader.metrics import METER_NAME
from openhouse.dataloader.scan_optimizer import optimize_scan
Expand Down Expand Up @@ -302,7 +303,7 @@ def _build_query(self) -> str | None:
outer_cols = ", ".join(_quote_identifier(c) for c in self._columns) if self._columns else "*"
combined = f"SELECT {outer_cols} FROM ({sql}) AS _t"
if self._filters and not isinstance(self._filters, AlwaysTrue):
combined += f" WHERE {_to_datafusion_sql(self._filters)}"
combined += f" WHERE {to_sql(self._filters, SqlTarget.DATA_FUSION)}"
return combined

def __iter__(self) -> Iterator[DataLoaderSplit]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlglot.tokens import Tokenizer as _Tokenizer
from sqlglot.tokens import TokenType

from openhouse.dataloader.filters import SqlTarget
from openhouse.dataloader.table_identifier import TableIdentifier


Expand Down Expand Up @@ -100,7 +101,7 @@ class Generator(_Generator):

def to_datafusion_sql(
sql: str,
source_dialect: str,
source_dialect: SqlTarget,
*,
table: TableIdentifier | None = None,
) -> str:
Expand All @@ -110,28 +111,23 @@ def to_datafusion_sql(
that table.

Args:
sql: SQL statement in the source dialect.
source_dialect: sqlglot dialect name (e.g. "spark", "postgres").
sql: SQL statement written in *source_dialect*.
source_dialect: The SqlTarget the *sql* is written in (e.g. ``SqlTarget.SPARK``).
table: Expected table the SQL must reference. When set the function
verifies the SQL contains exactly one table matching this
identifier.

Raises:
ValueError: If the dialect is unsupported, the SQL is invalid, the
input contains more than one statement, or (when *table* is set)
the table reference does not match.
ValueError: If the SQL is invalid, the input contains more than one
statement, or (when *table* is set) the table reference does not match.
"""
if source_dialect not in Dialect.classes:
raise ValueError(
f"Unsupported source dialect '{source_dialect}'. Supported dialects: {', '.join(sorted(Dialect.classes))}"
)
if source_dialect == DataFusion.DIALECT and table is None:
if source_dialect is SqlTarget.DATA_FUSION and table is None:
return sql

try:
statements = sqlglot.parse(sql, dialect=source_dialect)
statements = sqlglot.parse(sql, dialect=source_dialect.value)
except sqlglot.errors.SqlglotError as e:
raise ValueError(f"Failed to transpile SQL from '{source_dialect}' to DataFusion: {e}") from e
raise ValueError(f"Failed to transpile SQL from '{source_dialect.value}' to DataFusion: {e}") from e
if len(statements) != 1 or statements[0] is None:
raise ValueError(f"Expected exactly one SQL statement, got {len(statements)}: {statements}")
ast = statements[0]
Expand Down
159 changes: 119 additions & 40 deletions integrations/python/dataloader/src/openhouse/dataloader/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from datetime import date, datetime, time
from decimal import Decimal
from enum import Enum
from typing import Any
from uuid import UUID

Expand Down Expand Up @@ -325,101 +326,179 @@ def _escape_like(value: str) -> str:
return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")


def _literal_to_sql(value: object) -> str:
"""Convert a Python literal to a SQL literal string using sqlglot.
def _non_finite_double(value: float | Decimal) -> exp.Cast:
"""Build ``CAST('<spelling>' AS DOUBLE)`` for a non-finite float/decimal.

Spark and DataFusion both parse the IEEE special-value spellings ``NaN`` /
``Infinity`` / ``-Infinity`` (the lowercase ``str(float(...))`` forms like
``'inf'`` are not reliably cast by Spark).
"""
number = float(value)
if math.isnan(number):
spelling = "NaN"
elif number == math.inf:
spelling = "Infinity"
elif number == -math.inf:
spelling = "-Infinity"
else:
raise ValueError(f"_non_finite_double expects a non-finite value, got {value!r}")
return exp.Cast(this=exp.Literal.string(spelling), to=exp.DataType.build("DOUBLE"))


def _literal_to_expr(value: object) -> exp.Expression:
"""Convert a Python literal to a sqlglot literal expression.

Datetime/date/time values are emitted as plain string literals (ISO format).
DataFusion implicitly coerces string literals to the column type at execution,
and PyIceberg promotes StringLiteral to the matching typed literal during expression binding.
"""
if isinstance(value, str):
return exp.Literal.string(value).sql()
return exp.Literal.string(value)
if isinstance(value, bool):
return exp.Boolean(this=True).sql() if value else exp.Boolean(this=False).sql()
return exp.true() if value else exp.false()
if isinstance(value, datetime):
return exp.Literal.string(value.isoformat()).sql()
return exp.Literal.string(value.isoformat())
if isinstance(value, date):
return exp.Literal.string(value.isoformat()).sql()
return exp.Literal.string(value.isoformat())
if isinstance(value, time):
if value.tzinfo is not None:
raise TypeError(
"DataFusion does not support timezones for time data types. "
"The SQL target does not support timezones for time data types. "
"The time should match the timezone used in the dataset."
)
return exp.Literal.string(value.isoformat()).sql()
return exp.Literal.string(value.isoformat())
if isinstance(value, (int, float)):
if isinstance(value, float) and not math.isfinite(value):
return exp.Cast(this=exp.Literal.string(str(value)), to=exp.DataType.build("DOUBLE")).sql()
return exp.Literal.number(value).sql()
return _non_finite_double(value)
return exp.Literal.number(value)
if isinstance(value, Decimal):
if not value.is_finite():
return exp.Cast(this=exp.Literal.string(str(value)), to=exp.DataType.build("DOUBLE")).sql()
return exp.Literal.number(value).sql()
return _non_finite_double(value)
return exp.Literal.number(value)
if isinstance(value, UUID):
return exp.Literal.string(str(value)).sql()
return exp.Literal.string(str(value))
raise TypeError(f"Unsupported literal type: {type(value).__name__}")


def _to_datafusion_sql(expr: Filter) -> str:
"""Convert a Filter expression tree to a DataFusion SQL expression string."""
match expr:
class SqlTarget(Enum):
"""A SQL flavor used by this library.

Used both to render filters (:func:`to_sql`) and to declare the SQL a
``TableTransformer`` emits. Members name a concrete SQL flavor; the backing
sqlglot dialect string is an internal detail — callers use the member, not
the string.
"""

SPARK = "spark"
TRINO = "trino"
DATA_FUSION = "datafusion"


def _column_expr(name: str) -> exp.Column:
"""Build a quoted sqlglot column reference."""
return exp.column(exp.to_identifier(name, quoted=True))


def _like_prefix(column: str, prefix: str) -> exp.Expression:
r"""Build ``col LIKE 'prefix%' ESCAPE '\'`` with the prefix matched literally."""
pattern = exp.Literal.string(_escape_like(prefix) + "%")
like = exp.Like(this=_column_expr(column), expression=pattern)
return exp.Escape(this=like, expression=exp.Literal.string("\\"))


def _isnan(column: str, target: SqlTarget) -> exp.Anonymous:
"""Build the dialect-appropriate NaN-check function call.

Trino spells it ``is_nan``; Spark and DataFusion use ``isnan``. (sqlglot's
built-in IsNan node renders ``IS_NAN(...)``, which none of these accept, so we
emit an Anonymous function with the correct name for the target.)
"""
name = "is_nan" if target is SqlTarget.TRINO else "isnan"
return exp.Anonymous(this=name, expressions=[_column_expr(column)])


def _filter_to_expr(filter_expr: Filter, target: SqlTarget) -> exp.Expression:
"""Build a sqlglot expression tree for a Filter, ready to render for *target*.

The tree is dialect-agnostic except where SQL genuinely diverges (NaN checks);
sqlglot handles per-dialect identifier quoting, operators, and literal
formatting at render time. See :func:`to_sql`.
"""
match filter_expr:
case AlwaysTrue():
return exp.Boolean(this=True).sql()
return exp.true()

# Comparison
case EqualTo(column, value):
return f"{_quote_identifier(column)} = {_literal_to_sql(value)}"
return exp.EQ(this=_column_expr(column), expression=_literal_to_expr(value))
case NotEqualTo(column, value):
return f"{_quote_identifier(column)} <> {_literal_to_sql(value)}"
return exp.NEQ(this=_column_expr(column), expression=_literal_to_expr(value))
case GreaterThan(column, value):
return f"{_quote_identifier(column)} > {_literal_to_sql(value)}"
return exp.GT(this=_column_expr(column), expression=_literal_to_expr(value))
case GreaterThanOrEqual(column, value):
return f"{_quote_identifier(column)} >= {_literal_to_sql(value)}"
return exp.GTE(this=_column_expr(column), expression=_literal_to_expr(value))
case LessThan(column, value):
return f"{_quote_identifier(column)} < {_literal_to_sql(value)}"
return exp.LT(this=_column_expr(column), expression=_literal_to_expr(value))
case LessThanOrEqual(column, value):
return f"{_quote_identifier(column)} <= {_literal_to_sql(value)}"
return exp.LTE(this=_column_expr(column), expression=_literal_to_expr(value))

# Null / NaN
case IsNull(column):
return f"{_quote_identifier(column)} IS NULL"
return exp.Is(this=_column_expr(column), expression=exp.Null())
case IsNotNull(column):
return f"{_quote_identifier(column)} IS NOT NULL"
return exp.not_(exp.Is(this=_column_expr(column), expression=exp.Null()))
case IsNaN(column):
return f"{_quote_identifier(column)} IS NAN"
return _isnan(column, target)
case IsNotNaN(column):
return f"{_quote_identifier(column)} IS NOT NAN"
return exp.not_(_isnan(column, target))

# Set membership
case In(column, values):
vals = ", ".join(_literal_to_sql(v) for v in values)
return f"{_quote_identifier(column)} IN ({vals})"
return exp.In(this=_column_expr(column), expressions=[_literal_to_expr(v) for v in values])
case NotIn(column, values):
vals = ", ".join(_literal_to_sql(v) for v in values)
return f"{_quote_identifier(column)} NOT IN ({vals})"
return exp.not_(exp.In(this=_column_expr(column), expressions=[_literal_to_expr(v) for v in values]))

# String prefix
case StartsWith(column, prefix):
escaped = _escape_like(prefix)
return f"{_quote_identifier(column)} LIKE {_literal_to_sql(escaped + '%')} ESCAPE '\\'"
return _like_prefix(column, prefix)
case NotStartsWith(column, prefix):
escaped = _escape_like(prefix)
return f"{_quote_identifier(column)} NOT LIKE {_literal_to_sql(escaped + '%')} ESCAPE '\\'"
return exp.not_(_like_prefix(column, prefix))

# Range
case Between(column, lower, upper):
return f"{_quote_identifier(column)} BETWEEN {_literal_to_sql(lower)} AND {_literal_to_sql(upper)}"
return exp.Between(
this=_column_expr(column),
low=_literal_to_expr(lower),
high=_literal_to_expr(upper),
)

# Logical combinators
case And(left, right):
return f"({_to_datafusion_sql(left)} AND {_to_datafusion_sql(right)})"
return exp.and_(_filter_to_expr(left, target), _filter_to_expr(right, target))
case Or(left, right):
return f"({_to_datafusion_sql(left)} OR {_to_datafusion_sql(right)})"
return exp.or_(_filter_to_expr(left, target), _filter_to_expr(right, target))
case Not(operand):
return f"NOT ({_to_datafusion_sql(operand)})"
return exp.not_(_filter_to_expr(operand, target))

case _:
raise TypeError(f"Unsupported filter type: {type(expr).__name__}")
raise TypeError(f"Unsupported filter type: {type(filter_expr).__name__}")


def to_sql(filter_expr: Filter, target: SqlTarget = SqlTarget.SPARK) -> str:
"""Render a filter as a SQL boolean expression for the given target.

The result is a WHERE-clause-ready predicate (without a leading ``WHERE``).

Example::

to_sql(col("age") > 21) # `age` > 21
to_sql((col("a") == 1) & col("b").is_null()) # `a` = 1 AND `b` IS NULL

Args:
filter_expr: The filter expression to render.
target: The SQL flavor to render for. Defaults to Spark.
"""
return _filter_to_expr(filter_expr, target).sql(dialect=target.value)


def _to_pyiceberg(expr: Filter) -> ice.BooleanExpression:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping

from openhouse.dataloader.filters import SqlTarget
from openhouse.dataloader.table_identifier import TableIdentifier


Expand All @@ -13,11 +14,11 @@ class TableTransformer(ABC):
projections and filters can reference them.

Args:
dialect: The SQL dialect used by ``transform()`` (e.g. ``"spark"``).
dialect: The SQL flavor that ``transform()`` emits (e.g. ``SqlTarget.SPARK``).
"""

def __init__(self, dialect: str) -> None:
self.dialect: str = dialect
def __init__(self, dialect: SqlTarget) -> None:
self.dialect: SqlTarget = dialect

@abstractmethod
def transform(self, table: TableIdentifier, context: Mapping[str, str]) -> str | None:
Expand Down
Loading
Loading