|
| 1 | +# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +"""Provides a helpers to populate the list of kernel arguments that will |
| 6 | +be passed to a DPCTLQueue_Submit function call by a KernelLaunchIRBuilder |
| 7 | +object. |
| 8 | +""" |
| 9 | + |
| 10 | +from typing import NamedTuple |
| 11 | + |
| 12 | +import dpctl |
| 13 | +from llvmlite import ir as llvmir |
| 14 | +from numba.core import types |
| 15 | +from numba.core.cpu import CPUContext |
| 16 | + |
| 17 | +from numba_dpex import utils |
| 18 | +from numba_dpex.core.types import USMNdArray |
| 19 | +from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType |
| 20 | +from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum |
| 21 | + |
| 22 | + |
| 23 | +class KernelArg(NamedTuple): |
| 24 | + """Stores the llvm IR value and the dpctl typeid for a kernel argument.""" |
| 25 | + |
| 26 | + llvm_val: llvmir.Instruction |
| 27 | + typeid: int |
| 28 | + |
| 29 | + |
| 30 | +class KernelLaunchArgBuilder: |
| 31 | + """Helper to generate the flattened list of kernel arguments to be |
| 32 | + passed to a DPCTLQueue_Submit function. |
| 33 | +
|
| 34 | + **Note** Two separate data models are used when building a flattened |
| 35 | + kernel argument for the following reason: |
| 36 | +
|
| 37 | + Different numba-dpex targets can use different data models for the same |
| 38 | + data type that may have different number of attributes and a different |
| 39 | + type for each attribute. |
| 40 | +
|
| 41 | + In the case the DpnpNdArray type, two separate data models are used for |
| 42 | + the CPUTarget and for the SPIRVTarget. The SPIRVTarget does not have the |
| 43 | + ``parent``, ``meminfo`` and ``sycl_queue`` attributes that are present |
| 44 | + in the data model used by the CPUTarget. The SPIRVTarget's data model |
| 45 | + for DpnpNdArray also requires an explicit address space qualifier for |
| 46 | + the ``data`` attribute. |
| 47 | +
|
| 48 | + When generating the LLVM IR for the host-side control code for executing |
| 49 | + a SPIR-V kernel, the kernel arguments are represented using the |
| 50 | + CPUTarget's data model for each argument's type. However, the actual |
| 51 | + kernel function generated as a SPIR-V binary by the SPIRVTarget uses its |
| 52 | + own data model manager to build the flattened kernel function argument |
| 53 | + list. For this reason, when building the flattened argument list for a |
| 54 | + kernel launch call the host data model is used to extract the |
| 55 | + required attributes and then the kernel data model is used to get the |
| 56 | + correct type for the attribute. |
| 57 | + """ |
| 58 | + |
| 59 | + def __init__( # pylint: disable=too-many-arguments |
| 60 | + self, |
| 61 | + target_context: CPUContext, |
| 62 | + irbuilder: llvmir.IRBuilder, |
| 63 | + arg_type, |
| 64 | + arg_packed_llvm_val, |
| 65 | + arg_kernel_datamodel, |
| 66 | + ) -> list[KernelArg]: |
| 67 | + |
| 68 | + self._context = target_context |
| 69 | + self._builder = irbuilder |
| 70 | + self._arg_type = arg_type |
| 71 | + self._arg_llvm_val = arg_packed_llvm_val |
| 72 | + self._arg_host_datamodel = self._context.data_model_manager.lookup( |
| 73 | + self._arg_type |
| 74 | + ) |
| 75 | + self._arg_kernel_datamodel = arg_kernel_datamodel |
| 76 | + |
| 77 | + def get_kernel_arg_list(self) -> list[KernelArg]: |
| 78 | + """Returns a list of KernelArg objects representing a flattened kernel |
| 79 | + argument. |
| 80 | +
|
| 81 | + Returns: |
| 82 | + list[KernelArg]: List of flattened KernelArg objects |
| 83 | + """ |
| 84 | + kernel_arg_list = [] |
| 85 | + |
| 86 | + if isinstance(self._arg_type, USMNdArray): |
| 87 | + kernel_arg_list.extend( |
| 88 | + self._build_array_arg( |
| 89 | + llvm_array_val=self._arg_llvm_val, |
| 90 | + is_local_accessor=isinstance( |
| 91 | + self._arg_type, LocalAccessorType |
| 92 | + ), |
| 93 | + ) |
| 94 | + ) |
| 95 | + elif self._arg_type == types.complex64: |
| 96 | + kernel_arg_list.extend( |
| 97 | + self._build_complex_arg( |
| 98 | + llvm_val=self._arg_llvm_val, numba_type=types.float32 |
| 99 | + ) |
| 100 | + ) |
| 101 | + elif self._arg_type == types.complex128: |
| 102 | + kernel_arg_list.extend( |
| 103 | + self._build_complex_arg( |
| 104 | + llvm_val=self._arg_llvm_val, numba_type=types.float64 |
| 105 | + ) |
| 106 | + ) |
| 107 | + else: |
| 108 | + kernel_arg_list.extend( |
| 109 | + self._build_arg( |
| 110 | + llvm_val=self._arg_llvm_val, numba_type=self._arg_type |
| 111 | + ) |
| 112 | + ) |
| 113 | + |
| 114 | + return kernel_arg_list |
| 115 | + |
| 116 | + def print_kernel_arg_list(self, args_list: list[KernelArg]) -> None: |
| 117 | + """Prints out the kernel argument list in a human readable format. |
| 118 | +
|
| 119 | + Args: |
| 120 | + args_list (list[KernelArg]): List of kernel arguments to be printed |
| 121 | + """ |
| 122 | + print(f"Number of flattened kernel arguments: {len(args_list)}") |
| 123 | + for karg in args_list: |
| 124 | + print(f" {karg.llvm_val} of typeid {karg.typeid}") |
| 125 | + |
| 126 | + def _allocate_local_accessor_metadata_struct(self): |
| 127 | + """Allocates a struct into the current function to store the metadata |
| 128 | + that should be passed to libsyclinterface to allocate a |
| 129 | + sycl::local_accessor object. The constructor of the sycl::local_accessor |
| 130 | + class is: local_accessor<Ty, Ndim>(range<Ndims> r). |
| 131 | +
|
| 132 | + For this reason, the struct is allocated as: |
| 133 | +
|
| 134 | + LOCAL_ACCESSOR_MDSTRUCT_TYPE = llvmir.LiteralStructType( |
| 135 | + [ |
| 136 | + llvmir.IntType(64), # Ndim (0..3] |
| 137 | + llvmir.IntType(32), # typeid |
| 138 | + llvmir.IntType(64), # Dim0 extent |
| 139 | + llvmir.IntType(64), # Dim1 extent or NULL |
| 140 | + llvmir.IntType(64), # Dim2 extent or NULL |
| 141 | + ] |
| 142 | + ) |
| 143 | + """ |
| 144 | + local_accessor_mdstruct_type = llvmir.LiteralStructType( |
| 145 | + [ |
| 146 | + llvmir.IntType(64), |
| 147 | + llvmir.IntType(32), |
| 148 | + llvmir.IntType(64), |
| 149 | + llvmir.IntType(64), |
| 150 | + llvmir.IntType(64), |
| 151 | + ] |
| 152 | + ) |
| 153 | + |
| 154 | + struct_ref = None |
| 155 | + with self._builder.goto_entry_block(): |
| 156 | + struct_ref = self._builder.alloca(typ=local_accessor_mdstruct_type) |
| 157 | + |
| 158 | + return struct_ref |
| 159 | + |
| 160 | + def _build_arg(self, llvm_val, numba_type): |
| 161 | + """Returns a KernelArg to be passed to a DPCTLQueue_Submit call. |
| 162 | +
|
| 163 | + The passed in LLVM IR Value is bitcast to a void* and the |
| 164 | + numba/numba_dpex type object is mapped to the corresponding |
| 165 | + DPCTLKernelArgType enum value and returned back as a KernelArg object. |
| 166 | +
|
| 167 | + Args: |
| 168 | + llvm_val: An LLVM IR Value that will be stored into the arguments |
| 169 | + array |
| 170 | + numba_type: A Numba type that will be converted to a |
| 171 | + DPCTLKernelArgType enum and stored into the argument types |
| 172 | + list array |
| 173 | + Returns: |
| 174 | + KernelArg: Tuple corresponding to the LLVM IR Instruction and |
| 175 | + DPCTLKernelArgType enum value. |
| 176 | + """ |
| 177 | + llvm_val = self._builder.bitcast( |
| 178 | + llvm_val, |
| 179 | + utils.get_llvm_type(context=self._context, type=types.voidptr), |
| 180 | + ) |
| 181 | + typeid = numba_type_to_dpctl_typenum(self._context, numba_type) |
| 182 | + |
| 183 | + return [KernelArg(llvm_val, typeid)] |
| 184 | + |
| 185 | + def _build_unituple_member_arg(self, llvm_val, attr_pos, ndims): |
| 186 | + kernel_arg_list = [] |
| 187 | + array_attr = self._builder.gep( |
| 188 | + llvm_val, |
| 189 | + [ |
| 190 | + self._context.get_constant(types.int32, 0), |
| 191 | + self._context.get_constant(types.int32, attr_pos), |
| 192 | + ], |
| 193 | + ) |
| 194 | + |
| 195 | + for ndim in range(ndims): |
| 196 | + kernel_arg_list.extend( |
| 197 | + self._build_collections_attr_arg( |
| 198 | + llvm_val=array_attr, |
| 199 | + attr_index=ndim, |
| 200 | + attr_type=types.int64, |
| 201 | + ) |
| 202 | + ) |
| 203 | + |
| 204 | + return kernel_arg_list |
| 205 | + |
| 206 | + def _build_collections_attr_arg(self, llvm_val, attr_index, attr_type): |
| 207 | + array_attr = self._builder.gep( |
| 208 | + llvm_val, |
| 209 | + [ |
| 210 | + self._context.get_constant(types.int32, 0), |
| 211 | + self._context.get_constant(types.int32, attr_index), |
| 212 | + ], |
| 213 | + ) |
| 214 | + |
| 215 | + if isinstance(attr_type, (types.misc.RawPointer, types.misc.CPointer)): |
| 216 | + array_attr = self._builder.load(array_attr) |
| 217 | + |
| 218 | + return self._build_arg(llvm_val=array_attr, numba_type=attr_type) |
| 219 | + |
| 220 | + def _build_complex_arg(self, llvm_val, numba_type): |
| 221 | + """Creates a list of LLVM Values for an unpacked complex kernel |
| 222 | + argument. |
| 223 | + """ |
| 224 | + kernel_arg_list = [] |
| 225 | + |
| 226 | + kernel_arg_list.extend( |
| 227 | + self._build_collections_attr_arg( |
| 228 | + llvm_val=llvm_val, |
| 229 | + attr_index=0, |
| 230 | + attr_type=numba_type, |
| 231 | + ) |
| 232 | + ) |
| 233 | + kernel_arg_list.extend( |
| 234 | + self._build_collections_attr_arg( |
| 235 | + llvm_val=llvm_val, |
| 236 | + attr_index=1, |
| 237 | + attr_type=numba_type, |
| 238 | + ) |
| 239 | + ) |
| 240 | + |
| 241 | + return kernel_arg_list |
| 242 | + |
| 243 | + def _build_local_accessor_metadata_arg( |
| 244 | + self, |
| 245 | + llvm_val, |
| 246 | + data_attr_ty, |
| 247 | + ): |
| 248 | + """Handles the special case of building the kernel argument for the data |
| 249 | + attribute of a kernel_api.LocalAccessor object. |
| 250 | +
|
| 251 | + A kernel_api.LocalAccessor conceptually represents a device-only memory |
| 252 | + allocation. The mock kernel_api.LocalAccessor uses a numpy.ndarray to |
| 253 | + represent the data allocation. The numpy.ndarray cannot be passed to the |
| 254 | + kernel and is ignored when building the kernel argument. Instead, a |
| 255 | + struct is allocated to store the metadata about the size of the device |
| 256 | + memory allocation and a reference to the struct is passed to the |
| 257 | + DPCTLQueue_Submit call. The DPCTLQueue_Submit then constructs a |
| 258 | + sycl::local_accessor object using the metadata and passes the |
| 259 | + sycl::local_accessor as the kernel argument, letting the DPC++ runtime |
| 260 | + handle proper device memory allocation. |
| 261 | + """ |
| 262 | + shape_member = self._arg_kernel_datamodel.get_member_fe_type("shape") |
| 263 | + shape_member_pos = self._arg_host_datamodel.get_field_position("shape") |
| 264 | + ndim = shape_member.count |
| 265 | + |
| 266 | + mdstruct_ref = self._allocate_local_accessor_metadata_struct() |
| 267 | + mdstruct = self._builder.load(mdstruct_ref) |
| 268 | + pos = 0 |
| 269 | + # Store the number of dimensions in the local accessor |
| 270 | + self._builder.insert_value( |
| 271 | + mdstruct, |
| 272 | + self._context.get_constant(types.int64, ndim), |
| 273 | + idx=pos, |
| 274 | + ) |
| 275 | + # Get the underlying dtype of the data (a CPointer) attribute of a |
| 276 | + # local_accessor object |
| 277 | + pos += 1 |
| 278 | + self._builder.insert_value( |
| 279 | + mdstruct, |
| 280 | + numba_type_to_dpctl_typenum(self._context, data_attr_ty.dtype), |
| 281 | + idx=pos, |
| 282 | + ) |
| 283 | + # Extract and store the shape values from array into mdstruct |
| 284 | + shape_attr = self._builder.gep( |
| 285 | + llvm_val, |
| 286 | + [ |
| 287 | + self._context.get_constant(types.int32, 0), |
| 288 | + self._context.get_constant(types.int32, shape_member_pos), |
| 289 | + ], |
| 290 | + ) |
| 291 | + for dim in range(ndim): |
| 292 | + shape_ext = self._builder.gep( |
| 293 | + shape_attr, |
| 294 | + [ |
| 295 | + self._context.get_constant(types.int32, 0), |
| 296 | + self._context.get_constant(types.int32, dim), |
| 297 | + ], |
| 298 | + ) |
| 299 | + shape_ext_val = self._builder.load(shape_ext) |
| 300 | + pos += 1 |
| 301 | + self._builder.insert_value(mdstruct, shape_ext_val, idx=pos) |
| 302 | + |
| 303 | + return self._build_arg( |
| 304 | + llvm_val=mdstruct_ref, |
| 305 | + numba_type=LocalAccessorType( |
| 306 | + ndim, dpctl.tensor.dtype(data_attr_ty.dtype.name) |
| 307 | + ), |
| 308 | + ) |
| 309 | + |
| 310 | + def _build_array_arg( |
| 311 | + self, |
| 312 | + llvm_array_val, |
| 313 | + is_local_accessor=False, |
| 314 | + ): |
| 315 | + """Creates a list of LLVM Values for an unpacked USMNdArray kernel |
| 316 | + argument. |
| 317 | + """ |
| 318 | + kernel_arg_list = [] |
| 319 | + |
| 320 | + kernel_arg_list.extend( |
| 321 | + self._build_collections_attr_arg( |
| 322 | + llvm_val=llvm_array_val, |
| 323 | + attr_index=self._arg_host_datamodel.get_field_position( |
| 324 | + "nitems" |
| 325 | + ), |
| 326 | + attr_type=self._arg_kernel_datamodel.get_member_fe_type( |
| 327 | + "nitems" |
| 328 | + ), |
| 329 | + ) |
| 330 | + ) |
| 331 | + # Argument itemsize |
| 332 | + kernel_arg_list.extend( |
| 333 | + self._build_collections_attr_arg( |
| 334 | + llvm_val=llvm_array_val, |
| 335 | + attr_index=self._arg_host_datamodel.get_field_position( |
| 336 | + "itemsize" |
| 337 | + ), |
| 338 | + attr_type=self._arg_kernel_datamodel.get_member_fe_type( |
| 339 | + "itemsize" |
| 340 | + ), |
| 341 | + ) |
| 342 | + ) |
| 343 | + # Argument data |
| 344 | + data_attr_pos = self._arg_host_datamodel.get_field_position("data") |
| 345 | + data_attr_ty = self._arg_kernel_datamodel.get_member_fe_type("data") |
| 346 | + |
| 347 | + if is_local_accessor: |
| 348 | + kernel_arg_list.extend( |
| 349 | + self._build_local_accessor_metadata_arg( |
| 350 | + llvm_val=llvm_array_val, |
| 351 | + data_attr_ty=data_attr_ty, |
| 352 | + ) |
| 353 | + ) |
| 354 | + else: |
| 355 | + kernel_arg_list.extend( |
| 356 | + self._build_collections_attr_arg( |
| 357 | + llvm_val=llvm_array_val, |
| 358 | + attr_index=data_attr_pos, |
| 359 | + attr_type=data_attr_ty, |
| 360 | + ) |
| 361 | + ) |
| 362 | + # Arguments for shape |
| 363 | + kernel_arg_list.extend( |
| 364 | + self._build_unituple_member_arg( |
| 365 | + llvm_val=llvm_array_val, |
| 366 | + attr_pos=self._arg_host_datamodel.get_field_position("shape"), |
| 367 | + ndims=self._arg_kernel_datamodel.get_member_fe_type( |
| 368 | + "shape" |
| 369 | + ).count, |
| 370 | + ) |
| 371 | + ) |
| 372 | + # Arguments for strides |
| 373 | + kernel_arg_list.extend( |
| 374 | + self._build_unituple_member_arg( |
| 375 | + llvm_val=llvm_array_val, |
| 376 | + attr_pos=self._arg_host_datamodel.get_field_position("strides"), |
| 377 | + ndims=self._arg_kernel_datamodel.get_member_fe_type( |
| 378 | + "strides" |
| 379 | + ).count, |
| 380 | + ) |
| 381 | + ) |
| 382 | + |
| 383 | + return kernel_arg_list |
0 commit comments