From d2b2507b589644c6bab997d381400e3ed4855aac Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Mon, 25 May 2026 22:52:01 +0200 Subject: [PATCH 1/6] feature - implement wave 1 function catalog foundations --- .../language/reference/builders/aggregates.md | 22 +- docs/language/reference/functions/index.md | 7 +- docs/release_notes/v0_1.md | 4 +- src/aggregate_builders.incn | 42 +++- src/dataset/mod.incn | 4 +- src/function_registry.incn | 176 ++++++++++++++- src/functions/aggregates/avg.incn | 44 ++++ src/functions/aggregates/count.incn | 44 +++- src/functions/aggregates/max.incn | 44 ++++ src/functions/aggregates/min.incn | 44 ++++ src/functions/aggregates/mod.incn | 5 +- src/functions/math/abs.incn | 51 +++++ src/functions/math/ceil.incn | 51 +++++ src/functions/math/floor.incn | 51 +++++ src/functions/math/mod.incn | 6 + src/functions/math/round.incn | 52 +++++ src/functions/mod.incn | 20 +- src/functions/registry.incn | 15 +- src/lib.incn | 20 +- src/substrait/extensions.incn | 9 +- src/substrait/function_extensions.incn | 7 + src/substrait/inspect.incn | 29 ++- src/substrait/mod.incn | 1 + src/substrait/relations.incn | 84 ++++++-- tests/test_common_scalar_functions.incn | 32 +++ tests/test_dataset.incn | 14 ++ tests/test_function_registry.incn | 204 +++++++++++++++++- tests/test_session_aggregates.incn | 34 ++- tests/test_session_projection.incn | 56 ++++- tests/test_substrait_plan.incn | 43 +++- 30 files changed, 1136 insertions(+), 79 deletions(-) create mode 100644 src/functions/aggregates/avg.incn create mode 100644 src/functions/aggregates/max.incn create mode 100644 src/functions/aggregates/min.incn create mode 100644 src/functions/math/abs.incn create mode 100644 src/functions/math/ceil.incn create mode 100644 src/functions/math/floor.incn create mode 100644 src/functions/math/mod.incn create mode 100644 src/functions/math/round.incn create mode 100644 tests/test_common_scalar_functions.incn diff --git a/docs/language/reference/builders/aggregates.md b/docs/language/reference/builders/aggregates.md index b066e43..092540a 100644 --- a/docs/language/reference/builders/aggregates.md +++ b/docs/language/reference/builders/aggregates.md @@ -9,17 +9,31 @@ Current aggregate authoring is explicit and scalar-expression-based. | `col` | `def col(name: str) -> ColumnExpr` | Column reference builder used by aggregates, filters, and projections. | | `lit` | `def lit(value: int \| float \| str \| bool) -> ColumnExpr` | Canonical scalar literal helper. | | `sum` | `def sum(expr: ColumnExpr) -> AggregateMeasure` | Sum one scalar expression. | -| `count` | `def count() -> AggregateMeasure` | Count rows in the current relation or group. | +| `count` | `def count() -> AggregateMeasure` | Count rows. | +| `count_expr` | `def count_expr(expr: ColumnExpr) -> AggregateMeasure` | Count non-null expression values; compatibility spelling for the future `count(expr)` form. | +| `avg` | `def avg(expr: ColumnExpr) -> AggregateMeasure` | Average one numeric scalar expression. | +| `min` | `def min(expr: ColumnExpr) -> AggregateMeasure` | Return the minimum non-null value for one orderable scalar expression. | +| `max` | `def max(expr: ColumnExpr) -> AggregateMeasure` | Return the maximum non-null value for one orderable scalar expression. | ## Example ```incan -from pub::inql.functions import add, col, count, lit, sum - -grouped = orders.group_by([col("customer_id")]).agg([sum(add(col("amount"), lit(5))), count()]) +from pub::inql.functions import add, avg, col, count, count_expr, lit, max, min, sum + +grouped = orders.group_by([col("customer_id")]).agg([ + sum(add(col("amount"), lit(5))), + count(), + count_expr(col("discount_code")), + avg(col("amount")), + min(col("created_at")), + max(col("created_at")), +]) ``` ## Notes - Aggregate inputs use the same scalar-expression model as filters, projections, and grouping keys. +- `count()` counts rows. `count_expr(expr)` counts non-null values produced by the expression and lowers to the same + canonical `count` Substrait extension function. +- `sum`, `avg`, `min`, and `max` skip null values. They return backend-null results when no non-null input value exists. - Future `.column` sugar and scoped aggregate symbols should lower to this same surface rather than replacing its semantics. diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index 38d7cb2..36a54ba 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -10,10 +10,12 @@ Today the concrete shipped surfaces are documented here: The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, ordering, and aggregates. Each runtime entry exposes a stable function reference such as `inql.functions.col`, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), function class, null behavior, alias policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, and aggregates. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. +RFC 024 policy category is separate from function class. Function class describes the semantic shape (`scalar`, `aggregate`, `ordering`, and later table-valued or partition-transform shapes). Policy category describes where the function belongs: portable core, explicitly namespaced extension-only, opt-in compatibility alias, engine-specific, or rejected compatibility request. Name-only registry lookup remains core-scoped; extension and engine-specific entries use namespace-qualified lookup so compatibility names cannot silently shadow portable core names. Rejected requests are documented as rejection metadata, not as lowerable registry entries or fake Substrait mappings. + The registered helper surface currently includes: | Function | Registry class | Mapping | @@ -28,7 +30,8 @@ The registered helper surface currently includes: | `is_null(...)`, `is_not_null(...)`, `is_nan(...)`, `is_not_nan(...)` | scalar | registered predicate mappings; `is_not_nan(...)` lowers as `not(is_nan(...))` | | `coalesce(...)`, `nullif(...)`, `case_when(...)` | scalar | registered Substrait mappings; `case_when(...)` lowers as built-in `IfThen` | | `in_(...)`, `between(...)` | scalar | built-in membership/range lowering (`SingularOrList` and `between`) | +| `abs(...)`, `ceil(...)`, `floor(...)`, `round(...)` | scalar | registered Substrait math scalar mappings; `round(...)` is currently the single-argument form | | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | -| `sum(...)`, `count()` | aggregate | registered Substrait extension functions | +| `sum(...)`, `count()`, `count_expr(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions; `count_expr(...)` is a compatibility spelling for future `count(expr)` helper overloading | Future ANSI-style families should grow under this section instead of bloating `dataset_types` or `dataset_methods`. diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index 3ac9691..f06d757 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -10,10 +10,12 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Carriers:** `DataSet[T]` hierarchy including bounded vs unbounded traits and concrete frame/stream types. - **Plans:** Apache Substrait as the logical interchange contract. - **Authoring:** method-chain lowering into a real Substrait boundary today, with `query {}` work still ahead. -- **Aggregates:** builder-based `col`, `sum`, and `count` helpers now lower grouped and global aggregates through Prism, Substrait, and Session execution. +- **Aggregates:** builder-based `col`, `sum`, `count`, `count_expr`, `avg`, `min`, and `max` helpers now lower grouped and global aggregates through Prism, Substrait, and Session execution. `count()` counts rows, and `count_expr(expr)` counts non-null expression values while preserving the future `count(expr)` semantics. - **Scalar expressions:** RFC 012 unifies filter predicates, computed projection values, grouping keys, and aggregate inputs around one `ColumnExpr` surface with canonical `lit(...)` and typed literal helpers. - **Core scalar functions:** RFC 015 adds registry-backed scalar function applications and the first core helper slice for casts, comparisons, boolean logic, null/NaN predicates, arithmetic, conditionals, membership/range predicates, and ordering expressions. Implemented helpers lower to Substrait IR through registry metadata, built-in Rex shapes, or structural sort-field lowering; DataFusion remains the first execution adapter rather than the semantic boundary. +- **Common scalar functions:** The first RFC 018 slice adds registry-backed math helpers for `abs(...)`, `ceil(...)`, `floor(...)`, and single-argument `round(...)`, with Substrait mappings and DataFusion-backed execution coverage. - **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. +- **Function extension policy:** RFC 024 policy metadata now distinguishes portable core functions, namespaced extension-only functions, opt-in compatibility aliases, engine-specific functions, and rejected compatibility requests without adding an extension plugin system or backend-owned semantics. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. - **Substrait internals:** RFC 002 helpers are now split into focused owner modules for relation building, plan assembly, inspection, schema registry, extension bookkeeping, and expression lowering instead of one `substrait.plan` godmodule. - **Prism:** `LazyFrame` lowering applies safe canonical rewrites (`Filter(true)` elimination and adjacent `Limit`/`Project`/`OrderBy` collapse) before RFC 002 plan emission. diff --git a/src/aggregate_builders.incn b/src/aggregate_builders.incn index 9be0647..320f60b 100644 --- a/src/aggregate_builders.incn +++ b/src/aggregate_builders.incn @@ -6,6 +6,7 @@ Today they provide a library-owned way to express grouping columns and aggregate typechecker changes in the Incan compiler. """ +from function_registry import function_ref_for from projection_builders import ColumnExpr, col as col_expr @@ -15,6 +16,9 @@ pub enum AggregateKind(str): Sum = "sum" Count = "count" + Avg = "avg" + Min = "min" + Max = "max" @derive(Clone) @@ -22,7 +26,21 @@ pub model AggregateMeasure: """Aggregate measure description carried through dataset, Prism, and Substrait boundaries.""" pub kind: AggregateKind + pub function_ref: str + pub canonical_name: str pub expr: ColumnExpr + pub has_expr: bool + + +def _aggregate_measure(canonical_name: str, kind: AggregateKind, expr: ColumnExpr, has_expr: bool) -> AggregateMeasure: + """Build one registry-backed aggregate measure description.""" + return AggregateMeasure( + kind=kind, + function_ref=function_ref_for(canonical_name), + canonical_name=canonical_name, + expr=expr, + has_expr=has_expr, + ) pub def col(name: str) -> ColumnExpr: @@ -32,9 +50,29 @@ pub def col(name: str) -> ColumnExpr: pub def sum(expr: ColumnExpr) -> AggregateMeasure: """Build one `sum` aggregate measure over a scalar expression.""" - return AggregateMeasure(kind=AggregateKind.Sum, expr=expr) + return _aggregate_measure("sum", AggregateKind.Sum, expr, true) pub def count() -> AggregateMeasure: """Build one zero-argument `count` aggregate measure.""" - return AggregateMeasure(kind=AggregateKind.Count, expr=col_expr("")) + return _aggregate_measure("count", AggregateKind.Count, col_expr(""), false) + + +pub def count_expr(expr: ColumnExpr) -> AggregateMeasure: + """Build one expression-count aggregate measure.""" + return _aggregate_measure("count_expr", AggregateKind.Count, expr, true) + + +pub def avg(expr: ColumnExpr) -> AggregateMeasure: + """Build one `avg` aggregate measure over a scalar expression.""" + return _aggregate_measure("avg", AggregateKind.Avg, expr, true) + + +pub def min(expr: ColumnExpr) -> AggregateMeasure: + """Build one `min` aggregate measure over a scalar expression.""" + return _aggregate_measure("min", AggregateKind.Min, expr, true) + + +pub def max(expr: ColumnExpr) -> AggregateMeasure: + """Build one `max` aggregate measure over a scalar expression.""" + return _aggregate_measure("max", AggregateKind.Max, expr, true) diff --git a/src/dataset/mod.incn b/src/dataset/mod.incn index 0d34d4c..fa850bd 100644 --- a/src/dataset/mod.incn +++ b/src/dataset/mod.incn @@ -28,7 +28,7 @@ Illustrative current-shape examples: ```incan from pub::inql import LazyFrame -from pub::inql.functions import add, col, count, gt, lit, sum +from pub::inql.functions import add, avg, col, count, gt, lit, max, min, sum from models import Order, OrderSummary def high_value_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: @@ -38,7 +38,7 @@ def enrich_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: return orders.with_column("amount_x2", add(col("amount"), lit(2))) def summarize_orders(orders: LazyFrame[Order]) -> LazyFrame[OrderSummary]: - return orders.group_by([col("customer_id")]).agg([sum(col("amount")), count()]) + return orders.group_by([col("customer_id")]).agg([sum(col("amount")), count(), avg(col("amount")), min(col("amount")), max(col("amount"))]) ``` See also: diff --git a/src/function_registry.incn b/src/function_registry.incn index b4bcb51..edfc0cb 100644 --- a/src/function_registry.incn +++ b/src/function_registry.incn @@ -8,7 +8,18 @@ and signatures are checked API metadata facts, not second copies in this runtime from substrait.function_extensions import function_extension_uri -const FUNCTION_REF_PREFIX: str = "inql.functions." +const CORE_FUNCTION_NAMESPACE: str = "inql.functions" + + +@derive(Clone) +pub enum FunctionPolicyCategory(str): + """RFC 024 portability and extension-policy category for one function-like metadata item.""" + + PortableCore = "portable_core" + ExtensionOnly = "extension_only" + CompatibilityAlias = "compatibility_alias" + EngineSpecific = "engine_specific" + Rejected = "rejected" @derive(Clone) @@ -22,7 +33,6 @@ pub enum FunctionClass(str): Generator = "generator" TableValued = "table_valued" PartitionTransform = "partition_transform" - ExtensionOnly = "extension_only" @derive(Clone) @@ -124,6 +134,8 @@ pub model SubstraitMapping: pub model FunctionSpec: """Machine-readable function facts supplied to the registry decorator.""" + pub namespace: str + pub policy_category: FunctionPolicyCategory pub function_class: FunctionClass pub aliases: list[str] pub alias_policy: FunctionAliasPolicy @@ -139,7 +151,9 @@ pub model FunctionRegistryEntry: """Runtime projection for one registered InQL function.""" pub function_ref: str + pub namespace: str pub canonical_name: str + pub policy_category: FunctionPolicyCategory pub function_class: FunctionClass pub aliases: list[str] pub alias_policy: FunctionAliasPolicy @@ -150,6 +164,16 @@ pub model FunctionRegistryEntry: pub substrait: SubstraitMapping +@derive(Clone) +pub model RejectedFunctionPolicy: + """Documented rejection metadata for a likely compatibility request that is not a function entry.""" + + pub requested_name: str + pub policy_category: FunctionPolicyCategory + pub reason: str + pub alternative: str + + @derive(Clone) pub class FunctionRegistry: """Mutable package-owned function registry populated by declaration-site decorators.""" @@ -169,16 +193,19 @@ pub class FunctionRegistry: def _add_entry(mut self, canonical_name: str, spec: FunctionSpec) -> None: """Record one registry entry or fail fast on invalid duplicate metadata.""" assert _valid_canonical_name(canonical_name), (f"function canonical name must be one non-empty identifier segment, found `{canonical_name}`") + assert _valid_namespace(spec.namespace), f"function namespace must be non-empty, found `{spec.namespace}`" - function_ref = function_ref_for(canonical_name) + function_ref = namespaced_function_ref(spec.namespace, canonical_name) for entry in self.entries: assert entry.function_ref != function_ref, f"duplicate function reference `{function_ref}`" - assert entry.canonical_name != canonical_name, f"duplicate canonical function name `{canonical_name}`" + assert entry.namespace != spec.namespace or entry.canonical_name != canonical_name, f"duplicate canonical function name `{canonical_name}` in namespace `{spec.namespace}`" self.entries.append( FunctionRegistryEntry( function_ref=function_ref, + namespace=spec.namespace, canonical_name=canonical_name, + policy_category=spec.policy_category, function_class=spec.function_class, aliases=spec.aliases, alias_policy=spec.alias_policy, @@ -199,9 +226,13 @@ pub class FunctionRegistry: return None def entry_by_name(self, canonical_name: str) -> Option[FunctionRegistryEntry]: - """Return the registry entry for one canonical public function name when it is known.""" + """Return the core registry entry for one canonical public function name when it is known.""" + return self.entry_by_namespace_and_name(core_function_namespace(), canonical_name) + + def entry_by_namespace_and_name(self, namespace: str, canonical_name: str) -> Option[FunctionRegistryEntry]: + """Return the registry entry for one namespace-qualified canonical function name when it is known.""" for entry in self.entries: - if entry.canonical_name == canonical_name: + if entry.namespace == namespace and entry.canonical_name == canonical_name: return Some(entry) return None @@ -224,7 +255,19 @@ pub class FunctionRegistry: pub def function_ref_for(canonical_name: str) -> str: """Return the durable registry reference for one canonical InQL function name.""" - return f"{FUNCTION_REF_PREFIX}{canonical_name}" + return namespaced_function_ref(core_function_namespace(), canonical_name) + + +pub def core_function_namespace() -> str: + """Return the namespace used by portable core package functions.""" + return CORE_FUNCTION_NAMESPACE + + +pub def namespaced_function_ref(namespace: str, canonical_name: str) -> str: + """Return the durable registry reference for a function in one explicit namespace.""" + assert _valid_namespace(namespace), f"function namespace must be non-empty, found `{namespace}`" + assert _valid_canonical_name(canonical_name), (f"function canonical name must be one non-empty identifier segment, found `{canonical_name}`") + return f"{namespace}.{canonical_name}" pub def extension_mapping(function_name: str, anchor: u32) -> SubstraitMapping: @@ -287,6 +330,33 @@ pub def sort_field_mapping(direction: str) -> SubstraitMapping: ) +pub def function_policy_spec( + namespace: str, + policy_category: FunctionPolicyCategory, + function_class: FunctionClass, + aliases: list[str], + alias_policy: FunctionAliasPolicy, + lifecycle: FunctionLifecycle, + determinism: FunctionDeterminism, + null_behavior: FunctionNullBehavior, + error_behavior: FunctionErrorBehavior, + substrait: SubstraitMapping, +) -> FunctionSpec: + """Build one function spec with explicit RFC 024 namespace and policy metadata.""" + return FunctionSpec( + namespace=namespace, + policy_category=policy_category, + function_class=function_class, + aliases=aliases, + alias_policy=alias_policy, + lifecycle=lifecycle, + determinism=determinism, + null_behavior=null_behavior, + error_behavior=error_behavior, + substrait=substrait, + ) + + pub def deterministic_spec( function_class: FunctionClass, lifecycle: FunctionLifecycle, @@ -295,6 +365,8 @@ pub def deterministic_spec( ) -> FunctionSpec: """Build one deterministic core-import function spec.""" return FunctionSpec( + namespace=core_function_namespace(), + policy_category=FunctionPolicyCategory.PortableCore, function_class=function_class, aliases=[], alias_policy=FunctionAliasPolicy.CoreImport, @@ -306,6 +378,96 @@ pub def deterministic_spec( ) +pub def extension_only_spec( + namespace: str, + function_class: FunctionClass, + lifecycle: FunctionLifecycle, + determinism: FunctionDeterminism, + null_behavior: FunctionNullBehavior, + error_behavior: FunctionErrorBehavior, + substrait: SubstraitMapping, +) -> FunctionSpec: + """Build one explicitly namespaced extension-only function spec.""" + return function_policy_spec( + namespace, + FunctionPolicyCategory.ExtensionOnly, + function_class, + [], + FunctionAliasPolicy.OptInCompatibility, + lifecycle, + determinism, + null_behavior, + error_behavior, + substrait, + ) + + +pub def compatibility_alias_spec( + namespace: str, + function_class: FunctionClass, + aliases: list[str], + lifecycle: FunctionLifecycle, + determinism: FunctionDeterminism, + null_behavior: FunctionNullBehavior, + error_behavior: FunctionErrorBehavior, + substrait: SubstraitMapping, +) -> FunctionSpec: + """Build one opt-in compatibility alias spec for semantics that remain typeable by InQL.""" + return function_policy_spec( + namespace, + FunctionPolicyCategory.CompatibilityAlias, + function_class, + aliases, + FunctionAliasPolicy.OptInCompatibility, + lifecycle, + determinism, + null_behavior, + error_behavior, + substrait, + ) + + +pub def engine_specific_spec( + namespace: str, + function_class: FunctionClass, + lifecycle: FunctionLifecycle, + determinism: FunctionDeterminism, + null_behavior: FunctionNullBehavior, + error_behavior: FunctionErrorBehavior, + substrait: SubstraitMapping, +) -> FunctionSpec: + """Build one explicitly namespaced engine-specific function spec.""" + return function_policy_spec( + namespace, + FunctionPolicyCategory.EngineSpecific, + function_class, + [], + FunctionAliasPolicy.OptInCompatibility, + lifecycle, + determinism, + null_behavior, + error_behavior, + substrait, + ) + + +pub def rejected_function_policy(requested_name: str, reason: str, alternative: str) -> RejectedFunctionPolicy: + """Build metadata for a rejected compatibility request without creating a lowerable function entry.""" + assert len(requested_name) > 0, "rejected function request must have a name" + assert len(reason) > 0, "rejected function request must explain the policy reason" + return RejectedFunctionPolicy( + requested_name=requested_name, + policy_category=FunctionPolicyCategory.Rejected, + reason=reason, + alternative=alternative, + ) + + def _valid_canonical_name(canonical_name: str) -> bool: """Return whether one canonical function name can derive a durable single-segment reference.""" return len(canonical_name) > 0 and not canonical_name.contains(".") + + +def _valid_namespace(namespace: str) -> bool: + """Return whether a namespace can prefix durable function references.""" + return len(namespace) > 0 diff --git a/src/functions/aggregates/avg.incn b/src/functions/aggregates/avg.incn new file mode 100644 index 0000000..70c4758 --- /dev/null +++ b/src/functions/aggregates/avg.incn @@ -0,0 +1,44 @@ +""" +Average aggregate helper. + +`avg` records its aggregate signature and Substrait extension anchor at the helper declaration site. +""" + +from aggregate_builders import AggregateKind, AggregateMeasure, avg as avg_builder +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import AVG_FUNCTION_ANCHOR + + +@function_registry.add("avg", deterministic_spec( + FunctionClass.Aggregate, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NullSkippingAggregate, + extension_mapping("avg", AVG_FUNCTION_ANCHOR), +)) +pub def avg(expr: ColumnExpr) -> AggregateMeasure: + """ + Build an average aggregate measure. + + Examples: + mean_revenue = avg(col("amount")) + + Parameters: + expr: Numeric expression to aggregate. + """ + return avg_builder(expr) + + +module tests: + from projection_builders import col + def test_avg_builds_avg_aggregate_measure() -> None: + measure = avg(col("amount")) + assert measure.kind == AggregateKind.Avg diff --git a/src/functions/aggregates/count.incn b/src/functions/aggregates/count.incn index b3b5503..ffb6955 100644 --- a/src/functions/aggregates/count.incn +++ b/src/functions/aggregates/count.incn @@ -1,19 +1,26 @@ """ Count aggregate helper. -`count` is registered as a zero-argument aggregate with a concrete Substrait extension mapping. +`count` is registered as an aggregate with a concrete Substrait extension mapping. `count()` counts rows. `count_expr` +is a compatibility helper for expression counts until the decorated public helper surface can expose `count(expr)` +directly. """ -from aggregate_builders import AggregateKind, AggregateMeasure, count as count_builder +from aggregate_builders import AggregateKind, AggregateMeasure, count as count_builder, count_expr as count_expr_builder from function_registry import ( FunctionClass, + FunctionDeterminism, + FunctionErrorBehavior, FunctionLifecycle, FunctionNullBehavior, + compatibility_alias_spec, + core_function_namespace, deterministic_spec, extension_mapping, v0_1, ) from functions.registry import function_registry +from projection_builders import ColumnExpr, column_expr_name from substrait.function_extensions import COUNT_FUNCTION_ANCHOR @@ -33,7 +40,40 @@ pub def count() -> AggregateMeasure: return count_builder() +@function_registry.add("count_expr", compatibility_alias_spec( + core_function_namespace(), + FunctionClass.Aggregate, + ["count"], + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionDeterminism.Deterministic, + FunctionNullBehavior.NullSkippingAggregate, + FunctionErrorBehavior.Typechecked, + extension_mapping("count", COUNT_FUNCTION_ANCHOR), +)) +pub def count_expr(expr: ColumnExpr) -> AggregateMeasure: + """ + Build a count aggregate measure over non-null expression values. + + Examples: + populated_codes = count_expr(col("discount_code")) + + Parameters: + expr: Expression whose non-null values should be counted. + """ + return count_expr_builder(expr) + + module tests: + from projection_builders import col def test_count_builds_count_aggregate_measure() -> None: measure = count() assert measure.kind == AggregateKind.Count + assert measure.canonical_name == "count" + assert not measure.has_expr + assert column_expr_name(measure.expr) == "" + def test_count_with_expression_builds_count_aggregate_measure() -> None: + measure = count_expr(col("amount")) + assert measure.kind == AggregateKind.Count + assert measure.canonical_name == "count_expr" + assert measure.has_expr + assert column_expr_name(measure.expr) == "amount" diff --git a/src/functions/aggregates/max.incn b/src/functions/aggregates/max.incn new file mode 100644 index 0000000..b4c4836 --- /dev/null +++ b/src/functions/aggregates/max.incn @@ -0,0 +1,44 @@ +""" +Maximum aggregate helper. + +`max` records its aggregate signature and Substrait extension anchor at the helper declaration site. +""" + +from aggregate_builders import AggregateKind, AggregateMeasure, max as max_builder +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import MAX_FUNCTION_ANCHOR + + +@function_registry.add("max", deterministic_spec( + FunctionClass.Aggregate, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NullSkippingAggregate, + extension_mapping("max", MAX_FUNCTION_ANCHOR), +)) +pub def max(expr: ColumnExpr) -> AggregateMeasure: + """ + Build a maximum aggregate measure. + + Examples: + last_created = max(col("created_at")) + + Parameters: + expr: Orderable expression to aggregate. + """ + return max_builder(expr) + + +module tests: + from projection_builders import col + def test_max_builds_max_aggregate_measure() -> None: + measure = max(col("amount")) + assert measure.kind == AggregateKind.Max diff --git a/src/functions/aggregates/min.incn b/src/functions/aggregates/min.incn new file mode 100644 index 0000000..ad6eb5a --- /dev/null +++ b/src/functions/aggregates/min.incn @@ -0,0 +1,44 @@ +""" +Minimum aggregate helper. + +`min` records its aggregate signature and Substrait extension anchor at the helper declaration site. +""" + +from aggregate_builders import AggregateKind, AggregateMeasure, min as min_builder +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import MIN_FUNCTION_ANCHOR + + +@function_registry.add("min", deterministic_spec( + FunctionClass.Aggregate, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NullSkippingAggregate, + extension_mapping("min", MIN_FUNCTION_ANCHOR), +)) +pub def min(expr: ColumnExpr) -> AggregateMeasure: + """ + Build a minimum aggregate measure. + + Examples: + first_created = min(col("created_at")) + + Parameters: + expr: Orderable expression to aggregate. + """ + return min_builder(expr) + + +module tests: + from projection_builders import col + def test_min_builds_min_aggregate_measure() -> None: + measure = min(col("amount")) + assert measure.kind == AggregateKind.Min diff --git a/src/functions/aggregates/mod.incn b/src/functions/aggregates/mod.incn index 7c607df..92de101 100644 --- a/src/functions/aggregates/mod.incn +++ b/src/functions/aggregates/mod.incn @@ -1,4 +1,7 @@ """Aggregate function helpers.""" pub from functions.aggregates.sum import sum -pub from functions.aggregates.count import count +pub from functions.aggregates.count import count, count_expr +pub from functions.aggregates.avg import avg +pub from functions.aggregates.min import min +pub from functions.aggregates.max import max diff --git a/src/functions/math/abs.incn b/src/functions/math/abs.incn new file mode 100644 index 0000000..de9b0f1 --- /dev/null +++ b/src/functions/math/abs.incn @@ -0,0 +1,51 @@ +""" +Absolute-value math helper. + +`abs` is a registry-backed scalar application for the common numeric function catalog. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import ABS_FUNCTION_ANCHOR + + +@function_registry.add("abs", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("abs", ABS_FUNCTION_ANCHOR), +)) +pub def abs(expr: ColumnExpr) -> ColumnExpr: + """ + Build an absolute-value expression. + + Examples: + magnitude = abs(col("delta")) + + Parameters: + expr: Numeric expression to transform. + """ + return registered_application("abs", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_abs_builds_registered_application() -> None: + expr = abs(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "abs" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/math/ceil.incn b/src/functions/math/ceil.incn new file mode 100644 index 0000000..f7d37a8 --- /dev/null +++ b/src/functions/math/ceil.incn @@ -0,0 +1,51 @@ +""" +Ceiling math helper. + +`ceil` is a registry-backed scalar application for the common numeric function catalog. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import CEIL_FUNCTION_ANCHOR + + +@function_registry.add("ceil", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("ceil", CEIL_FUNCTION_ANCHOR), +)) +pub def ceil(expr: ColumnExpr) -> ColumnExpr: + """ + Build a ceiling expression. + + Examples: + next_whole_amount = ceil(col("amount")) + + Parameters: + expr: Numeric expression to round toward positive infinity. + """ + return registered_application("ceil", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_ceil_builds_registered_application() -> None: + expr = ceil(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "ceil" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/math/floor.incn b/src/functions/math/floor.incn new file mode 100644 index 0000000..f59b441 --- /dev/null +++ b/src/functions/math/floor.incn @@ -0,0 +1,51 @@ +""" +Floor math helper. + +`floor` is a registry-backed scalar application for the common numeric function catalog. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import FLOOR_FUNCTION_ANCHOR + + +@function_registry.add("floor", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("floor", FLOOR_FUNCTION_ANCHOR), +)) +pub def floor(expr: ColumnExpr) -> ColumnExpr: + """ + Build a floor expression. + + Examples: + previous_whole_amount = floor(col("amount")) + + Parameters: + expr: Numeric expression to round toward negative infinity. + """ + return registered_application("floor", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_floor_builds_registered_application() -> None: + expr = floor(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "floor" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/math/mod.incn b/src/functions/math/mod.incn new file mode 100644 index 0000000..58e66ee --- /dev/null +++ b/src/functions/math/mod.incn @@ -0,0 +1,6 @@ +"""Common math scalar helpers.""" + +pub from functions.math.abs import abs +pub from functions.math.ceil import ceil +pub from functions.math.floor import floor +pub from functions.math.round import round diff --git a/src/functions/math/round.incn b/src/functions/math/round.incn new file mode 100644 index 0000000..f271e27 --- /dev/null +++ b/src/functions/math/round.incn @@ -0,0 +1,52 @@ +""" +Nearest-integer rounding math helper. + +`round` is a registry-backed scalar application for the common numeric function catalog. The current helper is the +single-argument form only; precision arguments remain a later catalog slice. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import ROUND_FUNCTION_ANCHOR + + +@function_registry.add("round", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("round", ROUND_FUNCTION_ANCHOR), +)) +pub def round(expr: ColumnExpr) -> ColumnExpr: + """ + Build a nearest-integer rounding expression. + + Examples: + rounded_amount = round(col("amount")) + + Parameters: + expr: Numeric expression to round to the nearest integer value. + """ + return registered_application("round", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_round_builds_registered_application() -> None: + expr = round(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "round" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/mod.incn b/src/functions/mod.incn index b5eb257..3f5f3da 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -11,6 +11,7 @@ from functions.registry import ( function_registry_entries as raw_function_registry_entries, function_registry_entry as raw_function_registry_entry, function_registry_entry_by_name as raw_function_registry_entry_by_name, + function_registry_entry_by_namespace_and_name as raw_function_registry_entry_by_namespace_and_name, function_registry_entry_count as raw_function_registry_entry_count, function_registry_function_refs as raw_function_registry_function_refs, registered_substrait_mapped_function_refs as raw_registered_substrait_mapped_function_refs, @@ -27,8 +28,15 @@ pub from functions.literals.int_lit import int_lit pub from functions.literals.lit import lit pub from functions.literals.str_expr import str_expr pub from functions.literals.str_lit import str_lit -pub from functions.aggregates.count import count +pub from functions.aggregates.count import count, count_expr pub from functions.aggregates.sum import sum +pub from functions.aggregates.avg import avg +pub from functions.aggregates.min import min +pub from functions.aggregates.max import max +pub from functions.math.abs import abs +pub from functions.math.ceil import ceil +pub from functions.math.floor import floor +pub from functions.math.round import round pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div @@ -76,10 +84,18 @@ pub def function_registry_entry(function_ref: str) -> Option[FunctionRegistryEnt pub def function_registry_entry_by_name(canonical_name: str) -> Option[FunctionRegistryEntry]: - """Return a loaded registry entry for one canonical public function name when it is known.""" + """Return a loaded core registry entry for one canonical public function name when it is known.""" return raw_function_registry_entry_by_name(canonical_name) +pub def function_registry_entry_by_namespace_and_name( + namespace: str, + canonical_name: str, +) -> Option[FunctionRegistryEntry]: + """Return a loaded registry entry for one namespace-qualified function name when it is known.""" + return raw_function_registry_entry_by_namespace_and_name(namespace, canonical_name) + + pub def function_registry_function_refs() -> list[str]: """Return loaded function references in runtime registry order.""" return raw_function_registry_function_refs() diff --git a/src/functions/registry.incn b/src/functions/registry.incn index f0c5fc5..e3888a1 100644 --- a/src/functions/registry.incn +++ b/src/functions/registry.incn @@ -57,11 +57,16 @@ pub def function_registry_entry(function_ref: str) -> Option[FunctionRegistryEnt pub def function_registry_entry_by_name(canonical_name: str) -> Option[FunctionRegistryEntry]: - """Return a loaded registry entry for one canonical public function name when it is known.""" - for entry in function_registry.entries: - if entry.canonical_name == canonical_name: - return Some(entry) - return None + """Return a loaded core registry entry for one canonical public function name when it is known.""" + return function_registry.entry_by_name(canonical_name) + + +pub def function_registry_entry_by_namespace_and_name( + namespace: str, + canonical_name: str, +) -> Option[FunctionRegistryEntry]: + """Return a loaded registry entry for one namespace-qualified function name when it is known.""" + return function_registry.entry_by_namespace_and_name(namespace, canonical_name) pub def function_registry_function_refs() -> list[str]: diff --git a/src/lib.incn b/src/lib.incn index 713e064..82bf427 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -32,6 +32,7 @@ pub from functions import ( function_registry_entries, function_registry_entry, function_registry_entry_by_name, + function_registry_entry_by_namespace_and_name, function_registry_entry_count, function_registry_function_refs, registered_substrait_mapped_function_refs, @@ -47,8 +48,15 @@ pub from functions.literals.int_lit import int_lit pub from functions.literals.lit import lit pub from functions.literals.str_expr import str_expr pub from functions.literals.str_lit import str_lit -pub from functions.aggregates.count import count +pub from functions.aggregates.count import count, count_expr pub from functions.aggregates.sum import sum +pub from functions.aggregates.avg import avg +pub from functions.aggregates.min import min +pub from functions.aggregates.max import max +pub from functions.math.abs import abs +pub from functions.math.ceil import ceil +pub from functions.math.floor import floor +pub from functions.math.round import round pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div @@ -92,15 +100,24 @@ pub from function_registry import ( FunctionErrorBehavior, FunctionLifecycle, FunctionNullBehavior, + FunctionPolicyCategory, FunctionRegistry, FunctionRegistryEntry, FunctionSpec, FunctionVersion, + RejectedFunctionPolicy, SubstraitMapping, SubstraitMappingKind, + compatibility_alias_spec, + core_function_namespace, deterministic_spec, + engine_specific_spec, + extension_only_spec, extension_mapping, function_ref_for, + function_policy_spec, + namespaced_function_ref, + rejected_function_policy, rewrite_mapping, sort_field_mapping, structural_mapping, @@ -166,6 +183,7 @@ pub from substrait.plans import ( substrait_release_tag, ) pub from substrait.inspect import ( + aggregate_measure_function_names, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, diff --git a/src/substrait/extensions.incn b/src/substrait/extensions.incn index 98cca28..4db2b86 100644 --- a/src/substrait/extensions.incn +++ b/src/substrait/extensions.incn @@ -236,7 +236,10 @@ def _aggregate_extension_anchors_for_rel(rel: Rel) -> list[u32]: """Collect aggregate-function anchors used by one relation subtree in stable declaration order.""" mut anchors: list[u32] = [] for spec in _function_extension_specs(): - if spec.kind == ExtensionFunctionKind.Aggregate and _rel_uses_aggregate_function_anchor(rel.clone(), spec.anchor): + if (spec.kind == ExtensionFunctionKind.Aggregate and _rel_uses_aggregate_function_anchor( + rel.clone(), + spec.anchor, + ) and not anchors.contains(spec.anchor)): anchors.append(spec.anchor) return anchors @@ -250,7 +253,9 @@ def _scalar_extension_anchors_for_rel(rel: Rel) -> list[u32]: mut anchors: list[u32] = [] for spec in _function_extension_specs(): - if spec.kind == ExtensionFunctionKind.Scalar and _rel_uses_scalar_function_anchor(rel.clone(), spec.anchor): + if (spec.kind == ExtensionFunctionKind.Scalar and _rel_uses_scalar_function_anchor(rel.clone(), spec.anchor) and not anchors.contains( + spec.anchor, + )): anchors.append(spec.anchor) return anchors diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn index 409cf9d..0ad86f1 100644 --- a/src/substrait/function_extensions.incn +++ b/src/substrait/function_extensions.incn @@ -47,6 +47,13 @@ pub const NEGATE_FUNCTION_ANCHOR: u32 = 20 pub const COALESCE_FUNCTION_ANCHOR: u32 = 21 pub const NULLIF_FUNCTION_ANCHOR: u32 = 22 pub const BETWEEN_FUNCTION_ANCHOR: u32 = 23 +pub const AVG_FUNCTION_ANCHOR: u32 = 24 +pub const MIN_FUNCTION_ANCHOR: u32 = 25 +pub const MAX_FUNCTION_ANCHOR: u32 = 26 +pub const ABS_FUNCTION_ANCHOR: u32 = 27 +pub const CEIL_FUNCTION_ANCHOR: u32 = 28 +pub const FLOOR_FUNCTION_ANCHOR: u32 = 29 +pub const ROUND_FUNCTION_ANCHOR: u32 = 30 const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" diff --git a/src/substrait/inspect.incn b/src/substrait/inspect.incn index 0ffa860..f6f7d07 100644 --- a/src/substrait/inspect.incn +++ b/src/substrait/inspect.incn @@ -19,7 +19,7 @@ from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure from projection_builders import scalar_expr_output_name from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit -from substrait.function_extensions import COUNT_FUNCTION_ANCHOR +from substrait.extensions import aggregate_function_name_from_anchor from substrait.schema_registry import named_table_columns, unknown_named_struct from substrait.traversal import relation_children @@ -45,7 +45,7 @@ def _set_op_value(value: SetOp) -> RustI32: def _measure_output_name(measure: AggregateMeasure) -> str: """Return the projected output column name for one aggregate measure.""" - if measure.kind == AggregateKind.Count: + if measure.kind == AggregateKind.Count and not measure.has_expr: return measure.kind.value() return measure.kind.value() + "_" + scalar_expr_output_name(measure.expr, "expr") @@ -87,20 +87,19 @@ def _aggregate_measure_output_name(input_columns: list[str], measure: Measure) - """Resolve one aggregate output name from the lowered Substrait aggregate measure payload.""" match measure.measure: Some(agg_fn) => - if agg_fn.function_reference == COUNT_FUNCTION_ANCHOR: - return "count" + function_name = aggregate_function_name_from_anchor(agg_fn.function_reference) if len(agg_fn.arguments) == 0: - return "sum" + return function_name match agg_fn.arguments[0].arg_type: Some(arg_type) => match arg_type: ArgType.Value(expr) => field_index = field_index_from_expression(expr) if field_index >= 0 and field_index < len(input_columns): - return "sum_" + input_columns[field_index] - return "sum" - _ => return "sum" - _ => return "sum" + return function_name + "_" + input_columns[field_index] + return function_name + _ => return function_name + _ => return function_name None => return "" @@ -188,6 +187,18 @@ pub def relation_output_columns(rel: Rel) -> list[str]: return _relation_output_columns(rel) +pub def aggregate_measure_function_names(rel: Rel) -> list[str]: + """Return aggregate function names used by a top-level AggregateRel, otherwise empty.""" + match rel.rel_type: + Some(RelType.Aggregate(aggregate_rel)) => + mut names: list[str] = [] + for measure in aggregate_rel.measures: + if let Some(agg_fn) = measure.measure: + names.append(aggregate_function_name_from_anchor(agg_fn.function_reference)) + return names + _ => return [] + + pub def root_rel(plan: Plan) -> Rel: """Return the logical root relation from a plan.""" if len(plan.relations) == 0: diff --git a/src/substrait/mod.incn b/src/substrait/mod.incn index 523a858..2ed8dc8 100644 --- a/src/substrait/mod.incn +++ b/src/substrait/mod.incn @@ -53,6 +53,7 @@ pub from substrait.plans import ( ) pub from substrait.inspect import ( aggregate_group_columns, + aggregate_measure_function_names, aggregate_measure_output_names, plan_contains_relation_kind, plan_has_extension_urn, diff --git a/src/substrait/relations.incn b/src/substrait/relations.incn index 948b3b5..739e421 100644 --- a/src/substrait/relations.incn +++ b/src/substrait/relations.incn @@ -44,7 +44,7 @@ from rust::substrait::proto::rel_common import Direct, Emit, EmitKind from rust::substrait::proto::set_rel import SetOp from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure -from function_registry import SubstraitMappingKind +from function_registry import FunctionClass, SubstraitMappingKind from functions.registry import function_registry_entry from projection_builders import ColumnExpr, ProjectionAssignment, ScalarFunctionApplicationExpr, col from substrait.expr_lowering import ( @@ -57,14 +57,16 @@ from substrait.expr_lowering import ( string_expr, ) from substrait.errors import SubstraitLoweringError, invalid_scalar_expression -from substrait.function_extensions import COUNT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR from substrait.inspect import relation_output_columns from substrait.schema_registry import named_table_base_schema, unknown_named_struct model ResolvedAggregateMeasure: kind: AggregateKind + function_ref: str + canonical_name: str expr: Expression + has_expr: bool @derive(Clone) @@ -193,38 +195,76 @@ def _set_operation_from_name(operation: str) -> Result[SubstraitSetOperation, Su return Err(invalid_scalar_expression(f"unknown set operation '{operation}'")) -def _resolved_measure_to_substrait(measure: ResolvedAggregateMeasure) -> Measure: +def _resolved_measure_to_substrait(measure: ResolvedAggregateMeasure) -> Result[Measure, SubstraitLoweringError]: """Lower one resolved aggregate measure into a Substrait aggregate measure.""" mut arguments: list[FunctionArgument] = [] - mut function_reference = COUNT_FUNCTION_ANCHOR - if measure.kind == AggregateKind.Sum: - function_reference = SUM_FUNCTION_ANCHOR + if measure.has_expr: arguments = [FunctionArgument(arg_type=Some(ArgType.Value(measure.expr.clone())))] - return Measure( - measure=Some( - AggregateFunction( - function_reference=function_reference, - arguments=arguments, - sorts=[], - output_type=None, - invocation=AggregationInvocation.All.into(), - phase=AggregationPhase.Unspecified.into(), - args=[], - options=[], + return Ok( + Measure( + measure=Some( + AggregateFunction( + function_reference=_aggregate_function_reference(measure)?, + arguments=arguments, + sorts=[], + output_type=None, + invocation=AggregationInvocation.All.into(), + phase=AggregationPhase.Unspecified.into(), + args=[], + options=[], + ), ), + filter=None, ), - filter=None, ) +def _aggregate_function_reference(measure: ResolvedAggregateMeasure) -> Result[u32, SubstraitLoweringError]: + """Resolve one aggregate measure through declaration-side registry metadata.""" + match function_registry_entry(measure.function_ref): + Some(entry) => + if entry.function_class != FunctionClass.Aggregate: + return Err(invalid_scalar_expression(f"{entry.function_ref} is not registered as an aggregate function")) + if entry.substrait.kind != SubstraitMappingKind.ExtensionFunction: + return Err( + invalid_scalar_expression( + f"{entry.function_ref} does not declare a concrete Substrait aggregate mapping", + ), + ) + return Ok(entry.substrait.anchor) + None => + return Err(invalid_scalar_expression(f"missing aggregate registry entry for `{measure.canonical_name}`")) + + +def _is_argument_free_count(measure: AggregateMeasure) -> bool: + """Return whether one count measure represents `count()` rather than `count(expr)`.""" + return measure.kind == AggregateKind.Count and not measure.has_expr + + def _resolved_measure( measure: AggregateMeasure, input_columns: list[str], ) -> Result[ResolvedAggregateMeasure, SubstraitLoweringError]: """Resolve one aggregate measure against the current input-column list.""" - if measure.kind == AggregateKind.Count: - return Ok(ResolvedAggregateMeasure(kind=measure.kind, expr=bool_expr(true))) - return Ok(ResolvedAggregateMeasure(kind=measure.kind, expr=scalar_expr(input_columns, measure.expr)?)) + if _is_argument_free_count(measure): + return Ok( + ResolvedAggregateMeasure( + kind=measure.kind, + function_ref=measure.function_ref, + canonical_name=measure.canonical_name, + expr=bool_expr(true), + has_expr=false, + ), + ) + return Ok( + ResolvedAggregateMeasure( + kind=measure.kind, + function_ref=measure.function_ref, + canonical_name=measure.canonical_name, + expr=scalar_expr(input_columns, measure.expr)?, + has_expr=true, + ), + ) def _lowered_rel_or_raise(result: Result[Rel, SubstraitLoweringError]) -> Rel: @@ -491,7 +531,7 @@ pub def try_aggregate_rel_of_columns( grouping_expressions=grouping_expr_copies, expression_references=grouping_reference_indexes(len(grouping_exprs)), )] - lowered_measures = [_resolved_measure_to_substrait(_resolved_measure(measure, input_columns)?) for measure in measures] + lowered_measures = [_resolved_measure_to_substrait(_resolved_measure(measure, input_columns)?)? for measure in measures] return Ok( _rel_aggregate( AggregateRel( diff --git a/tests/test_common_scalar_functions.incn b/tests/test_common_scalar_functions.incn new file mode 100644 index 0000000..039b238 --- /dev/null +++ b/tests/test_common_scalar_functions.incn @@ -0,0 +1,32 @@ +"""Test: common scalar helper surface.""" + +from functions import abs, ceil, col, floor, round +from function_registry import function_ref_for +from projection_builders import ( + ColumnExpr, + ColumnExprKind, + column_expr_argument_count, + column_expr_function_name, + column_expr_function_ref, + column_expr_kind, +) + + +def _assert_math_application(expr: ColumnExpr, expected_name: str) -> None: + """Assert one math helper uses the shared scalar application node.""" + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction, f"{expected_name} should use the scalar function kind" + assert column_expr_function_name(expr) == expected_name, f"{expected_name} should preserve its canonical name" + assert column_expr_function_ref(expr) == function_ref_for(expected_name), "math helper should preserve its registry function ref" + assert column_expr_argument_count(expr) == 1, f"{expected_name} should carry one scalar argument" + + +def test_common_scalar_functions__math_helpers_share_scalar_application_node() -> None: + """Assert the first common math slice uses the registry-backed scalar expression model.""" + # -- Arrange -- + amount = col("amount") + + # -- Act / Assert -- + _assert_math_application(abs(amount), "abs") + _assert_math_application(ceil(amount), "ceil") + _assert_math_application(floor(amount), "floor") + _assert_math_application(round(amount), "round") diff --git a/tests/test_dataset.incn b/tests/test_dataset.incn index 25c11ad..ca25527 100644 --- a/tests/test_dataset.incn +++ b/tests/test_dataset.incn @@ -2,19 +2,24 @@ from std.testing import fail_t, parametrize from metadata import inql_version +from aggregate_builders import AggregateKind from dataset import DataSet, BoundedDataSet, UnboundedDataSet, DataFrame, LazyFrame, DataStream, lazy_frame_named_table from dataset.materialization import DataFrameMaterialization from dataset.ops import filter_ds, join_ds from functions import ( always_false, always_true, + avg, bool_lit, col, count, + count_expr, float_expr, int_expr, int_lit, lit, + max, + min, mul, str_expr, str_lit, @@ -147,10 +152,19 @@ def test_smoke__dataset_types_are_published() -> None: # -- Act -- sum_result = sum(amount) count_result = count() + expression_count_result = count_expr(amount) + avg_result = avg(amount) + min_result = min(amount) + max_result = max(amount) # -- Assert -- assert column_expr_name(sum_result.expr) == "amount", "sum should preserve the selected input column expression" assert column_expr_name(count_result.expr) == "", "count should remain a zero-argument aggregate helper" + assert expression_count_result.has_expr, "count(expr) should mark expression-count measures explicitly" + assert column_expr_name(expression_count_result.expr) == "amount", "count(expr) should preserve the selected input column expression" + assert avg_result.kind == AggregateKind.Avg, "avg should build an aggregate measure" + assert min_result.kind == AggregateKind.Min, "min should build an aggregate measure" + assert max_result.kind == AggregateKind.Max, "max should build an aggregate measure" @parametrize("helper_name, expected_kind", [("int_expr", "IntLiteral"), ("float_expr", "FloatLiteral"), ("str_expr", "StringLiteral"), ("int_lit", "IntLiteral"), ("str_lit", "StringLiteral"), ("bool_lit", "BoolLiteral")]) diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index fe8fd07..5db41b1 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -3,6 +3,7 @@ from std.testing import fail_t from aggregate_builders import AggregateKind from functions import ( + abs, add, always_false, always_true, @@ -10,20 +11,24 @@ from functions import ( asc, asc_nulls_first, asc_nulls_last, + avg, between, bool_expr, bool_lit, case_when, cast, + ceil, col, coalesce, count, + count_expr, desc, desc_nulls_first, desc_nulls_last, div, eq, equal_null, + floor, float_expr, function_registry_canonical_names, function_registry_entries, @@ -43,6 +48,8 @@ from functions import ( lit, lt, lte, + max, + min, modulo, mul, ne, @@ -51,6 +58,7 @@ from functions import ( nullif, or_, registered_substrait_mapped_function_refs, + round, str_expr, str_lit, sub, @@ -60,21 +68,37 @@ from functions import ( from function_registry import ( FunctionAliasPolicy, FunctionClass, + FunctionDeterminism, + FunctionErrorBehavior, + FunctionLifecycle, + FunctionPolicyCategory, + FunctionRegistry, FunctionRegistryEntry, FunctionNullBehavior, SubstraitMappingKind, + core_function_namespace, + deterministic_spec, + extension_only_spec, + extension_mapping, function_ref_for, + namespaced_function_ref, + rejected_function_policy, + rewrite_mapping, v0_1, ) from projection_builders import ColumnExprKind, column_expr_kind from substrait.function_extensions import ( + ABS_FUNCTION_ANCHOR, ADD_FUNCTION_ANCHOR, AND_FUNCTION_ANCHOR, + AVG_FUNCTION_ANCHOR, BETWEEN_FUNCTION_ANCHOR, + CEIL_FUNCTION_ANCHOR, COALESCE_FUNCTION_ANCHOR, COUNT_FUNCTION_ANCHOR, DIVIDE_FUNCTION_ANCHOR, EQUAL_FUNCTION_ANCHOR, + FLOOR_FUNCTION_ANCHOR, GT_FUNCTION_ANCHOR, GTE_FUNCTION_ANCHOR, IS_NAN_FUNCTION_ANCHOR, @@ -83,6 +107,8 @@ from substrait.function_extensions import ( IS_NULL_FUNCTION_ANCHOR, LT_FUNCTION_ANCHOR, LTE_FUNCTION_ANCHOR, + MAX_FUNCTION_ANCHOR, + MIN_FUNCTION_ANCHOR, MODULUS_FUNCTION_ANCHOR, MULTIPLY_FUNCTION_ANCHOR, NEGATE_FUNCTION_ANCHOR, @@ -90,6 +116,7 @@ from substrait.function_extensions import ( NOT_FUNCTION_ANCHOR, NULLIF_FUNCTION_ANCHOR, OR_FUNCTION_ANCHOR, + ROUND_FUNCTION_ANCHOR, SUBTRACT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR, function_extension_uri, @@ -120,14 +147,43 @@ def _entry_by_name_or_fail(canonical_name: str) -> FunctionRegistryEntry: None => return fail_t("missing function registry entry by name") -def _expected_rfc015_registry_names() -> list[str]: - """Return the expected registered public helper names after RFC 015.""" - return ["col", "lit", "sum", "count", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last"] +def _local_entry_or_fail(registry: FunctionRegistry, function_ref: str) -> FunctionRegistryEntry: + """Return one entry from a local registry fixture or fail the test.""" + lookup_ref = f"{function_ref}" + match registry.entry(lookup_ref): + Some(entry) => return entry + None => return fail_t("missing local function registry entry") + + +def _local_entry_by_name_or_fail(registry: FunctionRegistry, canonical_name: str) -> FunctionRegistryEntry: + """Return one core-scoped local registry entry by canonical name or fail the test.""" + lookup_name = f"{canonical_name}" + match registry.entry_by_name(lookup_name): + Some(entry) => return entry + None => return fail_t("missing local core function registry entry by name") + + +def _local_entry_by_namespace_and_name_or_fail( + registry: FunctionRegistry, + namespace: str, + canonical_name: str, +) -> FunctionRegistryEntry: + """Return one namespace-scoped local registry entry by canonical name or fail the test.""" + lookup_namespace = f"{namespace}" + lookup_name = f"{canonical_name}" + match registry.entry_by_namespace_and_name(lookup_namespace, lookup_name): + Some(entry) => return entry + None => return fail_t("missing local namespaced function registry entry") + + +def _expected_registry_names() -> list[str]: + """Return the expected registered public helper names.""" + return ["col", "lit", "sum", "count", "count_expr", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round"] def _expected_substrait_mapped_names() -> list[str]: """Return helpers with concrete Substrait extension-function mappings.""" - return ["sum", "count", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between"] + return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round"] def _exercise_current_public_helpers() -> None: @@ -146,6 +202,10 @@ def _exercise_current_public_helpers() -> None: always_false() sum(amount) count() + count_expr(status) + avg(amount) + min(amount) + max(amount) add(amount, lit(1)) mul(amount, int_lit(2)) eq(status, str_lit("paid")) @@ -179,6 +239,10 @@ def _exercise_current_public_helpers() -> None: asc_nulls_last(amount) desc_nulls_first(amount) desc_nulls_last(amount) + abs(amount) + ceil(amount) + floor(amount) + round(amount) return @@ -237,7 +301,7 @@ def test_function_registry__covers_current_public_helpers() -> None: """Assert that RFC 015 registry metadata covers the current public helper surface.""" # -- Arrange -- _exercise_current_public_helpers() - expected_names = _expected_rfc015_registry_names() + expected_names = _expected_registry_names() # -- Act -- entry_count = function_registry_entry_count() @@ -256,7 +320,7 @@ def test_function_registry__lifecycle_metadata_is_versioned() -> None: entries = function_registry_entries() # -- Act / Assert -- - assert len(entries) == len(_expected_rfc015_registry_names()), "lifecycle fixture should cover the current registry surface" + assert len(entries) == len(_expected_registry_names()), "lifecycle fixture should cover the current registry surface" for entry in entries: assert entry.lifecycle.since.major == v0_1.major, f"{entry.function_ref} should expose its introduction major version" assert entry.lifecycle.since.minor == v0_1.minor, f"{entry.function_ref} should expose its introduction minor version" @@ -283,6 +347,108 @@ def test_function_registry__lookup_exposes_canonical_metadata() -> None: assert count_entry.function_class == FunctionClass.Aggregate, "count should be classified as aggregate" +def test_function_registry__core_helpers_expose_portable_policy_metadata() -> None: + """Assert current helper entries distinguish portability policy from semantic function class.""" + # -- Arrange -- + _exercise_current_public_helpers() + entries = function_registry_entries() + + # -- Act / Assert -- + for entry in entries: + assert entry.namespace == core_function_namespace(), f"{entry.function_ref} should live in the core function namespace" + if entry.canonical_name == "count_expr": + assert entry.policy_category == FunctionPolicyCategory.CompatibilityAlias, "count_expr should be marked as a compatibility helper" + assert entry.alias_policy == FunctionAliasPolicy.OptInCompatibility, "compatibility helpers should be opt-in by policy" + continue + assert entry.policy_category == FunctionPolicyCategory.PortableCore, f"{entry.function_ref} should be portable core" + + +def test_function_registry__extension_policy_is_separate_from_scalar_class() -> None: + """Assert extension-only functions can be scalar while remaining explicitly namespaced.""" + # -- Arrange -- + mut registry = FunctionRegistry.new() + extension_namespace = "inql_ext.geo" + + # -- Act -- + registry._add_entry( + "st_srid", + extension_only_spec( + extension_namespace, + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionDeterminism.Deterministic, + FunctionNullBehavior.DependsOnInputs, + FunctionErrorBehavior.Typechecked, + extension_mapping("st_srid", 9000), + ), + ) + extension_entry = _local_entry_or_fail(registry, namespaced_function_ref(extension_namespace, "st_srid")) + + # -- Assert -- + assert extension_entry.namespace == extension_namespace, "extension-only entries should preserve explicit namespace" + assert extension_entry.function_ref == "inql_ext.geo.st_srid", "extension-only entries should derive namespaced function refs" + assert extension_entry.policy_category == FunctionPolicyCategory.ExtensionOnly, "extension-only entries should expose policy category" + assert extension_entry.function_class == FunctionClass.Scalar, "extension-only policy should not replace semantic function class" + assert extension_entry.substrait.kind == SubstraitMappingKind.ExtensionFunction, "extension-only entries may declare Substrait extension mappings" + + +def test_function_registry__canonical_name_lookup_stays_core_scoped() -> None: + """Assert explicit namespaces keep extension names from shadowing portable core names.""" + # -- Arrange -- + mut registry = FunctionRegistry.new() + extension_namespace = "inql_ext.crypto" + + # -- Act -- + registry._add_entry( + "hash", + deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + rewrite_mapping("portable hash fixture"), + ), + ) + registry._add_entry( + "hash", + extension_only_spec( + extension_namespace, + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionDeterminism.Deterministic, + FunctionNullBehavior.DependsOnInputs, + FunctionErrorBehavior.Typechecked, + extension_mapping("hash", 9001), + ), + ) + + core_entry = _local_entry_by_name_or_fail(registry, "hash") + extension_entry = _local_entry_by_namespace_and_name_or_fail(registry, extension_namespace, "hash") + + # -- Assert -- + assert core_entry.function_ref == function_ref_for("hash"), "canonical name lookup should use the core function reference" + assert core_entry.namespace == core_function_namespace(), "canonical name lookup should keep resolving core names" + assert core_entry.policy_category == FunctionPolicyCategory.PortableCore, "core-scoped lookup should return the portable core function" + assert extension_entry.function_ref == namespaced_function_ref(extension_namespace, "hash"), "namespace-qualified lookup should use the extension function reference" + assert extension_entry.namespace == extension_namespace, "namespace-qualified lookup should still find the extension entry" + assert extension_entry.policy_category == FunctionPolicyCategory.ExtensionOnly, "extension entry should keep its extension-only policy" + + +def test_function_registry__rejected_compatibility_requests_are_metadata_only() -> None: + """Assert rejected requests are documented metadata, not fake function entries.""" + # -- Arrange / Act -- + rejection = rejected_function_policy( + "spark_partition_id", + "physical Spark partition ids are not portable data logic", + "", + ) + + # -- Assert -- + assert rejection.policy_category == FunctionPolicyCategory.Rejected, "rejected requests should expose the rejected policy category" + assert rejection.requested_name == "spark_partition_id", "rejected metadata should preserve the requested compatibility name" + assert rejection.reason == "physical Spark partition ids are not portable data logic", "rejected metadata should preserve the policy reason" + assert rejection.alternative == "", "rejected metadata may omit an alternative when none exists" + + def test_function_registry__substrait_extension_mappings_are_structured() -> None: """Assert Substrait extension-backed helpers carry stable mapping facts.""" # -- Arrange -- @@ -297,6 +463,10 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non assert _contains_text(mapped_refs, function_ref_for(canonical_name)), "mapped helper should be in the Substrait extension mapping set" _assert_extension_mapping("sum", "sum", SUM_FUNCTION_ANCHOR) _assert_extension_mapping("count", "count", COUNT_FUNCTION_ANCHOR) + _assert_extension_mapping("count_expr", "count", COUNT_FUNCTION_ANCHOR) + _assert_extension_mapping("avg", "avg", AVG_FUNCTION_ANCHOR) + _assert_extension_mapping("min", "min", MIN_FUNCTION_ANCHOR) + _assert_extension_mapping("max", "max", MAX_FUNCTION_ANCHOR) _assert_extension_mapping("add", "add", ADD_FUNCTION_ANCHOR) _assert_extension_mapping("mul", "multiply", MULTIPLY_FUNCTION_ANCHOR) _assert_extension_mapping("eq", "equal", EQUAL_FUNCTION_ANCHOR) @@ -319,6 +489,10 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non _assert_extension_mapping("coalesce", "coalesce", COALESCE_FUNCTION_ANCHOR) _assert_extension_mapping("nullif", "nullif", NULLIF_FUNCTION_ANCHOR) _assert_extension_mapping("between", "between", BETWEEN_FUNCTION_ANCHOR) + _assert_extension_mapping("abs", "abs", ABS_FUNCTION_ANCHOR) + _assert_extension_mapping("ceil", "ceil", CEIL_FUNCTION_ANCHOR) + _assert_extension_mapping("floor", "floor", FLOOR_FUNCTION_ANCHOR) + _assert_extension_mapping("round", "round", ROUND_FUNCTION_ANCHOR) def test_function_registry__ordering_helpers_are_contextual_sort_fields() -> None: @@ -381,25 +555,35 @@ def test_function_registry__public_helpers_preserve_existing_behavior() -> None: # -- Act -- sum_measure = sum(amount) count_measure = count() + expression_count_measure = count_expr(status) + avg_measure = avg(amount) + min_measure = min(amount) + max_measure = max(amount) add_expr = add(amount, lit(7)) eq_expr = eq(status, str_lit("paid")) gt_expr = gt(amount, int_lit(10)) - core_exprs = [and_(eq_expr, gt_expr), asc(amount), asc_nulls_first(amount), asc_nulls_last(amount), between( + core_exprs = [abs(amount), and_(eq_expr, gt_expr), asc(amount), asc_nulls_first(amount), asc_nulls_last(amount), between( amount, int_lit(1), int_lit(10), - ), bool_expr(true), coalesce([status, str_expr("unknown")]), desc(amount), desc_nulls_first(amount), desc_nulls_last( + ), bool_expr(true), ceil(amount), coalesce([status, str_expr("unknown")]), desc(amount), desc_nulls_first(amount), desc_nulls_last( amount, - ), div(amount, lit(2)), equal_null(status, str_lit("paid")), float_expr(1.5), gte(amount, int_lit(10)), in_( + ), div(amount, lit(2)), equal_null(status, str_lit("paid")), float_expr(1.5), floor(amount), gte(amount, int_lit(10)), in_( status, [str_lit("paid"), str_lit("open")], - ), lt(amount, int_lit(10)), lte(amount, int_lit(10)), modulo(amount, lit(2))] + ), lt(amount, int_lit(10)), lte(amount, int_lit(10)), modulo(amount, lit(2)), round(amount)] # -- Assert -- assert column_expr_kind(amount) == ColumnExprKind.Column, "col should still build a column reference" assert column_expr_kind(lit(true)) == ColumnExprKind.BoolLiteral, "lit should still build typed literals" assert sum_measure.kind == AggregateKind.Sum, "sum wrapper should preserve aggregate kind" assert count_measure.kind == AggregateKind.Count, "count wrapper should preserve aggregate kind" + assert not count_measure.has_expr, "count wrapper should preserve argument-free count semantics" + assert expression_count_measure.kind == AggregateKind.Count, "count(expr) wrapper should preserve aggregate kind" + assert expression_count_measure.has_expr, "count(expr) wrapper should preserve expression-count semantics" + assert avg_measure.kind == AggregateKind.Avg, "avg wrapper should preserve aggregate kind" + assert min_measure.kind == AggregateKind.Min, "min wrapper should preserve aggregate kind" + assert max_measure.kind == AggregateKind.Max, "max wrapper should preserve aggregate kind" assert column_expr_kind(add_expr) == ColumnExprKind.ScalarFunction, "add should use the shared scalar function kind" assert column_expr_kind(eq_expr) == ColumnExprKind.ScalarFunction, "eq should use the shared scalar function kind" assert column_expr_kind(gt_expr) == ColumnExprKind.ScalarFunction, "gt should use the shared scalar function kind" diff --git a/tests/test_session_aggregates.incn b/tests/test_session_aggregates.incn index f40c31e..a8060ef 100644 --- a/tests/test_session_aggregates.incn +++ b/tests/test_session_aggregates.incn @@ -1,6 +1,6 @@ """End-to-end Session aggregate execution tests over the DataFusion backend.""" -from functions import col, count, sum +from functions import avg, col, count, count_expr, max, min, sum from dataset import LazyFrame from session import Session from std.testing import assert_is_ok @@ -46,6 +46,38 @@ def test_session_aggregates__grouped_collect_executes_sum_and_count() -> None: assert payload.contains("1"), "grouped aggregate output should contain the row-count for customer B" +def test_session_aggregates__grouped_collect_executes_core_aggregates() -> None: + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + grouped = lazy.group_by([col("customer_id")]).agg( + [sum(col("amount")), count(), count_expr(col("amount")), avg(col("amount")), min(col("amount")), max( + col("amount"), + )], + ) + df = assert_is_ok(session.collect(grouped), "mixed grouped aggregate collect should execute") + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 2, "mixed grouped aggregate should produce one row per customer" + assert resolved == ["customer_id", "sum_amount", "count", "count_amount", "avg_amount", "min_amount", "max_amount"], "mixed grouped aggregate should expose stable output columns" + assert payload.contains("A"), "mixed grouped aggregate output should contain customer A" + assert payload.contains("B"), "mixed grouped aggregate output should contain customer B" + assert payload.contains("25"), "customer A sum should be materialized" + assert payload.contains("12.5"), "customer A average should be materialized" + assert payload.contains("10"), "customer A minimum should be materialized" + assert payload.contains("15"), "customer A maximum should be materialized" + assert payload.contains("7"), "customer B aggregate values should be materialized" + assert payload.contains("2"), "customer A count and expression count should be materialized" + assert payload.contains("1"), "customer B count and expression count should be materialized" + + def test_session_aggregates__global_collect_executes_count() -> None: # -- Arrange -- mut session = Session.default() diff --git a/tests/test_session_projection.incn b/tests/test_session_projection.incn index c39809f..0c2777e 100644 --- a/tests/test_session_projection.incn +++ b/tests/test_session_projection.incn @@ -1,6 +1,26 @@ """End-to-end Session projection execution tests over the DataFusion backend.""" -from functions import add, case_when, cast, coalesce, col, desc, div, gt, lit, modulo, mul, neg, nullif, sub, try_cast +from functions import ( + abs, + add, + case_when, + cast, + ceil, + coalesce, + col, + desc, + div, + floor, + gt, + lit, + modulo, + mul, + neg, + nullif, + round, + sub, + try_cast, +) from dataset import DataFrame, LazyFrame from session import Session, SessionErrorKind from std.testing import assert_is_err, assert_is_ok, fail_t @@ -131,6 +151,40 @@ def test_session_projection__collect_executes_core_scalar_projection_functions() assert payload.contains("amount_bucket"), "case_when projection should materialize its alias" +def test_session_projection__collect_executes_common_math_scalar_projection_functions() -> None: + """collect should execute the first RFC 018 math scalar helpers through DataFusion.""" + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + projected = lazy.with_column("abs_delta", abs(sub(lit(5), col("amount")))).with_column( + "ceil_quarter", + ceil(div(col("amount"), lit(4.0))), + ).with_column("floor_quarter", floor(div(col("amount"), lit(4.0)))).with_column( + "round_quarter", + round(div(col("amount"), lit(4.0))), + ) + df = _collect_or_fail(session, projected) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 3, "math scalar projections should preserve the input rows" + assert len(resolved) == 6, "projection should expose all appended math outputs" + assert payload.contains("abs_delta"), "abs projection should materialize its alias" + assert payload.contains("ceil_quarter"), "ceil projection should materialize its alias" + assert payload.contains("floor_quarter"), "floor projection should materialize its alias" + assert payload.contains("round_quarter"), "round projection should materialize its alias" + assert payload.contains("10"), "abs projection should include abs(5 - 15)" + assert payload.contains("4"), "ceil projection should include ceil(15 / 4.0)" + assert payload.contains("1"), "floor projection should include floor(7 / 4.0)" + assert payload.contains("3"), "round projection should include round(10 / 4.0)" + + def test_session_projection__collect_executes_identity_select() -> None: # -- Arrange -- mut session = Session.default() diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 1cc59cf..60b3f49 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -2,21 +2,26 @@ from std.testing import fail_t from functions import ( + abs, add, always_true, and_, asc, asc_nulls_last, + avg, between, case_when, cast, + ceil, col, coalesce, count, + count_expr, desc, div, eq, equal_null, + floor, gt, gte, in_, @@ -27,6 +32,8 @@ from functions import ( lit, lt, lte, + max, + min, modulo, mul, ne, @@ -34,6 +41,7 @@ from functions import ( not_, nullif, or_, + round, sub, sum, try_cast, @@ -47,6 +55,7 @@ from substrait.function_extensions import ( registered_substrait_extension_uris, ) from substrait.inspect import ( + aggregate_measure_function_names, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -353,41 +362,65 @@ def test_plan__core_scalar_extension_mappings_lower_to_substrait() -> None: _assert_scalar_expr_lowers(coalesce([col("status"), lit("unknown")])) _assert_scalar_expr_lowers(nullif(col("status"), lit(""))) _assert_scalar_expr_lowers(between(col("amount"), lit(1), lit(10))) + _assert_scalar_expr_lowers(abs(col("amount"))) + _assert_scalar_expr_lowers(ceil(div(col("amount"), lit(4.0)))) + _assert_scalar_expr_lowers(floor(div(col("amount"), lit(4.0)))) + _assert_scalar_expr_lowers(round(div(col("amount"), lit(4.0)))) def test_plan__aggregate_rel_surfaces_group_and_measure_output_columns() -> None: # -- Arrange -- _register_orders_schema() base = read_named_table_rel("orders") - aggregated = aggregate_rel(base, [col("id")], [sum(col("id")), count()]) - plan = plan_from_root_relation(aggregated, ["id", "sum_id", "count"]) + aggregated = aggregate_rel( + base, + [col("id")], + [sum(col("id")), count(), count_expr(col("id")), avg(col("id")), min(col("id")), max(col("id"))], + ) + plan = plan_from_root_relation(aggregated, ["id", "sum_id", "count", "count_id", "avg_id", "min_id", "max_id"]) # -- Act -- output_columns = relation_output_columns(aggregated) + aggregate_functions = aggregate_measure_function_names(aggregated) # -- Assert -- assert relation_kind_name(aggregated) == "AggregateRel", "aggregate lowering should emit AggregateRel" assert plan_has_extension_urn(plan, registered_substrait_extension_uris()[0]), "aggregate function plans should register the shared function extension URN" - assert len(output_columns) == 3, "aggregate output columns should include group key plus measure outputs" + assert len(output_columns) == 7, "aggregate output columns should include group key plus measure outputs" assert output_columns[0] == "id", "group key should remain the first aggregate output column" assert output_columns[1] == "sum_id", "sum measure output columns should use the stable prefixed name" assert output_columns[2] == "count", "count measure output columns should use the stable count name" + assert output_columns[3] == "count_id", "expression-count output columns should include the argument name" + assert output_columns[4] == "avg_id", "avg measure output columns should use the stable prefixed name" + assert output_columns[5] == "min_id", "min measure output columns should use the stable prefixed name" + assert output_columns[6] == "max_id", "max measure output columns should use the stable prefixed name" + assert aggregate_functions == ["sum", "count", "count", "avg", "min", "max"], "all core aggregate helpers should lower to registered aggregate functions" def test_plan__aggregate_rel_accepts_scalar_group_and_measure_expressions() -> None: # -- Arrange -- _register_orders_schema() base = read_named_table_rel("orders") - aggregated = aggregate_rel(base, [add(col("id"), lit(1))], [sum(add(col("id"), lit(2)))]) + aggregated = aggregate_rel( + base, + [add(col("id"), lit(1))], + [sum(add(col("id"), lit(2))), count_expr(add(col("id"), lit(3))), avg(add(col("id"), lit(4))), min( + add(col("id"), lit(5)), + ), max(add(col("id"), lit(6)))], + ) # -- Act -- output_columns = relation_output_columns(aggregated) # -- Assert -- assert relation_kind_name(aggregated) == "AggregateRel", "scalar aggregate lowering should emit AggregateRel" - assert len(output_columns) == 2, "scalar aggregate output should include computed group key and measure" + assert len(output_columns) == 6, "scalar aggregate output should include computed group key and measures" assert output_columns[0] == "group_0", "computed grouping expressions should get stable fallback names" assert output_columns[1] == "sum", "computed aggregate measures should use stable fallback names" + assert output_columns[2] == "count", "computed expression-count measures should use stable fallback names" + assert output_columns[3] == "avg", "computed avg measures should use stable fallback names" + assert output_columns[4] == "min", "computed min measures should use stable fallback names" + assert output_columns[5] == "max", "computed max measures should use stable fallback names" def test_plan__set_rel_uses_operation_enum() -> None: From f0908231c51700097068a2ffdd97d44b9703e52a Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Mon, 25 May 2026 23:29:51 +0200 Subject: [PATCH 2/6] feature - implement RFC 017 aggregate modifiers --- .../language/reference/builders/aggregates.md | 23 ++++- docs/language/reference/functions/index.md | 3 +- docs/release_notes/v0_1.md | 1 + docs/rfcs/017_aggregate_modifiers.md | 16 ++-- src/aggregate_builders.incn | 61 +++++++++++++ src/function_registry.incn | 48 ++++++++++ src/functions/aggregates/count_distinct.incn | 55 ++++++++++++ src/functions/aggregates/count_if.incn | 55 ++++++++++++ src/functions/aggregates/mod.incn | 2 + src/functions/mod.incn | 2 + src/lib.incn | 5 ++ src/prism/store.incn | 41 +++++++-- src/substrait/extensions.incn | 77 +++++++++++++++- src/substrait/inspect.incn | 88 +++++++++++++++++-- src/substrait/mod.incn | 3 + src/substrait/relations.incn | 75 ++++++++++++++-- tests/fixtures/aggregate_modifiers.csv | 7 ++ tests/test_dataset.incn | 7 ++ tests/test_function_registry.incn | 38 +++++++- tests/test_prism.incn | 24 ++++- tests/test_session_aggregates.incn | 55 +++++++++++- tests/test_substrait_plan.incn | 61 ++++++++++++- 22 files changed, 702 insertions(+), 45 deletions(-) create mode 100644 src/functions/aggregates/count_distinct.incn create mode 100644 src/functions/aggregates/count_if.incn create mode 100644 tests/fixtures/aggregate_modifiers.csv diff --git a/docs/language/reference/builders/aggregates.md b/docs/language/reference/builders/aggregates.md index 092540a..626bf94 100644 --- a/docs/language/reference/builders/aggregates.md +++ b/docs/language/reference/builders/aggregates.md @@ -9,21 +9,36 @@ Current aggregate authoring is explicit and scalar-expression-based. | `col` | `def col(name: str) -> ColumnExpr` | Column reference builder used by aggregates, filters, and projections. | | `lit` | `def lit(value: int \| float \| str \| bool) -> ColumnExpr` | Canonical scalar literal helper. | | `sum` | `def sum(expr: ColumnExpr) -> AggregateMeasure` | Sum one scalar expression. | -| `count` | `def count() -> AggregateMeasure` | Count rows. | +| `count` | `def count() -> AggregateMeasure` | Count rows. | | `count_expr` | `def count_expr(expr: ColumnExpr) -> AggregateMeasure` | Count non-null expression values; compatibility spelling for the future `count(expr)` form. | +| `count_distinct` | `def count_distinct(expr: ColumnExpr) -> AggregateMeasure` | Count distinct non-null expression values. | +| `count_if` | `def count_if(predicate: ColumnExpr) -> AggregateMeasure` | Count rows where the predicate is true. | | `avg` | `def avg(expr: ColumnExpr) -> AggregateMeasure` | Average one numeric scalar expression. | | `min` | `def min(expr: ColumnExpr) -> AggregateMeasure` | Return the minimum non-null value for one orderable scalar expression. | | `max` | `def max(expr: ColumnExpr) -> AggregateMeasure` | Return the maximum non-null value for one orderable scalar expression. | +## Modifiers + +Aggregate measures support method-style modifiers: + +| Modifier | Signature | Meaning | +| --- | --- | --- | +| `distinct` | `measure.distinct() -> AggregateMeasure` | Apply SQL-style `DISTINCT` to aggregate input values. | +| `filter` | `measure.filter(predicate: ColumnExpr) -> AggregateMeasure` | Apply an aggregate-local boolean predicate before aggregation. | +| `order_by` | `measure.order_by(ordering: list[ColumnExpr]) -> AggregateMeasure` | Record ordered aggregate input. Core aggregates reject ordered input until an order-sensitive aggregate lands. | + ## Example ```incan -from pub::inql.functions import add, avg, col, count, count_expr, lit, max, min, sum +from pub::inql.functions import add, avg, col, count, count_distinct, count_expr, count_if, eq, lit, max, min, str_lit, sum grouped = orders.group_by([col("customer_id")]).agg([ sum(add(col("amount"), lit(5))), count(), count_expr(col("discount_code")), + count_distinct(col("product_id")), + count_if(eq(col("status"), str_lit("paid"))), + sum(col("amount")).filter(eq(col("status"), str_lit("paid"))), avg(col("amount")), min(col("created_at")), max(col("created_at")), @@ -35,5 +50,9 @@ grouped = orders.group_by([col("customer_id")]).agg([ - Aggregate inputs use the same scalar-expression model as filters, projections, and grouping keys. - `count()` counts rows. `count_expr(expr)` counts non-null values produced by the expression and lowers to the same canonical `count` Substrait extension function. +- `count_distinct(expr)` is compatibility sugar for `count_expr(expr).distinct()`. +- `count_if(predicate)` is compatibility sugar for `count().filter(predicate)`. Rows where the predicate is false or + null do not contribute to the aggregate. - `sum`, `avg`, `min`, and `max` skip null values. They return backend-null results when no non-null input value exists. +- Unsupported aggregate modifiers fail at lowering or backend planning; they are not ignored. - Future `.column` sugar and scoped aggregate symbols should lower to this same surface rather than replacing its semantics. diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index 36a54ba..8b32f31 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -10,7 +10,7 @@ Today the concrete shipped surfaces are documented here: The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, and aggregates. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, and aggregates. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. @@ -33,5 +33,6 @@ The registered helper surface currently includes: | `abs(...)`, `ceil(...)`, `floor(...)`, `round(...)` | scalar | registered Substrait math scalar mappings; `round(...)` is currently the single-argument form | | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count()`, `count_expr(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions; `count_expr(...)` is a compatibility spelling for future `count(expr)` helper overloading | +| `count_distinct(...)`, `count_if(...)` | aggregate | compatibility helpers that lower through aggregate modifiers over canonical `count` semantics | Future ANSI-style families should grow under this section instead of bloating `dataset_types` or `dataset_methods`. diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index f06d757..eacddef 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -11,6 +11,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Plans:** Apache Substrait as the logical interchange contract. - **Authoring:** method-chain lowering into a real Substrait boundary today, with `query {}` work still ahead. - **Aggregates:** builder-based `col`, `sum`, `count`, `count_expr`, `avg`, `min`, and `max` helpers now lower grouped and global aggregates through Prism, Substrait, and Session execution. `count()` counts rows, and `count_expr(expr)` counts non-null expression values while preserving the future `count(expr)` semantics. +- **Aggregate modifiers:** RFC 017 adds `AggregateMeasure.distinct()`, aggregate-local `filter(...)`, ordered-input representation with explicit rejection for current core aggregates, and compatibility helpers `count_distinct(...)` and `count_if(...)`. Distinct and filtered aggregate forms lower through Substrait and execute through the DataFusion-backed Session path. - **Scalar expressions:** RFC 012 unifies filter predicates, computed projection values, grouping keys, and aggregate inputs around one `ColumnExpr` surface with canonical `lit(...)` and typed literal helpers. - **Core scalar functions:** RFC 015 adds registry-backed scalar function applications and the first core helper slice for casts, comparisons, boolean logic, null/NaN predicates, arithmetic, conditionals, membership/range predicates, and ordering expressions. Implemented helpers lower to Substrait IR through registry metadata, built-in Rex shapes, or structural sort-field lowering; DataFusion remains the first execution adapter rather than the semantic boundary. - **Common scalar functions:** The first RFC 018 slice adds registry-backed math helpers for `abs(...)`, `ceil(...)`, `floor(...)`, and single-argument `round(...)`, with Substrait mappings and DataFusion-backed execution coverage. diff --git a/docs/rfcs/017_aggregate_modifiers.md b/docs/rfcs/017_aggregate_modifiers.md index d4fdcf1..42efe27 100644 --- a/docs/rfcs/017_aggregate_modifiers.md +++ b/docs/rfcs/017_aggregate_modifiers.md @@ -1,6 +1,6 @@ # InQL RFC 017: Aggregate modifiers -- **Status:** Draft +- **Status:** Implemented - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -11,8 +11,8 @@ - InQL RFC 016 (core aggregate functions) - **Issue:** [InQL #34](https://github.com/dannys-code-corner/InQL/issues/34) - **RFC PR:** — -- **Written against:** Incan v0.2 -- **Shipped in:** — +- **Written against:** Incan v0.3-era InQL +- **Shipped in:** v0.1 ## Summary @@ -111,9 +111,11 @@ Existing aggregate helpers remain valid. New compatibility helpers such as `coun - **Execution / interchange** — Prism and Substrait lowering must preserve filter, distinct, and ordering semantics or reject unsupported forms. - **Documentation** — aggregate docs should prefer the modifier model and list compatibility helper aliases. -## Unresolved questions +## Design Decisions -- Should `count_if(null)` count zero rows or follow a stricter boolean-null diagnostic rule? -- Which aggregate functions must allow ordered input in the initial modifier contract, especially `listagg`, `percentile_cont`, and `percentile_disc`? +### Resolved - +- `count_if(predicate)` follows aggregate `FILTER` semantics: rows where the predicate is false or null do not + contribute to the aggregate. +- The initial modifier contract records ordered aggregate input but no current core aggregate allows it. Ordered input + is rejected explicitly until an order-sensitive aggregate such as `listagg` or ordered percentile functions lands. diff --git a/src/aggregate_builders.incn b/src/aggregate_builders.incn index 320f60b..755f66b 100644 --- a/src/aggregate_builders.incn +++ b/src/aggregate_builders.incn @@ -30,16 +30,77 @@ pub model AggregateMeasure: pub canonical_name: str pub expr: ColumnExpr pub has_expr: bool + pub is_distinct: bool + pub filter_expr: ColumnExpr + pub has_filter: bool + pub ordering: list[ColumnExpr] + + def distinct(self) -> Self: + """Return this aggregate measure with `DISTINCT` input semantics.""" + return _aggregate_measure_with_modifiers( + self.canonical_name, + self.kind, + self.expr, + self.has_expr, + true, + self.filter_expr, + self.has_filter, + self.ordering, + ) + + def filter(self, predicate: ColumnExpr) -> Self: + """Return this aggregate measure with an aggregate-local row filter.""" + return _aggregate_measure_with_modifiers( + self.canonical_name, + self.kind, + self.expr, + self.has_expr, + self.is_distinct, + predicate, + true, + self.ordering, + ) + + def order_by(self, ordering: list[ColumnExpr]) -> Self: + """Return this aggregate measure with ordered aggregate input.""" + return _aggregate_measure_with_modifiers( + self.canonical_name, + self.kind, + self.expr, + self.has_expr, + self.is_distinct, + self.filter_expr, + self.has_filter, + ordering, + ) def _aggregate_measure(canonical_name: str, kind: AggregateKind, expr: ColumnExpr, has_expr: bool) -> AggregateMeasure: """Build one registry-backed aggregate measure description.""" + return _aggregate_measure_with_modifiers(canonical_name, kind, expr, has_expr, false, col_expr(""), false, []) + + +def _aggregate_measure_with_modifiers( + canonical_name: str, + kind: AggregateKind, + expr: ColumnExpr, + has_expr: bool, + is_distinct: bool, + filter_expr: ColumnExpr, + has_filter: bool, + ordering: list[ColumnExpr], +) -> AggregateMeasure: + """Build one aggregate measure preserving function identity and modifier state.""" return AggregateMeasure( kind=kind, function_ref=function_ref_for(canonical_name), canonical_name=canonical_name, expr=expr, has_expr=has_expr, + is_distinct=is_distinct, + filter_expr=filter_expr, + has_filter=has_filter, + ordering=ordering, ) diff --git a/src/function_registry.incn b/src/function_registry.incn index edfc0cb..b5642f9 100644 --- a/src/function_registry.incn +++ b/src/function_registry.incn @@ -118,6 +118,15 @@ pub model FunctionLifecycle: pub deprecated: Option[FunctionDeprecation] +@derive(Clone) +pub model AggregateModifierPolicy: + """Aggregate-specific modifier support recorded with registry metadata.""" + + pub allows_distinct: bool + pub allows_filter: bool + pub allows_ordered_input: bool + + @derive(Clone) pub model SubstraitMapping: """Portable interchange mapping metadata for one registered function.""" @@ -143,6 +152,7 @@ pub model FunctionSpec: pub determinism: FunctionDeterminism pub null_behavior: FunctionNullBehavior pub error_behavior: FunctionErrorBehavior + pub aggregate_modifiers: AggregateModifierPolicy pub substrait: SubstraitMapping @@ -161,6 +171,7 @@ pub model FunctionRegistryEntry: pub determinism: FunctionDeterminism pub null_behavior: FunctionNullBehavior pub error_behavior: FunctionErrorBehavior + pub aggregate_modifiers: AggregateModifierPolicy pub substrait: SubstraitMapping @@ -213,6 +224,7 @@ pub class FunctionRegistry: determinism=spec.determinism, null_behavior=spec.null_behavior, error_behavior=spec.error_behavior, + aggregate_modifiers=spec.aggregate_modifiers, substrait=spec.substrait, ), ) @@ -330,6 +342,36 @@ pub def sort_field_mapping(direction: str) -> SubstraitMapping: ) +pub def aggregate_modifier_policy( + allows_distinct: bool, + allows_filter: bool, + allows_ordered_input: bool, +) -> AggregateModifierPolicy: + """Build one aggregate modifier policy.""" + return AggregateModifierPolicy( + allows_distinct=allows_distinct, + allows_filter=allows_filter, + allows_ordered_input=allows_ordered_input, + ) + + +pub def no_aggregate_modifiers() -> AggregateModifierPolicy: + """Return the default modifier policy for non-aggregate functions.""" + return aggregate_modifier_policy(false, false, false) + + +pub def core_aggregate_modifier_policy() -> AggregateModifierPolicy: + """Return the default RFC 017 policy for core order-insensitive aggregates.""" + return aggregate_modifier_policy(true, true, false) + + +def _aggregate_modifier_policy_for_class(function_class: FunctionClass) -> AggregateModifierPolicy: + """Return the default modifier policy for a semantic function class.""" + if function_class == FunctionClass.Aggregate: + return core_aggregate_modifier_policy() + return no_aggregate_modifiers() + + pub def function_policy_spec( namespace: str, policy_category: FunctionPolicyCategory, @@ -340,6 +382,7 @@ pub def function_policy_spec( determinism: FunctionDeterminism, null_behavior: FunctionNullBehavior, error_behavior: FunctionErrorBehavior, + aggregate_modifiers: AggregateModifierPolicy, substrait: SubstraitMapping, ) -> FunctionSpec: """Build one function spec with explicit RFC 024 namespace and policy metadata.""" @@ -353,6 +396,7 @@ pub def function_policy_spec( determinism=determinism, null_behavior=null_behavior, error_behavior=error_behavior, + aggregate_modifiers=aggregate_modifiers, substrait=substrait, ) @@ -374,6 +418,7 @@ pub def deterministic_spec( determinism=FunctionDeterminism.Deterministic, null_behavior=null_behavior, error_behavior=FunctionErrorBehavior.Typechecked, + aggregate_modifiers=_aggregate_modifier_policy_for_class(function_class), substrait=substrait, ) @@ -398,6 +443,7 @@ pub def extension_only_spec( determinism, null_behavior, error_behavior, + no_aggregate_modifiers(), substrait, ) @@ -423,6 +469,7 @@ pub def compatibility_alias_spec( determinism, null_behavior, error_behavior, + _aggregate_modifier_policy_for_class(function_class), substrait, ) @@ -447,6 +494,7 @@ pub def engine_specific_spec( determinism, null_behavior, error_behavior, + no_aggregate_modifiers(), substrait, ) diff --git a/src/functions/aggregates/count_distinct.incn b/src/functions/aggregates/count_distinct.incn new file mode 100644 index 0000000..15eb88f --- /dev/null +++ b/src/functions/aggregates/count_distinct.incn @@ -0,0 +1,55 @@ +""" +Distinct count compatibility helper. + +`count_distinct(expr)` is sugar over the canonical aggregate modifier model: `count_expr(expr).distinct()`. +""" + +from aggregate_builders import AggregateKind, AggregateMeasure +from function_registry import ( + FunctionClass, + FunctionDeterminism, + FunctionErrorBehavior, + FunctionLifecycle, + FunctionNullBehavior, + compatibility_alias_spec, + core_function_namespace, + rewrite_mapping, + v0_1, +) +from functions.aggregates.count import count_expr +from functions.registry import function_registry +from projection_builders import ColumnExpr, col, column_expr_name + + +@function_registry.add("count_distinct", compatibility_alias_spec( + core_function_namespace(), + FunctionClass.Aggregate, + ["count"], + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionDeterminism.Deterministic, + FunctionNullBehavior.NullSkippingAggregate, + FunctionErrorBehavior.Typechecked, + rewrite_mapping("count_expr(expr).distinct()"), +)) +pub def count_distinct(expr: ColumnExpr) -> AggregateMeasure: + """ + Count distinct non-null values produced by one expression. + + Examples: + products = count_distinct(col("product_id")) + + Parameters: + expr: Expression whose distinct non-null values should be counted. + """ + return count_expr(expr).distinct() + + +module tests: + def test_count_distinct_builds_distinct_count_measure() -> None: + """count_distinct should build a distinct expression-count measure.""" + measure = count_distinct(col("product_id")) + assert measure.kind == AggregateKind.Count + assert measure.has_expr + assert measure.is_distinct + assert not measure.has_filter + assert column_expr_name(measure.expr) == "product_id" diff --git a/src/functions/aggregates/count_if.incn b/src/functions/aggregates/count_if.incn new file mode 100644 index 0000000..92e7dcf --- /dev/null +++ b/src/functions/aggregates/count_if.incn @@ -0,0 +1,55 @@ +""" +Filtered count compatibility helper. + +`count_if(predicate)` is sugar over the canonical aggregate modifier model: `count().filter(predicate)`. +""" + +from aggregate_builders import AggregateKind, AggregateMeasure +from function_registry import ( + FunctionClass, + FunctionDeterminism, + FunctionErrorBehavior, + FunctionLifecycle, + FunctionNullBehavior, + compatibility_alias_spec, + core_function_namespace, + rewrite_mapping, + v0_1, +) +from functions.aggregates.count import count +from functions.registry import function_registry +from projection_builders import ColumnExpr, col, column_expr_name + + +@function_registry.add("count_if", compatibility_alias_spec( + core_function_namespace(), + FunctionClass.Aggregate, + ["count"], + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionDeterminism.Deterministic, + FunctionNullBehavior.NullSkippingAggregate, + FunctionErrorBehavior.Typechecked, + rewrite_mapping("count().filter(predicate)"), +)) +pub def count_if(predicate: ColumnExpr) -> AggregateMeasure: + """ + Count rows where a predicate evaluates to true. + + Examples: + completed = count_if(col("is_completed")) + + Parameters: + predicate: Boolean expression that decides whether each input row contributes to the count. + """ + return count().filter(predicate) + + +module tests: + def test_count_if_builds_filtered_count_measure() -> None: + """count_if should build a filtered row-count measure.""" + measure = count_if(col("is_completed")) + assert measure.kind == AggregateKind.Count + assert not measure.has_expr + assert measure.has_filter + assert not measure.is_distinct + assert column_expr_name(measure.filter_expr) == "is_completed" diff --git a/src/functions/aggregates/mod.incn b/src/functions/aggregates/mod.incn index 92de101..c0180bd 100644 --- a/src/functions/aggregates/mod.incn +++ b/src/functions/aggregates/mod.incn @@ -2,6 +2,8 @@ pub from functions.aggregates.sum import sum pub from functions.aggregates.count import count, count_expr +pub from functions.aggregates.count_distinct import count_distinct +pub from functions.aggregates.count_if import count_if pub from functions.aggregates.avg import avg pub from functions.aggregates.min import min pub from functions.aggregates.max import max diff --git a/src/functions/mod.incn b/src/functions/mod.incn index 3f5f3da..4ec6028 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -29,6 +29,8 @@ pub from functions.literals.lit import lit pub from functions.literals.str_expr import str_expr pub from functions.literals.str_lit import str_lit pub from functions.aggregates.count import count, count_expr +pub from functions.aggregates.count_distinct import count_distinct +pub from functions.aggregates.count_if import count_if pub from functions.aggregates.sum import sum pub from functions.aggregates.avg import avg pub from functions.aggregates.min import min diff --git a/src/lib.incn b/src/lib.incn index 82bf427..7b96f4a 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -49,6 +49,8 @@ pub from functions.literals.lit import lit pub from functions.literals.str_expr import str_expr pub from functions.literals.str_lit import str_lit pub from functions.aggregates.count import count, count_expr +pub from functions.aggregates.count_distinct import count_distinct +pub from functions.aggregates.count_if import count_if pub from functions.aggregates.sum import sum pub from functions.aggregates.avg import avg pub from functions.aggregates.min import min @@ -183,7 +185,10 @@ pub from substrait.plans import ( substrait_release_tag, ) pub from substrait.inspect import ( + aggregate_measure_filter_flags, aggregate_measure_function_names, + aggregate_measure_invocation_names, + aggregate_measure_sort_counts, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, diff --git a/src/prism/store.incn b/src/prism/store.incn index d2e38a5..d620574 100644 --- a/src/prism/store.incn +++ b/src/prism/store.incn @@ -230,16 +230,8 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema return false if not _column_expr_lists_structurally_equal(candidate.sort_columns, source_node.sort_columns): return false - if len(candidate.aggregate_measures) != len(source_node.aggregate_measures): + if not _aggregate_measure_lists_structurally_equal(candidate.aggregate_measures, source_node.aggregate_measures): return false - for idx in range(len(candidate.aggregate_measures)): - if candidate.aggregate_measures[idx].kind != source_node.aggregate_measures[idx].kind: - return false - if not _column_exprs_structurally_equal( - candidate.aggregate_measures[idx].expr, - source_node.aggregate_measures[idx].expr, - ): - return false if not _projection_assignments_structurally_equal( candidate.projection_assignments, source_node.projection_assignments, @@ -248,6 +240,37 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema return candidate.input_ids == remapped_input_ids +def _aggregate_measure_lists_structurally_equal(left: list[AggregateMeasure], right: list[AggregateMeasure]) -> bool: + """Return whether two aggregate-measure lists carry identical semantics.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if not _aggregate_measures_structurally_equal(left[idx], right[idx]): + return false + return true + + +def _aggregate_measures_structurally_equal(left: AggregateMeasure, right: AggregateMeasure) -> bool: + """Return whether two aggregate measures carry identical registry identity and modifier state.""" + if left.kind != right.kind: + return false + if left.function_ref != right.function_ref: + return false + if left.canonical_name != right.canonical_name: + return false + if left.has_expr != right.has_expr: + return false + if left.is_distinct != right.is_distinct: + return false + if left.has_filter != right.has_filter: + return false + if not _column_exprs_structurally_equal(left.expr, right.expr): + return false + if not _column_exprs_structurally_equal(left.filter_expr, right.filter_expr): + return false + return _column_expr_lists_structurally_equal(left.ordering, right.ordering) + + def _filter_predicates_structurally_equal(left: ColumnExpr, right: ColumnExpr) -> bool: """Return whether two filter scalar expressions are structurally equivalent.""" return _column_exprs_structurally_equal(left, right) diff --git a/src/substrait/extensions.incn b/src/substrait/extensions.incn index 4db2b86..f4efeb4 100644 --- a/src/substrait/extensions.incn +++ b/src/substrait/extensions.incn @@ -6,7 +6,7 @@ expression trees. """ from rust::incan_stdlib::errors import raise_value_error -from rust::substrait::proto import Expression, Rel +from rust::substrait::proto import AggregateFunction, Expression, FunctionArgument, Rel, SortField from rust::substrait::proto::extensions import SimpleExtensionDeclaration, SimpleExtensionUrn from rust::substrait::proto::extensions::simple_extension_declaration import ExtensionFunction, MappingType from rust::substrait::proto::function_argument import ArgType @@ -180,6 +180,59 @@ def _expr_uses_if_then(expr: Expression) -> bool: return false +def _function_argument_uses_scalar_function_anchor(argument: FunctionArgument, expected_anchor: u32) -> bool: + """Return whether one function argument contains the requested scalar-function anchor.""" + if let Some(ArgType.Value(value)) = argument.arg_type: + return _expr_uses_scalar_function_anchor(value, expected_anchor) + return false + + +def _sort_field_uses_scalar_function_anchor(sort: SortField, expected_anchor: u32) -> bool: + """Return whether one sort field expression contains the requested scalar-function anchor.""" + if let Some(expr) = sort.expr: + return _expr_uses_scalar_function_anchor(expr, expected_anchor) + return false + + +def _aggregate_function_uses_scalar_function_anchor( + aggregate_function: AggregateFunction, + expected_anchor: u32, +) -> bool: + """Return whether one aggregate function payload contains the requested scalar-function anchor.""" + for argument in aggregate_function.arguments: + if _function_argument_uses_scalar_function_anchor(argument, expected_anchor): + return true + for sort in aggregate_function.sorts: + if _sort_field_uses_scalar_function_anchor(sort, expected_anchor): + return true + return false + + +def _function_argument_uses_if_then(argument: FunctionArgument) -> bool: + """Return whether one function argument contains a Substrait IfThen Rex shape.""" + if let Some(ArgType.Value(value)) = argument.arg_type: + return _expr_uses_if_then(value) + return false + + +def _sort_field_uses_if_then(sort: SortField) -> bool: + """Return whether one sort field expression contains a Substrait IfThen Rex shape.""" + if let Some(expr) = sort.expr: + return _expr_uses_if_then(expr) + return false + + +def _aggregate_function_uses_if_then(aggregate_function: AggregateFunction) -> bool: + """Return whether one aggregate function payload contains a Substrait IfThen Rex shape.""" + for argument in aggregate_function.arguments: + if _function_argument_uses_if_then(argument): + return true + for sort in aggregate_function.sorts: + if _sort_field_uses_if_then(sort): + return true + return false + + def _rel_uses_aggregate_function_anchor(rel: Rel, expected_anchor: u32) -> bool: """Return whether one relation subtree uses the requested aggregate-function anchor.""" if let Some(RelType.Aggregate(aggregate_rel)) = rel.rel_type.clone(): @@ -205,6 +258,17 @@ def _rel_uses_scalar_function_anchor(rel: Rel, expected_anchor: u32) -> bool: for expr in project_rel.expressions: if _expr_uses_scalar_function_anchor(expr, expected_anchor): return true + Some(RelType.Aggregate(aggregate_rel)) => + for grouping_expr in aggregate_rel.grouping_expressions: + if _expr_uses_scalar_function_anchor(grouping_expr, expected_anchor): + return true + for measure in aggregate_rel.measures: + if let Some(filter_expr) = measure.filter: + if _expr_uses_scalar_function_anchor(filter_expr, expected_anchor): + return true + if let Some(agg_fn) = measure.measure: + if _aggregate_function_uses_scalar_function_anchor(agg_fn, expected_anchor): + return true _ => pass for child in relation_children(rel): @@ -224,6 +288,17 @@ def _rel_uses_if_then(rel: Rel) -> bool: for expr in project_rel.expressions: if _expr_uses_if_then(expr): return true + Some(RelType.Aggregate(aggregate_rel)) => + for grouping_expr in aggregate_rel.grouping_expressions: + if _expr_uses_if_then(grouping_expr): + return true + for measure in aggregate_rel.measures: + if let Some(filter_expr) = measure.filter: + if _expr_uses_if_then(filter_expr): + return true + if let Some(agg_fn) = measure.measure: + if _aggregate_function_uses_if_then(agg_fn): + return true _ => pass for child in relation_children(rel): diff --git a/src/substrait/inspect.incn b/src/substrait/inspect.incn index f6f7d07..37005e0 100644 --- a/src/substrait/inspect.incn +++ b/src/substrait/inspect.incn @@ -9,6 +9,7 @@ from rust::std::boxed import Box from rust::std::primitive import i32 as RustI32 from rust::substrait::proto import AggregateRel, Expression, Plan, ReadRel, Rel, RelCommon from rust::substrait::proto::aggregate_rel import Measure +from rust::substrait::proto::aggregate_function import AggregationInvocation from rust::substrait::proto::function_argument import ArgType from rust::substrait::proto::plan_rel import RelType as PlanRelType from rust::substrait::proto::read_rel import NamedTable as ReadNamedTable, ReadType @@ -43,11 +44,30 @@ def _set_op_value(value: SetOp) -> RustI32: return value.into() +def _aggregation_invocation_value(value: AggregationInvocation) -> RustI32: + """Convert one Substrait aggregation invocation into its stored Rust `i32` discriminant.""" + return value.into() + + +def _measure_has_filter(measure: Measure) -> bool: + """Return whether one aggregate measure carries an aggregate-local filter.""" + match measure.filter: + Some(_) => return true + None => return false + + def _measure_output_name(measure: AggregateMeasure) -> str: """Return the projected output column name for one aggregate measure.""" - if measure.kind == AggregateKind.Count and not measure.has_expr: - return measure.kind.value() - return measure.kind.value() + "_" + scalar_expr_output_name(measure.expr, "expr") + mut output_name = measure.kind.value() + if measure.is_distinct and measure.has_expr: + output_name = output_name + "_distinct_" + scalar_expr_output_name(measure.expr, "expr") + elif measure.has_expr: + output_name = output_name + "_" + scalar_expr_output_name(measure.expr, "expr") + if measure.has_filter: + output_name = output_name + "_filtered" + if len(measure.ordering) > 0: + output_name = output_name + "_ordered" + return output_name pub def aggregate_measure_output_names(measures: list[AggregateMeasure]) -> list[str]: @@ -85,21 +105,32 @@ def _grouping_output_columns(input_columns: list[str], grouping_expressions: lis def _aggregate_measure_output_name(input_columns: list[str], measure: Measure) -> str: """Resolve one aggregate output name from the lowered Substrait aggregate measure payload.""" + measure_has_filter = _measure_has_filter(measure.clone()) match measure.measure: Some(agg_fn) => function_name = aggregate_function_name_from_anchor(agg_fn.function_reference) + is_distinct = agg_fn.invocation == _aggregation_invocation_value(AggregationInvocation.Distinct) + mut suffix = "" + if measure_has_filter: + suffix = suffix + "_filtered" + if len(agg_fn.sorts) > 0: + suffix = suffix + "_ordered" if len(agg_fn.arguments) == 0: - return function_name + return function_name + suffix match agg_fn.arguments[0].arg_type: Some(arg_type) => match arg_type: ArgType.Value(expr) => field_index = field_index_from_expression(expr) if field_index >= 0 and field_index < len(input_columns): - return function_name + "_" + input_columns[field_index] - return function_name - _ => return function_name - _ => return function_name + if is_distinct: + return function_name + "_distinct_" + input_columns[field_index] + suffix + return function_name + "_" + input_columns[field_index] + suffix + if is_distinct: + return function_name + "_distinct" + suffix + return function_name + suffix + _ => return function_name + suffix + _ => return function_name + suffix None => return "" @@ -199,6 +230,47 @@ pub def aggregate_measure_function_names(rel: Rel) -> list[str]: _ => return [] +def _aggregation_invocation_name(value: RustI32) -> str: + """Return the stable aggregate invocation name for one Substrait enum value.""" + if value == _aggregation_invocation_value(AggregationInvocation.Distinct): + return "Distinct" + if value == _aggregation_invocation_value(AggregationInvocation.All): + return "All" + return "Unspecified" + + +pub def aggregate_measure_invocation_names(rel: Rel) -> list[str]: + """Return aggregate invocation names used by a top-level AggregateRel, otherwise empty.""" + match rel.rel_type: + Some(RelType.Aggregate(aggregate_rel)) => + mut names: list[str] = [] + for measure in aggregate_rel.measures: + if let Some(agg_fn) = measure.measure: + names.append(_aggregation_invocation_name(agg_fn.invocation)) + return names + _ => return [] + + +pub def aggregate_measure_filter_flags(rel: Rel) -> list[bool]: + """Return whether each top-level aggregate measure carries an aggregate-local filter.""" + match rel.rel_type: + Some(RelType.Aggregate(aggregate_rel)) => + return [_measure_has_filter(measure) for measure in aggregate_rel.measures] + _ => return [] + + +pub def aggregate_measure_sort_counts(rel: Rel) -> list[int]: + """Return sort-field counts for each top-level aggregate measure.""" + match rel.rel_type: + Some(RelType.Aggregate(aggregate_rel)) => + mut counts: list[int] = [] + for measure in aggregate_rel.measures: + if let Some(agg_fn) = measure.measure: + counts.append(len(agg_fn.sorts)) + return counts + _ => return [] + + pub def root_rel(plan: Plan) -> Rel: """Return the logical root relation from a plan.""" if len(plan.relations) == 0: diff --git a/src/substrait/mod.incn b/src/substrait/mod.incn index 2ed8dc8..16e0f38 100644 --- a/src/substrait/mod.incn +++ b/src/substrait/mod.incn @@ -53,8 +53,11 @@ pub from substrait.plans import ( ) pub from substrait.inspect import ( aggregate_group_columns, + aggregate_measure_filter_flags, aggregate_measure_function_names, + aggregate_measure_invocation_names, aggregate_measure_output_names, + aggregate_measure_sort_counts, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, diff --git a/src/substrait/relations.incn b/src/substrait/relations.incn index 739e421..849beba 100644 --- a/src/substrait/relations.incn +++ b/src/substrait/relations.incn @@ -44,7 +44,7 @@ from rust::substrait::proto::rel_common import Direct, Emit, EmitKind from rust::substrait::proto::set_rel import SetOp from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure -from function_registry import FunctionClass, SubstraitMappingKind +from function_registry import FunctionClass, FunctionRegistryEntry, SubstraitMappingKind from functions.registry import function_registry_entry from projection_builders import ColumnExpr, ProjectionAssignment, ScalarFunctionApplicationExpr, col from substrait.expr_lowering import ( @@ -61,12 +61,17 @@ from substrait.inspect import relation_output_columns from substrait.schema_registry import named_table_base_schema, unknown_named_struct +@derive(Clone) model ResolvedAggregateMeasure: kind: AggregateKind function_ref: str canonical_name: str expr: Expression has_expr: bool + is_distinct: bool + filter_expr: Expression + has_filter: bool + ordering: list[SortField] @derive(Clone) @@ -197,34 +202,67 @@ def _set_operation_from_name(operation: str) -> Result[SubstraitSetOperation, Su def _resolved_measure_to_substrait(measure: ResolvedAggregateMeasure) -> Result[Measure, SubstraitLoweringError]: """Lower one resolved aggregate measure into a Substrait aggregate measure.""" + _validate_aggregate_modifiers(measure.clone())? mut arguments: list[FunctionArgument] = [] if measure.has_expr: arguments = [FunctionArgument(arg_type=Some(ArgType.Value(measure.expr.clone())))] + mut invocation: RustI32 = AggregationInvocation.All.into() + if measure.is_distinct: + invocation = AggregationInvocation.Distinct.into() + mut measure_filter: Option[Expression] = None + if measure.has_filter: + measure_filter = Some(measure.filter_expr.clone()) return Ok( Measure( measure=Some( AggregateFunction( - function_reference=_aggregate_function_reference(measure)?, + function_reference=_aggregate_function_reference(measure.clone())?, arguments=arguments, - sorts=[], + sorts=measure.ordering, output_type=None, - invocation=AggregationInvocation.All.into(), + invocation=invocation, phase=AggregationPhase.Unspecified.into(), args=[], options=[], ), ), - filter=None, + filter=measure_filter, ), ) -def _aggregate_function_reference(measure: ResolvedAggregateMeasure) -> Result[u32, SubstraitLoweringError]: - """Resolve one aggregate measure through declaration-side registry metadata.""" +def _aggregate_registry_entry( + measure: ResolvedAggregateMeasure, +) -> Result[FunctionRegistryEntry, SubstraitLoweringError]: + """Resolve one aggregate measure registry entry and validate its semantic class.""" match function_registry_entry(measure.function_ref): Some(entry) => if entry.function_class != FunctionClass.Aggregate: return Err(invalid_scalar_expression(f"{entry.function_ref} is not registered as an aggregate function")) + return Ok(entry) + None => + return Err(invalid_scalar_expression(f"missing aggregate registry entry for `{measure.canonical_name}`")) + + +def _validate_aggregate_modifiers(measure: ResolvedAggregateMeasure) -> Result[None, SubstraitLoweringError]: + """Validate RFC 017 modifiers against aggregate registry metadata.""" + entry = _aggregate_registry_entry(measure.clone())? + if measure.is_distinct: + if not entry.aggregate_modifiers.allows_distinct: + return Err(invalid_scalar_expression(f"{entry.function_ref} does not allow DISTINCT aggregate input")) + if not measure.has_expr: + return Err(invalid_scalar_expression("DISTINCT aggregate input requires an aggregate expression")) + if measure.has_filter and not entry.aggregate_modifiers.allows_filter: + return Err(invalid_scalar_expression(f"{entry.function_ref} does not allow aggregate-local FILTER")) + if len(measure.ordering) > 0 and not entry.aggregate_modifiers.allows_ordered_input: + return Err(invalid_scalar_expression(f"{entry.function_ref} does not allow ordered aggregate input")) + return Ok(None) + + +def _aggregate_function_reference(measure: ResolvedAggregateMeasure) -> Result[u32, SubstraitLoweringError]: + """Resolve one aggregate measure through declaration-side registry metadata.""" + match _aggregate_registry_entry(measure): + Ok(entry) => if entry.substrait.kind != SubstraitMappingKind.ExtensionFunction: return Err( invalid_scalar_expression( @@ -232,8 +270,7 @@ def _aggregate_function_reference(measure: ResolvedAggregateMeasure) -> Result[u ), ) return Ok(entry.substrait.anchor) - None => - return Err(invalid_scalar_expression(f"missing aggregate registry entry for `{measure.canonical_name}`")) + Err(err) => return Err(err) def _is_argument_free_count(measure: AggregateMeasure) -> bool: @@ -246,6 +283,8 @@ def _resolved_measure( input_columns: list[str], ) -> Result[ResolvedAggregateMeasure, SubstraitLoweringError]: """Resolve one aggregate measure against the current input-column list.""" + filter_expr = _resolved_measure_filter_expr(measure.clone(), input_columns)? + ordering = [_sort_field(input_columns, key)? for key in measure.ordering] if _is_argument_free_count(measure): return Ok( ResolvedAggregateMeasure( @@ -254,6 +293,10 @@ def _resolved_measure( canonical_name=measure.canonical_name, expr=bool_expr(true), has_expr=false, + is_distinct=measure.is_distinct, + filter_expr=filter_expr, + has_filter=measure.has_filter, + ordering=ordering, ), ) return Ok( @@ -263,10 +306,24 @@ def _resolved_measure( canonical_name=measure.canonical_name, expr=scalar_expr(input_columns, measure.expr)?, has_expr=true, + is_distinct=measure.is_distinct, + filter_expr=filter_expr, + has_filter=measure.has_filter, + ordering=ordering, ), ) +def _resolved_measure_filter_expr( + measure: AggregateMeasure, + input_columns: list[str], +) -> Result[Expression, SubstraitLoweringError]: + """Resolve one aggregate-local filter expression or a harmless placeholder.""" + if measure.has_filter: + return Ok(filter_predicate_expr(input_columns, measure.filter_expr)?) + return Ok(bool_expr(true)) + + def _lowered_rel_or_raise(result: Result[Rel, SubstraitLoweringError]) -> Rel: """Return one lowered relation or raise the structured lowering diagnostic for infallible call sites.""" match result: diff --git a/tests/fixtures/aggregate_modifiers.csv b/tests/fixtures/aggregate_modifiers.csv new file mode 100644 index 0000000..b11475c --- /dev/null +++ b/tests/fixtures/aggregate_modifiers.csv @@ -0,0 +1,7 @@ +customer_id,amount,status,product_id +A,10,paid,p1 +A,15,paid,p1 +A,3,cancelled,p2 +B,7,paid,p3 +B,9,cancelled,p3 +B,11,paid,p4 diff --git a/tests/test_dataset.incn b/tests/test_dataset.incn index ca25527..8d762b1 100644 --- a/tests/test_dataset.incn +++ b/tests/test_dataset.incn @@ -13,7 +13,10 @@ from functions import ( bool_lit, col, count, + count_distinct, count_expr, + count_if, + eq, float_expr, int_expr, int_lit, @@ -153,6 +156,8 @@ def test_smoke__dataset_types_are_published() -> None: sum_result = sum(amount) count_result = count() expression_count_result = count_expr(amount) + distinct_count_result = count_distinct(amount) + filtered_count_result = count_if(eq(col("status"), str_lit("paid"))) avg_result = avg(amount) min_result = min(amount) max_result = max(amount) @@ -162,6 +167,8 @@ def test_smoke__dataset_types_are_published() -> None: assert column_expr_name(count_result.expr) == "", "count should remain a zero-argument aggregate helper" assert expression_count_result.has_expr, "count(expr) should mark expression-count measures explicitly" assert column_expr_name(expression_count_result.expr) == "amount", "count(expr) should preserve the selected input column expression" + assert distinct_count_result.is_distinct, "count_distinct should build a distinct aggregate measure" + assert filtered_count_result.has_filter, "count_if should build a filtered aggregate measure" assert avg_result.kind == AggregateKind.Avg, "avg should build an aggregate measure" assert min_result.kind == AggregateKind.Min, "min should build an aggregate measure" assert max_result.kind == AggregateKind.Max, "max should build an aggregate measure" diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 5db41b1..8d673f3 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -21,7 +21,9 @@ from functions import ( col, coalesce, count, + count_distinct, count_expr, + count_if, desc, desc_nulls_first, desc_nulls_last, @@ -178,7 +180,7 @@ def _local_entry_by_namespace_and_name_or_fail( def _expected_registry_names() -> list[str]: """Return the expected registered public helper names.""" - return ["col", "lit", "sum", "count", "count_expr", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round"] + return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round"] def _expected_substrait_mapped_names() -> list[str]: @@ -203,6 +205,8 @@ def _exercise_current_public_helpers() -> None: sum(amount) count() count_expr(status) + count_distinct(status) + count_if(eq(status, str_lit("paid"))) avg(amount) min(amount) max(amount) @@ -356,13 +360,33 @@ def test_function_registry__core_helpers_expose_portable_policy_metadata() -> No # -- Act / Assert -- for entry in entries: assert entry.namespace == core_function_namespace(), f"{entry.function_ref} should live in the core function namespace" - if entry.canonical_name == "count_expr": - assert entry.policy_category == FunctionPolicyCategory.CompatibilityAlias, "count_expr should be marked as a compatibility helper" + if entry.canonical_name == "count_expr" or entry.canonical_name == "count_distinct" or entry.canonical_name == "count_if": + assert entry.policy_category == FunctionPolicyCategory.CompatibilityAlias, f"{entry.canonical_name} should be marked as a compatibility helper" assert entry.alias_policy == FunctionAliasPolicy.OptInCompatibility, "compatibility helpers should be opt-in by policy" continue assert entry.policy_category == FunctionPolicyCategory.PortableCore, f"{entry.function_ref} should be portable core" +def test_function_registry__aggregate_helpers_expose_modifier_policy() -> None: + """Assert aggregate registry metadata declares RFC 017 modifier support.""" + # -- Arrange -- + _exercise_current_public_helpers() + + # -- Act -- + count_entry = _entry_by_name_or_fail("count") + sum_entry = _entry_by_name_or_fail("sum") + abs_entry = _entry_by_name_or_fail("abs") + + # -- Assert -- + assert count_entry.aggregate_modifiers.allows_distinct, "aggregate helpers should allow DISTINCT when they have input expressions" + assert count_entry.aggregate_modifiers.allows_filter, "aggregate helpers should allow aggregate-local filters" + assert not count_entry.aggregate_modifiers.allows_ordered_input, "core aggregates should reject ordered input until an order-sensitive aggregate lands" + assert sum_entry.aggregate_modifiers.allows_distinct, "numeric aggregates should allow DISTINCT" + assert sum_entry.aggregate_modifiers.allows_filter, "numeric aggregates should allow aggregate-local filters" + assert not abs_entry.aggregate_modifiers.allows_distinct, "scalar helpers should not expose aggregate modifier support" + assert not abs_entry.aggregate_modifiers.allows_filter, "scalar helpers should not expose aggregate modifier support" + + def test_function_registry__extension_policy_is_separate_from_scalar_class() -> None: """Assert extension-only functions can be scalar while remaining explicitly namespaced.""" # -- Arrange -- @@ -464,6 +488,8 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non _assert_extension_mapping("sum", "sum", SUM_FUNCTION_ANCHOR) _assert_extension_mapping("count", "count", COUNT_FUNCTION_ANCHOR) _assert_extension_mapping("count_expr", "count", COUNT_FUNCTION_ANCHOR) + _assert_rewrite_mapping("count_distinct", "count_expr(expr).distinct()") + _assert_rewrite_mapping("count_if", "count().filter(predicate)") _assert_extension_mapping("avg", "avg", AVG_FUNCTION_ANCHOR) _assert_extension_mapping("min", "min", MIN_FUNCTION_ANCHOR) _assert_extension_mapping("max", "max", MAX_FUNCTION_ANCHOR) @@ -556,6 +582,8 @@ def test_function_registry__public_helpers_preserve_existing_behavior() -> None: sum_measure = sum(amount) count_measure = count() expression_count_measure = count_expr(status) + distinct_count_measure = count_distinct(status) + filtered_count_measure = count_if(eq(status, str_lit("paid"))) avg_measure = avg(amount) min_measure = min(amount) max_measure = max(amount) @@ -581,6 +609,10 @@ def test_function_registry__public_helpers_preserve_existing_behavior() -> None: assert not count_measure.has_expr, "count wrapper should preserve argument-free count semantics" assert expression_count_measure.kind == AggregateKind.Count, "count(expr) wrapper should preserve aggregate kind" assert expression_count_measure.has_expr, "count(expr) wrapper should preserve expression-count semantics" + assert distinct_count_measure.kind == AggregateKind.Count, "count_distinct wrapper should preserve aggregate kind" + assert distinct_count_measure.is_distinct, "count_distinct should lower to the aggregate distinct modifier" + assert filtered_count_measure.kind == AggregateKind.Count, "count_if wrapper should preserve aggregate kind" + assert filtered_count_measure.has_filter, "count_if should lower to the aggregate-local filter modifier" assert avg_measure.kind == AggregateKind.Avg, "avg wrapper should preserve aggregate kind" assert min_measure.kind == AggregateKind.Min, "min wrapper should preserve aggregate kind" assert max_measure.kind == AggregateKind.Max, "max wrapper should preserve aggregate kind" diff --git a/tests/test_prism.incn b/tests/test_prism.incn index e668c2a..f0c7490 100644 --- a/tests/test_prism.incn +++ b/tests/test_prism.incn @@ -1,6 +1,6 @@ """Internal Prism engine tests against the shared-store cursor substrate.""" -from functions import always_false, always_true, col, count, lit, mul, sum +from functions import always_false, always_true, col, count, count_expr, lit, mul, sum from prism import ( PrismCursor, prism_cursor_apply_filter, @@ -181,6 +181,28 @@ def test_prism__cross_store_join_dedups_equivalent_rhs_multistep_branches() -> N assert plan_contains_relation_kind(plan, str("ReadRel")), "deduped multistep adoption should still preserve read roots" +def test_prism__cross_store_adoption_keeps_distinct_aggregate_modifier_state() -> None: + # -- Arrange -- + _register_projection_test_schema(str("orders")) + _register_projection_test_schema(str("orders_archive")) + left: PrismCursor[Order] = prism_cursor_named_table(str("orders")).filter(always_false()) + right_base: PrismCursor[Order] = prism_cursor_named_table(str("orders_archive")) + right_plain: PrismCursor[Order] = right_base.group_by([col("id")]).agg([count_expr(col("id"))]) + right_distinct: PrismCursor[Order] = right_base.group_by([col("id")]).agg([count_expr(col("id")).distinct()]) + + # -- Act -- + right_joined: PrismCursor[Order] = right_plain.join(right_distinct.clone(), true) + joined: PrismCursor[Order] = left.join(right_joined.clone(), true) + plan = joined.to_substrait_plan() + + # -- Assert -- + assert prism_cursors_share_store(left, right_joined) is false, "final join should still adopt the rhs aggregate tree" + assert prism_cursor_store_node_count(joined) == 8, "cross-store adoption must not dedup aggregates that differ by modifier state" + assert prism_cursor_authored_node_count(joined) == 8, "joined cursor should preserve both modified and unmodified aggregate branches" + assert relation_kind_name(root_rel(plan)) == str("JoinRel"), "aggregate modifier adoption should still lower through JoinRel" + assert plan_contains_relation_kind(plan, str("AggregateRel")), "both rhs aggregate branches should survive into lowering" + + def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: # -- Arrange -- _register_projection_test_schema(str("orders")) diff --git a/tests/test_session_aggregates.incn b/tests/test_session_aggregates.incn index a8060ef..871bf6b 100644 --- a/tests/test_session_aggregates.incn +++ b/tests/test_session_aggregates.incn @@ -1,9 +1,9 @@ """End-to-end Session aggregate execution tests over the DataFusion backend.""" -from functions import avg, col, count, count_expr, max, min, sum -from dataset import LazyFrame +from functions import avg, col, count, count_distinct, count_expr, count_if, eq, max, min, str_lit, sum +from dataset import DataFrame, LazyFrame from session import Session -from std.testing import assert_is_ok +from std.testing import assert_is_ok, fail_t @derive(Clone) @@ -12,15 +12,34 @@ pub model AggregateOrder: pub amount: int +@derive(Clone) +pub model AggregateModifierOrder: + pub customer_id: str + pub amount: int + pub status: str + pub product_id: str + + @derive(Clone) pub model Order: pub id: int +const AGGREGATE_MODIFIERS_CSV_FIXTURE: str = "tests/fixtures/aggregate_modifiers.csv" const AGGREGATE_ORDERS_CSV_FIXTURE: str = "tests/fixtures/aggregate_orders.csv" const ORDERS_CSV_FIXTURE: str = "tests/fixtures/orders.csv" +def _collect_modifier_or_fail( + mut session: Session, + grouped: LazyFrame[AggregateModifierOrder], +) -> DataFrame[AggregateModifierOrder]: + """Collect a modified aggregate frame or fail with the backend diagnostic.""" + match session.collect(grouped): + Ok(df) => return df + Err(err) => return fail_t(err.error_message()) + + def test_session_aggregates__grouped_collect_executes_sum_and_count() -> None: # -- Arrange -- mut session = Session.default() @@ -78,6 +97,36 @@ def test_session_aggregates__grouped_collect_executes_core_aggregates() -> None: assert payload.contains("1"), "customer B count and expression count should be materialized" +def test_session_aggregates__grouped_collect_executes_distinct_and_filter_modifiers() -> None: + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[AggregateModifierOrder] = assert_is_ok( + session.read_csv("aggregate_modifiers", AGGREGATE_MODIFIERS_CSV_FIXTURE), + "aggregate modifiers fixture should load", + ) + paid = eq(col("status"), str_lit("paid")) + grouped = lazy.group_by([col("customer_id")]).agg( + [count_distinct(col("product_id")), count_if(paid.clone()), sum(col("amount")).filter(paid)], + ) + df = _collect_modifier_or_fail(session, grouped) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 2, "modified grouped aggregate should produce one row per customer" + assert resolved == ["customer_id", "count_distinct_product_id", "count_filtered", "sum_amount_filtered"], "modified aggregate output columns should be stable" + assert payload.contains("A"), "modified aggregate output should contain customer A" + assert payload.contains("B"), "modified aggregate output should contain customer B" + assert payload.contains("25"), "customer A paid sum should be materialized" + assert payload.contains("18"), "customer B paid sum should be materialized" + assert payload.contains("2"), "distinct and filtered counts should be materialized" + assert not payload.contains("28"), "customer A filtered sum should exclude cancelled rows" + assert not payload.contains("27"), "customer B filtered sum should exclude cancelled rows" + assert not payload.contains("3"), "distinct counts should not count duplicated product ids" + + def test_session_aggregates__global_collect_executes_count() -> None: # -- Arrange -- mut session = Session.default() diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 60b3f49..0e69548 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -1,6 +1,6 @@ """Tests for RFC 002 proto-backed Substrait emission and conformance alignment.""" -from std.testing import fail_t +from std.testing import assert_is_err, fail_t from functions import ( abs, add, @@ -16,7 +16,9 @@ from functions import ( col, coalesce, count, + count_distinct, count_expr, + count_if, desc, div, eq, @@ -56,6 +58,9 @@ from substrait.function_extensions import ( ) from substrait.inspect import ( aggregate_measure_function_names, + aggregate_measure_filter_flags, + aggregate_measure_invocation_names, + aggregate_measure_sort_counts, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -98,6 +103,7 @@ from substrait.relations import ( set_rel, set_rel_of_kind, sort_rel_of_columns, + try_aggregate_rel_of_columns, ) from substrait.schema_registry import named_table_columns, register_named_table_schema from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind @@ -423,6 +429,59 @@ def test_plan__aggregate_rel_accepts_scalar_group_and_measure_expressions() -> N assert output_columns[5] == "max", "computed max measures should use stable fallback names" +def test_plan__aggregate_rel_lowers_distinct_and_filter_modifiers() -> None: + # -- Arrange -- + _register_orders_schema() + base = read_named_table_rel("orders") + aggregated = aggregate_rel( + base, + [col("id")], + [count_distinct(col("id")), count_if(gt(col("id"), lit(1))), sum(col("id")).distinct().filter( + gt(col("id"), lit(0)), + )], + ) + plan = plan_from_root_relation(aggregated, ["id", "count_distinct_id", "count_filtered", "sum_distinct_id_filtered"]) + + # -- Act -- + output_columns = relation_output_columns(aggregated) + aggregate_functions = aggregate_measure_function_names(aggregated) + aggregate_invocations = aggregate_measure_invocation_names(aggregated) + aggregate_filter_flags = aggregate_measure_filter_flags(aggregated) + aggregate_sort_counts = aggregate_measure_sort_counts(aggregated) + + # -- Assert -- + assert relation_kind_name(aggregated) == "AggregateRel", "aggregate modifiers should still lower through AggregateRel" + assert plan_has_extension_urn(plan, registered_substrait_extension_uris()[0]), "modified aggregate plans should register aggregate extensions" + assert output_columns == ["id", "count_distinct_id", "count_filtered", "sum_distinct_id_filtered"], "modified aggregate output names should remain stable" + assert aggregate_functions == ["count", "count", "sum"], "compatibility helpers should lower to canonical aggregate functions" + assert aggregate_invocations == ["Distinct", "All", "Distinct"], "distinct modifiers should lower to Substrait aggregate invocations" + assert aggregate_filter_flags == [false, true, true], "aggregate-local filters should lower onto the target measures" + assert aggregate_sort_counts == [0, 0, 0], "core aggregates should not emit ordered input fields" + + +def test_plan__aggregate_rel_rejects_invalid_modifier_shapes() -> None: + # -- Arrange -- + _register_orders_schema() + base = read_named_table_rel("orders") + + # -- Act -- + distinct_count_result = try_aggregate_rel_of_columns(base.clone(), ["id"], [col("id")], [count().distinct()]) + ordered_sum_result = try_aggregate_rel_of_columns( + base, + ["id"], + [col("id")], + [sum(col("id")).order_by([asc(col("id"))])], + ) + distinct_err = assert_is_err(distinct_count_result, "count().distinct() should require an input expression") + ordered_err = assert_is_err(ordered_sum_result, "core aggregates should reject ordered input for now") + + # -- Assert -- + assert distinct_err.kind == SubstraitLoweringErrorKind.InvalidScalarExpression, "invalid distinct count should be a scalar lowering diagnostic" + assert distinct_err.message.contains("DISTINCT"), "invalid distinct diagnostic should mention DISTINCT" + assert ordered_err.kind == SubstraitLoweringErrorKind.InvalidScalarExpression, "unsupported ordered aggregate should be a scalar lowering diagnostic" + assert ordered_err.message.contains("ordered aggregate input"), "ordered aggregate diagnostic should identify the unsupported modifier" + + def test_plan__set_rel_uses_operation_enum() -> None: # -- Arrange -- left = read_named_table_rel("orders_current") From 76cdb21ee4c54d4ee50a53adc6984908c3b10e92 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Tue, 26 May 2026 00:11:15 +0200 Subject: [PATCH 3/6] feature - implement nested scalar functions (#37) --- docs/language/reference/functions/index.md | 4 +- docs/language/reference/functions/nested.md | 58 +++++++++ docs/release_notes/v0_1.md | 1 + docs/rfcs/020_nested_data_functions.md | 44 +++---- src/functions/mod.incn | 22 ++++ src/functions/nested/array.incn | 38 ++++++ src/functions/nested/array_contains.incn | 34 ++++++ src/functions/nested/array_distinct.incn | 33 ++++++ src/functions/nested/array_except.incn | 34 ++++++ src/functions/nested/array_flatten.incn | 37 ++++++ src/functions/nested/array_intersect.incn | 34 ++++++ src/functions/nested/array_join.incn | 34 ++++++ src/functions/nested/array_position.incn | 34 ++++++ src/functions/nested/array_reverse.incn | 33 ++++++ src/functions/nested/array_slice.incn | 35 ++++++ src/functions/nested/array_sort.incn | 33 ++++++ src/functions/nested/array_union.incn | 34 ++++++ src/functions/nested/arrays_overlap.incn | 34 ++++++ src/functions/nested/cardinality.incn | 33 ++++++ src/functions/nested/common.incn | 33 ++++++ src/functions/nested/element_at.incn | 34 ++++++ src/functions/nested/map_contains_key.incn | 35 ++++++ src/functions/nested/map_entries.incn | 33 ++++++ src/functions/nested/map_extract.incn | 34 ++++++ src/functions/nested/map_from_arrays.incn | 34 ++++++ src/functions/nested/map_keys.incn | 33 ++++++ src/functions/nested/map_values.incn | 33 ++++++ src/functions/nested/mod.incn | 24 ++++ src/functions/nested/named_struct.incn | 34 ++++++ src/lib.incn | 22 ++++ src/substrait/function_extensions.incn | 21 ++++ tests/test_function_registry.incn | 92 ++++++++++++++- tests/test_nested_data_functions.incn | 124 ++++++++++++++++++++ tests/test_session_projection.incn | 42 +++++++ tests/test_substrait_plan.incn | 54 +++++++++ 35 files changed, 1267 insertions(+), 24 deletions(-) create mode 100644 docs/language/reference/functions/nested.md create mode 100644 src/functions/nested/array.incn create mode 100644 src/functions/nested/array_contains.incn create mode 100644 src/functions/nested/array_distinct.incn create mode 100644 src/functions/nested/array_except.incn create mode 100644 src/functions/nested/array_flatten.incn create mode 100644 src/functions/nested/array_intersect.incn create mode 100644 src/functions/nested/array_join.incn create mode 100644 src/functions/nested/array_position.incn create mode 100644 src/functions/nested/array_reverse.incn create mode 100644 src/functions/nested/array_slice.incn create mode 100644 src/functions/nested/array_sort.incn create mode 100644 src/functions/nested/array_union.incn create mode 100644 src/functions/nested/arrays_overlap.incn create mode 100644 src/functions/nested/cardinality.incn create mode 100644 src/functions/nested/common.incn create mode 100644 src/functions/nested/element_at.incn create mode 100644 src/functions/nested/map_contains_key.incn create mode 100644 src/functions/nested/map_entries.incn create mode 100644 src/functions/nested/map_extract.incn create mode 100644 src/functions/nested/map_from_arrays.incn create mode 100644 src/functions/nested/map_keys.incn create mode 100644 src/functions/nested/map_values.incn create mode 100644 src/functions/nested/mod.incn create mode 100644 src/functions/nested/named_struct.incn create mode 100644 tests/test_nested_data_functions.incn diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index 8b32f31..f6347a8 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -7,10 +7,11 @@ Today the concrete shipped surfaces are documented here: - [Filter builders](../builders/filters.md) - [Aggregate builders](../builders/aggregates.md) - [Projection builders](../builders/projections.md) +- [Nested data functions](nested.md) The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, and aggregates. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, and nested data. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. @@ -31,6 +32,7 @@ The registered helper surface currently includes: | `coalesce(...)`, `nullif(...)`, `case_when(...)` | scalar | registered Substrait mappings; `case_when(...)` lowers as built-in `IfThen` | | `in_(...)`, `between(...)` | scalar | built-in membership/range lowering (`SingularOrList` and `between`) | | `abs(...)`, `ceil(...)`, `floor(...)`, `round(...)` | scalar | registered Substrait math scalar mappings; `round(...)` is currently the single-argument form | +| `array(...)`, `cardinality(...)`, `array_contains(...)`, `arrays_overlap(...)`, `array_position(...)`, `element_at(...)`, `array_sort(...)`, `array_distinct(...)`, `array_except(...)`, `array_intersect(...)`, `array_union(...)`, `array_join(...)`, `array_slice(...)`, `array_reverse(...)`, `array_flatten(...)`, `map_from_arrays(...)`, `map_extract(...)`, `map_contains_key(...)`, `map_keys(...)`, `map_values(...)`, `map_entries(...)`, `named_struct(...)` | scalar | registered nested scalar helpers backed by Substrait extension mappings; `map_contains_key(...)` lowers as a documented predicate rewrite | | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count()`, `count_expr(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions; `count_expr(...)` is a compatibility spelling for future `count(expr)` helper overloading | | `count_distinct(...)`, `count_if(...)` | aggregate | compatibility helpers that lower through aggregate modifiers over canonical `count` semantics | diff --git a/docs/language/reference/functions/nested.md b/docs/language/reference/functions/nested.md new file mode 100644 index 0000000..644e1ad --- /dev/null +++ b/docs/language/reference/functions/nested.md @@ -0,0 +1,58 @@ +# Nested Data Functions (Reference) + +Nested data helpers build and inspect row-level arrays, maps, and structs. They are scalar expressions: every helper returns one value for each input row and does not change relation cardinality. + +Generator or table-valued operations such as row-expanding `explode(...)` are separate from this page. + +## Arrays + +| Function | Meaning | +| --- | --- | +| `array(values)` | Build an array expression from one or more scalar expressions. | +| `cardinality(value)` | Return the size of an array or map. | +| `array_contains(array_expr, value)` | Return whether an array contains a value. | +| `arrays_overlap(left, right)` | Return whether two arrays have any elements in common. | +| `array_position(array_expr, value)` | Return the one-based position of a value. | +| `element_at(array_expr, index)` | Return an array element by one-based index. | +| `array_sort(array_expr)` | Sort one array value. | +| `array_distinct(array_expr)` | Remove duplicate elements from one array value. | +| `array_except(left, right)` | Return elements from `left` that are not in `right`. | +| `array_intersect(left, right)` | Return elements shared by both arrays. | +| `array_union(left, right)` | Return the union of both arrays. | +| `array_join(array_expr, delimiter)` | Join a string array into one string. | +| `array_slice(array_expr, start, stop)` | Return a one-based array slice using the backend adapter's slice contract. | +| `array_reverse(array_expr)` | Reverse one array value. | +| `array_flatten(array_expr)` | Flatten an array-of-arrays into one row-level array value. | + +## Maps And Structs + +| Function | Meaning | +| --- | --- | +| `map_from_arrays(keys, values)` | Build a map from key and value arrays. | +| `map_extract(map_expr, key)` | Return the values associated with a key. | +| `map_contains_key(map_expr, key)` | Return whether `map_extract(...)` finds at least one value for the key. | +| `map_keys(map_expr)` | Return the map's keys as an array. | +| `map_values(map_expr)` | Return the map's values as an array. | +| `map_entries(map_expr)` | Return map entries. | +| `named_struct(field_names, values)` | Build a struct expression with explicit field names. | + +## Example + +```incan +from pub::inql.functions import array, array_contains, cardinality, col, element_at, lit + +projected = ( + events + .with_column("tags", array([lit("paid"), col("source")])) + .with_column("tag_count", cardinality(col("tags"))) + .with_column("has_paid_tag", array_contains(col("tags"), lit("paid"))) + .with_column("first_tag", element_at(col("tags"), lit(1))) +) +``` + +## Semantics + +- Array indexing is one-based for `element_at(...)`, `array_position(...)`, and `array_slice(...)`. +- `element_at(...)` currently maps to the portable array-element adapter path. Out-of-range behavior follows the current backend adapter's recoverable result until InQL has a richer static/runtime error-policy split for strict versus try-style element access. +- `array_flatten(...)` is intentionally named to avoid colliding with future table-valued or generator `flatten(...)` forms. +- Grouping or ordering by nested values is not documented as portable until equality and ordering semantics for arrays, maps, and structs are specified. diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index eacddef..2543685 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -15,6 +15,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Scalar expressions:** RFC 012 unifies filter predicates, computed projection values, grouping keys, and aggregate inputs around one `ColumnExpr` surface with canonical `lit(...)` and typed literal helpers. - **Core scalar functions:** RFC 015 adds registry-backed scalar function applications and the first core helper slice for casts, comparisons, boolean logic, null/NaN predicates, arithmetic, conditionals, membership/range predicates, and ordering expressions. Implemented helpers lower to Substrait IR through registry metadata, built-in Rex shapes, or structural sort-field lowering; DataFusion remains the first execution adapter rather than the semantic boundary. - **Common scalar functions:** The first RFC 018 slice adds registry-backed math helpers for `abs(...)`, `ceil(...)`, `floor(...)`, and single-argument `round(...)`, with Substrait mappings and DataFusion-backed execution coverage. +- **Nested data functions:** RFC 020 adds registry-backed scalar helpers for array construction/access, cardinality, containment, overlap, sorting, set-like operations, joining, slicing, reversing, scalar array flattening, map construction/access, map key/value/entry extraction, map key containment, and named struct construction. These helpers lower through Substrait extension metadata and execute through the DataFusion-backed Session path without introducing generator semantics. - **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. - **Function extension policy:** RFC 024 policy metadata now distinguishes portable core functions, namespaced extension-only functions, opt-in compatibility aliases, engine-specific functions, and rejected compatibility requests without adding an extension plugin system or backend-owned semantics. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. diff --git a/docs/rfcs/020_nested_data_functions.md b/docs/rfcs/020_nested_data_functions.md index c3f0d35..2680ec3 100644 --- a/docs/rfcs/020_nested_data_functions.md +++ b/docs/rfcs/020_nested_data_functions.md @@ -1,6 +1,6 @@ # InQL RFC 020: Nested data functions -- **Status:** Draft +- **Status:** Implemented - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -11,12 +11,12 @@ - InQL RFC 021 (generator and table-valued functions) - **Issue:** [InQL #37](https://github.com/dannys-code-corner/InQL/issues/37) - **RFC PR:** — -- **Written against:** Incan v0.2 -- **Shipped in:** — +- **Written against:** Incan v0.3-era InQL +- **Shipped in:** v0.1 ## Summary -This RFC defines InQL functions for nested scalar values: arrays, maps, and structs. It covers construction, element access, cardinality, containment, sorting, set-like array operations, map entry access, and higher-order collection functions as a later extension point. Nested functions remain scalar when they produce one value per input row; cardinality-changing operations such as `explode` belong to a separate generator RFC. +This RFC defines InQL functions for nested scalar values: arrays, maps, and structs. It covers construction, element access, cardinality, containment, overlap checks, sorting, set-like array operations, scalar array flattening, map entry access, and higher-order collection functions as a later extension point. Nested functions remain scalar when they produce one value per input row; cardinality-changing operations such as `explode` belong to a separate generator RFC. ## Motivation @@ -28,7 +28,7 @@ The split matters. `array_contains(.items, "x")` is a row-level scalar predicate - Define scalar functions for arrays, maps, and structs. - Distinguish nested scalar operations from generators. -- Define element access and safe element access. +- Define element access with an explicit one-based indexing policy. - Define collection size, containment, sorting, and set-like operations. - Leave lambda-based higher-order functions as a later design decision unless the host language surface is ready. @@ -41,16 +41,16 @@ The split matters. `array_contains(.items, "x")` is a row-level scalar predicate ## Guide-level explanation (how authors think about it) -Authors should be able to inspect and manipulate nested values without changing relation cardinality: +Authors can inspect and manipulate nested values without changing relation cardinality: ```incan -from pub::inql.functions import array_contains, cardinality, col, element_at, map_keys +from pub::inql.functions import array_contains, cardinality, col, element_at, lit, map_keys enriched = ( events - .filter(array_contains(col("tags"), "purchase")) + .filter(array_contains(col("tags"), lit("purchase"))) .with_column("tag_count", cardinality(col("tags"))) - .with_column("first_item", element_at(col("items"), 1)) + .with_column("first_item", element_at(col("items"), lit(1))) .with_column("metadata_keys", map_keys(col("metadata"))) ) ``` @@ -59,17 +59,17 @@ If an author wants one output row per item, that is a generator/table-valued ope ## Reference-level explanation (precise rules) -InQL should define array construction with `array`, struct construction with `struct` or `named_struct`, and map construction with `create_map` or an equivalent canonical name. +InQL defines array construction with `array`, struct construction with `named_struct`, and map construction with `map_from_arrays`. -InQL should define `cardinality` as the canonical size function for arrays and maps. Compatibility aliases such as `size`, `array_size`, and `array_length` may resolve to `cardinality` where semantics match. +InQL defines `cardinality` as the canonical size function for arrays and maps. Compatibility aliases such as `size`, `array_size`, and `array_length` may resolve to `cardinality` where semantics match, but the initial implemented surface keeps the canonical spelling. -InQL should define element access functions including `element_at`, `try_element_at`, and `get`. Strict element access must fail or diagnose according to its registry error policy when an index or key is invalid. `try_element_at` must produce the recoverable result defined by its registry entry. +InQL defines array element access with `element_at(array_expr, index)`. Indexes are one-based. Current lowering maps to the portable array-element adapter path and uses the backend adapter's recoverable out-of-range behavior until InQL has a richer static/runtime error-policy split for strict versus try-style element access. -InQL should define array predicates and transforms including `array_contains`, `array_position`, `array_sort`, `array_distinct`, `array_except`, `array_intersect`, `array_union`, `array_join`, `arrays_overlap`, `flatten`, `slice`, and `reverse` where type and null semantics are specified. +InQL defines array predicates and transforms including `array_contains`, `array_position`, `array_sort`, `array_distinct`, `array_except`, `array_intersect`, `array_union`, `array_join`, `arrays_overlap`, `array_flatten`, `array_slice`, and `array_reverse` where type and null semantics are specified by the registry and backend adapter boundary. The scalar array-flattening helper is named `array_flatten` so table-valued or generator `flatten` remains available for RFC 021. -InQL should define map functions including `map_contains_key`, `map_entries`, `map_from_arrays`, `map_from_entries`, `map_keys`, and `map_values`. +InQL defines map functions including `map_contains_key`, `map_entries`, `map_extract`, `map_from_arrays`, `map_keys`, and `map_values`. -InQL should account for object-style warehouse functions such as `object_construct`, `object_construct_keep_null`, `object_delete`, `object_insert`, `object_keys`, and `object_pick`. These should be modeled through typed object/map semantics where possible and through a variant/semi-structured family only when dynamic value semantics are required. +Object-style warehouse functions such as `object_construct`, `object_construct_keep_null`, `object_delete`, `object_insert`, `object_keys`, and `object_pick` are accounted for as semi-structured and dynamic-object concerns. They should be modeled through typed object/map semantics where possible and through the RFC 022 semi-structured family only when dynamic value semantics are required. Higher-order functions such as `transform`, `filter`, `exists`, `forall`, `aggregate`, `reduce`, `zip_with`, `map_filter`, `transform_keys`, and `transform_values` must not reach Planned status until lambda or equivalent callback semantics are specified for InQL expressions. @@ -87,7 +87,7 @@ Index origin, invalid-index behavior, null container behavior, null element beha ### Interaction with other InQL surfaces -Nested functions may appear wherever scalar expressions of their result type are valid. Grouping by nested values may be restricted until equality and ordering semantics for nested values are fully specified. +Nested functions may appear wherever scalar expressions of their result type are valid. Grouping by nested values is not documented as portable until equality and ordering semantics for nested values are fully specified. ### Compatibility / migration @@ -113,10 +113,12 @@ No current InQL APIs are expected to break. Nested functions should be additive - **Execution / interchange** — Prism and Substrait lowering must preserve nested value semantics or diagnose unsupported operations. - **Documentation** — docs should separate nested scalar operations from generator functions. -## Unresolved questions +## Design Decisions -- Should element access use one-based indexing for SQL/Spark compatibility or zero-based indexing for host-language familiarity? -- What should strict `element_at` do on out-of-range indexes? -- Should grouping and ordering over arrays, maps, and structs be allowed initially? +### Resolved - +- Element access, array position results, and array slice boundaries are one-based for SQL/Spark compatibility. +- `element_at(...)` uses the current adapter's recoverable array-element behavior for out-of-range indexes. A separate strict/try split is deferred until registry error policy can distinguish static validation failures from runtime recoverable results. +- Grouping and ordering over arrays, maps, and structs are not documented as portable in the initial implementation. +- Scalar `array_flatten(...)` is separate from RFC 021 table-valued or generator flattening. +- Higher-order collection functions remain deferred until InQL expression callback or lambda semantics are specified. diff --git a/src/functions/mod.incn b/src/functions/mod.incn index 4ec6028..e0e8754 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -39,6 +39,28 @@ pub from functions.math.abs import abs pub from functions.math.ceil import ceil pub from functions.math.floor import floor pub from functions.math.round import round +pub from functions.nested.array import array +pub from functions.nested.array_contains import array_contains +pub from functions.nested.array_distinct import array_distinct +pub from functions.nested.array_except import array_except +pub from functions.nested.array_flatten import array_flatten +pub from functions.nested.array_intersect import array_intersect +pub from functions.nested.array_join import array_join +pub from functions.nested.array_position import array_position +pub from functions.nested.array_reverse import array_reverse +pub from functions.nested.array_slice import array_slice +pub from functions.nested.array_sort import array_sort +pub from functions.nested.array_union import array_union +pub from functions.nested.arrays_overlap import arrays_overlap +pub from functions.nested.cardinality import cardinality +pub from functions.nested.element_at import element_at +pub from functions.nested.map_contains_key import map_contains_key +pub from functions.nested.map_entries import map_entries +pub from functions.nested.map_extract import map_extract +pub from functions.nested.map_from_arrays import map_from_arrays +pub from functions.nested.map_keys import map_keys +pub from functions.nested.map_values import map_values +pub from functions.nested.named_struct import named_struct pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div diff --git a/src/functions/nested/array.incn b/src/functions/nested/array.incn new file mode 100644 index 0000000..3c681e7 --- /dev/null +++ b/src/functions/nested/array.incn @@ -0,0 +1,38 @@ +""" +Array construction helper. + +`array` builds a row-level nested scalar value and does not change relation cardinality. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application, require_non_empty_args +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import MAKE_ARRAY_FUNCTION_ANCHOR + + +@function_registry.add("array", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("make_array", MAKE_ARRAY_FUNCTION_ANCHOR), +)) +pub def array(values: list[ColumnExpr]) -> ColumnExpr: + """ + Build an array expression from one or more scalar values. + + Examples: + tags = array([str_lit("paid"), col("status")]) + + Parameters: + values: Element expressions in array order. + """ + require_non_empty_args(values) + return nested_application("array", values) diff --git a/src/functions/nested/array_contains.incn b/src/functions/nested/array_contains.incn new file mode 100644 index 0000000..fc0f480 --- /dev/null +++ b/src/functions/nested/array_contains.incn @@ -0,0 +1,34 @@ +"""Array containment predicate helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_HAS_FUNCTION_ANCHOR + + +@function_registry.add("array_contains", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_has", ARRAY_HAS_FUNCTION_ANCHOR), +)) +pub def array_contains(array_expr: ColumnExpr, value: ColumnExpr) -> ColumnExpr: + """ + Return whether an array contains a value. + + Examples: + has_purchase = array_contains(col("tags"), str_lit("purchase")) + + Parameters: + array_expr: Array expression to search. + value: Value expression to find. + """ + return nested_application("array_contains", [array_expr, value]) diff --git a/src/functions/nested/array_distinct.incn b/src/functions/nested/array_distinct.incn new file mode 100644 index 0000000..2464b9d --- /dev/null +++ b/src/functions/nested/array_distinct.incn @@ -0,0 +1,33 @@ +"""Array distinct helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_DISTINCT_FUNCTION_ANCHOR + + +@function_registry.add("array_distinct", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_distinct", ARRAY_DISTINCT_FUNCTION_ANCHOR), +)) +pub def array_distinct(array_expr: ColumnExpr) -> ColumnExpr: + """ + Return an array with duplicate values removed. + + Examples: + unique_tags = array_distinct(col("tags")) + + Parameters: + array_expr: Array expression to de-duplicate. + """ + return nested_application("array_distinct", [array_expr]) diff --git a/src/functions/nested/array_except.incn b/src/functions/nested/array_except.incn new file mode 100644 index 0000000..10a2caf --- /dev/null +++ b/src/functions/nested/array_except.incn @@ -0,0 +1,34 @@ +"""Array set-difference helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_EXCEPT_FUNCTION_ANCHOR + + +@function_registry.add("array_except", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_except", ARRAY_EXCEPT_FUNCTION_ANCHOR), +)) +pub def array_except(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Return values from one array that are not present in another array. + + Examples: + missing_tags = array_except(col("expected_tags"), col("actual_tags")) + + Parameters: + left: Array expression that supplies candidate values. + right: Array expression containing values to remove. + """ + return nested_application("array_except", [left, right]) diff --git a/src/functions/nested/array_flatten.incn b/src/functions/nested/array_flatten.incn new file mode 100644 index 0000000..8403ee1 --- /dev/null +++ b/src/functions/nested/array_flatten.incn @@ -0,0 +1,37 @@ +""" +Array flattening helper. + +`array_flatten` is scalar: it flattens an array value in one input row and does not produce more rows. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_FLATTEN_FUNCTION_ANCHOR + + +@function_registry.add("array_flatten", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("flatten", ARRAY_FLATTEN_FUNCTION_ANCHOR), +)) +pub def array_flatten(array_expr: ColumnExpr) -> ColumnExpr: + """ + Flatten an array-of-arrays into one row-level array value. + + Examples: + flattened = array_flatten(col("nested_tags")) + + Parameters: + array_expr: Array expression to flatten. + """ + return nested_application("array_flatten", [array_expr]) diff --git a/src/functions/nested/array_intersect.incn b/src/functions/nested/array_intersect.incn new file mode 100644 index 0000000..a457da4 --- /dev/null +++ b/src/functions/nested/array_intersect.incn @@ -0,0 +1,34 @@ +"""Array set-intersection helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_INTERSECT_FUNCTION_ANCHOR + + +@function_registry.add("array_intersect", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_intersect", ARRAY_INTERSECT_FUNCTION_ANCHOR), +)) +pub def array_intersect(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Return values that are present in both arrays. + + Examples: + shared_tags = array_intersect(col("expected_tags"), col("actual_tags")) + + Parameters: + left: First array expression. + right: Second array expression. + """ + return nested_application("array_intersect", [left, right]) diff --git a/src/functions/nested/array_join.incn b/src/functions/nested/array_join.incn new file mode 100644 index 0000000..600a1d9 --- /dev/null +++ b/src/functions/nested/array_join.incn @@ -0,0 +1,34 @@ +"""Array-to-string join helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_TO_STRING_FUNCTION_ANCHOR + + +@function_registry.add("array_join", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_to_string", ARRAY_TO_STRING_FUNCTION_ANCHOR), +)) +pub def array_join(array_expr: ColumnExpr, delimiter: ColumnExpr) -> ColumnExpr: + """ + Join array elements into one string with a delimiter. + + Examples: + tag_text = array_join(col("tags"), str_lit(",")) + + Parameters: + array_expr: Array expression to render. + delimiter: String delimiter expression placed between elements. + """ + return nested_application("array_join", [array_expr, delimiter]) diff --git a/src/functions/nested/array_position.incn b/src/functions/nested/array_position.incn new file mode 100644 index 0000000..e443508 --- /dev/null +++ b/src/functions/nested/array_position.incn @@ -0,0 +1,34 @@ +"""Array first-position helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_POSITION_FUNCTION_ANCHOR + + +@function_registry.add("array_position", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_position", ARRAY_POSITION_FUNCTION_ANCHOR), +)) +pub def array_position(array_expr: ColumnExpr, value: ColumnExpr) -> ColumnExpr: + """ + Return the one-based first position of a value in an array, or null when absent. + + Examples: + first_paid = array_position(col("tags"), str_lit("paid")) + + Parameters: + array_expr: Array expression to search. + value: Value expression to find. + """ + return nested_application("array_position", [array_expr, value]) diff --git a/src/functions/nested/array_reverse.incn b/src/functions/nested/array_reverse.incn new file mode 100644 index 0000000..8e4c13c --- /dev/null +++ b/src/functions/nested/array_reverse.incn @@ -0,0 +1,33 @@ +"""Array reverse helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_REVERSE_FUNCTION_ANCHOR + + +@function_registry.add("array_reverse", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_reverse", ARRAY_REVERSE_FUNCTION_ANCHOR), +)) +pub def array_reverse(array_expr: ColumnExpr) -> ColumnExpr: + """ + Return an array with elements in reverse order. + + Examples: + newest_first = array_reverse(col("events")) + + Parameters: + array_expr: Array expression to reverse. + """ + return nested_application("array_reverse", [array_expr]) diff --git a/src/functions/nested/array_slice.incn b/src/functions/nested/array_slice.incn new file mode 100644 index 0000000..08aa946 --- /dev/null +++ b/src/functions/nested/array_slice.incn @@ -0,0 +1,35 @@ +"""Array slice helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_SLICE_FUNCTION_ANCHOR + + +@function_registry.add("array_slice", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_slice", ARRAY_SLICE_FUNCTION_ANCHOR), +)) +pub def array_slice(array_expr: ColumnExpr, start: ColumnExpr, stop: ColumnExpr) -> ColumnExpr: + """ + Return a one-based array slice. + + Examples: + first_two = array_slice(col("tags"), int_lit(1), int_lit(2)) + + Parameters: + array_expr: Array expression to slice. + start: One-based start index. + stop: One-based stop index following the backend adapter's `array_slice` contract. + """ + return nested_application("array_slice", [array_expr, start, stop]) diff --git a/src/functions/nested/array_sort.incn b/src/functions/nested/array_sort.incn new file mode 100644 index 0000000..73c7f75 --- /dev/null +++ b/src/functions/nested/array_sort.incn @@ -0,0 +1,33 @@ +"""Array sort helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_SORT_FUNCTION_ANCHOR + + +@function_registry.add("array_sort", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_sort", ARRAY_SORT_FUNCTION_ANCHOR), +)) +pub def array_sort(array_expr: ColumnExpr) -> ColumnExpr: + """ + Return an array with comparable elements sorted into backend-default order. + + Examples: + sorted_tags = array_sort(col("tags")) + + Parameters: + array_expr: Array expression to sort. + """ + return nested_application("array_sort", [array_expr]) diff --git a/src/functions/nested/array_union.incn b/src/functions/nested/array_union.incn new file mode 100644 index 0000000..6aaba15 --- /dev/null +++ b/src/functions/nested/array_union.incn @@ -0,0 +1,34 @@ +"""Array set-union helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_UNION_FUNCTION_ANCHOR + + +@function_registry.add("array_union", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_union", ARRAY_UNION_FUNCTION_ANCHOR), +)) +pub def array_union(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Return the distinct union of two arrays. + + Examples: + all_tags = array_union(col("left_tags"), col("right_tags")) + + Parameters: + left: First array expression. + right: Second array expression. + """ + return nested_application("array_union", [left, right]) diff --git a/src/functions/nested/arrays_overlap.incn b/src/functions/nested/arrays_overlap.incn new file mode 100644 index 0000000..cedebed --- /dev/null +++ b/src/functions/nested/arrays_overlap.incn @@ -0,0 +1,34 @@ +"""Array overlap predicate helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_HAS_ANY_FUNCTION_ANCHOR + + +@function_registry.add("arrays_overlap", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_has_any", ARRAY_HAS_ANY_FUNCTION_ANCHOR), +)) +pub def arrays_overlap(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Return whether two arrays have any elements in common. + + Examples: + has_shared_tag = arrays_overlap(col("tags"), array([str_lit("paid")])) + + Parameters: + left: First array expression. + right: Second array expression. + """ + return nested_application("arrays_overlap", [left, right]) diff --git a/src/functions/nested/cardinality.incn b/src/functions/nested/cardinality.incn new file mode 100644 index 0000000..9049ccc --- /dev/null +++ b/src/functions/nested/cardinality.incn @@ -0,0 +1,33 @@ +"""Nested collection cardinality helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import CARDINALITY_FUNCTION_ANCHOR + + +@function_registry.add("cardinality", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("cardinality", CARDINALITY_FUNCTION_ANCHOR), +)) +pub def cardinality(value: ColumnExpr) -> ColumnExpr: + """ + Return the number of entries in an array or map expression. + + Examples: + tag_count = cardinality(col("tags")) + + Parameters: + value: Array or map expression to size. + """ + return nested_application("cardinality", [value]) diff --git a/src/functions/nested/common.incn b/src/functions/nested/common.incn new file mode 100644 index 0000000..319c115 --- /dev/null +++ b/src/functions/nested/common.incn @@ -0,0 +1,33 @@ +"""Shared implementation helpers for nested scalar functions.""" + +from rust::incan_stdlib::errors import raise_value_error +from functions.registry import registered_application +from projection_builders import ColumnExpr, str_expr + + +pub def nested_application(canonical_name: str, arguments: list[ColumnExpr]) -> ColumnExpr: + """Build one registry-backed nested scalar function application.""" + return registered_application(canonical_name, arguments) + + +pub def require_non_empty_args(arguments: list[ColumnExpr]) -> None: + """Reject empty variadic nested constructors that cannot infer a value type.""" + if len(arguments) == 0: + return raise_value_error("nested constructor requires at least one scalar expression") + return + + +pub def named_struct_arguments(field_names: list[str], values: list[ColumnExpr]) -> list[ColumnExpr]: + """Build alternating field-name/value arguments for a named struct function call.""" + if len(field_names) == 0: + return raise_value_error("named_struct requires at least one field") + if len(field_names) != len(values): + return raise_value_error("named_struct requires one value for each field name") + + mut arguments: list[ColumnExpr] = [] + for idx, field_name in enumerate(field_names): + if len(field_name) == 0: + return raise_value_error("named_struct field names must be non-empty") + arguments.append(str_expr(field_name)) + arguments.append(values[idx]) + return arguments diff --git a/src/functions/nested/element_at.incn b/src/functions/nested/element_at.incn new file mode 100644 index 0000000..e726baa --- /dev/null +++ b/src/functions/nested/element_at.incn @@ -0,0 +1,34 @@ +"""Array element access helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import ARRAY_ELEMENT_FUNCTION_ANCHOR + + +@function_registry.add("element_at", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("array_element", ARRAY_ELEMENT_FUNCTION_ANCHOR), +)) +pub def element_at(array_expr: ColumnExpr, index: ColumnExpr) -> ColumnExpr: + """ + Return an array element by one-based index. + + Examples: + first_tag = element_at(col("tags"), int_lit(1)) + + Parameters: + array_expr: Array expression to access. + index: One-based element index. Negative indexes count from the end where supported by the backend adapter. + """ + return nested_application("element_at", [array_expr, index]) diff --git a/src/functions/nested/map_contains_key.incn b/src/functions/nested/map_contains_key.incn new file mode 100644 index 0000000..8d0d02e --- /dev/null +++ b/src/functions/nested/map_contains_key.incn @@ -0,0 +1,35 @@ +"""Map key containment predicate helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.nested.cardinality import cardinality +from functions.nested.map_extract import map_extract +from functions.operators.gt import gt +from functions.registry import function_registry +from projection_builders import ColumnExpr, int_expr + + +@function_registry.add("map_contains_key", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + rewrite_mapping("gt(cardinality(map_extract(map_expr, key)), int_expr(0))"), +)) +pub def map_contains_key(map_expr: ColumnExpr, key: ColumnExpr) -> ColumnExpr: + """ + Return whether a map contains a key. + + Examples: + has_status = map_contains_key(col("attributes"), str_lit("status")) + + Parameters: + map_expr: Map expression to inspect. + key: Key expression to look up. + """ + return gt(cardinality(map_extract(map_expr, key)), int_expr(0)) diff --git a/src/functions/nested/map_entries.incn b/src/functions/nested/map_entries.incn new file mode 100644 index 0000000..49c12ff --- /dev/null +++ b/src/functions/nested/map_entries.incn @@ -0,0 +1,33 @@ +"""Map entries helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import MAP_ENTRIES_FUNCTION_ANCHOR + + +@function_registry.add("map_entries", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("map_entries", MAP_ENTRIES_FUNCTION_ANCHOR), +)) +pub def map_entries(map_expr: ColumnExpr) -> ColumnExpr: + """ + Return a map as an array of key/value entry structs. + + Examples: + entries = map_entries(col("attributes")) + + Parameters: + map_expr: Map expression to inspect. + """ + return nested_application("map_entries", [map_expr]) diff --git a/src/functions/nested/map_extract.incn b/src/functions/nested/map_extract.incn new file mode 100644 index 0000000..30e4e7b --- /dev/null +++ b/src/functions/nested/map_extract.incn @@ -0,0 +1,34 @@ +"""Map key extraction helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import MAP_EXTRACT_FUNCTION_ANCHOR + + +@function_registry.add("map_extract", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("map_extract", MAP_EXTRACT_FUNCTION_ANCHOR), +)) +pub def map_extract(map_expr: ColumnExpr, key: ColumnExpr) -> ColumnExpr: + """ + Return the value-list associated with a map key. + + Examples: + status_value = map_extract(col("attributes"), str_lit("status")) + + Parameters: + map_expr: Map expression to inspect. + key: Key expression to look up. + """ + return nested_application("map_extract", [map_expr, key]) diff --git a/src/functions/nested/map_from_arrays.incn b/src/functions/nested/map_from_arrays.incn new file mode 100644 index 0000000..1c94165 --- /dev/null +++ b/src/functions/nested/map_from_arrays.incn @@ -0,0 +1,34 @@ +"""Map construction helper from key and value arrays.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import MAP_FUNCTION_ANCHOR + + +@function_registry.add("map_from_arrays", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("map", MAP_FUNCTION_ANCHOR), +)) +pub def map_from_arrays(keys: ColumnExpr, values: ColumnExpr) -> ColumnExpr: + """ + Build a map from equal-length key and value arrays. + + Examples: + attrs = map_from_arrays(array([str_lit("status")]), array([col("status")])) + + Parameters: + keys: Array expression containing non-null map keys. + values: Array expression containing map values. + """ + return nested_application("map_from_arrays", [keys, values]) diff --git a/src/functions/nested/map_keys.incn b/src/functions/nested/map_keys.incn new file mode 100644 index 0000000..d223229 --- /dev/null +++ b/src/functions/nested/map_keys.incn @@ -0,0 +1,33 @@ +"""Map keys helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import MAP_KEYS_FUNCTION_ANCHOR + + +@function_registry.add("map_keys", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("map_keys", MAP_KEYS_FUNCTION_ANCHOR), +)) +pub def map_keys(map_expr: ColumnExpr) -> ColumnExpr: + """ + Return the keys of a map as an array. + + Examples: + keys = map_keys(col("attributes")) + + Parameters: + map_expr: Map expression to inspect. + """ + return nested_application("map_keys", [map_expr]) diff --git a/src/functions/nested/map_values.incn b/src/functions/nested/map_values.incn new file mode 100644 index 0000000..35ec19a --- /dev/null +++ b/src/functions/nested/map_values.incn @@ -0,0 +1,33 @@ +"""Map values helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import MAP_VALUES_FUNCTION_ANCHOR + + +@function_registry.add("map_values", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("map_values", MAP_VALUES_FUNCTION_ANCHOR), +)) +pub def map_values(map_expr: ColumnExpr) -> ColumnExpr: + """ + Return the values of a map as an array. + + Examples: + values = map_values(col("attributes")) + + Parameters: + map_expr: Map expression to inspect. + """ + return nested_application("map_values", [map_expr]) diff --git a/src/functions/nested/mod.incn b/src/functions/nested/mod.incn new file mode 100644 index 0000000..bdbdff1 --- /dev/null +++ b/src/functions/nested/mod.incn @@ -0,0 +1,24 @@ +"""Nested scalar function helpers for arrays, maps, and structs.""" + +pub from functions.nested.array import array +pub from functions.nested.array_contains import array_contains +pub from functions.nested.array_distinct import array_distinct +pub from functions.nested.array_except import array_except +pub from functions.nested.array_flatten import array_flatten +pub from functions.nested.array_intersect import array_intersect +pub from functions.nested.array_join import array_join +pub from functions.nested.array_position import array_position +pub from functions.nested.array_reverse import array_reverse +pub from functions.nested.array_slice import array_slice +pub from functions.nested.array_sort import array_sort +pub from functions.nested.array_union import array_union +pub from functions.nested.arrays_overlap import arrays_overlap +pub from functions.nested.cardinality import cardinality +pub from functions.nested.element_at import element_at +pub from functions.nested.map_contains_key import map_contains_key +pub from functions.nested.map_entries import map_entries +pub from functions.nested.map_extract import map_extract +pub from functions.nested.map_from_arrays import map_from_arrays +pub from functions.nested.map_keys import map_keys +pub from functions.nested.map_values import map_values +pub from functions.nested.named_struct import named_struct diff --git a/src/functions/nested/named_struct.incn b/src/functions/nested/named_struct.incn new file mode 100644 index 0000000..2f18a30 --- /dev/null +++ b/src/functions/nested/named_struct.incn @@ -0,0 +1,34 @@ +"""Named struct construction helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.nested.common import named_struct_arguments, nested_application +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import NAMED_STRUCT_FUNCTION_ANCHOR + + +@function_registry.add("named_struct", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("named_struct", NAMED_STRUCT_FUNCTION_ANCHOR), +)) +pub def named_struct(field_names: list[str], values: list[ColumnExpr]) -> ColumnExpr: + """ + Build a struct expression with explicit field names. + + Examples: + event = named_struct(["status", "amount"], [col("status"), col("amount")]) + + Parameters: + field_names: Non-empty struct field names. + values: Field value expressions in the same order as `field_names`. + """ + return nested_application("named_struct", named_struct_arguments(field_names, values)) diff --git a/src/lib.incn b/src/lib.incn index 7b96f4a..a1767b6 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -59,6 +59,28 @@ pub from functions.math.abs import abs pub from functions.math.ceil import ceil pub from functions.math.floor import floor pub from functions.math.round import round +pub from functions.nested.array import array +pub from functions.nested.array_contains import array_contains +pub from functions.nested.array_distinct import array_distinct +pub from functions.nested.array_except import array_except +pub from functions.nested.array_flatten import array_flatten +pub from functions.nested.array_intersect import array_intersect +pub from functions.nested.array_join import array_join +pub from functions.nested.array_position import array_position +pub from functions.nested.array_reverse import array_reverse +pub from functions.nested.array_slice import array_slice +pub from functions.nested.array_sort import array_sort +pub from functions.nested.array_union import array_union +pub from functions.nested.arrays_overlap import arrays_overlap +pub from functions.nested.cardinality import cardinality +pub from functions.nested.element_at import element_at +pub from functions.nested.map_contains_key import map_contains_key +pub from functions.nested.map_entries import map_entries +pub from functions.nested.map_extract import map_extract +pub from functions.nested.map_from_arrays import map_from_arrays +pub from functions.nested.map_keys import map_keys +pub from functions.nested.map_values import map_values +pub from functions.nested.named_struct import named_struct pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn index 0ad86f1..490f93c 100644 --- a/src/substrait/function_extensions.incn +++ b/src/substrait/function_extensions.incn @@ -54,6 +54,27 @@ pub const ABS_FUNCTION_ANCHOR: u32 = 27 pub const CEIL_FUNCTION_ANCHOR: u32 = 28 pub const FLOOR_FUNCTION_ANCHOR: u32 = 29 pub const ROUND_FUNCTION_ANCHOR: u32 = 30 +pub const MAKE_ARRAY_FUNCTION_ANCHOR: u32 = 31 +pub const CARDINALITY_FUNCTION_ANCHOR: u32 = 32 +pub const ARRAY_HAS_FUNCTION_ANCHOR: u32 = 33 +pub const ARRAY_POSITION_FUNCTION_ANCHOR: u32 = 34 +pub const ARRAY_ELEMENT_FUNCTION_ANCHOR: u32 = 35 +pub const ARRAY_SORT_FUNCTION_ANCHOR: u32 = 36 +pub const ARRAY_DISTINCT_FUNCTION_ANCHOR: u32 = 37 +pub const ARRAY_EXCEPT_FUNCTION_ANCHOR: u32 = 38 +pub const ARRAY_INTERSECT_FUNCTION_ANCHOR: u32 = 39 +pub const ARRAY_UNION_FUNCTION_ANCHOR: u32 = 40 +pub const ARRAY_TO_STRING_FUNCTION_ANCHOR: u32 = 41 +pub const ARRAY_SLICE_FUNCTION_ANCHOR: u32 = 42 +pub const ARRAY_REVERSE_FUNCTION_ANCHOR: u32 = 43 +pub const MAP_FUNCTION_ANCHOR: u32 = 44 +pub const MAP_KEYS_FUNCTION_ANCHOR: u32 = 45 +pub const MAP_VALUES_FUNCTION_ANCHOR: u32 = 46 +pub const MAP_ENTRIES_FUNCTION_ANCHOR: u32 = 47 +pub const MAP_EXTRACT_FUNCTION_ANCHOR: u32 = 48 +pub const NAMED_STRUCT_FUNCTION_ANCHOR: u32 = 49 +pub const ARRAY_HAS_ANY_FUNCTION_ANCHOR: u32 = 50 +pub const ARRAY_FLATTEN_FUNCTION_ANCHOR: u32 = 51 const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 8d673f3..22f8739 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -11,6 +11,19 @@ from functions import ( asc, asc_nulls_first, asc_nulls_last, + array, + array_contains, + array_distinct, + array_except, + array_flatten, + array_intersect, + array_join, + array_position, + array_reverse, + array_slice, + array_sort, + array_union, + arrays_overlap, avg, between, bool_expr, @@ -20,6 +33,7 @@ from functions import ( ceil, col, coalesce, + cardinality, count, count_distinct, count_expr, @@ -30,6 +44,7 @@ from functions import ( div, eq, equal_null, + element_at, floor, float_expr, function_registry_canonical_names, @@ -50,12 +65,19 @@ from functions import ( lit, lt, lte, + map_contains_key, + map_entries, + map_extract, + map_from_arrays, + map_keys, + map_values, max, min, modulo, mul, ne, neg, + named_struct, not_, nullif, or_, @@ -93,8 +115,22 @@ from substrait.function_extensions import ( ABS_FUNCTION_ANCHOR, ADD_FUNCTION_ANCHOR, AND_FUNCTION_ANCHOR, + ARRAY_DISTINCT_FUNCTION_ANCHOR, + ARRAY_ELEMENT_FUNCTION_ANCHOR, + ARRAY_EXCEPT_FUNCTION_ANCHOR, + ARRAY_FLATTEN_FUNCTION_ANCHOR, + ARRAY_HAS_FUNCTION_ANCHOR, + ARRAY_HAS_ANY_FUNCTION_ANCHOR, + ARRAY_INTERSECT_FUNCTION_ANCHOR, + ARRAY_POSITION_FUNCTION_ANCHOR, + ARRAY_REVERSE_FUNCTION_ANCHOR, + ARRAY_SLICE_FUNCTION_ANCHOR, + ARRAY_SORT_FUNCTION_ANCHOR, + ARRAY_TO_STRING_FUNCTION_ANCHOR, + ARRAY_UNION_FUNCTION_ANCHOR, AVG_FUNCTION_ANCHOR, BETWEEN_FUNCTION_ANCHOR, + CARDINALITY_FUNCTION_ANCHOR, CEIL_FUNCTION_ANCHOR, COALESCE_FUNCTION_ANCHOR, COUNT_FUNCTION_ANCHOR, @@ -109,10 +145,17 @@ from substrait.function_extensions import ( IS_NULL_FUNCTION_ANCHOR, LT_FUNCTION_ANCHOR, LTE_FUNCTION_ANCHOR, + MAKE_ARRAY_FUNCTION_ANCHOR, + MAP_ENTRIES_FUNCTION_ANCHOR, + MAP_EXTRACT_FUNCTION_ANCHOR, + MAP_FUNCTION_ANCHOR, + MAP_KEYS_FUNCTION_ANCHOR, + MAP_VALUES_FUNCTION_ANCHOR, MAX_FUNCTION_ANCHOR, MIN_FUNCTION_ANCHOR, MODULUS_FUNCTION_ANCHOR, MULTIPLY_FUNCTION_ANCHOR, + NAMED_STRUCT_FUNCTION_ANCHOR, NEGATE_FUNCTION_ANCHOR, NOT_EQUAL_FUNCTION_ANCHOR, NOT_FUNCTION_ANCHOR, @@ -180,18 +223,21 @@ def _local_entry_by_namespace_and_name_or_fail( def _expected_registry_names() -> list[str]: """Return the expected registered public helper names.""" - return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round"] + return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct"] def _expected_substrait_mapped_names() -> list[str]: """Return helpers with concrete Substrait extension-function mappings.""" - return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round"] + return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct"] def _exercise_current_public_helpers() -> None: """Touch each current registered helper so runtime registry tests observe loaded modules only.""" amount = col("amount") status = col("status") + tags = array([str_lit("paid"), status]) + backup_tags = array([str_lit("open"), str_lit("paid")]) + attr_map = map_from_arrays(array([str_lit("status")]), array([status])) lit(1) int_expr(1) float_expr(1.5) @@ -247,6 +293,26 @@ def _exercise_current_public_helpers() -> None: ceil(amount) floor(amount) round(amount) + array_contains(tags, str_lit("paid")) + array_distinct(tags) + array_except(tags, backup_tags) + array_flatten(array([tags, backup_tags])) + array_intersect(tags, backup_tags) + array_join(tags, str_lit("|")) + array_position(tags, str_lit("paid")) + array_reverse(tags) + array_slice(tags, int_lit(1), int_lit(2)) + array_sort(tags) + array_union(tags, backup_tags) + arrays_overlap(tags, backup_tags) + cardinality(tags) + element_at(tags, int_lit(1)) + map_contains_key(attr_map, str_lit("status")) + map_entries(attr_map) + map_extract(attr_map, str_lit("status")) + map_keys(attr_map) + map_values(attr_map) + named_struct(["status", "amount"], [status, amount]) return @@ -519,6 +585,27 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non _assert_extension_mapping("ceil", "ceil", CEIL_FUNCTION_ANCHOR) _assert_extension_mapping("floor", "floor", FLOOR_FUNCTION_ANCHOR) _assert_extension_mapping("round", "round", ROUND_FUNCTION_ANCHOR) + _assert_extension_mapping("array", "make_array", MAKE_ARRAY_FUNCTION_ANCHOR) + _assert_extension_mapping("array_contains", "array_has", ARRAY_HAS_FUNCTION_ANCHOR) + _assert_extension_mapping("array_distinct", "array_distinct", ARRAY_DISTINCT_FUNCTION_ANCHOR) + _assert_extension_mapping("array_except", "array_except", ARRAY_EXCEPT_FUNCTION_ANCHOR) + _assert_extension_mapping("array_flatten", "flatten", ARRAY_FLATTEN_FUNCTION_ANCHOR) + _assert_extension_mapping("array_intersect", "array_intersect", ARRAY_INTERSECT_FUNCTION_ANCHOR) + _assert_extension_mapping("array_join", "array_to_string", ARRAY_TO_STRING_FUNCTION_ANCHOR) + _assert_extension_mapping("array_position", "array_position", ARRAY_POSITION_FUNCTION_ANCHOR) + _assert_extension_mapping("array_reverse", "array_reverse", ARRAY_REVERSE_FUNCTION_ANCHOR) + _assert_extension_mapping("array_slice", "array_slice", ARRAY_SLICE_FUNCTION_ANCHOR) + _assert_extension_mapping("array_sort", "array_sort", ARRAY_SORT_FUNCTION_ANCHOR) + _assert_extension_mapping("array_union", "array_union", ARRAY_UNION_FUNCTION_ANCHOR) + _assert_extension_mapping("arrays_overlap", "array_has_any", ARRAY_HAS_ANY_FUNCTION_ANCHOR) + _assert_extension_mapping("cardinality", "cardinality", CARDINALITY_FUNCTION_ANCHOR) + _assert_extension_mapping("element_at", "array_element", ARRAY_ELEMENT_FUNCTION_ANCHOR) + _assert_extension_mapping("map_entries", "map_entries", MAP_ENTRIES_FUNCTION_ANCHOR) + _assert_extension_mapping("map_extract", "map_extract", MAP_EXTRACT_FUNCTION_ANCHOR) + _assert_extension_mapping("map_from_arrays", "map", MAP_FUNCTION_ANCHOR) + _assert_extension_mapping("map_keys", "map_keys", MAP_KEYS_FUNCTION_ANCHOR) + _assert_extension_mapping("map_values", "map_values", MAP_VALUES_FUNCTION_ANCHOR) + _assert_extension_mapping("named_struct", "named_struct", NAMED_STRUCT_FUNCTION_ANCHOR) def test_function_registry__ordering_helpers_are_contextual_sort_fields() -> None: @@ -568,6 +655,7 @@ def test_function_registry__rewrite_mappings_identify_non_extension_helpers() -> assert always_true_entry.substrait.kind == SubstraitMappingKind.Rewrite, "always_true should lower as a literal rewrite" assert always_false_entry.substrait.kind == SubstraitMappingKind.Rewrite, "always_false should lower as a literal rewrite" _assert_rewrite_mapping("is_not_nan", "not_(is_nan(expr))") + _assert_rewrite_mapping("map_contains_key", "gt(cardinality(map_extract(map_expr, key)), int_expr(0))") assert always_true_entry.null_behavior == FunctionNullBehavior.Predicate, "predicate helpers should expose predicate null behavior" assert always_false_entry.null_behavior == FunctionNullBehavior.Predicate, "predicate helpers should expose predicate null behavior" diff --git a/tests/test_nested_data_functions.incn b/tests/test_nested_data_functions.incn new file mode 100644 index 0000000..f0b53be --- /dev/null +++ b/tests/test_nested_data_functions.incn @@ -0,0 +1,124 @@ +"""Test: RFC 020 nested scalar helper surface.""" + +from std.testing import assert_raises +from functions import ( + array, + array_contains, + array_distinct, + array_except, + array_flatten, + array_intersect, + array_join, + array_position, + array_reverse, + array_slice, + array_sort, + array_union, + arrays_overlap, + cardinality, + col, + element_at, + int_lit, + map_contains_key, + map_entries, + map_extract, + map_from_arrays, + map_keys, + map_values, + named_struct, + str_lit, +) +from function_registry import function_ref_for +from projection_builders import ( + ColumnExpr, + ColumnExprKind, + column_expr_argument_count, + column_expr_function_name, + column_expr_function_ref, + column_expr_kind, +) + + +def _assert_nested_application(expr: ColumnExpr, expected_name: str, expected_args: int) -> None: + """Assert one helper uses the shared registry-backed scalar application node.""" + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction, f"{expected_name} should use the scalar function kind" + assert column_expr_function_name(expr) == expected_name, f"{expected_name} should preserve its canonical name" + assert column_expr_function_ref(expr) == function_ref_for(expected_name), "nested helper should preserve its registry function ref" + assert column_expr_argument_count(expr) == expected_args, f"{expected_name} should carry its scalar arguments" + + +def _call_empty_array() -> None: + """Call array with no values for ValueError assertions.""" + array([]) + return + + +def _call_empty_named_struct() -> None: + """Call named_struct with no fields for ValueError assertions.""" + named_struct([], []) + return + + +def _call_mismatched_named_struct() -> None: + """Call named_struct with mismatched fields and values for ValueError assertions.""" + named_struct(["status"], []) + return + + +def _call_empty_named_struct_field() -> None: + """Call named_struct with an empty field name for ValueError assertions.""" + named_struct([""], [str_lit("paid")]) + return + + +def test_nested_data_functions__array_helpers_share_scalar_application_node() -> None: + """Assert array helpers use the shared registry-backed scalar expression model.""" + # -- Arrange -- + tags = array([str_lit("paid"), col("status")]) + other_tags = array([str_lit("open"), str_lit("paid")]) + + # -- Act / Assert -- + _assert_nested_application(tags, "array", 2) + _assert_nested_application(cardinality(tags), "cardinality", 1) + _assert_nested_application(array_contains(tags, str_lit("paid")), "array_contains", 2) + _assert_nested_application(array_position(tags, str_lit("paid")), "array_position", 2) + _assert_nested_application(element_at(tags, int_lit(1)), "element_at", 2) + _assert_nested_application(array_sort(tags), "array_sort", 1) + _assert_nested_application(array_distinct(tags), "array_distinct", 1) + _assert_nested_application(array_except(tags, other_tags), "array_except", 2) + _assert_nested_application(array_flatten(array([tags, other_tags])), "array_flatten", 1) + _assert_nested_application(array_intersect(tags, other_tags), "array_intersect", 2) + _assert_nested_application(array_union(tags, other_tags), "array_union", 2) + _assert_nested_application(arrays_overlap(tags, other_tags), "arrays_overlap", 2) + _assert_nested_application(array_join(tags, str_lit("|")), "array_join", 2) + _assert_nested_application(array_slice(tags, int_lit(1), int_lit(2)), "array_slice", 3) + _assert_nested_application(array_reverse(tags), "array_reverse", 1) + + +def test_nested_data_functions__map_and_struct_helpers_share_scalar_application_node() -> None: + """Assert map and struct helpers use scalar expressions rather than relation-shaping nodes.""" + # -- Arrange -- + keys = array([str_lit("status")]) + values = array([col("status")]) + attr_map = map_from_arrays(keys, values) + contains_key = map_contains_key(attr_map, str_lit("status")) + + # -- Act / Assert -- + _assert_nested_application(attr_map, "map_from_arrays", 2) + _assert_nested_application(map_entries(attr_map), "map_entries", 1) + _assert_nested_application(map_extract(attr_map, str_lit("status")), "map_extract", 2) + _assert_nested_application(map_keys(attr_map), "map_keys", 1) + _assert_nested_application(map_values(attr_map), "map_values", 1) + _assert_nested_application(named_struct(["status", "amount"], [col("status"), col("amount")]), "named_struct", 4) + assert column_expr_kind(contains_key) == ColumnExprKind.ScalarFunction, "map_contains_key rewrite should still be scalar" + assert column_expr_function_name(contains_key) == "gt", "map_contains_key should lower through its documented predicate rewrite" + assert column_expr_function_ref(contains_key) == function_ref_for("gt"), "rewrite should use the registered greater-than helper" + + +def test_nested_data_functions__constructor_shape_errors_raise_value_error() -> None: + """Assert nested constructors reject shapes that cannot produce typed scalar values.""" + # -- Arrange / Act / Assert -- + assert_raises[ValueError](_call_empty_array) + assert_raises[ValueError](_call_empty_named_struct) + assert_raises[ValueError](_call_mismatched_named_struct) + assert_raises[ValueError](_call_empty_named_struct_field) diff --git a/tests/test_session_projection.incn b/tests/test_session_projection.incn index 0c2777e..fb6207e 100644 --- a/tests/test_session_projection.incn +++ b/tests/test_session_projection.incn @@ -3,6 +3,11 @@ from functions import ( abs, add, + array, + array_contains, + array_distinct, + array_join, + array_position, case_when, cast, ceil, @@ -20,6 +25,8 @@ from functions import ( round, sub, try_cast, + cardinality, + element_at, ) from dataset import DataFrame, LazyFrame from session import Session, SessionErrorKind @@ -185,6 +192,41 @@ def test_session_projection__collect_executes_common_math_scalar_projection_func assert payload.contains("3"), "round projection should include round(10 / 4.0)" +def test_session_projection__collect_executes_nested_scalar_projection_functions() -> None: + """collect should execute RFC 020 nested scalar helpers through DataFusion.""" + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + tags = array([lit("paid"), col("customer_id"), lit("paid")]) + with_count = lazy.with_column("tag_count", cardinality(tags)) + with_contains = with_count.with_column("has_paid", array_contains(tags, lit("paid"))) + with_first = with_contains.with_column("first_tag", element_at(tags, lit(1))) + with_position = with_first.with_column("paid_position", array_position(tags, lit("paid"))) + projected = with_position.with_column("joined_tags", array_join(array_distinct(tags), lit("|"))) + df = _collect_or_fail(session, projected) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 3, "nested scalar projections should preserve the input rows" + assert len(resolved) == 7, "projection should expose all appended nested outputs" + assert payload.contains("tag_count"), "cardinality projection should materialize its alias" + assert payload.contains("has_paid"), "array_contains projection should materialize its alias" + assert payload.contains("first_tag"), "element_at projection should materialize its alias" + assert payload.contains("paid_position"), "array_position projection should materialize its alias" + assert payload.contains("joined_tags"), "array_join projection should materialize its alias" + assert payload.contains("3"), "cardinality should report three input array elements" + assert payload.contains("true"), "array_contains should find the paid tag" + assert payload.contains("paid"), "element_at should return the first tag" + assert payload.contains("1"), "array_position should use one-based positions" + assert payload.contains("paid|A"), "array_join should materialize distinct string tags" + + def test_session_projection__collect_executes_identity_select() -> None: # -- Arrange -- mut session = Session.default() diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 0e69548..17a70c6 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -6,6 +6,19 @@ from functions import ( add, always_true, and_, + array, + array_contains, + array_distinct, + array_except, + array_flatten, + array_intersect, + array_join, + array_position, + array_reverse, + array_slice, + array_sort, + array_union, + arrays_overlap, asc, asc_nulls_last, avg, @@ -34,12 +47,19 @@ from functions import ( lit, lt, lte, + map_contains_key, + map_entries, + map_extract, + map_from_arrays, + map_keys, + map_values, max, min, modulo, mul, ne, neg, + named_struct, not_, nullif, or_, @@ -47,6 +67,8 @@ from functions import ( sub, sum, try_cast, + cardinality, + element_at, ) from projection_builders import ColumnExpr, with_column_assignment from substrait.errors import SubstraitLoweringErrorKind @@ -374,6 +396,38 @@ def test_plan__core_scalar_extension_mappings_lower_to_substrait() -> None: _assert_scalar_expr_lowers(round(div(col("amount"), lit(4.0)))) +def test_plan__nested_scalar_extension_mappings_lower_to_substrait() -> None: + """Assert RFC 020 nested scalar helpers emit Substrait scalar functions.""" + # -- Arrange -- + tags = array([lit("paid"), lit("open"), col("status")]) + other_tags = array([lit("paid"), lit("closed")]) + attr_map = map_from_arrays(array([lit("status")]), array([col("status")])) + + # -- Act / Assert -- + _assert_scalar_expr_lowers(tags) + _assert_scalar_expr_lowers(cardinality(tags)) + _assert_scalar_expr_lowers(array_contains(tags, lit("paid"))) + _assert_scalar_expr_lowers(array_position(tags, lit("paid"))) + _assert_scalar_expr_lowers(element_at(tags, lit(1))) + _assert_scalar_expr_lowers(array_sort(tags)) + _assert_scalar_expr_lowers(array_distinct(tags)) + _assert_scalar_expr_lowers(array_except(tags, other_tags)) + _assert_scalar_expr_lowers(array_flatten(array([tags, other_tags]))) + _assert_scalar_expr_lowers(array_intersect(tags, other_tags)) + _assert_scalar_expr_lowers(array_union(tags, other_tags)) + _assert_scalar_expr_lowers(arrays_overlap(tags, other_tags)) + _assert_scalar_expr_lowers(array_join(tags, lit("|"))) + _assert_scalar_expr_lowers(array_slice(tags, lit(1), lit(2))) + _assert_scalar_expr_lowers(array_reverse(tags)) + _assert_scalar_expr_lowers(attr_map) + _assert_scalar_expr_lowers(map_entries(attr_map)) + _assert_scalar_expr_lowers(map_extract(attr_map, lit("status"))) + _assert_scalar_expr_lowers(map_keys(attr_map)) + _assert_scalar_expr_lowers(map_values(attr_map)) + _assert_scalar_expr_lowers(map_contains_key(attr_map, lit("status"))) + _assert_scalar_expr_lowers(named_struct(["status", "amount"], [col("status"), col("amount")])) + + def test_plan__aggregate_rel_surfaces_group_and_measure_output_columns() -> None: # -- Arrange -- _register_orders_schema() From bbe913e95711eccd4efa4732f5ddd4f2aa359441 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Tue, 26 May 2026 00:57:28 +0200 Subject: [PATCH 4/6] feature - implement RFC 021 generator foundations (#38) --- docs/language/reference/dataset_methods.md | 4 +- .../reference/functions/generators.md | 32 ++++ docs/language/reference/functions/index.md | 4 +- .../reference/substrait/operator_catalog.md | 3 + docs/release_notes/v0_1.md | 1 + docs/rfcs/021_generator_table_functions.md | 34 ++-- docs/rfcs/README.md | 2 +- src/dataset/mod.incn | 26 +++ src/dataset/ops.incn | 12 ++ src/function_registry.incn | 13 ++ src/functions/generators/explode.incn | 42 +++++ src/functions/generators/explode_outer.incn | 42 +++++ src/functions/generators/mod.incn | 6 + src/functions/generators/posexplode.incn | 44 +++++ .../generators/posexplode_outer.incn | 44 +++++ src/functions/mod.incn | 4 + src/generator_builders.incn | 150 ++++++++++++++++++ src/lib.incn | 30 +++- src/prism/lower.incn | 7 + src/prism/mod.incn | 38 +++++ src/prism/output_columns.incn | 11 ++ src/prism/rewrite.incn | 6 + src/prism/store.incn | 54 +++++++ src/prism/types.incn | 3 + src/substrait/expr_lowering.incn | 6 + src/substrait/extensions.incn | 25 ++- src/substrait/function_extensions.incn | 20 ++- src/substrait/inspect.incn | 38 ++++- src/substrait/mod.incn | 7 + src/substrait/relations.incn | 90 +++++++++++ tests/test_dataset.incn | 36 ++++- tests/test_function_registry.incn | 35 +++- tests/test_generator_functions.incn | 55 +++++++ tests/test_prism.incn | 38 ++++- tests/test_substrait_plan.incn | 19 ++- 35 files changed, 954 insertions(+), 27 deletions(-) create mode 100644 docs/language/reference/functions/generators.md create mode 100644 src/functions/generators/explode.incn create mode 100644 src/functions/generators/explode_outer.incn create mode 100644 src/functions/generators/mod.incn create mode 100644 src/functions/generators/posexplode.incn create mode 100644 src/functions/generators/posexplode_outer.incn create mode 100644 src/generator_builders.incn create mode 100644 tests/test_generator_functions.incn diff --git a/docs/language/reference/dataset_methods.md b/docs/language/reference/dataset_methods.md index ab4926e..9a9a701 100644 --- a/docs/language/reference/dataset_methods.md +++ b/docs/language/reference/dataset_methods.md @@ -19,9 +19,10 @@ The Substrait helper surface behind these methods is split by semantic role: | `with_column` | `def with_column(self, name: str, expr: ColumnExpr) -> Self` | Add or replace one projected column using a scalar expression. | | `group_by` | `def group_by(self, columns: list[ColumnExpr]) -> Self` | Define grouping keys using scalar expressions. | | `agg` | `def agg(self, measures: list[AggregateMeasure]) -> Self` | Apply aggregate measures over the current relation or current grouping. | +| `generate` | `def generate(self, generator: GeneratorApplication) -> Self` | Apply a relation-shaping generator such as `explode(...)` with explicit output aliases. | | `order_by` | `def order_by(self, columns: list[ColumnExpr]) -> Self` | Sort rows by scalar expressions or ordering helpers such as `asc(...)` and `desc(...)`. | | `limit` | `def limit(self, n: int) -> Self` | Cap row count. | -| `explode` | `def explode(self) -> Self` | Expand a nested list column into rows. | +| `explode` | `def explode(self) -> Self` | Compatibility marker for the older EXPLODE extension path. Prefer `generate(explode(...))`. | ## `with_column` @@ -67,6 +68,7 @@ def enrich(orders: LazyFrame[Order]) -> LazyFrame[Order]: - `join(...)` is constrained to same-carrier inputs and the boolean join predicate surface shown in the signature. - `select(...)` preserves projection shape; explicit projection lists are represented today through `with_column(...)` and scalar-expression builders. +- `generate(...)` preserves all input columns and appends generated output aliases. Alias collisions are rejected during planning/lowering. - `DataFrame[T]` exposes materialized metadata and preview text; row-level accessors belong to the materialized DataFrame API surface. - Query-block and scoped DSL surfaces lower into these builder APIs rather than defining separate method semantics. diff --git a/docs/language/reference/functions/generators.md b/docs/language/reference/functions/generators.md new file mode 100644 index 0000000..844cb20 --- /dev/null +++ b/docs/language/reference/functions/generators.md @@ -0,0 +1,32 @@ +# Generator and Table-Valued Functions (Reference) + +Generators are relation-shaping operations. They are registry-backed like scalar and aggregate helpers, but they return +`GeneratorApplication` values and must be applied through a relation method such as `generate(...)`. + +```incan +from pub::inql import LazyFrame +from pub::inql.functions import col, explode +from models import Order + +def order_lines(orders: LazyFrame[Order]) -> LazyFrame[Order]: + return orders.generate(explode(col("line_items"), "line_item")) +``` + +The explicit generator surface currently includes: + +| Function | Output aliases | Relation effect | +| --- | --- | --- | +| `explode(expr, as_)` | one value column | Emits one row per array element; null or empty inputs emit zero rows. | +| `explode_outer(expr, as_)` | one value column | Preserves the input row for null or empty inputs and emits a null generated value. | +| `posexplode(expr, position_as, value_as)` | position and value columns | Emits one row per array element with a zero-based position column. | +| `posexplode_outer(expr, position_as, value_as)` | position and value columns | Outer positional explode with the same zero-based position rule. | + +Generator applications preserve input columns and append generated columns in declaration order. Generated aliases are +required, must be non-empty, and must not collide with existing input columns. + +The older zero-argument `DataSet.explode()` method remains available as a compatibility marker for the current Substrait +extension relation gap. New code should prefer `generate(explode(...))` so the relation-shaping function identity and +output schema are explicit. + +Nested scalar helpers such as `array_flatten(...)` remain scalar expressions. They do not expand rows and are documented +on the [nested data functions](nested.md) page. diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index f6347a8..e65ea90 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -7,11 +7,12 @@ Today the concrete shipped surfaces are documented here: - [Filter builders](../builders/filters.md) - [Aggregate builders](../builders/aggregates.md) - [Projection builders](../builders/projections.md) +- [Generator and table-valued functions](generators.md) - [Nested data functions](nested.md) The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, and nested data. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, and nested data. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. @@ -33,6 +34,7 @@ The registered helper surface currently includes: | `in_(...)`, `between(...)` | scalar | built-in membership/range lowering (`SingularOrList` and `between`) | | `abs(...)`, `ceil(...)`, `floor(...)`, `round(...)` | scalar | registered Substrait math scalar mappings; `round(...)` is currently the single-argument form | | `array(...)`, `cardinality(...)`, `array_contains(...)`, `arrays_overlap(...)`, `array_position(...)`, `element_at(...)`, `array_sort(...)`, `array_distinct(...)`, `array_except(...)`, `array_intersect(...)`, `array_union(...)`, `array_join(...)`, `array_slice(...)`, `array_reverse(...)`, `array_flatten(...)`, `map_from_arrays(...)`, `map_extract(...)`, `map_contains_key(...)`, `map_keys(...)`, `map_values(...)`, `map_entries(...)`, `named_struct(...)` | scalar | registered nested scalar helpers backed by Substrait extension mappings; `map_contains_key(...)` lowers as a documented predicate rewrite | +| `explode(...)`, `explode_outer(...)`, `posexplode(...)`, `posexplode_outer(...)` | generator | relation-extension mappings consumed by `generate(...)`; positional forms use zero-based positions | | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count()`, `count_expr(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions; `count_expr(...)` is a compatibility spelling for future `count(expr)` helper overloading | | `count_distinct(...)`, `count_if(...)` | aggregate | compatibility helpers that lower through aggregate modifiers over canonical `count` semantics | diff --git a/docs/language/reference/substrait/operator_catalog.md b/docs/language/reference/substrait/operator_catalog.md index 4560185..327ad49 100644 --- a/docs/language/reference/substrait/operator_catalog.md +++ b/docs/language/reference/substrait/operator_catalog.md @@ -81,6 +81,9 @@ Core Substrait does not define a portable unnest or explode `Rel` at the logical Current package-level RFC 002 boundary registration: - `https://inql.io/extensions/v0.1/unnest.yaml#explode` +- `https://inql.io/extensions/v0.1/unnest.yaml#explode_outer` +- `https://inql.io/extensions/v0.1/unnest.yaml#posexplode` +- `https://inql.io/extensions/v0.1/unnest.yaml#posexplode_outer` ### Pivot / unpivot diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index 2543685..5c23085 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -16,6 +16,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Core scalar functions:** RFC 015 adds registry-backed scalar function applications and the first core helper slice for casts, comparisons, boolean logic, null/NaN predicates, arithmetic, conditionals, membership/range predicates, and ordering expressions. Implemented helpers lower to Substrait IR through registry metadata, built-in Rex shapes, or structural sort-field lowering; DataFusion remains the first execution adapter rather than the semantic boundary. - **Common scalar functions:** The first RFC 018 slice adds registry-backed math helpers for `abs(...)`, `ceil(...)`, `floor(...)`, and single-argument `round(...)`, with Substrait mappings and DataFusion-backed execution coverage. - **Nested data functions:** RFC 020 adds registry-backed scalar helpers for array construction/access, cardinality, containment, overlap, sorting, set-like operations, joining, slicing, reversing, scalar array flattening, map construction/access, map key/value/entry extraction, map key containment, and named struct construction. These helpers lower through Substrait extension metadata and execute through the DataFusion-backed Session path without introducing generator semantics. +- **Generator functions:** RFC 021 adds registry-backed generator applications for `explode(...)`, `explode_outer(...)`, `posexplode(...)`, and `posexplode_outer(...)`. Generators remain relation-shaping operations applied with `generate(...)`; they preserve input columns, require explicit output aliases, and lower through the current Substrait extension-relation gap encoding. - **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. - **Function extension policy:** RFC 024 policy metadata now distinguishes portable core functions, namespaced extension-only functions, opt-in compatibility aliases, engine-specific functions, and rejected compatibility requests without adding an extension plugin system or backend-owned semantics. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. diff --git a/docs/rfcs/021_generator_table_functions.md b/docs/rfcs/021_generator_table_functions.md index b33febb..ad0039c 100644 --- a/docs/rfcs/021_generator_table_functions.md +++ b/docs/rfcs/021_generator_table_functions.md @@ -1,6 +1,6 @@ # InQL RFC 021: Generator and table-valued functions -- **Status:** Draft +- **Status:** In Progress - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -42,14 +42,15 @@ InQL already has an unnest/explode design direction through its Substrait work. ## Guide-level explanation (how authors think about it) -Authors should use generators when one input row may become multiple output rows: +Authors should use generators when one input row may become multiple output rows. In the current builder surface, +generators are constructed as explicit applications and then applied to a relation: ```incan -from pub::inql.functions import col +from pub::inql.functions import col, explode items = ( orders - .explode(col("line_items"), as_="line_item") + .generate(explode(col("line_items"), "line_item")) .select(["order_id", "line_item"]) ) ``` @@ -64,13 +65,13 @@ Generator functions must be registry entries with function class `generator` or `explode_outer(array_expr)` must preserve the input row when the input array is null or empty and must produce a null generated value according to its output schema. -`posexplode(array_expr)` and `posexplode_outer(array_expr)` must include a positional output column in addition to the generated element. The position origin must be specified before this RFC reaches Planned status. +`posexplode(array_expr)` and `posexplode_outer(array_expr)` must include a positional output column in addition to the generated element. Positional output is zero-based because `posexplode` follows the Spark-compatible naming convention rather than InQL's one-based scalar collection indexing rule. `inline(array_of_struct_expr)` must expand each struct element into output columns. `inline_outer` must preserve outer rows for null or empty input according to the outer generator rule. `stack` must construct multiple output rows from explicit expressions according to a declared row count and output schema. -`flatten` must be treated as a table-valued/generator operation when supported. Its exact input type, recursive behavior, path behavior, and output columns must be specified before it reaches Planned status. +`flatten` must be treated as a table-valued/generator operation when supported. Portable InQL does not yet define Snowflake-style recursive/path flattening; scalar `array_flatten(...)` remains part of RFC 020 and does not change row cardinality. Every generator must define output column names, output types, nullability, interaction with existing columns, and aliasing requirements. Name collisions must be diagnosed unless an explicit overwrite or qualification rule applies. @@ -78,11 +79,11 @@ Every generator must define output column names, output types, nullability, inte ### Syntax -Generators may appear as dataframe relation methods, query-block clauses, or table-valued function forms. Regardless of syntax, they must lower to relation-shaping operations. +Generators may appear as dataframe relation methods, query-block clauses, or table-valued function forms. Regardless of syntax, they must lower to relation-shaping operations. The initial builder API uses `generate(generator)` to avoid overloading the existing zero-argument compatibility `explode()` method. ### Semantics -Generator output schema is part of the relation schema after the generator operation. Generators may preserve input columns, replace a nested column with generated columns, or produce a new relation depending on the function and syntax, but the behavior must be explicit. +Generator output schema is part of the relation schema after the generator operation. The initial portable generator applications preserve all input columns and append generated output columns in declaration order. Generated aliases are required, must be non-empty, and must not collide with existing columns. ### Interaction with other InQL surfaces @@ -112,11 +113,16 @@ Existing unnest/explode behavior should align with this RFC. If current behavior - **Execution / interchange** — Prism and Substrait lowering must represent cardinality changes and output schemas faithfully. - **Documentation** — generator docs should explain cardinality and schema effects before listing helper names. -## Unresolved questions +## Design Decisions -- Should positional generators use zero-based or one-based positions? -- Should `.explode(...)` preserve all input columns by default? -- What aliasing syntax should be required for generated output columns? -- What subset of Snowflake-style `flatten` behavior belongs in portable InQL versus a warehouse compatibility extension? +### Resolved - +- Positional generators use zero-based positions for compatibility with the `posexplode` naming convention. +- Explicit generator applications preserve all input columns by default and append generated output columns. +- Generated aliases are required at builder construction time. +- Snowflake-style recursive/path `flatten` remains outside the portable core until its output schema and compatibility category are specified separately. + +### Remaining + +- `inline`, `inline_outer`, `stack`, and portable table-valued `flatten` need separate helper slices on top of the generator application model. +- Query-block generator syntax still needs compiler/query-surface work. diff --git a/docs/rfcs/README.md b/docs/rfcs/README.md index fac71de..c42c434 100644 --- a/docs/rfcs/README.md +++ b/docs/rfcs/README.md @@ -27,7 +27,7 @@ InQL uses its **own** RFC series (starting at 000), independent of the [Incan la | [018][rfc-018] | Draft | Common scalar function catalog | | | [019][rfc-019] | Draft | Window functions | | | [020][rfc-020] | Draft | Nested data functions | | -| [021][rfc-021] | Draft | Generator and table-valued functions | | +| [021][rfc-021] | In Progress | Generator and table-valued functions | | | [022][rfc-022] | Draft | Semi-structured and format functions | | | [023][rfc-023] | Draft | Approximate and sketch functions | | | [024][rfc-024] | Draft | Function extension policy | | diff --git a/src/dataset/mod.incn b/src/dataset/mod.incn index fa850bd..e9b31b1 100644 --- a/src/dataset/mod.incn +++ b/src/dataset/mod.incn @@ -22,6 +22,7 @@ The current method-chain surface in this module is the explicit builder-based AP - `with_column(name: str, expr: ColumnExpr)` - `group_by(columns: list[ColumnExpr])` - `agg(measures: list[AggregateMeasure])` +- `generate(generator: GeneratorApplication)` - plus the structural operators `join`, `select`, `order_by`, `limit`, and `explode` Illustrative current-shape examples: @@ -53,6 +54,7 @@ See also: from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr from dataset.materialization import DataFrameMaterialization from substrait.errors import SubstraitLoweringError @@ -63,6 +65,7 @@ from dataset.ops import ( agg_ds_of_columns, explode_ds, filter_ds_of_columns, + generate_ds_of_columns, group_by_ds_of_columns, join_ds, limit_ds, @@ -76,6 +79,7 @@ from prism import ( prism_cursor_apply_agg, prism_cursor_apply_explode, prism_cursor_apply_filter, + prism_cursor_apply_generate, prism_cursor_apply_group_by, prism_cursor_apply_join, prism_cursor_apply_limit, @@ -98,6 +102,7 @@ pub trait DataSet[T with Clone]: def with_column(self, name: str, expr: ColumnExpr) -> Self def group_by(self, columns: list[ColumnExpr]) -> Self def agg(self, measures: list[AggregateMeasure]) -> Self + def generate(self, generator: GeneratorApplication) -> Self def order_by(self, columns: list[ColumnExpr]) -> Self def limit(self, n: int) -> Self def explode(self) -> Self @@ -207,6 +212,12 @@ pub class DataFrame[T with Clone] with BoundedDataSet: agg_ds_of_columns(self._substrait_rel, self.planned_columns(), measures), ) + def generate(self, generator: GeneratorApplication) -> Self: + """Return one new DataFrame with a generator stage and stale materialization cleared.""" + return _data_frame_with_invalidated_materialization( + generate_ds_of_columns(self._substrait_rel, self.planned_columns(), generator), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataFrame with an ordering stage and stale materialization cleared.""" return _data_frame_with_invalidated_materialization( @@ -288,6 +299,10 @@ pub class LazyFrame[T with Clone] with BoundedDataSet: """Return one new lazy carrier with an appended aggregation stage.""" return LazyFrame(_cursor=prism_cursor_apply_agg(self._cursor, measures)) + def generate(self, generator: GeneratorApplication) -> Self: + """Return one new lazy carrier with an appended generator stage.""" + return LazyFrame(_cursor=prism_cursor_apply_generate(self._cursor, generator)) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new lazy carrier with an appended ordering stage.""" return LazyFrame(_cursor=prism_cursor_apply_order_by(self._cursor, columns)) @@ -430,6 +445,17 @@ pub class DataStream[T with Clone] with UnboundedDataSet: ), ) + def generate(self, generator: GeneratorApplication) -> Self: + """Return one new DataStream with a generator stage.""" + return DataStream( + _row_schema_marker=self._row_schema_marker.clone(), + _substrait_rel=generate_ds_of_columns( + self._substrait_rel, + relation_output_columns(self._substrait_rel.clone()), + generator, + ), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataStream with an ordering stage.""" return DataStream( diff --git a/src/dataset/ops.incn b/src/dataset/ops.incn index 5319f4d..bafad30 100644 --- a/src/dataset/ops.incn +++ b/src/dataset/ops.incn @@ -8,6 +8,7 @@ views stay aligned with the lowered relation tree. from rust::substrait::proto import Rel from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment from substrait.function_extensions import explode_extension_uri from substrait.inspect import relation_output_columns @@ -19,6 +20,7 @@ from substrait.relations import ( join_rel, project_rel_of_columns, sort_rel_of_columns, + generator_rel_of_columns, ) @@ -122,6 +124,16 @@ pub def agg_ds_of_columns(rel: Rel, input_columns: list[str], measures: list[Agg return aggregate_rel_of_columns(rel, input_columns, [], measures) +pub def generate_ds(rel: Rel, generator: GeneratorApplication) -> Rel: + """Apply one relation-shaping generator to a relation.""" + return generate_ds_of_columns(rel, relation_output_columns(rel.clone()), generator) + + +pub def generate_ds_of_columns(rel: Rel, input_columns: list[str], generator: GeneratorApplication) -> Rel: + """Apply one relation-shaping generator using explicit input-column names.""" + return generator_rel_of_columns(rel, input_columns, generator) + + pub def order_by_ds(rel: Rel, columns: list[ColumnExpr]) -> Rel: """ Apply dataset-level ordering intent to one relation. diff --git a/src/function_registry.incn b/src/function_registry.incn index b5642f9..2ac97ff 100644 --- a/src/function_registry.incn +++ b/src/function_registry.incn @@ -75,6 +75,7 @@ pub enum SubstraitMappingKind(str): CoreFunction = "core_function" ExtensionFunction = "extension_function" + RelationExtension = "relation_extension" Rewrite = "rewrite" StructuralFunction = "structural_function" @@ -294,6 +295,18 @@ pub def extension_mapping(function_name: str, anchor: u32) -> SubstraitMapping: ) +pub def relation_extension_mapping(function_name: str, uri: str) -> SubstraitMapping: + """Build one registered Substrait relation-extension mapping.""" + return SubstraitMapping( + kind=SubstraitMappingKind.RelationExtension, + uri=uri, + function_name=function_name, + anchor=0, + rewrite="", + detail="extension_single", + ) + + pub def core_mapping(function_name: str) -> SubstraitMapping: """Build one mapping for a built-in Substrait Rex shape rather than an extension function declaration.""" return SubstraitMapping( diff --git a/src/functions/generators/explode.incn b/src/functions/generators/explode.incn new file mode 100644 index 0000000..1b6f2ed --- /dev/null +++ b/src/functions/generators/explode.incn @@ -0,0 +1,42 @@ +"""Inner explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import function_registry +from generator_builders import GeneratorApplication, explode as explode_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import explode_extension_uri + + +@function_registry.add("explode", deterministic_spec( + FunctionClass.Generator, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + relation_extension_mapping("explode", explode_extension_uri()), +)) +pub def explode(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """ + Build an inner row-expanding generator for array values. + + Examples: + generated = explode(col("line_items"), "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + as_: Output alias for the generated value column. + """ + return explode_builder(expr, as_) + + +module tests: + from projection_builders import col + def test_explode_builds_generator_application() -> None: + generator = explode(col("line_items"), "line_item") + assert generator.canonical_name == "explode" + assert generator.output_columns[0] == "line_item" diff --git a/src/functions/generators/explode_outer.incn b/src/functions/generators/explode_outer.incn new file mode 100644 index 0000000..bdbc1c9 --- /dev/null +++ b/src/functions/generators/explode_outer.incn @@ -0,0 +1,42 @@ +"""Outer explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import function_registry +from generator_builders import GeneratorApplication, explode_outer as explode_outer_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import explode_outer_extension_uri + + +@function_registry.add("explode_outer", deterministic_spec( + FunctionClass.Generator, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + relation_extension_mapping("explode_outer", explode_outer_extension_uri()), +)) +pub def explode_outer(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """ + Build an outer row-expanding generator for array values. + + Examples: + generated = explode_outer(col("line_items"), "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + as_: Output alias for the generated nullable value column. + """ + return explode_outer_builder(expr, as_) + + +module tests: + from projection_builders import col + def test_explode_outer_builds_outer_generator_application() -> None: + generator = explode_outer(col("line_items"), "line_item") + assert generator.canonical_name == "explode_outer" + assert generator.is_outer diff --git a/src/functions/generators/mod.incn b/src/functions/generators/mod.incn new file mode 100644 index 0000000..4865e2b --- /dev/null +++ b/src/functions/generators/mod.incn @@ -0,0 +1,6 @@ +"""Relation-shaping generator helpers.""" + +pub from functions.generators.explode import explode +pub from functions.generators.explode_outer import explode_outer +pub from functions.generators.posexplode import posexplode +pub from functions.generators.posexplode_outer import posexplode_outer diff --git a/src/functions/generators/posexplode.incn b/src/functions/generators/posexplode.incn new file mode 100644 index 0000000..b4d5185 --- /dev/null +++ b/src/functions/generators/posexplode.incn @@ -0,0 +1,44 @@ +"""Inner positional explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import function_registry +from generator_builders import GeneratorApplication, posexplode as posexplode_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import posexplode_extension_uri + + +@function_registry.add("posexplode", deterministic_spec( + FunctionClass.Generator, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + relation_extension_mapping("posexplode", posexplode_extension_uri()), +)) +pub def posexplode(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """ + Build an inner row-expanding generator with a zero-based position column. + + Examples: + generated = posexplode(col("line_items"), "position", "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + position_as: Output alias for the zero-based position column. + value_as: Output alias for the generated value column. + """ + return posexplode_builder(expr, position_as, value_as) + + +module tests: + from projection_builders import col + def test_posexplode_builds_positional_generator_application() -> None: + generator = posexplode(col("line_items"), "position", "line_item") + assert generator.canonical_name == "posexplode" + assert generator.position_origin == 0 + assert generator.output_columns[0] == "position" diff --git a/src/functions/generators/posexplode_outer.incn b/src/functions/generators/posexplode_outer.incn new file mode 100644 index 0000000..20bda72 --- /dev/null +++ b/src/functions/generators/posexplode_outer.incn @@ -0,0 +1,44 @@ +"""Outer positional explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import function_registry +from generator_builders import GeneratorApplication, posexplode_outer as posexplode_outer_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import posexplode_outer_extension_uri + + +@function_registry.add("posexplode_outer", deterministic_spec( + FunctionClass.Generator, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + relation_extension_mapping("posexplode_outer", posexplode_outer_extension_uri()), +)) +pub def posexplode_outer(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """ + Build an outer row-expanding generator with a zero-based position column. + + Examples: + generated = posexplode_outer(col("line_items"), "position", "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + position_as: Output alias for the zero-based position column. + value_as: Output alias for the generated nullable value column. + """ + return posexplode_outer_builder(expr, position_as, value_as) + + +module tests: + from projection_builders import col + def test_posexplode_outer_builds_outer_positional_generator_application() -> None: + generator = posexplode_outer(col("line_items"), "position", "line_item") + assert generator.canonical_name == "posexplode_outer" + assert generator.is_outer + assert generator.output_columns[1] == "line_item" diff --git a/src/functions/mod.incn b/src/functions/mod.incn index e0e8754..c20b662 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -61,6 +61,10 @@ pub from functions.nested.map_from_arrays import map_from_arrays pub from functions.nested.map_keys import map_keys pub from functions.nested.map_values import map_values pub from functions.nested.named_struct import named_struct +pub from functions.generators.explode import explode +pub from functions.generators.explode_outer import explode_outer +pub from functions.generators.posexplode import posexplode +pub from functions.generators.posexplode_outer import posexplode_outer pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div diff --git a/src/generator_builders.incn b/src/generator_builders.incn new file mode 100644 index 0000000..d3d6b16 --- /dev/null +++ b/src/generator_builders.incn @@ -0,0 +1,150 @@ +""" +Relation-shaping generator builder surface. + +Generators are not scalar expressions: they may change row cardinality and append output columns. This module carries +the authoring intent through Dataset, Prism, and Substrait boundaries without making generators valid in ordinary +row-level expression positions. +""" + +from rust::incan_stdlib::errors import raise_value_error +from function_registry import function_ref_for +from projection_builders import ColumnExpr + + +@derive(Clone) +pub enum GeneratorKind(str): + """Supported relation-shaping generator kinds in the current portable slice.""" + + Explode = "explode" + ExplodeOuter = "explode_outer" + PosExplode = "posexplode" + PosExplodeOuter = "posexplode_outer" + + +@derive(Clone) +pub model GeneratorApplication: + """One registry-backed relation-shaping generator application.""" + + pub kind: GeneratorKind + pub function_ref: str + pub canonical_name: str + pub expr: ColumnExpr + pub output_columns: list[str] + pub preserves_input_columns: bool + pub is_outer: bool + pub position_origin: int + + +pub def explode(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """Build an inner `explode` generator that appends one value column.""" + return _generator_application("explode", GeneratorKind.Explode, expr, [as_], true, false, 0) + + +pub def explode_outer(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """Build an outer `explode` generator that appends one nullable value column.""" + return _generator_application("explode_outer", GeneratorKind.ExplodeOuter, expr, [as_], true, true, 0) + + +pub def posexplode(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """Build an inner positional explode generator with zero-based positions.""" + return _generator_application("posexplode", GeneratorKind.PosExplode, expr, [position_as, value_as], true, false, 0) + + +pub def posexplode_outer(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """Build an outer positional explode generator with zero-based positions.""" + return _generator_application( + "posexplode_outer", + GeneratorKind.PosExplodeOuter, + expr, + [position_as, value_as], + true, + true, + 0, + ) + + +pub def generator_output_columns(input_columns: list[str], generator: GeneratorApplication) -> list[str]: + """Return output columns after applying one generator to the provided input columns.""" + mut output_columns: list[str] = [] + if generator.preserves_input_columns: + output_columns.extend(input_columns) + for output_column in generator.output_columns: + if _contains_text(output_columns, output_column): + message = f"generator output column `{output_column}` conflicts with an existing column" + return raise_value_error(message) + output_columns.append(output_column) + return output_columns + + +pub def generator_primary_output_column(generator: GeneratorApplication) -> str: + """Return the primary generated value column for inspection and tests.""" + if len(generator.output_columns) == 0: + return "" + return generator.output_columns[len(generator.output_columns) - 1] + + +def _generator_application( + canonical_name: str, + kind: GeneratorKind, + expr: ColumnExpr, + output_columns: list[str], + preserves_input_columns: bool, + is_outer: bool, + position_origin: int, +) -> GeneratorApplication: + """Build one generator application after validating declared output aliases.""" + _validate_output_columns(canonical_name, output_columns) + return GeneratorApplication( + kind=kind, + function_ref=function_ref_for(canonical_name), + canonical_name=canonical_name, + expr=expr, + output_columns=output_columns, + preserves_input_columns=preserves_input_columns, + is_outer=is_outer, + position_origin=position_origin, + ) + + +def _validate_output_columns(canonical_name: str, output_columns: list[str]) -> None: + """Validate mandatory generator output aliases.""" + if len(output_columns) == 0: + message = f"{canonical_name} requires at least one output alias" + return raise_value_error(message) + mut seen: list[str] = [] + for output_column in output_columns: + if len(output_column) == 0: + message = f"{canonical_name} output aliases must be non-empty" + return raise_value_error(message) + if _contains_text(seen, output_column): + message = f"{canonical_name} output alias `{output_column}` is duplicated" + return raise_value_error(message) + seen.append(output_column) + return + + +def _contains_text(values: list[str], expected: str) -> bool: + """Return whether a string list contains a value.""" + for value in values: + if value == expected: + return true + return false + + +module tests: + from projection_builders import col, column_expr_name + def test_explode_application_records_function_identity_and_output_column() -> None: + generator = explode(col("line_items"), "line_item") + assert generator.kind == GeneratorKind.Explode + assert generator.canonical_name == "explode" + assert generator.function_ref == "inql.functions.explode" + assert column_expr_name(generator.expr) == "line_items" + assert generator.output_columns[0] == "line_item" + assert generator.preserves_input_columns + assert not generator.is_outer + def test_posexplode_uses_zero_based_position_origin() -> None: + generator = posexplode(col("line_items"), "pos", "line_item") + assert generator.kind == GeneratorKind.PosExplode + assert generator.position_origin == 0 + assert generator.output_columns[0] == "pos" + assert generator.output_columns[1] == "line_item" diff --git a/src/lib.incn b/src/lib.incn index a1767b6..87604d7 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -6,8 +6,24 @@ Consumers depend on this package via `[dependencies]` and import with `from pub: """ pub from dataset import BoundedDataSet, DataFrame, DataSet, DataStream, LazyFrame, UnboundedDataSet -pub from dataset.ops import agg_ds, explode_ds, filter_ds, group_by_ds, join_ds, limit_ds, order_by_ds, select_ds +pub from dataset.ops import ( + agg_ds, + explode_ds, + filter_ds, + generate_ds, + group_by_ds, + join_ds, + limit_ds, + order_by_ds, + select_ds, +) pub from aggregate_builders import AggregateKind, AggregateMeasure +pub from generator_builders import ( + GeneratorApplication, + GeneratorKind, + generator_output_columns, + generator_primary_output_column, +) pub from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -81,6 +97,10 @@ pub from functions.nested.map_from_arrays import map_from_arrays pub from functions.nested.map_keys import map_keys pub from functions.nested.map_values import map_values pub from functions.nested.named_struct import named_struct +pub from functions.generators.explode import explode +pub from functions.generators.explode_outer import explode_outer +pub from functions.generators.posexplode import posexplode +pub from functions.generators.posexplode_outer import posexplode_outer pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div @@ -142,6 +162,7 @@ pub from function_registry import ( function_policy_spec, namespaced_function_ref, rejected_function_policy, + relation_extension_mapping, rewrite_mapping, sort_field_mapping, structural_mapping, @@ -184,6 +205,8 @@ pub from substrait.relations import ( extension_single_rel, fetch_rel, filter_rel, + generator_rel, + generator_rel_of_columns, join_rel, join_rel_of_kind, project_rel, @@ -211,6 +234,8 @@ pub from substrait.inspect import ( aggregate_measure_function_names, aggregate_measure_invocation_names, aggregate_measure_sort_counts, + plan_extension_urn_anchor_at, + plan_extension_urn_count, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -222,7 +247,10 @@ pub from substrait.inspect import ( ) pub from substrait.function_extensions import ( explode_extension_uri, + explode_outer_extension_uri, function_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, registered_substrait_extension_uris, ) pub from substrait.conformance_catalog import ( diff --git a/src/prism/lower.incn b/src/prism/lower.incn index 9ae303c..6020b57 100644 --- a/src/prism/lower.incn +++ b/src/prism/lower.incn @@ -10,6 +10,7 @@ from substrait.relations import ( fetch_rel, join_rel, read_named_table_rel, + try_generator_rel_of_columns, sort_rel_of_columns, try_aggregate_rel_of_columns, try_filter_rel_of_columns, @@ -118,6 +119,12 @@ def _try_lower_node(view: PrismOptimizedView, node_id: int) -> Result[Rel, Subst [], node.aggregate_measures, ) + PrismNodeKind.Generate => + return try_generator_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, + rewritten_output_columns(view, node.input_ids[0]), + node.generator_applications[0], + ) PrismNodeKind.OrderBy => return Ok( sort_rel_of_columns( diff --git a/src/prism/mod.incn b/src/prism/mod.incn index 229cbaa..3564d35 100644 --- a/src/prism/mod.incn +++ b/src/prism/mod.incn @@ -13,6 +13,7 @@ This façade keeps one stable internal import surface while the implementation i from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure from filter_builders import always_true +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment from prism.lower import ( lower_prism_tip as lower_prism_tip_impl, @@ -69,6 +70,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -87,6 +89,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -102,6 +105,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -119,6 +123,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -136,6 +141,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[with_column_assignment(name, expr)], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -153,6 +159,7 @@ pub class PrismCursor[T with Clone]: group_columns=columns, sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -170,6 +177,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=measures, + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -187,6 +195,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=columns, aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -204,6 +213,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -221,6 +231,25 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], + projection_assignments=[], + ) + return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) + + def generate(self, generator: GeneratorApplication) -> Self: + """Append one explicit generator node and return the derived tip.""" + next_tip_id = append_node( + store_id=self.store_id, + kind=PrismNodeKind.Generate, + input_ids=[self.tip_id], + named_table="", + join_predicate=false, + filter_predicate=always_true(), + limit_count=0, + group_columns=[], + sort_columns=[], + aggregate_measures=[], + generator_applications=[generator], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -264,6 +293,7 @@ pub def prism_cursor_named_table[T with Clone](table_name: str) -> PrismCursor[T group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=store_id, tip_id=tip_id, _type_marker=[]) @@ -325,6 +355,14 @@ pub def prism_cursor_apply_explode[T with Clone](cursor: PrismCursor[T]) -> Pris return cursor.explode() +pub def prism_cursor_apply_generate[T with Clone]( + cursor: PrismCursor[T], + generator: GeneratorApplication, +) -> PrismCursor[T]: + """Apply one explicit generator through Prism.""" + return cursor.generate(generator) + + pub def prism_cursor_output_columns[T with Clone](cursor: PrismCursor[T]) -> list[str]: """Return plan-time output columns for one cursor tip.""" return cursor.planned_columns() diff --git a/src/prism/output_columns.incn b/src/prism/output_columns.incn index f1de58c..d1cfa06 100644 --- a/src/prism/output_columns.incn +++ b/src/prism/output_columns.incn @@ -3,6 +3,7 @@ from prism.store import node_at from prism.rewrite import rewritten_node_at from prism.types import PrismNodeKind, PrismOptimizedView, PrismStoreId +from generator_builders import generator_output_columns from projection_builders import ColumnExpr, project_output_columns, scalar_expr_output_name from substrait.inspect import aggregate_measure_output_names from substrait.schema_registry import named_table_columns @@ -27,6 +28,11 @@ pub def authored_output_columns(store_id: PrismStoreId, tip_id: int) -> list[str return authored_output_columns(store_id, node.input_ids[0]) if node.kind == PrismNodeKind.Project: return project_output_columns(authored_output_columns(store_id, node.input_ids[0]), node.projection_assignments) + if node.kind == PrismNodeKind.Generate: + return generator_output_columns( + authored_output_columns(store_id, node.input_ids[0]), + node.generator_applications[0], + ) if node.kind == PrismNodeKind.Join: # Join output columns preserve the conventional left-then-right relation order. # We keep both sides verbatim here; duplicate names are part of the current output shape and are resolved later @@ -59,6 +65,11 @@ pub def rewritten_output_columns(view: PrismOptimizedView, node_id: int) -> list return rewritten_output_columns(view, node.input_ids[0]) if node.kind == PrismNodeKind.Project: return project_output_columns(rewritten_output_columns(view, node.input_ids[0]), node.projection_assignments) + if node.kind == PrismNodeKind.Generate: + return generator_output_columns( + rewritten_output_columns(view, node.input_ids[0]), + node.generator_applications[0], + ) if node.kind == PrismNodeKind.Join: # Rewritten views keep the same left-then-right join column order as authored views # so output-column inference stays stable across Prism rewrite passes. diff --git a/src/prism/rewrite.incn b/src/prism/rewrite.incn index 6247b0b..419f968 100644 --- a/src/prism/rewrite.incn +++ b/src/prism/rewrite.incn @@ -168,6 +168,7 @@ def _build_collapsed_limit_node( group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) @@ -204,6 +205,7 @@ def _build_collapsed_project_node( group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=merged_assignments, ) @@ -240,6 +242,7 @@ def _build_collapsed_aggregate_node( group_columns=[], sort_columns=[], aggregate_measures=merged_measures, + generator_applications=[], projection_assignments=[], ) @@ -274,6 +277,7 @@ def _build_collapsed_order_by_node( group_columns=[], sort_columns=node.sort_columns, aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) @@ -291,6 +295,7 @@ def _build_rewritten_node(node: PrismNode, remapped_inputs: list[int], rewritten group_columns=node.group_columns, sort_columns=node.sort_columns, aggregate_measures=node.aggregate_measures, + generator_applications=node.generator_applications, projection_assignments=node.projection_assignments, ) @@ -336,6 +341,7 @@ def _compact_optimized_view(view: PrismOptimizedView) -> PrismOptimizedView: group_columns=old_node.group_columns, sort_columns=old_node.sort_columns, aggregate_measures=old_node.aggregate_measures, + generator_applications=old_node.generator_applications, projection_assignments=old_node.projection_assignments, ), ) diff --git a/src/prism/store.incn b/src/prism/store.incn index d620574..e451ade 100644 --- a/src/prism/store.incn +++ b/src/prism/store.incn @@ -1,6 +1,7 @@ """Append-only Prism store allocation, storage, reachability, and cross-store adoption.""" from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -54,6 +55,7 @@ pub def append_node( group_columns: list[ColumnExpr], sort_columns: list[ColumnExpr], aggregate_measures: list[AggregateMeasure], + generator_applications: list[GeneratorApplication], projection_assignments: list[ProjectionAssignment], ) -> int: """ @@ -73,6 +75,7 @@ pub def append_node( group_columns=group_columns, sort_columns=sort_columns, aggregate_measures=aggregate_measures, + generator_applications=generator_applications, projection_assignments=projection_assignments, ) prism_stored_nodes.append(PrismStoredNode(store_id_raw=store_id.0, node=appended)) @@ -119,10 +122,12 @@ pub def adopt_cursor_subgraph( adopted_group_columns = [column for column in source_node.group_columns] adopted_sort_columns = [column for column in source_node.sort_columns] adopted_measures = [measure for measure in source_node.aggregate_measures] + adopted_generators = [generator for generator in source_node.generator_applications] adopted_assignments = [assignment for assignment in source_node.projection_assignments] target_group_columns = [column for column in source_node.group_columns] target_sort_columns = [column for column in source_node.sort_columns] target_measures = [measure for measure in source_node.aggregate_measures] + target_generators = [generator for generator in source_node.generator_applications] target_assignments = [assignment for assignment in source_node.projection_assignments] adopted_id = append_node( store_id=target_store_id, @@ -135,6 +140,7 @@ pub def adopt_cursor_subgraph( group_columns=adopted_group_columns, sort_columns=adopted_sort_columns, aggregate_measures=adopted_measures, + generator_applications=adopted_generators, projection_assignments=adopted_assignments, ) target_store_nodes.append( @@ -149,6 +155,7 @@ pub def adopt_cursor_subgraph( group_columns=target_group_columns, sort_columns=target_sort_columns, aggregate_measures=target_measures, + generator_applications=target_generators, projection_assignments=target_assignments, ), ) @@ -232,6 +239,11 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema return false if not _aggregate_measure_lists_structurally_equal(candidate.aggregate_measures, source_node.aggregate_measures): return false + if not _generator_application_lists_structurally_equal( + candidate.generator_applications, + source_node.generator_applications, + ): + return false if not _projection_assignments_structurally_equal( candidate.projection_assignments, source_node.projection_assignments, @@ -271,6 +283,48 @@ def _aggregate_measures_structurally_equal(left: AggregateMeasure, right: Aggreg return _column_expr_lists_structurally_equal(left.ordering, right.ordering) +def _generator_application_lists_structurally_equal( + left: list[GeneratorApplication], + right: list[GeneratorApplication], +) -> bool: + """Return whether two generator-application lists carry identical relation-shaping semantics.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if not _generator_applications_structurally_equal(left[idx], right[idx]): + return false + return true + + +def _generator_applications_structurally_equal(left: GeneratorApplication, right: GeneratorApplication) -> bool: + """Return whether two generator applications carry identical registry identity and schema effects.""" + if left.kind != right.kind: + return false + if left.function_ref != right.function_ref: + return false + if left.canonical_name != right.canonical_name: + return false + if left.preserves_input_columns != right.preserves_input_columns: + return false + if left.is_outer != right.is_outer: + return false + if left.position_origin != right.position_origin: + return false + if not _text_lists_structurally_equal(left.output_columns, right.output_columns): + return false + return _column_exprs_structurally_equal(left.expr, right.expr) + + +def _text_lists_structurally_equal(left: list[str], right: list[str]) -> bool: + """Return whether two string lists are structurally equivalent.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if left[idx] != right[idx]: + return false + return true + + def _filter_predicates_structurally_equal(left: ColumnExpr, right: ColumnExpr) -> bool: """Return whether two filter scalar expressions are structurally equivalent.""" return _column_exprs_structurally_equal(left, right) diff --git a/src/prism/types.incn b/src/prism/types.incn index a5573cf..59472c1 100644 --- a/src/prism/types.incn +++ b/src/prism/types.incn @@ -1,6 +1,7 @@ """Shared Prism types that define the internal planning substrate contract.""" from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment @@ -17,6 +18,7 @@ pub enum PrismNodeKind(str): Project = "Project" GroupBy = "GroupBy" Aggregate = "Aggregate" + Generate = "Generate" OrderBy = "OrderBy" Limit = "Limit" Explode = "Explode" @@ -41,6 +43,7 @@ pub model PrismNode: pub group_columns: list[ColumnExpr] pub sort_columns: list[ColumnExpr] pub aggregate_measures: list[AggregateMeasure] + pub generator_applications: list[GeneratorApplication] pub projection_assignments: list[ProjectionAssignment] diff --git a/src/substrait/expr_lowering.incn b/src/substrait/expr_lowering.incn index d5dcd72..21384eb 100644 --- a/src/substrait/expr_lowering.incn +++ b/src/substrait/expr_lowering.incn @@ -281,6 +281,12 @@ def _resolved_scalar_function_application_expr( f"{entry.function_ref} is only valid in {entry.substrait.function_name} context", ), ) + SubstraitMappingKind.RelationExtension => + return Err( + invalid_scalar_expression( + f"{entry.function_ref} is a relation-shaping generator and must be applied through generate(...)", + ), + ) SubstraitMappingKind.Rewrite => return Err( invalid_scalar_expression( diff --git a/src/substrait/extensions.incn b/src/substrait/extensions.incn index f4efeb4..4f0edad 100644 --- a/src/substrait/extensions.incn +++ b/src/substrait/extensions.incn @@ -6,6 +6,7 @@ expression trees. """ from rust::incan_stdlib::errors import raise_value_error +from rust::std::primitive import u32 as RustU32 from rust::substrait::proto import AggregateFunction, Expression, FunctionArgument, Rel, SortField from rust::substrait::proto::extensions import SimpleExtensionDeclaration, SimpleExtensionUrn from rust::substrait::proto::extensions::simple_extension_declaration import ExtensionFunction, MappingType @@ -28,7 +29,23 @@ model ExtensionUrnSpec: const FUNCTION_EXTENSION_URN_ANCHOR: u32 = 0 -const RELATION_EXTENSION_URN_ANCHOR: u32 = 1 + + +def _to_extension_urn_anchor(value: int) -> RustU32: + """Convert a small extension-URN anchor into the protobuf field type.""" + match RustU32.try_from(value): + Ok(converted) => return converted + Err(_) => + message = f"extension URN anchor {value} does not fit Rust u32" + return raise_value_error(message) + + +def _has_extension_urn_spec(specs: list[ExtensionUrnSpec], urn: str) -> bool: + """Return whether a plan-level extension URN list already contains one URI.""" + for spec in specs: + if spec.urn == urn: + return true + return false pub def aggregate_function_name_from_anchor(anchor: u32) -> str: @@ -407,6 +424,10 @@ pub def extension_urns_for_rel(rel: Rel) -> list[SimpleExtensionUrn]: mut specs: list[ExtensionUrnSpec] = [] if _function_extension_urn_is_required(rel.clone()): specs.append(ExtensionUrnSpec(anchor=FUNCTION_EXTENSION_URN_ANCHOR, urn=function_extension_uri())) + mut relation_anchor_count = 0 for urn in _collect_extension_urn_strings(rel): - specs.append(ExtensionUrnSpec(anchor=RELATION_EXTENSION_URN_ANCHOR, urn=urn)) + if _has_extension_urn_spec(specs, urn): + continue + relation_anchor_count += 1 + specs.append(ExtensionUrnSpec(anchor=_to_extension_urn_anchor(relation_anchor_count), urn=urn)) return [SimpleExtensionUrn(extension_urn_anchor=spec.anchor, urn=spec.urn) for spec in specs] diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn index 490f93c..649a680 100644 --- a/src/substrait/function_extensions.incn +++ b/src/substrait/function_extensions.incn @@ -77,6 +77,9 @@ pub const ARRAY_HAS_ANY_FUNCTION_ANCHOR: u32 = 50 pub const ARRAY_FLATTEN_FUNCTION_ANCHOR: u32 = 51 const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" +const EXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode_outer" +const POSEXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#posexplode" +const POSEXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#posexplode_outer" pub def function_extension_uri() -> str: @@ -89,6 +92,21 @@ pub def explode_extension_uri() -> str: return EXPLODE_EXTENSION_URI +pub def explode_outer_extension_uri() -> str: + """Return the registered extension URI used for outer EXPLODE gap encoding.""" + return EXPLODE_OUTER_EXTENSION_URI + + +pub def posexplode_extension_uri() -> str: + """Return the registered extension URI used for positional EXPLODE gap encoding.""" + return POSEXPLODE_EXTENSION_URI + + +pub def posexplode_outer_extension_uri() -> str: + """Return the registered extension URI used for outer positional EXPLODE gap encoding.""" + return POSEXPLODE_OUTER_EXTENSION_URI + + pub def registered_substrait_extension_uris() -> list[str]: """Return the registered extension URIs used by current package-level Substrait lowering.""" - return [FUNCTION_EXTENSION_URI, EXPLODE_EXTENSION_URI] + return [FUNCTION_EXTENSION_URI, EXPLODE_EXTENSION_URI, EXPLODE_OUTER_EXTENSION_URI, POSEXPLODE_EXTENSION_URI, POSEXPLODE_OUTER_EXTENSION_URI] diff --git a/src/substrait/inspect.incn b/src/substrait/inspect.incn index 37005e0..063061f 100644 --- a/src/substrait/inspect.incn +++ b/src/substrait/inspect.incn @@ -19,8 +19,14 @@ from rust::substrait::proto::set_rel import SetOp from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure from projection_builders import scalar_expr_output_name -from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit +from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit, rust_u32_to_int from substrait.extensions import aggregate_function_name_from_anchor +from substrait.function_extensions import ( + explode_extension_uri, + explode_outer_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, +) from substrait.schema_registry import named_table_columns, unknown_named_struct from substrait.traversal import relation_children @@ -185,7 +191,11 @@ def _relation_output_columns(rel: Rel) -> list[str]: None => return [] Some(RelType.ExtensionSingle(extension_rel)) => match extension_rel.input: - Some(child) => return _relation_output_columns(child.as_ref().clone()) + Some(child) => + input_columns = _relation_output_columns(child.as_ref().clone()) + match extension_rel.detail: + Some(detail) => return _extension_single_output_columns(input_columns, detail.type_url) + None => return input_columns None => return [] Some(RelType.Join(join_rel)) => mut names: list[str] = [] @@ -218,6 +228,18 @@ pub def relation_output_columns(rel: Rel) -> list[str]: return _relation_output_columns(rel) +def _extension_single_output_columns(input_columns: list[str], extension_uri: str) -> list[str]: + """Return best-effort output columns for known extension-single relation encodings.""" + mut columns: list[str] = [] + columns.extend(input_columns) + if extension_uri == explode_extension_uri() or extension_uri == explode_outer_extension_uri(): + columns.append("value") + elif extension_uri == posexplode_extension_uri() or extension_uri == posexplode_outer_extension_uri(): + columns.append("position") + columns.append("value") + return columns + + pub def aggregate_measure_function_names(rel: Rel) -> list[str]: """Return aggregate function names used by a top-level AggregateRel, otherwise empty.""" match rel.rel_type: @@ -453,3 +475,15 @@ pub def plan_has_extension_urn(plan: Plan, extension_uri: str) -> bool: if urn.urn == extension_uri: return true return false + + +pub def plan_extension_urn_count(plan: Plan) -> int: + """Return the number of extension URN declarations carried by one plan.""" + return len(plan.extension_urns) + + +pub def plan_extension_urn_anchor_at(plan: Plan, index: int) -> int: + """Return one extension URN anchor as an Incan integer for tests and diagnostics.""" + if index < 0 or index >= len(plan.extension_urns): + return -1 + return rust_u32_to_int(plan.extension_urns[index].extension_urn_anchor) diff --git a/src/substrait/mod.incn b/src/substrait/mod.incn index 16e0f38..2f15c20 100644 --- a/src/substrait/mod.incn +++ b/src/substrait/mod.incn @@ -26,6 +26,8 @@ pub from substrait.relations import ( fetch_rel, filter_rel, filter_rel_of_columns, + generator_rel, + generator_rel_of_columns, join_rel, join_rel_of_kind, project_rel, @@ -58,6 +60,8 @@ pub from substrait.inspect import ( aggregate_measure_invocation_names, aggregate_measure_output_names, aggregate_measure_sort_counts, + plan_extension_urn_anchor_at, + plan_extension_urn_count, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -76,6 +80,9 @@ pub from substrait.inspect import ( ) pub from substrait.function_extensions import ( explode_extension_uri, + explode_outer_extension_uri, function_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, registered_substrait_extension_uris, ) diff --git a/src/substrait/relations.incn b/src/substrait/relations.incn index 849beba..b075e5f 100644 --- a/src/substrait/relations.incn +++ b/src/substrait/relations.incn @@ -46,6 +46,7 @@ from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure from function_registry import FunctionClass, FunctionRegistryEntry, SubstraitMappingKind from functions.registry import function_registry_entry +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, ScalarFunctionApplicationExpr, col from substrait.expr_lowering import ( bool_expr, @@ -81,6 +82,15 @@ model ResolvedRelationExpression: expr: Expression +@derive(Clone) +model ResolvedGeneratorApplication: + """One generator application resolved against input columns and registry metadata.""" + + generator: GeneratorApplication + entry: FunctionRegistryEntry + expr: Expression + + pub enum SubstraitJoinKind: Inner Left @@ -259,6 +269,61 @@ def _validate_aggregate_modifiers(measure: ResolvedAggregateMeasure) -> Result[N return Ok(None) +def _generator_registry_entry(generator: GeneratorApplication) -> Result[FunctionRegistryEntry, SubstraitLoweringError]: + """Resolve one generator registry entry and validate its semantic class.""" + match function_registry_entry(generator.function_ref): + Some(entry) => + if entry.function_class != FunctionClass.Generator: + return Err(invalid_scalar_expression(f"{entry.function_ref} is not registered as a generator function")) + if entry.substrait.kind != SubstraitMappingKind.RelationExtension: + return Err( + invalid_scalar_expression(f"{entry.function_ref} does not declare a relation-extension mapping"), + ) + return Ok(entry) + None => + return Err(invalid_scalar_expression(f"missing generator registry entry for `{generator.canonical_name}`")) + + +def _resolved_generator( + generator: GeneratorApplication, + input_columns: list[str], +) -> Result[ResolvedGeneratorApplication, SubstraitLoweringError]: + """Resolve one generator application against input-column names.""" + _validate_generator_output_columns(input_columns, generator.clone())? + return Ok( + ResolvedGeneratorApplication( + generator=generator.clone(), + entry=_generator_registry_entry(generator.clone())?, + expr=scalar_expr(input_columns, generator.expr)?, + ), + ) + + +def _validate_generator_output_columns( + input_columns: list[str], + generator: GeneratorApplication, +) -> Result[None, SubstraitLoweringError]: + """Validate generator output columns against the current input relation shape.""" + mut output_columns: list[str] = [] + if generator.preserves_input_columns: + output_columns.extend(input_columns) + for output_column in generator.output_columns: + if _contains_text(output_columns, output_column): + return Err( + invalid_scalar_expression(f"generator output column `{output_column}` conflicts with an existing column"), + ) + output_columns.append(output_column) + return Ok(None) + + +def _contains_text(values: list[str], expected: str) -> bool: + """Return whether a string list contains a value.""" + for value in values: + if value == expected: + return true + return false + + def _aggregate_function_reference(measure: ResolvedAggregateMeasure) -> Result[u32, SubstraitLoweringError]: """Resolve one aggregate measure through declaration-side registry metadata.""" match _aggregate_registry_entry(measure): @@ -603,6 +668,31 @@ pub def try_aggregate_rel_of_columns( ) +pub def generator_rel(input: Rel, generator: GeneratorApplication) -> Rel: + """Wrap a child relation in a generator relation-extension node.""" + return _lowered_rel_or_raise(try_generator_rel(input, generator)) + + +pub def try_generator_rel(input: Rel, generator: GeneratorApplication) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a generator relation-extension node.""" + return try_generator_rel_of_columns(input.clone(), relation_output_columns(input), generator) + + +pub def generator_rel_of_columns(input: Rel, input_columns: list[str], generator: GeneratorApplication) -> Rel: + """Wrap a child relation in a generator relation-extension node using explicit input-column names.""" + return _lowered_rel_or_raise(try_generator_rel_of_columns(input, input_columns, generator)) + + +pub def try_generator_rel_of_columns( + input: Rel, + input_columns: list[str], + generator: GeneratorApplication, +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a generator relation-extension node using explicit input-column names.""" + resolved = _resolved_generator(generator, input_columns)? + return Ok(extension_single_rel(input, resolved.entry.substrait.uri)) + + pub def sort_rel(input: Rel) -> Rel: """Wrap a child relation in `SortRel` using the first known output column as the default sort key.""" input_columns = relation_output_columns(input.clone()) diff --git a/tests/test_dataset.incn b/tests/test_dataset.incn index 8d762b1..3140a03 100644 --- a/tests/test_dataset.incn +++ b/tests/test_dataset.incn @@ -17,6 +17,8 @@ from functions import ( count_expr, count_if, eq, + explode, + explode_outer, float_expr, int_expr, int_lit, @@ -24,12 +26,19 @@ from functions import ( max, min, mul, + posexplode, + posexplode_outer, str_expr, str_lit, sum, ) from projection_builders import ColumnExprKind, column_expr_kind, column_expr_name -from substrait.function_extensions import explode_extension_uri +from substrait.function_extensions import ( + explode_extension_uri, + explode_outer_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, +) from substrait.inspect import plan_contains_relation_kind, plan_has_extension_urn, relation_kind_name, root_rel from substrait.plans import plan_encoded_len, plan_from_named_table, plan_from_root_relation from substrait.relations import read_named_table_rel @@ -422,6 +431,14 @@ def test_lazy_frame__independent_roots_can_join_and_lower() -> None: def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None: # -- Arrange -- _register_order_schema("orders") + register_named_table_schema( + "orders_generator_dataset", + [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false), RowColumnSpec( + name="line_items", + kind=SubstraitPrimitiveKind.String, + nullable=true, + )], + ) projected: LazyFrame[Order] = lazy_frame_named_table("orders").select() grouped: LazyFrame[Order] = lazy_frame_named_table("orders").group_by([col("id")]) @@ -430,6 +447,18 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None ordered: LazyFrame[Order] = lazy_frame_named_table("orders").order_by([col("id")]) limited: LazyFrame[Order] = lazy_frame_named_table("orders").limit(10) exploded: LazyFrame[Order] = lazy_frame_named_table("orders").explode() + generated: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + explode(col("line_items"), "line_item"), + ) + generated_outer: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + explode_outer(col("line_items"), "line_item"), + ) + generated_positional: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + posexplode(col("line_items"), "position", "line_item"), + ) + generated_positional_outer: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + posexplode_outer(col("line_items"), "position", "line_item"), + ) # -- Assert -- assert relation_kind_name(root_rel(projected.to_substrait_plan())) == "ProjectRel", "select should lower through the project boundary shape" @@ -438,6 +467,11 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None assert relation_kind_name(root_rel(ordered.to_substrait_plan())) == "SortRel", "order_by should lower to SortRel" assert relation_kind_name(root_rel(limited.to_substrait_plan())) == "FetchRel", "limit should lower to FetchRel" assert plan_has_extension_urn(exploded.to_substrait_plan(), explode_extension_uri()), "explode should keep emitting the registered extension boundary" + assert relation_kind_name(root_rel(generated.to_substrait_plan())) == "ExtensionSingleRel", "generate should lower through the relation extension boundary" + assert generated.planned_columns() == ["id", "line_items", "line_item"], "generate should append declared output aliases" + assert plan_has_extension_urn(generated_outer.to_substrait_plan(), explode_outer_extension_uri()), "outer explode should use its relation extension URI" + assert plan_has_extension_urn(generated_positional.to_substrait_plan(), posexplode_extension_uri()), "posexplode should use its relation extension URI" + assert plan_has_extension_urn(generated_positional_outer.to_substrait_plan(), posexplode_outer_extension_uri()), "posexplode_outer should use its relation extension URI" def test_lazy_frame__deeper_independent_roots_still_lower_with_stable_shapes() -> None: diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 22f8739..424147e 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -45,6 +45,8 @@ from functions import ( eq, equal_null, element_at, + explode, + explode_outer, floor, float_expr, function_registry_canonical_names, @@ -81,6 +83,8 @@ from functions import ( not_, nullif, or_, + posexplode, + posexplode_outer, registered_substrait_mapped_function_refs, round, str_expr, @@ -164,7 +168,11 @@ from substrait.function_extensions import ( ROUND_FUNCTION_ANCHOR, SUBTRACT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR, + explode_extension_uri, + explode_outer_extension_uri, function_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, ) @@ -223,7 +231,7 @@ def _local_entry_by_namespace_and_name_or_fail( def _expected_registry_names() -> list[str]: """Return the expected registered public helper names.""" - return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct"] + return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer"] def _expected_substrait_mapped_names() -> list[str]: @@ -313,6 +321,10 @@ def _exercise_current_public_helpers() -> None: map_keys(attr_map) map_values(attr_map) named_struct(["status", "amount"], [status, amount]) + explode(tags, "tag") + explode_outer(tags, "tag") + posexplode(tags, "position", "tag") + posexplode_outer(tags, "position", "tag") return @@ -346,6 +358,15 @@ def _assert_extension_mapping(canonical_name: str, function_name: str, anchor: u assert entry.substrait.anchor == anchor, f"{canonical_name} should carry the stable Substrait anchor" +def _assert_relation_extension_mapping(canonical_name: str, function_name: str, extension_uri: str) -> None: + """Assert one generator helper declares a relation-extension mapping.""" + entry = _entry_or_fail(function_ref_for(canonical_name)) + assert entry.function_class == FunctionClass.Generator, f"{canonical_name} should be classified as a generator" + assert entry.substrait.kind == SubstraitMappingKind.RelationExtension, f"{canonical_name} should use a relation extension" + assert entry.substrait.uri == extension_uri, f"{canonical_name} should carry the registered relation extension URI" + assert entry.substrait.function_name == function_name, f"{canonical_name} should use the registered extension name" + + def _assert_core_mapping(canonical_name: str, function_name: str) -> None: """Assert one helper declares the expected built-in Substrait Rex mapping.""" entry = _entry_or_fail(function_ref_for(canonical_name)) @@ -608,6 +629,18 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non _assert_extension_mapping("named_struct", "named_struct", NAMED_STRUCT_FUNCTION_ANCHOR) +def test_function_registry__generator_helpers_are_relation_extensions() -> None: + """Assert generator helpers are registry entries without scalar or aggregate extension anchors.""" + # -- Arrange -- + _exercise_current_public_helpers() + + # -- Act / Assert -- + _assert_relation_extension_mapping("explode", "explode", explode_extension_uri()) + _assert_relation_extension_mapping("explode_outer", "explode_outer", explode_outer_extension_uri()) + _assert_relation_extension_mapping("posexplode", "posexplode", posexplode_extension_uri()) + _assert_relation_extension_mapping("posexplode_outer", "posexplode_outer", posexplode_outer_extension_uri()) + + def test_function_registry__ordering_helpers_are_contextual_sort_fields() -> None: """Assert RFC 015 ordering helpers are modeled as sort-field context helpers.""" # -- Arrange -- diff --git a/tests/test_generator_functions.incn b/tests/test_generator_functions.incn new file mode 100644 index 0000000..052d784 --- /dev/null +++ b/tests/test_generator_functions.incn @@ -0,0 +1,55 @@ +"""Tests for registry-backed generator and table-valued function builders.""" + +from std.testing import assert_raises +from generator_builders import GeneratorKind, generator_output_columns, generator_primary_output_column +from functions import col, explode, explode_outer, posexplode, posexplode_outer + + +def test_generator_functions__explode_family_builds_relation_applications() -> None: + # -- Arrange -- + items = col("line_items") + + # -- Act -- + inner = explode(items, "line_item") + outer = explode_outer(items, "line_item") + positional = posexplode(items, "position", "line_item") + positional_outer = posexplode_outer(items, "position", "line_item") + + # -- Assert -- + assert inner.kind == GeneratorKind.Explode + assert outer.kind == GeneratorKind.ExplodeOuter + assert positional.kind == GeneratorKind.PosExplode + assert positional_outer.kind == GeneratorKind.PosExplodeOuter + assert not inner.is_outer + assert outer.is_outer + assert positional.position_origin == 0 + assert positional_outer.position_origin == 0 + assert generator_primary_output_column(inner) == "line_item" + assert generator_primary_output_column(positional) == "line_item" + + +def test_generator_functions__output_columns_preserve_input_then_append_aliases() -> None: + # -- Arrange -- + input_columns = ["id", "line_items"] + + # -- Act -- + exploded_columns = generator_output_columns(input_columns, explode(col("line_items"), "line_item")) + positional_columns = generator_output_columns(input_columns, posexplode(col("line_items"), "position", "line_item")) + + # -- Assert -- + assert exploded_columns == ["id", "line_items", "line_item"] + assert positional_columns == ["id", "line_items", "position", "line_item"] + + +def _call_generator_with_input_collision() -> None: + """Call generator output inference with a generated name that collides with input.""" + generator_output_columns(["id", "line_items"], explode(col("line_items"), "id")) + return + + +def test_generator_functions__output_alias_collisions_are_rejected() -> None: + # -- Arrange -- + call = _call_generator_with_input_collision + + # -- Act / Assert -- + assert_raises[ValueError](call) diff --git a/tests/test_prism.incn b/tests/test_prism.incn index f0c7490..fc1f097 100644 --- a/tests/test_prism.incn +++ b/tests/test_prism.incn @@ -1,9 +1,10 @@ """Internal Prism engine tests against the shared-store cursor substrate.""" -from functions import always_false, always_true, col, count, count_expr, lit, mul, sum +from functions import always_false, always_true, col, count, count_expr, explode, lit, mul, sum from prism import ( PrismCursor, prism_cursor_apply_filter, + prism_cursor_apply_generate, prism_cursor_apply_limit, prism_cursor_apply_select, prism_cursor_authored_node_count, @@ -35,6 +36,17 @@ def _register_projection_test_schema(table_name: str) -> None: register_named_table_schema(table_name, [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false)]) +def _register_generator_test_schema(table_name: str) -> None: + register_named_table_schema( + table_name, + [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false), RowColumnSpec( + name="line_items", + kind=SubstraitPrimitiveKind.String, + nullable=true, + )], + ) + + def test_prism__branching_keeps_base_reachable_history_small() -> None: # -- Arrange -- base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) @@ -206,6 +218,7 @@ def test_prism__cross_store_adoption_keeps_distinct_aggregate_modifier_state() - def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: # -- Arrange -- _register_projection_test_schema(str("orders")) + _register_generator_test_schema(str("orders_generator_prism")) projected: PrismCursor[Order] = prism_cursor_named_table(str("orders")).select() grouped: PrismCursor[Order] = prism_cursor_named_table(str("orders")).group_by([col("id")]) @@ -214,6 +227,9 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: ordered: PrismCursor[Order] = prism_cursor_named_table(str("orders")).order_by([col("id")]) limited: PrismCursor[Order] = prism_cursor_named_table(str("orders")).limit(10) exploded: PrismCursor[Order] = prism_cursor_named_table(str("orders")).explode() + generated: PrismCursor[Order] = prism_cursor_named_table(str("orders_generator_prism")).generate( + explode(col("line_items"), "line_item"), + ) # -- Assert -- assert prism_cursor_tip_kind_name(projected) == str("Project"), "select should append a native project node" @@ -222,6 +238,8 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: assert prism_cursor_tip_kind_name(ordered) == str("OrderBy"), "order_by should append a native sort node" assert prism_cursor_tip_kind_name(limited) == str("Limit"), "limit should append a native limit node" assert prism_cursor_tip_kind_name(exploded) == str("Explode"), "explode should append a native explode node" + assert prism_cursor_tip_kind_name(generated) == str("Generate"), "generate should append a native generator node" + assert prism_cursor_output_columns(generated) == ["id", "line_items", "line_item"], "generate should append declared output aliases" def test_prism__rewrite_eliminates_filter_true_by_default() -> None: @@ -332,3 +350,21 @@ def test_prism__cursor_methods_match_apply_helpers() -> None: assert relation_kind_name(root_rel(via_methods.to_substrait_plan())) == relation_kind_name( root_rel(via_helpers.to_substrait_plan()), ), "method and helper paths should lower to equivalent root relation kinds" + + +def test_prism__generate_method_matches_apply_helper() -> None: + # -- Arrange -- + _register_generator_test_schema("orders_generator_apply") + base: PrismCursor[Order] = prism_cursor_named_table(str("orders_generator_apply")) + generator = explode(col("line_items"), "line_item") + + # -- Act -- + via_method = base.generate(generator) + via_helper = prism_cursor_apply_generate(base, generator) + + # -- Assert -- + assert prism_cursor_tip_kind_name(via_method) == prism_cursor_tip_kind_name(via_helper), "method and helper paths should produce the same generator node kind" + assert prism_cursor_output_columns(via_method) == ["id", "line_items", "line_item"], "generator helper should preserve planned output columns" + assert relation_kind_name(root_rel(via_method.to_substrait_plan())) == relation_kind_name( + root_rel(via_helper.to_substrait_plan()), + ), "generator method and helper paths should lower to equivalent root relation kinds" diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 17a70c6..04bb902 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -75,7 +75,10 @@ from substrait.errors import SubstraitLoweringErrorKind from substrait.expr_lowering import scalar_expr from substrait.function_extensions import ( explode_extension_uri, + explode_outer_extension_uri, function_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, registered_substrait_extension_uris, ) from substrait.inspect import ( @@ -83,6 +86,8 @@ from substrait.inspect import ( aggregate_measure_filter_flags, aggregate_measure_invocation_names, aggregate_measure_sort_counts, + plan_extension_urn_anchor_at, + plan_extension_urn_count, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -591,13 +596,22 @@ def test_plan__extension_urns_are_surfaced() -> None: # -- Arrange -- extension_uri = explode_extension_uri() rel = extension_single_rel(read_named_table_rel("orders"), extension_uri) + nested = extension_single_rel( + extension_single_rel(read_named_table_rel("orders"), explode_extension_uri()), + posexplode_extension_uri(), + ) # -- Act -- plan = plan_from_root_relation(rel, ["id"]) + nested_plan = plan_from_root_relation(nested, ["id", "position", "value"]) # -- Assert -- assert plan_has_extension_urn(plan, extension_uri), "extension relation should populate extension URNs" assert plan_contains_relation_kind(plan, "ExtensionSingleRel"), "extension root should remain inspectable" + assert plan_has_extension_urn(nested_plan, explode_extension_uri()), "nested extension plans should include child extension URNs" + assert plan_has_extension_urn(nested_plan, posexplode_extension_uri()), "nested extension plans should include root extension URNs" + assert plan_extension_urn_count(nested_plan) == 2, "different relation extension URIs should be declared once each" + assert plan_extension_urn_anchor_at(nested_plan, 0) != plan_extension_urn_anchor_at(nested_plan, 1), "relation extension URNs should use distinct anchors" def test_plan__revision_pin_and_extension_registry_are_exported() -> None: @@ -611,9 +625,12 @@ def test_plan__revision_pin_and_extension_registry_are_exported() -> None: # -- Assert -- assert tag == "v0.63.0", "revision helpers should expose the currently targeted Substrait release tag" assert producer == "inql-rfc002", "revision helpers should expose the package producer label" - assert len(registered) == 2, "current package boundary should register both extension URIs" + assert len(registered) == 5, "current package boundary should register function and generator extension URIs" assert registered[0] == function_extension_uri(), "registry should include the shared function extension URI first" assert registered[1] == explode_extension_uri(), "registry should include the emitted explode extension URI" + assert registered[2] == explode_outer_extension_uri(), "registry should include the outer explode extension URI" + assert registered[3] == posexplode_extension_uri(), "registry should include the positional explode extension URI" + assert registered[4] == posexplode_outer_extension_uri(), "registry should include the outer positional explode extension URI" def test_conformance__core_scenarios_validate_emission_output() -> None: From 083dfc2326fa443999a36e8d0edc29a667b36991 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Tue, 26 May 2026 01:36:52 +0200 Subject: [PATCH 5/6] feature - implement RFC 019 window foundations (#36) --- docs/language/reference/dataset_methods.md | 3 + docs/language/reference/functions/index.md | 4 +- docs/language/reference/functions/windows.md | 33 ++++ .../reference/substrait/operator_catalog.md | 2 +- docs/release_notes/v0_1.md | 1 + docs/rfcs/019_window_functions.md | 28 ++-- docs/rfcs/README.md | 2 +- src/dataset/mod.incn | 27 ++++ src/dataset/ops.incn | 12 ++ src/functions/mod.incn | 4 + src/functions/windows/dense_rank.incn | 36 +++++ src/functions/windows/mod.incn | 6 + src/functions/windows/rank.incn | 36 +++++ src/functions/windows/row_number.incn | 36 +++++ src/functions/windows/window.incn | 35 ++++ src/lib.incn | 22 +++ src/prism/lower.incn | 7 + src/prism/mod.incn | 41 +++++ src/prism/output_columns.incn | 5 + src/prism/rewrite.incn | 6 + src/prism/store.incn | 48 ++++++ src/prism/types.incn | 3 + src/substrait/extensions.incn | 56 ++++++- src/substrait/function_extensions.incn | 4 + src/substrait/inspect.incn | 51 +++++- src/substrait/relations.incn | 118 ++++++++++++++ src/substrait/traversal.incn | 4 + src/window_builders.incn | 152 ++++++++++++++++++ tests/test_dataset.incn | 10 ++ tests/test_function_registry.incn | 32 +++- tests/test_prism.incn | 27 +++- tests/test_substrait_plan.incn | 91 +++++++++++ tests/test_window_functions.incn | 47 ++++++ 33 files changed, 967 insertions(+), 22 deletions(-) create mode 100644 docs/language/reference/functions/windows.md create mode 100644 src/functions/windows/dense_rank.incn create mode 100644 src/functions/windows/mod.incn create mode 100644 src/functions/windows/rank.incn create mode 100644 src/functions/windows/row_number.incn create mode 100644 src/functions/windows/window.incn create mode 100644 src/window_builders.incn create mode 100644 tests/test_window_functions.incn diff --git a/docs/language/reference/dataset_methods.md b/docs/language/reference/dataset_methods.md index 9a9a701..05503fd 100644 --- a/docs/language/reference/dataset_methods.md +++ b/docs/language/reference/dataset_methods.md @@ -20,6 +20,7 @@ The Substrait helper surface behind these methods is split by semantic role: | `group_by` | `def group_by(self, columns: list[ColumnExpr]) -> Self` | Define grouping keys using scalar expressions. | | `agg` | `def agg(self, measures: list[AggregateMeasure]) -> Self` | Apply aggregate measures over the current relation or current grouping. | | `generate` | `def generate(self, generator: GeneratorApplication) -> Self` | Apply a relation-shaping generator such as `explode(...)` with explicit output aliases. | +| `with_window_column` | `def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self` | Add or replace one projected column using a placed window function. | | `order_by` | `def order_by(self, columns: list[ColumnExpr]) -> Self` | Sort rows by scalar expressions or ordering helpers such as `asc(...)` and `desc(...)`. | | `limit` | `def limit(self, n: int) -> Self` | Cap row count. | | `explode` | `def explode(self) -> Self` | Compatibility marker for the older EXPLODE extension path. Prefer `generate(explode(...))`. | @@ -69,6 +70,7 @@ def enrich(orders: LazyFrame[Order]) -> LazyFrame[Order]: - `join(...)` is constrained to same-carrier inputs and the boolean join predicate surface shown in the signature. - `select(...)` preserves projection shape; explicit projection lists are represented today through `with_column(...)` and scalar-expression builders. - `generate(...)` preserves all input columns and appends generated output aliases. Alias collisions are rejected during planning/lowering. +- `with_window_column(...)` currently supports ranking helpers over explicit window specs and lowers through Substrait window relations. Backend execution support is tracked separately from logical planning support. - `DataFrame[T]` exposes materialized metadata and preview text; row-level accessors belong to the materialized DataFrame API surface. - Query-block and scoped DSL surfaces lower into these builder APIs rather than defining separate method semantics. @@ -77,3 +79,4 @@ def enrich(orders: LazyFrame[Order]) -> LazyFrame[Order]: - [Filter builders](builders/filters.md) - [Aggregate builders](builders/aggregates.md) - [Projection builders](builders/projections.md) +- [Window functions](functions/windows.md) diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index e65ea90..adc070d 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -9,10 +9,11 @@ Today the concrete shipped surfaces are documented here: - [Projection builders](../builders/projections.md) - [Generator and table-valued functions](generators.md) - [Nested data functions](nested.md) +- [Window functions](windows.md) The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, and nested data. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, nested data, and windows. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. @@ -35,6 +36,7 @@ The registered helper surface currently includes: | `abs(...)`, `ceil(...)`, `floor(...)`, `round(...)` | scalar | registered Substrait math scalar mappings; `round(...)` is currently the single-argument form | | `array(...)`, `cardinality(...)`, `array_contains(...)`, `arrays_overlap(...)`, `array_position(...)`, `element_at(...)`, `array_sort(...)`, `array_distinct(...)`, `array_except(...)`, `array_intersect(...)`, `array_union(...)`, `array_join(...)`, `array_slice(...)`, `array_reverse(...)`, `array_flatten(...)`, `map_from_arrays(...)`, `map_extract(...)`, `map_contains_key(...)`, `map_keys(...)`, `map_values(...)`, `map_entries(...)`, `named_struct(...)` | scalar | registered nested scalar helpers backed by Substrait extension mappings; `map_contains_key(...)` lowers as a documented predicate rewrite | | `explode(...)`, `explode_outer(...)`, `posexplode(...)`, `posexplode_outer(...)` | generator | relation-extension mappings consumed by `generate(...)`; positional forms use zero-based positions | +| `window()`, `row_number()`, `rank()`, `dense_rank()` | window | `window()` builds structural window-spec metadata; ranking helpers lower through `ConsistentPartitionWindowRel` when placed with `with_window_column(...)` | | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count()`, `count_expr(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions; `count_expr(...)` is a compatibility spelling for future `count(expr)` helper overloading | | `count_distinct(...)`, `count_if(...)` | aggregate | compatibility helpers that lower through aggregate modifiers over canonical `count` semantics | diff --git a/docs/language/reference/functions/windows.md b/docs/language/reference/functions/windows.md new file mode 100644 index 0000000..37185ae --- /dev/null +++ b/docs/language/reference/functions/windows.md @@ -0,0 +1,33 @@ +# Window Functions (Reference) + +Window helpers are relation-aware. A window function application produces one output value per input row while reading a +partition of related rows. It is not an ordinary scalar expression and must be placed through a projection-like dataset +method. + +```incan +from pub::inql import LazyFrame +from pub::inql.functions import col, desc, rank, window +from models import Order + +def ranked_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: + return orders.with_window_column( + "customer_rank", + rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))])), + ) +``` + +The current foundation slice includes: + +| Function | Meaning | Placement | +| --- | --- | --- | +| `window()` | Build an empty window specification. | Structural builder used before `.over(...)`. | +| `row_number()` | Assign a sequential row number inside the ordered window. | Use `.over(window().order_by(...))`, then `with_window_column(...)`. | +| `rank()` | Rank rows with gaps after ties inside the ordered window. | Use `.over(window().order_by(...))`, then `with_window_column(...)`. | +| `dense_rank()` | Rank rows without gaps after ties inside the ordered window. | Use `.over(window().order_by(...))`, then `with_window_column(...)`. | + +`WindowSpec.partition_by(...)` replaces the partition expressions. `WindowSpec.order_by(...)` replaces the ordering +expressions. Ranking helpers require explicit ordering; missing ordering is rejected during logical lowering. + +`with_window_column(name, application)` preserves input columns and adds or replaces `name` using add-or-replace +projection semantics. Each call lowers one window projection through Substrait `ConsistentPartitionWindowRel` with a +registry-backed function anchor. Backend execution support is separate from this logical planning surface. diff --git a/docs/language/reference/substrait/operator_catalog.md b/docs/language/reference/substrait/operator_catalog.md index 327ad49..81bfecb 100644 --- a/docs/language/reference/substrait/operator_catalog.md +++ b/docs/language/reference/substrait/operator_catalog.md @@ -34,7 +34,7 @@ The following table maps InQL plan capabilities to Substrait logical relations a | Group by / aggregates | `AggregateRel` with scalar grouping keys and aggregate measures; grouping sets are tracked as a distinct capability below | core | | Rollup / cube / grouping sets | `AggregateRel` with multiple groupings | core | | Distinct rows | `AggregateRel` with grouping keys and no measures | core | -| Window / analytic functions | `ProjectRel` with window expressions | core | +| Window / analytic functions | `ConsistentPartitionWindowRel` with partition/order expressions and registered window function anchors | core | | Sort | `SortRel` | core | | Limit / offset | `FetchRel` | core | | Union, intersect, except | `SetRel` with the appropriate set operation enum | core | diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index 5c23085..d337e4e 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -17,6 +17,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Common scalar functions:** The first RFC 018 slice adds registry-backed math helpers for `abs(...)`, `ceil(...)`, `floor(...)`, and single-argument `round(...)`, with Substrait mappings and DataFusion-backed execution coverage. - **Nested data functions:** RFC 020 adds registry-backed scalar helpers for array construction/access, cardinality, containment, overlap, sorting, set-like operations, joining, slicing, reversing, scalar array flattening, map construction/access, map key/value/entry extraction, map key containment, and named struct construction. These helpers lower through Substrait extension metadata and execute through the DataFusion-backed Session path without introducing generator semantics. - **Generator functions:** RFC 021 adds registry-backed generator applications for `explode(...)`, `explode_outer(...)`, `posexplode(...)`, and `posexplode_outer(...)`. Generators remain relation-shaping operations applied with `generate(...)`; they preserve input columns, require explicit output aliases, and lower through the current Substrait extension-relation gap encoding. +- **Window functions:** RFC 019 adds the first window-function planning slice with `window()` specs, `row_number()`, `rank()`, `dense_rank()`, and `with_window_column(...)`. Ranking windows require explicit ordering and lower through Substrait `ConsistentPartitionWindowRel`; backend execution support remains a separate adapter capability. - **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. - **Function extension policy:** RFC 024 policy metadata now distinguishes portable core functions, namespaced extension-only functions, opt-in compatibility aliases, engine-specific functions, and rejected compatibility requests without adding an extension plugin system or backend-owned semantics. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. diff --git a/docs/rfcs/019_window_functions.md b/docs/rfcs/019_window_functions.md index 7e6fb1c..b509d88 100644 --- a/docs/rfcs/019_window_functions.md +++ b/docs/rfcs/019_window_functions.md @@ -1,6 +1,6 @@ # InQL RFC 019: Window functions -- **Status:** Draft +- **Status:** In Progress - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -11,7 +11,7 @@ - InQL RFC 016 (core aggregate functions) - **Issue:** [InQL #36](https://github.com/dannys-code-corner/InQL/issues/36) - **RFC PR:** — -- **Written against:** Incan v0.2 +- **Written against:** Incan v0.3-era InQL - **Shipped in:** — ## Summary @@ -40,19 +40,18 @@ Window functions also force a clearer relation between row-level expressions and ## Guide-level explanation (how authors think about it) -Authors should be able to rank and compare rows within a partition: +Authors can rank rows within a partition using the builder surface: ```incan -from pub::inql.functions import col, desc, lag, rank, window +from pub::inql.functions import col, desc, rank, window ranked = ( orders - .with_column("customer_rank", rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]))) - .with_column("previous_amount", lag(col("amount"), 1).over(window().partition_by([col("customer_id")]).order_by([col("created_at")]))) + .with_window_column("customer_rank", rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]))) ) ``` -The exact builder syntax may evolve, but authors should understand that a window function returns a row-level value computed with access to nearby or related rows. +The exact query-block syntax may evolve, but authors should understand that a window function returns a row-level value computed with access to nearby or related rows. ## Reference-level explanation (precise rules) @@ -108,10 +107,17 @@ No current InQL function should be reclassified silently as a window function. A - **Execution / interchange** — Prism and Substrait lowering must preserve window partitioning, ordering, frames, and function identity. - **Documentation** — docs should clearly separate aggregate functions from window functions. -## Unresolved questions +## Design Decisions + +### Resolved + +- The first implementation slice exposes explicit `with_window_column(...)` projection-like placement rather than accepting window functions in arbitrary scalar-expression positions. +- Ranking helpers require explicit `order_by(...)` in the window spec. InQL does not invent a silent default ordering. +- The current foundation slice lowers `row_number`, `rank`, and `dense_rank` through `ConsistentPartitionWindowRel` with registry-backed function anchors. +- DataFusion execution for window relations is not claimed until a backend adapter slice explicitly supports the lowered window relation. + +### Remaining - What default frame should InQL use for ordered window functions? -- Should window functions be allowed in `WHERE` or only in projection/order positions? - Should null treatment use explicit `IGNORE NULLS` / `RESPECT NULLS` style modifiers? - - +- How should `lag`, `lead`, first/last/nth value functions, aggregate-over-window calls, and query-block `OVER (...)` syntax be phased on top of the foundation model? diff --git a/docs/rfcs/README.md b/docs/rfcs/README.md index c42c434..4ec010f 100644 --- a/docs/rfcs/README.md +++ b/docs/rfcs/README.md @@ -25,7 +25,7 @@ InQL uses its **own** RFC series (starting at 000), independent of the [Incan la | [016][rfc-016] | Draft | Core aggregate functions | | | [017][rfc-017] | Draft | Aggregate modifiers | | | [018][rfc-018] | Draft | Common scalar function catalog | | -| [019][rfc-019] | Draft | Window functions | | +| [019][rfc-019] | In Progress | Window functions | | | [020][rfc-020] | Draft | Nested data functions | | | [021][rfc-021] | In Progress | Generator and table-valued functions | | | [022][rfc-022] | Draft | Semi-structured and format functions | | diff --git a/src/dataset/mod.incn b/src/dataset/mod.incn index e9b31b1..9a1b134 100644 --- a/src/dataset/mod.incn +++ b/src/dataset/mod.incn @@ -23,6 +23,7 @@ The current method-chain surface in this module is the explicit builder-based AP - `group_by(columns: list[ColumnExpr])` - `agg(measures: list[AggregateMeasure])` - `generate(generator: GeneratorApplication)` +- `with_window_column(name: str, application: WindowFunctionApplication)` - plus the structural operators `join`, `select`, `order_by`, `limit`, and `explode` Illustrative current-shape examples: @@ -56,6 +57,7 @@ from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure from generator_builders import GeneratorApplication from projection_builders import ColumnExpr +from window_builders import WindowFunctionApplication from dataset.materialization import DataFrameMaterialization from substrait.errors import SubstraitLoweringError from substrait.schema_registry import named_table_columns @@ -72,6 +74,7 @@ from dataset.ops import ( order_by_ds_of_columns, select_ds_of_columns, with_column_ds, + with_window_column_ds, ) from session.types import SessionError, collect_with_active_session from prism import ( @@ -86,6 +89,7 @@ from prism import ( prism_cursor_apply_order_by, prism_cursor_apply_select, prism_cursor_apply_with_column, + prism_cursor_apply_with_window_column, prism_cursor_named_table, prism_cursor_output_columns, ) @@ -103,6 +107,7 @@ pub trait DataSet[T with Clone]: def group_by(self, columns: list[ColumnExpr]) -> Self def agg(self, measures: list[AggregateMeasure]) -> Self def generate(self, generator: GeneratorApplication) -> Self + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self def order_by(self, columns: list[ColumnExpr]) -> Self def limit(self, n: int) -> Self def explode(self) -> Self @@ -218,6 +223,12 @@ pub class DataFrame[T with Clone] with BoundedDataSet: generate_ds_of_columns(self._substrait_rel, self.planned_columns(), generator), ) + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self: + """Return one new DataFrame with a named window projection stage and stale materialization cleared.""" + return _data_frame_with_invalidated_materialization( + with_window_column_ds(self._substrait_rel, self.planned_columns(), name, application), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataFrame with an ordering stage and stale materialization cleared.""" return _data_frame_with_invalidated_materialization( @@ -303,6 +314,10 @@ pub class LazyFrame[T with Clone] with BoundedDataSet: """Return one new lazy carrier with an appended generator stage.""" return LazyFrame(_cursor=prism_cursor_apply_generate(self._cursor, generator)) + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self: + """Return one new lazy carrier with an appended named window projection stage.""" + return LazyFrame(_cursor=prism_cursor_apply_with_window_column(self._cursor, name, application)) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new lazy carrier with an appended ordering stage.""" return LazyFrame(_cursor=prism_cursor_apply_order_by(self._cursor, columns)) @@ -456,6 +471,18 @@ pub class DataStream[T with Clone] with UnboundedDataSet: ), ) + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self: + """Return one new DataStream with a named window projection stage.""" + return DataStream( + _row_schema_marker=self._row_schema_marker.clone(), + _substrait_rel=with_window_column_ds( + self._substrait_rel, + relation_output_columns(self._substrait_rel.clone()), + name, + application, + ), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataStream with an ordering stage.""" return DataStream( diff --git a/src/dataset/ops.incn b/src/dataset/ops.incn index bafad30..d8c90c2 100644 --- a/src/dataset/ops.incn +++ b/src/dataset/ops.incn @@ -10,6 +10,7 @@ from rust::substrait::proto import Rel from aggregate_builders import AggregateMeasure from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment +from window_builders import WindowFunctionApplication, window_projection from substrait.function_extensions import explode_extension_uri from substrait.inspect import relation_output_columns from substrait.relations import ( @@ -21,6 +22,7 @@ from substrait.relations import ( project_rel_of_columns, sort_rel_of_columns, generator_rel_of_columns, + window_rel_of_columns, ) @@ -134,6 +136,16 @@ pub def generate_ds_of_columns(rel: Rel, input_columns: list[str], generator: Ge return generator_rel_of_columns(rel, input_columns, generator) +pub def with_window_column_ds( + rel: Rel, + input_columns: list[str], + name: str, + application: WindowFunctionApplication, +) -> Rel: + """Apply one dataset-level named window projection using explicit input-column names.""" + return window_rel_of_columns(rel, input_columns, [window_projection(name, application)]) + + pub def order_by_ds(rel: Rel, columns: list[ColumnExpr]) -> Rel: """ Apply dataset-level ordering intent to one relation. diff --git a/src/functions/mod.incn b/src/functions/mod.incn index c20b662..1cfc03c 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -65,6 +65,10 @@ pub from functions.generators.explode import explode pub from functions.generators.explode_outer import explode_outer pub from functions.generators.posexplode import posexplode pub from functions.generators.posexplode_outer import posexplode_outer +pub from functions.windows.window import window +pub from functions.windows.row_number import row_number +pub from functions.windows.rank import rank +pub from functions.windows.dense_rank import dense_rank pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div diff --git a/src/functions/windows/dense_rank.incn b/src/functions/windows/dense_rank.incn new file mode 100644 index 0000000..2545172 --- /dev/null +++ b/src/functions/windows/dense_rank.incn @@ -0,0 +1,36 @@ +"""Dense-rank window helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from substrait.function_extensions import DENSE_RANK_FUNCTION_ANCHOR +from window_builders import WindowFunctionCall, dense_rank as dense_rank_builder + + +@function_registry.add("dense_rank", deterministic_spec( + FunctionClass.Window, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + extension_mapping("dense_rank", DENSE_RANK_FUNCTION_ANCHOR), +)) +pub def dense_rank() -> WindowFunctionCall: + """ + Build a dense-rank window function call. + + Examples: + ranked = dense_rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))])) + """ + return dense_rank_builder() + + +module tests: + def test_dense_rank_builds_window_call() -> None: + call = dense_rank() + assert call.canonical_name == "dense_rank" + assert call.requires_ordering diff --git a/src/functions/windows/mod.incn b/src/functions/windows/mod.incn new file mode 100644 index 0000000..3c203cf --- /dev/null +++ b/src/functions/windows/mod.incn @@ -0,0 +1,6 @@ +"""Window specification and ranking helper functions.""" + +pub from functions.windows.window import window +pub from functions.windows.row_number import row_number +pub from functions.windows.rank import rank +pub from functions.windows.dense_rank import dense_rank diff --git a/src/functions/windows/rank.incn b/src/functions/windows/rank.incn new file mode 100644 index 0000000..54f7ff0 --- /dev/null +++ b/src/functions/windows/rank.incn @@ -0,0 +1,36 @@ +"""Rank window helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from substrait.function_extensions import RANK_FUNCTION_ANCHOR +from window_builders import WindowFunctionCall, rank as rank_builder + + +@function_registry.add("rank", deterministic_spec( + FunctionClass.Window, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + extension_mapping("rank", RANK_FUNCTION_ANCHOR), +)) +pub def rank() -> WindowFunctionCall: + """ + Build a rank window function call. + + Examples: + ranked = rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))])) + """ + return rank_builder() + + +module tests: + def test_rank_builds_window_call() -> None: + call = rank() + assert call.canonical_name == "rank" + assert call.requires_ordering diff --git a/src/functions/windows/row_number.incn b/src/functions/windows/row_number.incn new file mode 100644 index 0000000..f22ee64 --- /dev/null +++ b/src/functions/windows/row_number.incn @@ -0,0 +1,36 @@ +"""Row-number window helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from substrait.function_extensions import ROW_NUMBER_FUNCTION_ANCHOR +from window_builders import WindowFunctionCall, row_number as row_number_builder + + +@function_registry.add("row_number", deterministic_spec( + FunctionClass.Window, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + extension_mapping("row_number", ROW_NUMBER_FUNCTION_ANCHOR), +)) +pub def row_number() -> WindowFunctionCall: + """ + Build a row-number window function call. + + Examples: + numbered = row_number().over(window().partition_by([col("customer_id")]).order_by([col("created_at")])) + """ + return row_number_builder() + + +module tests: + def test_row_number_builds_window_call() -> None: + call = row_number() + assert call.canonical_name == "row_number" + assert call.requires_ordering diff --git a/src/functions/windows/window.incn b/src/functions/windows/window.incn new file mode 100644 index 0000000..7bcdc18 --- /dev/null +++ b/src/functions/windows/window.incn @@ -0,0 +1,35 @@ +"""Window specification builder helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + structural_mapping, + v0_1, +) +from functions.registry import function_registry +from window_builders import WindowSpec, window as window_builder + + +@function_registry.add("window", deterministic_spec( + FunctionClass.Window, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + structural_mapping("window_spec"), +)) +pub def window() -> WindowSpec: + """ + Build an empty window specification. + + Examples: + spec = window().partition_by([col("customer_id")]).order_by([col("created_at")]) + """ + return window_builder() + + +module tests: + def test_window_builds_empty_window_spec() -> None: + spec = window() + assert len(spec.partition_columns) == 0 + assert len(spec.sort_columns) == 0 diff --git a/src/lib.incn b/src/lib.incn index 87604d7..a707823 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -16,6 +16,7 @@ pub from dataset.ops import ( limit_ds, order_by_ds, select_ds, + with_window_column_ds, ) pub from aggregate_builders import AggregateKind, AggregateMeasure pub from generator_builders import ( @@ -24,6 +25,15 @@ pub from generator_builders import ( generator_output_columns, generator_primary_output_column, ) +pub from window_builders import ( + WindowFunctionApplication, + WindowFunctionCall, + WindowFunctionKind, + WindowProjection, + WindowSpec, + window_output_columns, + window_projection, +) pub from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -101,6 +111,10 @@ pub from functions.generators.explode import explode pub from functions.generators.explode_outer import explode_outer pub from functions.generators.posexplode import posexplode pub from functions.generators.posexplode_outer import posexplode_outer +pub from functions.windows.window import window +pub from functions.windows.row_number import row_number +pub from functions.windows.rank import rank +pub from functions.windows.dense_rank import dense_rank pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div @@ -218,6 +232,8 @@ pub from substrait.relations import ( set_rel_of_kind, sort_rel, sort_rel_of_columns, + window_rel, + window_rel_of_columns, ) pub from substrait.plans import ( empty_plan, @@ -244,6 +260,9 @@ pub from substrait.inspect import ( rel_contains_kind, root_rel, set_operation_name, + window_function_names, + window_partition_count, + window_sort_count, ) pub from substrait.function_extensions import ( explode_extension_uri, @@ -251,6 +270,9 @@ pub from substrait.function_extensions import ( function_extension_uri, posexplode_extension_uri, posexplode_outer_extension_uri, + DENSE_RANK_FUNCTION_ANCHOR, + RANK_FUNCTION_ANCHOR, + ROW_NUMBER_FUNCTION_ANCHOR, registered_substrait_extension_uris, ) pub from substrait.conformance_catalog import ( diff --git a/src/prism/lower.incn b/src/prism/lower.incn index 6020b57..078a2cd 100644 --- a/src/prism/lower.incn +++ b/src/prism/lower.incn @@ -15,6 +15,7 @@ from substrait.relations import ( try_aggregate_rel_of_columns, try_filter_rel_of_columns, try_project_rel_of_columns, + try_window_rel_of_columns, ) from substrait.errors import SubstraitLoweringError from prism.rewrite import derive_rewritten_view, rewritten_node_at @@ -125,6 +126,12 @@ def _try_lower_node(view: PrismOptimizedView, node_id: int) -> Result[Rel, Subst rewritten_output_columns(view, node.input_ids[0]), node.generator_applications[0], ) + PrismNodeKind.Window => + return try_window_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, + rewritten_output_columns(view, node.input_ids[0]), + node.window_projections, + ) PrismNodeKind.OrderBy => return Ok( sort_rel_of_columns( diff --git a/src/prism/mod.incn b/src/prism/mod.incn index 3564d35..09d2933 100644 --- a/src/prism/mod.incn +++ b/src/prism/mod.incn @@ -15,6 +15,7 @@ from aggregate_builders import AggregateMeasure from filter_builders import always_true from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment +from window_builders import WindowFunctionApplication, window_projection from prism.lower import ( lower_prism_tip as lower_prism_tip_impl, prism_rel_to_plan, @@ -71,6 +72,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -90,6 +92,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -106,6 +109,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -124,6 +128,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -142,6 +147,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[with_column_assignment(name, expr)], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -160,6 +166,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -178,6 +185,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=measures, generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -196,6 +204,7 @@ pub class PrismCursor[T with Clone]: sort_columns=columns, aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -214,6 +223,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -232,6 +242,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -250,6 +261,26 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[generator], + window_projections=[], + projection_assignments=[], + ) + return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) + + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self: + """Append one named window projection and return the derived tip.""" + next_tip_id = append_node( + store_id=self.store_id, + kind=PrismNodeKind.Window, + input_ids=[self.tip_id], + named_table="", + join_predicate=false, + filter_predicate=always_true(), + limit_count=0, + group_columns=[], + sort_columns=[], + aggregate_measures=[], + generator_applications=[], + window_projections=[window_projection(name, application)], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -294,6 +325,7 @@ pub def prism_cursor_named_table[T with Clone](table_name: str) -> PrismCursor[T sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=store_id, tip_id=tip_id, _type_marker=[]) @@ -363,6 +395,15 @@ pub def prism_cursor_apply_generate[T with Clone]( return cursor.generate(generator) +pub def prism_cursor_apply_with_window_column[T with Clone]( + cursor: PrismCursor[T], + name: str, + application: WindowFunctionApplication, +) -> PrismCursor[T]: + """Apply one named window projection through Prism.""" + return cursor.with_window_column(name, application) + + pub def prism_cursor_output_columns[T with Clone](cursor: PrismCursor[T]) -> list[str]: """Return plan-time output columns for one cursor tip.""" return cursor.planned_columns() diff --git a/src/prism/output_columns.incn b/src/prism/output_columns.incn index d1cfa06..6c88c6a 100644 --- a/src/prism/output_columns.incn +++ b/src/prism/output_columns.incn @@ -7,6 +7,7 @@ from generator_builders import generator_output_columns from projection_builders import ColumnExpr, project_output_columns, scalar_expr_output_name from substrait.inspect import aggregate_measure_output_names from substrait.schema_registry import named_table_columns +from window_builders import window_output_columns def _is_passthrough_output_kind(kind: PrismNodeKind) -> bool: @@ -33,6 +34,8 @@ pub def authored_output_columns(store_id: PrismStoreId, tip_id: int) -> list[str authored_output_columns(store_id, node.input_ids[0]), node.generator_applications[0], ) + if node.kind == PrismNodeKind.Window: + return window_output_columns(authored_output_columns(store_id, node.input_ids[0]), node.window_projections) if node.kind == PrismNodeKind.Join: # Join output columns preserve the conventional left-then-right relation order. # We keep both sides verbatim here; duplicate names are part of the current output shape and are resolved later @@ -70,6 +73,8 @@ pub def rewritten_output_columns(view: PrismOptimizedView, node_id: int) -> list rewritten_output_columns(view, node.input_ids[0]), node.generator_applications[0], ) + if node.kind == PrismNodeKind.Window: + return window_output_columns(rewritten_output_columns(view, node.input_ids[0]), node.window_projections) if node.kind == PrismNodeKind.Join: # Rewritten views keep the same left-then-right join column order as authored views # so output-column inference stays stable across Prism rewrite passes. diff --git a/src/prism/rewrite.incn b/src/prism/rewrite.incn index 419f968..12095fd 100644 --- a/src/prism/rewrite.incn +++ b/src/prism/rewrite.incn @@ -169,6 +169,7 @@ def _build_collapsed_limit_node( sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) @@ -206,6 +207,7 @@ def _build_collapsed_project_node( sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=merged_assignments, ) @@ -243,6 +245,7 @@ def _build_collapsed_aggregate_node( sort_columns=[], aggregate_measures=merged_measures, generator_applications=[], + window_projections=[], projection_assignments=[], ) @@ -278,6 +281,7 @@ def _build_collapsed_order_by_node( sort_columns=node.sort_columns, aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) @@ -296,6 +300,7 @@ def _build_rewritten_node(node: PrismNode, remapped_inputs: list[int], rewritten sort_columns=node.sort_columns, aggregate_measures=node.aggregate_measures, generator_applications=node.generator_applications, + window_projections=node.window_projections, projection_assignments=node.projection_assignments, ) @@ -342,6 +347,7 @@ def _compact_optimized_view(view: PrismOptimizedView) -> PrismOptimizedView: sort_columns=old_node.sort_columns, aggregate_measures=old_node.aggregate_measures, generator_applications=old_node.generator_applications, + window_projections=old_node.window_projections, projection_assignments=old_node.projection_assignments, ), ) diff --git a/src/prism/store.incn b/src/prism/store.incn index e451ade..e7af5c3 100644 --- a/src/prism/store.incn +++ b/src/prism/store.incn @@ -2,6 +2,7 @@ from aggregate_builders import AggregateMeasure from generator_builders import GeneratorApplication +from window_builders import WindowFunctionApplication, WindowProjection, WindowSpec from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -56,6 +57,7 @@ pub def append_node( sort_columns: list[ColumnExpr], aggregate_measures: list[AggregateMeasure], generator_applications: list[GeneratorApplication], + window_projections: list[WindowProjection], projection_assignments: list[ProjectionAssignment], ) -> int: """ @@ -76,6 +78,7 @@ pub def append_node( sort_columns=sort_columns, aggregate_measures=aggregate_measures, generator_applications=generator_applications, + window_projections=window_projections, projection_assignments=projection_assignments, ) prism_stored_nodes.append(PrismStoredNode(store_id_raw=store_id.0, node=appended)) @@ -123,11 +126,13 @@ pub def adopt_cursor_subgraph( adopted_sort_columns = [column for column in source_node.sort_columns] adopted_measures = [measure for measure in source_node.aggregate_measures] adopted_generators = [generator for generator in source_node.generator_applications] + adopted_windows = [projection for projection in source_node.window_projections] adopted_assignments = [assignment for assignment in source_node.projection_assignments] target_group_columns = [column for column in source_node.group_columns] target_sort_columns = [column for column in source_node.sort_columns] target_measures = [measure for measure in source_node.aggregate_measures] target_generators = [generator for generator in source_node.generator_applications] + target_windows = [projection for projection in source_node.window_projections] target_assignments = [assignment for assignment in source_node.projection_assignments] adopted_id = append_node( store_id=target_store_id, @@ -141,6 +146,7 @@ pub def adopt_cursor_subgraph( sort_columns=adopted_sort_columns, aggregate_measures=adopted_measures, generator_applications=adopted_generators, + window_projections=adopted_windows, projection_assignments=adopted_assignments, ) target_store_nodes.append( @@ -156,6 +162,7 @@ pub def adopt_cursor_subgraph( sort_columns=target_sort_columns, aggregate_measures=target_measures, generator_applications=target_generators, + window_projections=target_windows, projection_assignments=target_assignments, ), ) @@ -244,6 +251,8 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema source_node.generator_applications, ): return false + if not _window_projection_lists_structurally_equal(candidate.window_projections, source_node.window_projections): + return false if not _projection_assignments_structurally_equal( candidate.projection_assignments, source_node.projection_assignments, @@ -315,6 +324,45 @@ def _generator_applications_structurally_equal(left: GeneratorApplication, right return _column_exprs_structurally_equal(left.expr, right.expr) +def _window_projection_lists_structurally_equal(left: list[WindowProjection], right: list[WindowProjection]) -> bool: + """Return whether two window projection lists carry identical output names and applications.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if not _window_projections_structurally_equal(left[idx], right[idx]): + return false + return true + + +def _window_projections_structurally_equal(left: WindowProjection, right: WindowProjection) -> bool: + """Return whether two window projections carry identical names and window semantics.""" + if left.output_name != right.output_name: + return false + return _window_applications_structurally_equal(left.application, right.application) + + +def _window_applications_structurally_equal(left: WindowFunctionApplication, right: WindowFunctionApplication) -> bool: + """Return whether two window applications carry identical registry identity and window specs.""" + if left.kind != right.kind: + return false + if left.function_ref != right.function_ref: + return false + if left.canonical_name != right.canonical_name: + return false + if left.requires_ordering != right.requires_ordering: + return false + if not _column_expr_lists_structurally_equal(left.arguments, right.arguments): + return false + return _window_specs_structurally_equal(left.spec, right.spec) + + +def _window_specs_structurally_equal(left: WindowSpec, right: WindowSpec) -> bool: + """Return whether two window specs carry identical partitioning and ordering.""" + if not _column_expr_lists_structurally_equal(left.partition_columns, right.partition_columns): + return false + return _column_expr_lists_structurally_equal(left.sort_columns, right.sort_columns) + + def _text_lists_structurally_equal(left: list[str], right: list[str]) -> bool: """Return whether two string lists are structurally equivalent.""" if len(left) != len(right): diff --git a/src/prism/types.incn b/src/prism/types.incn index 59472c1..666b266 100644 --- a/src/prism/types.incn +++ b/src/prism/types.incn @@ -3,6 +3,7 @@ from aggregate_builders import AggregateMeasure from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment +from window_builders import WindowProjection pub type PrismStoreId = newtype int @@ -19,6 +20,7 @@ pub enum PrismNodeKind(str): GroupBy = "GroupBy" Aggregate = "Aggregate" Generate = "Generate" + Window = "Window" OrderBy = "OrderBy" Limit = "Limit" Explode = "Explode" @@ -44,6 +46,7 @@ pub model PrismNode: pub sort_columns: list[ColumnExpr] pub aggregate_measures: list[AggregateMeasure] pub generator_applications: list[GeneratorApplication] + pub window_projections: list[WindowProjection] pub projection_assignments: list[ProjectionAssignment] diff --git a/src/substrait/extensions.incn b/src/substrait/extensions.incn index 4f0edad..bc2b60b 100644 --- a/src/substrait/extensions.incn +++ b/src/substrait/extensions.incn @@ -58,6 +58,11 @@ pub def scalar_function_name_from_anchor(anchor: u32) -> str: return _function_name_from_anchor_or_raise(anchor, ExtensionFunctionKind.Scalar) +pub def window_function_name_from_anchor(anchor: u32) -> str: + """Resolve one known window-function anchor back to its registered function name.""" + return _function_name_from_anchor_or_raise(anchor, ExtensionFunctionKind.Window) + + def _function_extension_specs() -> list[FunctionExtensionSpec]: """Return Substrait extension specs derived from declaration-side registry metadata.""" mut specs: list[FunctionExtensionSpec] = [] @@ -81,6 +86,14 @@ def _function_extension_specs() -> list[FunctionExtensionSpec]: kind=ExtensionFunctionKind.Scalar, ), ) + elif entry.function_class == FunctionClass.Window: + specs.append( + FunctionExtensionSpec( + anchor=entry.substrait.anchor, + name=entry.substrait.function_name, + kind=ExtensionFunctionKind.Window, + ), + ) return specs @@ -145,6 +158,15 @@ def _scalar_extension_decl(anchor: u32) -> SimpleExtensionDeclaration: return raise_value_error(message) +def _window_extension_decl(anchor: u32) -> SimpleExtensionDeclaration: + """Build one window-function extension declaration for the provided anchor.""" + match _function_spec_from_anchor_of_kind(anchor, ExtensionFunctionKind.Window): + Ok(spec) => return _function_extension_decl(spec) + Err(err) => + message = err.error_message() + return raise_value_error(message) + + def _expr_uses_scalar_function_anchor(expr: Expression, expected_anchor: u32) -> bool: """Return whether one expression tree uses the requested scalar-function anchor.""" match expr.rex_type: @@ -264,6 +286,19 @@ def _rel_uses_aggregate_function_anchor(rel: Rel, expected_anchor: u32) -> bool: return false +def _rel_uses_window_function_anchor(rel: Rel, expected_anchor: u32) -> bool: + """Return whether one relation subtree uses the requested window-function anchor.""" + if let Some(RelType.Window(window_rel)) = rel.rel_type.clone(): + for window_function in window_rel.window_functions: + if window_function.function_reference == expected_anchor: + return true + + for child in relation_children(rel): + if _rel_uses_window_function_anchor(child, expected_anchor): + return true + return false + + def _rel_uses_scalar_function_anchor(rel: Rel, expected_anchor: u32) -> bool: """Return whether one relation subtree uses the requested scalar-function anchor.""" match rel.rel_type.clone(): @@ -336,6 +371,17 @@ def _aggregate_extension_anchors_for_rel(rel: Rel) -> list[u32]: return anchors +def _window_extension_anchors_for_rel(rel: Rel) -> list[u32]: + """Collect window-function anchors used by one relation subtree in stable declaration order.""" + mut anchors: list[u32] = [] + for spec in _function_extension_specs(): + if (spec.kind == ExtensionFunctionKind.Window and _rel_uses_window_function_anchor(rel.clone(), spec.anchor) and not anchors.contains( + spec.anchor, + )): + anchors.append(spec.anchor) + return anchors + + def _scalar_extension_anchors_for_rel(rel: Rel) -> list[u32]: """Collect scalar-function anchors used by one relation subtree in stable declaration order.""" if _rel_uses_if_then(rel.clone()): @@ -362,8 +408,13 @@ def _scalar_extension_decls(anchors: list[u32]) -> list[SimpleExtensionDeclarati return [_scalar_extension_decl(anchor) for anchor in anchors] +def _window_extension_decls(anchors: list[u32]) -> list[SimpleExtensionDeclaration]: + """Lower window-function anchors into extension declarations in the provided order.""" + return [_window_extension_decl(anchor) for anchor in anchors] + + def _extension_decl_for_anchor(anchor: u32) -> SimpleExtensionDeclaration: - """Lower one known aggregate or scalar function anchor into its extension declaration.""" + """Lower one known aggregate, window, or scalar function anchor into its extension declaration.""" match _function_spec_from_anchor(anchor): Ok(spec) => return _function_extension_decl(spec) Err(err) => @@ -372,8 +423,9 @@ def _extension_decl_for_anchor(anchor: u32) -> SimpleExtensionDeclaration: def _plan_extension_anchors_for_rel(rel: Rel) -> list[u32]: - """Collect aggregate and scalar function anchors used by one relation subtree in stable plan order.""" + """Collect function anchors used by one relation subtree in stable plan order.""" mut anchors = _aggregate_extension_anchors_for_rel(rel.clone()) + anchors.extend(_window_extension_anchors_for_rel(rel.clone())) anchors.extend(_scalar_extension_anchors_for_rel(rel)) return anchors diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn index 649a680..72e5d5f 100644 --- a/src/substrait/function_extensions.incn +++ b/src/substrait/function_extensions.incn @@ -12,6 +12,7 @@ pub enum ExtensionFunctionKind(str): Aggregate = "aggregate" Scalar = "scalar" + Window = "window" @derive(Clone) @@ -75,6 +76,9 @@ pub const MAP_EXTRACT_FUNCTION_ANCHOR: u32 = 48 pub const NAMED_STRUCT_FUNCTION_ANCHOR: u32 = 49 pub const ARRAY_HAS_ANY_FUNCTION_ANCHOR: u32 = 50 pub const ARRAY_FLATTEN_FUNCTION_ANCHOR: u32 = 51 +pub const ROW_NUMBER_FUNCTION_ANCHOR: u32 = 52 +pub const RANK_FUNCTION_ANCHOR: u32 = 53 +pub const DENSE_RANK_FUNCTION_ANCHOR: u32 = 54 const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" const EXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode_outer" diff --git a/src/substrait/inspect.incn b/src/substrait/inspect.incn index 063061f..7e23b48 100644 --- a/src/substrait/inspect.incn +++ b/src/substrait/inspect.incn @@ -7,7 +7,7 @@ inspection utilities used by tests, dataset carriers, and conformance validation from rust::std::boxed import Box from rust::std::primitive import i32 as RustI32 -from rust::substrait::proto import AggregateRel, Expression, Plan, ReadRel, Rel, RelCommon +from rust::substrait::proto import AggregateRel, ConsistentPartitionWindowRel, Expression, Plan, ReadRel, Rel, RelCommon from rust::substrait::proto::aggregate_rel import Measure from rust::substrait::proto::aggregate_function import AggregationInvocation from rust::substrait::proto::function_argument import ArgType @@ -20,7 +20,7 @@ from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure from projection_builders import scalar_expr_output_name from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit, rust_u32_to_int -from substrait.extensions import aggregate_function_name_from_anchor +from substrait.extensions import aggregate_function_name_from_anchor, window_function_name_from_anchor from substrait.function_extensions import ( explode_extension_uri, explode_outer_extension_uri, @@ -158,6 +158,30 @@ def _aggregate_output_columns(aggregate_rel: AggregateRel) -> list[str]: return columns +def _window_input_columns(window_rel: ConsistentPartitionWindowRel) -> list[str]: + """Resolve the current input-column names feeding one window relation.""" + match window_rel.input: + Some(child) => return relation_output_columns(child.as_ref().clone()) + None => return [] + + +def _window_output_columns(window_rel: ConsistentPartitionWindowRel) -> list[str]: + """Return input-column names followed by best-effort lowered window output names.""" + mut columns = _window_input_columns(window_rel.clone()) + for idx, window_function in enumerate(window_rel.window_functions): + function_name = window_function_name_from_anchor(window_function.function_reference) + if len(function_name) > 0: + columns.append(function_name) + else: + columns.append(f"window_{idx}") + return columns + + +def _window_function_names_for_rel(window_rel: ConsistentPartitionWindowRel) -> list[str]: + """Return lowered window-function names in declaration order.""" + return [window_function_name_from_anchor(fun.function_reference) for fun in window_rel.window_functions] + + def _relation_output_columns(rel: Rel) -> list[str]: """Return the current best-effort output column names for one relation subtree.""" match rel.rel_type.clone(): @@ -220,6 +244,7 @@ def _relation_output_columns(rel: Rel) -> list[str]: return [] return _relation_output_columns(set_rel.inputs[0]) Some(RelType.Aggregate(aggregate_rel)) => return _aggregate_output_columns(aggregate_rel.as_ref().clone()) + Some(RelType.Window(window_rel)) => return _window_output_columns(window_rel.as_ref().clone()) _ => return [] @@ -228,6 +253,27 @@ pub def relation_output_columns(rel: Rel) -> list[str]: return _relation_output_columns(rel) +pub def window_function_names(rel: Rel) -> list[str]: + """Return lowered window-function names when `rel` is a window root.""" + match rel.rel_type: + Some(RelType.Window(window_rel)) => return _window_function_names_for_rel(window_rel.as_ref().clone()) + _ => return [] + + +pub def window_partition_count(rel: Rel) -> int: + """Return the number of partition expressions when `rel` is a window root.""" + match rel.rel_type: + Some(RelType.Window(window_rel)) => return len(window_rel.partition_expressions) + _ => return 0 + + +pub def window_sort_count(rel: Rel) -> int: + """Return the number of sort expressions when `rel` is a window root.""" + match rel.rel_type: + Some(RelType.Window(window_rel)) => return len(window_rel.sorts) + _ => return 0 + + def _extension_single_output_columns(input_columns: list[str], extension_uri: str) -> list[str]: """Return best-effort output columns for known extension-single relation encodings.""" mut columns: list[str] = [] @@ -332,6 +378,7 @@ pub def relation_kind_name(rel: Rel) -> str: Some(RelType.Join(_)) => return "JoinRel" Some(RelType.Cross(_)) => return "CrossRel" Some(RelType.Aggregate(_)) => return "AggregateRel" + Some(RelType.Window(_)) => return "ConsistentPartitionWindowRel" Some(RelType.Sort(_)) => return "SortRel" Some(RelType.Fetch(_)) => return "FetchRel" Some(RelType.Set(_)) => return "SetRel" diff --git a/src/substrait/relations.incn b/src/substrait/relations.incn index b075e5f..9ae957c 100644 --- a/src/substrait/relations.incn +++ b/src/substrait/relations.incn @@ -13,6 +13,7 @@ from rust::std::primitive import i32 as RustI32 from rust::substrait::proto import ( AggregateFunction, AggregateRel, + ConsistentPartitionWindowRel, CrossRel, ExtensionSingleRel, Expression, @@ -32,6 +33,7 @@ from rust::substrait::proto import ( ) from rust::substrait::proto::aggregate_function import AggregationInvocation from rust::substrait::proto::aggregate_rel import Grouping, Measure +from rust::substrait::proto::consistent_partition_window_rel import WindowRelFunction from rust::substrait::proto::expression::nested import Struct as NestedStruct from rust::substrait::proto::fetch_rel import CountMode, OffsetMode from rust::substrait::proto::function_argument import ArgType @@ -48,6 +50,7 @@ from function_registry import FunctionClass, FunctionRegistryEntry, SubstraitMap from functions.registry import function_registry_entry from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, ScalarFunctionApplicationExpr, col +from window_builders import WindowFunctionApplication, WindowProjection from substrait.expr_lowering import ( bool_expr, filter_predicate_expr, @@ -91,6 +94,18 @@ model ResolvedGeneratorApplication: expr: Expression +@derive(Clone) +model ResolvedWindowProjection: + """One named window projection resolved against input columns and registry metadata.""" + + output_name: str + application: WindowFunctionApplication + entry: FunctionRegistryEntry + arguments: list[FunctionArgument] + partition_expressions: list[Expression] + sorts: list[SortField] + + pub enum SubstraitJoinKind: Inner Left @@ -146,6 +161,11 @@ def _rel_aggregate(aggregate: AggregateRel) -> Rel: return Rel(rel_type=Some(RelType.Aggregate(Box.new(aggregate)))) +def _rel_window(window: ConsistentPartitionWindowRel) -> Rel: + """Wrap ConsistentPartitionWindowRel payload into one generic Rel union value.""" + return Rel(rel_type=Some(RelType.Window(Box.new(window)))) + + def _rel_sort(sort: SortRel) -> Rel: """Wrap SortRel payload into one generic Rel union value.""" return Rel(rel_type=Some(RelType.Sort(Box.new(sort)))) @@ -316,6 +336,64 @@ def _validate_generator_output_columns( return Ok(None) +def _window_registry_entry( + application: WindowFunctionApplication, +) -> Result[FunctionRegistryEntry, SubstraitLoweringError]: + """Resolve one window registry entry and validate its semantic class.""" + match function_registry_entry(application.function_ref): + Some(entry) => + if entry.function_class != FunctionClass.Window: + return Err(invalid_scalar_expression(f"{entry.function_ref} is not registered as a window function")) + if entry.substrait.kind != SubstraitMappingKind.ExtensionFunction: + return Err( + invalid_scalar_expression(f"{entry.function_ref} does not declare a window extension mapping"), + ) + return Ok(entry) + None => + return Err(invalid_scalar_expression(f"missing window registry entry for `{application.canonical_name}`")) + + +def _resolved_window_projection( + input_columns: list[str], + projection: WindowProjection, +) -> Result[ResolvedWindowProjection, SubstraitLoweringError]: + """Resolve one window projection against input-column names.""" + if len(projection.output_name) == 0: + return Err(invalid_scalar_expression("window output alias must be non-empty")) + application = projection.application + if application.requires_ordering and len(application.spec.sort_columns) == 0: + return Err(invalid_scalar_expression(f"{application.function_ref} requires an explicit window ordering")) + return Ok( + ResolvedWindowProjection( + output_name=projection.output_name, + application=application.clone(), + entry=_window_registry_entry(application.clone())?, + arguments=[FunctionArgument(arg_type=Some(ArgType.Value(scalar_expr(input_columns, arg)?))) for arg in application.arguments], + partition_expressions=[scalar_expr(input_columns, column)? for column in application.spec.partition_columns], + sorts=[_sort_field(input_columns, column)? for column in application.spec.sort_columns], + ), + ) + + +def _resolved_window_projection_to_substrait( + projection: ResolvedWindowProjection, +) -> Result[WindowRelFunction, SubstraitLoweringError]: + """Lower one resolved window projection into a Substrait window function payload.""" + return Ok( + WindowRelFunction( + function_reference=projection.entry.substrait.anchor, + arguments=projection.arguments, + options=[], + output_type=None, + phase=AggregationPhase.InitialToResult.into(), + invocation=AggregationInvocation.All.into(), + lower_bound=None, + upper_bound=None, + bounds_type=0, + ), + ) + + def _contains_text(values: list[str], expected: str) -> bool: """Return whether a string list contains a value.""" for value in values: @@ -693,6 +771,46 @@ pub def try_generator_rel_of_columns( return Ok(extension_single_rel(input, resolved.entry.substrait.uri)) +pub def window_rel(input: Rel, projection: WindowProjection) -> Rel: + """Wrap a child relation in a window relation with one named window projection.""" + return _lowered_rel_or_raise(try_window_rel(input, projection)) + + +pub def try_window_rel(input: Rel, projection: WindowProjection) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a window relation with one named window projection.""" + return try_window_rel_of_columns(input.clone(), relation_output_columns(input), [projection]) + + +pub def window_rel_of_columns(input: Rel, input_columns: list[str], projections: list[WindowProjection]) -> Rel: + """Wrap a child relation in a window relation using explicit input-column names.""" + return _lowered_rel_or_raise(try_window_rel_of_columns(input, input_columns, projections)) + + +pub def try_window_rel_of_columns( + input: Rel, + input_columns: list[str], + projections: list[WindowProjection], +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a window relation using explicit input-column names.""" + if len(projections) == 0: + return Err(invalid_scalar_expression("window relation requires at least one window projection")) + if len(projections) > 1: + return Err(invalid_scalar_expression("window relation currently accepts exactly one window projection")) + resolved = _resolved_window_projection(input_columns, projections[0])? + return Ok( + _rel_window( + ConsistentPartitionWindowRel( + common=Some(_direct_common()), + input=Some(Box.new(input)), + window_functions=[_resolved_window_projection_to_substrait(resolved.clone())?], + partition_expressions=resolved.partition_expressions, + sorts=resolved.sorts, + advanced_extension=None, + ), + ), + ) + + pub def sort_rel(input: Rel) -> Rel: """Wrap a child relation in `SortRel` using the first known output column as the default sort key.""" input_columns = relation_output_columns(input.clone()) diff --git a/src/substrait/traversal.incn b/src/substrait/traversal.incn index 93ea530..f7bbcec 100644 --- a/src/substrait/traversal.incn +++ b/src/substrait/traversal.incn @@ -37,6 +37,10 @@ pub def relation_children(rel: Rel) -> list[Rel]: match aggregate.input: Some(child) => return [child.as_ref().clone()] None => return [] + Some(RelType.Window(window)) => + match window.input: + Some(child) => return [child.as_ref().clone()] + None => return [] Some(RelType.Sort(sort)) => match sort.input: Some(child) => return [child.as_ref().clone()] diff --git a/src/window_builders.incn b/src/window_builders.incn new file mode 100644 index 0000000..5e9f8ea --- /dev/null +++ b/src/window_builders.incn @@ -0,0 +1,152 @@ +""" +Window specification and window-function builder surface. + +Window applications are relation-aware expressions: they produce one value per input row while reading a partition of +related rows. They intentionally do not reuse scalar expression nodes, so invalid scalar positions can remain +diagnosable as the query surface grows. +""" + +from rust::incan_stdlib::errors import raise_value_error +from function_registry import function_ref_for +from projection_builders import ColumnExpr + + +@derive(Clone) +pub enum WindowFunctionKind(str): + """Supported window function kinds in the current foundation slice.""" + + RowNumber = "row_number" + Rank = "rank" + DenseRank = "dense_rank" + + +@derive(Clone) +pub model WindowSpec: + """Partitioning and ordering shared by one or more window function applications.""" + + pub partition_columns: list[ColumnExpr] + pub sort_columns: list[ColumnExpr] + + def partition_by(self, columns: list[ColumnExpr]) -> Self: + """Return this window spec with partition expressions replaced.""" + return WindowSpec(partition_columns=columns, sort_columns=self.sort_columns) + + def order_by(self, columns: list[ColumnExpr]) -> Self: + """Return this window spec with ordering expressions replaced.""" + return WindowSpec(partition_columns=self.partition_columns, sort_columns=columns) + + +@derive(Clone) +pub model WindowFunctionApplication: + """One placed window function application.""" + + pub kind: WindowFunctionKind + pub function_ref: str + pub canonical_name: str + pub arguments: list[ColumnExpr] + pub spec: WindowSpec + pub requires_ordering: bool + + +@derive(Clone) +pub model WindowFunctionCall: + """Unplaced window function call waiting for an explicit window specification.""" + + pub kind: WindowFunctionKind + pub function_ref: str + pub canonical_name: str + pub arguments: list[ColumnExpr] + pub requires_ordering: bool + + def over(self, spec: WindowSpec) -> WindowFunctionApplication: + """Place this window function call over a concrete window specification.""" + return WindowFunctionApplication( + kind=self.kind, + function_ref=self.function_ref, + canonical_name=self.canonical_name, + arguments=self.arguments, + spec=spec, + requires_ordering=self.requires_ordering, + ) + + +@derive(Clone) +pub model WindowProjection: + """One named output column backed by a placed window function application.""" + + pub output_name: str + pub application: WindowFunctionApplication + + +pub def window() -> WindowSpec: + """Build an empty window specification.""" + return WindowSpec(partition_columns=[], sort_columns=[]) + + +pub def row_number() -> WindowFunctionCall: + """Build a `row_number` window function call.""" + return _ranking_call("row_number", WindowFunctionKind.RowNumber) + + +pub def rank() -> WindowFunctionCall: + """Build a `rank` window function call.""" + return _ranking_call("rank", WindowFunctionKind.Rank) + + +pub def dense_rank() -> WindowFunctionCall: + """Build a `dense_rank` window function call.""" + return _ranking_call("dense_rank", WindowFunctionKind.DenseRank) + + +pub def window_projection(output_name: str, application: WindowFunctionApplication) -> WindowProjection: + """Build one named window projection after validating its output alias.""" + if len(output_name) == 0: + return raise_value_error("window output alias must be non-empty") + return WindowProjection(output_name=output_name, application=application) + + +pub def window_output_columns(input_columns: list[str], projections: list[WindowProjection]) -> list[str]: + """Return output columns after applying window projections with add-or-replace semantics.""" + mut output_columns: list[str] = [] + output_columns.extend(input_columns) + for projection in projections: + existing_idx = _index_of_text(output_columns, projection.output_name) + if existing_idx >= 0: + output_columns[existing_idx] = projection.output_name + else: + output_columns.append(projection.output_name) + return output_columns + + +def _ranking_call(canonical_name: str, kind: WindowFunctionKind) -> WindowFunctionCall: + """Build one ranking-family window call.""" + return WindowFunctionCall( + kind=kind, + function_ref=function_ref_for(canonical_name), + canonical_name=canonical_name, + arguments=[], + requires_ordering=true, + ) + + +def _index_of_text(values: list[str], expected: str) -> int: + """Return the index of a string value, or -1 when absent.""" + for idx, value in enumerate(values): + if value == expected: + return idx + return -1 + + +module tests: + from projection_builders import col, column_expr_name + def test_window_spec_builders_preserve_partition_and_order() -> None: + spec = window().partition_by([col("customer_id")]).order_by([col("amount")]) + assert len(spec.partition_columns) == 1 + assert column_expr_name(spec.partition_columns[0]) == "customer_id" + assert len(spec.sort_columns) == 1 + assert column_expr_name(spec.sort_columns[0]) == "amount" + def test_ranking_call_over_records_registry_identity() -> None: + application = rank().over(window().order_by([col("amount")])) + assert application.kind == WindowFunctionKind.Rank + assert application.function_ref == "inql.functions.rank" + assert application.requires_ordering diff --git a/tests/test_dataset.incn b/tests/test_dataset.incn index 3140a03..b933297 100644 --- a/tests/test_dataset.incn +++ b/tests/test_dataset.incn @@ -28,14 +28,17 @@ from functions import ( mul, posexplode, posexplode_outer, + row_number, str_expr, str_lit, sum, + window, ) from projection_builders import ColumnExprKind, column_expr_kind, column_expr_name from substrait.function_extensions import ( explode_extension_uri, explode_outer_extension_uri, + function_extension_uri, posexplode_extension_uri, posexplode_outer_extension_uri, ) @@ -459,6 +462,10 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None generated_positional_outer: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( posexplode_outer(col("line_items"), "position", "line_item"), ) + windowed: LazyFrame[Order] = lazy_frame_named_table("orders").with_window_column( + "row_num", + row_number().over(window().order_by([col("id")])), + ) # -- Assert -- assert relation_kind_name(root_rel(projected.to_substrait_plan())) == "ProjectRel", "select should lower through the project boundary shape" @@ -472,6 +479,9 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None assert plan_has_extension_urn(generated_outer.to_substrait_plan(), explode_outer_extension_uri()), "outer explode should use its relation extension URI" assert plan_has_extension_urn(generated_positional.to_substrait_plan(), posexplode_extension_uri()), "posexplode should use its relation extension URI" assert plan_has_extension_urn(generated_positional_outer.to_substrait_plan(), posexplode_outer_extension_uri()), "posexplode_outer should use its relation extension URI" + assert relation_kind_name(root_rel(windowed.to_substrait_plan())) == "ConsistentPartitionWindowRel", "with_window_column should lower through the window boundary shape" + assert windowed.planned_columns() == ["id", "row_num"], "window projections should append declared output aliases" + assert plan_has_extension_urn(windowed.to_substrait_plan(), function_extension_uri()), "window plans should use the shared function extension URI" def test_lazy_frame__deeper_independent_roots_still_lower_with_stable_shapes() -> None: diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 424147e..23902fa 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -85,13 +85,17 @@ from functions import ( or_, posexplode, posexplode_outer, + dense_rank, registered_substrait_mapped_function_refs, + rank, round, + row_number, str_expr, str_lit, sub, sum, try_cast, + window, ) from function_registry import ( FunctionAliasPolicy, @@ -138,6 +142,7 @@ from substrait.function_extensions import ( CEIL_FUNCTION_ANCHOR, COALESCE_FUNCTION_ANCHOR, COUNT_FUNCTION_ANCHOR, + DENSE_RANK_FUNCTION_ANCHOR, DIVIDE_FUNCTION_ANCHOR, EQUAL_FUNCTION_ANCHOR, FLOOR_FUNCTION_ANCHOR, @@ -165,6 +170,8 @@ from substrait.function_extensions import ( NOT_FUNCTION_ANCHOR, NULLIF_FUNCTION_ANCHOR, OR_FUNCTION_ANCHOR, + RANK_FUNCTION_ANCHOR, + ROW_NUMBER_FUNCTION_ANCHOR, ROUND_FUNCTION_ANCHOR, SUBTRACT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR, @@ -231,12 +238,12 @@ def _local_entry_by_namespace_and_name_or_fail( def _expected_registry_names() -> list[str]: """Return the expected registered public helper names.""" - return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer"] + return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer", "window", "row_number", "rank", "dense_rank"] def _expected_substrait_mapped_names() -> list[str]: """Return helpers with concrete Substrait extension-function mappings.""" - return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct"] + return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "row_number", "rank", "dense_rank"] def _exercise_current_public_helpers() -> None: @@ -325,6 +332,10 @@ def _exercise_current_public_helpers() -> None: explode_outer(tags, "tag") posexplode(tags, "position", "tag") posexplode_outer(tags, "position", "tag") + window() + row_number() + rank() + dense_rank() return @@ -641,6 +652,23 @@ def test_function_registry__generator_helpers_are_relation_extensions() -> None: _assert_relation_extension_mapping("posexplode_outer", "posexplode_outer", posexplode_outer_extension_uri()) +def test_function_registry__window_helpers_are_relation_window_functions() -> None: + """Assert window helpers carry relation-aware window-function metadata.""" + # -- Arrange -- + _exercise_current_public_helpers() + window_entry = _entry_by_name_or_fail("window") + row_number_entry = _entry_by_name_or_fail("row_number") + + # -- Act / Assert -- + assert window_entry.function_class == FunctionClass.Window, "window spec builder should be classified as window metadata" + assert window_entry.substrait.kind == SubstraitMappingKind.StructuralFunction, "window spec builder should be structural metadata" + assert window_entry.substrait.function_name == "window_spec", "window spec builder should name the window-spec context" + assert row_number_entry.function_class == FunctionClass.Window, "ranking helpers should be classified as window functions" + _assert_extension_mapping("row_number", "row_number", ROW_NUMBER_FUNCTION_ANCHOR) + _assert_extension_mapping("rank", "rank", RANK_FUNCTION_ANCHOR) + _assert_extension_mapping("dense_rank", "dense_rank", DENSE_RANK_FUNCTION_ANCHOR) + + def test_function_registry__ordering_helpers_are_contextual_sort_fields() -> None: """Assert RFC 015 ordering helpers are modeled as sort-field context helpers.""" # -- Arrange -- diff --git a/tests/test_prism.incn b/tests/test_prism.incn index fc1f097..16b8564 100644 --- a/tests/test_prism.incn +++ b/tests/test_prism.incn @@ -1,12 +1,13 @@ """Internal Prism engine tests against the shared-store cursor substrate.""" -from functions import always_false, always_true, col, count, count_expr, explode, lit, mul, sum +from functions import always_false, always_true, col, count, count_expr, explode, lit, mul, row_number, sum, window from prism import ( PrismCursor, prism_cursor_apply_filter, prism_cursor_apply_generate, prism_cursor_apply_limit, prism_cursor_apply_select, + prism_cursor_apply_with_window_column, prism_cursor_authored_node_count, prism_cursor_named_table, prism_cursor_rewrite_applied_rule_count, @@ -21,7 +22,8 @@ from prism import ( prism_cursor_tip_origin_id, prism_cursors_share_store, ) -from substrait.inspect import plan_contains_relation_kind, relation_kind_name, root_rel +from substrait.function_extensions import function_extension_uri +from substrait.inspect import plan_contains_relation_kind, plan_has_extension_urn, relation_kind_name, root_rel from substrait.plans import plan_encoded_len from substrait.schema_registry import register_named_table_schema from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind @@ -230,6 +232,10 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: generated: PrismCursor[Order] = prism_cursor_named_table(str("orders_generator_prism")).generate( explode(col("line_items"), "line_item"), ) + windowed: PrismCursor[Order] = prism_cursor_named_table(str("orders")).with_window_column( + "row_num", + row_number().over(window().order_by([col("id")])), + ) # -- Assert -- assert prism_cursor_tip_kind_name(projected) == str("Project"), "select should append a native project node" @@ -239,7 +245,24 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: assert prism_cursor_tip_kind_name(limited) == str("Limit"), "limit should append a native limit node" assert prism_cursor_tip_kind_name(exploded) == str("Explode"), "explode should append a native explode node" assert prism_cursor_tip_kind_name(generated) == str("Generate"), "generate should append a native generator node" + assert prism_cursor_tip_kind_name(windowed) == str("Window"), "with_window_column should append a native window node" assert prism_cursor_output_columns(generated) == ["id", "line_items", "line_item"], "generate should append declared output aliases" + assert prism_cursor_output_columns(windowed) == ["id", "row_num"], "window projections should append declared output aliases" + + +def test_prism__window_column_lowers_through_substrait_boundary() -> None: + # -- Arrange -- + _register_projection_test_schema(str("orders")) + base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) + + # -- Act -- + windowed = prism_cursor_apply_with_window_column(base, "row_num", row_number().over(window().order_by([col("id")]))) + plan = windowed.to_substrait_plan() + + # -- Assert -- + assert prism_cursor_tip_kind_name(windowed) == str("Window"), "window helper should create a Prism window node" + assert relation_kind_name(root_rel(plan)) == str("ConsistentPartitionWindowRel"), "window Prism node should lower to a window relation" + assert plan_has_extension_urn(plan, function_extension_uri()), "window plans should declare the shared function extension URI" def test_prism__rewrite_eliminates_filter_true_by_default() -> None: diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 04bb902..8a44395 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -32,6 +32,7 @@ from functions import ( count_distinct, count_expr, count_if, + dense_rank, desc, div, eq, @@ -63,12 +64,15 @@ from functions import ( not_, nullif, or_, + rank, round, + row_number, sub, sum, try_cast, cardinality, element_at, + window, ) from projection_builders import ColumnExpr, with_column_assignment from substrait.errors import SubstraitLoweringErrorKind @@ -100,6 +104,9 @@ from substrait.inspect import ( sort_field_count, sort_field_direction_name, sort_field_expr_index, + window_function_names, + window_partition_count, + window_sort_count, ) from substrait.plans import ( plan_encoded_len, @@ -131,9 +138,12 @@ from substrait.relations import ( set_rel_of_kind, sort_rel_of_columns, try_aggregate_rel_of_columns, + try_window_rel_of_columns, + window_rel_of_columns, ) from substrait.schema_registry import named_table_columns, register_named_table_schema from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind +from window_builders import WindowProjection, window_projection from substrait.conformance import ( ConformanceCapabilityTags, ConformancePortability, @@ -202,6 +212,32 @@ def _register_orders_schema() -> None: register_named_table_schema("orders", [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false)]) +def _register_window_orders_schema() -> None: + register_named_table_schema( + "orders_window", + [RowColumnSpec(name="customer_id", kind=SubstraitPrimitiveKind.String, nullable=false), RowColumnSpec( + name="amount", + kind=SubstraitPrimitiveKind.I64, + nullable=false, + )], + ) + + +def _row_number_projection() -> WindowProjection: + spec = window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]) + return window_projection("row_num", row_number().over(spec)) + + +def _rank_projection() -> WindowProjection: + spec = window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]) + return window_projection("amount_rank", rank().over(spec)) + + +def _dense_rank_projection() -> WindowProjection: + spec = window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]) + return window_projection("dense_amount_rank", dense_rank().over(spec)) + + def _register_fixture_schema(table_name: str) -> None: register_named_table_schema( table_name, @@ -541,6 +577,61 @@ def test_plan__aggregate_rel_rejects_invalid_modifier_shapes() -> None: assert ordered_err.message.contains("ordered aggregate input"), "ordered aggregate diagnostic should identify the unsupported modifier" +def _assert_window_projection_lowers(projection: WindowProjection, expected_function_name: str) -> None: + """Assert one window projection lowers to a concrete Substrait window relation.""" + _register_window_orders_schema() + base = read_named_table_rel("orders_window") + windowed = window_rel_of_columns(base, ["customer_id", "amount"], [projection]) + plan = plan_from_root_relation(windowed, ["customer_id", "amount", projection.output_name]) + + assert relation_kind_name(windowed) == "ConsistentPartitionWindowRel", "window lowering should emit the Substrait window relation" + assert plan_contains_relation_kind(plan, "ConsistentPartitionWindowRel"), "window plans should preserve the window relation shape" + assert plan_has_extension_urn(plan, function_extension_uri()), "window function plans should register the shared function extension URN" + assert window_function_names(windowed) == [expected_function_name], "window relation should carry the registered window function" + assert window_partition_count(windowed) == 1, "window relation should lower explicit partition expressions" + assert window_sort_count(windowed) == 1, "ranking window relation should lower explicit ordering" + + +def test_plan__row_number_window_rel_lowers_to_substrait() -> None: + # -- Arrange / Act / Assert -- + _assert_window_projection_lowers(_row_number_projection(), "row_number") + + +def test_plan__rank_window_rel_lowers_to_substrait() -> None: + # -- Arrange / Act / Assert -- + _assert_window_projection_lowers(_rank_projection(), "rank") + + +def test_plan__dense_rank_window_rel_lowers_to_substrait() -> None: + # -- Arrange / Act / Assert -- + _assert_window_projection_lowers(_dense_rank_projection(), "dense_rank") + + +def test_plan__ranking_window_rel_rejects_missing_ordering() -> None: + # -- Arrange -- + _register_window_orders_schema() + base = read_named_table_rel("orders_window") + unordered = window_projection("row_num", row_number().over(window().partition_by([col("customer_id")]))) + + # -- Act -- + result = try_window_rel_of_columns(base, ["customer_id", "amount"], [unordered]) + + # -- Assert -- + assert_is_err(result, "ranking window helpers should require explicit ordering") + + +def test_plan__window_rel_rejects_multiple_projections_until_partition_grouping_lands() -> None: + # -- Arrange -- + _register_window_orders_schema() + base = read_named_table_rel("orders_window") + + # -- Act -- + result = try_window_rel_of_columns(base, ["customer_id", "amount"], [_row_number_projection(), _rank_projection()]) + + # -- Assert -- + assert_is_err(result, "current window relation lowering should reject multiple projections explicitly") + + def test_plan__set_rel_uses_operation_enum() -> None: # -- Arrange -- left = read_named_table_rel("orders_current") diff --git a/tests/test_window_functions.incn b/tests/test_window_functions.incn new file mode 100644 index 0000000..17fb30c --- /dev/null +++ b/tests/test_window_functions.incn @@ -0,0 +1,47 @@ +"""Tests for RFC 019 window specification and ranking helpers.""" + +from functions import col, dense_rank, rank, row_number, window +from projection_builders import column_expr_name +from window_builders import WindowFunctionKind, window_output_columns, window_projection + + +def test_window_builders__spec_preserves_partition_and_order_columns() -> None: + # -- Arrange / Act -- + spec = window().partition_by([col("customer_id")]).order_by([col("amount")]) + + # -- Assert -- + assert len(spec.partition_columns) == 1, "window partition should record explicit partition expressions" + assert column_expr_name(spec.partition_columns[0]) == "customer_id", "partition expression should preserve column refs" + assert len(spec.sort_columns) == 1, "window ordering should record explicit sort expressions" + assert column_expr_name(spec.sort_columns[0]) == "amount", "sort expression should preserve column refs" + + +def test_window_builders__ranking_helpers_return_unplaced_calls() -> None: + # -- Arrange -- + spec = window().order_by([col("amount")]) + + # -- Act -- + row_number_app = row_number().over(spec) + rank_app = rank().over(spec) + dense_rank_app = dense_rank().over(spec) + + # -- Assert -- + assert row_number_app.kind == WindowFunctionKind.RowNumber, "row_number should keep typed window identity" + assert rank_app.kind == WindowFunctionKind.Rank, "rank should keep typed window identity" + assert dense_rank_app.kind == WindowFunctionKind.DenseRank, "dense_rank should keep typed window identity" + assert row_number_app.requires_ordering, "ranking helpers should require explicit window ordering" + assert rank_app.function_ref == "inql.functions.rank", "rank should derive stable registry identity" + assert dense_rank_app.canonical_name == "dense_rank", "dense_rank should expose its canonical name" + + +def test_window_builders__output_columns_use_add_or_replace_alias_semantics() -> None: + # -- Arrange -- + spec = window().order_by([col("amount")]) + + # -- Act -- + appended = window_output_columns(["id", "amount"], [window_projection("row_num", row_number().over(spec))]) + replaced = window_output_columns(["id", "rank"], [window_projection("rank", rank().over(spec))]) + + # -- Assert -- + assert appended == ["id", "amount", "row_num"], "new window aliases should append to input columns" + assert replaced == ["id", "rank"], "existing window aliases should replace in place without duplicating names" From 4ebb788e698b6c21640dc14d7e85d04bd44b1d87 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Tue, 26 May 2026 02:03:37 +0200 Subject: [PATCH 6/6] feature - implement RFC 022 hashing functions (#39) --- docs/language/reference/functions/format.md | 31 ++++++++ docs/language/reference/functions/index.md | 4 +- docs/release_notes/v0_1.md | 1 + .../022_semi_structured_format_functions.md | 39 ++++++++-- docs/rfcs/README.md | 2 +- src/functions/hashing/md5.incn | 51 +++++++++++++ src/functions/hashing/sha2.incn | 76 +++++++++++++++++++ src/functions/hashing/sha224.incn | 51 +++++++++++++ src/functions/hashing/sha256.incn | 51 +++++++++++++ src/functions/hashing/sha384.incn | 51 +++++++++++++ src/functions/hashing/sha512.incn | 51 +++++++++++++ src/functions/mod.incn | 6 ++ src/lib.incn | 6 ++ src/substrait/function_extensions.incn | 5 ++ tests/test_function_registry.incn | 35 ++++++++- tests/test_hashing_functions.incn | 55 ++++++++++++++ tests/test_session_projection.incn | 39 ++++++++++ tests/test_substrait_plan.incn | 12 +++ 18 files changed, 556 insertions(+), 10 deletions(-) create mode 100644 docs/language/reference/functions/format.md create mode 100644 src/functions/hashing/md5.incn create mode 100644 src/functions/hashing/sha2.incn create mode 100644 src/functions/hashing/sha224.incn create mode 100644 src/functions/hashing/sha256.incn create mode 100644 src/functions/hashing/sha384.incn create mode 100644 src/functions/hashing/sha512.incn create mode 100644 tests/test_hashing_functions.incn diff --git a/docs/language/reference/functions/format.md b/docs/language/reference/functions/format.md new file mode 100644 index 0000000..568e349 --- /dev/null +++ b/docs/language/reference/functions/format.md @@ -0,0 +1,31 @@ +# Format Functions (Reference) + +Format helpers operate on scalar payloads that are already present in a relation. They do not read files, infer source +schemas from external locations, or change relation cardinality. + +The current implemented slice is deterministic string hashing: + +| Function | Meaning | +| --- | --- | +| `md5(expr)` | Return the lowercase hexadecimal MD5 digest for one string expression. | +| `sha224(expr)` | Return the lowercase hexadecimal SHA-224 digest for one string expression. | +| `sha256(expr)` | Return the lowercase hexadecimal SHA-256 digest for one string expression. | +| `sha384(expr)` | Return the lowercase hexadecimal SHA-384 digest for one string expression. | +| `sha512(expr)` | Return the lowercase hexadecimal SHA-512 digest for one string expression. | +| `sha2(expr, bit_length)` | Compatibility helper that rewrites to `sha224`, `sha256`, `sha384`, or `sha512` for supported literal bit lengths. | + +```incan +from pub::inql.functions import col, md5, sha2 + +projected = ( + events + .with_column("user_hash", sha2(col("user_id"), 256)) + .with_column("payload_md5", md5(col("payload"))) +) +``` + +Hash helpers operate on UTF-8 string bytes and return lowercase hexadecimal strings. `sha2(...)` accepts `224`, `256`, +`384`, and `512`; unsupported digest lengths are rejected by the helper rather than being passed through to a backend. + +JSON, CSV, URL, and dynamic-value predicate helpers remain future format-function slices until their schema arguments, +option records, path validation rules, and dynamic value model are specified. diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index adc070d..23ec6d0 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -10,10 +10,11 @@ Today the concrete shipped surfaces are documented here: - [Generator and table-valued functions](generators.md) - [Nested data functions](nested.md) - [Window functions](windows.md) +- [Format functions](format.md) The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, nested data, and windows. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, nested data, windows, and format helpers. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. @@ -37,6 +38,7 @@ The registered helper surface currently includes: | `array(...)`, `cardinality(...)`, `array_contains(...)`, `arrays_overlap(...)`, `array_position(...)`, `element_at(...)`, `array_sort(...)`, `array_distinct(...)`, `array_except(...)`, `array_intersect(...)`, `array_union(...)`, `array_join(...)`, `array_slice(...)`, `array_reverse(...)`, `array_flatten(...)`, `map_from_arrays(...)`, `map_extract(...)`, `map_contains_key(...)`, `map_keys(...)`, `map_values(...)`, `map_entries(...)`, `named_struct(...)` | scalar | registered nested scalar helpers backed by Substrait extension mappings; `map_contains_key(...)` lowers as a documented predicate rewrite | | `explode(...)`, `explode_outer(...)`, `posexplode(...)`, `posexplode_outer(...)` | generator | relation-extension mappings consumed by `generate(...)`; positional forms use zero-based positions | | `window()`, `row_number()`, `rank()`, `dense_rank()` | window | `window()` builds structural window-spec metadata; ranking helpers lower through `ConsistentPartitionWindowRel` when placed with `with_window_column(...)` | +| `md5(...)`, `sha224(...)`, `sha256(...)`, `sha384(...)`, `sha512(...)`, `sha2(...)` | scalar | registered format/hash helpers; concrete helpers lower through Substrait extension mappings, while `sha2(...)` rewrites to a supported concrete SHA-2 helper | | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count()`, `count_expr(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions; `count_expr(...)` is a compatibility spelling for future `count(expr)` helper overloading | | `count_distinct(...)`, `count_if(...)` | aggregate | compatibility helpers that lower through aggregate modifiers over canonical `count` semantics | diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index d337e4e..9be3264 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -18,6 +18,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Nested data functions:** RFC 020 adds registry-backed scalar helpers for array construction/access, cardinality, containment, overlap, sorting, set-like operations, joining, slicing, reversing, scalar array flattening, map construction/access, map key/value/entry extraction, map key containment, and named struct construction. These helpers lower through Substrait extension metadata and execute through the DataFusion-backed Session path without introducing generator semantics. - **Generator functions:** RFC 021 adds registry-backed generator applications for `explode(...)`, `explode_outer(...)`, `posexplode(...)`, and `posexplode_outer(...)`. Generators remain relation-shaping operations applied with `generate(...)`; they preserve input columns, require explicit output aliases, and lower through the current Substrait extension-relation gap encoding. - **Window functions:** RFC 019 adds the first window-function planning slice with `window()` specs, `row_number()`, `rank()`, `dense_rank()`, and `with_window_column(...)`. Ranking windows require explicit ordering and lower through Substrait `ConsistentPartitionWindowRel`; backend execution support remains a separate adapter capability. +- **Format functions:** RFC 022 adds the first deterministic hashing slice with `md5(...)`, `sha224(...)`, `sha256(...)`, `sha384(...)`, `sha512(...)`, and `sha2(...)`. Hash helpers operate on UTF-8 string bytes, return lowercase hexadecimal strings, lower through registry-owned Substrait metadata, and execute through the DataFusion-backed Session path. - **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. - **Function extension policy:** RFC 024 policy metadata now distinguishes portable core functions, namespaced extension-only functions, opt-in compatibility aliases, engine-specific functions, and rejected compatibility requests without adding an extension plugin system or backend-owned semantics. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. diff --git a/docs/rfcs/022_semi_structured_format_functions.md b/docs/rfcs/022_semi_structured_format_functions.md index 8df68b7..a07650d 100644 --- a/docs/rfcs/022_semi_structured_format_functions.md +++ b/docs/rfcs/022_semi_structured_format_functions.md @@ -1,6 +1,6 @@ # InQL RFC 022: Semi-structured and format functions -- **Status:** Draft +- **Status:** In Progress - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -12,7 +12,7 @@ - InQL RFC 020 (nested data functions) - **Issue:** [InQL #39](https://github.com/dannys-code-corner/InQL/issues/39) - **RFC PR:** — -- **Written against:** Incan v0.2 +- **Written against:** Incan v0.3-era InQL - **Shipped in:** — ## Summary @@ -115,12 +115,41 @@ This RFC is additive. It should not change existing CSV ingestion behavior. - **Execution / interchange** — Prism and Substrait lowering must preserve parser options, hash encodings, and structured return values or diagnose unsupported functions. - **Documentation** — docs should distinguish scalar format functions from session read/write APIs. -## Unresolved questions +## Design Decisions + +### Resolved + +- The first implementation slice is deterministic hashing. JSON, CSV, URL, dynamic-value predicates, and structured parser helpers remain future slices because their schema arguments, option records, path validation, and dynamic value model are not settled here. +- Hash helpers in this slice operate on UTF-8 string bytes and return lowercase hexadecimal strings. +- Portable concrete hash helpers are `md5`, `sha224`, `sha256`, `sha384`, and `sha512`, each with an honest Substrait extension mapping and DataFusion-backed execution coverage. +- `sha2(expr, bit_length)` is a compatibility helper, not a separate backend mapping. It rewrites to `sha224`, `sha256`, `sha384`, or `sha512` for supported literal bit lengths and rejects unsupported values. +- `sha1`, `crc32`, and `xxhash64` are not implemented in the first slice because no honest Substrait/DataFusion mapping was validated for this branch. + +### Remaining - Should `from_json` accept model types directly as schema arguments, or only explicit schema values? - Should invalid JSON path expressions be compile-time errors when literal and runtime errors otherwise? - What option-record shape should CSV and JSON scalar parsers use? -- Should hash functions return binary values or lowercase hexadecimal strings by default? +- Should future binary-oriented hash helpers return binary values, lowercase hexadecimal strings, or an explicit typed encoding wrapper? - Which variant-style type predicates are portable enough for InQL core, and which should stay in a Snowflake-compatibility extension? - +## Implementation Plan + +1. Add registry-backed hashing helpers under a logical function family. +2. Add stable Substrait extension anchors for concrete hash helpers. +3. Keep `sha2(...)` as a compatibility rewrite over concrete helpers rather than a second mapping. +4. Add focused helper, registry, Substrait lowering, and DataFusion session tests with concrete digest values. +5. Add user-facing format-function docs and release notes. +6. Leave parser, URL, and dynamic-value helpers for later RFC 022 slices once their remaining design questions are resolved. + +## Progress Checklist + +- [x] RFC 022 moved to In Progress with a first implementation slice and recorded design decisions. +- [x] `md5`, `sha224`, `sha256`, `sha384`, `sha512`, and `sha2` helpers added under the function catalog. +- [x] Concrete hash helpers registered with Substrait extension metadata. +- [x] `sha2(...)` implemented as a literal-bit-length rewrite with invalid-input diagnostics. +- [x] Focused helper, registry, Substrait lowering, and DataFusion-backed session tests added. +- [x] User-facing format-function docs and release notes added. +- [ ] JSON and CSV scalar parser helpers specified and implemented. +- [ ] URL helper semantics specified and implemented. +- [ ] Dynamic-value predicate semantics specified and implemented. diff --git a/docs/rfcs/README.md b/docs/rfcs/README.md index 4ec010f..290c162 100644 --- a/docs/rfcs/README.md +++ b/docs/rfcs/README.md @@ -28,7 +28,7 @@ InQL uses its **own** RFC series (starting at 000), independent of the [Incan la | [019][rfc-019] | In Progress | Window functions | | | [020][rfc-020] | Draft | Nested data functions | | | [021][rfc-021] | In Progress | Generator and table-valued functions | | -| [022][rfc-022] | Draft | Semi-structured and format functions | | +| [022][rfc-022] | In Progress | Semi-structured and format functions | | | [023][rfc-023] | Draft | Approximate and sketch functions | | | [024][rfc-024] | Draft | Function extension policy | | diff --git a/src/functions/hashing/md5.incn b/src/functions/hashing/md5.incn new file mode 100644 index 0000000..6d4b1cc --- /dev/null +++ b/src/functions/hashing/md5.incn @@ -0,0 +1,51 @@ +""" +MD5 hash helper. + +`md5` hashes a string expression and returns its lowercase hexadecimal digest. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import MD5_FUNCTION_ANCHOR + + +@function_registry.add("md5", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("md5", MD5_FUNCTION_ANCHOR), +)) +pub def md5(expr: ColumnExpr) -> ColumnExpr: + """ + Build an MD5 hexadecimal digest expression. + + Examples: + user_digest = md5(col("user_id")) + + Parameters: + expr: String expression whose UTF-8 bytes should be hashed. + """ + return registered_application("md5", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_md5_builds_registered_application() -> None: + expr = md5(col("payload")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "md5" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/hashing/sha2.incn b/src/functions/hashing/sha2.incn new file mode 100644 index 0000000..154d424 --- /dev/null +++ b/src/functions/hashing/sha2.incn @@ -0,0 +1,76 @@ +""" +SHA-2 compatibility helper. + +`sha2(expr, bits)` rewrites to the matching concrete SHA-2 helper for supported digest lengths. +""" + +from rust::incan_stdlib::errors import raise_value_error +from function_registry import ( + FunctionClass, + FunctionDeterminism, + FunctionErrorBehavior, + FunctionLifecycle, + FunctionNullBehavior, + compatibility_alias_spec, + core_function_namespace, + rewrite_mapping, + v0_1, +) +from functions.hashing.sha224 import sha224 +from functions.hashing.sha256 import sha256 +from functions.hashing.sha384 import sha384 +from functions.hashing.sha512 import sha512 +from functions.registry import function_registry +from projection_builders import ColumnExpr + + +@function_registry.add("sha2", compatibility_alias_spec( + core_function_namespace(), + FunctionClass.Scalar, + ["sha224", "sha256", "sha384", "sha512"], + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionDeterminism.Deterministic, + FunctionNullBehavior.DependsOnInputs, + FunctionErrorBehavior.InvalidInputDiagnostic, + rewrite_mapping("sha2(expr, bits) -> sha224/sha256/sha384/sha512(expr) for supported literal bit lengths"), +)) +pub def sha2(expr: ColumnExpr, bit_length: int) -> ColumnExpr: + """ + Build a SHA-2 hexadecimal digest expression for a supported digest length. + + Examples: + user_digest = sha2(col("user_id"), 256) + + Parameters: + expr: String expression whose UTF-8 bytes should be hashed. + bit_length: Supported digest size: 224, 256, 384, or 512. + """ + if bit_length == 224: + return sha224(expr) + if bit_length == 256: + return sha256(expr) + if bit_length == 384: + return sha384(expr) + if bit_length == 512: + return sha512(expr) + return raise_value_error("sha2 bit_length must be one of 224, 256, 384, or 512") + + +module tests: + from std.testing import assert_raises + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_sha2_rewrites_to_supported_sha2_helper() -> None: + expr = sha2(col("payload"), 256) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "sha256" + assert column_expr_argument_count(expr) == 1 + def _call_sha2_with_unsupported_length() -> None: + sha2(col("payload"), 1) + def test_sha2_rejects_unsupported_bit_length() -> None: + assert_raises[ValueError](_call_sha2_with_unsupported_length) diff --git a/src/functions/hashing/sha224.incn b/src/functions/hashing/sha224.incn new file mode 100644 index 0000000..4b209d1 --- /dev/null +++ b/src/functions/hashing/sha224.incn @@ -0,0 +1,51 @@ +""" +SHA-224 hash helper. + +`sha224` hashes a string expression and returns its lowercase hexadecimal digest. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import SHA224_FUNCTION_ANCHOR + + +@function_registry.add("sha224", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("sha224", SHA224_FUNCTION_ANCHOR), +)) +pub def sha224(expr: ColumnExpr) -> ColumnExpr: + """ + Build a SHA-224 hexadecimal digest expression. + + Examples: + payload_digest = sha224(col("payload")) + + Parameters: + expr: String expression whose UTF-8 bytes should be hashed. + """ + return registered_application("sha224", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_sha224_builds_registered_application() -> None: + expr = sha224(col("payload")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "sha224" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/hashing/sha256.incn b/src/functions/hashing/sha256.incn new file mode 100644 index 0000000..32d0963 --- /dev/null +++ b/src/functions/hashing/sha256.incn @@ -0,0 +1,51 @@ +""" +SHA-256 hash helper. + +`sha256` hashes a string expression and returns its lowercase hexadecimal digest. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import SHA256_FUNCTION_ANCHOR + + +@function_registry.add("sha256", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("sha256", SHA256_FUNCTION_ANCHOR), +)) +pub def sha256(expr: ColumnExpr) -> ColumnExpr: + """ + Build a SHA-256 hexadecimal digest expression. + + Examples: + payload_digest = sha256(col("payload")) + + Parameters: + expr: String expression whose UTF-8 bytes should be hashed. + """ + return registered_application("sha256", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_sha256_builds_registered_application() -> None: + expr = sha256(col("payload")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "sha256" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/hashing/sha384.incn b/src/functions/hashing/sha384.incn new file mode 100644 index 0000000..c7afab1 --- /dev/null +++ b/src/functions/hashing/sha384.incn @@ -0,0 +1,51 @@ +""" +SHA-384 hash helper. + +`sha384` hashes a string expression and returns its lowercase hexadecimal digest. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import SHA384_FUNCTION_ANCHOR + + +@function_registry.add("sha384", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("sha384", SHA384_FUNCTION_ANCHOR), +)) +pub def sha384(expr: ColumnExpr) -> ColumnExpr: + """ + Build a SHA-384 hexadecimal digest expression. + + Examples: + payload_digest = sha384(col("payload")) + + Parameters: + expr: String expression whose UTF-8 bytes should be hashed. + """ + return registered_application("sha384", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_sha384_builds_registered_application() -> None: + expr = sha384(col("payload")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "sha384" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/hashing/sha512.incn b/src/functions/hashing/sha512.incn new file mode 100644 index 0000000..193fe54 --- /dev/null +++ b/src/functions/hashing/sha512.incn @@ -0,0 +1,51 @@ +""" +SHA-512 hash helper. + +`sha512` hashes a string expression and returns its lowercase hexadecimal digest. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import SHA512_FUNCTION_ANCHOR + + +@function_registry.add("sha512", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("sha512", SHA512_FUNCTION_ANCHOR), +)) +pub def sha512(expr: ColumnExpr) -> ColumnExpr: + """ + Build a SHA-512 hexadecimal digest expression. + + Examples: + payload_digest = sha512(col("payload")) + + Parameters: + expr: String expression whose UTF-8 bytes should be hashed. + """ + return registered_application("sha512", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_sha512_builds_registered_application() -> None: + expr = sha512(col("payload")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "sha512" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/mod.incn b/src/functions/mod.incn index 1cfc03c..6652a8f 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -69,6 +69,12 @@ pub from functions.windows.window import window pub from functions.windows.row_number import row_number pub from functions.windows.rank import rank pub from functions.windows.dense_rank import dense_rank +pub from functions.hashing.md5 import md5 +pub from functions.hashing.sha2 import sha2 +pub from functions.hashing.sha224 import sha224 +pub from functions.hashing.sha256 import sha256 +pub from functions.hashing.sha384 import sha384 +pub from functions.hashing.sha512 import sha512 pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div diff --git a/src/lib.incn b/src/lib.incn index a707823..2b6670e 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -115,6 +115,12 @@ pub from functions.windows.window import window pub from functions.windows.row_number import row_number pub from functions.windows.rank import rank pub from functions.windows.dense_rank import dense_rank +pub from functions.hashing.md5 import md5 +pub from functions.hashing.sha2 import sha2 +pub from functions.hashing.sha224 import sha224 +pub from functions.hashing.sha256 import sha256 +pub from functions.hashing.sha384 import sha384 +pub from functions.hashing.sha512 import sha512 pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn index 72e5d5f..fa9cfed 100644 --- a/src/substrait/function_extensions.incn +++ b/src/substrait/function_extensions.incn @@ -79,6 +79,11 @@ pub const ARRAY_FLATTEN_FUNCTION_ANCHOR: u32 = 51 pub const ROW_NUMBER_FUNCTION_ANCHOR: u32 = 52 pub const RANK_FUNCTION_ANCHOR: u32 = 53 pub const DENSE_RANK_FUNCTION_ANCHOR: u32 = 54 +pub const MD5_FUNCTION_ANCHOR: u32 = 55 +pub const SHA224_FUNCTION_ANCHOR: u32 = 56 +pub const SHA256_FUNCTION_ANCHOR: u32 = 57 +pub const SHA384_FUNCTION_ANCHOR: u32 = 58 +pub const SHA512_FUNCTION_ANCHOR: u32 = 59 const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" const EXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode_outer" diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 23902fa..197c275 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -74,6 +74,7 @@ from functions import ( map_keys, map_values, max, + md5, min, modulo, mul, @@ -90,6 +91,11 @@ from functions import ( rank, round, row_number, + sha2, + sha224, + sha256, + sha384, + sha512, str_expr, str_lit, sub, @@ -161,6 +167,7 @@ from substrait.function_extensions import ( MAP_KEYS_FUNCTION_ANCHOR, MAP_VALUES_FUNCTION_ANCHOR, MAX_FUNCTION_ANCHOR, + MD5_FUNCTION_ANCHOR, MIN_FUNCTION_ANCHOR, MODULUS_FUNCTION_ANCHOR, MULTIPLY_FUNCTION_ANCHOR, @@ -173,6 +180,10 @@ from substrait.function_extensions import ( RANK_FUNCTION_ANCHOR, ROW_NUMBER_FUNCTION_ANCHOR, ROUND_FUNCTION_ANCHOR, + SHA224_FUNCTION_ANCHOR, + SHA256_FUNCTION_ANCHOR, + SHA384_FUNCTION_ANCHOR, + SHA512_FUNCTION_ANCHOR, SUBTRACT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR, explode_extension_uri, @@ -238,12 +249,12 @@ def _local_entry_by_namespace_and_name_or_fail( def _expected_registry_names() -> list[str]: """Return the expected registered public helper names.""" - return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer", "window", "row_number", "rank", "dense_rank"] + return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer", "window", "row_number", "rank", "dense_rank", "sha224", "sha256", "sha384", "sha512", "sha2", "md5"] def _expected_substrait_mapped_names() -> list[str]: """Return helpers with concrete Substrait extension-function mappings.""" - return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "row_number", "rank", "dense_rank"] + return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "row_number", "rank", "dense_rank", "sha224", "sha256", "sha384", "sha512", "md5"] def _exercise_current_public_helpers() -> None: @@ -336,6 +347,12 @@ def _exercise_current_public_helpers() -> None: row_number() rank() dense_rank() + sha224(status) + sha256(status) + sha384(status) + sha512(status) + sha2(status, 256) + md5(status) return @@ -458,7 +475,7 @@ def test_function_registry__core_helpers_expose_portable_policy_metadata() -> No # -- Act / Assert -- for entry in entries: assert entry.namespace == core_function_namespace(), f"{entry.function_ref} should live in the core function namespace" - if entry.canonical_name == "count_expr" or entry.canonical_name == "count_distinct" or entry.canonical_name == "count_if": + if entry.canonical_name == "count_expr" or entry.canonical_name == "count_distinct" or entry.canonical_name == "count_if" or entry.canonical_name == "sha2": assert entry.policy_category == FunctionPolicyCategory.CompatibilityAlias, f"{entry.canonical_name} should be marked as a compatibility helper" assert entry.alias_policy == FunctionAliasPolicy.OptInCompatibility, "compatibility helpers should be opt-in by policy" continue @@ -638,6 +655,11 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non _assert_extension_mapping("map_keys", "map_keys", MAP_KEYS_FUNCTION_ANCHOR) _assert_extension_mapping("map_values", "map_values", MAP_VALUES_FUNCTION_ANCHOR) _assert_extension_mapping("named_struct", "named_struct", NAMED_STRUCT_FUNCTION_ANCHOR) + _assert_extension_mapping("sha224", "sha224", SHA224_FUNCTION_ANCHOR) + _assert_extension_mapping("sha256", "sha256", SHA256_FUNCTION_ANCHOR) + _assert_extension_mapping("sha384", "sha384", SHA384_FUNCTION_ANCHOR) + _assert_extension_mapping("sha512", "sha512", SHA512_FUNCTION_ANCHOR) + _assert_extension_mapping("md5", "md5", MD5_FUNCTION_ANCHOR) def test_function_registry__generator_helpers_are_relation_extensions() -> None: @@ -717,6 +739,10 @@ def test_function_registry__rewrite_mappings_identify_non_extension_helpers() -> assert always_false_entry.substrait.kind == SubstraitMappingKind.Rewrite, "always_false should lower as a literal rewrite" _assert_rewrite_mapping("is_not_nan", "not_(is_nan(expr))") _assert_rewrite_mapping("map_contains_key", "gt(cardinality(map_extract(map_expr, key)), int_expr(0))") + _assert_rewrite_mapping( + "sha2", + "sha2(expr, bits) -> sha224/sha256/sha384/sha512(expr) for supported literal bit lengths", + ) assert always_true_entry.null_behavior == FunctionNullBehavior.Predicate, "predicate helpers should expose predicate null behavior" assert always_false_entry.null_behavior == FunctionNullBehavior.Predicate, "predicate helpers should expose predicate null behavior" @@ -749,6 +775,7 @@ def test_function_registry__public_helpers_preserve_existing_behavior() -> None: status, [str_lit("paid"), str_lit("open")], ), lt(amount, int_lit(10)), lte(amount, int_lit(10)), modulo(amount, lit(2)), round(amount)] + hash_exprs = [md5(status), sha2(status, 256), sha224(status), sha256(status), sha384(status), sha512(status)] # -- Assert -- assert column_expr_kind(amount) == ColumnExprKind.Column, "col should still build a column reference" @@ -770,5 +797,7 @@ def test_function_registry__public_helpers_preserve_existing_behavior() -> None: assert column_expr_kind(gt_expr) == ColumnExprKind.ScalarFunction, "gt should use the shared scalar function kind" for core_expr in core_exprs: assert column_expr_kind(core_expr) != ColumnExprKind.Column, "core scalar helpers should build scalar expressions" + for hash_expr in hash_exprs: + assert column_expr_kind(hash_expr) == ColumnExprKind.ScalarFunction, "hash helpers should build scalar expressions" assert column_expr_kind(always_true()) == ColumnExprKind.BoolLiteral, "always_true should still build a bool literal" assert column_expr_kind(always_false()) == ColumnExprKind.BoolLiteral, "always_false should still build a bool literal" diff --git a/tests/test_hashing_functions.incn b/tests/test_hashing_functions.incn new file mode 100644 index 0000000..2cfc4c5 --- /dev/null +++ b/tests/test_hashing_functions.incn @@ -0,0 +1,55 @@ +"""Test: RFC 022 hashing helper surface.""" + +from std.testing import assert_raises +from functions import col, md5, sha2, sha224, sha256, sha384, sha512 +from function_registry import function_ref_for +from projection_builders import ( + ColumnExpr, + ColumnExprKind, + column_expr_argument_count, + column_expr_function_name, + column_expr_function_ref, + column_expr_kind, +) + + +def _assert_hash_application(expr: ColumnExpr, expected_name: str) -> None: + """Assert one hashing helper builds a registry-backed scalar application.""" + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction, f"{expected_name} should use scalar application nodes" + assert column_expr_function_name(expr) == expected_name, f"{expected_name} should preserve its canonical name" + assert column_expr_function_ref(expr) == function_ref_for(expected_name), f"{expected_name} should preserve its function ref" + assert column_expr_argument_count(expr) == 1, f"{expected_name} should carry one string input expression" + + +def _call_sha2_with_unsupported_length() -> None: + """Call sha2 with an unsupported digest length for ValueError assertions.""" + sha2(col("payload"), 1) + return + + +def test_hashing_functions__concrete_helpers_share_scalar_application_node() -> None: + # -- Arrange -- + payload = col("payload") + + # -- Act / Assert -- + _assert_hash_application(md5(payload), "md5") + _assert_hash_application(sha224(payload), "sha224") + _assert_hash_application(sha256(payload), "sha256") + _assert_hash_application(sha384(payload), "sha384") + _assert_hash_application(sha512(payload), "sha512") + + +def test_hashing_functions__sha2_rewrites_to_concrete_sha2_helpers() -> None: + # -- Arrange -- + payload = col("payload") + + # -- Act / Assert -- + _assert_hash_application(sha2(payload, 224), "sha224") + _assert_hash_application(sha2(payload, 256), "sha256") + _assert_hash_application(sha2(payload, 384), "sha384") + _assert_hash_application(sha2(payload, 512), "sha512") + + +def test_hashing_functions__sha2_rejects_unsupported_digest_lengths() -> None: + # -- Arrange / Act / Assert -- + assert_raises[ValueError](_call_sha2_with_unsupported_length) diff --git a/tests/test_session_projection.incn b/tests/test_session_projection.incn index fb6207e..d3d2132 100644 --- a/tests/test_session_projection.incn +++ b/tests/test_session_projection.incn @@ -18,11 +18,16 @@ from functions import ( floor, gt, lit, + md5, modulo, mul, neg, nullif, round, + sha2, + sha224, + sha384, + sha512, sub, try_cast, cardinality, @@ -192,6 +197,40 @@ def test_session_projection__collect_executes_common_math_scalar_projection_func assert payload.contains("3"), "round projection should include round(10 / 4.0)" +def test_session_projection__collect_executes_format_hashing_projection_functions() -> None: + """collect should execute the first RFC 022 hashing helpers through DataFusion.""" + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + projected = lazy.with_column("md5_abc", md5(lit("abc"))).with_column("sha224_abc", sha224(lit("abc"))).with_column( + "sha2_256_abc", + sha2(lit("abc"), 256), + ).with_column("sha384_abc", sha384(lit("abc"))).with_column("sha512_abc", sha512(lit("abc"))) + df = _collect_or_fail(session, projected) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 3, "hashing projections should preserve the input rows" + assert len(resolved) == 7, "projection should expose all appended hash outputs" + assert payload.contains("md5_abc"), "md5 projection should materialize its alias" + assert payload.contains("sha2_256_abc"), "sha2 compatibility projection should materialize its alias" + assert payload.contains("900150983cd24fb0d6963f7d28e17f72"), "md5 should return the lowercase hex digest for abc" + assert payload.contains("23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7"), "sha224 should return the lowercase hex digest for abc" + assert payload.contains("ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"), "sha2(..., 256) should rewrite to sha256" + assert payload.contains( + "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7", + ), "sha384 should return the lowercase hex digest for abc" + assert payload.contains( + "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f", + ), "sha512 should return the lowercase hex digest for abc" + + def test_session_projection__collect_executes_nested_scalar_projection_functions() -> None: """collect should execute RFC 020 nested scalar helpers through DataFusion.""" # -- Arrange -- diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 8a44395..0ec722b 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -55,6 +55,7 @@ from functions import ( map_keys, map_values, max, + md5, min, modulo, mul, @@ -67,6 +68,11 @@ from functions import ( rank, round, row_number, + sha2, + sha224, + sha256, + sha384, + sha512, sub, sum, try_cast, @@ -435,6 +441,12 @@ def test_plan__core_scalar_extension_mappings_lower_to_substrait() -> None: _assert_scalar_expr_lowers(ceil(div(col("amount"), lit(4.0)))) _assert_scalar_expr_lowers(floor(div(col("amount"), lit(4.0)))) _assert_scalar_expr_lowers(round(div(col("amount"), lit(4.0)))) + _assert_scalar_expr_lowers(md5(col("status"))) + _assert_scalar_expr_lowers(sha224(col("status"))) + _assert_scalar_expr_lowers(sha256(col("status"))) + _assert_scalar_expr_lowers(sha384(col("status"))) + _assert_scalar_expr_lowers(sha512(col("status"))) + _assert_scalar_expr_lowers(sha2(col("status"), 256)) def test_plan__nested_scalar_extension_mappings_lower_to_substrait() -> None: