@@ -80,10 +80,65 @@ sycl::event full_contig_impl(sycl::queue &exec_q,
80
80
{
81
81
dstTy fill_v = py::cast<dstTy>(py_value);
82
82
83
- using dpctl::tensor::kernels::constructors::full_contig_impl ;
83
+ sycl::event fill_ev ;
84
84
85
- sycl::event fill_ev =
86
- full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
85
+ if constexpr (sizeof (dstTy) == sizeof (char )) {
86
+ const auto memset_val = sycl::bit_cast<unsigned char >(fill_v);
87
+ fill_ev = exec_q.submit ([&](sycl::handler &cgh) {
88
+ cgh.depends_on (depends);
89
+
90
+ cgh.memset (reinterpret_cast <void *>(dst_p), memset_val,
91
+ nelems * sizeof (dstTy));
92
+ });
93
+ }
94
+ else {
95
+ bool is_zero = false ;
96
+ if constexpr (sizeof (dstTy) == 1 ) {
97
+ is_zero = (std::uint8_t {0 } == sycl::bit_cast<std::uint8_t >(fill_v));
98
+ }
99
+ else if constexpr (sizeof (dstTy) == 2 ) {
100
+ is_zero =
101
+ (std::uint16_t {0 } == sycl::bit_cast<std::uint16_t >(fill_v));
102
+ }
103
+ else if constexpr (sizeof (dstTy) == 4 ) {
104
+ is_zero =
105
+ (std::uint32_t {0 } == sycl::bit_cast<std::uint32_t >(fill_v));
106
+ }
107
+ else if constexpr (sizeof (dstTy) == 8 ) {
108
+ is_zero =
109
+ (std::uint64_t {0 } == sycl::bit_cast<std::uint64_t >(fill_v));
110
+ }
111
+ else if constexpr (sizeof (dstTy) == 16 ) {
112
+ struct UInt128
113
+ {
114
+
115
+ constexpr UInt128 () : v1{}, v2{} {}
116
+ UInt128 (const UInt128 &) = default ;
117
+
118
+ operator bool () const { return bool (v1) && bool (v2); }
119
+
120
+ std::uint64_t v1;
121
+ std::uint64_t v2;
122
+ };
123
+ is_zero = static_cast <bool >(sycl::bit_cast<UInt128>(fill_v));
124
+ }
125
+
126
+ if (is_zero) {
127
+ constexpr int memset_val = 0 ;
128
+ fill_ev = exec_q.submit ([&](sycl::handler &cgh) {
129
+ cgh.depends_on (depends);
130
+
131
+ cgh.memset (reinterpret_cast <void *>(dst_p), memset_val,
132
+ nelems * sizeof (dstTy));
133
+ });
134
+ }
135
+ else {
136
+ using dpctl::tensor::kernels::constructors::full_contig_impl;
137
+
138
+ fill_ev =
139
+ full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
140
+ }
141
+ }
87
142
88
143
return fill_ev;
89
144
}
@@ -126,7 +181,6 @@ usm_ndarray_full(const py::object &py_value,
126
181
int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
127
182
128
183
char *dst_data = dst.get_data ();
129
- sycl::event full_event;
130
184
131
185
if (dst_nelems == 1 || dst.is_c_contiguous () || dst.is_f_contiguous ()) {
132
186
auto fn = full_contig_dispatch_vector[dst_typeid];
0 commit comments