Skip to content

Commit 7f892a6

Browse files
rustyconoverclaude
andcommitted
Convert all legacy scalar examples to new Param API with type_bound support
- Add type_bound validation to _validate_param_types for AnyArrow params - Convert DoubleColumnFunction to use Param(AnyArrow, type_bound=...) - Convert AddNumericColumnsFunction to use Param(AnyArrow, type_bound=...) - Convert SumColumnsFunction to use Param(AnyArrow, type_bound=..., varargs=True) - Fix type_bound warning check to also consider is_any flag for new API - Update module docstring to reflect all functions using new API Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c175e6d commit 7f892a6

3 files changed

Lines changed: 72 additions & 76 deletions

File tree

vgi/arguments.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,8 @@ def __call__(
719719
)
720720

721721
# Warn if type_bound is used with non-AnyArrow type
722-
if type_bound is not None and self._type_param is not AnyArrow:
722+
# Check both _type_param (legacy API) and is_any (new Param API)
723+
if type_bound is not None and self._type_param is not AnyArrow and not is_any:
723724
type_name = getattr(self._type_param, "__name__", str(self._type_param))
724725
warnings.warn(
725726
f"type_bound is only meaningful for Arg[AnyArrow], "

vgi/examples/scalar.py

Lines changed: 55 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,30 @@
33
This module provides example scalar functions that transform input batches
44
to single-column output with 1:1 row mapping.
55
6-
NEW PARAM/CONSTPARAM/RETURNS API
7-
--------------------------------
6+
All functions use the Param/ConstParam/Returns API:
7+
8+
STATIC OUTPUT TYPE
9+
------------------
810
MultiplyFunction - Multiplies column by constant (ConstParam example)
911
UpperCaseFunction - Converts string column to uppercase
1012
NullHandlingFunction - Demonstrates special null handling (NullHandling.SPECIAL)
1113
RandomIntFunction - Generates random integers (VOLATILE stability)
1214
13-
LEGACY ARG DESCRIPTOR API (still supported)
14-
-------------------------------------------
15+
DYNAMIC OUTPUT TYPE (with type_bound)
16+
-------------------------------------
1517
DoubleColumnFunction - Doubles values in a numeric column (AnyArrow + type_bound)
1618
AddNumericColumnsFunction - Adds two numeric columns (type promotion)
1719
SumColumnsFunction - Sums multiple columns (varargs example)
1820
"""
1921

2022
from __future__ import annotations
2123

22-
from typing import Annotated, Any
24+
from typing import Any
2325

2426
import pyarrow as pa
2527
import pyarrow.compute as pc
2628

27-
from vgi.arguments import AnyArrow, AnyArrowValue, Arg, ConstParam, Param, Returns
29+
from vgi.arguments import AnyArrow, ConstParam, Param, Returns
2830
from vgi.exceptions import SchemaValidationError
2931
from vgi.metadata import FunctionExample, FunctionStability, NullHandling
3032
from vgi.scalar_function import ScalarFunction
@@ -125,17 +127,16 @@ def compute(
125127
return pc.multiply(column, factor)
126128

127129

128-
# =============================================================================
129-
# Legacy Arg Descriptor API Examples (still supported)
130-
# =============================================================================
131-
132-
133130
class DoubleColumnFunction(ScalarFunction):
134131
"""Doubles values in a numeric column.
135132
133+
This example demonstrates the new Param API with:
134+
- AnyArrow type with type_bound for flexible numeric input
135+
- Dynamic output type computed in bind()
136+
136137
Example:
137-
Input: x=[1, 2, 3]
138-
Args: column="x"
138+
SQL: SELECT double_column(price) FROM products
139+
Input: price=[1, 2, 3]
139140
Output: result=[2, 4, 6]
140141
141142
"""
@@ -145,7 +146,6 @@ class Meta:
145146

146147
name = "double_column"
147148
description = "Doubles values in a numeric column"
148-
output_type = AnyArrow # Output type depends on input column type
149149
examples = [
150150
FunctionExample(
151151
sql="SELECT double_column(price) FROM products",
@@ -157,37 +157,42 @@ class Meta:
157157
),
158158
]
159159

160-
# Explicit arrow_type demonstrates type specification
161-
column: Annotated[
162-
AnyArrowValue, Arg(0, doc="Value to double", type_bound=_is_addable_type)
163-
]
160+
_output_type: pa.DataType
164161

165162
def bind(self) -> None:
166163
"""Compute output type from input column types."""
167-
field1 = self.input_schema.field(self.column.value)
168-
169-
# Since we're going to be multiplying by 2, promote to a wider type
170-
self._output_type = _promote_for_addition(field1.type)
164+
# Get the input column type from the schema
165+
field = self.input_schema.field(0)
166+
# Promote to a wider type since we're multiplying by 2
167+
self._output_type = _promote_for_addition(field.type)
171168

172169
@property
173170
def output_type(self) -> pa.DataType:
174171
"""Return the type of the doubled column."""
175172
return self._output_type
176173

177-
def compute(self, *, column: pa.Array[Any]) -> pa.Array[Any]:
174+
def compute(
175+
self,
176+
column: Param(AnyArrow, "Numeric value to double", type_bound=_is_addable_type), # type: ignore[valid-type]
177+
) -> pa.Array[Any]:
178178
"""Double the values in the specified column."""
179-
return pc.multiply(column, 2)
179+
result: pa.Array[Any] = pc.multiply(column, 2)
180+
return result
180181

181182

182183
class AddNumericColumnsFunction(ScalarFunction):
183184
"""Adds two numeric columns together.
184185
186+
This example demonstrates:
187+
- Multiple Param() annotations with type_bound validation
188+
- Dynamic output type with type promotion for overflow safety
189+
185190
Validates that both columns are numeric types (integer, float, decimal, or
186-
temporal) at bind time, raising SchemaValidationError if not.
191+
temporal) at compute time, raising SchemaValidationError if not.
187192
188193
Example:
189-
Input: a=[1, 2, 3], b=[10, 20, 30]
190-
Args: col1="a", col2="b"
194+
SQL: SELECT add_columns(price, tax) FROM orders
195+
Input: price=[1, 2, 3], tax=[10, 20, 30]
191196
Output: result=[11, 22, 33]
192197
193198
Raises:
@@ -200,7 +205,6 @@ class Meta:
200205

201206
name = "add_columns"
202207
description = "Adds two numeric columns"
203-
output_type = AnyArrow # Output type depends on input column types
204208
examples = [
205209
FunctionExample(
206210
sql="SELECT add_columns(price, tax) FROM orders",
@@ -212,25 +216,15 @@ class Meta:
212216
),
213217
]
214218

215-
# type_bound validates value types at bind time (automatic via Function.__init__)
216-
col1: Annotated[
217-
AnyArrowValue, Arg(0, doc="First numeric value", type_bound=_is_addable_type)
218-
]
219-
col2: Annotated[
220-
AnyArrowValue, Arg(1, doc="Second numeric value", type_bound=_is_addable_type)
221-
]
222-
223219
_output_type: pa.DataType
224220

225221
def bind(self) -> None:
226222
"""Compute output type from input column types."""
227-
field1 = self.input_schema.field(self.col1.value)
228-
field2 = self.input_schema.field(self.col2.value)
223+
field1 = self.input_schema.field(0)
224+
field2 = self.input_schema.field(1)
229225

230226
# Compute the output type by promoting to the wider of the two types,
231227
# then promoting again to reduce overflow risk.
232-
# Use pc.add with null values to determine the common type, as PyArrow's
233-
# compute functions handle type promotion correctly.
234228
common_type = pc.add(
235229
pa.nulls(1, type=field1.type), pa.nulls(1, type=field2.type)
236230
).type
@@ -241,9 +235,14 @@ def output_type(self) -> pa.DataType:
241235
"""Return the computed output type based on input column types."""
242236
return self._output_type
243237

244-
def compute(self, *, col1: pa.Array[Any], col2: pa.Array[Any]) -> pa.Array[Any]:
238+
def compute(
239+
self,
240+
col1: Param(AnyArrow, "First numeric value", type_bound=_is_addable_type), # type: ignore[valid-type]
241+
col2: Param(AnyArrow, "Second numeric value", type_bound=_is_addable_type), # type: ignore[valid-type]
242+
) -> pa.Array[Any]:
245243
"""Add the two columns together."""
246-
return pc.add(col1, col2)
244+
result: pa.Array[Any] = pc.add(col1, col2)
245+
return result
247246

248247

249248
class UpperCaseFunction(ScalarFunction):
@@ -285,12 +284,14 @@ def compute(
285284
class SumColumnsFunction(ScalarFunction):
286285
"""Sums values from multiple numeric columns.
287286
288-
Uses varargs with type_bound to accept any number of numeric columns
289-
and validates that all columns are addable types at bind time.
287+
This example demonstrates:
288+
- varargs=True to accept variable number of columns
289+
- type_bound validation on all varargs columns
290+
- Dynamic output type computed in bind()
290291
291292
Example:
292-
Input: a=[1, 2], b=[10, 20], c=[100, 200]
293-
Args: columns=('a', 'b', 'c')
293+
SQL: SELECT sum_columns(price, tax, shipping) FROM orders
294+
Input: price=[1, 2], tax=[10, 20], shipping=[100, 200]
294295
Output: result=[111, 222]
295296
296297
"""
@@ -300,43 +301,33 @@ class Meta:
300301

301302
name = "sum_columns"
302303
description = "Sum values from multiple numeric columns"
303-
output_type = AnyArrow # Output type depends on input column types
304304
examples = [
305305
FunctionExample(
306306
sql="SELECT sum_columns(price, tax, shipping) FROM orders",
307307
description="Calculate total cost from multiple columns",
308308
),
309309
]
310310

311-
# Varargs with type_bound validates all values are numeric
312-
# Note: varargs returns tuple[Any, ...], not AnyArrowValue
313-
columns: Annotated[
314-
tuple[Any, ...],
315-
Arg(
316-
0,
317-
varargs=True,
318-
type_bound=_is_addable_type,
319-
doc="Numeric values to sum",
320-
),
321-
]
322-
323311
_output_type: pa.DataType
324312

325313
def bind(self) -> None:
326314
"""Compute output type from first column, promoted for overflow safety."""
327-
# With varargs=True, self.columns is a tuple of column names
328-
first_col = self.columns[0] # type: ignore[index]
329-
first_type = self.input_schema.field(first_col).type
315+
first_type = self.input_schema.field(0).type
330316
self._output_type = _promote_for_addition(first_type)
331317

332318
@property
333319
def output_type(self) -> pa.DataType:
334320
"""Return the computed output type based on first column."""
335321
return self._output_type
336322

337-
def compute(self, *, columns: list[pa.Array[Any]]) -> pa.Array[Any]:
323+
def compute(
324+
self,
325+
columns: Param( # type: ignore[valid-type]
326+
AnyArrow, "Numeric values to sum", type_bound=_is_addable_type, varargs=True
327+
),
328+
) -> pa.Array[Any]:
338329
"""Sum values from all specified columns."""
339-
result = columns[0]
330+
result: pa.Array[Any] = columns[0]
340331
for col in columns[1:]:
341332
result = pc.add(result, col)
342333
return result

vgi/scalar_function.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -980,31 +980,31 @@ def _extract_compute_kwargs(self, batch: pa.RecordBatch) -> dict[str, Any]:
980980
def _validate_param_types(self, kwargs: dict[str, Any]) -> None:
981981
"""Validate that input array types match declared Param types.
982982
983-
Only validates for the new Param/ConstParam API, and only for params
984-
that have a declared arrow_type (not is_any=True).
983+
For the new Param/ConstParam API:
984+
- Validates exact type match for params with declared arrow_type
985+
- Validates type_bound predicates for AnyArrow params with type_bound
985986
986987
Args:
987988
kwargs: Dict of parameter names to arrays (from _extract_compute_kwargs).
988989
989990
Raises:
990991
TypeMismatchError: If any array type doesn't match its declared type.
992+
SchemaValidationError: If any array type fails type_bound validation.
991993
992994
"""
993995
if not self._uses_new_param_api:
994-
return # Legacy API doesn't have explicit type declarations
996+
return # Legacy API uses _validate_type_bounds in Function.__init__
995997

996998
for name, arg in self._compute_params.items():
997-
if arg.is_any:
998-
continue # Skip AnyArrow params
999-
1000-
if arg.arrow_type is None:
1001-
continue # No type declared (shouldn't happen with new API)
1002-
1003999
if arg.varargs:
10041000
# Validate all arrays in varargs
10051001
arrays = kwargs[name]
10061002
for i, arr in enumerate(arrays):
1007-
if arr.type != arg.arrow_type:
1003+
if arg.is_any:
1004+
# AnyArrow: validate type_bound if specified
1005+
if arg.type_bound is not None:
1006+
arg.validate_type_bound(arr.type)
1007+
elif arg.arrow_type is not None and arr.type != arg.arrow_type:
10081008
raise TypeMismatchError(
10091009
f"Input type mismatch for vararg parameter '{name}' "
10101010
f"at index {i}.",
@@ -1015,7 +1015,11 @@ def _validate_param_types(self, kwargs: dict[str, Any]) -> None:
10151015
)
10161016
else:
10171017
arr = kwargs[name]
1018-
if arr.type != arg.arrow_type:
1018+
if arg.is_any:
1019+
# AnyArrow: validate type_bound if specified
1020+
if arg.type_bound is not None:
1021+
arg.validate_type_bound(arr.type)
1022+
elif arg.arrow_type is not None and arr.type != arg.arrow_type:
10191023
raise TypeMismatchError(
10201024
f"Input type mismatch for parameter '{name}'.",
10211025
param_name=name,

0 commit comments

Comments
 (0)