Skip to content

Commit 4344c83

Browse files
Merge pull request #896 from ndgrigorian/dpctl-eye
Implemented dpctl.tensor.eye constructor and tests
2 parents 8cf06c6 + ba3df76 commit 4344c83

File tree

4 files changed

+255
-0
lines changed

4 files changed

+255
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
asarray,
2828
empty,
2929
empty_like,
30+
eye,
3031
full,
3132
full_like,
3233
linspace,
@@ -62,6 +63,7 @@
6263
"zeros",
6364
"ones",
6465
"full",
66+
"eye",
6567
"linspace",
6668
"empty_like",
6769
"zeros_like",

dpctl/tensor/_ctors.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,3 +1035,84 @@ def linspace(
10351035
)
10361036
hev.wait()
10371037
return res
1038+
1039+
1040+
def eye(
1041+
n_rows,
1042+
n_cols=None,
1043+
/,
1044+
*,
1045+
k=0,
1046+
dtype=None,
1047+
order="C",
1048+
device=None,
1049+
usm_type="device",
1050+
sycl_queue=None,
1051+
):
1052+
"""
1053+
eye(n_rows, n_cols = None, /, *, k = 0, dtype = None, \
1054+
device = None, usm_type="device", sycl_queue=None) -> usm_ndarray
1055+
1056+
Creates `usm_ndarray` with ones on the `k`th diagonal.
1057+
1058+
Args:
1059+
n_rows: number of rows in the output array.
1060+
n_cols (optional): number of columns in the output array. If None,
1061+
n_cols = n_rows. Default: `None`.
1062+
k: index of the diagonal, with 0 as the main diagonal.
1063+
A positive value of k is a superdiagonal, a negative value
1064+
is a subdiagonal.
1065+
Raises `TypeError` if k is not an integer.
1066+
Default: `0`.
1067+
dtype (optional): data type of the array. Can be typestring,
1068+
a `numpy.dtype` object, `numpy` char string, or a numpy
1069+
scalar type. Default: None
1070+
order ("C" or F"): memory layout for the array. Default: "C"
1071+
device (optional): array API concept of device where the output array
1072+
is created. `device` can be `None`, a oneAPI filter selector string,
1073+
an instance of :class:`dpctl.SyclDevice` corresponding to a
1074+
non-partitioned SYCL device, an instance of
1075+
:class:`dpctl.SyclQueue`, or a `Device` object returnedby
1076+
`dpctl.tensor.usm_array.device`. Default: `None`.
1077+
usm_type ("device"|"shared"|"host", optional): The type of SYCL USM
1078+
allocation for the output array. Default: `"device"`.
1079+
sycl_queue (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
1080+
for output array allocation and copying. `sycl_queue` and `device`
1081+
are exclusive keywords, i.e. use one or another. If both are
1082+
specified, a `TypeError` is raised unless both imply the same
1083+
underlying SYCL queue to be used. If both are `None`, the
1084+
`dpctl.SyclQueue()` is used for allocation and copying.
1085+
Default: `None`.
1086+
"""
1087+
if not isinstance(order, str) or len(order) == 0 or order[0] not in "CcFf":
1088+
raise ValueError(
1089+
"Unrecognized order keyword value, expecting 'F' or 'C'."
1090+
)
1091+
else:
1092+
order = order[0].upper()
1093+
n_rows = operator.index(n_rows)
1094+
n_cols = n_rows if n_cols is None else operator.index(n_cols)
1095+
k = operator.index(k)
1096+
if k >= n_cols or -k >= n_rows:
1097+
return dpt.zeros(
1098+
(n_rows, n_cols),
1099+
dtype=dtype,
1100+
order=order,
1101+
device=device,
1102+
usm_type=usm_type,
1103+
sycl_queue=sycl_queue,
1104+
)
1105+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1106+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1107+
dtype = _get_dtype(dtype, sycl_queue)
1108+
res = dpt.usm_ndarray(
1109+
(n_rows, n_cols),
1110+
dtype=dtype,
1111+
buffer=usm_type,
1112+
order=order,
1113+
buffer_ctor_kwargs={"queue": sycl_queue},
1114+
)
1115+
if n_rows != 0 and n_cols != 0:
1116+
hev, _ = ti._eye(k, dst=res, sycl_queue=sycl_queue)
1117+
hev.wait()
1118+
return res

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ template <typename srcT, typename dstT, int nd> class copy_cast_spec_kernel;
4545
template <typename Ty> class copy_for_reshape_generic_kernel;
4646
template <typename Ty> class linear_sequence_step_kernel;
4747
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
48+
template <typename Ty> class eye_kernel;
4849

