From 9b5510fa6c5e9214f1a4909c2a59a20ed9b1d2df Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 28 Mar 2026 13:26:16 -0700 Subject: [PATCH 1/2] Optimize unnecessary modulo expressions in reshape idx lambdas. --- .basedpyright/baseline.json | 40 ----------------------- pytato/transform/lower_to_index_lambda.py | 38 ++++++++++++++++++--- 2 files changed, 33 insertions(+), 45 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 49915bc9f..d007d7315 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -7663,46 +7663,6 @@ "lineCount": 1 } }, - { - "code": "reportReturnType", - "range": { - "startColumn": 11, - "endColumn": 87, - "lineCount": 4 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 87, - "lineCount": 4 - } - }, - { - "code": "reportOperatorIssue", - "range": { - "startColumn": 8, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportOperatorIssue", - "range": { - "startColumn": 8, - "endColumn": 60, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 8, - "endColumn": 86, - "lineCount": 2 - } - }, { "code": "reportUnknownMemberType", "range": { diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 3537b0486..49c3d3d64 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -36,7 +36,7 @@ import pymbolic.primitives as prim from pymbolic import ArithmeticExpression -from pytools import UniqueNameGenerator +from pytools import UniqueNameGenerator, product from pytato.array import ( AbstractResultWithNamedArrays, @@ -128,12 +128,40 @@ def _generate_index_expressions( old_size_tills = old_size_tills[::-1] flattened_index_expn = sum( - index_var*new_stride - for index_var, new_stride in zip(index_vars, new_strides, strict=True)) + ( + index_var if new_stride == 1 else index_var * new_stride + for index_var, new_stride in zip( + index_vars[1:], new_strides[1:], strict=True + ) + ), + start=( + (index_vars[0] if new_strides[0] == 1 else index_vars[0] * new_strides[0]) + if len(index_vars) + else 0 + ), + ) + + old_size = product(old_shape) + + def _mod( + num: ArithmeticExpression, denom: ArithmeticExpression + ) -> ArithmeticExpression: + from pymbolic.typing import Integer + if isinstance(old_size, Integer) and denom == old_size and denom != 0: + return num + # Pyright has a point: complex numbers don't support '%'. + return num % denom # pyright: ignore[reportOperatorIssue,reportUnknownVariableType] + + def _floordiv( + num: ArithmeticExpression, denom: ArithmeticExpression + ) -> ArithmeticExpression: + if denom == 1: + return num + # pyright has a point: complex numbers don't support '//'. + return num // denom # pyright: ignore[reportOperatorIssue,reportUnknownVariableType] return tuple( - # Mypy has a point: complex numbers don't support '//'. - (flattened_index_expn % old_size_till) // old_stride # type: ignore[operator] + _floordiv(_mod(flattened_index_expn, old_size_till), old_stride) # pyright: ignore[reportArgumentType] for old_size_till, old_stride in zip(old_size_tills, old_strides, strict=True)) From 9ab995b2aedab1a17bc4e8aad05794cfe03c991c Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 28 Mar 2026 13:26:31 -0700 Subject: [PATCH 2/2] Test no sub-optimality in reshape lowering. --- test/test_pytato.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 9d07d5ccd..f217746f8 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1429,6 +1429,23 @@ def test_lower_to_index_lambda(): assert idx_lambda.expr.index_tuple[4] == Variable("_1") +def test_lower_to_index_lambda_flatten_reshape(): + # Before commit<=577cb39, pytato would generate redundant floor div + # and modulo computations. Specifically, in this case we would get: + # out[_0] = x[((0 + _0*1) % 40) // 4, ((0 + _0*1) % 4) // 1]. + + from pymbolic import parse + + from pytato.array import IndexLambda + x = pt.make_placeholder(name="x", dtype=float, shape=(10, 4)) + idx_lambda = pt.to_index_lambda(x.reshape(-1)) + assert isinstance(idx_lambda, IndexLambda) + assert idx_lambda.expr.index_tuple == ( + parse("_0 // 4"), + parse("_0 % 4"), + ) + + def test_reserved_binding_name_patterns(): from pytato.transform.metadata import BINDING_NAME_RESERVED_PATTERN