@@ -235,6 +235,145 @@ def test_add_columns_accepts_mixed_int_types(self, example_worker: str) -> None:
235235 assert outputs [0 ].schema .field ("result" ).type == pa .int64 ()
236236
237237
238+ class TestSumColumns :
239+ """Tests for SumColumnsFunction via Client."""
240+
241+ def test_sum_two_columns (self , example_worker : str ) -> None :
242+ """Sum of two columns."""
243+ schema = pa .schema ([("a" , pa .int64 ()), ("b" , pa .int64 ())])
244+ batch = pa .RecordBatch .from_pydict (
245+ {"a" : [1 , 2 , 3 ], "b" : [10 , 20 , 30 ]}, schema = schema
246+ )
247+
248+ with Client (example_worker ) as client :
249+ outputs = list (
250+ client .scalar_function (
251+ function_name = "sum_columns" ,
252+ input = iter ([batch ]),
253+ arguments = Arguments (positional = (pa .scalar ("a" ), pa .scalar ("b" ))),
254+ )
255+ )
256+
257+ assert len (outputs ) == 1
258+ assert outputs [0 ].to_pydict () == {"result" : [11 , 22 , 33 ]}
259+
260+ def test_sum_three_columns (self , example_worker : str ) -> None :
261+ """Sum of three columns using varargs."""
262+ schema = pa .schema ([("a" , pa .int64 ()), ("b" , pa .int64 ()), ("c" , pa .int64 ())])
263+ batch = pa .RecordBatch .from_pydict (
264+ {"a" : [1 , 2 ], "b" : [10 , 20 ], "c" : [100 , 200 ]}, schema = schema
265+ )
266+
267+ with Client (example_worker ) as client :
268+ outputs = list (
269+ client .scalar_function (
270+ function_name = "sum_columns" ,
271+ input = iter ([batch ]),
272+ arguments = Arguments (
273+ positional = (pa .scalar ("a" ), pa .scalar ("b" ), pa .scalar ("c" ))
274+ ),
275+ )
276+ )
277+
278+ assert len (outputs ) == 1
279+ assert outputs [0 ].to_pydict () == {"result" : [111 , 222 ]}
280+
281+ def test_sum_with_type_promotion (self , example_worker : str ) -> None :
282+ """Different int types promote correctly."""
283+ schema = pa .schema ([("a" , pa .int32 ()), ("b" , pa .int64 ())])
284+ batch = pa .RecordBatch .from_pydict ({"a" : [1 , 2 ], "b" : [10 , 20 ]}, schema = schema )
285+
286+ with Client (example_worker ) as client :
287+ outputs = list (
288+ client .scalar_function (
289+ function_name = "sum_columns" ,
290+ input = iter ([batch ]),
291+ arguments = Arguments (positional = (pa .scalar ("a" ), pa .scalar ("b" ))),
292+ )
293+ )
294+
295+ assert len (outputs ) == 1
296+ assert outputs [0 ].to_pydict () == {"result" : [11 , 22 ]}
297+ # Output should be int64 (promoted from int32)
298+ assert outputs [0 ].schema .field ("result" ).type == pa .int64 ()
299+
300+ def test_sum_rejects_string_column (self , example_worker : str ) -> None :
301+ """Type bound rejects non-numeric columns."""
302+ schema = pa .schema ([("a" , pa .int64 ()), ("b" , pa .string ())]) # type: ignore[arg-type]
303+ batch = pa .RecordBatch .from_pydict (
304+ {"a" : [1 , 2 ], "b" : ["x" , "y" ]}, schema = schema
305+ )
306+
307+ with (
308+ Client (example_worker ) as client ,
309+ pytest .raises (Exception , match = "does not match any of" ),
310+ ):
311+ list (
312+ client .scalar_function (
313+ function_name = "sum_columns" ,
314+ input = iter ([batch ]),
315+ arguments = Arguments (positional = (pa .scalar ("a" ), pa .scalar ("b" ))),
316+ )
317+ )
318+
319+ def test_sum_multiple_batches (self , example_worker : str ) -> None :
320+ """Multiple input batches processed correctly."""
321+ schema = pa .schema ([("a" , pa .int64 ()), ("b" , pa .int64 ())])
322+ batch1 = pa .RecordBatch .from_pydict ({"a" : [1 , 2 ], "b" : [10 , 20 ]}, schema = schema )
323+ batch2 = pa .RecordBatch .from_pydict ({"a" : [3 , 4 ], "b" : [30 , 40 ]}, schema = schema )
324+
325+ with Client (example_worker ) as client :
326+ outputs = list (
327+ client .scalar_function (
328+ function_name = "sum_columns" ,
329+ input = iter ([batch1 , batch2 ]),
330+ arguments = Arguments (positional = (pa .scalar ("a" ), pa .scalar ("b" ))),
331+ )
332+ )
333+
334+ assert_total_rows (outputs , 4 )
335+ all_values : list [int ] = []
336+ for batch in outputs :
337+ all_values .extend (cast (list [int ], batch .column ("result" ).to_pylist ()))
338+ assert sorted (all_values ) == [11 , 22 , 33 , 44 ]
339+
340+ def test_sum_empty_batch (self , example_worker : str ) -> None :
341+ """Empty batch returns empty output."""
342+ schema = pa .schema ([("a" , pa .int64 ()), ("b" , pa .int64 ())])
343+ empty_batch = pa .RecordBatch .from_pydict ({"a" : [], "b" : []}, schema = schema )
344+
345+ with Client (example_worker ) as client :
346+ outputs = list (
347+ client .scalar_function (
348+ function_name = "sum_columns" ,
349+ input = iter ([empty_batch ]),
350+ arguments = Arguments (positional = (pa .scalar ("a" ), pa .scalar ("b" ))),
351+ )
352+ )
353+
354+ assert len (outputs ) == 1
355+ assert outputs [0 ].num_rows == 0
356+
357+ def test_sum_float_columns (self , example_worker : str ) -> None :
358+ """Sum of float columns."""
359+ schema = pa .schema ([("a" , pa .float64 ()), ("b" , pa .float64 ())])
360+ batch = pa .RecordBatch .from_pydict (
361+ {"a" : [1.5 , 2.5 ], "b" : [0.5 , 0.5 ]}, schema = schema
362+ )
363+
364+ with Client (example_worker ) as client :
365+ outputs = list (
366+ client .scalar_function (
367+ function_name = "sum_columns" ,
368+ input = iter ([batch ]),
369+ arguments = Arguments (positional = (pa .scalar ("a" ), pa .scalar ("b" ))),
370+ )
371+ )
372+
373+ assert len (outputs ) == 1
374+ assert outputs [0 ].to_pydict () == {"result" : [2.0 , 3.0 ]}
375+
376+
238377class TestScalarFunctionParallel :
239378 """Tests for scalar functions with parallel processing."""
240379
0 commit comments