4950
static dpctl::tensor::detail::usm_ndarray_types array_types;
5051

@@ -1740,6 +1741,144 @@ usm_ndarray_full(py::object py_value,
17401741
}
17411742
}
17421743

1744+
/* ================ Eye ================== */
1745+
1746+
typedef sycl::event (*eye_fn_ptr_t)(sycl::queue,
1747+
size_t nelems, // num_elements
1748+
py::ssize_t start,
1749+
py::ssize_t end,
1750+
py::ssize_t step,
1751+
char *, // dst_data_ptr
1752+
const std::vector<sycl::event> &);
1753+
1754+
static eye_fn_ptr_t eye_dispatch_vector[_ns::num_types];
1755+
1756+
template <typename Ty> class EyeFunctor
1757+
{
1758+
private:
1759+
Ty *p = nullptr;
1760+
py::ssize_t start_v;
1761+
py::ssize_t end_v;
1762+
py::ssize_t step_v;
1763+
1764+
public:
1765+
EyeFunctor(char *dst_p,
1766+
const py::ssize_t v0,
1767+
const py::ssize_t v1,
1768+
const py::ssize_t dv)
1769+
: p(reinterpret_cast<Ty *>(dst_p)), start_v(v0), end_v(v1), step_v(dv)
1770+
{
1771+
}
1772+
1773+
void operator()(sycl::id<1> wiid) const
1774+
{
1775+
Ty set_v = 0;
1776+
py::ssize_t i = static_cast<py::ssize_t>(wiid.get(0));
1777+
if (i >= start_v and i <= end_v) {
1778+
if ((i - start_v) % step_v == 0) {
1779+
set_v = 1;
1780+
}
1781+
}
1782+
p[i] = set_v;
1783+
}
1784+
};
1785+
1786+
template <typename Ty>
1787+
sycl::event eye_impl(sycl::queue exec_q,
1788+
size_t nelems,
1789+
const py::ssize_t start,
1790+
const py::ssize_t end,
1791+
const py::ssize_t step,
1792+
char *array_data,
1793+
const std::vector<sycl::event> &depends)
1794+
{
1795+
sycl::event eye_event = exec_q.submit([&](sycl::handler &cgh) {
1796+
cgh.depends_on(depends);
1797+
cgh.parallel_for<eye_kernel<Ty>>(
1798+
sycl::range<1>{nelems},
1799+
EyeFunctor<Ty>(array_data, start, end, step));
1800+
});
1801+
1802+
return eye_event;
1803+
}
1804+
1805+
template <typename fnT, typename Ty> struct EyeFactory
1806+
{
1807+
fnT get()
1808+
{
1809+
fnT f = eye_impl<Ty>;
1810+
return f;
1811+
}
1812+
};
1813+
1814+
std::pair<sycl::event, sycl::event>
1815+
eye(py::ssize_t k,
1816+
dpctl::tensor::usm_ndarray dst,
1817+
sycl::queue exec_q,
1818+
const std::vector<sycl::event> &depends = {})
1819+
{
1820+
// dst must be 2D
1821+
1822+
if (dst.get_ndim() != 2) {
1823+
throw py::value_error(
1824+
"usm_ndarray_eye: Expecting 2D array to populate");
1825+
}
1826+
1827+
sycl::queue dst_q = dst.get_queue();
1828+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
1829+
throw py::value_error("Execution queue is not compatible with the "
1830+
"allocation queue");
1831+
}
1832+
1833+
int dst_typenum = dst.get_typenum();
1834+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
1835+
1836+
const py::ssize_t nelem = dst.get_size();
1837+
const py::ssize_t rows = dst.get_shape(0);
1838+
const py::ssize_t cols = dst.get_shape(1);
1839+
if (rows == 0 || cols == 0) {
1840+
// nothing to do
1841+
return std::make_pair(sycl::event{}, sycl::event{});
1842+
}
1843+
1844+
bool is_dst_c_contig = ((dst.get_flags() & USM_ARRAY_C_CONTIGUOUS) != 0);
1845+
bool is_dst_f_contig = ((dst.get_flags() & USM_ARRAY_F_CONTIGUOUS) != 0);
1846+
if (!is_dst_c_contig && !is_dst_f_contig) {
1847+
throw py::value_error("USM array is not contiguous");
1848+
}
1849+
1850+
py::ssize_t start;
1851+
if (is_dst_c_contig) {
1852+
start = (k < 0) ? -k * cols : k;
1853+
}
1854+
else {
1855+
start = (k < 0) ? -k : k * rows;
1856+
}
1857+
1858+
const py::ssize_t *strides = dst.get_strides_raw();
1859+
py::ssize_t step;
1860+
if (strides == nullptr) {
1861+
step = (is_dst_c_contig) ? cols + 1 : rows + 1;
1862+
}
1863+
else {
1864+
step = strides[0] + strides[1];
1865+
}
1866+
1867+
const py::ssize_t length = std::min({rows, cols, rows + k, cols - k});
1868+
const py::ssize_t end = start + step * (length - 1);
1869+
1870+
char *dst_data = dst.get_data();
1871+
sycl::event eye_event;
1872+
1873+
auto fn = eye_dispatch_vector[dst_typeid];
1874+
1875+
eye_event = fn(exec_q, static_cast<size_t>(nelem), start, end, step,
1876+
dst_data, depends);
1877+
1878+
return std::make_pair(keep_args_alive(exec_q, {dst}, {eye_event}),
1879+
eye_event);
1880+
}
1881+
17431882
// populate dispatch tables
17441883
void init_copy_and_cast_dispatch_tables(void)
17451884
{
@@ -1794,6 +1933,9 @@ void init_copy_for_reshape_dispatch_vector(void)
17941933
dvb3;
17951934
dvb3.populate_dispatch_vector(full_contig_dispatch_vector);
17961935

1936+
DispatchVectorBuilder<eye_fn_ptr_t, EyeFactory, num_types> dvb4;
1937+
dvb4.populate_dispatch_vector(eye_dispatch_vector);
1938+
17971939
return;
17981940
}
17991941

