Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -7663,46 +7663,6 @@
"lineCount": 1
}
},
{
"code": "reportReturnType",
"range": {
"startColumn": 11,
"endColumn": 87,
"lineCount": 4
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 11,
"endColumn": 87,
"lineCount": 4
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 8,
"endColumn": 46,
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 8,
"endColumn": 60,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 8,
"endColumn": 86,
"lineCount": 2
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down
38 changes: 33 additions & 5 deletions pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

import pymbolic.primitives as prim
from pymbolic import ArithmeticExpression
from pytools import UniqueNameGenerator
from pytools import UniqueNameGenerator, product

from pytato.array import (
AbstractResultWithNamedArrays,
Expand Down Expand Up @@ -128,12 +128,40 @@ def _generate_index_expressions(
old_size_tills = old_size_tills[::-1]

flattened_index_expn = sum(
index_var*new_stride
for index_var, new_stride in zip(index_vars, new_strides, strict=True))
(
index_var if new_stride == 1 else index_var * new_stride
for index_var, new_stride in zip(
index_vars[1:], new_strides[1:], strict=True
)
),
start=(
(index_vars[0] if new_strides[0] == 1 else index_vars[0] * new_strides[0])
if len(index_vars)
else 0
),
)

old_size = product(old_shape)

def _mod(
num: ArithmeticExpression, denom: ArithmeticExpression
) -> ArithmeticExpression:
from pymbolic.typing import Integer
if isinstance(old_size, Integer) and denom == old_size and denom != 0:
return num
# Pyright has a point: complex numbers don't support '%'.
return num % denom # pyright: ignore[reportOperatorIssue,reportUnknownVariableType]

def _floordiv(
num: ArithmeticExpression, denom: ArithmeticExpression
) -> ArithmeticExpression:
if denom == 1:
return num
# pyright has a point: complex numbers don't support '//'.
return num // denom # pyright: ignore[reportOperatorIssue,reportUnknownVariableType]

return tuple(
# Mypy has a point: complex numbers don't support '//'.
(flattened_index_expn % old_size_till) // old_stride # type: ignore[operator]
_floordiv(_mod(flattened_index_expn, old_size_till), old_stride) # pyright: ignore[reportArgumentType]
for old_size_till, old_stride in zip(old_size_tills, old_strides, strict=True))


Expand Down
17 changes: 17 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,23 @@ def test_lower_to_index_lambda():
assert idx_lambda.expr.index_tuple[4] == Variable("_1")


def test_lower_to_index_lambda_flatten_reshape():
# Before commit<=577cb39, pytato would generate redundant floor div
# and modulo computations. Specifically, in this case we would get:
# out[_0] = x[((0 + _0*1) % 40) // 4, ((0 + _0*1) % 4) // 1].

from pymbolic import parse

from pytato.array import IndexLambda
x = pt.make_placeholder(name="x", dtype=float, shape=(10, 4))
idx_lambda = pt.to_index_lambda(x.reshape(-1))
assert isinstance(idx_lambda, IndexLambda)
assert idx_lambda.expr.index_tuple == (
parse("_0 // 4"),
parse("_0 % 4"),
)


def test_reserved_binding_name_patterns():
from pytato.transform.metadata import BINDING_NAME_RESERVED_PATTERN

Expand Down
Loading