diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 6c9ac88fc2..5a31a05b61 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -33,6 +33,8 @@ linspace, ones, ones_like, + tril, + triu, zeros, zeros_like, ) @@ -83,4 +85,6 @@ "to_numpy", "asnumpy", "from_dlpack", + "tril", + "triu", ] diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index e3bb8fe3ec..0e50f38f05 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -1116,3 +1116,85 @@ def eye( hev, _ = ti._eye(k, dst=res, sycl_queue=sycl_queue) hev.wait() return res + + +def tril(X, k=0): + """ + tril(X: usm_ndarray, k: int) -> usm_ndarray + + Returns the lower triangular part of a matrix (or a stack of matrices) X. + """ + if type(X) is not dpt.usm_ndarray: + raise TypeError + + k = operator.index(k) + + # F_CONTIGUOUS = 2 + order = "F" if (X.flags & 2) else "C" + + shape = X.shape + nd = X.ndim + if nd < 2: + raise ValueError("Array dimensions less than 2.") + + if k >= shape[nd - 1] - 1: + res = dpt.empty( + X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + ) + hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=X, dst=res, sycl_queue=X.sycl_queue + ) + hev.wait() + elif k < -shape[nd - 2]: + res = dpt.zeros( + X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + ) + else: + res = dpt.empty( + X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + ) + hev, _ = ti._tril(src=X, dst=res, k=k, sycl_queue=X.sycl_queue) + hev.wait() + + return res + + +def triu(X, k=0): + """ + triu(X: usm_ndarray, k: int) -> usm_ndarray + + Returns the upper triangular part of a matrix (or a stack of matrices) X. + """ + if type(X) is not dpt.usm_ndarray: + raise TypeError + + k = operator.index(k) + + # F_CONTIGUOUS = 2 + order = "F" if (X.flags & 2) else "C" + + shape = X.shape + nd = X.ndim + if nd < 2: + raise ValueError("Array dimensions less than 2.") + + if k > shape[nd - 1]: + res = dpt.zeros( + X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + ) + elif k <= -shape[nd - 2] + 1: + res = dpt.empty( + X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + ) + hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=X, dst=res, sycl_queue=X.sycl_queue + ) + hev.wait() + else: + res = dpt.empty( + X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + ) + hev, _ = ti._triu(src=X, dst=res, k=k, sycl_queue=X.sycl_queue) + hev.wait() + + return res diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 6f625c5bb4..533a36e512 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -1879,6 +1879,313 @@ eye(py::ssize_t k, eye_event); } +/* =========================== Tril and triu ============================== */ +// define function type +typedef sycl::event (*tri_fn_ptr_t)(sycl::queue, + py::ssize_t, // inner_range //py::ssize_t + py::ssize_t, // outer_range + char *, // src_data_ptr + char *, // dst_data_ptr + py::ssize_t, // nd + py::ssize_t *, // shape_and_strides + py::ssize_t, // k + const std::vector &, + const std::vector &); + +template class tri_kernel; +template +sycl::event tri_impl(sycl::queue exec_q, + py::ssize_t inner_range, + py::ssize_t outer_range, + char *src_p, + char *dst_p, + py::ssize_t nd, + py::ssize_t *shape_and_strides, + py::ssize_t k, + const std::vector &depends, + const std::vector &additional_depends) +{ + constexpr int d2 = 2; + py::ssize_t src_s = nd; + py::ssize_t dst_s = 2 * nd; + py::ssize_t nd_1 = nd - 1; + py::ssize_t nd_2 = nd - 2; + Ty *src = reinterpret_cast(src_p); + Ty *dst = reinterpret_cast(dst_p); + + sycl::event tri_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + cgh.parallel_for>( + sycl::range<1>(inner_range * outer_range), [=](sycl::id<1> idx) { + py::ssize_t outer_gid = idx[0] / inner_range; + py::ssize_t inner_gid = idx[0] - inner_range * outer_gid; + + py::ssize_t src_inner_offset, dst_inner_offset; + bool to_copy; + + { + // py::ssize_t inner_gid = idx.get_id(0); + CIndexer_array indexer_i( + {shape_and_strides[nd_2], shape_and_strides[nd_1]}); + indexer_i.set(inner_gid); + const std::array &inner = indexer_i.get(); + src_inner_offset = + inner[0] * shape_and_strides[src_s + nd_2] + + inner[1] * shape_and_strides[src_s + nd_1]; + dst_inner_offset = + inner[0] * shape_and_strides[dst_s + nd_2] + + inner[1] * shape_and_strides[dst_s + nd_1]; + + if (l) + to_copy = (inner[0] + k >= inner[1]); + else + to_copy = (inner[0] + k <= inner[1]); + } + + py::ssize_t src_offset = 0; + py::ssize_t dst_offset = 0; + { + // py::ssize_t outer_gid = idx.get_id(1); + CIndexer_vector outer(nd - d2); + outer.get_displacement( + outer_gid, shape_and_strides, shape_and_strides + src_s, + shape_and_strides + dst_s, src_offset, dst_offset); + } + + src_offset += src_inner_offset; + dst_offset += dst_inner_offset; + + dst[dst_offset] = (to_copy) ? src[src_offset] : Ty(0); + }); + }); + return tri_ev; +} + +static tri_fn_ptr_t tril_generic_dispatch_vector[_ns::num_types]; + +template struct TrilGenericFactory +{ + fnT get() + { + fnT f = tri_impl; + return f; + } +}; + +static tri_fn_ptr_t triu_generic_dispatch_vector[_ns::num_types]; + +template struct TriuGenericFactory +{ + fnT get() + { + fnT f = tri_impl; + return f; + } +}; + +std::pair +tri(sycl::queue &exec_q, + dpctl::tensor::usm_ndarray src, + dpctl::tensor::usm_ndarray dst, + char part, + py::ssize_t k = 0, + const std::vector &depends = {}) +{ + // array dimensions must be the same + int src_nd = src.get_ndim(); + int dst_nd = dst.get_ndim(); + if (src_nd != dst_nd) { + throw py::value_error("Array dimensions are not the same."); + } + + if (src_nd < 2) { + throw py::value_error("Array dimensions less than 2."); + } + + // shapes must be the same + const py::ssize_t *src_shape = src.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + + bool shapes_equal(true); + size_t src_nelems(1); + + for (int i = 0; shapes_equal && i < src_nd; ++i) { + src_nelems *= static_cast(src_shape[i]); + shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]); + } + if (!shapes_equal) { + throw py::value_error("Array shapes are not the same."); + } + + if (src_nelems == 0) { + // nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + // check that arrays do not overlap, and concurrent copying is safe. + char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + auto src_offsets = src.get_minmax_offsets(); + auto dst_offsets = dst.get_minmax_offsets(); + int src_elem_size = src.get_elemsize(); + int dst_elem_size = dst.get_elemsize(); + + bool memory_overlap = + ((dst_data - src_data > src_offsets.second * src_elem_size - + dst_offsets.first * dst_elem_size) && + (src_data - dst_data > dst_offsets.second * dst_elem_size - + src_offsets.first * src_elem_size)); + if (memory_overlap) { + // TODO: could use a temporary, but this is done by the caller + throw py::value_error("Arrays index overlapping segments of memory"); + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + if (dst_typeid != src_typeid) { + throw py::value_error("Array dtype are not the same."); + } + + // check same contexts + sycl::queue src_q = src.get_queue(); + sycl::queue dst_q = dst.get_queue(); + + if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) { + throw py::value_error( + "Execution queue context is not the same as allocation contexts"); + } + + using shT = std::vector; + int src_flags = src.get_flags(); + const py::ssize_t *src_strides_raw = src.get_strides_raw(); + shT src_strides(src_nd); + bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0); + bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0); + if (src_strides_raw == nullptr) { + if (is_src_c_contig) { + src_strides = c_contiguous_strides(src_nd, src_shape); + } + else if (is_src_f_contig) { + src_strides = f_contiguous_strides(src_nd, src_shape); + } + else { + throw std::runtime_error("Source array has null strides but has " + "neither C- nor F- contiguous flag set"); + } + } + else { + std::copy(src_strides_raw, src_strides_raw + src_nd, + src_strides.begin()); + } + + int dst_flags = dst.get_flags(); + const py::ssize_t *dst_strides_raw = dst.get_strides_raw(); + shT dst_strides(src_nd); + bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0); + bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0); + if (dst_strides_raw == nullptr) { + if (is_dst_c_contig) { + dst_strides = c_contiguous_strides(src_nd, src_shape); + } + else if (is_dst_f_contig) { + dst_strides = f_contiguous_strides(src_nd, src_shape); + } + else { + throw std::runtime_error("Source array has null strides but has " + "neither C- nor F- contiguous flag set"); + } + } + else { + std::copy(dst_strides_raw, dst_strides_raw + dst_nd, + dst_strides.begin()); + } + + shT simplified_shape; + shT simplified_src_strides; + shT simplified_dst_strides; + py::ssize_t src_offset(0); + py::ssize_t dst_offset(0); + + constexpr py::ssize_t src_itemsize = 1; // item size in elements + constexpr py::ssize_t dst_itemsize = 1; // item size in elements + + int nd = src_nd - 2; + const py::ssize_t *shape = src_shape; + const py::ssize_t *p_src_strides = src_strides.data(); + const py::ssize_t *p_dst_strides = dst_strides.data(); + + simplify_iteration_space(nd, shape, p_src_strides, src_itemsize, + is_src_c_contig, is_src_f_contig, p_dst_strides, + dst_itemsize, is_dst_c_contig, is_dst_f_contig, + simplified_shape, simplified_src_strides, + simplified_dst_strides, src_offset, dst_offset); + + if (src_offset != 0 || dst_offset != 0) { + throw py::value_error("Reversed slice for dst is not supported"); + } + + nd += 2; + std::vector shape_and_strides(3 * nd); + + std::copy(simplified_shape.begin(), simplified_shape.end(), + shape_and_strides.begin()); + shape_and_strides[nd - 2] = src_shape[src_nd - 2]; + shape_and_strides[nd - 1] = src_shape[src_nd - 1]; + std::copy(simplified_src_strides.begin(), simplified_src_strides.end(), + shape_and_strides.begin() + nd); + shape_and_strides[2 * nd - 2] = src_strides[src_nd - 2]; + shape_and_strides[2 * nd - 1] = src_strides[src_nd - 1]; + std::copy(simplified_dst_strides.begin(), simplified_dst_strides.end(), + shape_and_strides.begin() + 2 * nd); + shape_and_strides[3 * nd - 2] = dst_strides[src_nd - 2]; + shape_and_strides[3 * nd - 1] = dst_strides[src_nd - 1]; + + std::shared_ptr shp_host_shape_and_strides = + std::make_shared(shape_and_strides); + + py::ssize_t *dev_shape_and_strides = + sycl::malloc_device(3 * nd, exec_q); + if (dev_shape_and_strides == nullptr) { + throw std::runtime_error("Unabled to allocate device memory"); + } + sycl::event copy_shape_and_strides = exec_q.copy( + shp_host_shape_and_strides->data(), dev_shape_and_strides, 3 * nd); + + py::ssize_t inner_range = + shape_and_strides[nd - 1] * shape_and_strides[nd - 2]; + py::ssize_t outer_range = src_nelems / inner_range; + + sycl::event tri_ev; + if (part == 'l') { + auto fn = tril_generic_dispatch_vector[src_typeid]; + tri_ev = + fn(exec_q, inner_range, outer_range, src_data, dst_data, nd, + dev_shape_and_strides, k, depends, {copy_shape_and_strides}); + } + else { + auto fn = triu_generic_dispatch_vector[src_typeid]; + tri_ev = + fn(exec_q, inner_range, outer_range, src_data, dst_data, nd, + dev_shape_and_strides, k, depends, {copy_shape_and_strides}); + } + + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on({tri_ev}); + auto ctx = exec_q.get_context(); + cgh.host_task( + [shp_host_shape_and_strides, dev_shape_and_strides, ctx]() { + // capture of shp_host_shape_and_strides ensure the underlying + // vector exists for the entire execution of copying kernel + sycl::free(dev_shape_and_strides, ctx); + }); + }); + return std::make_pair(keep_args_alive(exec_q, {src, dst}, {tri_ev}), + tri_ev); +} + // populate dispatch tables void init_copy_and_cast_dispatch_tables(void) { @@ -1936,6 +2243,12 @@ void init_copy_for_reshape_dispatch_vector(void) DispatchVectorBuilder dvb4; dvb4.populate_dispatch_vector(eye_dispatch_vector); + DispatchVectorBuilder dvb5; + dvb5.populate_dispatch_vector(tril_generic_dispatch_vector); + + DispatchVectorBuilder dvb6; + dvb6.populate_dispatch_vector(triu_generic_dispatch_vector); + return; } @@ -2081,4 +2394,27 @@ PYBIND11_MODULE(_tensor_impl, m) [](sycl::device dev) -> std::string { return get_default_device_complex_type(dev); }); + m.def( + "_tril", + [](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, + py::ssize_t k, sycl::queue exec_q, + const std::vector depends) + -> std::pair { + return tri(exec_q, src, dst, 'l', k, depends); + }, + "Tril helper function.", py::arg("src"), py::arg("dst"), + py::arg("k") = 0, py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + m.def( + "_triu", + [](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, + py::ssize_t k, sycl::queue exec_q, + const std::vector depends) + -> std::pair { + return tri(exec_q, src, dst, 'u', k, depends); + }, + "Triu helper function.", py::arg("src"), py::arg("dst"), + py::arg("k") = 0, py::arg("sycl_queue"), + py::arg("depends") = py::list()); } diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index c77933d56a..feabecb1ff 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1274,6 +1274,159 @@ def test_eye(dtype, usm_kind): assert np.array_equal(Xnp, dpt.asnumpy(X)) +@pytest.mark.parametrize("dtype", _all_dtypes[1:]) +def test_tril(dtype): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + if dtype in ["f8", "c16"] and q.sycl_device.has_aspect_fp64 is False: + pytest.skip( + "Device does not support double precision floating point type" + ) + shape = (2, 3, 4, 5, 5) + X = dpt.reshape( + dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape + ) + Y = dpt.tril(X) + Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + Ynp = np.tril(Xnp) + assert Y.dtype == Ynp.dtype + assert np.array_equal(Ynp, dpt.asnumpy(Y)) + + +@pytest.mark.parametrize("dtype", _all_dtypes[1:]) +def test_triu(dtype): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + if dtype in ["f8", "c16"] and q.sycl_device.has_aspect_fp64 is False: + pytest.skip( + "Device does not support double precision floating point type" + ) + shape = (4, 5) + X = dpt.reshape( + dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape + ) + Y = dpt.triu(X, 1) + Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + Ynp = np.triu(Xnp, 1) + assert Y.dtype == Ynp.dtype + assert np.array_equal(Ynp, dpt.asnumpy(Y)) + + +def test_tril_slice(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + shape = (6, 10) + X = dpt.reshape( + dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape + )[1:, ::-2] + Y = dpt.tril(X) + Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape)[1:, ::-2] + Ynp = np.tril(Xnp) + assert Y.dtype == Ynp.dtype + assert np.array_equal(Ynp, dpt.asnumpy(Y)) + + +def test_triu_permute_dims(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + shape = (2, 3, 4, 5) + X = dpt.permute_dims( + dpt.reshape( + dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape + ), + (3, 2, 1, 0), + ) + Y = dpt.triu(X) + Xnp = np.transpose( + np.arange(np.prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0) + ) + Ynp = np.triu(Xnp) + assert Y.dtype == Ynp.dtype + assert np.array_equal(Ynp, dpt.asnumpy(Y)) + + +def test_tril_broadcast_to(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + shape = (5, 5) + X = dpt.broadcast_to(dpt.ones((1), dtype="int", sycl_queue=q), shape) + Y = dpt.tril(X) + Xnp = np.broadcast_to(np.ones((1), dtype="int"), shape) + Ynp = np.tril(Xnp) + assert Y.dtype == Ynp.dtype + assert np.array_equal(Ynp, dpt.asnumpy(Y)) + + +def test_triu_bool(): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + + shape = (4, 5) + X = dpt.ones((shape), dtype="bool", sycl_queue=q) + Y = dpt.triu(X) + Xnp = np.ones((shape), dtype="bool") + Ynp = np.triu(Xnp) + assert Y.dtype == Ynp.dtype + assert np.array_equal(Ynp, dpt.asnumpy(Y)) + + +@pytest.mark.parametrize("order", ["F", "C"]) +@pytest.mark.parametrize("k", [-10, -2, -1, 3, 4, 10]) +def test_triu_order_k(order, k): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + shape = (3, 3) + X = dpt.reshape( + dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), + shape, + order=order, + ) + Y = dpt.triu(X, k) + Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order) + Ynp = np.triu(Xnp, k) + assert Y.dtype == Ynp.dtype + assert X.flags == Y.flags + assert np.array_equal(Ynp, dpt.asnumpy(Y)) + + +@pytest.mark.parametrize("order", ["F", "C"]) +@pytest.mark.parametrize("k", [-10, -4, -3, 1, 2, 10]) +def test_tril_order_k(order, k): + try: + q = dpctl.SyclQueue() + except dpctl.SyclQueueCreationError: + pytest.skip("Queue could not be created") + shape = (3, 3) + X = dpt.reshape( + dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), + shape, + order=order, + ) + Y = dpt.tril(X, k) + Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order) + Ynp = np.tril(Xnp, k) + assert Y.dtype == Ynp.dtype + assert X.flags == Y.flags + assert np.array_equal(Ynp, dpt.asnumpy(Y)) + + def test_common_arg_validation(): order = "I" # invalid order must raise ValueError @@ -1306,3 +1459,7 @@ def test_common_arg_validation(): dpt.ones_like(X) with pytest.raises(TypeError): dpt.full_like(X, 1) + with pytest.raises(TypeError): + dpt.tril(X) + with pytest.raises(TypeError): + dpt.triu(X)