Implement StructuredDotGradCSR and StructuredDotGradCSC in numba backend#1860
Implement StructuredDotGradCSR and StructuredDotGradCSC in numba backend#1860tomicapretto wants to merge 5 commits intopymc-devs:mainfrom
StructuredDotGradCSR and StructuredDotGradCSC in numba backend#1860Conversation
190c587 to
4690cde
Compare
c96ae8c to
2025883
Compare
512fb59 to
32099f1
Compare
a054b5d to
6af5a1a
Compare
|
The test that fails is: which is unrelated to this PR. |
|
|
|
||
| for col_idx in range(size): | ||
| for value_idx in range(x_ptr[col_idx], x_ptr[col_idx + 1]): | ||
| output[value_idx] = np.dot( |
There was a problem hiding this comment.
Have to be careful with np.dot. IIRC numba overload doesn't support integer / mixed dtypes well
There was a problem hiding this comment.
Argh, I'm using it since np.sum(x * y) was slower. There are a bunch of test that pass different data types, and they have all passed. Probably that's ok?
There was a problem hiding this comment.
Its probably fine as long as we're upcasting the inputs to a common dtype in the make_node of Dot?
In the medium term we should consider re-implementing the BLAS calls ourselves
There was a problem hiding this comment.
Do we have mixed integer / float types in the test. Or just discrete. I have >20% belief numba np.dot overload refuses those
There was a problem hiding this comment.
You're right. The issue was not covered by tests (I mixed something I saw in tests for Dot with what we're doing here with StructuredDot) plus Numba does not accept mixed types in np.dot.
Rejected as the implementation raised a specific error:
TypingError: np.dot() arguments must all have the same dtype
I'll double check the upcasting
There was a problem hiding this comment.
This doesn't solve the numba side? In numba you'll have to explicitly cast one or both of the inputs to dot them?
There was a problem hiding this comment.
You're so right!
There was a problem hiding this comment.
There was a problem hiding this comment.
Probably fine, but perhaps add a print statement under a if pytensor.config.compiler_verbose so users can track down a potential source of slow down. We do this for a couple of linalg numba dispatch.
You should be able to know at dispatch time whether a conversion will be needed (saying this because the warning shouldn't be inside the numba function)
There was a problem hiding this comment.
Just added it, thanks for the hint @ricardoV94
ricardoV94
left a comment
There was a problem hiding this comment.
This looks great, I just left some minor comments
| axis = op.axis | ||
|
|
||
| @numba_basic.numba_njit | ||
| def perform(x): |
There was a problem hiding this comment.
does mypy freak out if you typehint this as SparseArray -> SparseArray? It would make the function more clear. Not required if it causes a headache (typehinting these overloads often does)
There was a problem hiding this comment.
I have not tried it, but the SpSum op returns a dense array (see this).
What happens here is that this calls the function I implemented in overload_sum in variable.py.
Maybe a global somewhere (per op or at the top of the file) saying that many (if not all) Ops are using overloads written in a separate python file?
| # General spmspm algorithm in CSR format | ||
| @numba_basic.numba_njit | ||
| def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data): | ||
| def _spmspm_csr(x, y, n_row, n_col): |
There was a problem hiding this comment.
I think it's worth considering a bit of reorganization here for future extensibility. We can make a new sparse/math sub-module and have a sum.py file with each of these inner njit functions defined independently. numba_funcify_SparseDenseMultiply can still live here, but it would be just an input checker and routing to the correct function. I'm thinking about what it will look like in the future to add support for a new sparse type.
The pattern I'm thinking about is what we are doing with linalg, for example QZ: each case is defined separately here, then the actual dispatch is defined here.
There was a problem hiding this comment.
It sounds good to me. I thought a bit about it prior starting to work on this, but I saw the other ops in this module were implemented this way, so I thought it was for a reason. Maybe I just overthought about it and it was simple convenience.
e19f795 to
38e92e1
Compare
b886acd to
63775d6
Compare
a2e9294 to
575eceb
Compare
…those implementations in SparseFromDense
575eceb to
2873849
Compare
|
Is this ready for final review? |
Yes |
Description
The main contribution of this PR is the implementation of
StructuredDotGradCSRandStructuredDotGradCSCin the numba backend.While I was working on it, I noticed Ops
SpSumandSparseFromDensewere running in object mode, so I also implemented them.Checklist
Type of change