Skip to content

Commit 2be9090

Browse files
committed
cleanups for scalar functions
1 parent f67ba46 commit 2be9090

24 files changed

Lines changed: 881 additions & 518 deletions

CLAUDE.md

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne
4040
│ ▼ │
4141
│ ┌───────────────────────────────────────────────────────────────┐ │
4242
│ │ Worker Process │ │
43+
│ │ SCALAR FUNCTION (ScalarFunction) │ │
44+
│ │ - compute(batch): Transform each row to single output column │ │
45+
│ │ OR │ │
4346
│ │ TABLE FUNCTION (TableFunctionGenerator) │ │
4447
│ │ - process(): Generator yielding output batches (no input) │ │
4548
│ │ OR │ │
@@ -52,13 +55,15 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne
5255

5356
| Type | Base Class | Input | Use Case |
5457
|------|------------|-------|----------|
58+
| **Scalar Function** | `ScalarFunction` | Batches | Per-row transforms (1:1 row mapping, single output column) |
5559
| **Table Function** | `TableFunctionGenerator` | None | Generate data (sequences, ranges) |
5660
| **Table-In-Out Function** | `TableInOutFunction` | Batches | Transform, filter, aggregate |
5761

5862
### Key Components
5963

6064
- **Worker** (`vgi/worker.py`): Subprocess that hosts functions
6165
- **Client** (`vgi/client/client.py`): Spawns workers, streams data
66+
- **ScalarFunction** (`vgi/scalar_function.py`): Base for scalar functions
6267
- **TableFunctionGenerator** (`vgi/table_function.py`): Base for table functions
6368
- **TableInOutFunction** (`vgi/table_in_out_function.py`): Base for table-in-out functions
6469

@@ -67,7 +72,8 @@ VGI (Vector Gateway Interface) provides an Apache Arrow-based protocol for conne
6772
```
6873
vgi/
6974
__init__.py # Package exports
70-
function.py # Invocation, OutputSpec, Arguments, GlobalInitResult
75+
function.py # Invocation, OutputSpec, Arguments, FunctionType
76+
scalar_function.py # ScalarFunction, ScalarFunctionGenerator
7177
table_function.py # TableFunctionGenerator, CardinalityInfo, Output
7278
table_in_out_function.py # TableInOutFunction, TableInOutGeneratorFunction
7379
metadata.py # Function metadata for introspection
@@ -76,6 +82,7 @@ vgi/
7682
client/
7783
client.py # Client class
7884
examples/
85+
scalar.py # Example scalar functions
7986
table.py # Example table functions
8087
table_in_out.py # Example table-in-out functions
8188
worker.py # ExampleWorker with registry
@@ -89,6 +96,32 @@ vgi-client --input data.parquet --function echo --server vgi-example-worker
8996
vgi-client --input data.parquet --function sum_all_columns --server vgi-example-worker
9097
```
9198

99+
## Creating a Scalar Function (Per-Row Transform)
100+
101+
```python
102+
import pyarrow as pa
103+
import pyarrow.compute as pc
104+
from vgi import ScalarFunction, Arg
105+
106+
class DoubleColumn(ScalarFunction):
107+
"""Double the value in a specified column."""
108+
109+
column = Arg[str](0, doc="Column to double")
110+
111+
@property
112+
def output_type(self) -> pa.DataType:
113+
# Output type matches input column type
114+
return self.input_schema.field(self.column).type
115+
116+
def compute(self, batch: pa.RecordBatch) -> pa.Array:
117+
return pc.multiply(batch.column(self.column), 2)
118+
```
119+
120+
### Key Constraints for Scalar Functions:
121+
- **1:1 row mapping**: Output must have exactly the same number of rows as input
122+
- **Single column output**: Output schema has exactly one column named "result"
123+
- **No finalize phase**: All processing happens in compute()
124+
92125
## Creating a Table-In-Out Function (Recommended)
93126

94127
```python
@@ -182,6 +215,9 @@ if __name__ == "__main__":
182215
### Imports
183216

184217
```python
218+
# Scalar Functions (per-row transform)
219+
from vgi import ScalarFunction, Arg, Worker
220+
185221
# Table Functions (no input)
186222
from vgi import TableFunctionGenerator, Output, Arg, Worker
187223

