Skip to content

Commit 5ee344a

Browse files
rustyconoverclaude
andcommitted
Fix ConstParam type inference and improve documentation
- Fix metadata extraction to properly infer types from ConstParam annotations (was hardcoded as "double", now infers from Annotated[type, ...]) - Add clarifying comments in PolarsMultiplyFunction explaining the class attribute vs property pattern for ConstParam - Update docs/polars-scalar-functions.md with proper ConstParam pattern including the metadata declaration and runtime value access Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 5213347 commit 5ee344a

3 files changed

Lines changed: 72 additions & 25 deletions

File tree

docs/polars-scalar-functions.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,26 +195,38 @@ class SumAll(PolarsScalarFunction):
195195

196196
## Constant Arguments
197197

198-
Access scalar values passed in SQL (not from table columns):
198+
Use `ConstParam` to declare scalar values passed in SQL (not from table columns).
199+
This ensures the argument appears in function metadata for catalog registration.
199200

200201
```python
202+
from vgi import ConstParam
203+
201204
class Multiply(PolarsScalarFunction):
202-
value: Annotated[pl.Float64, Param(position=0, doc="Column")]
205+
# Column binding: input column at position 0
206+
value: Annotated[pl.Float64, Param(position=0, doc="Column to multiply")]
207+
208+
# ConstParam declaration: scalar argument at position 0 in function call
209+
# This is a type annotation for metadata - use _factor property to access value
210+
factor: Annotated[float, ConstParam("Multiplication factor", position=0)]
203211

204212
class Meta:
205213
output_type = pl.Float64
206214

207215
@property
208-
def factor(self) -> float:
209-
"""Get constant from SQL arguments."""
216+
def _factor(self) -> float:
217+
"""Get constant from SQL arguments at runtime."""
210218
return self.invocation.arguments.positional[0].as_py()
211219

212220
def compute_polars(self) -> pl.Expr:
213-
return pl.col("value") * self.factor
221+
return pl.col("value") * self._factor
214222
```
215223

216224
SQL usage: `SELECT polars_multiply(price, 1.1) FROM products`
217225

226+
**Important**: The `ConstParam` class attribute is a type annotation for metadata
227+
extraction only. To access the actual value at runtime, use a property that reads
228+
from `self.invocation.arguments.positional[position]`.
229+
218230
## Meta Class Options
219231

220232
```python

vgi/examples/scalar_polars.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,12 @@ class PolarsMultiplyFunction(PolarsScalarFunction):
277277
278278
"""
279279

280+
# Column binding: maps input column at position 0 to "value" in expression
280281
value: Annotated[pl.Float64, Param(position=0, doc="Value to multiply")]
282+
283+
# ConstParam declaration for metadata extraction (tells catalog about the argument).
284+
# The actual value is accessed via _factor property below since class-level
285+
# Annotated declarations are type hints only, not instance attributes.
281286
factor: Annotated[float, ConstParam("Multiplication factor", position=0)]
282287

283288
class Meta:

vgi/metadata.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@ def transform(self, batch):
4646

4747
from __future__ import annotations
4848

49+
import contextlib
4950
import functools
5051
import json
5152
import re
5253
import warnings
5354
from collections.abc import Sequence
5455
from dataclasses import dataclass, field
5556
from enum import Enum, auto
56-
from typing import TYPE_CHECKING, Any, get_type_hints
57+
from typing import TYPE_CHECKING, Annotated, Any, get_args, get_origin, get_type_hints
5758

5859
import pyarrow as pa
5960

@@ -486,26 +487,55 @@ def extract_parameters(
486487
# However, ConstParam entries (is_const=True) ARE function arguments and must
487488
# be extracted for the function signature.
488489
polars_params = getattr(cls, "_polars_params", {})
489-
for name, param_info in polars_params.items():
490-
if not param_info.is_const:
491-
continue # Skip column bindings, only extract ConstParam
492-
seen_names.add(name)
493-
# ConstParam is always required (no default support currently)
494-
type_name = "double" # TODO: Infer from annotation type
495-
parameters.append(
496-
ParameterInfo(
497-
name=name,
498-
position=param_info.position,
499-
type_name=type_name,
500-
description=param_info.doc,
501-
required=True,
502-
default=None,
503-
constraints=None,
504-
is_table_input=False,
505-
is_varargs=False,
506-
is_const=True,
490+
if polars_params:
491+
# Get class annotations to infer ConstParam types
492+
annotations = getattr(cls, "__annotations__", {})
493+
for name, param_info in polars_params.items():
494+
if not param_info.is_const:
495+
continue # Skip column bindings, only extract ConstParam
496+
seen_names.add(name)
497+
498+
# Infer type from annotation (e.g., Annotated[float, ConstParam(...)])
499+
type_name = "any"
500+
if name in annotations:
501+
hint = annotations[name]
502+
# Handle string annotations (from __future__ import annotations)
503+
if isinstance(hint, str):
504+
# Build namespace with builtins, typing constructs, and ConstParam
505+
import builtins
506+
507+
from vgi.arguments import ConstParam
508+
509+
eval_ns = dict(vars(builtins))
510+
eval_ns["Annotated"] = Annotated
511+
eval_ns["ConstParam"] = ConstParam
512+
with contextlib.suppress(Exception):
513+
hint = eval(hint, eval_ns) # noqa: S307
514+
# Extract base type from Annotated[base_type, ...]
515+
base_type = get_args(hint)[0] if get_origin(hint) is Annotated else hint
516+
# Map Python types to Arrow type names
517+
python_to_arrow = {
518+
float: "double",
519+
int: "int64",
520+
str: "string",
521+
bool: "bool",
522+
}
523+
type_name = python_to_arrow.get(base_type, "any")
524+
525+
parameters.append(
526+
ParameterInfo(
527+
name=name,
528+
position=param_info.position,
529+
type_name=type_name,
530+
description=param_info.doc,
531+
required=True, # ConstParam is always required
532+
default=None,
533+
constraints=None,
534+
is_table_input=False,
535+
is_varargs=False,
536+
is_const=True,
537+
)
507538
)
508-
)
509539

510540
# Check for new Param/ConstParam API (ScalarFunction subclasses)
511541
# These are stored in _compute_params and _const_params class attributes

0 commit comments

Comments
 (0)