Skip to content

dpt.take_along_axis and dpt.put_along_axis raise an Error when indices data type is uint64 #1936

@vtavana

Description

@vtavana

In dpctl, take_along_axis raises an error when indices dtype is uint64 while it works properly in NumPy.

import dpctl.tensor as dpt
a = dpt.asarray([[10, 30, 20], [60, 40, 50]])
ind = dpt.asarray([[2, 1, 0], [2, 1, 0]], dtype=dpt.uint64)
dpt.take_along_axis(a, ind, axis=1)
# ValueError: cannot safely promote indices to an integer data type

import numpy
a = numpy.array([[10, 30, 20], [60, 40, 50]])
ind = numpy.array([[2, 1, 0], [2, 1, 0]], dtype=numpy.uint64)
numpy.take_along_axis(a, ind, axis=1)
# array([[20, 30, 10],
#       [50, 40, 60]])

a similar behavior is observed for dpt.put_along_axis

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions