-
Notifications
You must be signed in to change notification settings - Fork 30
Implementation of permute_dims function #786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of permute_dims function #786
Conversation
dpctl/tensor/_permute_dims.py
Outdated
|
||
|
||
def permute_dims(X, axes): | ||
if type(X) is not dpt.usm_ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if type(X) is not dpt.usm_ndarray: | |
if not isinstance(X, dpt.usm_ndarray): |
ae95685
to
1c8970d
Compare
dpctl/tensor/_permute_dims.py
Outdated
|
||
def permute_dims(X, axes): | ||
""" | ||
permute_dims(X: usm_ndarray, axes: tuple or list) -> usm_ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since your branch does not build documentation we do not see how this comes out in docs. Please create a branch in the main repo and push there to see the rendered docs,
dpctl/tensor/_permute_dims.py
Outdated
"The length of the passed axes does not match " | ||
"to the number of usm_ndarray dimensions." | ||
) | ||
if axes not in permutations(range(0, X.ndim), X.ndim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has horrible computational and space complexity in the dimensionality of the array d
, specifically O(d!)
.
More efficient way:
- Length of the tuple must be
d
- Minimum element must be
0
, maximumd-1
. - Array should have no duplicates.
1c8970d
to
aefab1d
Compare
aefab1d
to
30e98fe
Compare
Moved to #787 |
This PR adds
permute_dims
functions according to Python array API standard forusm_ndarray