Skip to content

Commit b5d361f

Browse files
Merge pull request #833 from IntelPython/refine-gemv-example
Refine gemv example
2 parents b196e73 + 50a3243 commit b5d361f

File tree

4 files changed

+34
-24
lines changed

4 files changed

+34
-24
lines changed

examples/pybind11/onemkl_gemv/solve.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@
2121
import dpctl.tensor as dpt
2222

2323

24-
def empty_like(A):
25-
return dpt.empty(A.shape, A.dtype, device=A.device)
26-
27-
2824
def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
2925
"""Chebyshev iterative solver using SYCL routines"""
3026
d = (lMax + lMin) / 2
@@ -33,9 +29,9 @@ def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
3329
x = dpt.copy(x0)
3430
exec_queue = A.sycl_queue
3531
assert exec_queue == x.sycl_queue
36-
Ax = empty_like(A[:, 0])
37-
r = empty_like(Ax)
38-
p = empty_like(Ax)
32+
Ax = dpt.empty_like(A[:, 0])
33+
r = dpt.empty_like(Ax)
34+
p = dpt.empty_like(Ax)
3935

4036
e_x = dpctl.SyclEvent()
4137
# Ax = A @ x
@@ -131,12 +127,13 @@ def cg_solve(A, b):
131127
converged is False if solver has not converged, or the iteration number
132128
"""
133129
exec_queue = A.sycl_queue
134-
x = dpt.zeros(b.shape, dtype=b.dtype)
135-
Ap = empty_like(x)
130+
x = dpt.zeros_like(b)
131+
Ap = dpt.empty_like(x)
136132

137133
all_host_tasks = []
138-
r = dpt.copy(b)
139-
p = dpt.copy(b)
134+
r = dpt.copy(b) # synchronous copy
135+
p = dpt.copy(b) # synchronous copy
136+
140137
rsold = sycl_gemm.norm_squared_blocking(exec_queue, r)
141138
if rsold < 1e-20:
142139
return (b, 0)
@@ -147,22 +144,21 @@ def cg_solve(A, b):
147144
e_x = dpctl.SyclEvent()
148145
for i in range(max_iters):
149146
# Ap = A @ p
150-
he_dot, e_dot = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p])
151-
all_host_tasks.append(he_dot)
147+
he_gemv, e_gemv = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p])
148+
all_host_tasks.append(he_gemv)
152149
# alpha = rsold / dot(p, Ap)
153150
alpha = rsold / sycl_gemm.dot_blocking(
154-
exec_queue, p, Ap, depends=[e_dot]
151+
exec_queue, p, Ap, depends=[e_p, e_gemv]
155152
)
156153
# x = x + alpha * p
157154
he1_x_update, e1_x_update = sycl_gemm.axpby_inplace(
158-
exec_queue, alpha, p, 1, x, depends=[e_p, e_x]
155+
exec_queue, alpha, p, 1, x, depends=[e_x]
159156
)
160157
all_host_tasks.append(he1_x_update)
161-
e_x = e1_x_update
162158

163159
# r = r - alpha * Ap
164160
he2_r_update, e2_r_update = sycl_gemm.axpby_inplace(
165-
exec_queue, -alpha, Ap, 1, r, depends=[e_p]
161+
exec_queue, -alpha, Ap, 1, r
166162
)
167163
all_host_tasks.append(he2_r_update)
168164

examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ py::object py_dot_blocking(sycl::queue q,
497497
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
498498
T res_v{};
499499
q.copy<T>(res_usm, &res_v, 1, {dot_ev}).wait_and_throw();
500+
sycl::free(res_usm, q);
500501
res = py::float_(res_v);
501502
}
502503
else if (v1_typenum == UAR_FLOAT) {
@@ -507,6 +508,7 @@ py::object py_dot_blocking(sycl::queue q,
507508
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
508509
T res_v(0);
509510
q.copy<T>(res_usm, &res_v, 1, {dot_ev}).wait_and_throw();
511+
sycl::free(res_usm, q);
510512
res = py::float_(res_v);
511513
}
512514
else if (v1_typenum == UAR_CDOUBLE) {
@@ -517,6 +519,7 @@ py::object py_dot_blocking(sycl::queue q,
517519
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
518520
T res_v{};
519521
q.copy<T>(res_usm, &res_v, 1, {dotc_ev}).wait_and_throw();
522+
sycl::free(res_usm, q);
520523
res = py::cast(res_v);
521524
}
522525
else if (v1_typenum == UAR_CFLOAT) {
@@ -527,6 +530,7 @@ py::object py_dot_blocking(sycl::queue q,
527530
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
528531
T res_v{};
529532
q.copy<T>(res_usm, &res_v, 1, {dotc_ev}).wait_and_throw();
533+
sycl::free(res_usm, q);
530534
res = py::cast(res_v);
531535
}
532536
else {

examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,16 @@ int cg_solve(sycl::queue exec_q,
177177
}
178178

179179
int converged_at = max_iters;
180-
sycl::event prev_dep = copy_to_p_ev;
180+
sycl::event e_p = copy_to_p_ev;
181181
sycl::event e_x = fill_ev;
182182

183183
for (std::int64_t i = 0; i < max_iters; ++i) {
184184
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
185185
exec_q, oneapi::mkl::transpose::N, n, n, T(1), Amat, n, p, 1, T(0),
186-
Ap, 1, {prev_dep});
186+
Ap, 1, {e_p});
187187

188188
sycl::event pAp_dot_ev = oneapi::mkl::blas::row_major::dot(
189-
exec_q, n, p, 1, Ap, 1, pAp_dot_dev, {prev_dep, gemv_ev});
189+
exec_q, n, p, 1, Ap, 1, pAp_dot_dev, {e_p, gemv_ev});
190190

191191
T pAp_dot_host{};
192192
exec_q.copy<T>(pAp_dot_dev, &pAp_dot_host, 1, {pAp_dot_ev})
@@ -212,8 +212,7 @@ int cg_solve(sycl::queue exec_q,
212212
T beta = rs_new / rs_old;
213213

214214
// p = r + beta * p
215-
prev_dep =
216-
detail::axpby_inplace(exec_q, n, T(1), r, beta, p, {r_update_ev});
215+
e_p = detail::axpby_inplace(exec_q, n, T(1), r, beta, p, {r_update_ev});
217216
e_x = x_update_ev;
218217

219218
rs_old = rs_new;

examples/pybind11/onemkl_gemv/sycl_timing_solver.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@
5555
A = dpt.asarray(Anp, "d", device=api_dev)
5656
b = dpt.asarray(bnp, "d", device=api_dev)
5757

58+
assert A.sycl_queue == b.sycl_queue
59+
60+
# allocate buffers for computation of residual
61+
r = dpt.empty_like(b)
62+
delta = dpt.empty_like(b)
63+
5864
timer = dpctl.SyclTimer(time_scale=1e3)
5965

6066
iters = []
@@ -64,17 +70,22 @@
6470

6571
print(i, "(host_dt, device_dt)=", timer.dt)
6672
iters.append(conv_in)
73+
assert x.usm_type == A.usm_type
74+
assert x.usm_type == b.usm_type
75+
assert x.sycl_queue == A.sycl_queue
76+
assert x.sycl_queue == b.sycl_queue
6777

6878
print("Converged in: ", iters)
6979

70-
r = dpt.empty_like(b)
7180
hev, ev = sycl_gemm.gemv(q, A, x, r)
72-
delta = dpt.empty_like(b)
7381
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
7482
rs = sycl_gemm.norm_squared_blocking(q, delta)
7583
dpctl.SyclEvent.wait_for([hev, hev2])
7684
print(f"Python solution residual norm squared: {rs}")
7785

86+
assert q == api_dev.sycl_queue
87+
print("")
88+
7889
x_cpp = dpt.empty_like(b)
7990
iters = []
8091
for i in range(6):

0 commit comments

Comments
 (0)