Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
* Compile indexing extension with `-fno-sycl-id-queries-fit-in-int` to support huge arrays [#2721](https://github.com/IntelPython/dpnp/pull/2721)
* Updated `dpnp.fix` to reuse `dpnp.trunc` internally [#2722](https://github.com/IntelPython/dpnp/pull/2722)
* Changed the build scripts and documentation due to `python setup.py develop` deprecation notice [#2716](https://github.com/IntelPython/dpnp/pull/2716)
* Clarified behavior on repeated `axes` in `dpnp.tensordot` and `dpnp.linalg.tensordot` functions [#2733](https://github.com/IntelPython/dpnp/pull/2733)

### Deprecated

Expand Down
14 changes: 12 additions & 2 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ def outer(a, b, out=None):
return result


def tensordot(a, b, axes=2):
def tensordot(a, b, /, *, axes=2):
r"""
Compute tensor dot product along specified axes.

Expand All @@ -1148,7 +1148,10 @@ def tensordot(a, b, axes=2):
axes must match.
* (2,) array_like: A list of axes to be summed over, first sequence
applying to `a`, second to `b`. Both elements array_like must be of
the same length.
the same length. Each axis may appear at most once; repeated axes are
not allowed.

Default: ``2``.

Returns
-------
Expand Down Expand Up @@ -1178,6 +1181,13 @@ def tensordot(a, b, axes=2):
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.

For example, if ``a.shape == (2, 3, 4)`` and ``b.shape == (3, 4, 5)``, then
``axes=([1, 2], [0, 1])`` sums over the ``(3, 4)`` dimensions of both
arrays and produces an output of shape ``(2, 5)``.

Each summation axis corresponds to a distinct contraction index; repeating
an axis (for example ``axes=([1, 1], [0, 0])``) is invalid.

The shape of the result consists of the non-contracted axes of the
first tensor, followed by the non-contracted axes of the second.

Expand Down
12 changes: 10 additions & 2 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,9 +1975,10 @@ def tensordot(a, b, /, *, axes=2):
axes must match.
* (2,) array_like: A list of axes to be summed over, first sequence
applying to `a`, second to `b`. Both elements array_like must be of
the same length.
the same length. Each axis may appear at most once; repeated axes are
not allowed.

Default: ``2``.
Default: ``2``.

Returns
-------
Expand Down Expand Up @@ -2007,6 +2008,13 @@ def tensordot(a, b, /, *, axes=2):
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.

For example, if ``a.shape == (2, 3, 4)`` and ``b.shape == (3, 4, 5)``, then
``axes=([1, 2], [0, 1])`` sums over the ``(3, 4)`` dimensions of both
arrays and produces an output of shape ``(2, 5)``.

Each summation axis corresponds to a distinct contraction index; repeating
an axis (for example ``axes=([1, 1], [0, 0])``) is invalid.

The shape of the result consists of the non-contracted axes of the
first tensor, followed by the non-contracted axes of the second.

Expand Down
7 changes: 7 additions & 0 deletions dpnp/tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,13 @@ def test_error(self):
with pytest.raises(ValueError):
dpnp.tensordot(dpnp.arange(4), dpnp.array(5), axes=-1)

@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_repeated_axes(self, xp):
a = xp.ones((2, 3, 3))
b = xp.ones((3, 3, 4))
with pytest.raises(ValueError):
xp.tensordot(a, b, axes=([1, 1], [0, 0]))


class TestVdot:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
Expand Down
Loading