@@ -221,6 +257,17 @@ output_schema = schema_like(self.input_schema, rename={"old": "new"})
221257

222258
### Method Override Summary
223259

260+
**ScalarFunction:**
261+
262+
| Method | When to Override | Default |
263+
|--------|------------------|---------|
264+
| `output_type` | Define output column type | Required |
265+
| `compute(batch)` | Transform batch to single array | Required |
266+
| `setup()` | Acquire resources | No-op |
267+
| `teardown()` | Release resources | No-op |
268+
269+
**TableInOutFunction:**
270+
224271
| Method | When to Override | Default |
225272
|--------|------------------|---------|
226273
| `output_schema` | Change output columns | Returns input_schema |
@@ -232,17 +279,22 @@ output_schema = schema_like(self.input_schema, rename={"old": "new"})
232279
### Pattern Decision Tree
233280

234281
```
235-
Need to implement a VGI function?
236-
237-
├─ Does the function receive input data?
238-
│ │
239-
│ ├─ NO → Use TableFunctionGenerator
240-
│ │ Override process() to yield Output batches
241-
│ │
242-
│ └─ YES → Use TableInOutFunction
243-
│ ├─ Transform each batch? → Override transform()
244-
│ ├─ Aggregate results? → Accumulate in transform(), emit in finish()
245-
│ └─ Need generator control? → See docs/generator-api.md
282+
How will your function be used in SQL?
283+
284+
1. SELECT my_func(col1, col2) FROM table
285+
→ SCALAR FUNCTION: Returns one value per input row
286+
→ Use ScalarFunction, override output_type and compute()
287+
→ Example: upper(), abs(), concat()
288+
289+
2. SELECT * FROM my_func(args)
290+
→ TABLE FUNCTION: Generates rows from arguments (no input table)
291+
→ Use TableFunctionGenerator, override process()
292+
→ Example: range(), read_csv(), glob()
293+
294+
3. SELECT * FROM my_func(args, (SELECT * FROM input_table))
295+
→ TABLE-IN-OUT FUNCTION: Transforms input rows to output rows
296+
→ Use TableInOutFunction, override transform() and optionally finish()
297+
→ Example: filtering, enrichment, aggregation
246298
```
247299

248300
## Additional Documentation

tests/client/test_cli.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -795,9 +795,7 @@ def test_auto_type_without_input_uses_table(
795795
lines = output_file.read_text().strip().split("\n")
796796
assert len(lines) == 3
797797

798-
def test_scalar_with_add_columns(
799-
self, example_worker: str, tmp_path: Path
800-
) -> None:
798+
def test_scalar_with_add_columns(self, example_worker: str, tmp_path: Path) -> None:
801799
"""Test add_columns scalar function via CLI."""
802800
# Create input with two columns
803801
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]})
@@ -832,9 +830,7 @@ def test_scalar_with_add_columns(
832830
results = [json.loads(line)["result"] for line in lines]
833831
assert results == [11, 22, 33]
834832

835-
def test_scalar_with_upper_case(
836-
self, example_worker: str, tmp_path: Path
837-
) -> None:
833+
def test_scalar_with_upper_case(self, example_worker: str, tmp_path: Path) -> None:
838834
"""Test upper_case scalar function via CLI."""
839835
batch = pa.RecordBatch.from_pydict({"name": ["alice", "bob"]})
840836
input_file = tmp_path / "input.parquet"

tests/scalar/test_function.py

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,23 @@
99
import structlog
1010

1111
from vgi.arguments import Arg
12-
from vgi.function import Arguments, Invocation
12+
from vgi.function import Arguments, Invocation, InvocationType, SchemaValidationError
1313
from vgi.log import Level, Message
1414
from vgi.scalar_function import (
1515
Output,
16-
OutputGenerator,
1716
ProtocolInput,
1817
ScalarFunction,
1918
ScalarFunctionGenerator,
19+
ScalarOutputGenerator,
2020
)
21-
from vgi.table_function import SchemaValidationError
2221

2322

2423
def create_invocation(input_schema: pa.Schema) -> Invocation:
2524
"""Create a test invocation with the given input schema."""
2625
return Invocation(
2726
function_name="test_function",
28-
in_out_function_input_schema=input_schema,
27+
input_schema=input_schema,
28+
function_type=InvocationType.SCALAR,
2929
correlation_id="test-correlation",
3030
invocation_id=b"test-invocation",
3131
arguments=Arguments(),
@@ -43,8 +43,8 @@ class DoubleColumn(ScalarFunctionGenerator):
4343
def output_schema(self) -> pa.Schema:
4444
return pa.schema([("result", pa.int64())])
4545

46-
def process(self, batch: pa.RecordBatch) -> OutputGenerator:
47-
_ = yield None
46+
def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator:
47+
_ = yield Output(self.empty_output_batch) # Priming yield
4848
import pyarrow.compute as pc
4949

5050
while True:
@@ -83,12 +83,13 @@ class TestFunc(ScalarFunctionGenerator):
8383
def output_schema(self) -> pa.Schema:
8484
return pa.schema([("result", pa.int64())])
8585

86-
def process(self, batch: pa.RecordBatch) -> OutputGenerator:
87-
_ = yield None
86+
def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator:
87+
_ = yield Output(self.empty_output_batch)
8888

8989
invocation = Invocation(
9090
function_name="test",
91-
in_out_function_input_schema=None,
91+
input_schema=None,
92+
function_type=InvocationType.SCALAR,
9293
correlation_id="test",
9394
invocation_id=None,
9495
arguments=Arguments(),
@@ -105,8 +106,8 @@ class TwoColumnOutput(ScalarFunctionGenerator):
105106
def output_schema(self) -> pa.Schema:
106107
return pa.schema([("a", pa.int64()), ("b", pa.int64())])
107108

108-
def process(self, batch: pa.RecordBatch) -> OutputGenerator:
109-
_ = yield None
109+
def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator:
110+
_ = yield Output(self.empty_output_batch)
110111

111112
input_schema = pa.schema([("x", pa.int64())])
112113
invocation = create_invocation(input_schema)
@@ -122,8 +123,8 @@ class LoggingScalar(ScalarFunctionGenerator):
122123
def output_schema(self) -> pa.Schema:
123124
return pa.schema([("result", pa.int64())])
124125

125-
def process(self, batch: pa.RecordBatch) -> OutputGenerator:
126-
_ = yield None
126+
def process(self, batch: pa.RecordBatch) -> ScalarOutputGenerator:
127+
_ = yield Output(self.empty_output_batch) # Priming yield
127128
import pyarrow.compute as pc
128129

129130
while True:
@@ -179,7 +180,8 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]:
179180
input_schema = pa.schema([("x", pa.int64())])
180181
invocation = Invocation(
181182
function_name="test",
182-
in_out_function_input_schema=input_schema,
183+
input_schema=input_schema,
184+
function_type=InvocationType.SCALAR,
183185
correlation_id="test",
184186
invocation_id=None,
185187
arguments=Arguments(positional=(pa.scalar("x"),)),
@@ -198,29 +200,6 @@ def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]:
198200
assert output.batch.num_rows == 3
199201
assert output.batch.column("result").to_pylist() == [2, 4, 6]
200202

201-
def test_custom_output_name(self) -> None:
202-
"""Test custom output column name."""
203-
204-
class CustomName(ScalarFunction):
205-
@property
206-
def output_name(self) -> str:
207-
return "doubled"
208-
209-
@property
210-
def output_type(self) -> pa.DataType:
211-
return pa.int64()
212-
213-
def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]:
214-
import pyarrow.compute as pc
215-
216-
return pc.multiply(batch.column("x"), 2)
217-
218-
input_schema = pa.schema([("x", pa.int64())])
219-
invocation = create_invocation(input_schema)
220-
221-
func = CustomName(invocation=invocation, logger=structlog.get_logger())
222-
assert func.output_schema.names == ["doubled"]
223-
224203
def test_log_method(self) -> None:
225204
"""Test self.log() method."""
226205

tests/table/generator/test_constant_table_function.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ def test_cardinality(self) -> None:
2929
"""Cardinality should always be 1."""
3030
import structlog
3131

32-
from vgi.function import Arguments, Invocation
32+
from vgi.function import Arguments, Invocation, InvocationType
3333

3434
invocation = Invocation(
3535
function_name="constant_table",
36-
arguments=Arguments(positional=(pa.scalar(42),)),
37-
in_out_function_input_schema=None,
36+
input_schema=None,
37+
function_type=InvocationType.TABLE,
3838
correlation_id="test",
3939
invocation_id=b"test",
40+
arguments=Arguments(positional=(pa.scalar(42),)),
4041
)
4142
func = ConstantTableFunction(
4243
invocation=invocation,

tests/table/generator/test_projected_data_function.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,16 @@ def test_output_schema_reflects_projection(self) -> None:
152152
"""The output_schema property should reflect the projection."""
153153
import structlog
154154

155-
from vgi.function import Invocation
156-
from vgi.table_function import GlobalStateInitInput
155+
from vgi.function import Invocation, InvocationType
156+
from vgi.table_function import TableFunctionInitInput
157157

158158
invocation = Invocation(
159159
function_name="projected_data",
160-
arguments=Arguments(positional=(pa.scalar(10),)),
161-
in_out_function_input_schema=None,
160+
input_schema=None,
161+
function_type=InvocationType.TABLE,
162162
correlation_id="test",
163163
invocation_id=b"test",
164+
arguments=Arguments(positional=(pa.scalar(10),)),
164165
)
165166
func = ProjectedDataFunction(
166167
invocation=invocation,
@@ -171,7 +172,7 @@ def test_output_schema_reflects_projection(self) -> None:
171172
assert func.output_schema == ProjectedDataFunction.FULL_SCHEMA
172173

173174
# After setting init_data with projection, should return projected schema
174-
func.init_data = GlobalStateInitInput(projection_ids=[0, 2])
175+
func.init_data = TableFunctionInitInput(projection_ids=[0, 2])
175176
schema = func.output_schema
176177
assert len(schema) == 2
177178
assert schema.names == ["id", "value"]

tests/table/generator/test_sequence_function.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ def test_cardinality(self) -> None:
3030
"""Cardinality should match requested count."""
3131
import structlog
3232

33-
from vgi.function import Arguments, Invocation
33+
from vgi.function import Arguments, Invocation, InvocationType
3434

3535
invocation = Invocation(
3636
function_name="sequence",
37-
arguments=Arguments(positional=(pa.scalar(100),)),
38-
in_out_function_input_schema=None,
37+
input_schema=None,
38+
function_type=InvocationType.TABLE,
3939
correlation_id="test",
4040
invocation_id=b"test",
41+
arguments=Arguments(positional=(pa.scalar(100),)),
4142
)
4243
func = SequenceFunction(
4344
invocation=invocation,

tests/table/test_function.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import structlog
88

99
from tests.utils import make_schema
10-
from vgi.function import Arguments, Invocation
10+
from vgi.function import Arguments, Invocation, InvocationType
1111
from vgi.table_function import (
1212
CardinalityInfo,
1313
Output,
@@ -194,10 +194,11 @@ def output_schema(self) -> pa.Schema:
194194

195195
invocation = Invocation(
196196
function_name="test",
197-
arguments=Arguments(),
198-
in_out_function_input_schema=None,
197+
input_schema=None,
198+
function_type=InvocationType.TABLE,
199199
correlation_id="test",
200200
invocation_id=b"test",
201+
arguments=Arguments(),
201202
)
202203
func = NoCardinalityFunction(
203204
invocation=invocation,
@@ -219,10 +220,11 @@ def cardinality(self) -> CardinalityInfo:
219220

220221
invocation = Invocation(
221222
function_name="test",
222-
arguments=Arguments(),
223-
in_out_function_input_schema=None,
223+
input_schema=None,
224+
function_type=InvocationType.TABLE,
224225
correlation_id="test",
225226
invocation_id=b"test",
227+
arguments=Arguments(),
226228
)
227229
func = CardinalityFunction(
228230
invocation=invocation,
@@ -311,10 +313,11 @@ def output_schema(self) -> pa.Schema:
311313

312314
invocation = Invocation(
313315
function_name="test",
314-
arguments=Arguments(),
315-
in_out_function_input_schema=None,
316+
input_schema=None,
317+
function_type=InvocationType.TABLE,
316318
correlation_id="test",
317319
invocation_id=b"test",
320+
arguments=Arguments(),
318321
)
319322
func = TestFunction(
320323
invocation=invocation,

0 commit comments

Comments
 (0)