-
Notifications
You must be signed in to change notification settings - Fork 36
Closed
Description
Hi there,
this code is kinda out-of-date, it's already mentioned in the documentation that PyTorch is supporting tuple of ints.
Well it's fine, but the biggest issue is that this squeeze is not behaving similarly to numpy, since I can lazily use squeeze to remove all trivial dimensions.
Well, I can do patch to this part and PR it if it's okay for your team.
Thanks,
Sam
array-api-compat/array_api_compat/torch/_aliases.py
Lines 488 to 499 in 21aa31b
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: | |
if isinstance(axis, int): | |
axis = (axis,) | |
for a in axis: | |
if x.shape[a] != 1: | |
raise ValueError("squeezed dimensions must be equal to 1") | |
axes = _normalize_axes(axis, x.ndim) | |
# Remove this once pytorch 1.14 is released with the above PR #89017. | |
sequence = [a - i for i, a in enumerate(axes)] | |
for a in sequence: | |
x = torch.squeeze(x, a) | |
return x |
Metadata
Metadata
Assignees
Labels
No labels