File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1429,6 +1429,23 @@ def test_lower_to_index_lambda():
14291429 assert idx_lambda .expr .index_tuple [4 ] == Variable ("_1" )
14301430
14311431
1432+ def test_lower_to_index_lambda_flatten_reshape ():
1433+ # Before commit<=577cb39, pytato would generate redundant floor div
1434+ # and modulo computations. Specifically, in this case we would get:
1435+ # out[_0] = x[((0 + _0*1) % 40) // 4, ((0 + _0*1) % 4) // 1].
1436+
1437+ from pymbolic import parse
1438+
1439+ from pytato .array import IndexLambda
1440+ x = pt .make_placeholder (name = "x" , dtype = float , shape = (10 , 4 ))
1441+ idx_lambda = pt .to_index_lambda (x .reshape (- 1 ))
1442+ assert isinstance (idx_lambda , IndexLambda )
1443+ assert idx_lambda .expr .index_tuple == (
1444+ parse ("_0 // 4" ),
1445+ parse ("_0 % 4" ),
1446+ )
1447+
1448+
14321449def test_reserved_binding_name_patterns ():
14331450 from pytato .transform .metadata import BINDING_NAME_RESERVED_PATTERN
14341451
You can’t perform that action at this time.
0 commit comments