Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 0 additions & 64 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -8461,54 +8461,6 @@
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 4,
"endColumn": 9,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 11,
"endColumn": 15,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 17,
"endColumn": 21,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 24,
"endColumn": 36,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 38,
"endColumn": 49,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 51,
"endColumn": 62,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand All @@ -8517,22 +8469,6 @@
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 16,
"endColumn": 21,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 16,
"endColumn": 20,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
Expand Down
4 changes: 3 additions & 1 deletion pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,9 @@ def _entries_are_identical(

# {{{ array interface

ConvertibleToIndexExpr = Union[Integer, slice, "Array", EllipsisType, None]
ConvertibleToIndexExpr = Union[
Integer, slice, "NormalizedSlice", "Array", EllipsisType, None
]
IndexExpr = Union[Integer, "NormalizedSlice", "Array", None]
PyScalarType = type[bool] | type[int] | type[float] | type[complex]
DtypeOrPyScalarType = _dtype_any | PyScalarType
Expand Down
34 changes: 15 additions & 19 deletions pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def _is_non_positive(expr: ShapeComponent) -> Bool:

# {{{ normalized slice

def _normalize_slice(slice_: slice,
def _normalize_slice(slice_: slice | NormalizedSlice,
axis_len: ShapeComponent) -> NormalizedSlice:
start, stop, step = slice_.start, slice_.stop, slice_.step
if step is None:
Expand All @@ -458,28 +458,24 @@ def _normalize_slice(slice_: slice,
if start is None:
start = default_start
else:
if isinstance(axis_len, INT_CLASSES):
if -axis_len <= start < axis_len:
start = start % axis_len
elif start >= axis_len:
start = axis_len if step > 0 else axis_len - 1
else:
start = 0 if step > 0 else -1
if _is_non_negative(start + axis_len) and not _is_non_positive(
axis_len - start
):
start = start % axis_len
elif _is_non_negative(start - axis_len):
start = axis_len if step > 0 else axis_len - 1
else:
raise NotImplementedError
start = 0 if step > 0 else -1
Comment on lines +461 to +468

if stop is None:
stop = default_stop
else:
if isinstance(axis_len, INT_CLASSES):
if -axis_len <= stop < axis_len:
stop = stop % axis_len
elif stop >= axis_len:
stop = axis_len if step > 0 else axis_len - 1
else:
stop = 0 if step > 0 else -1
if _is_non_negative(stop + axis_len) and not _is_non_positive(axis_len - stop):
stop = stop % axis_len
elif _is_non_negative(stop - axis_len):
stop = axis_len if step > 0 else axis_len - 1
else:
raise NotImplementedError
stop = 0 if step > 0 else -1
Comment on lines +473 to +478

return NormalizedSlice(start, stop, step)

Expand Down Expand Up @@ -572,7 +568,7 @@ def _index_into(
# {{{ validate index

for i, idx in enumerate(indices):
if isinstance(idx, slice):
if isinstance(idx, (slice, NormalizedSlice)):
pass
elif isinstance(idx, INT_CLASSES):
if not (_is_non_negative(idx + ary.shape[i])
Expand All @@ -594,7 +590,7 @@ def _index_into(
# {{{ normalize slices

normalized_indices: list[IndexExpr] = [_normalize_slice(idx, axis_len)
if isinstance(idx, slice)
if isinstance(idx, (slice, NormalizedSlice))
else idx
for idx, axis_len
in zip(indices, ary.shape, strict=True)]
Expand Down
6 changes: 6 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,12 @@ def test_replace_if_different_on_idx_lambda():
assert 3*x == new_expr


def test_normalized_slice_is_valid_indexee():
from pytato.array import NormalizedSlice
a = pt.make_placeholder("a", 10)
assert a[NormalizedSlice(0, 10, 1)] == a[:]


Comment on lines +2069 to +2072
if __name__ == "__main__":
import os
if "INVOCATION_INFO" in os.environ:
Expand Down
Loading