Skip to content

Commit f31fb28

Browse files
Add synchronization points in example
1 parent 61f8c54 commit f31fb28

File tree

7 files changed

+73
-77
lines changed

7 files changed

+73
-77
lines changed

examples/cython/usm_memory/blackscholes/blackscholes.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,15 @@ def black_scholes_price(c_dpt.usm_ndarray option_params_arr):
106106
call_put_prices = dpt.empty((n_opts, 2), dtype='d', sycl_queue=q)
107107
dp1 = <double *>option_params_arr.get_data()
108108
dp2 = <double *>call_put_prices.get_data()
109+
# ensure content of dp1 and dp2 is no longer worked on
110+
exec_q_ptr[0].wait()
109111
cpp_blackscholes[double](exec_q_ptr[0], n_opts, dp1, dp2)
110112
elif (typenum_ == c_dpt.UAR_FLOAT):
111113
call_put_prices = dpt.empty((n_opts, 2), dtype='f', sycl_queue=q)
112114
fp1 = <float *>option_params_arr.get_data()
113115
fp2 = <float *>call_put_prices.get_data()
116+
# ensure content of fp1 and fp2 is no longer worked on
117+
exec_q_ptr[0].wait()
114118
cpp_blackscholes[float](exec_q_ptr[0], n_opts, fp1, fp2)
115119
else:
116120
raise ValueError("Unsupported data-type")
@@ -196,11 +200,13 @@ def populate_params(
196200

197201
if (typenum_ == c_dpt.UAR_DOUBLE):
198202
dp = <double *>option_params_arr.get_data()
203+
exec_q_ptr[0].wait()
199204
cpp_populate_params[double](
200205
exec_q_ptr[0], n_opts, dp, pl, ph, sl, sh, tl, th, rl, rh, vl, vh, seed
201206
)
202207
elif (typenum_ == c_dpt.UAR_FLOAT):
203208
fp = <float *>option_params_arr.get_data()
209+
exec_q_ptr[0].wait()
204210
cpp_populate_params[float](
205211
exec_q_ptr[0], n_opts, fp, pl, ph, sl, sh, tl, th, rl, rh, vl, vh, seed
206212
)

examples/cython/usm_memory/src/sycl_blackscholes.hpp

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ constexpr int CALL = 0;
4747
constexpr int PUT = 1;
4848

4949
template <typename T>
50-
void cpp_blackscholes(sycl::queue q, size_t n_opts, T *params, T *callput)
50+
void cpp_blackscholes(sycl::queue &q, size_t n_opts, T *params, T *callput)
5151
{
5252
using data_t = T;
5353

@@ -57,56 +57,54 @@ void cpp_blackscholes(sycl::queue q, size_t n_opts, T *params, T *callput)
5757
data_t half = one / two;
5858

5959
cgh.parallel_for<class black_scholes_kernel<T>>(
60-
sycl::range<1>(n_opts),
61-
[=](sycl::id<1> idx)
62-
{
63-
const size_t i = n_params * idx[0];
64-
const data_t opt_price = params[i + PRICE];
65-
const data_t opt_strike = params[i + STRIKE];
66-
const data_t opt_maturity = params[i + MATURITY];
67-
const data_t opt_rate = params[i + RATE];
68-
const data_t opt_volatility = params[i + VOLATILITY];
69-
data_t a, b, c, y, z, e, d1, d1c, d2, d2c, w1, w2;
70-
data_t mr = -opt_rate,
71-
sig_sig_two = two * opt_volatility * opt_volatility;
72-
73-
a = sycl::log(opt_price / opt_strike);
74-
b = opt_maturity * mr;
75-
z = opt_maturity * sig_sig_two;
76-
77-
c = quarter * z;
78-
e = sycl::exp(b);
79-
y = sycl::rsqrt(z);
80-
81-
a = b - a;
82-
w1 = (a - c) * y;
83-
w2 = (a + c) * y;
84-
85-
if (w1 < zero) {
86-
d1 = sycl::erfc(w1) * half;
87-
d1c = one - d1;
88-
}
89-
else {
90-
d1c = sycl::erfc(-w1) * half;
91-
d1 = one - d1c;
92-
}
93-
if (w2 < zero) {
94-
d2 = sycl::erfc(w2) * half;
95-
d2c = one - d2;
96-
}
97-
else {
98-
d2c = sycl::erfc(-w2) * half;
99-
d2 = one - d2c;
100-
}
101-
102-
e *= opt_strike;
103-
data_t call_price = opt_price * d1 - e * d2;
104-
data_t put_price = e * d2c - opt_price * d1c;
105-
106-
const size_t callput_i = n_prices * idx[0];
107-
callput[callput_i + CALL] = call_price;
108-
callput[callput_i + PUT] = put_price;
109-
});
60+
sycl::range<1>(n_opts), [=](sycl::id<1> idx) {
61+
const size_t i = n_params * idx[0];
62+
const data_t opt_price = params[i + PRICE];
63+
const data_t opt_strike = params[i + STRIKE];
64+
const data_t opt_maturity = params[i + MATURITY];
65+
const data_t opt_rate = params[i + RATE];
66+
const data_t opt_volatility = params[i + VOLATILITY];
67+
data_t a, b, c, y, z, e, d1, d1c, d2, d2c, w1, w2;
68+
data_t mr = -opt_rate,
69+
sig_sig_two = two * opt_volatility * opt_volatility;
70+
71+
a = sycl::log(opt_price / opt_strike);
72+
b = opt_maturity * mr;
73+
z = opt_maturity * sig_sig_two;
74+
75+
c = quarter * z;
76+
e = sycl::exp(b);
77+
y = sycl::rsqrt(z);
78+
79+
a = b - a;
80+
w1 = (a - c) * y;
81+
w2 = (a + c) * y;
82+
83+
if (w1 < zero) {
84+
d1 = sycl::erfc(w1) * half;
85+
d1c = one - d1;
86+
}
87+
else {
88+
d1c = sycl::erfc(-w1) * half;
89+
d1 = one - d1c;
90+
}
91+
if (w2 < zero) {
92+
d2 = sycl::erfc(w2) * half;
93+
d2c = one - d2;
94+
}
95+
else {
96+
d2c = sycl::erfc(-w2) * half;
97+
d2 = one - d2c;
98+
}
99+
100+
e *= opt_strike;
101+
data_t call_price = opt_price * d1 - e * d2;
102+
data_t put_price = e * d2c - opt_price * d1c;
103+
104+
const size_t callput_i = n_prices * idx[0];
105+
callput[callput_i + CALL] = call_price;
106+
callput[callput_i + PUT] = put_price;
107+
});
110108
});
111109

