Skip to content

dpt.put returns a TypeError when vals is usm_ndarray with deferent dtype than x #1382

@vlad-perevezentsev

Description

@vlad-perevezentsev

The documentation for dpt.put does not describe the case when vals is a usm_ndarray with a different data type than x.
Calling this case raises a TypeError from dpctl backend.

I think we should cast vals to proper data type in case of x mismatch as numpy does.

The below example demonstrates this case:

import dpctl.tensor as dpt

x = dpt.arange(10)
ind = dpt.asarray([0])
vals = dpt.asarray([10], dtype='f4')

dpt.put(a,ind,vals)

hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q)
    214 hev.wait()

TypeError: Array data types are not the same.

# numpy 

import numpy

x_np = dpt.asnumpy(x)

numpy.put(x_np, dpt.asnumpy(ind), dpt.asnumpy(vals))
x_np
# array([10,  1,  2,  3,  4,  5,  6,  7,  8,  9])

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions