Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,6 +1661,11 @@ def visit_Expr(self, node: ast.Expr) -> object:
def visit_Constant(self, node: ast.Constant) -> object:
return node.value

def visit_Lambda(self, node: ast.Lambda) -> object:
assert isinstance(node, ExtendedAST)
assert isinstance(node._type_info, CallableType)
return node._type_info.proxy()


class LiftTensorArgs:
values: dict[str, object]
Expand Down
11 changes: 8 additions & 3 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2347,9 +2347,14 @@ def visit_Global(self, node: ast.Global) -> TypeInfo:
# Global statements don't need child visiting since they only declare names
return NoType(origin=self.origin())

# TODO(jansel): support lambda
# pyrefly: ignore [bad-assignment, bad-param-name-override]
visit_Lambda: _VisitMethod = generic_visit
def visit_Lambda(self, node: ast.Lambda) -> TypeInfo:
assert isinstance(node, ExtendedAST)
if isinstance(node._type_info, CallableType):
return node._type_info
source = f"lambda {ast.unparse(node.args)}: {ast.unparse(node.body)}"
code = compile(source, filename=node._location.filename, mode="eval")
func = eval(code, self.func.fn.__globals__)
Comment on lines +2354 to +2356
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach doesn't seam ideal. While it will handle basic cases it doesn't handle closures and also uses tracing so some other operations won't work.

I'd be ok with some basic support here, but we should make sure to give good error messages if people go outside what we support.

return CallableType(self.origin(), func)

################################################################
# Control flow
Expand Down
59 changes: 59 additions & 0 deletions test/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,5 +634,64 @@ def test_argmax_unpacked_kernel(
torch.testing.assert_close(result, pytorch_result)


# TODO(hinriksnaer): expand lambda reduce tests to pallas backend
@onlyBackends(["triton", "cute"])
class TestReduceLambda(RefEagerTestBase, TestCase):
"""Test lambda support as combine_fn in hl.reduce."""

def test_lambda_before_loop(self):
"""Lambda assigned at host level, before the tile loop."""

@helion.kernel(autotune_effort="none")
def kernel(x: torch.Tensor) -> torch.Tensor:
add_fn = lambda a, b: a + b # noqa: E731, FURB118
result = torch.empty([x.size(0)], dtype=x.dtype, device=x.device)
for i in hl.tile(x.size(0)):
result[i] = hl.reduce(add_fn, x[i, :], dim=1)
return result

x = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
device=DEVICE,
)
_, result = code_and_output(kernel, (x,))
torch.testing.assert_close(result, x.sum(dim=1))

def test_lambda_inside_loop(self):
"""Lambda assigned inside the tile loop body."""

@helion.kernel(autotune_effort="none")
def kernel(x: torch.Tensor) -> torch.Tensor:
result = torch.empty([x.size(0)], dtype=x.dtype, device=x.device)
for i in hl.tile(x.size(0)):
add_fn = lambda a, b: a + b # noqa: E731, FURB118
result[i] = hl.reduce(add_fn, x[i, :], dim=1)
return result

x = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
device=DEVICE,
)
_, result = code_and_output(kernel, (x,))
torch.testing.assert_close(result, x.sum(dim=1))

def test_lambda_inline(self):
"""Lambda passed directly as argument to hl.reduce."""

@helion.kernel(autotune_effort="none")
def kernel(x: torch.Tensor) -> torch.Tensor:
result = torch.empty([x.size(0)], dtype=x.dtype, device=x.device)
for i in hl.tile(x.size(0)):
result[i] = hl.reduce(lambda a, b: a + b, x[i, :], dim=1) # noqa: FURB118
return result

x = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
device=DEVICE,
)
_, result = code_and_output(kernel, (x,))
torch.testing.assert_close(result, x.sum(dim=1))


if __name__ == "__main__":
unittest.main()
Loading