Skip to content

Should tensordot broadcast the contracted dimensions? #294

@asmeurer

Description

@asmeurer

Should tensordot broadcast the contracted dimensions. For example, say we contract the first dimensions here

tensordot(ones((3, 3)), ones((1, 3)), axes=((0,), (0,)))

The dimension 3 and 1 do not match, but if we broadcast the arrays together first they both become shape (3, 3), after which they do match.

The spec is a little unclear about this https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#tensordot-x1-x2-axes-2. It says x2 must be compatible with x1 by broadcasting, which seems to imply unconditional broadcasting. But it also says "Each axis (dimension) x1_axes[i] for x1 must have the same size as the respective axis (dimension) x2_axes[i] for x2."

NumPy disallows broadcasting in contracted dimensions (it does broadcast non-contracted dimensions):

>>> np.tensordot(np.ones((3, 3)), np.ones((1, 3)), axes=((0,), (0,)))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<__array_function__ internals>", line 181, in tensordot
  File "./numpy/core/numeric.py", line 1110, in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum
>>> np.tensordot(np.ones((3, 3)), np.ones((1, 3)), axes=((1,), (1,)))
array([[3.],
       [3.],
       [3.]])

Pytorch broadcasts all dimensions, including contracted ones (note that pytorch still calls its axes argument dims)

>>> torch.tensordot(torch.ones((3, 3)), torch.ones((1, 3)), dims=((0,), (0,)))
tensor([[3., 3., 3.],
        [3., 3., 3.],
        [3., 3., 3.]])
>>> torch.tensordot(torch.ones((3, 3)), torch.ones((1, 3)), dims=((1,), (1,)))
tensor([[3.],
        [3.],
        [3.]])

Note that in either case, the resulting array shape is based on the non-broadcasted input shapes, so it's not as simple as wrapping the call with broadcast_arrays.

>>> np.tensordot(np.ones((3, 3)), np.ones((2, 3, 3)), axes=((-1,), (2,))).shape
(3, 2, 3)
>>> np.tensordot(np.ones((2, 3, 3)), np.ones((2, 3, 3)), axes=((-1,), (2,))).shape
(2, 3, 2, 3)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions