Skip to content

Commit b87f40e

Browse files
rustyconoverclaude
andcommitted
Convert example functions to bind() and add varargs support to worker matching
- Convert AddNumericColumnsFunction to use bind() instead of __init__ - Convert RepeatInputsFunction, SumAllColumnsFunction, and SumAllColumnsFunctionWithLogging to use bind() - Fix worker _match_function to support varargs parameters (unlimited positional args) - Add comprehensive tests for SumColumnsFunction (7 test cases) - Remove unused imports from examples 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 9165661 commit b87f40e

4 files changed

Lines changed: 159 additions & 34 deletions

File tree

tests/scalar/test_client.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,145 @@ def test_add_columns_accepts_mixed_int_types(self, example_worker: str) -> None:
235235
assert outputs[0].schema.field("result").type == pa.int64()
236236

237237

238+
class TestSumColumns:
239+
"""Tests for SumColumnsFunction via Client."""
240+
241+
def test_sum_two_columns(self, example_worker: str) -> None:
242+
"""Sum of two columns."""
243+
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
244+
batch = pa.RecordBatch.from_pydict(
245+
{"a": [1, 2, 3], "b": [10, 20, 30]}, schema=schema
246+
)
247+
248+
with Client(example_worker) as client:
249+
outputs = list(
250+
client.scalar_function(
251+
function_name="sum_columns",
252+
input=iter([batch]),
253+
arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))),
254+
)
255+
)
256+
257+
assert len(outputs) == 1
258+
assert outputs[0].to_pydict() == {"result": [11, 22, 33]}
259+
260+
def test_sum_three_columns(self, example_worker: str) -> None:
261+
"""Sum of three columns using varargs."""
262+
schema = pa.schema([("a", pa.int64()), ("b", pa.int64()), ("c", pa.int64())])
263+
batch = pa.RecordBatch.from_pydict(
264+
{"a": [1, 2], "b": [10, 20], "c": [100, 200]}, schema=schema
265+
)
266+
267+
with Client(example_worker) as client:
268+
outputs = list(
269+
client.scalar_function(
270+
function_name="sum_columns",
271+
input=iter([batch]),
272+
arguments=Arguments(
273+
positional=(pa.scalar("a"), pa.scalar("b"), pa.scalar("c"))
274+
),
275+
)
276+
)
277+
278+
assert len(outputs) == 1
279+
assert outputs[0].to_pydict() == {"result": [111, 222]}
280+
281+
def test_sum_with_type_promotion(self, example_worker: str) -> None:
282+
"""Different int types promote correctly."""
283+
schema = pa.schema([("a", pa.int32()), ("b", pa.int64())])
284+
batch = pa.RecordBatch.from_pydict({"a": [1, 2], "b": [10, 20]}, schema=schema)
285+
286+
with Client(example_worker) as client:
287+
outputs = list(
288+
client.scalar_function(
289+
function_name="sum_columns",
290+
input=iter([batch]),
291+
arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))),
292+
)
293+
)
294+
295+
assert len(outputs) == 1
296+
assert outputs[0].to_pydict() == {"result": [11, 22]}
297+
# Output should be int64 (promoted from int32)
298+
assert outputs[0].schema.field("result").type == pa.int64()
299+
300+
def test_sum_rejects_string_column(self, example_worker: str) -> None:
301+
"""Type bound rejects non-numeric columns."""
302+
schema = pa.schema([("a", pa.int64()), ("b", pa.string())]) # type: ignore[arg-type]
303+
batch = pa.RecordBatch.from_pydict(
304+
{"a": [1, 2], "b": ["x", "y"]}, schema=schema
305+
)
306+
307+
with (
308+
Client(example_worker) as client,
309+
pytest.raises(Exception, match="does not match any of"),
310+
):
311+
list(
312+
client.scalar_function(
313+
function_name="sum_columns",
314+
input=iter([batch]),
315+
arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))),
316+
)
317+
)
318+
319+
def test_sum_multiple_batches(self, example_worker: str) -> None:
320+
"""Multiple input batches processed correctly."""
321+
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
322+
batch1 = pa.RecordBatch.from_pydict({"a": [1, 2], "b": [10, 20]}, schema=schema)
323+
batch2 = pa.RecordBatch.from_pydict({"a": [3, 4], "b": [30, 40]}, schema=schema)
324+
325+
with Client(example_worker) as client:
326+
outputs = list(
327+
client.scalar_function(
328+
function_name="sum_columns",
329+
input=iter([batch1, batch2]),
330+
arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))),
331+
)
332+
)
333+
334+
assert_total_rows(outputs, 4)
335+
all_values: list[int] = []
336+
for batch in outputs:
337+
all_values.extend(cast(list[int], batch.column("result").to_pylist()))
338+
assert sorted(all_values) == [11, 22, 33, 44]
339+
340+
def test_sum_empty_batch(self, example_worker: str) -> None:
341+
"""Empty batch returns empty output."""
342+
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
343+
empty_batch = pa.RecordBatch.from_pydict({"a": [], "b": []}, schema=schema)
344+
345+
with Client(example_worker) as client:
346+
outputs = list(
347+
client.scalar_function(
348+
function_name="sum_columns",
349+
input=iter([empty_batch]),
350+
arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))),
351+
)
352+
)
353+
354+
assert len(outputs) == 1
355+
assert outputs[0].num_rows == 0
356+
357+
def test_sum_float_columns(self, example_worker: str) -> None:
358+
"""Sum of float columns."""
359+
schema = pa.schema([("a", pa.float64()), ("b", pa.float64())])
360+
batch = pa.RecordBatch.from_pydict(
361+
{"a": [1.5, 2.5], "b": [0.5, 0.5]}, schema=schema
362+
)
363+
364+
with Client(example_worker) as client:
365+
outputs = list(
366+
client.scalar_function(
367+
function_name="sum_columns",
368+
input=iter([batch]),
369+
arguments=Arguments(positional=(pa.scalar("a"), pa.scalar("b"))),
370+
)
371+
)
372+
373+
assert len(outputs) == 1
374+
assert outputs[0].to_pydict() == {"result": [2.0, 3.0]}
375+
376+
238377
class TestScalarFunctionParallel:
239378
"""Tests for scalar functions with parallel processing."""
240379

vgi/examples/scalar.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616

1717
import pyarrow as pa
1818
import pyarrow.compute as pc
19-
import structlog
2019

21-
import vgi.invocation
2220
from vgi.arguments import AnyArrow, Arg
2321
from vgi.exceptions import SchemaValidationError
2422
from vgi.metadata import FunctionExample
@@ -151,18 +149,12 @@ class Meta:
151149
col1 = Arg[AnyArrow](0, doc="First column name", type_bound=_is_addable_type)
152150
col2 = Arg[AnyArrow](1, doc="Second column name", type_bound=_is_addable_type)
153151

154-
def __init__(
155-
self,
156-
invocation: vgi.invocation.Invocation,
157-
logger: structlog.stdlib.BoundLogger,
158-
):
159-
"""Initialize and compute output type based on input column types."""
160-
super().__init__(invocation, logger)
161-
assert invocation.input_schema is not None # Required for scalar functions
162-
163-
# Type validation is automatic via type_bound - we just compute output type
164-
field1 = invocation.input_schema.field(self.col1.value)
165-
field2 = invocation.input_schema.field(self.col2.value)
152+
_output_type: pa.DataType
153+
154+
def bind(self) -> None:
155+
"""Compute output type from input column types."""
156+
field1 = self.input_schema.field(self.col1.value)
157+
field2 = self.input_schema.field(self.col2.value)
166158

167159
# Compute the output type by promoting to the wider of the two types,
168160
# then promoting again to reduce overflow risk.

vgi/examples/table_in_out.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727

2828
import pyarrow as pa
2929
import pyarrow.compute as pc
30-
import structlog
3130

3231
from vgi.arguments import Arg, TableInput
33-
from vgi.invocation import Invocation
3432
from vgi.ipc_utils import RecordBatchState
3533
from vgi.log import Level, Message
3634
from vgi.metadata import FunctionExample
@@ -239,13 +237,8 @@ class Meta:
239237
repeat_count = Arg[int](0, doc="Number of times to repeat each input batch")
240238
data: TableInput = Arg[TableInput](1, doc="Input table to repeat") # type: ignore[assignment]
241239

242-
def __init__(
243-
self, invocation: Invocation, logger: structlog.stdlib.BoundLogger
244-
) -> None:
245-
"""Initialize and validate repeat count argument."""
246-
super().__init__(invocation=invocation, logger=logger)
247-
248-
# Access to trigger validation early
240+
def bind(self) -> None:
241+
"""Validate repeat count argument."""
249242
if self.repeat_count < 1:
250243
raise ValueError("Repeat count must be at least 1")
251244

@@ -363,11 +356,8 @@ def cardinality(self) -> TableCardinality | None:
363356
"""Return cardinality estimate of exactly 1 row."""
364357
return TableCardinality(estimate=1, max=1)
365358

366-
def __init__(
367-
self, invocation: Invocation, logger: structlog.stdlib.BoundLogger
368-
) -> None:
359+
def bind(self) -> None:
369360
"""Initialize the sum accumulator."""
370-
super().__init__(invocation=invocation, logger=logger)
371361
self.sums: dict[str, pa.Scalar[Any]] = {}
372362

373363
@property
@@ -671,11 +661,8 @@ class Meta:
671661

672662
data: TableInput = Arg[TableInput](0, doc="Input table with numeric columns") # type: ignore[assignment]
673663

674-
def __init__(
675-
self, invocation: Invocation, logger: structlog.stdlib.BoundLogger
676-
) -> None:
664+
def bind(self) -> None:
677665
"""Initialize with empty sums dict."""
678-
super().__init__(invocation=invocation, logger=logger)
679666
self.sums: dict[str, pa.Scalar[Any]] = {}
680667

681668
@property

vgi/worker.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,18 @@ def _match_function(
263263

264264
# Check positional arguments
265265
required_positional = [p for p in positional_params if p.required]
266-
max_positional = len(positional_params)
266+
has_varargs = any(p.is_varargs for p in positional_params)
267267
min_positional = len(required_positional)
268268

269-
if not (min_positional <= num_positional <= max_positional):
270-
continue # Wrong number of positional arguments
269+
if has_varargs:
270+
# Varargs: allow any number >= min_positional
271+
if num_positional < min_positional:
272+
continue # Too few positional arguments
273+
else:
274+
# Fixed positional: must be within [min, max]
275+
max_positional = len(positional_params)
276+
if not (min_positional <= num_positional <= max_positional):
277+
continue # Wrong number of positional arguments
271278

272279
# Check named arguments
273280
valid_named_keys = {p.position for p in named_params}

0 commit comments

Comments
 (0)