@@ -1899,6 +2041,16 @@ PYBIND11_MODULE(_tensor_impl, m)
18992041
py::arg("fill_value"), py::arg("dst"), py::arg("sycl_queue"),
19002042
py::arg("depends") = py::list());
19012043

2044+
m.def("_eye", &eye,
2045+
"Fills input 2D contiguous usm_ndarray `dst` with "
2046+
"zeros outside of the diagonal "
2047+
"specified by "
2048+
"the diagonal index `k` "
2049+
"which is filled with ones."
2050+
"Returns a tuple of events: (ht_event, comp_event)",
2051+
py::arg("k"), py::arg("dst"), py::arg("sycl_queue"),
2052+
py::arg("depends") = py::list());
2053+
19022054
m.def("default_device_fp_type", [](sycl::queue q) -> std::string {
19032055
return get_default_device_fp_type(q.get_device());
19042056
});

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,24 @@ def test_full_like(dt, usm_kind):
12561256
assert np.array_equal(dpt.asnumpy(Y), np.ones(X.shape, dtype=X.dtype))
12571257

12581258

1259+
@pytest.mark.parametrize("dtype", _all_dtypes)
1260+
@pytest.mark.parametrize("usm_kind", ["shared", "device", "host"])
1261+
def test_eye(dtype, usm_kind):
1262+
try:
1263+
q = dpctl.SyclQueue()
1264+
except dpctl.SyclQueueCreationError:
1265+
pytest.skip("Queue could not be created")
1266+
1267+
if dtype in ["f8", "c16"] and q.sycl_device.has_aspect_fp64 is False:
1268+
pytest.skip(
1269+
"Device does not support double precision floating point type"
1270+
)
1271+
X = dpt.eye(4, 5, k=1, dtype=dtype, usm_type=usm_kind, sycl_queue=q)
1272+
Xnp = np.eye(4, 5, k=1, dtype=dtype)
1273+
assert X.dtype == Xnp.dtype
1274+
assert np.array_equal(Xnp, dpt.asnumpy(X))
1275+
1276+
12591277
def test_common_arg_validation():
12601278
order = "I"
12611279
# invalid order must raise ValueError
@@ -1267,6 +1285,8 @@ def test_common_arg_validation():
12671285
dpt.ones(10, order=order)
12681286
with pytest.raises(ValueError):
12691287
dpt.full(10, 1, order=order)
1288+
with pytest.raises(ValueError):
1289+
dpt.eye(10, order=order)
12701290
X = dpt.empty(10)
12711291
with pytest.raises(ValueError):
12721292
dpt.empty_like(X, order=order)

0 commit comments

Comments
 (0)