diff --git a/CHANGELOG.md b/CHANGELOG.md index 3507b51efc6..6e098c365d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum * Compile indexing extension with `-fno-sycl-id-queries-fit-in-int` to support huge arrays [#2721](https://github.com/IntelPython/dpnp/pull/2721) * Updated `dpnp.fix` to reuse `dpnp.trunc` internally [#2722](https://github.com/IntelPython/dpnp/pull/2722) * Changed the build scripts and documentation due to `python setup.py develop` deprecation notice [#2716](https://github.com/IntelPython/dpnp/pull/2716) +* Clarified behavior on repeated `axes` in `dpnp.tensordot` and `dpnp.linalg.tensordot` functions [#2733](https://github.com/IntelPython/dpnp/pull/2733) ### Deprecated diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index a0068a3597e..acb12347348 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -1121,7 +1121,7 @@ def outer(a, b, out=None): return result -def tensordot(a, b, axes=2): +def tensordot(a, b, /, *, axes=2): r""" Compute tensor dot product along specified axes. @@ -1148,7 +1148,10 @@ def tensordot(a, b, axes=2): axes must match. * (2,) array_like: A list of axes to be summed over, first sequence applying to `a`, second to `b`. Both elements array_like must be of - the same length. + the same length. Each axis may appear at most once; repeated axes are + not allowed. + + Default: ``2``. Returns ------- @@ -1178,6 +1181,13 @@ def tensordot(a, b, axes=2): two sequences of the same length, with the first axis to sum over given first in both sequences, the second axis second, and so forth. + For example, if ``a.shape == (2, 3, 4)`` and ``b.shape == (3, 4, 5)``, then + ``axes=([1, 2], [0, 1])`` sums over the ``(3, 4)`` dimensions of both + arrays and produces an output of shape ``(2, 5)``. + + Each summation axis corresponds to a distinct contraction index; repeating + an axis (for example ``axes=([1, 1], [0, 0])``) is invalid. + The shape of the result consists of the non-contracted axes of the first tensor, followed by the non-contracted axes of the second. diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 95bb849afa3..6959565ecf1 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -1975,9 +1975,10 @@ def tensordot(a, b, /, *, axes=2): axes must match. * (2,) array_like: A list of axes to be summed over, first sequence applying to `a`, second to `b`. Both elements array_like must be of - the same length. + the same length. Each axis may appear at most once; repeated axes are + not allowed. - Default: ``2``. + Default: ``2``. Returns ------- @@ -2007,6 +2008,13 @@ def tensordot(a, b, /, *, axes=2): two sequences of the same length, with the first axis to sum over given first in both sequences, the second axis second, and so forth. + For example, if ``a.shape == (2, 3, 4)`` and ``b.shape == (3, 4, 5)``, then + ``axes=([1, 2], [0, 1])`` sums over the ``(3, 4)`` dimensions of both + arrays and produces an output of shape ``(2, 5)``. + + Each summation axis corresponds to a distinct contraction index; repeating + an axis (for example ``axes=([1, 1], [0, 0])``) is invalid. + The shape of the result consists of the non-contracted axes of the first tensor, followed by the non-contracted axes of the second. diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index 763049e8791..afe767a5e5d 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -1842,6 +1842,13 @@ def test_error(self): with pytest.raises(ValueError): dpnp.tensordot(dpnp.arange(4), dpnp.array(5), axes=-1) + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_repeated_axes(self, xp): + a = xp.ones((2, 3, 3)) + b = xp.ones((3, 3, 4)) + with pytest.raises(ValueError): + xp.tensordot(a, b, axes=([1, 1], [0, 0])) + class TestVdot: @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))