-
Notifications
You must be signed in to change notification settings - Fork 52
Description
Should tensordot broadcast the contracted dimensions. For example, say we contract the first dimensions here
tensordot(ones((3, 3)), ones((1, 3)), axes=((0,), (0,)))
The dimension 3 and 1 do not match, but if we broadcast the arrays together first they both become shape (3, 3), after which they do match.
The spec is a little unclear about this https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#tensordot-x1-x2-axes-2. It says x2 must be compatible with x1 by broadcasting, which seems to imply unconditional broadcasting. But it also says "Each axis (dimension) x1_axes[i] for x1 must have the same size as the respective axis (dimension) x2_axes[i] for x2."
NumPy disallows broadcasting in contracted dimensions (it does broadcast non-contracted dimensions):
>>> np.tensordot(np.ones((3, 3)), np.ones((1, 3)), axes=((0,), (0,)))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<__array_function__ internals>", line 181, in tensordot
File "./numpy/core/numeric.py", line 1110, in tensordot
raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum
>>> np.tensordot(np.ones((3, 3)), np.ones((1, 3)), axes=((1,), (1,)))
array([[3.],
[3.],
[3.]])
Pytorch broadcasts all dimensions, including contracted ones (note that pytorch still calls its axes argument dims
)
>>> torch.tensordot(torch.ones((3, 3)), torch.ones((1, 3)), dims=((0,), (0,)))
tensor([[3., 3., 3.],
[3., 3., 3.],
[3., 3., 3.]])
>>> torch.tensordot(torch.ones((3, 3)), torch.ones((1, 3)), dims=((1,), (1,)))
tensor([[3.],
[3.],
[3.]])
Note that in either case, the resulting array shape is based on the non-broadcasted input shapes, so it's not as simple as wrapping the call with broadcast_arrays
.
>>> np.tensordot(np.ones((3, 3)), np.ones((2, 3, 3)), axes=((-1,), (2,))).shape
(3, 2, 3)
>>> np.tensordot(np.ones((2, 3, 3)), np.ones((2, 3, 3)), axes=((-1,), (2,))).shape
(2, 3, 2, 3)