-
Notifications
You must be signed in to change notification settings - Fork 183
Implement StructuredDotGradCSR and StructuredDotGradCSC in numba backend
#1860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
026202a
59f0217
de647d0
33b10f9
2873849
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| import scipy.sparse as sp | ||
|
|
||
| import pytensor.sparse.basic as psb | ||
| from pytensor import config | ||
| from pytensor.link.numba.dispatch import basic as numba_basic | ||
| from pytensor.link.numba.dispatch.basic import ( | ||
| register_funcify_and_cache_key, | ||
|
|
@@ -13,10 +14,24 @@ | |
| Dot, | ||
| SparseDenseMultiply, | ||
| SparseDenseVectorMultiply, | ||
| SpSum, | ||
| StructuredDot, | ||
| StructuredDotGradCSC, | ||
| StructuredDotGradCSR, | ||
| ) | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(SpSum) | ||
| def numba_funcify_SpSum(op, node, **kwargs): | ||
| axis = op.axis | ||
|
|
||
| @numba_basic.numba_njit | ||
| def perform(x): | ||
| return x.sum(axis) | ||
|
|
||
| return perform | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(SparseDenseMultiply) | ||
| @register_funcify_default_op_cache_key(SparseDenseVectorMultiply) | ||
| def numba_funcify_SparseDenseMultiply(op, node, **kwargs): | ||
|
|
@@ -402,3 +417,233 @@ def dmspm(x, y): | |
| return spmdm_csr(y.T, x.T).T | ||
|
|
||
| return dmspm, cache_key | ||
|
|
||
|
|
||
| @register_funcify_and_cache_key(StructuredDotGradCSR) | ||
| @register_funcify_and_cache_key(StructuredDotGradCSC) | ||
| def numba_funcify_StructuredDotGrad(op, node, **kwargs): | ||
| """Overload StructuredDotGrad in Numba. | ||
|
|
||
| Let: | ||
| Z = structured_dot(X, Y) | ||
| L = L(Z), a scalar loss depending on Z. | ||
|
|
||
| This function computes the gradient of the loss with respect to X: | ||
|
|
||
| dL/dX = dot(dL/dZ, Y^T) | ||
|
|
||
| where G = dL/dZ is the accumulated (upstream) gradient. | ||
|
|
||
| The returned gradient is structured, preserving the sparsity pattern of X, | ||
| and only the `.data` component of the sparse matrix is computed. | ||
| If Y is sparse, the sparsity pattern of the result is not recomputed. | ||
| The output may contain explicit zeros at positions that would be structural zeros | ||
| if the sparsity structure were updated. | ||
|
|
||
| The core of the algorithm is: | ||
|
|
||
| dot(g_xy[i], y[j]) | ||
|
|
||
| where g_xy[i] (row of G) and y[j] (column of Y^T) are vectors of length 'k' | ||
|
|
||
| Reminder: | ||
| x.shape (n, p) | ||
| y.shape (p, k) | ||
| g_xy.shape (n, k) | ||
| """ | ||
| _, _, y, g_xy = node.inputs | ||
|
|
||
| y_dtype = y.type.dtype | ||
| y_is_sparse = psb._is_sparse_variable(y) | ||
| y_format = y.type.format if y_is_sparse else None | ||
|
|
||
| g_xy_dtype = g_xy.type.dtype | ||
| g_xy_is_sparse = psb._is_sparse_variable(g_xy) | ||
| g_xy_format = g_xy.type.format if g_xy_is_sparse else None | ||
|
|
||
| x_format = "csc" if isinstance(op, StructuredDotGradCSC) else "csr" | ||
| out_dtype = g_xy_dtype | ||
|
|
||
| cache_key = sha256( | ||
| str( | ||
| ( | ||
| type(op), | ||
| x_format, | ||
| y_format, | ||
| y_dtype, | ||
| g_xy_format, | ||
| out_dtype, | ||
| y.type.shape, | ||
| ) | ||
| ).encode() | ||
| ).hexdigest() | ||
|
tomicapretto marked this conversation as resolved.
|
||
|
|
||
| if not g_xy_is_sparse: | ||
| # X is sparse, Y and G_xy are dense. | ||
| if x_format == "csr": | ||
| if y.type.shape[1] == 1: | ||
| # If Y is actually 1D, use more performant specialized algorithm | ||
| # Inputs with ndims > 2 will never appear in the StructuredDot Op | ||
| @numba_basic.numba_njit | ||
| def _grad_spmdv_csr(x_indices, x_ptr, y, g_xy): | ||
| output = np.empty(len(x_indices), dtype=out_dtype) | ||
| size = len(x_ptr) - 1 | ||
| x_indices = x_indices.view(np.uint32) | ||
| x_ptr = x_ptr.view(np.uint32) | ||
| for row_idx in range(size): | ||
| for value_idx in range(x_ptr[row_idx], x_ptr[row_idx + 1]): | ||
| output[value_idx] = g_xy[row_idx] * y[x_indices[value_idx]] | ||
| return output | ||
|
|
||
| @numba_basic.numba_njit | ||
| def grad_spmdv_csr(x_indices, x_ptr, y, g_xy): | ||
| return _grad_spmdv_csr(x_indices, x_ptr, y[:, 0], g_xy[:, 0]) | ||
|
|
||
| return grad_spmdv_csr, cache_key | ||
| else: | ||
| # Y is a matrix | ||
| if config.compiler_verbose and y_dtype != out_dtype: | ||
| print( # noqa: T201 | ||
| "Numba StructuredDotGrad requires a type casting of inputs: " | ||
| f"{y_dtype=}, {g_xy_dtype=}." | ||
| ) | ||
|
|
||
| @numba_basic.numba_njit | ||
| def grad_spmdm_csr(x_indices, x_ptr, y, g_xy): | ||
| size = len(x_ptr) - 1 | ||
| x_indices = x_indices.view(np.uint32) | ||
| x_ptr = x_ptr.view(np.uint32) | ||
|
|
||
| if y_dtype != out_dtype: | ||
| new_out_dtype = np.result_type(y, g_xy) | ||
| output = np.zeros(len(x_indices), dtype=new_out_dtype) | ||
| y = y.astype(out_dtype) | ||
| g_xy = g_xy.astype(out_dtype) | ||
| else: | ||
| output = np.zeros(len(x_indices), dtype=out_dtype) | ||
|
|
||
| for row_idx in range(size): | ||
| for value_idx in range(x_ptr[row_idx], x_ptr[row_idx + 1]): | ||
| output[value_idx] = np.dot( | ||
| g_xy[row_idx], y[x_indices[value_idx]] | ||
| ) | ||
| return output | ||
|
|
||
| return grad_spmdm_csr, cache_key | ||
| else: | ||
| # X is CSC | ||
| @numba_basic.numba_njit | ||
| def grad_spmdm_csc(x_indices, x_ptr, y, g_xy): | ||
| # len(x_indices) gives the number of non-zero elements in X. | ||
| output = np.zeros(len(x_indices), dtype=out_dtype) | ||
| size = len(x_ptr) - 1 | ||
| x_indices = x_indices.view(np.uint32) | ||
| x_ptr = x_ptr.view(np.uint32) | ||
|
|
||
| for col_idx in range(size): | ||
| for value_idx in range(x_ptr[col_idx], x_ptr[col_idx + 1]): | ||
| output[value_idx] = np.dot( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have to be careful with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Argh, I'm using it since
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its probably fine as long as we're upcasting the inputs to a common dtype in the make_node of Dot? In the medium term we should consider re-implementing the BLAS calls ourselves
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have mixed integer / float types in the test. Or just discrete. I have >20% belief numba np.dot overload refuses those
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. The issue was not covered by tests (I mixed something I saw in tests for Dot with what we're doing here with StructuredDot) plus Numba does not accept mixed types in I'll double check the upcasting
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't solve the numba side? In numba you'll have to explicitly cast one or both of the inputs to dot them?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're so right!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably fine, but perhaps add a print statement under a You should be able to know at dispatch time whether a conversion will be needed (saying this because the warning shouldn't be inside the numba function)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just added it, thanks for the hint @ricardoV94 |
||
| g_xy[x_indices[value_idx]], y[col_idx] | ||
| ) | ||
| return output | ||
|
|
||
| return grad_spmdm_csc, cache_key | ||
|
|
||
| # Y is sparse. In either case we need 'dot_csr_rows' | ||
| @numba_basic.numba_njit | ||
| def dot_csr_rows(x_ptr, x_indices, x_data, x_row, y_ptr, y_indices, y_data, y_row): | ||
| x_p = x_ptr[x_row] | ||
| x_end = x_ptr[x_row + 1] | ||
| y_p = y_ptr[y_row] | ||
| y_end = y_ptr[y_row + 1] | ||
|
|
||
| acc = 0.0 | ||
| while x_p < x_end and y_p < y_end: | ||
| x_col = x_indices[x_p] | ||
| y_col = y_indices[y_p] | ||
| if x_col == y_col: | ||
| acc += x_data[x_p] * y_data[y_p] | ||
| x_p += 1 | ||
| y_p += 1 | ||
| elif x_col < y_col: | ||
| x_p += 1 | ||
| else: | ||
| y_p += 1 | ||
|
|
||
| return acc | ||
|
|
||
| if x_format == "csr": | ||
| assert g_xy_format == "csr" | ||
| assert psb._is_sparse_variable(y) | ||
|
|
||
| @numba_basic.numba_njit | ||
| def grad_spmspm_csr(x_indices, x_ptr, y, g_xy): | ||
| if y_format == "csc": | ||
| y = y.tocsr() | ||
|
tomicapretto marked this conversation as resolved.
|
||
|
|
||
| g_xy_data = g_xy.data | ||
| g_xy_indices = g_xy.indices.view(np.uint32) | ||
| g_xy_ptr = g_xy.indptr.view(np.uint32) | ||
|
|
||
| y_data = y.data | ||
| y_indices = y.indices.view(np.uint32) | ||
| y_ptr = y.indptr.view(np.uint32) | ||
|
|
||
| n_row = len(x_ptr) - 1 | ||
| output = np.zeros(len(x_indices), dtype=out_dtype) | ||
|
|
||
| for x_row in range(n_row): | ||
| for data_idx in range(x_ptr[x_row], x_ptr[x_row + 1]): | ||
| x_col = x_indices[data_idx] | ||
| output[data_idx] = dot_csr_rows( | ||
| g_xy_ptr, | ||
| g_xy_indices, | ||
| g_xy_data, | ||
| x_row, | ||
| y_ptr, | ||
| y_indices, | ||
| y_data, | ||
| x_col, | ||
| ) | ||
| return output | ||
|
|
||
| return grad_spmspm_csr, cache_key | ||
| else: | ||
| assert g_xy_format == "csc" | ||
| assert psb._is_sparse_variable(y) | ||
|
|
||
| @numba_basic.numba_njit | ||
| def grad_spmspm_csc(x_indices, x_ptr, y, g_xy): | ||
| if y_format == "csc": | ||
| y = y.tocsr() | ||
|
|
||
| # Looping a CSC matrix rowwise is too painful, slow, and cryptic. | ||
| g_xy = g_xy.tocsr() | ||
|
tomicapretto marked this conversation as resolved.
|
||
|
|
||
| g_xy_data = g_xy.data | ||
| g_xy_indices = g_xy.indices.view(np.uint32) | ||
| g_xy_ptr = g_xy.indptr.view(np.uint32) | ||
|
|
||
| y_data = y.data | ||
| y_indices = y.indices.view(np.uint32) | ||
| y_ptr = y.indptr.view(np.uint32) | ||
|
|
||
| n_cols = len(x_ptr) - 1 | ||
| output = np.empty(len(x_indices), dtype=out_dtype) | ||
|
|
||
| for x_col in range(n_cols): | ||
| for data_idx in range(x_ptr[x_col], x_ptr[x_col + 1]): | ||
| x_row = x_indices[data_idx] | ||
| output[data_idx] = dot_csr_rows( | ||
| g_xy_ptr, | ||
| g_xy_indices, | ||
| g_xy_data, | ||
| x_row, | ||
| y_ptr, | ||
| y_indices, | ||
| y_data, | ||
| x_col, | ||
| ) | ||
| return output | ||
|
|
||
| return grad_spmspm_csc, cache_key | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does mypy freak out if you typehint this as SparseArray -> SparseArray? It would make the function more clear. Not required if it causes a headache (typehinting these overloads often does)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not tried it, but the
SpSumop returns a dense array (see this).What happens here is that this calls the function I implemented in
overload_suminvariable.py.Maybe a global somewhere (per op or at the top of the file) saying that many (if not all) Ops are using overloads written in a separate python file?