From d96a1f258b2def5756f64aaea415f3adf013667d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 26 Mar 2026 10:43:08 -0500 Subject: [PATCH 1/2] fix typing for CachedMapper.rec/rec_function_definition --- .basedpyright/baseline.json | 152 ----------------------------------- pytato/analysis/__init__.py | 8 +- pytato/transform/__init__.py | 34 ++++---- pytato/transform/metadata.py | 12 +-- 4 files changed, 28 insertions(+), 178 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 49915bc9f..e27c54c1e 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1825,46 +1825,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 13, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 16, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 16, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 19, - "endColumn": 25, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -6657,38 +6617,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 21, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 26, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 15, - "endColumn": 44, - "lineCount": 1 - } - }, { "code": "reportUnreachable", "range": { @@ -6713,22 +6641,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 43, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 43, - "endColumn": 82, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -6737,22 +6649,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 24, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 24, - "endColumn": 83, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -7313,22 +7209,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 43, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 43, - "endColumn": 78, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { @@ -8129,38 +8009,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 21, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 49, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 43, - "endColumn": 49, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 6de9e5de4..28239a8e6 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -52,6 +52,7 @@ from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.transform import ( ArrayOrNames, + CachedMapper, CachedWalkMapper, CombineMapper, Mapper, @@ -720,10 +721,9 @@ def rec(self, expr: ArrayOrNames) -> int: try: return self._cache_retrieve(inputs) except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - s = Mapper.rec(self, expr) + # Using super(CachedMapper, self) instead of super() to bypass + # CachedMapper.rec and avoid double caching + s = super(CachedMapper, self).rec(expr) if ( isinstance(expr, Array) and ( diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 604ce4826..4f969285b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -280,15 +280,15 @@ def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: """Call the mapper method of *expr* and return the result.""" - method: Callable[..., FunctionResultT] | None - try: - method = self.map_function_definition # type: ignore[attr-defined] + method_name = "map_function_definition" + method: Callable[..., FunctionResultT] = cast( + "Callable[..., FunctionResultT]", + getattr(self, method_name)) except AttributeError: raise ValueError( f"{type(self).__name__} lacks a mapper method for functions.") from None - assert method is not None return method(expr, *args, **kwargs) @overload @@ -523,10 +523,11 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: try: return self._cache_retrieve(inputs) except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - return self._cache_add(inputs, Mapper.rec(self, expr, *args, **kwargs)) + # Reminder: If overriding this in a subclass and reimplementing the cache + # lookup logic there, must use super(CachedMapper, self) instead of + # super() below to avoid double caching, + # see https://github.com/inducer/pytato/pull/585. + return self._cache_add(inputs, super().rec(expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs @@ -535,11 +536,12 @@ def rec_function_definition( try: return self._function_cache_retrieve(inputs) except KeyError: + # Reminder: If overriding this in a subclass and reimplementing the cache + # lookup logic there, must use super(CachedMapper, self) instead of + # super() below to avoid double caching, + # see https://github.com/inducer/pytato/pull/585. return self._function_cache_add( - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs)) + inputs, super().rec_function_definition(expr, *args, **kwargs)) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -2004,10 +2006,10 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: try: return self._cache_retrieve(inputs) except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr))) + # Using super(CachedMapper, self) instead of super() to bypass + # CachedMapper.rec and avoid double caching + return self._cache_add(inputs, + super(CachedMapper, self).rec(self.map_fn(expr))) # }}} diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 36aab2391..e0c82b0f0 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -87,6 +87,7 @@ from pytato.transform import ( ArrayOrNames, ArrayOrNamesOrFunctionDefTc, + CachedMapper, CopyMapper, Mapper, TransformMapperCache, @@ -516,15 +517,14 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: try: return self._cache_retrieve(inputs) except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = Mapper.rec(self, expr) + # Using super(CachedMapper, self) instead of super() to bypass + # CachedMapper.rec and avoid double caching + result = super(CachedMapper, self).rec(expr) if not isinstance( expr, AbstractResultWithNamedArrays | DistributedSendRefHolder): assert isinstance(expr, Array) - # type-ignore reason: passed "ArrayOrNames"; expected "Array" - result = self._attach_tags(expr, result) # type: ignore[arg-type] + assert isinstance(result, Array) + result = self._attach_tags(expr, result) return self._cache_add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: From 646a66b3f16091954684ae3104da0acb29253036 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 26 Mar 2026 10:56:09 -0500 Subject: [PATCH 2/2] update PR URLs --- pytato/transform/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 4f969285b..9f6739fef 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -526,7 +526,7 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: # Reminder: If overriding this in a subclass and reimplementing the cache # lookup logic there, must use super(CachedMapper, self) instead of # super() below to avoid double caching, - # see https://github.com/inducer/pytato/pull/585. + # see https://github.com/inducer/pytato/pull/654. return self._cache_add(inputs, super().rec(expr, *args, **kwargs)) def rec_function_definition( @@ -539,7 +539,7 @@ def rec_function_definition( # Reminder: If overriding this in a subclass and reimplementing the cache # lookup logic there, must use super(CachedMapper, self) instead of # super() below to avoid double caching, - # see https://github.com/inducer/pytato/pull/585. + # see https://github.com/inducer/pytato/pull/654. return self._function_cache_add( inputs, super().rec_function_definition(expr, *args, **kwargs))