diff --git a/examples/pybind11/onemkl_gemv/sycl_timing_solver.py b/examples/pybind11/onemkl_gemv/sycl_timing_solver.py index 8e42a02dc0..253e512bfe 100644 --- a/examples/pybind11/onemkl_gemv/sycl_timing_solver.py +++ b/examples/pybind11/onemkl_gemv/sycl_timing_solver.py @@ -63,41 +63,66 @@ timer = dpctl.SyclTimer(time_scale=1e3) -iters = [] -for i in range(6): - with timer(api_dev.sycl_queue): - x, conv_in = solve.cg_solve(A, b) - - print(i, "(host_dt, device_dt)=", timer.dt) - iters.append(conv_in) - assert x.usm_type == A.usm_type - assert x.usm_type == b.usm_type - assert x.sycl_queue == A.sycl_queue - assert x.sycl_queue == b.sycl_queue - -print("Converged in: ", iters) - -hev, ev = sycl_gemm.gemv(q, A, x, r) -hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) -rs = sycl_gemm.norm_squared_blocking(q, delta) -dpctl.SyclEvent.wait_for([hev, hev2]) -print(f"Python solution residual norm squared: {rs}") + +def time_python_solver(num_iters=6): + """ + Time solver implemented in Python with use of asynchronous + SYCL kernel submission. + """ + global x + iters = [] + for i in range(num_iters): + with timer(api_dev.sycl_queue): + x, conv_in = solve.cg_solve(A, b) + + print(i, "(host_dt, device_dt)=", timer.dt) + iters.append(conv_in) + assert x.usm_type == A.usm_type + assert x.usm_type == b.usm_type + assert x.sycl_queue == A.sycl_queue + assert x.sycl_queue == b.sycl_queue + + return iters + + +def time_cpp_solver(num_iters=6): + """ + Time solver implemented in C++ but callable from Python. + C++ implementation uses the same algorithm and submits same + kernels asynchronously, but bypasses Python binding overhead + incurred when algorithm is driver from Python. + """ + global x_cpp + x_cpp = dpt.empty_like(b) + iters = [] + for i in range(num_iters): + with timer(api_dev.sycl_queue): + conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp) + + print(i, "(host_dt, device_dt)=", timer.dt) + iters.append(conv_in) + + return iters + + +def compute_residual(x): + """ + Computes quality of the solution, `norm_squared(A@x - b)`. + """ + assert isinstance(x, dpt.usm_ndarray) + q = A.sycl_queue + hev, ev = sycl_gemm.gemv(q, A, x, r) + hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) + rs = sycl_gemm.norm_squared_blocking(q, delta) + dpctl.SyclEvent.wait_for([hev, hev2]) + return rs + + +print("Converged in: ", time_python_solver()) +print(f"Python solution residual norm squared: {compute_residual(x)}") assert q == api_dev.sycl_queue print("") -x_cpp = dpt.empty_like(b) -iters = [] -for i in range(6): - with timer(api_dev.sycl_queue): - conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp) - - print(i, "(host_dt, device_dt)=", timer.dt) - iters.append(conv_in) - -print("Converged in: ", iters) -hev, ev = sycl_gemm.gemv(q, A, x_cpp, r) -hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) -rs = sycl_gemm.norm_squared_blocking(q, delta) -dpctl.SyclEvent.wait_for([hev, hev2]) -print(f"cpp_cg_solve solution residual norm squared: {rs}") +print("Converged in: ", time_cpp_solver()) +print(f"cpp_cg_solve solution residual norm squared: {compute_residual(x_cpp)}")