diff --git a/src/kirin/dialects/math/__init__.py b/src/kirin/dialects/math/__init__.py index 5facdaac3..c67856ce2 100644 --- a/src/kirin/dialects/math/__init__.py +++ b/src/kirin/dialects/math/__init__.py @@ -104,6 +104,10 @@ def isnan(x: float) -> bool: ... def lgamma(x: float) -> float: ... +@lowering.wraps(stmts.log) +def log(x: float, base: float) -> float: ... + + @lowering.wraps(stmts.log10) def log10(x: float) -> float: ... diff --git a/src/kirin/dialects/math/_gen.py b/src/kirin/dialects/math/_gen.py index 63be0b4d9..78d637ba4 100644 --- a/src/kirin/dialects/math/_gen.py +++ b/src/kirin/dialects/math/_gen.py @@ -32,6 +32,8 @@ def builtin_math_functions(): # 3.10 compat "cbrt", "exp2", + # 3.13 compat + "fma", ): continue @@ -40,12 +42,32 @@ def builtin_math_functions(): sig = inspect.signature(obj) yield name, obj, sig except: # noqa: E722 - continue + if name == "log": + sig = inspect.Signature( + parameters=[ + inspect.Parameter( + "x", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=object, + ), + inspect.Parameter( + "base", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=None, + annotation=object, + ), + ], + return_annotation=object, + ) + yield name, obj, sig + + else: + continue with open(os.path.join(os.path.dirname(__file__), "stmts.py"), "w") as f: f.write("# This file is generated by gen.py\n") - f.write("from kirin import ir, types, lowering2\n") + f.write("from kirin import ir, types, lowering\n") f.write("from kirin.decl import statement, info\n") f.write("from kirin.dialects.math.dialect import dialect\n") f.write("\n") @@ -58,6 +80,8 @@ def builtin_math_functions(): ) if "is" in name: ret_type = "types.Bool" + elif name in {"trunc", "ceil", "floor"}: + ret_type = "types.Int" else: ret_type = "types.Float" f.write(textwrap.dedent(f""" @@ -66,7 +90,7 @@ class {name}(ir.Statement): \"\"\"{name} statement, wrapping the math.{name} function \"\"\" name = "{name}" - traits = frozenset({{ir.Pure(), lowering2.FromPythonCall()}}) + traits = frozenset({{ir.Pure(), lowering.FromPythonCall()}}) {fields} result: ir.ResultValue = info.result({ret_type}) """)) @@ -109,7 +133,7 @@ class MathMethodTable(MethodTable): f.write("pi = pymath.pi\n") f.write("e = pymath.e\n") f.write("tau = pymath.tau\n") - f.write("from kirin import lowering2\n") + f.write("from kirin import lowering\n") for name, obj, sig in builtin_math_functions(): if "is" in name: @@ -119,8 +143,8 @@ class MathMethodTable(MethodTable): else: ret_type = "float" f.write(textwrap.dedent(f""" - @lowering2.wraps(stmts.{name}) - def {name}({", ".join(f"{arg}: {ret_type}" for arg in sig.parameters.keys())}) -> {ret_type}: ... + @lowering.wraps(stmts.{name}) + def {name}({", ".join(f"{arg}: float" for arg in sig.parameters.keys())}) -> {ret_type}: ... """)) f.write("\n") diff --git a/src/kirin/dialects/math/interp.py b/src/kirin/dialects/math/interp.py index f99d713a1..2ca9409c3 100644 --- a/src/kirin/dialects/math/interp.py +++ b/src/kirin/dialects/math/interp.py @@ -124,6 +124,11 @@ def lgamma(self, interp, frame: Frame, stmt: stmts.lgamma): values = frame.get_values(stmt.args) return (math.lgamma(values[0]),) + @impl(stmts.log) + def log(self, interp, frame: Frame, stmt: stmts.log): + values = frame.get_values(stmt.args) + return (math.log(values[0], values[1]),) + @impl(stmts.log10) def log10(self, interp, frame: Frame, stmt: stmts.log10): values = frame.get_values(stmt.args) diff --git a/src/kirin/dialects/math/stmts.py b/src/kirin/dialects/math/stmts.py index dc0d22f3c..0a165b0d2 100644 --- a/src/kirin/dialects/math/stmts.py +++ b/src/kirin/dialects/math/stmts.py @@ -237,6 +237,17 @@ class lgamma(ir.Statement): result: ir.ResultValue = info.result(types.Float) +@statement(dialect=dialect) +class log(ir.Statement): + """log statement, wrapping the math.log function""" + + name = "log" + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + x: ir.SSAValue = info.argument(types.Float) + base: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(types.Float) + + @statement(dialect=dialect) class log10(ir.Statement): """log10 statement, wrapping the math.log10 function""" diff --git a/test/dialects/math/test_basic.py b/test/dialects/math/test_basic.py index 0a445b215..b67975176 100644 --- a/test/dialects/math/test_basic.py +++ b/test/dialects/math/test_basic.py @@ -1,366 +1,303 @@ # type: ignore # This file is generated by gen.py import math as pymath - from kirin.prelude import basic from kirin.dialects import math + @basic def acos_func(x): return math.acos(x) - def test_acos(): truth = pymath.acos(0.42) assert (acos_func(0.42) - truth) < 1e-6 - @basic def asin_func(x): return math.asin(x) - def test_asin(): truth = pymath.asin(0.42) assert (asin_func(0.42) - truth) < 1e-6 - @basic def asinh_func(x): return math.asinh(x) - def test_asinh(): truth = pymath.asinh(0.42) assert (asinh_func(0.42) - truth) < 1e-6 - @basic def atan_func(x): return math.atan(x) - def test_atan(): truth = pymath.atan(0.42) assert (atan_func(0.42) - truth) < 1e-6 - @basic def atan2_func(y, x): return math.atan2(y, x) - def test_atan2(): truth = pymath.atan2(0.42, 0.42) assert (atan2_func(0.42, 0.42) - truth) < 1e-6 - @basic def atanh_func(x): return math.atanh(x) - def test_atanh(): truth = pymath.atanh(0.42) assert (atanh_func(0.42) - truth) < 1e-6 - @basic def ceil_func(x): return math.ceil(x) - def test_ceil(): truth = pymath.ceil(0.42) assert (ceil_func(0.42) - truth) < 1e-6 - @basic def copysign_func(x, y): return math.copysign(x, y) - def test_copysign(): truth = pymath.copysign(0.42, 0.42) assert (copysign_func(0.42, 0.42) - truth) < 1e-6 - @basic def cos_func(x): return math.cos(x) - def test_cos(): truth = pymath.cos(0.42) assert (cos_func(0.42) - truth) < 1e-6 - @basic def cosh_func(x): return math.cosh(x) - def test_cosh(): truth = pymath.cosh(0.42) assert (cosh_func(0.42) - truth) < 1e-6 - @basic def degrees_func(x): return math.degrees(x) - def test_degrees(): truth = pymath.degrees(0.42) assert (degrees_func(0.42) - truth) < 1e-6 - @basic def erf_func(x): return math.erf(x) - def test_erf(): truth = pymath.erf(0.42) assert (erf_func(0.42) - truth) < 1e-6 - @basic def erfc_func(x): return math.erfc(x) - def test_erfc(): truth = pymath.erfc(0.42) assert (erfc_func(0.42) - truth) < 1e-6 - @basic def exp_func(x): return math.exp(x) - def test_exp(): truth = pymath.exp(0.42) assert (exp_func(0.42) - truth) < 1e-6 - @basic def expm1_func(x): return math.expm1(x) - def test_expm1(): truth = pymath.expm1(0.42) assert (expm1_func(0.42) - truth) < 1e-6 - @basic def fabs_func(x): return math.fabs(x) - def test_fabs(): truth = pymath.fabs(0.42) assert (fabs_func(0.42) - truth) < 1e-6 - @basic def floor_func(x): return math.floor(x) - def test_floor(): truth = pymath.floor(0.42) assert (floor_func(0.42) - truth) < 1e-6 - @basic def fmod_func(x, y): return math.fmod(x, y) - def test_fmod(): truth = pymath.fmod(0.42, 0.42) assert (fmod_func(0.42, 0.42) - truth) < 1e-6 - @basic def gamma_func(x): return math.gamma(x) - def test_gamma(): truth = pymath.gamma(0.42) assert (gamma_func(0.42) - truth) < 1e-6 - @basic def isfinite_func(x): return math.isfinite(x) - def test_isfinite(): truth = pymath.isfinite(0.42) assert (isfinite_func(0.42) - truth) < 1e-6 - @basic def isinf_func(x): return math.isinf(x) - def test_isinf(): truth = pymath.isinf(0.42) assert (isinf_func(0.42) - truth) < 1e-6 - @basic def isnan_func(x): return math.isnan(x) - def test_isnan(): truth = pymath.isnan(0.42) assert (isnan_func(0.42) - truth) < 1e-6 - @basic def lgamma_func(x): return math.lgamma(x) - def test_lgamma(): truth = pymath.lgamma(0.42) assert (lgamma_func(0.42) - truth) < 1e-6 +@basic +def log_func(x, base): + return math.log(x, base) + +def test_log(): + truth = pymath.log(0.42, 0.42) + assert (log_func(0.42, 0.42) - truth) < 1e-6 @basic def log10_func(x): return math.log10(x) - def test_log10(): truth = pymath.log10(0.42) assert (log10_func(0.42) - truth) < 1e-6 - @basic def log1p_func(x): return math.log1p(x) - def test_log1p(): truth = pymath.log1p(0.42) assert (log1p_func(0.42) - truth) < 1e-6 - @basic def log2_func(x): return math.log2(x) - def test_log2(): truth = pymath.log2(0.42) assert (log2_func(0.42) - truth) < 1e-6 - @basic def pow_func(x, y): return math.pow(x, y) - def test_pow(): truth = pymath.pow(0.42, 0.42) assert (pow_func(0.42, 0.42) - truth) < 1e-6 - @basic def radians_func(x): return math.radians(x) - def test_radians(): truth = pymath.radians(0.42) assert (radians_func(0.42) - truth) < 1e-6 - @basic def remainder_func(x, y): return math.remainder(x, y) - def test_remainder(): truth = pymath.remainder(0.42, 0.42) assert (remainder_func(0.42, 0.42) - truth) < 1e-6 - @basic def sin_func(x): return math.sin(x) - def test_sin(): truth = pymath.sin(0.42) assert (sin_func(0.42) - truth) < 1e-6 - @basic def sinh_func(x): return math.sinh(x) - def test_sinh(): truth = pymath.sinh(0.42) assert (sinh_func(0.42) - truth) < 1e-6 - @basic def sqrt_func(x): return math.sqrt(x) - def test_sqrt(): truth = pymath.sqrt(0.42) assert (sqrt_func(0.42) - truth) < 1e-6 - @basic def tan_func(x): return math.tan(x) - def test_tan(): truth = pymath.tan(0.42) assert (tan_func(0.42) - truth) < 1e-6 - @basic def tanh_func(x): return math.tanh(x) - def test_tanh(): truth = pymath.tanh(0.42) assert (tanh_func(0.42) - truth) < 1e-6 - @basic def trunc_func(x): return math.trunc(x) - def test_trunc(): truth = pymath.trunc(0.42) assert (trunc_func(0.42) - truth) < 1e-6 - @basic def ulp_func(x): return math.ulp(x) - def test_ulp(): truth = pymath.ulp(0.42) assert (ulp_func(0.42) - truth) < 1e-6