From 7194eb7f0cb5ecd1bf3f724186d279b24229f3f1 Mon Sep 17 00:00:00 2001 From: Advaith Anand Date: Fri, 3 Apr 2026 21:36:53 -0700 Subject: [PATCH 1/5] first commit --- src/kirin/dialects/math/__init__.py | 133 ++++------ src/kirin/dialects/math/_gen.py | 22 +- src/kirin/dialects/math/interp.py | 52 +++- src/kirin/dialects/math/stmts.py | 360 +++++++++++++--------------- test/dialects/math/test_basic.py | 85 ++----- 5 files changed, 308 insertions(+), 344 deletions(-) diff --git a/src/kirin/dialects/math/__init__.py b/src/kirin/dialects/math/__init__.py index 5facdaac3..4e48d7e18 100644 --- a/src/kirin/dialects/math/__init__.py +++ b/src/kirin/dialects/math/__init__.py @@ -1,156 +1,123 @@ -"math dialect, modeling functions in python's `math` stdlib" # This file is generated by gen.py - -import math as pymath - -from kirin import lowering +"math dialect, modeling functions in python's `math` stdlib"# This file is generated by gen.py from kirin.dialects.math.dialect import dialect as dialect - from . import stmts as stmts, interp as interp - +import math as pymath pi = pymath.pi e = pymath.e tau = pymath.tau +from kirin import lowering2 - -@lowering.wraps(stmts.acos) +@lowering2.wraps(stmts.acos) def acos(x: float) -> float: ... - -@lowering.wraps(stmts.asin) +@lowering2.wraps(stmts.asin) def asin(x: float) -> float: ... - -@lowering.wraps(stmts.asinh) +@lowering2.wraps(stmts.asinh) def asinh(x: float) -> float: ... - -@lowering.wraps(stmts.atan) +@lowering2.wraps(stmts.atan) def atan(x: float) -> float: ... - -@lowering.wraps(stmts.atan2) +@lowering2.wraps(stmts.atan2) def atan2(y: float, x: float) -> float: ... - -@lowering.wraps(stmts.atanh) +@lowering2.wraps(stmts.atanh) def atanh(x: float) -> float: ... +@lowering2.wraps(stmts.ceil) +def ceil(x: int) -> int: ... -@lowering.wraps(stmts.ceil) -def ceil(x: float) -> int: ... - - -@lowering.wraps(stmts.copysign) +@lowering2.wraps(stmts.copysign) def copysign(x: float, y: float) -> float: ... - -@lowering.wraps(stmts.cos) +@lowering2.wraps(stmts.cos) def cos(x: float) -> float: ... - -@lowering.wraps(stmts.cosh) +@lowering2.wraps(stmts.cosh) def cosh(x: float) -> float: ... - -@lowering.wraps(stmts.degrees) +@lowering2.wraps(stmts.degrees) def degrees(x: float) -> float: ... - -@lowering.wraps(stmts.erf) +@lowering2.wraps(stmts.erf) def erf(x: float) -> float: ... - -@lowering.wraps(stmts.erfc) +@lowering2.wraps(stmts.erfc) def erfc(x: float) -> float: ... - -@lowering.wraps(stmts.exp) +@lowering2.wraps(stmts.exp) def exp(x: float) -> float: ... - -@lowering.wraps(stmts.expm1) +@lowering2.wraps(stmts.expm1) def expm1(x: float) -> float: ... - -@lowering.wraps(stmts.fabs) +@lowering2.wraps(stmts.fabs) def fabs(x: float) -> float: ... +@lowering2.wraps(stmts.floor) +def floor(x: int) -> int: ... -@lowering.wraps(stmts.floor) -def floor(x: float) -> int: ... - +@lowering2.wraps(stmts.fma) +def fma(x: float, y: float, z: float) -> float: ... -@lowering.wraps(stmts.fmod) +@lowering2.wraps(stmts.fmod) def fmod(x: float, y: float) -> float: ... - -@lowering.wraps(stmts.gamma) +@lowering2.wraps(stmts.gamma) def gamma(x: float) -> float: ... +@lowering2.wraps(stmts.isfinite) +def isfinite(x: bool) -> bool: ... -@lowering.wraps(stmts.isfinite) -def isfinite(x: float) -> bool: ... - - -@lowering.wraps(stmts.isinf) -def isinf(x: float) -> bool: ... +@lowering2.wraps(stmts.isinf) +def isinf(x: bool) -> bool: ... +@lowering2.wraps(stmts.isnan) +def isnan(x: bool) -> bool: ... -@lowering.wraps(stmts.isnan) -def isnan(x: float) -> bool: ... - - -@lowering.wraps(stmts.lgamma) +@lowering2.wraps(stmts.lgamma) def lgamma(x: float) -> float: ... +@lowering2.wraps(stmts.log) +def log(x: float, base: float) -> float: ... -@lowering.wraps(stmts.log10) +@lowering2.wraps(stmts.log10) def log10(x: float) -> float: ... - -@lowering.wraps(stmts.log1p) +@lowering2.wraps(stmts.log1p) def log1p(x: float) -> float: ... - -@lowering.wraps(stmts.log2) +@lowering2.wraps(stmts.log2) def log2(x: float) -> float: ... - -@lowering.wraps(stmts.pow) +@lowering2.wraps(stmts.pow) def pow(x: float, y: float) -> float: ... - -@lowering.wraps(stmts.radians) +@lowering2.wraps(stmts.radians) def radians(x: float) -> float: ... - -@lowering.wraps(stmts.remainder) +@lowering2.wraps(stmts.remainder) def remainder(x: float, y: float) -> float: ... - -@lowering.wraps(stmts.sin) +@lowering2.wraps(stmts.sin) def sin(x: float) -> float: ... - -@lowering.wraps(stmts.sinh) +@lowering2.wraps(stmts.sinh) def sinh(x: float) -> float: ... - -@lowering.wraps(stmts.sqrt) +@lowering2.wraps(stmts.sqrt) def sqrt(x: float) -> float: ... - -@lowering.wraps(stmts.tan) +@lowering2.wraps(stmts.tan) def tan(x: float) -> float: ... - -@lowering.wraps(stmts.tanh) +@lowering2.wraps(stmts.tanh) def tanh(x: float) -> float: ... +@lowering2.wraps(stmts.trunc) +def trunc(x: int) -> int: ... -@lowering.wraps(stmts.trunc) -def trunc(x: float) -> int: ... - - -@lowering.wraps(stmts.ulp) +@lowering2.wraps(stmts.ulp) def ulp(x: float) -> float: ... + diff --git a/src/kirin/dialects/math/_gen.py b/src/kirin/dialects/math/_gen.py index 63be0b4d9..248f8222d 100644 --- a/src/kirin/dialects/math/_gen.py +++ b/src/kirin/dialects/math/_gen.py @@ -40,7 +40,27 @@ def builtin_math_functions(): sig = inspect.signature(obj) yield name, obj, sig except: # noqa: E722 - continue + if name == "log": + sig = inspect.Signature( + parameters=[ + inspect.Parameter( + "x", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=object, + ), + inspect.Parameter( + "base", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=None, + annotation=object, + ), + ], + return_annotation=object, + ) + yield name, obj, sig + + else: + continue with open(os.path.join(os.path.dirname(__file__), "stmts.py"), "w") as f: diff --git a/src/kirin/dialects/math/interp.py b/src/kirin/dialects/math/interp.py index f99d713a1..1943cd1ea 100644 --- a/src/kirin/dialects/math/interp.py +++ b/src/kirin/dialects/math/interp.py @@ -1,9 +1,8 @@ # This file is generated by gen.py import math - -from kirin.interp import Frame, MethodTable, impl -from kirin.dialects.math import stmts from kirin.dialects.math.dialect import dialect +from kirin.dialects.math import stmts +from kirin.interp import MethodTable, Frame, impl @dialect.register @@ -14,176 +13,223 @@ def acos(self, interp, frame: Frame, stmt: stmts.acos): values = frame.get_values(stmt.args) return (math.acos(values[0]),) + @impl(stmts.asin) def asin(self, interp, frame: Frame, stmt: stmts.asin): values = frame.get_values(stmt.args) return (math.asin(values[0]),) + @impl(stmts.asinh) def asinh(self, interp, frame: Frame, stmt: stmts.asinh): values = frame.get_values(stmt.args) return (math.asinh(values[0]),) + @impl(stmts.atan) def atan(self, interp, frame: Frame, stmt: stmts.atan): values = frame.get_values(stmt.args) return (math.atan(values[0]),) + @impl(stmts.atan2) def atan2(self, interp, frame: Frame, stmt: stmts.atan2): values = frame.get_values(stmt.args) return (math.atan2(values[0], values[1]),) + @impl(stmts.atanh) def atanh(self, interp, frame: Frame, stmt: stmts.atanh): values = frame.get_values(stmt.args) return (math.atanh(values[0]),) + @impl(stmts.ceil) def ceil(self, interp, frame: Frame, stmt: stmts.ceil): values = frame.get_values(stmt.args) return (math.ceil(values[0]),) + @impl(stmts.copysign) def copysign(self, interp, frame: Frame, stmt: stmts.copysign): values = frame.get_values(stmt.args) return (math.copysign(values[0], values[1]),) + @impl(stmts.cos) def cos(self, interp, frame: Frame, stmt: stmts.cos): values = frame.get_values(stmt.args) return (math.cos(values[0]),) + @impl(stmts.cosh) def cosh(self, interp, frame: Frame, stmt: stmts.cosh): values = frame.get_values(stmt.args) return (math.cosh(values[0]),) + @impl(stmts.degrees) def degrees(self, interp, frame: Frame, stmt: stmts.degrees): values = frame.get_values(stmt.args) return (math.degrees(values[0]),) + @impl(stmts.erf) def erf(self, interp, frame: Frame, stmt: stmts.erf): values = frame.get_values(stmt.args) return (math.erf(values[0]),) + @impl(stmts.erfc) def erfc(self, interp, frame: Frame, stmt: stmts.erfc): values = frame.get_values(stmt.args) return (math.erfc(values[0]),) + @impl(stmts.exp) def exp(self, interp, frame: Frame, stmt: stmts.exp): values = frame.get_values(stmt.args) return (math.exp(values[0]),) + @impl(stmts.expm1) def expm1(self, interp, frame: Frame, stmt: stmts.expm1): values = frame.get_values(stmt.args) return (math.expm1(values[0]),) + @impl(stmts.fabs) def fabs(self, interp, frame: Frame, stmt: stmts.fabs): values = frame.get_values(stmt.args) return (math.fabs(values[0]),) + @impl(stmts.floor) def floor(self, interp, frame: Frame, stmt: stmts.floor): values = frame.get_values(stmt.args) return (math.floor(values[0]),) + + @impl(stmts.fma) + def fma(self, interp, frame: Frame, stmt: stmts.fma): + values = frame.get_values(stmt.args) + return (math.fma(values[0], values[1], values[2]),) + + @impl(stmts.fmod) def fmod(self, interp, frame: Frame, stmt: stmts.fmod): values = frame.get_values(stmt.args) return (math.fmod(values[0], values[1]),) + @impl(stmts.gamma) def gamma(self, interp, frame: Frame, stmt: stmts.gamma): values = frame.get_values(stmt.args) return (math.gamma(values[0]),) + @impl(stmts.isfinite) def isfinite(self, interp, frame: Frame, stmt: stmts.isfinite): values = frame.get_values(stmt.args) return (math.isfinite(values[0]),) + @impl(stmts.isinf) def isinf(self, interp, frame: Frame, stmt: stmts.isinf): values = frame.get_values(stmt.args) return (math.isinf(values[0]),) + @impl(stmts.isnan) def isnan(self, interp, frame: Frame, stmt: stmts.isnan): values = frame.get_values(stmt.args) return (math.isnan(values[0]),) + @impl(stmts.lgamma) def lgamma(self, interp, frame: Frame, stmt: stmts.lgamma): values = frame.get_values(stmt.args) return (math.lgamma(values[0]),) + + @impl(stmts.log) + def log(self, interp, frame: Frame, stmt: stmts.log): + values = frame.get_values(stmt.args) + return (math.log(values[0], values[1]),) + + @impl(stmts.log10) def log10(self, interp, frame: Frame, stmt: stmts.log10): values = frame.get_values(stmt.args) return (math.log10(values[0]),) + @impl(stmts.log1p) def log1p(self, interp, frame: Frame, stmt: stmts.log1p): values = frame.get_values(stmt.args) return (math.log1p(values[0]),) + @impl(stmts.log2) def log2(self, interp, frame: Frame, stmt: stmts.log2): values = frame.get_values(stmt.args) return (math.log2(values[0]),) + @impl(stmts.pow) def pow(self, interp, frame: Frame, stmt: stmts.pow): values = frame.get_values(stmt.args) return (math.pow(values[0], values[1]),) + @impl(stmts.radians) def radians(self, interp, frame: Frame, stmt: stmts.radians): values = frame.get_values(stmt.args) return (math.radians(values[0]),) + @impl(stmts.remainder) def remainder(self, interp, frame: Frame, stmt: stmts.remainder): values = frame.get_values(stmt.args) return (math.remainder(values[0], values[1]),) + @impl(stmts.sin) def sin(self, interp, frame: Frame, stmt: stmts.sin): values = frame.get_values(stmt.args) return (math.sin(values[0]),) + @impl(stmts.sinh) def sinh(self, interp, frame: Frame, stmt: stmts.sinh): values = frame.get_values(stmt.args) return (math.sinh(values[0]),) + @impl(stmts.sqrt) def sqrt(self, interp, frame: Frame, stmt: stmts.sqrt): values = frame.get_values(stmt.args) return (math.sqrt(values[0]),) + @impl(stmts.tan) def tan(self, interp, frame: Frame, stmt: stmts.tan): values = frame.get_values(stmt.args) return (math.tan(values[0]),) + @impl(stmts.tanh) def tanh(self, interp, frame: Frame, stmt: stmts.tanh): values = frame.get_values(stmt.args) return (math.tanh(values[0]),) + @impl(stmts.trunc) def trunc(self, interp, frame: Frame, stmt: stmts.trunc): values = frame.get_values(stmt.args) return (math.trunc(values[0]),) + @impl(stmts.ulp) def ulp(self, interp, frame: Frame, stmt: stmts.ulp): values = frame.get_values(stmt.args) diff --git a/src/kirin/dialects/math/stmts.py b/src/kirin/dialects/math/stmts.py index dc0d22f3c..e6a76e5a2 100644 --- a/src/kirin/dialects/math/stmts.py +++ b/src/kirin/dialects/math/stmts.py @@ -1,369 +1,355 @@ # This file is generated by gen.py -from kirin import ir, types, lowering -from kirin.decl import info, statement +from kirin import ir, types, lowering2 +from kirin.decl import statement, info from kirin.dialects.math.dialect import dialect @statement(dialect=dialect) class acos(ir.Statement): - """acos statement, wrapping the math.acos function""" - + """acos statement, wrapping the math.acos function + """ name = "acos" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class asin(ir.Statement): - """asin statement, wrapping the math.asin function""" - + """asin statement, wrapping the math.asin function + """ name = "asin" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class asinh(ir.Statement): - """asinh statement, wrapping the math.asinh function""" - + """asinh statement, wrapping the math.asinh function + """ name = "asinh" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class atan(ir.Statement): - """atan statement, wrapping the math.atan function""" - + """atan statement, wrapping the math.atan function + """ name = "atan" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class atan2(ir.Statement): - """atan2 statement, wrapping the math.atan2 function""" - + """atan2 statement, wrapping the math.atan2 function + """ name = "atan2" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - y: ir.SSAValue = info.argument(types.Float) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + y : ir.SSAValue = info.argument(types.Float) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class atanh(ir.Statement): - """atanh statement, wrapping the math.atanh function""" - + """atanh statement, wrapping the math.atanh function + """ name = "atanh" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class ceil(ir.Statement): - """ceil statement, wrapping the math.ceil function""" - + """ceil statement, wrapping the math.ceil function + """ name = "ceil" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) - result: ir.ResultValue = info.result(types.Int) - + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(types.Float) @statement(dialect=dialect) class copysign(ir.Statement): - """copysign statement, wrapping the math.copysign function""" - + """copysign statement, wrapping the math.copysign function + """ name = "copysign" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) - y: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) + y : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class cos(ir.Statement): - """cos statement, wrapping the math.cos function""" - + """cos statement, wrapping the math.cos function + """ name = "cos" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class cosh(ir.Statement): - """cosh statement, wrapping the math.cosh function""" - + """cosh statement, wrapping the math.cosh function + """ name = "cosh" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class degrees(ir.Statement): - """degrees statement, wrapping the math.degrees function""" - + """degrees statement, wrapping the math.degrees function + """ name = "degrees" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class erf(ir.Statement): - """erf statement, wrapping the math.erf function""" - + """erf statement, wrapping the math.erf function + """ name = "erf" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class erfc(ir.Statement): - """erfc statement, wrapping the math.erfc function""" - + """erfc statement, wrapping the math.erfc function + """ name = "erfc" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class exp(ir.Statement): - """exp statement, wrapping the math.exp function""" - + """exp statement, wrapping the math.exp function + """ name = "exp" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class expm1(ir.Statement): - """expm1 statement, wrapping the math.expm1 function""" - + """expm1 statement, wrapping the math.expm1 function + """ name = "expm1" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class fabs(ir.Statement): - """fabs statement, wrapping the math.fabs function""" - + """fabs statement, wrapping the math.fabs function + """ name = "fabs" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class floor(ir.Statement): - """floor statement, wrapping the math.floor function""" - + """floor statement, wrapping the math.floor function + """ name = "floor" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) - result: ir.ResultValue = info.result(types.Int) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(types.Float) +@statement(dialect=dialect) +class fma(ir.Statement): + """fma statement, wrapping the math.fma function + """ + name = "fma" + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) + y : ir.SSAValue = info.argument(types.Float) + z : ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(types.Float) @statement(dialect=dialect) class fmod(ir.Statement): - """fmod statement, wrapping the math.fmod function""" - + """fmod statement, wrapping the math.fmod function + """ name = "fmod" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) - y: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) + y : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class gamma(ir.Statement): - """gamma statement, wrapping the math.gamma function""" - + """gamma statement, wrapping the math.gamma function + """ name = "gamma" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class isfinite(ir.Statement): - """isfinite statement, wrapping the math.isfinite function""" - + """isfinite statement, wrapping the math.isfinite function + """ name = "isfinite" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Bool) - @statement(dialect=dialect) class isinf(ir.Statement): - """isinf statement, wrapping the math.isinf function""" - + """isinf statement, wrapping the math.isinf function + """ name = "isinf" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Bool) - @statement(dialect=dialect) class isnan(ir.Statement): - """isnan statement, wrapping the math.isnan function""" - + """isnan statement, wrapping the math.isnan function + """ name = "isnan" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Bool) - @statement(dialect=dialect) class lgamma(ir.Statement): - """lgamma statement, wrapping the math.lgamma function""" - + """lgamma statement, wrapping the math.lgamma function + """ name = "lgamma" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) +@statement(dialect=dialect) +class log(ir.Statement): + """log statement, wrapping the math.log function + """ + name = "log" + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) + base : ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(types.Float) @statement(dialect=dialect) class log10(ir.Statement): - """log10 statement, wrapping the math.log10 function""" - + """log10 statement, wrapping the math.log10 function + """ name = "log10" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class log1p(ir.Statement): - """log1p statement, wrapping the math.log1p function""" - + """log1p statement, wrapping the math.log1p function + """ name = "log1p" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class log2(ir.Statement): - """log2 statement, wrapping the math.log2 function""" - + """log2 statement, wrapping the math.log2 function + """ name = "log2" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class pow(ir.Statement): - """pow statement, wrapping the math.pow function""" - + """pow statement, wrapping the math.pow function + """ name = "pow" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) - y: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) + y : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class radians(ir.Statement): - """radians statement, wrapping the math.radians function""" - + """radians statement, wrapping the math.radians function + """ name = "radians" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class remainder(ir.Statement): - """remainder statement, wrapping the math.remainder function""" - + """remainder statement, wrapping the math.remainder function + """ name = "remainder" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) - y: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) + y : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class sin(ir.Statement): - """sin statement, wrapping the math.sin function""" - + """sin statement, wrapping the math.sin function + """ name = "sin" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class sinh(ir.Statement): - """sinh statement, wrapping the math.sinh function""" - + """sinh statement, wrapping the math.sinh function + """ name = "sinh" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class sqrt(ir.Statement): - """sqrt statement, wrapping the math.sqrt function""" - + """sqrt statement, wrapping the math.sqrt function + """ name = "sqrt" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class tan(ir.Statement): - """tan statement, wrapping the math.tan function""" - + """tan statement, wrapping the math.tan function + """ name = "tan" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class tanh(ir.Statement): - """tanh statement, wrapping the math.tanh function""" - + """tanh statement, wrapping the math.tanh function + """ name = "tanh" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class trunc(ir.Statement): - """trunc statement, wrapping the math.trunc function""" - + """trunc statement, wrapping the math.trunc function + """ name = "trunc" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) - result: ir.ResultValue = info.result(types.Int) - + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(types.Float) @statement(dialect=dialect) class ulp(ir.Statement): - """ulp statement, wrapping the math.ulp function""" - + """ulp statement, wrapping the math.ulp function + """ name = "ulp" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x: ir.SSAValue = info.argument(types.Float) + traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) diff --git a/test/dialects/math/test_basic.py b/test/dialects/math/test_basic.py index 0a445b215..386d96906 100644 --- a/test/dialects/math/test_basic.py +++ b/test/dialects/math/test_basic.py @@ -1,366 +1,311 @@ # type: ignore # This file is generated by gen.py import math as pymath - from kirin.prelude import basic from kirin.dialects import math + @basic def acos_func(x): return math.acos(x) - def test_acos(): truth = pymath.acos(0.42) assert (acos_func(0.42) - truth) < 1e-6 - @basic def asin_func(x): return math.asin(x) - def test_asin(): truth = pymath.asin(0.42) assert (asin_func(0.42) - truth) < 1e-6 - @basic def asinh_func(x): return math.asinh(x) - def test_asinh(): truth = pymath.asinh(0.42) assert (asinh_func(0.42) - truth) < 1e-6 - @basic def atan_func(x): return math.atan(x) - def test_atan(): truth = pymath.atan(0.42) assert (atan_func(0.42) - truth) < 1e-6 - @basic def atan2_func(y, x): return math.atan2(y, x) - def test_atan2(): truth = pymath.atan2(0.42, 0.42) assert (atan2_func(0.42, 0.42) - truth) < 1e-6 - @basic def atanh_func(x): return math.atanh(x) - def test_atanh(): truth = pymath.atanh(0.42) assert (atanh_func(0.42) - truth) < 1e-6 - @basic def ceil_func(x): return math.ceil(x) - def test_ceil(): truth = pymath.ceil(0.42) assert (ceil_func(0.42) - truth) < 1e-6 - @basic def copysign_func(x, y): return math.copysign(x, y) - def test_copysign(): truth = pymath.copysign(0.42, 0.42) assert (copysign_func(0.42, 0.42) - truth) < 1e-6 - @basic def cos_func(x): return math.cos(x) - def test_cos(): truth = pymath.cos(0.42) assert (cos_func(0.42) - truth) < 1e-6 - @basic def cosh_func(x): return math.cosh(x) - def test_cosh(): truth = pymath.cosh(0.42) assert (cosh_func(0.42) - truth) < 1e-6 - @basic def degrees_func(x): return math.degrees(x) - def test_degrees(): truth = pymath.degrees(0.42) assert (degrees_func(0.42) - truth) < 1e-6 - @basic def erf_func(x): return math.erf(x) - def test_erf(): truth = pymath.erf(0.42) assert (erf_func(0.42) - truth) < 1e-6 - @basic def erfc_func(x): return math.erfc(x) - def test_erfc(): truth = pymath.erfc(0.42) assert (erfc_func(0.42) - truth) < 1e-6 - @basic def exp_func(x): return math.exp(x) - def test_exp(): truth = pymath.exp(0.42) assert (exp_func(0.42) - truth) < 1e-6 - @basic def expm1_func(x): return math.expm1(x) - def test_expm1(): truth = pymath.expm1(0.42) assert (expm1_func(0.42) - truth) < 1e-6 - @basic def fabs_func(x): return math.fabs(x) - def test_fabs(): truth = pymath.fabs(0.42) assert (fabs_func(0.42) - truth) < 1e-6 - @basic def floor_func(x): return math.floor(x) - def test_floor(): truth = pymath.floor(0.42) assert (floor_func(0.42) - truth) < 1e-6 +@basic +def fma_func(x, y, z): + return math.fma(x, y, z) + +def test_fma(): + truth = pymath.fma(0.42, 0.42, 0.42) + assert (fma_func(0.42, 0.42, 0.42) - truth) < 1e-6 @basic def fmod_func(x, y): return math.fmod(x, y) - def test_fmod(): truth = pymath.fmod(0.42, 0.42) assert (fmod_func(0.42, 0.42) - truth) < 1e-6 - @basic def gamma_func(x): return math.gamma(x) - def test_gamma(): truth = pymath.gamma(0.42) assert (gamma_func(0.42) - truth) < 1e-6 - @basic def isfinite_func(x): return math.isfinite(x) - def test_isfinite(): truth = pymath.isfinite(0.42) assert (isfinite_func(0.42) - truth) < 1e-6 - @basic def isinf_func(x): return math.isinf(x) - def test_isinf(): truth = pymath.isinf(0.42) assert (isinf_func(0.42) - truth) < 1e-6 - @basic def isnan_func(x): return math.isnan(x) - def test_isnan(): truth = pymath.isnan(0.42) assert (isnan_func(0.42) - truth) < 1e-6 - @basic def lgamma_func(x): return math.lgamma(x) - def test_lgamma(): truth = pymath.lgamma(0.42) assert (lgamma_func(0.42) - truth) < 1e-6 +@basic +def log_func(x, base): + return math.log(x, base) + +def test_log(): + truth = pymath.log(0.42, 0.42) + assert (log_func(0.42, 0.42) - truth) < 1e-6 @basic def log10_func(x): return math.log10(x) - def test_log10(): truth = pymath.log10(0.42) assert (log10_func(0.42) - truth) < 1e-6 - @basic def log1p_func(x): return math.log1p(x) - def test_log1p(): truth = pymath.log1p(0.42) assert (log1p_func(0.42) - truth) < 1e-6 - @basic def log2_func(x): return math.log2(x) - def test_log2(): truth = pymath.log2(0.42) assert (log2_func(0.42) - truth) < 1e-6 - @basic def pow_func(x, y): return math.pow(x, y) - def test_pow(): truth = pymath.pow(0.42, 0.42) assert (pow_func(0.42, 0.42) - truth) < 1e-6 - @basic def radians_func(x): return math.radians(x) - def test_radians(): truth = pymath.radians(0.42) assert (radians_func(0.42) - truth) < 1e-6 - @basic def remainder_func(x, y): return math.remainder(x, y) - def test_remainder(): truth = pymath.remainder(0.42, 0.42) assert (remainder_func(0.42, 0.42) - truth) < 1e-6 - @basic def sin_func(x): return math.sin(x) - def test_sin(): truth = pymath.sin(0.42) assert (sin_func(0.42) - truth) < 1e-6 - @basic def sinh_func(x): return math.sinh(x) - def test_sinh(): truth = pymath.sinh(0.42) assert (sinh_func(0.42) - truth) < 1e-6 - @basic def sqrt_func(x): return math.sqrt(x) - def test_sqrt(): truth = pymath.sqrt(0.42) assert (sqrt_func(0.42) - truth) < 1e-6 - @basic def tan_func(x): return math.tan(x) - def test_tan(): truth = pymath.tan(0.42) assert (tan_func(0.42) - truth) < 1e-6 - @basic def tanh_func(x): return math.tanh(x) - def test_tanh(): truth = pymath.tanh(0.42) assert (tanh_func(0.42) - truth) < 1e-6 - @basic def trunc_func(x): return math.trunc(x) - def test_trunc(): truth = pymath.trunc(0.42) assert (trunc_func(0.42) - truth) < 1e-6 - @basic def ulp_func(x): return math.ulp(x) - def test_ulp(): truth = pymath.ulp(0.42) assert (ulp_func(0.42) - truth) < 1e-6 From ae71c3d914be265071b850c9760cb14cba50482b Mon Sep 17 00:00:00 2001 From: Advaith Anand Date: Sun, 5 Apr 2026 12:15:52 -0700 Subject: [PATCH 2/5] change lowering2 to lowering in math dialect --- src/kirin/dialects/math/__init__.py | 78 ++++++++++++++--------------- src/kirin/dialects/math/_gen.py | 8 +-- src/kirin/dialects/math/stmts.py | 78 ++++++++++++++--------------- 3 files changed, 82 insertions(+), 82 deletions(-) diff --git a/src/kirin/dialects/math/__init__.py b/src/kirin/dialects/math/__init__.py index 4e48d7e18..35e5a03dd 100644 --- a/src/kirin/dialects/math/__init__.py +++ b/src/kirin/dialects/math/__init__.py @@ -5,119 +5,119 @@ pi = pymath.pi e = pymath.e tau = pymath.tau -from kirin import lowering2 +from kirin import lowering -@lowering2.wraps(stmts.acos) +@lowering.wraps(stmts.acos) def acos(x: float) -> float: ... -@lowering2.wraps(stmts.asin) +@lowering.wraps(stmts.asin) def asin(x: float) -> float: ... -@lowering2.wraps(stmts.asinh) +@lowering.wraps(stmts.asinh) def asinh(x: float) -> float: ... -@lowering2.wraps(stmts.atan) +@lowering.wraps(stmts.atan) def atan(x: float) -> float: ... -@lowering2.wraps(stmts.atan2) +@lowering.wraps(stmts.atan2) def atan2(y: float, x: float) -> float: ... -@lowering2.wraps(stmts.atanh) +@lowering.wraps(stmts.atanh) def atanh(x: float) -> float: ... -@lowering2.wraps(stmts.ceil) +@lowering.wraps(stmts.ceil) def ceil(x: int) -> int: ... -@lowering2.wraps(stmts.copysign) +@lowering.wraps(stmts.copysign) def copysign(x: float, y: float) -> float: ... -@lowering2.wraps(stmts.cos) +@lowering.wraps(stmts.cos) def cos(x: float) -> float: ... -@lowering2.wraps(stmts.cosh) +@lowering.wraps(stmts.cosh) def cosh(x: float) -> float: ... -@lowering2.wraps(stmts.degrees) +@lowering.wraps(stmts.degrees) def degrees(x: float) -> float: ... -@lowering2.wraps(stmts.erf) +@lowering.wraps(stmts.erf) def erf(x: float) -> float: ... -@lowering2.wraps(stmts.erfc) +@lowering.wraps(stmts.erfc) def erfc(x: float) -> float: ... -@lowering2.wraps(stmts.exp) +@lowering.wraps(stmts.exp) def exp(x: float) -> float: ... -@lowering2.wraps(stmts.expm1) +@lowering.wraps(stmts.expm1) def expm1(x: float) -> float: ... -@lowering2.wraps(stmts.fabs) +@lowering.wraps(stmts.fabs) def fabs(x: float) -> float: ... -@lowering2.wraps(stmts.floor) +@lowering.wraps(stmts.floor) def floor(x: int) -> int: ... -@lowering2.wraps(stmts.fma) +@lowering.wraps(stmts.fma) def fma(x: float, y: float, z: float) -> float: ... -@lowering2.wraps(stmts.fmod) +@lowering.wraps(stmts.fmod) def fmod(x: float, y: float) -> float: ... -@lowering2.wraps(stmts.gamma) +@lowering.wraps(stmts.gamma) def gamma(x: float) -> float: ... -@lowering2.wraps(stmts.isfinite) +@lowering.wraps(stmts.isfinite) def isfinite(x: bool) -> bool: ... -@lowering2.wraps(stmts.isinf) +@lowering.wraps(stmts.isinf) def isinf(x: bool) -> bool: ... -@lowering2.wraps(stmts.isnan) +@lowering.wraps(stmts.isnan) def isnan(x: bool) -> bool: ... -@lowering2.wraps(stmts.lgamma) +@lowering.wraps(stmts.lgamma) def lgamma(x: float) -> float: ... -@lowering2.wraps(stmts.log) +@lowering.wraps(stmts.log) def log(x: float, base: float) -> float: ... -@lowering2.wraps(stmts.log10) +@lowering.wraps(stmts.log10) def log10(x: float) -> float: ... -@lowering2.wraps(stmts.log1p) +@lowering.wraps(stmts.log1p) def log1p(x: float) -> float: ... -@lowering2.wraps(stmts.log2) +@lowering.wraps(stmts.log2) def log2(x: float) -> float: ... -@lowering2.wraps(stmts.pow) +@lowering.wraps(stmts.pow) def pow(x: float, y: float) -> float: ... -@lowering2.wraps(stmts.radians) +@lowering.wraps(stmts.radians) def radians(x: float) -> float: ... -@lowering2.wraps(stmts.remainder) +@lowering.wraps(stmts.remainder) def remainder(x: float, y: float) -> float: ... -@lowering2.wraps(stmts.sin) +@lowering.wraps(stmts.sin) def sin(x: float) -> float: ... -@lowering2.wraps(stmts.sinh) +@lowering.wraps(stmts.sinh) def sinh(x: float) -> float: ... -@lowering2.wraps(stmts.sqrt) +@lowering.wraps(stmts.sqrt) def sqrt(x: float) -> float: ... -@lowering2.wraps(stmts.tan) +@lowering.wraps(stmts.tan) def tan(x: float) -> float: ... -@lowering2.wraps(stmts.tanh) +@lowering.wraps(stmts.tanh) def tanh(x: float) -> float: ... -@lowering2.wraps(stmts.trunc) +@lowering.wraps(stmts.trunc) def trunc(x: int) -> int: ... -@lowering2.wraps(stmts.ulp) +@lowering.wraps(stmts.ulp) def ulp(x: float) -> float: ... diff --git a/src/kirin/dialects/math/_gen.py b/src/kirin/dialects/math/_gen.py index 248f8222d..2a6684c4b 100644 --- a/src/kirin/dialects/math/_gen.py +++ b/src/kirin/dialects/math/_gen.py @@ -65,7 +65,7 @@ def builtin_math_functions(): with open(os.path.join(os.path.dirname(__file__), "stmts.py"), "w") as f: f.write("# This file is generated by gen.py\n") - f.write("from kirin import ir, types, lowering2\n") + f.write("from kirin import ir, types, lowering\n") f.write("from kirin.decl import statement, info\n") f.write("from kirin.dialects.math.dialect import dialect\n") f.write("\n") @@ -86,7 +86,7 @@ class {name}(ir.Statement): \"\"\"{name} statement, wrapping the math.{name} function \"\"\" name = "{name}" - traits = frozenset({{ir.Pure(), lowering2.FromPythonCall()}}) + traits = frozenset({{ir.Pure(), lowering.FromPythonCall()}}) {fields} result: ir.ResultValue = info.result({ret_type}) """)) @@ -129,7 +129,7 @@ class MathMethodTable(MethodTable): f.write("pi = pymath.pi\n") f.write("e = pymath.e\n") f.write("tau = pymath.tau\n") - f.write("from kirin import lowering2\n") + f.write("from kirin import lowering\n") for name, obj, sig in builtin_math_functions(): if "is" in name: @@ -139,7 +139,7 @@ class MathMethodTable(MethodTable): else: ret_type = "float" f.write(textwrap.dedent(f""" - @lowering2.wraps(stmts.{name}) + @lowering.wraps(stmts.{name}) def {name}({", ".join(f"{arg}: {ret_type}" for arg in sig.parameters.keys())}) -> {ret_type}: ... """)) f.write("\n") diff --git a/src/kirin/dialects/math/stmts.py b/src/kirin/dialects/math/stmts.py index e6a76e5a2..5978fa1b0 100644 --- a/src/kirin/dialects/math/stmts.py +++ b/src/kirin/dialects/math/stmts.py @@ -1,5 +1,5 @@ # This file is generated by gen.py -from kirin import ir, types, lowering2 +from kirin import ir, types, lowering from kirin.decl import statement, info from kirin.dialects.math.dialect import dialect @@ -9,7 +9,7 @@ class acos(ir.Statement): """acos statement, wrapping the math.acos function """ name = "acos" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -18,7 +18,7 @@ class asin(ir.Statement): """asin statement, wrapping the math.asin function """ name = "asin" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -27,7 +27,7 @@ class asinh(ir.Statement): """asinh statement, wrapping the math.asinh function """ name = "asinh" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -36,7 +36,7 @@ class atan(ir.Statement): """atan statement, wrapping the math.atan function """ name = "atan" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -45,7 +45,7 @@ class atan2(ir.Statement): """atan2 statement, wrapping the math.atan2 function """ name = "atan2" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) y : ir.SSAValue = info.argument(types.Float) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -55,7 +55,7 @@ class atanh(ir.Statement): """atanh statement, wrapping the math.atanh function """ name = "atanh" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -64,7 +64,7 @@ class ceil(ir.Statement): """ceil statement, wrapping the math.ceil function """ name = "ceil" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -73,7 +73,7 @@ class copysign(ir.Statement): """copysign statement, wrapping the math.copysign function """ name = "copysign" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) y : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -83,7 +83,7 @@ class cos(ir.Statement): """cos statement, wrapping the math.cos function """ name = "cos" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -92,7 +92,7 @@ class cosh(ir.Statement): """cosh statement, wrapping the math.cosh function """ name = "cosh" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -101,7 +101,7 @@ class degrees(ir.Statement): """degrees statement, wrapping the math.degrees function """ name = "degrees" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -110,7 +110,7 @@ class erf(ir.Statement): """erf statement, wrapping the math.erf function """ name = "erf" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -119,7 +119,7 @@ class erfc(ir.Statement): """erfc statement, wrapping the math.erfc function """ name = "erfc" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -128,7 +128,7 @@ class exp(ir.Statement): """exp statement, wrapping the math.exp function """ name = "exp" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -137,7 +137,7 @@ class expm1(ir.Statement): """expm1 statement, wrapping the math.expm1 function """ name = "expm1" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -146,7 +146,7 @@ class fabs(ir.Statement): """fabs statement, wrapping the math.fabs function """ name = "fabs" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -155,7 +155,7 @@ class floor(ir.Statement): """floor statement, wrapping the math.floor function """ name = "floor" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -164,7 +164,7 @@ class fma(ir.Statement): """fma statement, wrapping the math.fma function """ name = "fma" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) y : ir.SSAValue = info.argument(types.Float) z : ir.SSAValue = info.argument(types.Float) @@ -175,7 +175,7 @@ class fmod(ir.Statement): """fmod statement, wrapping the math.fmod function """ name = "fmod" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) y : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -185,7 +185,7 @@ class gamma(ir.Statement): """gamma statement, wrapping the math.gamma function """ name = "gamma" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -194,7 +194,7 @@ class isfinite(ir.Statement): """isfinite statement, wrapping the math.isfinite function """ name = "isfinite" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Bool) @@ -203,7 +203,7 @@ class isinf(ir.Statement): """isinf statement, wrapping the math.isinf function """ name = "isinf" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Bool) @@ -212,7 +212,7 @@ class isnan(ir.Statement): """isnan statement, wrapping the math.isnan function """ name = "isnan" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Bool) @@ -221,7 +221,7 @@ class lgamma(ir.Statement): """lgamma statement, wrapping the math.lgamma function """ name = "lgamma" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -230,7 +230,7 @@ class log(ir.Statement): """log statement, wrapping the math.log function """ name = "log" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) base : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -240,7 +240,7 @@ class log10(ir.Statement): """log10 statement, wrapping the math.log10 function """ name = "log10" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -249,7 +249,7 @@ class log1p(ir.Statement): """log1p statement, wrapping the math.log1p function """ name = "log1p" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -258,7 +258,7 @@ class log2(ir.Statement): """log2 statement, wrapping the math.log2 function """ name = "log2" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -267,7 +267,7 @@ class pow(ir.Statement): """pow statement, wrapping the math.pow function """ name = "pow" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) y : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -277,7 +277,7 @@ class radians(ir.Statement): """radians statement, wrapping the math.radians function """ name = "radians" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -286,7 +286,7 @@ class remainder(ir.Statement): """remainder statement, wrapping the math.remainder function """ name = "remainder" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) y : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -296,7 +296,7 @@ class sin(ir.Statement): """sin statement, wrapping the math.sin function """ name = "sin" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -305,7 +305,7 @@ class sinh(ir.Statement): """sinh statement, wrapping the math.sinh function """ name = "sinh" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -314,7 +314,7 @@ class sqrt(ir.Statement): """sqrt statement, wrapping the math.sqrt function """ name = "sqrt" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -323,7 +323,7 @@ class tan(ir.Statement): """tan statement, wrapping the math.tan function """ name = "tan" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -332,7 +332,7 @@ class tanh(ir.Statement): """tanh statement, wrapping the math.tanh function """ name = "tanh" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -341,7 +341,7 @@ class trunc(ir.Statement): """trunc statement, wrapping the math.trunc function """ name = "trunc" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) @@ -350,6 +350,6 @@ class ulp(ir.Statement): """ulp statement, wrapping the math.ulp function """ name = "ulp" - traits = frozenset({ir.Pure(), lowering2.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) From cd0c584e5c09db618531003dae02da6cc9221062 Mon Sep 17 00:00:00 2001 From: Advaith Anand Date: Mon, 6 Apr 2026 13:39:42 -0700 Subject: [PATCH 3/5] fix signature return and parameter type syncing --- src/kirin/dialects/math/__init__.py | 15 ++++++--------- src/kirin/dialects/math/_gen.py | 4 +++- src/kirin/dialects/math/interp.py | 6 ------ src/kirin/dialects/math/stmts.py | 11 ----------- test/dialects/math/test_basic.py | 8 -------- 5 files changed, 9 insertions(+), 35 deletions(-) diff --git a/src/kirin/dialects/math/__init__.py b/src/kirin/dialects/math/__init__.py index 35e5a03dd..423f131f0 100644 --- a/src/kirin/dialects/math/__init__.py +++ b/src/kirin/dialects/math/__init__.py @@ -26,7 +26,7 @@ def atan2(y: float, x: float) -> float: ... def atanh(x: float) -> float: ... @lowering.wraps(stmts.ceil) -def ceil(x: int) -> int: ... +def ceil(x: float) -> int: ... @lowering.wraps(stmts.copysign) def copysign(x: float, y: float) -> float: ... @@ -56,10 +56,7 @@ def expm1(x: float) -> float: ... def fabs(x: float) -> float: ... @lowering.wraps(stmts.floor) -def floor(x: int) -> int: ... - -@lowering.wraps(stmts.fma) -def fma(x: float, y: float, z: float) -> float: ... +def floor(x: float) -> int: ... @lowering.wraps(stmts.fmod) def fmod(x: float, y: float) -> float: ... @@ -68,13 +65,13 @@ def fmod(x: float, y: float) -> float: ... def gamma(x: float) -> float: ... @lowering.wraps(stmts.isfinite) -def isfinite(x: bool) -> bool: ... +def isfinite(x: float) -> bool: ... @lowering.wraps(stmts.isinf) -def isinf(x: bool) -> bool: ... +def isinf(x: float) -> bool: ... @lowering.wraps(stmts.isnan) -def isnan(x: bool) -> bool: ... +def isnan(x: float) -> bool: ... @lowering.wraps(stmts.lgamma) def lgamma(x: float) -> float: ... @@ -116,7 +113,7 @@ def tan(x: float) -> float: ... def tanh(x: float) -> float: ... @lowering.wraps(stmts.trunc) -def trunc(x: int) -> int: ... +def trunc(x: float) -> int: ... @lowering.wraps(stmts.ulp) def ulp(x: float) -> float: ... diff --git a/src/kirin/dialects/math/_gen.py b/src/kirin/dialects/math/_gen.py index 2a6684c4b..a77ced391 100644 --- a/src/kirin/dialects/math/_gen.py +++ b/src/kirin/dialects/math/_gen.py @@ -32,6 +32,8 @@ def builtin_math_functions(): # 3.10 compat "cbrt", "exp2", + # 3.13 compat + "fma", ): continue @@ -140,7 +142,7 @@ class MathMethodTable(MethodTable): ret_type = "float" f.write(textwrap.dedent(f""" @lowering.wraps(stmts.{name}) - def {name}({", ".join(f"{arg}: {ret_type}" for arg in sig.parameters.keys())}) -> {ret_type}: ... + def {name}({", ".join(f"{arg}: float" for arg in sig.parameters.keys())}) -> {ret_type}: ... """)) f.write("\n") diff --git a/src/kirin/dialects/math/interp.py b/src/kirin/dialects/math/interp.py index 1943cd1ea..9c8025572 100644 --- a/src/kirin/dialects/math/interp.py +++ b/src/kirin/dialects/math/interp.py @@ -110,12 +110,6 @@ def floor(self, interp, frame: Frame, stmt: stmts.floor): return (math.floor(values[0]),) - @impl(stmts.fma) - def fma(self, interp, frame: Frame, stmt: stmts.fma): - values = frame.get_values(stmt.args) - return (math.fma(values[0], values[1], values[2]),) - - @impl(stmts.fmod) def fmod(self, interp, frame: Frame, stmt: stmts.fmod): values = frame.get_values(stmt.args) diff --git a/src/kirin/dialects/math/stmts.py b/src/kirin/dialects/math/stmts.py index 5978fa1b0..c84b511b3 100644 --- a/src/kirin/dialects/math/stmts.py +++ b/src/kirin/dialects/math/stmts.py @@ -159,17 +159,6 @@ class floor(ir.Statement): x : ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) -@statement(dialect=dialect) -class fma(ir.Statement): - """fma statement, wrapping the math.fma function - """ - name = "fma" - traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) - y : ir.SSAValue = info.argument(types.Float) - z : ir.SSAValue = info.argument(types.Float) - result: ir.ResultValue = info.result(types.Float) - @statement(dialect=dialect) class fmod(ir.Statement): """fmod statement, wrapping the math.fmod function diff --git a/test/dialects/math/test_basic.py b/test/dialects/math/test_basic.py index 386d96906..b67975176 100644 --- a/test/dialects/math/test_basic.py +++ b/test/dialects/math/test_basic.py @@ -142,14 +142,6 @@ def test_floor(): truth = pymath.floor(0.42) assert (floor_func(0.42) - truth) < 1e-6 -@basic -def fma_func(x, y, z): - return math.fma(x, y, z) - -def test_fma(): - truth = pymath.fma(0.42, 0.42, 0.42) - assert (fma_func(0.42, 0.42, 0.42) - truth) < 1e-6 - @basic def fmod_func(x, y): return math.fmod(x, y) From ae2357b9d7e87151bab1ccb4f56d5e62328177af Mon Sep 17 00:00:00 2001 From: Advaith Anand Date: Mon, 6 Apr 2026 13:43:32 -0700 Subject: [PATCH 4/5] fix integer return types for appropriate functions in stmts.py --- src/kirin/dialects/math/_gen.py | 2 ++ src/kirin/dialects/math/stmts.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/kirin/dialects/math/_gen.py b/src/kirin/dialects/math/_gen.py index a77ced391..78d637ba4 100644 --- a/src/kirin/dialects/math/_gen.py +++ b/src/kirin/dialects/math/_gen.py @@ -80,6 +80,8 @@ def builtin_math_functions(): ) if "is" in name: ret_type = "types.Bool" + elif name in {"trunc", "ceil", "floor"}: + ret_type = "types.Int" else: ret_type = "types.Float" f.write(textwrap.dedent(f""" diff --git a/src/kirin/dialects/math/stmts.py b/src/kirin/dialects/math/stmts.py index c84b511b3..85fff3023 100644 --- a/src/kirin/dialects/math/stmts.py +++ b/src/kirin/dialects/math/stmts.py @@ -66,7 +66,7 @@ class ceil(ir.Statement): name = "ceil" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) - result: ir.ResultValue = info.result(types.Float) + result: ir.ResultValue = info.result(types.Int) @statement(dialect=dialect) class copysign(ir.Statement): @@ -157,7 +157,7 @@ class floor(ir.Statement): name = "floor" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) - result: ir.ResultValue = info.result(types.Float) + result: ir.ResultValue = info.result(types.Int) @statement(dialect=dialect) class fmod(ir.Statement): @@ -332,7 +332,7 @@ class trunc(ir.Statement): name = "trunc" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) x : ir.SSAValue = info.argument(types.Float) - result: ir.ResultValue = info.result(types.Float) + result: ir.ResultValue = info.result(types.Int) @statement(dialect=dialect) class ulp(ir.Statement): From 583718f87ad327f2c8e689cc6dfaa0c9d7010e16 Mon Sep 17 00:00:00 2001 From: Advaith Anand Date: Tue, 7 Apr 2026 18:34:37 -0700 Subject: [PATCH 5/5] addressed linting errors with precommit --- src/kirin/dialects/math/__init__.py | 48 ++++- src/kirin/dialects/math/interp.py | 41 +---- src/kirin/dialects/math/stmts.py | 272 ++++++++++++++++------------ 3 files changed, 201 insertions(+), 160 deletions(-) diff --git a/src/kirin/dialects/math/__init__.py b/src/kirin/dialects/math/__init__.py index 423f131f0..c67856ce2 100644 --- a/src/kirin/dialects/math/__init__.py +++ b/src/kirin/dialects/math/__init__.py @@ -1,120 +1,160 @@ -"math dialect, modeling functions in python's `math` stdlib"# This file is generated by gen.py +"math dialect, modeling functions in python's `math` stdlib" # This file is generated by gen.py + +import math as pymath + +from kirin import lowering from kirin.dialects.math.dialect import dialect as dialect + from . import stmts as stmts, interp as interp -import math as pymath + pi = pymath.pi e = pymath.e tau = pymath.tau -from kirin import lowering + @lowering.wraps(stmts.acos) def acos(x: float) -> float: ... + @lowering.wraps(stmts.asin) def asin(x: float) -> float: ... + @lowering.wraps(stmts.asinh) def asinh(x: float) -> float: ... + @lowering.wraps(stmts.atan) def atan(x: float) -> float: ... + @lowering.wraps(stmts.atan2) def atan2(y: float, x: float) -> float: ... + @lowering.wraps(stmts.atanh) def atanh(x: float) -> float: ... + @lowering.wraps(stmts.ceil) def ceil(x: float) -> int: ... + @lowering.wraps(stmts.copysign) def copysign(x: float, y: float) -> float: ... + @lowering.wraps(stmts.cos) def cos(x: float) -> float: ... + @lowering.wraps(stmts.cosh) def cosh(x: float) -> float: ... + @lowering.wraps(stmts.degrees) def degrees(x: float) -> float: ... + @lowering.wraps(stmts.erf) def erf(x: float) -> float: ... + @lowering.wraps(stmts.erfc) def erfc(x: float) -> float: ... + @lowering.wraps(stmts.exp) def exp(x: float) -> float: ... + @lowering.wraps(stmts.expm1) def expm1(x: float) -> float: ... + @lowering.wraps(stmts.fabs) def fabs(x: float) -> float: ... + @lowering.wraps(stmts.floor) def floor(x: float) -> int: ... + @lowering.wraps(stmts.fmod) def fmod(x: float, y: float) -> float: ... + @lowering.wraps(stmts.gamma) def gamma(x: float) -> float: ... + @lowering.wraps(stmts.isfinite) def isfinite(x: float) -> bool: ... + @lowering.wraps(stmts.isinf) def isinf(x: float) -> bool: ... + @lowering.wraps(stmts.isnan) def isnan(x: float) -> bool: ... + @lowering.wraps(stmts.lgamma) def lgamma(x: float) -> float: ... + @lowering.wraps(stmts.log) def log(x: float, base: float) -> float: ... + @lowering.wraps(stmts.log10) def log10(x: float) -> float: ... + @lowering.wraps(stmts.log1p) def log1p(x: float) -> float: ... + @lowering.wraps(stmts.log2) def log2(x: float) -> float: ... + @lowering.wraps(stmts.pow) def pow(x: float, y: float) -> float: ... + @lowering.wraps(stmts.radians) def radians(x: float) -> float: ... + @lowering.wraps(stmts.remainder) def remainder(x: float, y: float) -> float: ... + @lowering.wraps(stmts.sin) def sin(x: float) -> float: ... + @lowering.wraps(stmts.sinh) def sinh(x: float) -> float: ... + @lowering.wraps(stmts.sqrt) def sqrt(x: float) -> float: ... + @lowering.wraps(stmts.tan) def tan(x: float) -> float: ... + @lowering.wraps(stmts.tanh) def tanh(x: float) -> float: ... + @lowering.wraps(stmts.trunc) def trunc(x: float) -> int: ... + @lowering.wraps(stmts.ulp) def ulp(x: float) -> float: ... - diff --git a/src/kirin/dialects/math/interp.py b/src/kirin/dialects/math/interp.py index 9c8025572..2ca9409c3 100644 --- a/src/kirin/dialects/math/interp.py +++ b/src/kirin/dialects/math/interp.py @@ -1,8 +1,9 @@ # This file is generated by gen.py import math -from kirin.dialects.math.dialect import dialect + +from kirin.interp import Frame, MethodTable, impl from kirin.dialects.math import stmts -from kirin.interp import MethodTable, Frame, impl +from kirin.dialects.math.dialect import dialect @dialect.register @@ -13,217 +14,181 @@ def acos(self, interp, frame: Frame, stmt: stmts.acos): values = frame.get_values(stmt.args) return (math.acos(values[0]),) - @impl(stmts.asin) def asin(self, interp, frame: Frame, stmt: stmts.asin): values = frame.get_values(stmt.args) return (math.asin(values[0]),) - @impl(stmts.asinh) def asinh(self, interp, frame: Frame, stmt: stmts.asinh): values = frame.get_values(stmt.args) return (math.asinh(values[0]),) - @impl(stmts.atan) def atan(self, interp, frame: Frame, stmt: stmts.atan): values = frame.get_values(stmt.args) return (math.atan(values[0]),) - @impl(stmts.atan2) def atan2(self, interp, frame: Frame, stmt: stmts.atan2): values = frame.get_values(stmt.args) return (math.atan2(values[0], values[1]),) - @impl(stmts.atanh) def atanh(self, interp, frame: Frame, stmt: stmts.atanh): values = frame.get_values(stmt.args) return (math.atanh(values[0]),) - @impl(stmts.ceil) def ceil(self, interp, frame: Frame, stmt: stmts.ceil): values = frame.get_values(stmt.args) return (math.ceil(values[0]),) - @impl(stmts.copysign) def copysign(self, interp, frame: Frame, stmt: stmts.copysign): values = frame.get_values(stmt.args) return (math.copysign(values[0], values[1]),) - @impl(stmts.cos) def cos(self, interp, frame: Frame, stmt: stmts.cos): values = frame.get_values(stmt.args) return (math.cos(values[0]),) - @impl(stmts.cosh) def cosh(self, interp, frame: Frame, stmt: stmts.cosh): values = frame.get_values(stmt.args) return (math.cosh(values[0]),) - @impl(stmts.degrees) def degrees(self, interp, frame: Frame, stmt: stmts.degrees): values = frame.get_values(stmt.args) return (math.degrees(values[0]),) - @impl(stmts.erf) def erf(self, interp, frame: Frame, stmt: stmts.erf): values = frame.get_values(stmt.args) return (math.erf(values[0]),) - @impl(stmts.erfc) def erfc(self, interp, frame: Frame, stmt: stmts.erfc): values = frame.get_values(stmt.args) return (math.erfc(values[0]),) - @impl(stmts.exp) def exp(self, interp, frame: Frame, stmt: stmts.exp): values = frame.get_values(stmt.args) return (math.exp(values[0]),) - @impl(stmts.expm1) def expm1(self, interp, frame: Frame, stmt: stmts.expm1): values = frame.get_values(stmt.args) return (math.expm1(values[0]),) - @impl(stmts.fabs) def fabs(self, interp, frame: Frame, stmt: stmts.fabs): values = frame.get_values(stmt.args) return (math.fabs(values[0]),) - @impl(stmts.floor) def floor(self, interp, frame: Frame, stmt: stmts.floor): values = frame.get_values(stmt.args) return (math.floor(values[0]),) - @impl(stmts.fmod) def fmod(self, interp, frame: Frame, stmt: stmts.fmod): values = frame.get_values(stmt.args) return (math.fmod(values[0], values[1]),) - @impl(stmts.gamma) def gamma(self, interp, frame: Frame, stmt: stmts.gamma): values = frame.get_values(stmt.args) return (math.gamma(values[0]),) - @impl(stmts.isfinite) def isfinite(self, interp, frame: Frame, stmt: stmts.isfinite): values = frame.get_values(stmt.args) return (math.isfinite(values[0]),) - @impl(stmts.isinf) def isinf(self, interp, frame: Frame, stmt: stmts.isinf): values = frame.get_values(stmt.args) return (math.isinf(values[0]),) - @impl(stmts.isnan) def isnan(self, interp, frame: Frame, stmt: stmts.isnan): values = frame.get_values(stmt.args) return (math.isnan(values[0]),) - @impl(stmts.lgamma) def lgamma(self, interp, frame: Frame, stmt: stmts.lgamma): values = frame.get_values(stmt.args) return (math.lgamma(values[0]),) - @impl(stmts.log) def log(self, interp, frame: Frame, stmt: stmts.log): values = frame.get_values(stmt.args) return (math.log(values[0], values[1]),) - @impl(stmts.log10) def log10(self, interp, frame: Frame, stmt: stmts.log10): values = frame.get_values(stmt.args) return (math.log10(values[0]),) - @impl(stmts.log1p) def log1p(self, interp, frame: Frame, stmt: stmts.log1p): values = frame.get_values(stmt.args) return (math.log1p(values[0]),) - @impl(stmts.log2) def log2(self, interp, frame: Frame, stmt: stmts.log2): values = frame.get_values(stmt.args) return (math.log2(values[0]),) - @impl(stmts.pow) def pow(self, interp, frame: Frame, stmt: stmts.pow): values = frame.get_values(stmt.args) return (math.pow(values[0], values[1]),) - @impl(stmts.radians) def radians(self, interp, frame: Frame, stmt: stmts.radians): values = frame.get_values(stmt.args) return (math.radians(values[0]),) - @impl(stmts.remainder) def remainder(self, interp, frame: Frame, stmt: stmts.remainder): values = frame.get_values(stmt.args) return (math.remainder(values[0], values[1]),) - @impl(stmts.sin) def sin(self, interp, frame: Frame, stmt: stmts.sin): values = frame.get_values(stmt.args) return (math.sin(values[0]),) - @impl(stmts.sinh) def sinh(self, interp, frame: Frame, stmt: stmts.sinh): values = frame.get_values(stmt.args) return (math.sinh(values[0]),) - @impl(stmts.sqrt) def sqrt(self, interp, frame: Frame, stmt: stmts.sqrt): values = frame.get_values(stmt.args) return (math.sqrt(values[0]),) - @impl(stmts.tan) def tan(self, interp, frame: Frame, stmt: stmts.tan): values = frame.get_values(stmt.args) return (math.tan(values[0]),) - @impl(stmts.tanh) def tanh(self, interp, frame: Frame, stmt: stmts.tanh): values = frame.get_values(stmt.args) return (math.tanh(values[0]),) - @impl(stmts.trunc) def trunc(self, interp, frame: Frame, stmt: stmts.trunc): values = frame.get_values(stmt.args) return (math.trunc(values[0]),) - @impl(stmts.ulp) def ulp(self, interp, frame: Frame, stmt: stmts.ulp): values = frame.get_values(stmt.args) diff --git a/src/kirin/dialects/math/stmts.py b/src/kirin/dialects/math/stmts.py index 85fff3023..0a165b0d2 100644 --- a/src/kirin/dialects/math/stmts.py +++ b/src/kirin/dialects/math/stmts.py @@ -1,344 +1,380 @@ # This file is generated by gen.py from kirin import ir, types, lowering -from kirin.decl import statement, info +from kirin.decl import info, statement from kirin.dialects.math.dialect import dialect @statement(dialect=dialect) class acos(ir.Statement): - """acos statement, wrapping the math.acos function - """ + """acos statement, wrapping the math.acos function""" + name = "acos" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class asin(ir.Statement): - """asin statement, wrapping the math.asin function - """ + """asin statement, wrapping the math.asin function""" + name = "asin" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class asinh(ir.Statement): - """asinh statement, wrapping the math.asinh function - """ + """asinh statement, wrapping the math.asinh function""" + name = "asinh" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class atan(ir.Statement): - """atan statement, wrapping the math.atan function - """ + """atan statement, wrapping the math.atan function""" + name = "atan" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class atan2(ir.Statement): - """atan2 statement, wrapping the math.atan2 function - """ + """atan2 statement, wrapping the math.atan2 function""" + name = "atan2" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - y : ir.SSAValue = info.argument(types.Float) - x : ir.SSAValue = info.argument(types.Float) + y: ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class atanh(ir.Statement): - """atanh statement, wrapping the math.atanh function - """ + """atanh statement, wrapping the math.atanh function""" + name = "atanh" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class ceil(ir.Statement): - """ceil statement, wrapping the math.ceil function - """ + """ceil statement, wrapping the math.ceil function""" + name = "ceil" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Int) + @statement(dialect=dialect) class copysign(ir.Statement): - """copysign statement, wrapping the math.copysign function - """ + """copysign statement, wrapping the math.copysign function""" + name = "copysign" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) - y : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) + y: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class cos(ir.Statement): - """cos statement, wrapping the math.cos function - """ + """cos statement, wrapping the math.cos function""" + name = "cos" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class cosh(ir.Statement): - """cosh statement, wrapping the math.cosh function - """ + """cosh statement, wrapping the math.cosh function""" + name = "cosh" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class degrees(ir.Statement): - """degrees statement, wrapping the math.degrees function - """ + """degrees statement, wrapping the math.degrees function""" + name = "degrees" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class erf(ir.Statement): - """erf statement, wrapping the math.erf function - """ + """erf statement, wrapping the math.erf function""" + name = "erf" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class erfc(ir.Statement): - """erfc statement, wrapping the math.erfc function - """ + """erfc statement, wrapping the math.erfc function""" + name = "erfc" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class exp(ir.Statement): - """exp statement, wrapping the math.exp function - """ + """exp statement, wrapping the math.exp function""" + name = "exp" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class expm1(ir.Statement): - """expm1 statement, wrapping the math.expm1 function - """ + """expm1 statement, wrapping the math.expm1 function""" + name = "expm1" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class fabs(ir.Statement): - """fabs statement, wrapping the math.fabs function - """ + """fabs statement, wrapping the math.fabs function""" + name = "fabs" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class floor(ir.Statement): - """floor statement, wrapping the math.floor function - """ + """floor statement, wrapping the math.floor function""" + name = "floor" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Int) + @statement(dialect=dialect) class fmod(ir.Statement): - """fmod statement, wrapping the math.fmod function - """ + """fmod statement, wrapping the math.fmod function""" + name = "fmod" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) - y : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) + y: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class gamma(ir.Statement): - """gamma statement, wrapping the math.gamma function - """ + """gamma statement, wrapping the math.gamma function""" + name = "gamma" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class isfinite(ir.Statement): - """isfinite statement, wrapping the math.isfinite function - """ + """isfinite statement, wrapping the math.isfinite function""" + name = "isfinite" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Bool) + @statement(dialect=dialect) class isinf(ir.Statement): - """isinf statement, wrapping the math.isinf function - """ + """isinf statement, wrapping the math.isinf function""" + name = "isinf" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Bool) + @statement(dialect=dialect) class isnan(ir.Statement): - """isnan statement, wrapping the math.isnan function - """ + """isnan statement, wrapping the math.isnan function""" + name = "isnan" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Bool) + @statement(dialect=dialect) class lgamma(ir.Statement): - """lgamma statement, wrapping the math.lgamma function - """ + """lgamma statement, wrapping the math.lgamma function""" + name = "lgamma" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class log(ir.Statement): - """log statement, wrapping the math.log function - """ + """log statement, wrapping the math.log function""" + name = "log" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) - base : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) + base: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class log10(ir.Statement): - """log10 statement, wrapping the math.log10 function - """ + """log10 statement, wrapping the math.log10 function""" + name = "log10" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class log1p(ir.Statement): - """log1p statement, wrapping the math.log1p function - """ + """log1p statement, wrapping the math.log1p function""" + name = "log1p" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class log2(ir.Statement): - """log2 statement, wrapping the math.log2 function - """ + """log2 statement, wrapping the math.log2 function""" + name = "log2" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class pow(ir.Statement): - """pow statement, wrapping the math.pow function - """ + """pow statement, wrapping the math.pow function""" + name = "pow" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) - y : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) + y: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class radians(ir.Statement): - """radians statement, wrapping the math.radians function - """ + """radians statement, wrapping the math.radians function""" + name = "radians" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class remainder(ir.Statement): - """remainder statement, wrapping the math.remainder function - """ + """remainder statement, wrapping the math.remainder function""" + name = "remainder" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) - y : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) + y: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class sin(ir.Statement): - """sin statement, wrapping the math.sin function - """ + """sin statement, wrapping the math.sin function""" + name = "sin" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class sinh(ir.Statement): - """sinh statement, wrapping the math.sinh function - """ + """sinh statement, wrapping the math.sinh function""" + name = "sinh" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class sqrt(ir.Statement): - """sqrt statement, wrapping the math.sqrt function - """ + """sqrt statement, wrapping the math.sqrt function""" + name = "sqrt" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class tan(ir.Statement): - """tan statement, wrapping the math.tan function - """ + """tan statement, wrapping the math.tan function""" + name = "tan" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class tanh(ir.Statement): - """tanh statement, wrapping the math.tanh function - """ + """tanh statement, wrapping the math.tanh function""" + name = "tanh" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float) + @statement(dialect=dialect) class trunc(ir.Statement): - """trunc statement, wrapping the math.trunc function - """ + """trunc statement, wrapping the math.trunc function""" + name = "trunc" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Int) + @statement(dialect=dialect) class ulp(ir.Statement): - """ulp statement, wrapping the math.ulp function - """ + """ulp statement, wrapping the math.ulp function""" + name = "ulp" traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) - x : ir.SSAValue = info.argument(types.Float) + x: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(types.Float)