Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/dialects/python/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ This page provides a reference for dialects that bring in semantics for common P
- "!statement"
show_root_heading: true
show_if_no_docstring: true

### Set

`set` support currently covers set literals and `set()` construction only. Mutation APIs
and `set(iterable)` are not part of this first version.

::: kirin.dialects.py.set
options:
filters:
- "!statement"
show_root_heading: true
show_if_no_docstring: true
3 changes: 3 additions & 0 deletions src/kirin/dialects/func/typeinfer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from collections.abc import Hashable

from kirin import ir, types
from kirin.interp import Frame, MethodTable, ReturnValue, impl
from kirin.analysis import const
Expand Down Expand Up @@ -30,6 +32,7 @@ def return_(
if (
isinstance(hint := stmt.value.hints.get("const"), const.Value)
and hint.data is not None
and isinstance(hint.data, Hashable)
):
return ReturnValue(types.Literal(hint.data, frame.get(stmt.value)))
return ReturnValue(frame.get(stmt.value))
Expand Down
6 changes: 6 additions & 0 deletions src/kirin/dialects/ilist/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from kirin import types, lowering
from kirin.dialects import py
from kirin.dialects.py._comprehension import lower_listcomp_via_desugaring

from . import stmts as ilist
from ._dialect import dialect
Expand All @@ -23,6 +24,11 @@ def lower_List(self, state: lowering.State, node: ast.List) -> lowering.Result:

return state.current_frame.push(ilist.New(values=tuple(elts), elem_type=typ))

def lower_ListComp(
self, state: lowering.State, node: ast.ListComp
) -> lowering.Result:
return lower_listcomp_via_desugaring(state, node)

@lowering.akin(ilist.IList)
def lower_Call_IList(
self, state: lowering.State, node: ast.Call
Expand Down
1 change: 1 addition & 0 deletions src/kirin/dialects/py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import (
cmp as cmp,
len as len,
set as set,
attr as attr,
base as base,
list as list,
Expand Down
108 changes: 108 additions & 0 deletions src/kirin/dialects/py/_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import ast

from kirin import lowering

_LISTCOMP_TMP_PREFIX = "_kirin_listcomp_tmp"
_SETCOMP_TMP_PREFIX = "_kirin_setcomp_tmp"


def lower_listcomp_via_desugaring(
state: lowering.State, node: ast.ListComp
) -> lowering.Result:
tmp_name = fresh_comp_name(state, _LISTCOMP_TMP_PREFIX)
init = ast.List(elts=[], ctx=ast.Load())
leaf = ast.Assign(
targets=[ast.Name(id=tmp_name, ctx=ast.Store())],
value=ast.BinOp(
left=ast.Name(id=tmp_name, ctx=ast.Load()),
op=ast.Add(),
right=ast.List(elts=[node.elt], ctx=ast.Load()),
),
)
fix_locations(leaf, node.elt)
return lower_comprehension_via_desugaring(
state=state,
tmp_name=tmp_name,
init_value=init,
generators=node.generators,
leaf_stmt=leaf,
ref_node=node,
)


def lower_setcomp_via_desugaring(
state: lowering.State, node: ast.SetComp
) -> lowering.Result:
tmp_name = fresh_comp_name(state, _SETCOMP_TMP_PREFIX)
init = ast.Call(func=ast.Name(id="set", ctx=ast.Load()), args=[], keywords=[])
leaf = ast.Assign(
targets=[ast.Name(id=tmp_name, ctx=ast.Store())],
value=ast.BinOp(
left=ast.Name(id=tmp_name, ctx=ast.Load()),
op=ast.BitOr(),
right=ast.Set(elts=[node.elt]),
),
)
fix_locations(leaf, node.elt)
return lower_comprehension_via_desugaring(
state=state,
tmp_name=tmp_name,
init_value=init,
generators=node.generators,
leaf_stmt=leaf,
ref_node=node,
)


def lower_comprehension_via_desugaring(
state: lowering.State,
tmp_name: str,
init_value: ast.expr,
generators: list[ast.comprehension],
leaf_stmt: ast.stmt,
ref_node: ast.AST,
) -> lowering.Result:
init = ast.Assign(
targets=[ast.Name(id=tmp_name, ctx=ast.Store())],
value=init_value,
)
fix_locations(init, ref_node)
state.lower(init)

for stmt in build_comprehension_stmts(generators, leaf_stmt):
state.lower(stmt)

result = ast.Name(id=tmp_name, ctx=ast.Load())
fix_locations(result, ref_node)
return state.lower(result).expect_one()


def fresh_comp_name(state: lowering.State, prefix: str) -> str:
frame = state.current_frame
idx = 0
while True:
suffix = "" if idx == 0 else f"_{idx}"
name = f"{prefix}{suffix}"
if frame.get_local(name) is None and name not in frame.globals:
return name
idx += 1


def build_comprehension_stmts(
generators: list[ast.comprehension], leaf_stmt: ast.stmt
) -> list[ast.stmt]:
acc = leaf_stmt
for gen in reversed(generators):
for if_ in reversed(gen.ifs):
acc = ast.If(test=if_, body=[acc], orelse=[])
fix_locations(acc, if_)

acc = ast.For(target=gen.target, iter=gen.iter, body=[acc], orelse=[])
fix_locations(acc, gen)

return [acc]


def fix_locations(node: ast.AST, ref: ast.AST) -> None:
ast.copy_location(node, ref)
ast.fix_missing_locations(node)
7 changes: 6 additions & 1 deletion src/kirin/dialects/py/list/lowering.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import ast

from kirin import types, lowering
from kirin.dialects.py._comprehension import lower_listcomp_via_desugaring

from .stmts import New
from ._dialect import dialect


@dialect.register
class PythonLowering(lowering.FromPythonAST):

def lower_List(self, state: lowering.State, node: ast.List) -> lowering.Result:
elts = tuple(state.lower(each).expect_one() for each in node.elts)

Expand All @@ -20,3 +20,8 @@ def lower_List(self, state: lowering.State, node: ast.List) -> lowering.Result:
typ = types.Any

return state.current_frame.push(New(values=tuple(elts)))

def lower_ListComp(
self, state: lowering.State, node: ast.ListComp
) -> lowering.Result:
return lower_listcomp_via_desugaring(state, node)
12 changes: 12 additions & 0 deletions src/kirin/dialects/py/set/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""The set dialect for Python.

This module contains the dialect for set semantics in Python, including:

- The `New` statement class.
- The lowering pass for set literals and `set()`.
- The concrete implementation of set operations.
"""

from . import interp as interp, lowering as lowering, typeinfer as typeinfer
from .stmts import New as New
from ._dialect import dialect as dialect
3 changes: 3 additions & 0 deletions src/kirin/dialects/py/set/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("py.set")
12 changes: 12 additions & 0 deletions src/kirin/dialects/py/set/interp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from kirin import interp

from .stmts import New
from ._dialect import dialect


@dialect.register
class SetMethods(interp.MethodTable):

@interp.impl(New)
def new(self, interp, frame: interp.Frame, stmt: New):
return (set(frame.get_values(stmt.values)),)
27 changes: 27 additions & 0 deletions src/kirin/dialects/py/set/lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import ast

from kirin import lowering
from kirin.dialects.py._comprehension import lower_setcomp_via_desugaring

from .stmts import New
from ._dialect import dialect


@dialect.register
class PythonLowering(lowering.FromPythonAST):

def lower_Set(self, state: lowering.State, node: ast.Set) -> lowering.Result:
return state.current_frame.push(
New(tuple(state.lower(each).expect_one() for each in node.elts))
)

def lower_SetComp(
self, state: lowering.State, node: ast.SetComp
) -> lowering.Result:
return lower_setcomp_via_desugaring(state, node)

@lowering.akin(set)
def lower_Call_set(self, state: lowering.State, node: ast.Call) -> lowering.Result:
if len(node.args) != 0 or len(node.keywords) != 0:
raise lowering.BuildError("`set(iterable)` is not supported yet")
return state.current_frame.push(New(()))
29 changes: 29 additions & 0 deletions src/kirin/dialects/py/set/stmts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Sequence

from kirin import ir, types, lowering
from kirin.decl import info, statement

from ._dialect import dialect

T = types.TypeVar("T")


@statement(dialect=dialect, init=False)
class New(ir.Statement):
name = "set"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
values: tuple[ir.SSAValue, ...] = info.argument(T)
result: ir.ResultValue = info.result(types.Set[T])

def __init__(self, values: Sequence[ir.SSAValue]) -> None:
elem_type: types.TypeAttribute = types.Any
if values:
elem_type = values[0].type
for value in values[1:]:
elem_type = elem_type.join(value.type)

super().__init__(
args=values,
result_types=(types.Set[elem_type],),
args_slice={"values": slice(0, len(values))},
)
12 changes: 12 additions & 0 deletions src/kirin/dialects/py/set/typeinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from kirin import interp

from .stmts import New
from ._dialect import dialect


@dialect.register(key="typeinfer")
class TypeInfer(interp.MethodTable):

@interp.impl(New)
def new(self, interp, frame, stmt: New):
return (stmt.result.type,)
1 change: 0 additions & 1 deletion src/kirin/dialects/scf/typeinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

@dialect.register(key="typeinfer")
class TypeInfer(absint.Methods):

@interp.impl(IfElse)
def if_else_(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/kirin/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from kirin.dialects.py import (
cmp,
len,
set,
attr,
base,
list,
Expand Down Expand Up @@ -60,6 +61,7 @@ def run_pass(mt: Method) -> None:
python_basic.union(
[
list,
set,
slice,
cf,
func,
Expand All @@ -85,6 +87,7 @@ def run_pass(mt: Method) -> None:
python_basic.union(
[
ilist,
set,
slice,
cf,
func,
Expand Down Expand Up @@ -179,6 +182,7 @@ def run_pass(
python_basic.union(
[
ilist,
set,
slice,
scf,
cf,
Expand All @@ -204,6 +208,7 @@ def run_pass(method: Method) -> None:
python_basic.union(
[
ilist,
set,
slice,
scf,
cf,
Expand Down
Loading
Loading