Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion pytensor/link/numba/dispatch/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType
from pytensor.sparse import CSM, Cast, CSMProperties, DenseFromSparse, Transpose
from pytensor.sparse import (
CSM,
Cast,
CSMProperties,
DenseFromSparse,
SparseFromDense,
Transpose,
)


@overload(numba_deepcopy)
Expand Down Expand Up @@ -84,3 +91,21 @@ def to_array(x):
return x.toarray()

return to_array


@register_funcify_default_op_cache_key(SparseFromDense)
def numba_funcify_SparseFromDense(op, node, **kwargs):
if op.format == "csr":

@numba_basic.numba_njit
def dense_to_csr(matrix):
return sp.sparse.csr_matrix(matrix)

return dense_to_csr
else:

@numba_basic.numba_njit
def dense_to_csc(matrix):
return sp.sparse.csc_matrix(matrix)

return dense_to_csc
245 changes: 245 additions & 0 deletions pytensor/link/numba/dispatch/sparse/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import scipy.sparse as sp

import pytensor.sparse.basic as psb
from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
register_funcify_and_cache_key,
Expand All @@ -13,10 +14,24 @@
Dot,
SparseDenseMultiply,
SparseDenseVectorMultiply,
SpSum,
StructuredDot,
StructuredDotGradCSC,
StructuredDotGradCSR,
)


@register_funcify_default_op_cache_key(SpSum)
def numba_funcify_SpSum(op, node, **kwargs):
axis = op.axis

@numba_basic.numba_njit
def perform(x):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does mypy freak out if you typehint this as SparseArray -> SparseArray? It would make the function more clear. Not required if it causes a headache (typehinting these overloads often does)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tried it, but the SpSum op returns a dense array (see this).

What happens here is that this calls the function I implemented in overload_sum in variable.py.
Maybe a global somewhere (per op or at the top of the file) saying that many (if not all) Ops are using overloads written in a separate python file?

return x.sum(axis)

return perform


@register_funcify_default_op_cache_key(SparseDenseMultiply)
@register_funcify_default_op_cache_key(SparseDenseVectorMultiply)
def numba_funcify_SparseDenseMultiply(op, node, **kwargs):
Expand Down Expand Up @@ -402,3 +417,233 @@ def dmspm(x, y):
return spmdm_csr(y.T, x.T).T

return dmspm, cache_key


@register_funcify_and_cache_key(StructuredDotGradCSR)
@register_funcify_and_cache_key(StructuredDotGradCSC)
def numba_funcify_StructuredDotGrad(op, node, **kwargs):
"""Overload StructuredDotGrad in Numba.

Let:
Z = structured_dot(X, Y)
L = L(Z), a scalar loss depending on Z.

This function computes the gradient of the loss with respect to X:

dL/dX = dot(dL/dZ, Y^T)

where G = dL/dZ is the accumulated (upstream) gradient.

The returned gradient is structured, preserving the sparsity pattern of X,
and only the `.data` component of the sparse matrix is computed.
If Y is sparse, the sparsity pattern of the result is not recomputed.
The output may contain explicit zeros at positions that would be structural zeros
if the sparsity structure were updated.

The core of the algorithm is:

dot(g_xy[i], y[j])

where g_xy[i] (row of G) and y[j] (column of Y^T) are vectors of length 'k'

Reminder:
x.shape (n, p)
y.shape (p, k)
g_xy.shape (n, k)
"""
_, _, y, g_xy = node.inputs

y_dtype = y.type.dtype
y_is_sparse = psb._is_sparse_variable(y)
y_format = y.type.format if y_is_sparse else None

g_xy_dtype = g_xy.type.dtype
g_xy_is_sparse = psb._is_sparse_variable(g_xy)
g_xy_format = g_xy.type.format if g_xy_is_sparse else None

x_format = "csc" if isinstance(op, StructuredDotGradCSC) else "csr"
out_dtype = g_xy_dtype

cache_key = sha256(
str(
(
type(op),
x_format,
y_format,
y_dtype,
g_xy_format,
out_dtype,
y.type.shape,
)
).encode()
).hexdigest()
Comment thread
tomicapretto marked this conversation as resolved.

if not g_xy_is_sparse:
# X is sparse, Y and G_xy are dense.
if x_format == "csr":
if y.type.shape[1] == 1:
# If Y is actually 1D, use more performant specialized algorithm
# Inputs with ndims > 2 will never appear in the StructuredDot Op
@numba_basic.numba_njit
def _grad_spmdv_csr(x_indices, x_ptr, y, g_xy):
output = np.empty(len(x_indices), dtype=out_dtype)
size = len(x_ptr) - 1
x_indices = x_indices.view(np.uint32)
x_ptr = x_ptr.view(np.uint32)
for row_idx in range(size):
for value_idx in range(x_ptr[row_idx], x_ptr[row_idx + 1]):
output[value_idx] = g_xy[row_idx] * y[x_indices[value_idx]]
return output

@numba_basic.numba_njit
def grad_spmdv_csr(x_indices, x_ptr, y, g_xy):
return _grad_spmdv_csr(x_indices, x_ptr, y[:, 0], g_xy[:, 0])

return grad_spmdv_csr, cache_key
else:
# Y is a matrix
if config.compiler_verbose and y_dtype != out_dtype:
print( # noqa: T201
"Numba StructuredDotGrad requires a type casting of inputs: "
f"{y_dtype=}, {g_xy_dtype=}."
)

@numba_basic.numba_njit
def grad_spmdm_csr(x_indices, x_ptr, y, g_xy):
size = len(x_ptr) - 1
x_indices = x_indices.view(np.uint32)
x_ptr = x_ptr.view(np.uint32)

if y_dtype != out_dtype:
new_out_dtype = np.result_type(y, g_xy)
output = np.zeros(len(x_indices), dtype=new_out_dtype)
y = y.astype(out_dtype)
g_xy = g_xy.astype(out_dtype)
else:
output = np.zeros(len(x_indices), dtype=out_dtype)

for row_idx in range(size):
for value_idx in range(x_ptr[row_idx], x_ptr[row_idx + 1]):
output[value_idx] = np.dot(
g_xy[row_idx], y[x_indices[value_idx]]
)
return output

return grad_spmdm_csr, cache_key
else:
# X is CSC
@numba_basic.numba_njit
def grad_spmdm_csc(x_indices, x_ptr, y, g_xy):
# len(x_indices) gives the number of non-zero elements in X.
output = np.zeros(len(x_indices), dtype=out_dtype)
size = len(x_ptr) - 1
x_indices = x_indices.view(np.uint32)
x_ptr = x_ptr.view(np.uint32)

for col_idx in range(size):
for value_idx in range(x_ptr[col_idx], x_ptr[col_idx + 1]):
output[value_idx] = np.dot(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to be careful with np.dot. IIRC numba overload doesn't support integer / mixed dtypes well

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, I'm using it since np.sum(x * y) was slower. There are a bunch of test that pass different data types, and they have all passed. Probably that's ok?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its probably fine as long as we're upcasting the inputs to a common dtype in the make_node of Dot?

In the medium term we should consider re-implementing the BLAS calls ourselves

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have mixed integer / float types in the test. Or just discrete. I have >20% belief numba np.dot overload refuses those

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. The issue was not covered by tests (I mixed something I saw in tests for Dot with what we're doing here with StructuredDot) plus Numba does not accept mixed types in np.dot.

    Rejected as the implementation raised a specific error:
      TypingError: np.dot() arguments must all have the same dtype

I'll double check the upcasting

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't solve the numba side? In numba you'll have to explicitly cast one or both of the inputs to dot them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're so right!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably fine, but perhaps add a print statement under a if pytensor.config.compiler_verbose so users can track down a potential source of slow down. We do this for a couple of linalg numba dispatch.

You should be able to know at dispatch time whether a conversion will be needed (saying this because the warning shouldn't be inside the numba function)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added it, thanks for the hint @ricardoV94

g_xy[x_indices[value_idx]], y[col_idx]
)
return output

return grad_spmdm_csc, cache_key

# Y is sparse. In either case we need 'dot_csr_rows'
@numba_basic.numba_njit
def dot_csr_rows(x_ptr, x_indices, x_data, x_row, y_ptr, y_indices, y_data, y_row):
x_p = x_ptr[x_row]
x_end = x_ptr[x_row + 1]
y_p = y_ptr[y_row]
y_end = y_ptr[y_row + 1]

acc = 0.0
while x_p < x_end and y_p < y_end:
x_col = x_indices[x_p]
y_col = y_indices[y_p]
if x_col == y_col:
acc += x_data[x_p] * y_data[y_p]
x_p += 1
y_p += 1
elif x_col < y_col:
x_p += 1
else:
y_p += 1

return acc

if x_format == "csr":
assert g_xy_format == "csr"
assert psb._is_sparse_variable(y)

@numba_basic.numba_njit
def grad_spmspm_csr(x_indices, x_ptr, y, g_xy):
if y_format == "csc":
y = y.tocsr()
Comment thread
tomicapretto marked this conversation as resolved.

g_xy_data = g_xy.data
g_xy_indices = g_xy.indices.view(np.uint32)
g_xy_ptr = g_xy.indptr.view(np.uint32)

y_data = y.data
y_indices = y.indices.view(np.uint32)
y_ptr = y.indptr.view(np.uint32)

n_row = len(x_ptr) - 1
output = np.zeros(len(x_indices), dtype=out_dtype)

for x_row in range(n_row):
for data_idx in range(x_ptr[x_row], x_ptr[x_row + 1]):
x_col = x_indices[data_idx]
output[data_idx] = dot_csr_rows(
g_xy_ptr,
g_xy_indices,
g_xy_data,
x_row,
y_ptr,
y_indices,
y_data,
x_col,
)
return output

return grad_spmspm_csr, cache_key
else:
assert g_xy_format == "csc"
assert psb._is_sparse_variable(y)

@numba_basic.numba_njit
def grad_spmspm_csc(x_indices, x_ptr, y, g_xy):
if y_format == "csc":
y = y.tocsr()

# Looping a CSC matrix rowwise is too painful, slow, and cryptic.
g_xy = g_xy.tocsr()
Comment thread
tomicapretto marked this conversation as resolved.

g_xy_data = g_xy.data
g_xy_indices = g_xy.indices.view(np.uint32)
g_xy_ptr = g_xy.indptr.view(np.uint32)

y_data = y.data
y_indices = y.indices.view(np.uint32)
y_ptr = y.indptr.view(np.uint32)

n_cols = len(x_ptr) - 1
output = np.empty(len(x_indices), dtype=out_dtype)

for x_col in range(n_cols):
for data_idx in range(x_ptr[x_col], x_ptr[x_col + 1]):
x_row = x_indices[data_idx]
output[data_idx] = dot_csr_rows(
g_xy_ptr,
g_xy_indices,
g_xy_data,
x_row,
y_ptr,
y_indices,
y_data,
x_col,
)
return output

return grad_spmspm_csc, cache_key
Loading
Loading