diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index a2885b7bf..42a0b5730 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -201,6 +201,12 @@ def map_nan(self, expr: p.NaN): self.saw_inf_or_nan = True return super().map_nan(expr) + @override + def map_variable(self, expr: p.Variable): + if expr.name in ("HUGE_VAL", "INFINITY"): + self.saw_inf_or_nan = True + return super().map_variable(expr) + def c99_preamble_generator(preamble_info: PreambleInfo) -> Iterator[tuple[str, str]]: if any(dtype.is_integral() for dtype in preamble_info.seen_dtypes): @@ -561,6 +567,11 @@ def c_symbol_mangler(kernel, name): if name in ["INT_MAX", "INT_MIN"]: return NumpyType(np.dtype(np.int32)), name + if name == "INFINITY": + return NumpyType(np.dtype(np.float32)), name + if name == "HUGE_VAL": + return NumpyType(np.dtype(np.float64)), name + return None # }}} diff --git a/test/test_target.py b/test/test_target.py index 302d451b5..c47d20fd9 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -875,6 +875,22 @@ def test_float3(): assert "float3" in device_code +def test_argmax_ctarget_floating_point(): + for dtype in (np.float32, np.float64): + knl = lp.make_kernel( + "{[i]: 0<=i 1: