Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 0 additions & 152 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -1825,46 +1825,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 12,
"endColumn": 13,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 26,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 16,
"endColumn": 22,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 16,
"endColumn": 22,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 19,
"endColumn": 25,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand Down Expand Up @@ -6657,38 +6617,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 12,
"endColumn": 18,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 21,
"endColumn": 49,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 26,
"endColumn": 49,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 15,
"endColumn": 44,
"lineCount": 1
}
},
{
"code": "reportUnreachable",
"range": {
Expand All @@ -6713,22 +6641,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 43,
"endColumn": 53,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 43,
"endColumn": 82,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
Expand All @@ -6737,22 +6649,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 24,
"endColumn": 54,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 24,
"endColumn": 83,
"lineCount": 1
}
},
{
"code": "reportUnusedParameter",
"range": {
Expand Down Expand Up @@ -7313,22 +7209,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 43,
"endColumn": 53,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 43,
"endColumn": 78,
"lineCount": 1
}
},
{
"code": "reportIncompatibleMethodOverride",
"range": {
Expand Down Expand Up @@ -8129,38 +8009,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 12,
"endColumn": 18,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 21,
"endColumn": 31,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 49,
"endColumn": 55,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 43,
"endColumn": 49,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
Expand Down
8 changes: 4 additions & 4 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.transform import (
ArrayOrNames,
CachedMapper,
CachedWalkMapper,
CombineMapper,
Mapper,
Expand Down Expand Up @@ -720,10 +721,9 @@ def rec(self, expr: ArrayOrNames) -> int:
try:
return self._cache_retrieve(inputs)
except KeyError:
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
s = Mapper.rec(self, expr)
# Using super(CachedMapper, self) instead of super() to bypass
# CachedMapper.rec and avoid double caching
s = super(CachedMapper, self).rec(expr)
if (
isinstance(expr, Array)
and (
Expand Down
34 changes: 18 additions & 16 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,15 @@ def rec_function_definition(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
) -> FunctionResultT:
"""Call the mapper method of *expr* and return the result."""
method: Callable[..., FunctionResultT] | None

try:
method = self.map_function_definition # type: ignore[attr-defined]
method_name = "map_function_definition"
method: Callable[..., FunctionResultT] = cast(
"Callable[..., FunctionResultT]",
getattr(self, method_name))
except AttributeError:
raise ValueError(
f"{type(self).__name__} lacks a mapper method for functions.") from None

assert method is not None
return method(expr, *args, **kwargs)

@overload
Expand Down Expand Up @@ -523,10 +523,11 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
try:
return self._cache_retrieve(inputs)
except KeyError:
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
return self._cache_add(inputs, Mapper.rec(self, expr, *args, **kwargs))
# Reminder: If overriding this in a subclass and reimplementing the cache
# lookup logic there, must use super(CachedMapper, self) instead of
# super() below to avoid double caching,
# see https://github.com/inducer/pytato/pull/654.
return self._cache_add(inputs, super().rec(expr, *args, **kwargs))

def rec_function_definition(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
Expand All @@ -535,11 +536,12 @@ def rec_function_definition(
try:
return self._function_cache_retrieve(inputs)
except KeyError:
# Reminder: If overriding this in a subclass and reimplementing the cache
# lookup logic there, must use super(CachedMapper, self) instead of
# super() below to avoid double caching,
# see https://github.com/inducer/pytato/pull/654.
return self._function_cache_add(
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs))
inputs, super().rec_function_definition(expr, *args, **kwargs))

def clone_for_callee(
self, function: FunctionDefinition) -> Self:
Expand Down Expand Up @@ -2004,10 +2006,10 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
try:
return self._cache_retrieve(inputs)
except KeyError:
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr)))
# Using super(CachedMapper, self) instead of super() to bypass
# CachedMapper.rec and avoid double caching
return self._cache_add(inputs,
super(CachedMapper, self).rec(self.map_fn(expr)))

# }}}

Expand Down
12 changes: 6 additions & 6 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from pytato.transform import (
ArrayOrNames,
ArrayOrNamesOrFunctionDefTc,
CachedMapper,
CopyMapper,
Mapper,
TransformMapperCache,
Expand Down Expand Up @@ -516,15 +517,14 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
try:
return self._cache_retrieve(inputs)
except KeyError:
# Intentionally going to Mapper instead of super() to avoid
# double caching when subclasses of CachedMapper override rec,
# see https://github.com/inducer/pytato/pull/585
result = Mapper.rec(self, expr)
# Using super(CachedMapper, self) instead of super() to bypass
# CachedMapper.rec and avoid double caching
result = super(CachedMapper, self).rec(expr)
if not isinstance(
expr, AbstractResultWithNamedArrays | DistributedSendRefHolder):
assert isinstance(expr, Array)
# type-ignore reason: passed "ArrayOrNames"; expected "Array"
result = self._attach_tags(expr, result) # type: ignore[arg-type]
assert isinstance(result, Array)
result = self._attach_tags(expr, result)
return self._cache_add(inputs, result)

def map_named_call_result(self, expr: NamedCallResult) -> Array:
Expand Down
Loading