|
| 1 | +# Polars Scalar Functions |
| 2 | + |
| 3 | +This guide explains how to create scalar functions using Polars with the |
| 4 | +expression-based `PolarsScalarFunction` API. |
| 5 | + |
| 6 | +## Overview |
| 7 | + |
| 8 | +`PolarsScalarFunction` provides: |
| 9 | + |
| 10 | +- **Expression-based API**: Return `pl.Expr` instead of computing Series directly |
| 11 | +- **Zero-copy conversion**: Arrow ↔ Polars without data copying |
| 12 | +- **Named column references**: Reference columns by parameter name |
| 13 | +- **Type safety**: Optional type bounds for dynamic types |
| 14 | + |
| 15 | +## Quick Start |
| 16 | + |
| 17 | +```python |
| 18 | +from typing import Annotated |
| 19 | +import polars as pl |
| 20 | +from vgi import PolarsScalarFunction, Param |
| 21 | + |
| 22 | +class UpperCase(PolarsScalarFunction): |
| 23 | + """Convert text to uppercase.""" |
| 24 | + |
| 25 | + # 1. Declare parameter with position and Polars type |
| 26 | + text: Annotated[pl.Utf8, Param(position=0, doc="Input string")] |
| 27 | + |
| 28 | + # 2. Declare output type in Meta |
| 29 | + class Meta: |
| 30 | + output_type = pl.Utf8 |
| 31 | + |
| 32 | + # 3. Return a Polars expression |
| 33 | + def compute_polars(self) -> pl.Expr: |
| 34 | + return pl.col("text").str.to_uppercase() |
| 35 | +``` |
| 36 | + |
| 37 | +## Parameter Declaration |
| 38 | + |
| 39 | +Parameters are declared as class attributes using `Annotated[type, Param(...)]`: |
| 40 | + |
| 41 | +```python |
| 42 | +class MyFunction(PolarsScalarFunction): |
| 43 | + # Single parameter at position 0 |
| 44 | + value: Annotated[pl.Float64, Param(position=0, doc="Input value")] |
| 45 | + |
| 46 | + # Multiple parameters with different positions |
| 47 | + left: Annotated[pl.Int64, Param(position=0, doc="Left operand")] |
| 48 | + right: Annotated[pl.Int64, Param(position=1, doc="Right operand")] |
| 49 | +``` |
| 50 | + |
| 51 | +### Param Options |
| 52 | + |
| 53 | +| Option | Type | Description | |
| 54 | +|--------|------|-------------| |
| 55 | +| `position` | `int` | Column position in input batch (required) | |
| 56 | +| `doc` | `str` | Documentation string | |
| 57 | +| `varargs` | `bool` | Collect all remaining columns | |
| 58 | +| `type_bound` | `Callable` | Type constraint for dynamic types | |
| 59 | + |
| 60 | +## Writing Expressions |
| 61 | + |
| 62 | +In `compute_polars()`, reference columns by their parameter name: |
| 63 | + |
| 64 | +```python |
| 65 | +def compute_polars(self) -> pl.Expr: |
| 66 | + # Reference the "value" parameter as pl.col("value") |
| 67 | + return pl.col("value") * 2 |
| 68 | +``` |
| 69 | + |
| 70 | +### Multiple Columns |
| 71 | + |
| 72 | +```python |
| 73 | +class AddColumns(PolarsScalarFunction): |
| 74 | + left: Annotated[pl.Float64, Param(position=0, doc="First")] |
| 75 | + right: Annotated[pl.Float64, Param(position=1, doc="Second")] |
| 76 | + |
| 77 | + class Meta: |
| 78 | + output_type = pl.Float64 |
| 79 | + |
| 80 | + def compute_polars(self) -> pl.Expr: |
| 81 | + return pl.col("left") + pl.col("right") |
| 82 | +``` |
| 83 | + |
| 84 | +### Using Polars Methods |
| 85 | + |
| 86 | +```python |
| 87 | +def compute_polars(self) -> pl.Expr: |
| 88 | + # String operations |
| 89 | + return pl.col("text").str.to_uppercase() |
| 90 | + |
| 91 | + # Numeric operations |
| 92 | + return pl.col("value").abs().sqrt() |
| 93 | + |
| 94 | + # Conditional logic |
| 95 | + return pl.when(pl.col("x") > 0).then(1).otherwise(-1) |
| 96 | + |
| 97 | + # Aggregations (computed per-batch) |
| 98 | + col = pl.col("value") |
| 99 | + return (col - col.mean()) / col.std() |
| 100 | +``` |
| 101 | + |
| 102 | +## Output Types |
| 103 | + |
| 104 | +### Static Output Type |
| 105 | + |
| 106 | +When output type is known at definition time: |
| 107 | + |
| 108 | +```python |
| 109 | +class Meta: |
| 110 | + output_type = pl.Float64 # or pl.Utf8, pl.Int64, etc. |
| 111 | +``` |
| 112 | + |
| 113 | +### Dynamic Output Type |
| 114 | + |
| 115 | +When output type depends on input (e.g., preserving input type): |
| 116 | + |
| 117 | +```python |
| 118 | +from typing import Any |
| 119 | +import pyarrow.types as pat |
| 120 | +from vgi import AnyPolars |
| 121 | + |
| 122 | +class Double(PolarsScalarFunction): |
| 123 | + value: Annotated[ |
| 124 | + Any, # Accept any type |
| 125 | + Param( |
| 126 | + position=0, |
| 127 | + doc="Value to double", |
| 128 | + # Constrain to numeric types |
| 129 | + type_bound=[pat.is_integer, pat.is_floating], |
| 130 | + ), |
| 131 | + ] |
| 132 | + |
| 133 | + class Meta: |
| 134 | + output_type = AnyPolars # Dynamic type marker |
| 135 | + |
| 136 | + @property |
| 137 | + def output_polars_type(self) -> pl.DataType: |
| 138 | + # Return input type to preserve it |
| 139 | + return self.polars_schema[self.input_schema.field(0).name] |
| 140 | + |
| 141 | + def compute_polars(self) -> pl.Expr: |
| 142 | + return pl.col("value") * 2 |
| 143 | +``` |
| 144 | + |
| 145 | +### Type Bounds |
| 146 | + |
| 147 | +Type bounds constrain what input types are accepted: |
| 148 | + |
| 149 | +```python |
| 150 | +import pyarrow.types as pat |
| 151 | + |
| 152 | +# Single predicate |
| 153 | +type_bound=pat.is_integer |
| 154 | + |
| 155 | +# Multiple predicates (OR logic - any must match) |
| 156 | +type_bound=[pat.is_integer, pat.is_floating] |
| 157 | + |
| 158 | +# Available predicates from pyarrow.types: |
| 159 | +# - pat.is_integer, pat.is_floating, pat.is_numeric |
| 160 | +# - pat.is_string, pat.is_binary, pat.is_boolean |
| 161 | +# - pat.is_temporal, pat.is_date, pat.is_time, pat.is_timestamp |
| 162 | +``` |
| 163 | + |
| 164 | +If validation fails, you get a clear error: |
| 165 | +``` |
| 166 | +SchemaValidationError: Column 'value' has type string, |
| 167 | +but type_bound requires: is_integer, is_floating |
| 168 | +``` |
| 169 | + |
| 170 | +## Variable Arguments (Varargs) |
| 171 | + |
| 172 | +Accept any number of columns with `varargs=True`: |
| 173 | + |
| 174 | +```python |
| 175 | +class SumAll(PolarsScalarFunction): |
| 176 | + values: Annotated[ |
| 177 | + pl.Float64, |
| 178 | + Param(position=0, doc="Values to sum", varargs=True) |
| 179 | + ] |
| 180 | + |
| 181 | + class Meta: |
| 182 | + output_type = pl.Float64 |
| 183 | + |
| 184 | + def compute_polars(self) -> pl.Expr: |
| 185 | + # Vararg columns are renamed to values_0, values_1, etc. |
| 186 | + # Use regex to match all of them |
| 187 | + return pl.sum_horizontal(pl.col("^values_.*$")) |
| 188 | +``` |
| 189 | + |
| 190 | +### How Varargs Work |
| 191 | + |
| 192 | +1. Input columns: `["a", "b", "c"]` |
| 193 | +2. After rename: `["values_0", "values_1", "values_2"]` |
| 194 | +3. Match with: `pl.col("^values_.*$")` |
| 195 | + |
| 196 | +## Constant Arguments |
| 197 | + |
| 198 | +Access scalar values passed in SQL (not from table columns): |
| 199 | + |
| 200 | +```python |
| 201 | +class Multiply(PolarsScalarFunction): |
| 202 | + value: Annotated[pl.Float64, Param(position=0, doc="Column")] |
| 203 | + |
| 204 | + class Meta: |
| 205 | + output_type = pl.Float64 |
| 206 | + |
| 207 | + @property |
| 208 | + def factor(self) -> float: |
| 209 | + """Get constant from SQL arguments.""" |
| 210 | + return self.invocation.arguments.positional[0].as_py() |
| 211 | + |
| 212 | + def compute_polars(self) -> pl.Expr: |
| 213 | + return pl.col("value") * self.factor |
| 214 | +``` |
| 215 | + |
| 216 | +SQL usage: `SELECT polars_multiply(price, 1.1) FROM products` |
| 217 | + |
| 218 | +## Meta Class Options |
| 219 | + |
| 220 | +```python |
| 221 | +class Meta: |
| 222 | + # Output type (required) |
| 223 | + output_type = pl.Float64 |
| 224 | + |
| 225 | + # Function name for SQL (defaults to class name in snake_case) |
| 226 | + name = "my_custom_function" |
| 227 | + |
| 228 | + # Description for catalogs |
| 229 | + description = "Multiplies values by a factor" |
| 230 | + |
| 231 | + # Example queries |
| 232 | + examples = [ |
| 233 | + FunctionExample( |
| 234 | + sql="SELECT my_func(col) FROM table", |
| 235 | + description="Basic usage example", |
| 236 | + ), |
| 237 | + ] |
| 238 | +``` |
| 239 | + |
| 240 | +## Available Instance Attributes |
| 241 | + |
| 242 | +Inside your function methods, you have access to: |
| 243 | + |
| 244 | +| Attribute | Type | Description | |
| 245 | +|-----------|------|-------------| |
| 246 | +| `self.input_schema` | `pa.Schema` | Arrow schema of input | |
| 247 | +| `self.polars_schema` | `Mapping[str, pl.DataType]` | Polars schema | |
| 248 | +| `self.output_schema` | `pa.Schema` | Arrow output schema | |
| 249 | +| `self.invocation` | `Invocation` | Full invocation details | |
| 250 | +| `self.empty_output_batch` | `pa.RecordBatch` | Empty output batch | |
| 251 | + |
| 252 | +## Lifecycle Methods |
| 253 | + |
| 254 | +```python |
| 255 | +class MyFunction(PolarsScalarFunction): |
| 256 | + def bind(self) -> None: |
| 257 | + """Called after input_schema is set. Override to validate or compute.""" |
| 258 | + super().bind() |
| 259 | + # Access self.input_schema, self.polars_schema here |
| 260 | + |
| 261 | + def setup(self) -> None: |
| 262 | + """Called before processing. Acquire resources.""" |
| 263 | + pass |
| 264 | + |
| 265 | + def teardown(self) -> None: |
| 266 | + """Called after processing. Release resources.""" |
| 267 | + pass |
| 268 | +``` |
| 269 | + |
| 270 | +## Complete Example |
| 271 | + |
| 272 | +```python |
| 273 | +from typing import Annotated, Any |
| 274 | +import polars as pl |
| 275 | +import pyarrow.types as pat |
| 276 | +from vgi import PolarsScalarFunction, Param, AnyPolars |
| 277 | +from vgi.metadata import FunctionExample |
| 278 | + |
| 279 | +class ZScoreNormalize(PolarsScalarFunction): |
| 280 | + """Compute z-score normalization: (value - mean) / std. |
| 281 | +
|
| 282 | + Accepts any numeric type and preserves it in the output. |
| 283 | + """ |
| 284 | + |
| 285 | + value: Annotated[ |
| 286 | + Any, |
| 287 | + Param( |
| 288 | + position=0, |
| 289 | + doc="Numeric column to normalize", |
| 290 | + type_bound=[pat.is_integer, pat.is_floating], |
| 291 | + ), |
| 292 | + ] |
| 293 | + |
| 294 | + class Meta: |
| 295 | + name = "zscore_normalize" |
| 296 | + description = "Z-score normalization (standardization)" |
| 297 | + output_type = AnyPolars |
| 298 | + examples = [ |
| 299 | + FunctionExample( |
| 300 | + sql="SELECT zscore_normalize(score) FROM exams", |
| 301 | + description="Normalize exam scores", |
| 302 | + ), |
| 303 | + ] |
| 304 | + |
| 305 | + @property |
| 306 | + def output_polars_type(self) -> pl.DataType: |
| 307 | + # Always output Float64 for normalized values |
| 308 | + return pl.Float64 |
| 309 | + |
| 310 | + def compute_polars(self) -> pl.Expr: |
| 311 | + col = pl.col("value").cast(pl.Float64) |
| 312 | + return (col - col.mean()) / col.std() |
| 313 | +``` |
| 314 | + |
| 315 | +## See Also |
| 316 | + |
| 317 | +- [vgi/examples/scalar_polars.py](../vgi/examples/scalar_polars.py) - Example implementations |
| 318 | +- [vgi/scalar_function_polars.py](../vgi/scalar_function_polars.py) - Base class source |
0 commit comments