diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 258717472..f1d17230a 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -429,6 +429,47 @@ def map_quotient(self, expr: p.Quotient): else: return self.combine([n_dtype_set, d_dtype_set]) + def _map_int_div_modulo(self, expr: p.FloorDiv | p.Remainder): + n_dtype_set = self.rec(expr.numerator) + d_dtype_set = self.rec(expr.denominator) + + if not (n_dtype_set and d_dtype_set): + return cast("list[NumpyType]", []) + + n_dtype = cast("NumpyType", n_dtype_set[0]).dtype + d_dtype = cast("NumpyType", d_dtype_set[0]).dtype + num = ( + np.empty(0, dtype=n_dtype) + if not is_integer(expr.numerator) + else expr.numerator + ) + denom = ( + np.empty(0, dtype=d_dtype) + if not is_integer(expr.denominator) + else expr.denominator + ) + denom = ( + cast("int | np.integer", denom + 1) + if is_integer(denom) and denom == 0 + else denom + ) # avoid divide by zero. + + if is_integer(num) and is_integer(denom): + return self.rec(num // denom) + + floor_div_np = num // denom + assert isinstance(floor_div_np, np.ndarray) + + return [NumpyType(floor_div_np.dtype)] + + @override + def map_floor_div(self, expr: p.FloorDiv): + return self._map_int_div_modulo(expr) + + @override + def map_remainder(self, expr: p.Remainder): + return self._map_int_div_modulo(expr) + @override def map_constant(self, expr: object): if isinstance(expr, np.generic): diff --git a/test/test_loopy.py b/test/test_loopy.py index 8badf6016..2f3995d55 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -3733,6 +3733,20 @@ def test_type_cast_parse_stringify_roundtrip(): assert expr == parsed +def test_floor_div_modulo_with_uint_index(): + # See + knl = lp.make_kernel( + "{[i]: 0<=i<10}", + "a[map[i] // 2, map[i] % 35] = i", + [ + lp.GlobalArg("map", dtype=np.uint64, shape=lp.auto), + lp.GlobalArg("a", dtype=np.float64, shape=(10, 4)), + ], + ) + # check the codegen is successful + lp.generate_code_v2(knl).device_code() + + if __name__ == "__main__": import sys if len(sys.argv) > 1: