diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 49915bc9f..e27c54c1e 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1825,46 +1825,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 13, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 16, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 16, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 19, - "endColumn": 25, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -6657,38 +6617,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 21, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 26, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 15, - "endColumn": 44, - "lineCount": 1 - } - }, { "code": "reportUnreachable", "range": { @@ -6713,22 +6641,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 43, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 43, - "endColumn": 82, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -6737,22 +6649,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 24, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 24, - "endColumn": 83, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -7313,22 +7209,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 43, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 43, - "endColumn": 78, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -8129,38 +8009,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 21, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 49, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 43, - "endColumn": 49, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 6de9e5de4..28239a8e6 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -52,6 +52,7 @@ from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.transform import ( ArrayOrNames, + CachedMapper, CachedWalkMapper, CombineMapper, Mapper, @@ -720,10 +721,9 @@ def rec(self, expr: ArrayOrNames) -> int: try: return self._cache_retrieve(inputs) except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - s = Mapper.rec(self, expr) + # Using super(CachedMapper, self) instead of super() to bypass + # CachedMapper.rec and avoid double caching + s = super(CachedMapper, self).rec(expr) if ( isinstance(expr, Array) and ( diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 604ce4826..9f6739fef 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -280,15 +280,15 @@ def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: """Call the mapper method of *expr* and return the result.""" - method: Callable[..., FunctionResultT] | None - try: - method = self.map_function_definition # type: ignore[attr-defined] + method_name = "map_function_definition" + method: Callable[..., FunctionResultT] = cast( + "Callable[..., FunctionResultT]", + getattr(self, method_name)) except AttributeError: raise ValueError( f"{type(self).__name__} lacks a mapper method for functions.") from None - assert method is not None return method(expr, *args, **kwargs) @overload @@ -523,10 +523,11 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: try: return self._cache_retrieve(inputs) except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - return self._cache_add(inputs, Mapper.rec(self, expr, *args, **kwargs)) + # Reminder: If overriding this in a subclass and reimplementing the cache + # lookup logic there, must use super(CachedMapper, self) instead of + # super() below to avoid double caching, + # see https://github.com/inducer/pytato/pull/654. + return self._cache_add(inputs, super().rec(expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs @@ -535,11 +536,12 @@ def rec_function_definition( try: return self._function_cache_retrieve(inputs) except KeyError: + # Reminder: If overriding this in a subclass and reimplementing the cache + # lookup logic there, must use super(CachedMapper, self) instead of + # super() below to avoid double caching, + # see https://github.com/inducer/pytato/pull/654. return self._function_cache_add( - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs)) + inputs, super().rec_function_definition(expr, *args, **kwargs)) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -2004,10 +2006,10 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: try: return self._cache_retrieve(inputs) except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr))) + # Using super(CachedMapper, self) instead of super() to bypass + # CachedMapper.rec and avoid double caching + return self._cache_add(inputs, + super(CachedMapper, self).rec(self.map_fn(expr))) # }}} diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 36aab2391..e0c82b0f0 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -87,6 +87,7 @@ from pytato.transform import ( ArrayOrNames, ArrayOrNamesOrFunctionDefTc, + CachedMapper, CopyMapper, Mapper, TransformMapperCache, @@ -516,15 +517,14 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: try: return self._cache_retrieve(inputs) except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = Mapper.rec(self, expr) + # Using super(CachedMapper, self) instead of super() to bypass + # CachedMapper.rec and avoid double caching + result = super(CachedMapper, self).rec(expr) if not isinstance( expr, AbstractResultWithNamedArrays | DistributedSendRefHolder): assert isinstance(expr, Array) - # type-ignore reason: passed "ArrayOrNames"; expected "Array" - result = self._attach_tags(expr, result) # type: ignore[arg-type] + assert isinstance(result, Array) + result = self._attach_tags(expr, result) return self._cache_add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: