diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index b7337de4f..217a16527 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -8461,54 +8461,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 4, - "endColumn": 9, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 11, - "endColumn": 15, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 17, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 24, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 38, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 51, - "endColumn": 62, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -8517,22 +8469,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 16, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 16, - "endColumn": 20, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { diff --git a/pytato/array.py b/pytato/array.py index c0a832c04..aa8538851 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -538,7 +538,9 @@ def _entries_are_identical( # {{{ array interface -ConvertibleToIndexExpr = Union[Integer, slice, "Array", EllipsisType, None] +ConvertibleToIndexExpr = Union[ + Integer, slice, "NormalizedSlice", "Array", EllipsisType, None +] IndexExpr = Union[Integer, "NormalizedSlice", "Array", None] PyScalarType = type[bool] | type[int] | type[float] | type[complex] DtypeOrPyScalarType = _dtype_any | PyScalarType diff --git a/pytato/utils.py b/pytato/utils.py index 19dc08ef7..ad48e388a 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -438,7 +438,7 @@ def _is_non_positive(expr: ShapeComponent) -> Bool: # {{{ normalized slice -def _normalize_slice(slice_: slice, +def _normalize_slice(slice_: slice | NormalizedSlice, axis_len: ShapeComponent) -> NormalizedSlice: start, stop, step = slice_.start, slice_.stop, slice_.step if step is None: @@ -458,28 +458,24 @@ def _normalize_slice(slice_: slice, if start is None: start = default_start else: - if isinstance(axis_len, INT_CLASSES): - if -axis_len <= start < axis_len: - start = start % axis_len - elif start >= axis_len: - start = axis_len if step > 0 else axis_len - 1 - else: - start = 0 if step > 0 else -1 + if _is_non_negative(start + axis_len) and not _is_non_positive( + axis_len - start + ): + start = start % axis_len + elif _is_non_negative(start - axis_len): + start = axis_len if step > 0 else axis_len - 1 else: - raise NotImplementedError + start = 0 if step > 0 else -1 if stop is None: stop = default_stop else: - if isinstance(axis_len, INT_CLASSES): - if -axis_len <= stop < axis_len: - stop = stop % axis_len - elif stop >= axis_len: - stop = axis_len if step > 0 else axis_len - 1 - else: - stop = 0 if step > 0 else -1 + if _is_non_negative(stop + axis_len) and not _is_non_positive(axis_len - stop): + stop = stop % axis_len + elif _is_non_negative(stop - axis_len): + stop = axis_len if step > 0 else axis_len - 1 else: - raise NotImplementedError + stop = 0 if step > 0 else -1 return NormalizedSlice(start, stop, step) @@ -572,7 +568,7 @@ def _index_into( # {{{ validate index for i, idx in enumerate(indices): - if isinstance(idx, slice): + if isinstance(idx, (slice, NormalizedSlice)): pass elif isinstance(idx, INT_CLASSES): if not (_is_non_negative(idx + ary.shape[i]) @@ -594,7 +590,7 @@ def _index_into( # {{{ normalize slices normalized_indices: list[IndexExpr] = [_normalize_slice(idx, axis_len) - if isinstance(idx, slice) + if isinstance(idx, (slice, NormalizedSlice)) else idx for idx, axis_len in zip(indices, ary.shape, strict=True)] diff --git a/test/test_pytato.py b/test/test_pytato.py index 8ec1a0748..f8f6ab77c 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -2064,6 +2064,12 @@ def test_replace_if_different_on_idx_lambda(): assert 3*x == new_expr +def test_normalized_slice_is_valid_indexee(): + from pytato.array import NormalizedSlice + a = pt.make_placeholder("a", 10) + assert a[NormalizedSlice(0, 10, 1)] == a[:] + + if __name__ == "__main__": import os if "INVOCATION_INFO" in os.environ: