Skip to content
200 changes: 64 additions & 136 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -2547,6 +2547,22 @@
"lineCount": 1
}
},
{
"code": "reportIncompatibleVariableOverride",
"range": {
"startColumn": 6,
"endColumn": 18,
"lineCount": 1
}
},
{
"code": "reportIncompatibleVariableOverride",
"range": {
"startColumn": 6,
"endColumn": 18,
"lineCount": 1
}
},
{
"code": "reportConstantRedefinition",
"range": {
Expand Down Expand Up @@ -2627,6 +2643,14 @@
"lineCount": 1
}
},
{
"code": "reportCallInDefaultInitializer",
"range": {
"startColumn": 43,
"endColumn": 54,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand Down Expand Up @@ -5617,6 +5641,14 @@
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 4,
"endColumn": 18,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
Expand Down Expand Up @@ -7041,134 +7073,6 @@
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 17,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 16,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 28,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 19,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 18,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 29,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 39,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 28,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 16,
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 29,
"lineCount": 1
}
},
{
"code": "reportPrivateUsage",
"range": {
Expand All @@ -7177,14 +7081,6 @@
"lineCount": 1
}
},
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
Expand Down Expand Up @@ -10741,6 +10637,38 @@
"lineCount": 1
}
},
{
"code": "reportUnusedExpression",
"range": {
"startColumn": 4,
"endColumn": 9,
"lineCount": 1
}
},
{
"code": "reportUnusedExpression",
"range": {
"startColumn": 8,
"endColumn": 13,
"lineCount": 1
}
},
{
"code": "reportUnusedExpression",
"range": {
"startColumn": 8,
"endColumn": 13,
"lineCount": 1
}
},
{
"code": "reportUnusedExpression",
"range": {
"startColumn": 4,
"endColumn": 9,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down
1 change: 1 addition & 0 deletions .test-conda-env-py3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ dependencies:
- jax
- openmpi # Force using Open MPI since our pytest infrastructure needs it
- graphviz # for visualization tests
- matplotlib-base
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,17 @@

# It's :data:, not :class:, but we can't tell autodoc that.
["py:class", r"types\.EllipsisType"],
# pytools
# Got documented in Feb 2026, try removing?
["py:class", "ToTagSetConvertible"],
]


sphinxconfig_missing_reference_aliases = {
# pymbolic
"ArithmeticExpression": "obj:pymbolic.ArithmeticExpression",
# pytools
"lp.TemporaryVariable": "class:loopy.TemporaryVariable",
}


Expand Down
12 changes: 12 additions & 0 deletions pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def set_debug_enabled(flag: bool) -> None:
AxisPermutation,
BasicIndex,
Concatenate,
CSRMatmul,
CSRMatrix,
DataWrapper,
DictOfNamedArrays,
Einsum,
Expand All @@ -70,6 +72,8 @@ def set_debug_enabled(flag: bool) -> None:
Reshape,
Roll,
SizeParam,
SparseMatmul,
SparseMatrix,
Stack,
arange,
broadcast_to,
Expand All @@ -87,6 +91,7 @@ def set_debug_enabled(flag: bool) -> None:
logical_and,
logical_not,
logical_or,
make_csr_matrix,
make_data_wrapper,
make_dict_of_named_arrays,
make_placeholder,
Expand All @@ -99,6 +104,7 @@ def set_debug_enabled(flag: bool) -> None:
reshape,
roll,
set_traceback_tag_enabled,
sparse_matmul,
squeeze,
stack,
transpose,
Expand Down Expand Up @@ -179,6 +185,8 @@ def set_debug_enabled(flag: bool) -> None:
"Axis",
"AxisPermutation",
"BasicIndex",
"CSRMatmul",
"CSRMatrix",
"Concatenate",
"DataWrapper",
"DictOfNamedArrays",
Expand All @@ -200,6 +208,8 @@ def set_debug_enabled(flag: bool) -> None:
"Reshape",
"Roll",
"SizeParam",
"SparseMatmul",
"SparseMatrix",
"Stack",
"Target",
"abs",
Expand Down Expand Up @@ -247,6 +257,7 @@ def set_debug_enabled(flag: bool) -> None:
"logical_and",
"logical_not",
"logical_or",
"make_csr_matrix",
"make_data_wrapper",
"make_dict_of_named_arrays",
"make_distributed_recv",
Expand All @@ -273,6 +284,7 @@ def set_debug_enabled(flag: bool) -> None:
"show_fancy_placeholder_data_flow",
"sin",
"sinh",
"sparse_matmul",
"sqrt",
"squeeze",
"stack",
Expand Down
23 changes: 23 additions & 0 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pytato.array import (
Array,
Concatenate,
CSRMatmul,
DictOfNamedArrays,
Einsum,
IndexBase,
Expand Down Expand Up @@ -155,6 +156,20 @@ def map_einsum(self, expr: Einsum) -> None:
self.array_to_users[dim].append(expr)
self.rec(dim)

def map_csr_matmul(self, expr: CSRMatmul) -> None:
for ary in (
expr.matrix.elem_values,
expr.matrix.elem_col_indices,
expr.matrix.row_starts,
expr.array):
self.array_to_users[ary].append(expr)
self.rec(ary)

for dim in expr.shape:
if isinstance(dim, Array):
self.array_to_users[dim].append(expr)
self.rec(dim)

def map_named_array(self, expr: NamedArray) -> None:
self.rec(expr._container)

Expand Down Expand Up @@ -378,6 +393,14 @@ def map_concatenate(self, expr: Concatenate) -> list[ArrayOrNames]:
def map_einsum(self, expr: Einsum) -> list[ArrayOrNames]:
return self._get_preds_from_shape(expr.shape) + list(expr.args)

def map_csr_matmul(self, expr: CSRMatmul) -> list[ArrayOrNames]:
return [
*self._get_preds_from_shape(expr.shape),
expr.matrix.elem_values,
expr.matrix.elem_col_indices,
expr.matrix.row_starts,
expr.array]

def map_loopy_call(self, expr: LoopyCall) -> list[ArrayOrNames]:
return [ary for ary in expr.bindings.values() if isinstance(ary, Array)]

Expand Down
Loading
Loading