Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/kirin/dialects/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def isnan(x: float) -> bool: ...
def lgamma(x: float) -> float: ...


@lowering.wraps(stmts.log)
def log(x: float, base: float) -> float: ...


@lowering.wraps(stmts.log10)
def log10(x: float) -> float: ...

Expand Down
36 changes: 30 additions & 6 deletions src/kirin/dialects/math/_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def builtin_math_functions():
# 3.10 compat
"cbrt",
"exp2",
# 3.13 compat
"fma",
):
continue

Expand All @@ -40,12 +42,32 @@ def builtin_math_functions():
sig = inspect.signature(obj)
yield name, obj, sig
except: # noqa: E722
continue
if name == "log":
sig = inspect.Signature(
parameters=[
inspect.Parameter(
"x",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=object,
),
inspect.Parameter(
"base",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=None,
annotation=object,
),
],
return_annotation=object,
)
yield name, obj, sig

else:
continue


with open(os.path.join(os.path.dirname(__file__), "stmts.py"), "w") as f:
f.write("# This file is generated by gen.py\n")
f.write("from kirin import ir, types, lowering2\n")
f.write("from kirin import ir, types, lowering\n")
f.write("from kirin.decl import statement, info\n")
f.write("from kirin.dialects.math.dialect import dialect\n")
f.write("\n")
Expand All @@ -58,6 +80,8 @@ def builtin_math_functions():
)
if "is" in name:
ret_type = "types.Bool"
elif name in {"trunc", "ceil", "floor"}:
ret_type = "types.Int"
else:
ret_type = "types.Float"
f.write(textwrap.dedent(f"""
Expand All @@ -66,7 +90,7 @@ class {name}(ir.Statement):
\"\"\"{name} statement, wrapping the math.{name} function
\"\"\"
name = "{name}"
traits = frozenset({{ir.Pure(), lowering2.FromPythonCall()}})
traits = frozenset({{ir.Pure(), lowering.FromPythonCall()}})
{fields}
result: ir.ResultValue = info.result({ret_type})
"""))
Expand Down Expand Up @@ -109,7 +133,7 @@ class MathMethodTable(MethodTable):
f.write("pi = pymath.pi\n")
f.write("e = pymath.e\n")
f.write("tau = pymath.tau\n")
f.write("from kirin import lowering2\n")
f.write("from kirin import lowering\n")

for name, obj, sig in builtin_math_functions():
if "is" in name:
Expand All @@ -119,8 +143,8 @@ class MathMethodTable(MethodTable):
else:
ret_type = "float"
f.write(textwrap.dedent(f"""
@lowering2.wraps(stmts.{name})
def {name}({", ".join(f"{arg}: {ret_type}" for arg in sig.parameters.keys())}) -> {ret_type}: ...
@lowering.wraps(stmts.{name})
def {name}({", ".join(f"{arg}: float" for arg in sig.parameters.keys())}) -> {ret_type}: ...
"""))
f.write("\n")

Expand Down
5 changes: 5 additions & 0 deletions src/kirin/dialects/math/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def lgamma(self, interp, frame: Frame, stmt: stmts.lgamma):
values = frame.get_values(stmt.args)
return (math.lgamma(values[0]),)

@impl(stmts.log)
def log(self, interp, frame: Frame, stmt: stmts.log):
values = frame.get_values(stmt.args)
return (math.log(values[0], values[1]),)

@impl(stmts.log10)
def log10(self, interp, frame: Frame, stmt: stmts.log10):
values = frame.get_values(stmt.args)
Expand Down
11 changes: 11 additions & 0 deletions src/kirin/dialects/math/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@ class lgamma(ir.Statement):
result: ir.ResultValue = info.result(types.Float)


@statement(dialect=dialect)
class log(ir.Statement):
"""log statement, wrapping the math.log function"""

name = "log"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
x: ir.SSAValue = info.argument(types.Float)
base: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(types.Float)


@statement(dialect=dialect)
class log10(ir.Statement):
"""log10 statement, wrapping the math.log10 function"""
Expand Down
Loading