Skip to content

dpctl doesn't allow mixing integer dtypes of indices arrays #1482

@antonwolfy

Description

@antonwolfy

The below example causes an exception for dpctl, but works with numpy:

a = numpy.ones((3, 4, 5), dtype='f4')
ai1 = numpy.ones((3, 4, 5), dtype='i4')
ai2 = numpy.reshape(numpy.arange(3, dtype='i8'), ((3, 1, 1)))

# no exception:
a[tuple([ai1, ai2])]

a = dpt.ones((3, 4, 5), dtype='f4')
ai1 = dpt.ones((3, 4, 5), dtype='i4')
ai2 = dpt.reshape(dpt.arange(3, dtype='i8'), ((3, 1, 1)))

a[tuple([ai1, ai2])]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 1
----> 1 _a[tuple([_ai1, _ai2])]

File dpctl/tensor/_usmarray.pyx:761, in dpctl.tensor._usmarray.usm_ndarray.__getitem__()

File ~/miniconda3/envs/dpnp_dev/lib/python3.9/site-packages/dpctl/tensor/_copy_utils.py:755, in _take_multi_index(ary, inds, p)
    750 res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
    751 res = dpt.empty(
    752     res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
    753 )
--> 755 hev, _ = ti._take(
    756     src=ary, ind=inds, dst=res, axis_start=p, mode=0, sycl_queue=exec_q
    757 )
    758 hev.wait()
    760 return res

TypeError: Indices array data types are not all the same.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions