Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,9 +771,9 @@ def orderings(self, fgraph, ordered=True):
}
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", []
app.op, "destroyhandler_tolerate_aliased", ()
)
assert isinstance(tolerate_aliased, list)
assert isinstance(tolerate_aliased, tuple | list)
ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
}
Expand Down
36 changes: 3 additions & 33 deletions pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice


BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
Expand All @@ -35,10 +34,8 @@
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -48,10 +45,9 @@ def subtensor(x, *ilists):


@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, indices, y):
Expand All @@ -62,7 +58,7 @@ def jax_fn(x, indices, y):
def jax_fn(x, indices, y):
return x.at[indices].add(y)

def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
Expand All @@ -73,29 +69,3 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
return jax_fn(x, indices, y)

return incsubtensor


@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, indices, y):
return x.at[indices].set(y)

else:

def jax_fn(x, indices, y):
return x.at[indices].add(y)

def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)

return advancedincsubtensor


@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)

return makeslice
23 changes: 5 additions & 18 deletions pytensor/link/mlx/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice


@mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

def subtensor(x, *ilists):
indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
indices = indices_from_subtensor(
[int(element) for element in ilists], op.idx_list
)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -30,10 +29,8 @@ def subtensor(x, *ilists):
@mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -45,8 +42,6 @@ def advanced_subtensor(x, *ilists):
@mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

if getattr(op, "set_instead_of_inc", False):

def mlx_fn(x, indices, y):
Expand All @@ -63,7 +58,7 @@ def mlx_fn(x, indices, y):
x[indices] += y
return x

def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
Expand Down Expand Up @@ -95,11 +90,3 @@ def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn):
return mlx_fn(x, ilist, y)

return advancedincsubtensor


@mlx_funcify.register(MakeSlice)
def mlx_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)

return makeslice
Loading
Loading