112110
e.wait_and_throw();

examples/pybind11/external_usm_allocation/external_usm_allocation/_usm_alloc_example.cpp

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,12 @@ struct DMatrix
5252
DMatrix(const DMatrix &) = default;
5353
DMatrix(DMatrix &&) = default;
5454

55-
size_t get_n() const
56-
{
57-
return n_;
58-
}
59-
size_t get_m() const
60-
{
61-
return m_;
62-
}
63-
vec_t &get_vector()
64-
{
65-
return vec_;
66-
}
67-
sycl::queue get_queue() const
68-
{
69-
return q_;
70-
}
55+
size_t get_n() const { return n_; }
56+
size_t get_m() const { return m_; }
57+
vec_t &get_vector() { return vec_; }
58+
sycl::queue get_queue() const { return q_; }
7159

72-
double get_element(size_t i, size_t j)
73-
{
74-
return vec_.at(i * m_ + j);
75-
}
60+
double get_element(size_t i, size_t j) { return vec_.at(i * m_ + j); }
7661

7762
private:
7863
size_t n_;
@@ -114,6 +99,9 @@ py::dict construct_sua_iface(DMatrix &m)
11499
iface["typestr"] = "|f8";
115100
iface["syclobj"] = syclobj;
116101

102+
// ensure that content of array is flushed out
103+
m.get_queue().wait();
104+
117105
return iface;
118106
}
119107

examples/pybind11/onemkl_gemv/solve.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,14 @@ def cg_solve(A, b):
127127
converged is False if solver has not converged, or the iteration number
128128
"""
129129
exec_queue = A.sycl_queue
130+
exec_queue.wait()
131+
130132
x = dpt.zeros_like(b)
131133
Ap = dpt.empty_like(x)
132134

133135
all_host_tasks = []
134-
r = dpt.copy(b) # synchronous copy
135-
p = dpt.copy(b) # synchronous copy
136+
r = dpt.copy(b)
137+
p = dpt.copy(b)
136138

137139
rsold = sycl_gemm.norm_squared_blocking(exec_queue, r)
138140
if rsold < 1e-20:

examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ namespace py = pybind11;
4141
using dpctl::utils::keep_args_alive;
4242

4343
std::pair<sycl::event, sycl::event>
44-
py_gemv(sycl::queue q,
44+
py_gemv(sycl::queue &q,
4545
dpctl::tensor::usm_ndarray matrix,
4646
dpctl::tensor::usm_ndarray vector,
4747
dpctl::tensor::usm_ndarray result,
4848
const std::vector<sycl::event> &depends = {})
4949
{
5050
if (matrix.get_ndim() != 2 || vector.get_ndim() != 1 ||
51-
result.get_ndim() != 1) {
51+
result.get_ndim() != 1)
52+
{
5253
throw std::runtime_error(
5354
"Inconsistent dimensions, expecting matrix and a vector");
5455
}

examples/pybind11/use_dpctl_sycl_kernel/tests/test_user_kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def test_kernel_submit_through_extension():
6464
x = dpt.arange(0, stop=13, step=1, dtype="i4", sycl_queue=q)
6565
y = dpt.zeros_like(x)
6666

67+
q.wait()
6768
uk.submit_custom_kernel(q, krn, x, y, [])
6869

6970
assert np.array_equal(dpt.asnumpy(y), np.arange(0, 26, step=2, dtype="i4"))

examples/pybind11/use_dpctl_sycl_kernel/use_kernel/_example.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535

3636
namespace py = pybind11;
3737

38-
void submit_custom_kernel(sycl::queue q,
39-
sycl::kernel krn,
38+
void submit_custom_kernel(sycl::queue &q,
39+
sycl::kernel &krn,
4040
dpctl::tensor::usm_ndarray x,
4141
dpctl::tensor::usm_ndarray y,
4242
const std::vector<sycl::event> &depends = {})

0 commit comments

Comments
 (0)