diff --git a/docs/dialects/python/data.md b/docs/dialects/python/data.md index 7274a8de0f..3bbd9cf557 100644 --- a/docs/dialects/python/data.md +++ b/docs/dialects/python/data.md @@ -30,3 +30,15 @@ This page provides a reference for dialects that bring in semantics for common P - "!statement" show_root_heading: true show_if_no_docstring: true + +### Set + +`set` support currently covers set literals and `set()` construction only. Mutation APIs +and `set(iterable)` are not part of this first version. + +::: kirin.dialects.py.set + options: + filters: + - "!statement" + show_root_heading: true + show_if_no_docstring: true diff --git a/src/kirin/dialects/func/typeinfer.py b/src/kirin/dialects/func/typeinfer.py index 862fdcf454..5b2f64977c 100644 --- a/src/kirin/dialects/func/typeinfer.py +++ b/src/kirin/dialects/func/typeinfer.py @@ -1,5 +1,7 @@ from __future__ import annotations +from collections.abc import Hashable + from kirin import ir, types from kirin.interp import Frame, MethodTable, ReturnValue, impl from kirin.analysis import const @@ -30,6 +32,7 @@ def return_( if ( isinstance(hint := stmt.value.hints.get("const"), const.Value) and hint.data is not None + and isinstance(hint.data, Hashable) ): return ReturnValue(types.Literal(hint.data, frame.get(stmt.value))) return ReturnValue(frame.get(stmt.value)) diff --git a/src/kirin/dialects/ilist/lowering.py b/src/kirin/dialects/ilist/lowering.py index 6548841a76..24557e3002 100644 --- a/src/kirin/dialects/ilist/lowering.py +++ b/src/kirin/dialects/ilist/lowering.py @@ -3,6 +3,7 @@ from kirin import types, lowering from kirin.dialects import py +from kirin.dialects.py._comprehension import lower_listcomp_via_desugaring from . import stmts as ilist from ._dialect import dialect @@ -23,6 +24,11 @@ def lower_List(self, state: lowering.State, node: ast.List) -> lowering.Result: return state.current_frame.push(ilist.New(values=tuple(elts), elem_type=typ)) + def lower_ListComp( + self, state: lowering.State, node: ast.ListComp + ) -> lowering.Result: + return lower_listcomp_via_desugaring(state, node) + @lowering.akin(ilist.IList) def lower_Call_IList( self, state: lowering.State, node: ast.Call diff --git a/src/kirin/dialects/py/__init__.py b/src/kirin/dialects/py/__init__.py index 9dfa6ef9c0..71e00fa27f 100644 --- a/src/kirin/dialects/py/__init__.py +++ b/src/kirin/dialects/py/__init__.py @@ -9,6 +9,7 @@ from . import ( cmp as cmp, len as len, + set as set, attr as attr, base as base, list as list, diff --git a/src/kirin/dialects/py/_comprehension.py b/src/kirin/dialects/py/_comprehension.py new file mode 100644 index 0000000000..360c4ea162 --- /dev/null +++ b/src/kirin/dialects/py/_comprehension.py @@ -0,0 +1,108 @@ +import ast + +from kirin import lowering + +_LISTCOMP_TMP_PREFIX = "_kirin_listcomp_tmp" +_SETCOMP_TMP_PREFIX = "_kirin_setcomp_tmp" + + +def lower_listcomp_via_desugaring( + state: lowering.State, node: ast.ListComp +) -> lowering.Result: + tmp_name = fresh_comp_name(state, _LISTCOMP_TMP_PREFIX) + init = ast.List(elts=[], ctx=ast.Load()) + leaf = ast.Assign( + targets=[ast.Name(id=tmp_name, ctx=ast.Store())], + value=ast.BinOp( + left=ast.Name(id=tmp_name, ctx=ast.Load()), + op=ast.Add(), + right=ast.List(elts=[node.elt], ctx=ast.Load()), + ), + ) + fix_locations(leaf, node.elt) + return lower_comprehension_via_desugaring( + state=state, + tmp_name=tmp_name, + init_value=init, + generators=node.generators, + leaf_stmt=leaf, + ref_node=node, + ) + + +def lower_setcomp_via_desugaring( + state: lowering.State, node: ast.SetComp +) -> lowering.Result: + tmp_name = fresh_comp_name(state, _SETCOMP_TMP_PREFIX) + init = ast.Call(func=ast.Name(id="set", ctx=ast.Load()), args=[], keywords=[]) + leaf = ast.Assign( + targets=[ast.Name(id=tmp_name, ctx=ast.Store())], + value=ast.BinOp( + left=ast.Name(id=tmp_name, ctx=ast.Load()), + op=ast.BitOr(), + right=ast.Set(elts=[node.elt]), + ), + ) + fix_locations(leaf, node.elt) + return lower_comprehension_via_desugaring( + state=state, + tmp_name=tmp_name, + init_value=init, + generators=node.generators, + leaf_stmt=leaf, + ref_node=node, + ) + + +def lower_comprehension_via_desugaring( + state: lowering.State, + tmp_name: str, + init_value: ast.expr, + generators: list[ast.comprehension], + leaf_stmt: ast.stmt, + ref_node: ast.AST, +) -> lowering.Result: + init = ast.Assign( + targets=[ast.Name(id=tmp_name, ctx=ast.Store())], + value=init_value, + ) + fix_locations(init, ref_node) + state.lower(init) + + for stmt in build_comprehension_stmts(generators, leaf_stmt): + state.lower(stmt) + + result = ast.Name(id=tmp_name, ctx=ast.Load()) + fix_locations(result, ref_node) + return state.lower(result).expect_one() + + +def fresh_comp_name(state: lowering.State, prefix: str) -> str: + frame = state.current_frame + idx = 0 + while True: + suffix = "" if idx == 0 else f"_{idx}" + name = f"{prefix}{suffix}" + if frame.get_local(name) is None and name not in frame.globals: + return name + idx += 1 + + +def build_comprehension_stmts( + generators: list[ast.comprehension], leaf_stmt: ast.stmt +) -> list[ast.stmt]: + acc = leaf_stmt + for gen in reversed(generators): + for if_ in reversed(gen.ifs): + acc = ast.If(test=if_, body=[acc], orelse=[]) + fix_locations(acc, if_) + + acc = ast.For(target=gen.target, iter=gen.iter, body=[acc], orelse=[]) + fix_locations(acc, gen) + + return [acc] + + +def fix_locations(node: ast.AST, ref: ast.AST) -> None: + ast.copy_location(node, ref) + ast.fix_missing_locations(node) diff --git a/src/kirin/dialects/py/list/lowering.py b/src/kirin/dialects/py/list/lowering.py index 01aeff1ea3..bf8d5bfc69 100644 --- a/src/kirin/dialects/py/list/lowering.py +++ b/src/kirin/dialects/py/list/lowering.py @@ -1,6 +1,7 @@ import ast from kirin import types, lowering +from kirin.dialects.py._comprehension import lower_listcomp_via_desugaring from .stmts import New from ._dialect import dialect @@ -8,7 +9,6 @@ @dialect.register class PythonLowering(lowering.FromPythonAST): - def lower_List(self, state: lowering.State, node: ast.List) -> lowering.Result: elts = tuple(state.lower(each).expect_one() for each in node.elts) @@ -20,3 +20,8 @@ def lower_List(self, state: lowering.State, node: ast.List) -> lowering.Result: typ = types.Any return state.current_frame.push(New(values=tuple(elts))) + + def lower_ListComp( + self, state: lowering.State, node: ast.ListComp + ) -> lowering.Result: + return lower_listcomp_via_desugaring(state, node) diff --git a/src/kirin/dialects/py/set/__init__.py b/src/kirin/dialects/py/set/__init__.py new file mode 100644 index 0000000000..3a0631ace2 --- /dev/null +++ b/src/kirin/dialects/py/set/__init__.py @@ -0,0 +1,12 @@ +"""The set dialect for Python. + +This module contains the dialect for set semantics in Python, including: + +- The `New` statement class. +- The lowering pass for set literals and `set()`. +- The concrete implementation of set operations. +""" + +from . import interp as interp, lowering as lowering, typeinfer as typeinfer +from .stmts import New as New +from ._dialect import dialect as dialect diff --git a/src/kirin/dialects/py/set/_dialect.py b/src/kirin/dialects/py/set/_dialect.py new file mode 100644 index 0000000000..8f4d241a71 --- /dev/null +++ b/src/kirin/dialects/py/set/_dialect.py @@ -0,0 +1,3 @@ +from kirin import ir + +dialect = ir.Dialect("py.set") diff --git a/src/kirin/dialects/py/set/interp.py b/src/kirin/dialects/py/set/interp.py new file mode 100644 index 0000000000..c17795bc7c --- /dev/null +++ b/src/kirin/dialects/py/set/interp.py @@ -0,0 +1,12 @@ +from kirin import interp + +from .stmts import New +from ._dialect import dialect + + +@dialect.register +class SetMethods(interp.MethodTable): + + @interp.impl(New) + def new(self, interp, frame: interp.Frame, stmt: New): + return (set(frame.get_values(stmt.values)),) diff --git a/src/kirin/dialects/py/set/lowering.py b/src/kirin/dialects/py/set/lowering.py new file mode 100644 index 0000000000..8e636b5b49 --- /dev/null +++ b/src/kirin/dialects/py/set/lowering.py @@ -0,0 +1,27 @@ +import ast + +from kirin import lowering +from kirin.dialects.py._comprehension import lower_setcomp_via_desugaring + +from .stmts import New +from ._dialect import dialect + + +@dialect.register +class PythonLowering(lowering.FromPythonAST): + + def lower_Set(self, state: lowering.State, node: ast.Set) -> lowering.Result: + return state.current_frame.push( + New(tuple(state.lower(each).expect_one() for each in node.elts)) + ) + + def lower_SetComp( + self, state: lowering.State, node: ast.SetComp + ) -> lowering.Result: + return lower_setcomp_via_desugaring(state, node) + + @lowering.akin(set) + def lower_Call_set(self, state: lowering.State, node: ast.Call) -> lowering.Result: + if len(node.args) != 0 or len(node.keywords) != 0: + raise lowering.BuildError("`set(iterable)` is not supported yet") + return state.current_frame.push(New(())) diff --git a/src/kirin/dialects/py/set/stmts.py b/src/kirin/dialects/py/set/stmts.py new file mode 100644 index 0000000000..c048d21540 --- /dev/null +++ b/src/kirin/dialects/py/set/stmts.py @@ -0,0 +1,29 @@ +from typing import Sequence + +from kirin import ir, types, lowering +from kirin.decl import info, statement + +from ._dialect import dialect + +T = types.TypeVar("T") + + +@statement(dialect=dialect, init=False) +class New(ir.Statement): + name = "set" + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + values: tuple[ir.SSAValue, ...] = info.argument(T) + result: ir.ResultValue = info.result(types.Set[T]) + + def __init__(self, values: Sequence[ir.SSAValue]) -> None: + elem_type: types.TypeAttribute = types.Any + if values: + elem_type = values[0].type + for value in values[1:]: + elem_type = elem_type.join(value.type) + + super().__init__( + args=values, + result_types=(types.Set[elem_type],), + args_slice={"values": slice(0, len(values))}, + ) diff --git a/src/kirin/dialects/py/set/typeinfer.py b/src/kirin/dialects/py/set/typeinfer.py new file mode 100644 index 0000000000..3cf4306a05 --- /dev/null +++ b/src/kirin/dialects/py/set/typeinfer.py @@ -0,0 +1,12 @@ +from kirin import interp + +from .stmts import New +from ._dialect import dialect + + +@dialect.register(key="typeinfer") +class TypeInfer(interp.MethodTable): + + @interp.impl(New) + def new(self, interp, frame, stmt: New): + return (stmt.result.type,) diff --git a/src/kirin/dialects/scf/typeinfer.py b/src/kirin/dialects/scf/typeinfer.py index ccac6b985a..7cb33f35f0 100644 --- a/src/kirin/dialects/scf/typeinfer.py +++ b/src/kirin/dialects/scf/typeinfer.py @@ -10,7 +10,6 @@ @dialect.register(key="typeinfer") class TypeInfer(absint.Methods): - @interp.impl(IfElse) def if_else_( self, diff --git a/src/kirin/prelude.py b/src/kirin/prelude.py index bad0cc4ea7..a18a03e462 100644 --- a/src/kirin/prelude.py +++ b/src/kirin/prelude.py @@ -11,6 +11,7 @@ from kirin.dialects.py import ( cmp, len, + set, attr, base, list, @@ -60,6 +61,7 @@ def run_pass(mt: Method) -> None: python_basic.union( [ list, + set, slice, cf, func, @@ -85,6 +87,7 @@ def run_pass(mt: Method) -> None: python_basic.union( [ ilist, + set, slice, cf, func, @@ -179,6 +182,7 @@ def run_pass( python_basic.union( [ ilist, + set, slice, scf, cf, @@ -204,6 +208,7 @@ def run_pass(method: Method) -> None: python_basic.union( [ ilist, + set, slice, scf, cf, diff --git a/test/dialects/py_dialect/test_list_comp.py b/test/dialects/py_dialect/test_list_comp.py new file mode 100644 index 0000000000..9ff303d3cd --- /dev/null +++ b/test/dialects/py_dialect/test_list_comp.py @@ -0,0 +1,76 @@ +from kirin.prelude import basic, structural, structural_no_opt +from kirin.dialects import py, ilist + + +@basic +def simple(): + return [x for x in range(3)] + + +@basic +def filtered(): + return [x for x in range(5) if x % 2 == 0] + + +@structural_no_opt +def nested(): + return [(x, y) for x in range(2) for y in range(3) if y] + + +@structural +def structural_simple(): + return [x for x in range(3)] + + +@structural +def structural_nested(): + return [(x, y) for x in range(2) for y in range(3) if y] + + +@basic +def with_arg(i, j): + return [(x, y) for x in range(i) for y in range(j) if y] + + +@basic +def temp_name_collision(): + _kirin_listcomp_tmp = 99 + return _kirin_listcomp_tmp, [x for x in range(2)] + + +@basic.add(py.unpack) +def unpacking_target(): + pairs = [(1, 2), (3, 4)] + return [a + b for a, b in pairs] + + +def test_with_arg(): + assert with_arg(2, 3) == ilist.IList([(0, 1), (0, 2), (1, 1), (1, 2)]) + + +def test_simple_runtime(): + assert simple() == ilist.IList([0, 1, 2]) + + +def test_filtered_runtime(): + assert filtered() == ilist.IList([0, 2, 4]) + + +def test_nested_runtime(): + assert nested() == ilist.IList([(0, 1), (0, 2), (1, 1), (1, 2)]) + + +def test_structural_simple_runtime(): + assert structural_simple() == ilist.IList([0, 1, 2]) + + +def test_structural_nested_runtime(): + assert structural_nested() == ilist.IList([(0, 1), (0, 2), (1, 1), (1, 2)]) + + +def test_temp_name_collision(): + assert temp_name_collision() == (99, ilist.IList([0, 1])) + + +def test_unpacking_target(): + assert unpacking_target() == ilist.IList([3, 7]) diff --git a/test/dialects/py_dialect/test_set.py b/test/dialects/py_dialect/test_set.py new file mode 100644 index 0000000000..ed24e91890 --- /dev/null +++ b/test/dialects/py_dialect/test_set.py @@ -0,0 +1,80 @@ +from kirin.prelude import basic, structural +from kirin.dialects import py + + +@basic +def make_set(): + return {1, 1, 2} + + +@basic +def make_empty_set(): + return set() + + +@basic +def comp_simple(): + return {x for x in range(3)} + + +@basic +def comp_filtered(): + return {x for x in range(5) if x % 2 == 0} + + +@structural +def comp_nested(): + return {(x, y) for x in range(2) for y in range(3) if y} + + +@basic +def comp_dedup(): + return {x % 2 for x in range(5)} + + +@basic +def comp_temp_name_collision(): + _kirin_setcomp_tmp = 99 + return _kirin_setcomp_tmp, {x for x in range(2)} + + +@basic.add(py.unpack) +def comp_unpacking(): + pairs = [(1, 2), (3, 4)] + return {a + b for a, b in pairs} + + +def test_set_runtime_result(): + out = make_set() + assert isinstance(out, set) + assert out == {1, 2} + + +def test_empty_set_runtime_result(): + out = make_empty_set() + assert isinstance(out, set) + assert out == set() + + +def test_set_comp_runtime_simple(): + assert comp_simple() == {0, 1, 2} + + +def test_set_comp_runtime_filtered(): + assert comp_filtered() == {0, 2, 4} + + +def test_set_comp_runtime_nested(): + assert comp_nested() == {(0, 1), (0, 2), (1, 1), (1, 2)} + + +def test_set_comp_runtime_dedup(): + assert comp_dedup() == {0, 1} + + +def test_set_comp_temp_name_collision(): + assert comp_temp_name_collision() == (99, {0, 1}) + + +def test_set_comp_unpacking(): + assert comp_unpacking() == {3, 7} diff --git a/test/dialects/py_dialect/test_set_infer.py b/test/dialects/py_dialect/test_set_infer.py new file mode 100644 index 0000000000..dda277ac84 --- /dev/null +++ b/test/dialects/py_dialect/test_set_infer.py @@ -0,0 +1,35 @@ +from kirin import ir, types as ktypes +from kirin.prelude import structural +from kirin.analysis import TypeInference +from kirin.dialects import py + + +def set_stmt_result(kernel: ir.Method): + stmt = next( + stmt for stmt in kernel.code.body.blocks[0].stmts if isinstance(stmt, py.set.New) # type: ignore + ) + return stmt.results[0] + + +def test_set_type_infer_homogeneous(): + + @structural(typeinfer=True, fold=False) + def test(): + return {1, 2} + + typeinfer = TypeInference(structural) + frame, _ = typeinfer.run(test) + + assert frame.entries[set_stmt_result(test)] == ktypes.Set[ktypes.Int] + + +def test_set_type_infer_empty(): + + @structural(typeinfer=True, fold=False) + def test(): + return set() + + typeinfer = TypeInference(structural) + frame, _ = typeinfer.run(test) + + assert frame.entries[set_stmt_result(test)] == ktypes.Set[ktypes.Any] diff --git a/test/lowering/test_list_comp.py b/test/lowering/test_list_comp.py new file mode 100644 index 0000000000..7e44868f91 --- /dev/null +++ b/test/lowering/test_list_comp.py @@ -0,0 +1,29 @@ +from kirin import lowering +from kirin.prelude import basic_no_opt, structural_no_opt + + +def test_list_comp_lowers_with_cf(): + def main(): + return [x for x in range(3)] + + code = lowering.Python(basic_no_opt).python_function(main) + + assert code is not None + + +def test_list_comp_lowers_with_scf(): + def main(): + return [x for x in range(4) if x] + + code = lowering.Python(structural_no_opt).python_function(main) + + assert code is not None + + +def test_list_comp_nested_generators_lower(): + def main(): + return [(x, y) for x in range(2) for y in range(3) if y] + + code = lowering.Python(basic_no_opt).python_function(main) + + assert code is not None diff --git a/test/lowering/test_set.py b/test/lowering/test_set.py new file mode 100644 index 0000000000..f2c6694586 --- /dev/null +++ b/test/lowering/test_set.py @@ -0,0 +1,77 @@ +from kirin import types, lowering +from kirin.prelude import basic_no_opt, structural_no_opt +from kirin.dialects import cf, py, func +from kirin.dialects.lowering import func as func_lowering +from kirin.dialects.py._comprehension import lower_setcomp_via_desugaring + +lower = lowering.Python( + [cf, func, py.base, py.constant, py.set, py.assign, func_lowering] +) + + +def test_set_literal_lowers_to_new(): + + def set_literal(): + x = {1, 2} + return x + + code = lower.python_function(set_literal) + + set_stmt = next( + stmt for stmt in code.body.blocks[0].stmts if isinstance(stmt, py.set.New) # type: ignore + ) + + assert isinstance(set_stmt, py.set.New) + assert len(set_stmt.values) == 2 + assert set_stmt.result.type.is_subseteq(types.Set) + + +def test_empty_set_call_lowers_to_new(): + + def empty_set(): + x = set() + return x + + code = lower.python_function(empty_set) + + set_stmt = next( + stmt for stmt in code.body.blocks[0].stmts if isinstance(stmt, py.set.New) # type: ignore + ) + + assert isinstance(set_stmt, py.set.New) + assert len(set_stmt.values) == 0 + assert set_stmt.result.type.is_subseteq(types.Set) + + +def test_set_comp_lowers_with_cf(): + + def main(): + return {x for x in range(3)} + + code = lowering.Python(basic_no_opt).python_function(main) + + assert code is not None + + +def test_set_comp_lowers_with_scf(): + + def main(): + return {x for x in range(4) if x} + + code = lowering.Python(structural_no_opt).python_function(main) + + assert code is not None + + +def test_set_comp_nested_generators_lower(): + + def main(): + return {(x, y) for x in range(2) for y in range(3) if y} + + code = lowering.Python(basic_no_opt).python_function(main) + + assert code is not None + + +def test_set_comp_helper_import(): + assert callable(lower_setcomp_via_desugaring)