Skip to content

Commit acc34aa

Browse files
author
Diptorup Deb
committed
Updates the kernel_launcher API to handle LocalAccessor
1 parent afb630a commit acc34aa

File tree

3 files changed

+428
-231
lines changed

3 files changed

+428
-231
lines changed
Lines changed: 383 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Provides a helpers to populate the list of kernel arguments that will
6+
be passed to a DPCTLQueue_Submit function call by a KernelLaunchIRBuilder
7+
object.
8+
"""
9+
10+
from typing import NamedTuple
11+
12+
import dpctl
13+
from llvmlite import ir as llvmir
14+
from numba.core import types
15+
from numba.core.cpu import CPUContext
16+
17+
from numba_dpex import utils
18+
from numba_dpex.core.types import USMNdArray
19+
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
20+
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum
21+
22+
23+
class KernelArg(NamedTuple):
24+
"""Stores the llvm IR value and the dpctl typeid for a kernel argument."""
25+
26+
llvm_val: llvmir.Instruction
27+
typeid: int
28+
29+
30+
class KernelLaunchArgBuilder:
31+
"""Helper to generate the flattened list of kernel arguments to be
32+
passed to a DPCTLQueue_Submit function.
33+
34+
**Note** Two separate data models are used when building a flattened
35+
kernel argument for the following reason:
36+
37+
Different numba-dpex targets can use different data models for the same
38+
data type that may have different number of attributes and a different
39+
type for each attribute.
40+
41+
In the case the DpnpNdArray type, two separate data models are used for
42+
the CPUTarget and for the SPIRVTarget. The SPIRVTarget does not have the
43+
``parent``, ``meminfo`` and ``sycl_queue`` attributes that are present
44+
in the data model used by the CPUTarget. The SPIRVTarget's data model
45+
for DpnpNdArray also requires an explicit address space qualifier for
46+
the ``data`` attribute.
47+
48+
When generating the LLVM IR for the host-side control code for executing
49+
a SPIR-V kernel, the kernel arguments are represented using the
50+
CPUTarget's data model for each argument's type. However, the actual
51+
kernel function generated as a SPIR-V binary by the SPIRVTarget uses its
52+
own data model manager to build the flattened kernel function argument
53+
list. For this reason, when building the flattened argument list for a
54+
kernel launch call the host data model is used to extract the
55+
required attributes and then the kernel data model is used to get the
56+
correct type for the attribute.
57+
"""
58+
59+
def __init__( # pylint: disable=too-many-arguments
60+
self,
61+
target_context: CPUContext,
62+
irbuilder: llvmir.IRBuilder,
63+
arg_type,
64+
arg_packed_llvm_val,
65+
arg_kernel_datamodel,
66+
) -> list[KernelArg]:
67+
68+
self._context = target_context
69+
self._builder = irbuilder
70+
self._arg_type = arg_type
71+
self._arg_llvm_val = arg_packed_llvm_val
72+
self._arg_host_datamodel = self._context.data_model_manager.lookup(
73+
self._arg_type
74+
)
75+
self._arg_kernel_datamodel = arg_kernel_datamodel
76+
77+
def get_kernel_arg_list(self) -> list[KernelArg]:
78+
"""Returns a list of KernelArg objects representing a flattened kernel
79+
argument.
80+
81+
Returns:
82+
list[KernelArg]: List of flattened KernelArg objects
83+
"""
84+
kernel_arg_list = []
85+
86+
if isinstance(self._arg_type, USMNdArray):
87+
kernel_arg_list.extend(
88+
self._build_array_arg(
89+
llvm_array_val=self._arg_llvm_val,
90+
is_local_accessor=isinstance(
91+
self._arg_type, LocalAccessorType
92+
),
93+
)
94+
)
95+
elif self._arg_type == types.complex64:
96+
kernel_arg_list.extend(
97+
self._build_complex_arg(
98+
llvm_val=self._arg_llvm_val, numba_type=types.float32
99+
)
100+
)
101+
elif self._arg_type == types.complex128:
102+
kernel_arg_list.extend(
103+
self._build_complex_arg(
104+
llvm_val=self._arg_llvm_val, numba_type=types.float64
105+
)
106+
)
107+
else:
108+
kernel_arg_list.extend(
109+
self._build_arg(
110+
llvm_val=self._arg_llvm_val, numba_type=self._arg_type
111+
)
112+
)
113+
114+
return kernel_arg_list
115+
116+
def print_kernel_arg_list(self, args_list: list[KernelArg]) -> None:
117+
"""Prints out the kernel argument list in a human readable format.
118+
119+
Args:
120+
args_list (list[KernelArg]): List of kernel arguments to be printed
121+
"""
122+
print(f"Number of flattened kernel arguments: {len(args_list)}")
123+
for karg in args_list:
124+
print(f" {karg.llvm_val} of typeid {karg.typeid}")
125+
126+
def _allocate_local_accessor_metadata_struct(self):
127+
"""Allocates a struct into the current function to store the metadata
128+
that should be passed to libsyclinterface to allocate a
129+
sycl::local_accessor object. The constructor of the sycl::local_accessor
130+
class is: local_accessor<Ty, Ndim>(range<Ndims> r).
131+
132+
For this reason, the struct is allocated as:
133+
134+
LOCAL_ACCESSOR_MDSTRUCT_TYPE = llvmir.LiteralStructType(
135+
[
136+
llvmir.IntType(64), # Ndim (0..3]
137+
llvmir.IntType(32), # typeid
138+
llvmir.IntType(64), # Dim0 extent
139+
llvmir.IntType(64), # Dim1 extent or NULL
140+
llvmir.IntType(64), # Dim2 extent or NULL
141+
]
142+
)
143+
"""
144+
local_accessor_mdstruct_type = llvmir.LiteralStructType(
145+
[
146+
llvmir.IntType(64),
147+
llvmir.IntType(32),
148+
llvmir.IntType(64),
149+
llvmir.IntType(64),
150+
llvmir.IntType(64),
151+
]
152+
)
153+
154+
struct_ref = None
155+
with self._builder.goto_entry_block():
156+
struct_ref = self._builder.alloca(typ=local_accessor_mdstruct_type)
157+
158+
return struct_ref
159+
160+
def _build_arg(self, llvm_val, numba_type):
161+
"""Returns a KernelArg to be passed to a DPCTLQueue_Submit call.
162+
163+
The passed in LLVM IR Value is bitcast to a void* and the
164+
numba/numba_dpex type object is mapped to the corresponding
165+
DPCTLKernelArgType enum value and returned back as a KernelArg object.
166+
167+
Args:
168+
llvm_val: An LLVM IR Value that will be stored into the arguments
169+
array
170+
numba_type: A Numba type that will be converted to a
171+
DPCTLKernelArgType enum and stored into the argument types
172+
list array
173+
Returns:
174+
KernelArg: Tuple corresponding to the LLVM IR Instruction and
175+
DPCTLKernelArgType enum value.
176+
"""
177+
llvm_val = self._builder.bitcast(
178+
llvm_val,
179+
utils.get_llvm_type(context=self._context, type=types.voidptr),
180+
)
181+
typeid = numba_type_to_dpctl_typenum(self._context, numba_type)
182+
183+
return [KernelArg(llvm_val, typeid)]
184+
185+
def _build_unituple_member_arg(self, llvm_val, attr_pos, ndims):
186+
kernel_arg_list = []
187+
array_attr = self._builder.gep(
188+
llvm_val,
189+
[
190+
self._context.get_constant(types.int32, 0),
191+
self._context.get_constant(types.int32, attr_pos),
192+
],
193+
)
194+
195+
for ndim in range(ndims):
196+
kernel_arg_list.extend(
197+
self._build_collections_attr_arg(
198+
llvm_val=array_attr,
199+
attr_index=ndim,
200+
attr_type=types.int64,
201+
)
202+
)
203+
204+
return kernel_arg_list
205+
206+
def _build_collections_attr_arg(self, llvm_val, attr_index, attr_type):
207+
array_attr = self._builder.gep(
208+
llvm_val,
209+
[
210+
self._context.get_constant(types.int32, 0),
211+
self._context.get_constant(types.int32, attr_index),
212+
],
213+
)
214+
215+
if isinstance(attr_type, (types.misc.RawPointer, types.misc.CPointer)):
216+
array_attr = self._builder.load(array_attr)
217+
218+
return self._build_arg(llvm_val=array_attr, numba_type=attr_type)
219+
220+
def _build_complex_arg(self, llvm_val, numba_type):
221+
"""Creates a list of LLVM Values for an unpacked complex kernel
222+
argument.
223+
"""
224+
kernel_arg_list = []
225+
226+
kernel_arg_list.extend(
227+
self._build_collections_attr_arg(
228+
llvm_val=llvm_val,
229+
attr_index=0,
230+
attr_type=numba_type,
231+
)
232+
)
233+
kernel_arg_list.extend(
234+
self._build_collections_attr_arg(
235+
llvm_val=llvm_val,
236+
attr_index=1,
237+
attr_type=numba_type,
238+
)
239+
)
240+
241+
return kernel_arg_list
242+
243+
def _build_local_accessor_metadata_arg(
244+
self,
245+
llvm_val,
246+
data_attr_ty,
247+
):
248+
"""Handles the special case of building the kernel argument for the data
249+
attribute of a kernel_api.LocalAccessor object.
250+
251+
A kernel_api.LocalAccessor conceptually represents a device-only memory
252+
allocation. The mock kernel_api.LocalAccessor uses a numpy.ndarray to
253+
represent the data allocation. The numpy.ndarray cannot be passed to the
254+
kernel and is ignored when building the kernel argument. Instead, a
255+
struct is allocated to store the metadata about the size of the device
256+
memory allocation and a reference to the struct is passed to the
257+
DPCTLQueue_Submit call. The DPCTLQueue_Submit then constructs a
258+
sycl::local_accessor object using the metadata and passes the
259+
sycl::local_accessor as the kernel argument, letting the DPC++ runtime
260+
handle proper device memory allocation.
261+
"""
262+
shape_member = self._arg_kernel_datamodel.get_member_fe_type("shape")
263+
shape_member_pos = self._arg_host_datamodel.get_field_position("shape")
264+
ndim = shape_member.count
265+
266+
mdstruct_ref = self._allocate_local_accessor_metadata_struct()
267+
mdstruct = self._builder.load(mdstruct_ref)
268+
pos = 0
269+
# Store the number of dimensions in the local accessor
270+
self._builder.insert_value(
271+
mdstruct,
272+
self._context.get_constant(types.int64, ndim),
273+
idx=pos,
274+
)
275+
# Get the underlying dtype of the data (a CPointer) attribute of a
276+
# local_accessor object
277+
pos += 1
278+
self._builder.insert_value(
279+
mdstruct,
280+
numba_type_to_dpctl_typenum(self._context, data_attr_ty.dtype),
281+
idx=pos,
282+
)
283+
# Extract and store the shape values from array into mdstruct
284+
shape_attr = self._builder.gep(
285+
llvm_val,
286+
[
287+
self._context.get_constant(types.int32, 0),
288+
self._context.get_constant(types.int32, shape_member_pos),
289+
],
290+
)
291+
for dim in range(ndim):
292+
shape_ext = self._builder.gep(
293+
shape_attr,
294+
[
295+
self._context.get_constant(types.int32, 0),
296+
self._context.get_constant(types.int32, dim),
297+
],
298+
)
299+
shape_ext_val = self._builder.load(shape_ext)
300+
pos += 1
301+
self._builder.insert_value(mdstruct, shape_ext_val, idx=pos)
302+
303+
return self._build_arg(
304+
llvm_val=mdstruct_ref,
305+
numba_type=LocalAccessorType(
306+
ndim, dpctl.tensor.dtype(data_attr_ty.dtype.name)
307+
),
308+
)
309+
310+
def _build_array_arg(
311+
self,
312+
llvm_array_val,
313+
is_local_accessor=False,
314+
):
315+
"""Creates a list of LLVM Values for an unpacked USMNdArray kernel
316+
argument.
317+
"""
318+
kernel_arg_list = []
319+
320+
kernel_arg_list.extend(
321+
self._build_collections_attr_arg(
322+
llvm_val=llvm_array_val,
323+
attr_index=self._arg_host_datamodel.get_field_position(
324+
"nitems"
325+
),
326+
attr_type=self._arg_kernel_datamodel.get_member_fe_type(
327+
"nitems"
328+
),
329+
)
330+
)
331+
# Argument itemsize
332+
kernel_arg_list.extend(
333+
self._build_collections_attr_arg(
334+
llvm_val=llvm_array_val,
335+
attr_index=self._arg_host_datamodel.get_field_position(
336+
"itemsize"
337+
),
338+
attr_type=self._arg_kernel_datamodel.get_member_fe_type(
339+
"itemsize"
340+
),
341+
)
342+
)
343+
# Argument data
344+
data_attr_pos = self._arg_host_datamodel.get_field_position("data")
345+
data_attr_ty = self._arg_kernel_datamodel.get_member_fe_type("data")
346+
347+
if is_local_accessor:
348+
kernel_arg_list.extend(
349+
self._build_local_accessor_metadata_arg(
350+
llvm_val=llvm_array_val,
351+
data_attr_ty=data_attr_ty,
352+
)
353+
)
354+
else:
355+
kernel_arg_list.extend(
356+
self._build_collections_attr_arg(
357+
llvm_val=llvm_array_val,
358+
attr_index=data_attr_pos,
359+
attr_type=data_attr_ty,
360+
)
361+
)
362+
# Arguments for shape
363+
kernel_arg_list.extend(
364+
self._build_unituple_member_arg(
365+
llvm_val=llvm_array_val,
366+
attr_pos=self._arg_host_datamodel.get_field_position("shape"),
367+
ndims=self._arg_kernel_datamodel.get_member_fe_type(
368+
"shape"
369+
).count,
370+
)
371+
)
372+
# Arguments for strides
373+
kernel_arg_list.extend(
374+
self._build_unituple_member_arg(
375+
llvm_val=llvm_array_val,
376+
attr_pos=self._arg_host_datamodel.get_field_position("strides"),
377+
ndims=self._arg_kernel_datamodel.get_member_fe_type(
378+
"strides"
379+
).count,
380+
)
381+
)
382+
383+
return kernel_arg_list

0 commit comments

Comments
 (0)