diff --git a/examples/pybind11/onemkl_gemv/solve.py b/examples/pybind11/onemkl_gemv/solve.py index 6304808078..cb415becc9 100644 --- a/examples/pybind11/onemkl_gemv/solve.py +++ b/examples/pybind11/onemkl_gemv/solve.py @@ -21,10 +21,6 @@ import dpctl.tensor as dpt -def empty_like(A): - return dpt.empty(A.shape, A.dtype, device=A.device) - - def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]): """Chebyshev iterative solver using SYCL routines""" d = (lMax + lMin) / 2 @@ -33,9 +29,9 @@ def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]): x = dpt.copy(x0) exec_queue = A.sycl_queue assert exec_queue == x.sycl_queue - Ax = empty_like(A[:, 0]) - r = empty_like(Ax) - p = empty_like(Ax) + Ax = dpt.empty_like(A[:, 0]) + r = dpt.empty_like(Ax) + p = dpt.empty_like(Ax) e_x = dpctl.SyclEvent() # Ax = A @ x @@ -131,12 +127,13 @@ def cg_solve(A, b): converged is False if solver has not converged, or the iteration number """ exec_queue = A.sycl_queue - x = dpt.zeros(b.shape, dtype=b.dtype) - Ap = empty_like(x) + x = dpt.zeros_like(b) + Ap = dpt.empty_like(x) all_host_tasks = [] - r = dpt.copy(b) - p = dpt.copy(b) + r = dpt.copy(b) # synchronous copy + p = dpt.copy(b) # synchronous copy + rsold = sycl_gemm.norm_squared_blocking(exec_queue, r) if rsold < 1e-20: return (b, 0) @@ -147,22 +144,21 @@ def cg_solve(A, b): e_x = dpctl.SyclEvent() for i in range(max_iters): # Ap = A @ p - he_dot, e_dot = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p]) - all_host_tasks.append(he_dot) + he_gemv, e_gemv = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p]) + all_host_tasks.append(he_gemv) # alpha = rsold / dot(p, Ap) alpha = rsold / sycl_gemm.dot_blocking( - exec_queue, p, Ap, depends=[e_dot] + exec_queue, p, Ap, depends=[e_p, e_gemv] ) # x = x + alpha * p he1_x_update, e1_x_update = sycl_gemm.axpby_inplace( - exec_queue, alpha, p, 1, x, depends=[e_p, e_x] + exec_queue, alpha, p, 1, x, depends=[e_x] ) all_host_tasks.append(he1_x_update) - e_x = e1_x_update # r = r - alpha * Ap he2_r_update, e2_r_update = sycl_gemm.axpby_inplace( - exec_queue, -alpha, Ap, 1, r, depends=[e_p] + exec_queue, -alpha, Ap, 1, r ) all_host_tasks.append(he2_r_update) diff --git a/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp b/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp index 0531b8880f..1fda627c2a 100644 --- a/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp +++ b/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp @@ -497,6 +497,7 @@ py::object py_dot_blocking(sycl::queue q, reinterpret_cast(v2_typeless_ptr), 1, res_usm, depends); T res_v{}; q.copy(res_usm, &res_v, 1, {dot_ev}).wait_and_throw(); + sycl::free(res_usm, q); res = py::float_(res_v); } else if (v1_typenum == UAR_FLOAT) { @@ -507,6 +508,7 @@ py::object py_dot_blocking(sycl::queue q, reinterpret_cast(v2_typeless_ptr), 1, res_usm, depends); T res_v(0); q.copy(res_usm, &res_v, 1, {dot_ev}).wait_and_throw(); + sycl::free(res_usm, q); res = py::float_(res_v); } else if (v1_typenum == UAR_CDOUBLE) { @@ -517,6 +519,7 @@ py::object py_dot_blocking(sycl::queue q, reinterpret_cast(v2_typeless_ptr), 1, res_usm, depends); T res_v{}; q.copy(res_usm, &res_v, 1, {dotc_ev}).wait_and_throw(); + sycl::free(res_usm, q); res = py::cast(res_v); } else if (v1_typenum == UAR_CFLOAT) { @@ -527,6 +530,7 @@ py::object py_dot_blocking(sycl::queue q, reinterpret_cast(v2_typeless_ptr), 1, res_usm, depends); T res_v{}; q.copy(res_usm, &res_v, 1, {dotc_ev}).wait_and_throw(); + sycl::free(res_usm, q); res = py::cast(res_v); } else { diff --git a/examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp b/examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp index 6e0396c7a1..55d23c73c3 100644 --- a/examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp +++ b/examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp @@ -177,16 +177,16 @@ int cg_solve(sycl::queue exec_q, } int converged_at = max_iters; - sycl::event prev_dep = copy_to_p_ev; + sycl::event e_p = copy_to_p_ev; sycl::event e_x = fill_ev; for (std::int64_t i = 0; i < max_iters; ++i) { sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv( exec_q, oneapi::mkl::transpose::N, n, n, T(1), Amat, n, p, 1, T(0), - Ap, 1, {prev_dep}); + Ap, 1, {e_p}); sycl::event pAp_dot_ev = oneapi::mkl::blas::row_major::dot( - exec_q, n, p, 1, Ap, 1, pAp_dot_dev, {prev_dep, gemv_ev}); + exec_q, n, p, 1, Ap, 1, pAp_dot_dev, {e_p, gemv_ev}); T pAp_dot_host{}; exec_q.copy(pAp_dot_dev, &pAp_dot_host, 1, {pAp_dot_ev}) @@ -212,8 +212,7 @@ int cg_solve(sycl::queue exec_q, T beta = rs_new / rs_old; // p = r + beta * p - prev_dep = - detail::axpby_inplace(exec_q, n, T(1), r, beta, p, {r_update_ev}); + e_p = detail::axpby_inplace(exec_q, n, T(1), r, beta, p, {r_update_ev}); e_x = x_update_ev; rs_old = rs_new; diff --git a/examples/pybind11/onemkl_gemv/sycl_timing_solver.py b/examples/pybind11/onemkl_gemv/sycl_timing_solver.py index 42178ee0ed..8e42a02dc0 100644 --- a/examples/pybind11/onemkl_gemv/sycl_timing_solver.py +++ b/examples/pybind11/onemkl_gemv/sycl_timing_solver.py @@ -55,6 +55,12 @@ A = dpt.asarray(Anp, "d", device=api_dev) b = dpt.asarray(bnp, "d", device=api_dev) +assert A.sycl_queue == b.sycl_queue + +# allocate buffers for computation of residual +r = dpt.empty_like(b) +delta = dpt.empty_like(b) + timer = dpctl.SyclTimer(time_scale=1e3) iters = [] @@ -64,17 +70,22 @@ print(i, "(host_dt, device_dt)=", timer.dt) iters.append(conv_in) + assert x.usm_type == A.usm_type + assert x.usm_type == b.usm_type + assert x.sycl_queue == A.sycl_queue + assert x.sycl_queue == b.sycl_queue print("Converged in: ", iters) -r = dpt.empty_like(b) hev, ev = sycl_gemm.gemv(q, A, x, r) -delta = dpt.empty_like(b) hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) rs = sycl_gemm.norm_squared_blocking(q, delta) dpctl.SyclEvent.wait_for([hev, hev2]) print(f"Python solution residual norm squared: {rs}") +assert q == api_dev.sycl_queue +print("") + x_cpp = dpt.empty_like(b) iters = [] for i in range(6):