Skip to content

Commit e99bd45

Browse files
rustyconoverclaude
andcommitted
Add catalog_output_type/catalog_output_schema to ScalarFunction
Scalar functions now expose their output type for catalog introspection: - Added abstract classmethod catalog_output_type() returning pa.DataType or AnyArrow - Added final classmethod catalog_output_schema() wrapping type in schema - Default output_type property uses catalog_output_type() (DRY) - Functions with AnyArrow output produce schema with pa.null() + vgi:any metadata - CLI shows output_schema only for scalar functions, displays "any" for dynamic types 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 0ebcf9f commit e99bd45

6 files changed

Lines changed: 222 additions & 36 deletions

File tree

tests/catalog/test_example_worker_catalog.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,68 @@ def test_function_info_schema_name(self) -> None:
169169
# All functions should be in 'main' schema
170170
for item in contents:
171171
assert item.schema_name == "main"
172+
173+
def test_scalar_function_has_output_schema(self) -> None:
174+
"""Scalar functions with static output types have output_schema populated."""
175+
client = Client(EXAMPLE_WORKER)
176+
177+
attach_result = client.catalog_attach(name="example", options={})
178+
contents = list(
179+
client.schema_contents(attach_id=attach_result.attach_id, name="main")
180+
)
181+
functions = _get_functions(contents)
182+
183+
# Create lookup by name
184+
by_name = {fn.name: fn for fn in functions}
185+
186+
# upper_case has static output type (string)
187+
upper_info = by_name["upper_case"]
188+
output_schema = pa.ipc.read_schema(pa.py_buffer(upper_info.output_schema))
189+
190+
# Should have a single column named "result" with string type
191+
assert len(output_schema) == 1
192+
assert output_schema.field(0).name == "result"
193+
assert output_schema.field(0).type == pa.string()
194+
195+
def test_scalar_function_with_dynamic_output_has_any_type(self) -> None:
196+
"""Scalar functions with AnyArrow output type have 'any' output_schema."""
197+
client = Client(EXAMPLE_WORKER)
198+
199+
attach_result = client.catalog_attach(name="example", options={})
200+
contents = list(
201+
client.schema_contents(attach_id=attach_result.attach_id, name="main")
202+
)
203+
functions = _get_functions(contents)
204+
205+
# Create lookup by name
206+
by_name = {fn.name: fn for fn in functions}
207+
208+
# double_column returns AnyArrow (output depends on input)
209+
double_info = by_name["double_column"]
210+
output_schema = pa.ipc.read_schema(pa.py_buffer(double_info.output_schema))
211+
212+
# Should have a single "result" field with null type and vgi:any metadata
213+
assert len(output_schema) == 1
214+
assert output_schema.field(0).name == "result"
215+
assert output_schema.field(0).type == pa.null()
216+
assert output_schema.field(0).metadata == {b"vgi:any": b"true"}
217+
218+
def test_table_function_has_empty_output_schema(self) -> None:
219+
"""Table functions have empty output_schema (can't determine without input)."""
220+
client = Client(EXAMPLE_WORKER)
221+
222+
attach_result = client.catalog_attach(name="example", options={})
223+
contents = list(
224+
client.schema_contents(attach_id=attach_result.attach_id, name="main")
225+
)
226+
functions = _get_functions(contents)
227+
228+
# Create lookup by name
229+
by_name = {fn.name: fn for fn in functions}
230+
231+
# echo is a table function
232+
echo_info = by_name["echo"]
233+
output_schema = pa.ipc.read_schema(pa.py_buffer(echo_info.output_schema))
234+
235+
# Table functions don't have catalog_output_schema, so it's empty
236+
assert len(output_schema) == 0

tests/scalar/test_function.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import structlog
1010

1111
from tests.conftest import make_scalar_invocation
12-
from vgi.arguments import Arg, Arguments
12+
from vgi.arguments import AnyArrow, Arg, Arguments
1313
from vgi.exceptions import SchemaValidationError
1414
from vgi.invocation import Invocation, InvocationType
1515
from vgi.log import Level, Message
@@ -158,8 +158,8 @@ def test_basic_compute(self) -> None:
158158
class DoubleColumn(ScalarFunction):
159159
column = Arg[str](0)
160160

161-
@property
162-
def output_type(self) -> pa.DataType:
161+
@classmethod
162+
def catalog_output_type(cls) -> pa.DataType | type[AnyArrow]:
163163
return pa.int64()
164164

165165
def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]:
@@ -194,8 +194,8 @@ def test_log_method(self) -> None:
194194
"""Test self.log() method."""
195195

196196
class LoggingFunc(ScalarFunction):
197-
@property
198-
def output_type(self) -> pa.DataType:
197+
@classmethod
198+
def catalog_output_type(cls) -> pa.DataType | type[AnyArrow]:
199199
return pa.int64()
200200

201201
def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]:
@@ -228,8 +228,8 @@ def test_row_count_validation(self) -> None:
228228
"""Test that row count mismatch raises error."""
229229

230230
class WrongRowCount(ScalarFunction):
231-
@property
232-
def output_type(self) -> pa.DataType:
231+
@classmethod
232+
def catalog_output_type(cls) -> pa.DataType | type[AnyArrow]:
233233
return pa.int64()
234234

235235
def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]:
@@ -255,8 +255,8 @@ def test_row_count_exceeds_input(self) -> None:
255255
"""Test that output with more rows than input raises error (lines 134-142)."""
256256

257257
class TooManyRows(ScalarFunction):
258-
@property
259-
def output_type(self) -> pa.DataType:
258+
@classmethod
259+
def catalog_output_type(cls) -> pa.DataType | type[AnyArrow]:
260260
return pa.int64()
261261

262262
def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]:
@@ -284,8 +284,8 @@ def test_empty_batch(self) -> None:
284284
"""Test handling of empty batches."""
285285

286286
class DoubleFunc(ScalarFunction):
287-
@property
288-
def output_type(self) -> pa.DataType:
287+
@classmethod
288+
def catalog_output_type(cls) -> pa.DataType | type[AnyArrow]:
289289
return pa.int64()
290290

291291
def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]:

vgi/catalog/catalog_interface.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,8 +1043,12 @@ def _function_to_info(self, func_cls: type, schema_name: str) -> FunctionInfo:
10431043
args_schema = argument_specs_to_schema(arg_specs)
10441044
args_bytes = SerializedSchema(args_schema.serialize().to_pybytes())
10451045

1046-
# Output schema placeholder (not available without instantiation)
1047-
output_schema = pa.schema([])
1046+
# Get output schema from catalog introspection methods if available
1047+
output_schema: pa.Schema = pa.schema([])
1048+
has_catalog_schema = hasattr(func_cls, "catalog_output_schema")
1049+
if func_type == FunctionType.SCALAR and has_catalog_schema:
1050+
# ScalarFunction has catalog_output_schema() classmethod
1051+
output_schema = func_cls.catalog_output_schema() # type: ignore[attr-defined]
10481052
output_bytes = SerializedSchema(output_schema.serialize().to_pybytes())
10491053

10501054
return FunctionInfo(

vgi/client/cli_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,20 @@ def arrow_schema_to_json(serialized: bytes) -> list[dict[str, str]]:
177177
"""
178178
reader = pa.BufferReader(serialized)
179179
schema = pa.ipc.read_schema(reader) # type: ignore[arg-type]
180-
return [{"name": f.name, "type": str(f.type)} for f in schema]
180+
result = []
181+
for f in schema:
182+
type_str = str(f.type)
183+
if f.metadata:
184+
# Check for vgi:any metadata (output schema)
185+
if f.metadata.get(b"vgi:any") == b"true":
186+
type_str = "any"
187+
# Check for vgi_type metadata (argument schema)
188+
elif f.metadata.get(b"vgi_type") == b"table":
189+
type_str = "table"
190+
elif f.metadata.get(b"vgi_type") == b"any":
191+
type_str = "any"
192+
result.append({"name": f.name, "type": type_str})
193+
return result
181194

182195

183196
def output_json(data: Any) -> None:
@@ -279,14 +292,18 @@ def function_info_to_dict(function_info: Any) -> dict[str, Any]:
279292
Dictionary representation
280293
281294
"""
282-
return {
295+
result: dict[str, Any] = {
283296
"name": function_info.name,
284297
"schema_name": function_info.schema_name,
285298
"function_type": function_info.function_type.value,
286299
"arguments": arrow_schema_to_json(function_info.arguments),
287300
"comment": function_info.comment,
288301
"tags": function_info.tags,
289302
}
303+
# Only include output_schema for scalar functions
304+
if function_info.function_type.value == "scalar":
305+
result["output_schema"] = arrow_schema_to_json(function_info.output_schema)
306+
return result
290307

291308

292309
def catalog_attach_result_to_dict(result: Any) -> dict[str, Any]:

vgi/examples/scalar.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pyarrow as pa
1818
import pyarrow.compute as pc
1919

20-
from vgi.arguments import Arg
20+
from vgi.arguments import AnyArrow, Arg
2121
from vgi.scalar_function import ScalarFunction
2222

2323
__all__ = [
@@ -46,6 +46,11 @@ class Meta:
4646
# Explicit arrow_type demonstrates type specification
4747
column = Arg[str](0, doc="Column name to double", arrow_type=pa.utf8())
4848

49+
@classmethod
50+
def catalog_output_type(cls) -> pa.DataType | type[AnyArrow]:
51+
"""Output type depends on input column type."""
52+
return AnyArrow
53+
4954
@property
5055
def output_type(self) -> pa.DataType:
5156
"""Return the type of the doubled column."""
@@ -75,6 +80,11 @@ class Meta:
7580
col1 = Arg[str](0, doc="First column name")
7681
col2 = Arg[str](1, doc="Second column name")
7782

83+
@classmethod
84+
def catalog_output_type(cls) -> pa.DataType | type[AnyArrow]:
85+
"""Output type depends on input column type."""
86+
return AnyArrow
87+
7888
@property
7989
def output_type(self) -> pa.DataType:
8090
"""Return the type of the first column."""
@@ -103,11 +113,13 @@ class Meta:
103113

104114
column = Arg[str](0, doc="Column name to uppercase")
105115

106-
@property
107-
def output_type(self) -> pa.DataType:
108-
"""Return string type."""
116+
@classmethod
117+
def catalog_output_type(cls) -> pa.DataType | type[AnyArrow]:
118+
"""Return string type (static output)."""
109119
return pa.string()
110120

121+
# Note: No need to override output_type - default uses catalog_output_type()
122+
111123
def compute(self, batch: pa.RecordBatch) -> pa.Array[Any]:
112124
"""Convert the column values to uppercase."""
113125
return pc.utf8_upper(batch.column(self.column)) # type: ignore[no-matching-overload]

0 commit comments

Comments
 (0)