Skip to content

Implement StructuredDotGradCSR and StructuredDotGradCSC in numba backend#1860

Open
tomicapretto wants to merge 5 commits intopymc-devs:mainfrom
tomicapretto:sparse_gradients_numba
Open

Implement StructuredDotGradCSR and StructuredDotGradCSC in numba backend#1860
tomicapretto wants to merge 5 commits intopymc-devs:mainfrom
tomicapretto:sparse_gradients_numba

Conversation

@tomicapretto
Copy link
Contributor

@tomicapretto tomicapretto commented Jan 29, 2026

Description

The main contribution of this PR is the implementation of StructuredDotGradCSR and StructuredDotGradCSC in the numba backend.

While I was working on it, I noticed Ops SpSum and SparseFromDense were running in object mode, so I also implemented them.

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 2 times, most recently from 190c587 to 4690cde Compare January 29, 2026 13:34
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 2 times, most recently from c96ae8c to 2025883 Compare January 29, 2026 14:56
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 2 times, most recently from 512fb59 to 32099f1 Compare January 30, 2026 03:42
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 3 times, most recently from a054b5d to 6af5a1a Compare January 31, 2026 17:54
@tomicapretto tomicapretto marked this pull request as ready for review January 31, 2026 17:56
@tomicapretto
Copy link
Contributor Author

The test that fails is:

FAILED tests/tensor/test_slinalg.py::TestSchur::test_schur_empty - ValueError: negative dimensions not allowed

which is unrelated to this PR.

@ricardoV94
Copy link
Member

The test that fails is:

FAILED tests/tensor/test_slinalg.py::TestSchur::test_schur_empty - ValueError: negative dimensions not allowed

which is unrelated to this PR.

@jessegrabowski


for col_idx in range(size):
for value_idx in range(x_ptr[col_idx], x_ptr[col_idx + 1]):
output[value_idx] = np.dot(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to be careful with np.dot. IIRC numba overload doesn't support integer / mixed dtypes well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, I'm using it since np.sum(x * y) was slower. There are a bunch of test that pass different data types, and they have all passed. Probably that's ok?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its probably fine as long as we're upcasting the inputs to a common dtype in the make_node of Dot?

In the medium term we should consider re-implementing the BLAS calls ourselves

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have mixed integer / float types in the test. Or just discrete. I have >20% belief numba np.dot overload refuses those

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. The issue was not covered by tests (I mixed something I saw in tests for Dot with what we're doing here with StructuredDot) plus Numba does not accept mixed types in np.dot.

    Rejected as the implementation raised a specific error:
      TypingError: np.dot() arguments must all have the same dtype

I'll double check the upcasting

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't solve the numba side? In numba you'll have to explicitly cast one or both of the inputs to dot them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're so right!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably fine, but perhaps add a print statement under a if pytensor.config.compiler_verbose so users can track down a potential source of slow down. We do this for a couple of linalg numba dispatch.

You should be able to know at dispatch time whether a conversion will be needed (saying this because the warning shouldn't be inside the numba function)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added it, thanks for the hint @ricardoV94

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, I just left some minor comments

axis = op.axis

@numba_basic.numba_njit
def perform(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does mypy freak out if you typehint this as SparseArray -> SparseArray? It would make the function more clear. Not required if it causes a headache (typehinting these overloads often does)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tried it, but the SpSum op returns a dense array (see this).

What happens here is that this calls the function I implemented in overload_sum in variable.py.
Maybe a global somewhere (per op or at the top of the file) saying that many (if not all) Ops are using overloads written in a separate python file?

# General spmspm algorithm in CSR format
@numba_basic.numba_njit
def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data):
def _spmspm_csr(x, y, n_row, n_col):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth considering a bit of reorganization here for future extensibility. We can make a new sparse/math sub-module and have a sum.py file with each of these inner njit functions defined independently. numba_funcify_SparseDenseMultiply can still live here, but it would be just an input checker and routing to the correct function. I'm thinking about what it will look like in the future to add support for a new sparse type.

The pattern I'm thinking about is what we are doing with linalg, for example QZ: each case is defined separately here, then the actual dispatch is defined here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds good to me. I thought a bit about it prior starting to work on this, but I saw the other ops in this module were implemented this way, so I thought it was for a reason. Maybe I just overthought about it and it was simple convenience.

@tomicapretto tomicapretto marked this pull request as draft February 2, 2026 13:24
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 2 times, most recently from e19f795 to 38e92e1 Compare February 2, 2026 13:40
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 3 times, most recently from b886acd to 63775d6 Compare February 2, 2026 14:18
@tomicapretto tomicapretto marked this pull request as ready for review February 2, 2026 14:29
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 3 times, most recently from a2e9294 to 575eceb Compare February 3, 2026 00:59
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch from 575eceb to 2873849 Compare February 5, 2026 12:04
@ricardoV94
Copy link
Member

Is this ready for final review?

@tomicapretto
Copy link
Contributor Author

Is this ready for final review?

Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants