Skip to content

[UR][L0] Add the multi-device-compile experimental feature #924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion include/ur.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ class ur_function_v(IntEnum):
KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194## Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp
COMMAND_BUFFER_APPEND_USM_PREFETCH_EXP = 195 ## Enumerator for ::urCommandBufferAppendUSMPrefetchExp
COMMAND_BUFFER_APPEND_USM_ADVISE_EXP = 196 ## Enumerator for ::urCommandBufferAppendUSMAdviseExp
LOADER_CONFIG_SET_CODE_LOCATION_CALLBACK = 197 ## Enumerator for ::urLoaderConfigSetCodeLocationCallback
PROGRAM_BUILD_EXP = 197 ## Enumerator for ::urProgramBuildExp
PROGRAM_COMPILE_EXP = 198 ## Enumerator for ::urProgramCompileExp
PROGRAM_LINK_EXP = 199 ## Enumerator for ::urProgramLinkExp
LOADER_CONFIG_SET_CODE_LOCATION_CALLBACK = 200 ## Enumerator for ::urLoaderConfigSetCodeLocationCallback

class ur_function_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -2315,6 +2318,11 @@ class ur_exp_command_buffer_handle_t(c_void_p):
## which is returned when querying device extensions.
UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP = "ur_exp_cooperative_kernels"

###############################################################################
## @brief The extension string which defines support for test
## which is returned when querying device extensions.
UR_MULTI_DEVICE_COMPILE_EXTENSION_STRING_EXP = "ur_exp_multi_device_compile"

###############################################################################
## @brief Supported peer info
class ur_exp_peer_info_v(IntEnum):
Expand Down Expand Up @@ -2631,6 +2639,37 @@ class ur_program_dditable_t(Structure):
("pfnCreateWithNativeHandle", c_void_p) ## _urProgramCreateWithNativeHandle_t
]

###############################################################################
## @brief Function-pointer for urProgramBuildExp
if __use_win_types:
_urProgramBuildExp_t = WINFUNCTYPE( ur_result_t, ur_program_handle_t, c_ulong, POINTER(ur_device_handle_t), c_char_p )
else:
_urProgramBuildExp_t = CFUNCTYPE( ur_result_t, ur_program_handle_t, c_ulong, POINTER(ur_device_handle_t), c_char_p )

###############################################################################
## @brief Function-pointer for urProgramCompileExp
if __use_win_types:
_urProgramCompileExp_t = WINFUNCTYPE( ur_result_t, ur_program_handle_t, c_ulong, POINTER(ur_device_handle_t), c_char_p )
else:
_urProgramCompileExp_t = CFUNCTYPE( ur_result_t, ur_program_handle_t, c_ulong, POINTER(ur_device_handle_t), c_char_p )

###############################################################################
## @brief Function-pointer for urProgramLinkExp
if __use_win_types:
_urProgramLinkExp_t = WINFUNCTYPE( ur_result_t, ur_context_handle_t, c_ulong, POINTER(ur_device_handle_t), c_ulong, POINTER(ur_program_handle_t), c_char_p, POINTER(ur_program_handle_t) )
else:
_urProgramLinkExp_t = CFUNCTYPE( ur_result_t, ur_context_handle_t, c_ulong, POINTER(ur_device_handle_t), c_ulong, POINTER(ur_program_handle_t), c_char_p, POINTER(ur_program_handle_t) )


###############################################################################
## @brief Table of ProgramExp functions pointers
class ur_program_exp_dditable_t(Structure):
_fields_ = [
("pfnBuildExp", c_void_p), ## _urProgramBuildExp_t
("pfnCompileExp", c_void_p), ## _urProgramCompileExp_t
("pfnLinkExp", c_void_p) ## _urProgramLinkExp_t
]

###############################################################################
## @brief Function-pointer for urKernelCreate
if __use_win_types:
Expand Down Expand Up @@ -3862,6 +3901,7 @@ class ur_dditable_t(Structure):
("Context", ur_context_dditable_t),
("Event", ur_event_dditable_t),
("Program", ur_program_dditable_t),
("ProgramExp", ur_program_exp_dditable_t),
("Kernel", ur_kernel_dditable_t),
("KernelExp", ur_kernel_exp_dditable_t),
("Sampler", ur_sampler_dditable_t),
Expand Down Expand Up @@ -3966,6 +4006,18 @@ def __init__(self, version : ur_api_version_t):
self.urProgramGetNativeHandle = _urProgramGetNativeHandle_t(self.__dditable.Program.pfnGetNativeHandle)
self.urProgramCreateWithNativeHandle = _urProgramCreateWithNativeHandle_t(self.__dditable.Program.pfnCreateWithNativeHandle)

# call driver to get function pointers
ProgramExp = ur_program_exp_dditable_t()
r = ur_result_v(self.__dll.urGetProgramExpProcAddrTable(version, byref(ProgramExp)))
if r != ur_result_v.SUCCESS:
raise Exception(r)
self.__dditable.ProgramExp = ProgramExp

# attach function interface to function address
self.urProgramBuildExp = _urProgramBuildExp_t(self.__dditable.ProgramExp.pfnBuildExp)
self.urProgramCompileExp = _urProgramCompileExp_t(self.__dditable.ProgramExp.pfnCompileExp)
self.urProgramLinkExp = _urProgramLinkExp_t(self.__dditable.ProgramExp.pfnLinkExp)

# call driver to get function pointers
Kernel = ur_kernel_dditable_t()
r = ur_result_v(self.__dll.urGetKernelProcAddrTable(version, byref(Kernel)))
Expand Down
166 changes: 165 additions & 1 deletion include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ typedef enum ur_function_t {
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194, ///< Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp
UR_FUNCTION_COMMAND_BUFFER_APPEND_USM_PREFETCH_EXP = 195, ///< Enumerator for ::urCommandBufferAppendUSMPrefetchExp
UR_FUNCTION_COMMAND_BUFFER_APPEND_USM_ADVISE_EXP = 196, ///< Enumerator for ::urCommandBufferAppendUSMAdviseExp
UR_FUNCTION_LOADER_CONFIG_SET_CODE_LOCATION_CALLBACK = 197, ///< Enumerator for ::urLoaderConfigSetCodeLocationCallback
UR_FUNCTION_PROGRAM_BUILD_EXP = 197, ///< Enumerator for ::urProgramBuildExp
UR_FUNCTION_PROGRAM_COMPILE_EXP = 198, ///< Enumerator for ::urProgramCompileExp
UR_FUNCTION_PROGRAM_LINK_EXP = 199, ///< Enumerator for ::urProgramLinkExp
UR_FUNCTION_LOADER_CONFIG_SET_CODE_LOCATION_CALLBACK = 200, ///< Enumerator for ::urLoaderConfigSetCodeLocationCallback
/// @cond
UR_FUNCTION_FORCE_UINT32 = 0x7fffffff
/// @endcond
Expand Down Expand Up @@ -8408,6 +8411,131 @@ urKernelSuggestMaxCooperativeGroupCountExp(
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
);

#if !defined(__GNUC__)
#pragma endregion
#endif
// Intel 'oneAPI' Unified Runtime Experimental APIs for multi-device compile
#if !defined(__GNUC__)
#pragma region multi device compile(experimental)
#endif
///////////////////////////////////////////////////////////////////////////////
#ifndef UR_MULTI_DEVICE_COMPILE_EXTENSION_STRING_EXP
/// @brief The extension string which defines support for test
/// which is returned when querying device extensions.
#define UR_MULTI_DEVICE_COMPILE_EXTENSION_STRING_EXP "ur_exp_multi_device_compile"
#endif // UR_MULTI_DEVICE_COMPILE_EXTENSION_STRING_EXP

///////////////////////////////////////////////////////////////////////////////
/// @brief Produces an executable program from one program, negates need for the
/// linking step.
///
/// @details
/// - The application may call this function from simultaneous threads.
/// - Following a successful call to this entry point, the program passed
/// will contain a binary of the ::UR_PROGRAM_BINARY_TYPE_EXECUTABLE type
/// for each device in `phDevices`.
///
/// @remarks
/// _Analogues_
/// - **clBuildProgram**
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hProgram`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevices`
/// - ::UR_RESULT_ERROR_INVALID_PROGRAM
/// + If `hProgram` isn't a valid program object.
/// - ::UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE
/// + If an error occurred when building `hProgram`.
UR_APIEXPORT ur_result_t UR_APICALL
urProgramBuildExp(
ur_program_handle_t hProgram, ///< [in] Handle of the program to build.
uint32_t numDevices, ///< [in] number of devices

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could say here that if numDevices is 0, then urProgramBuildExp would be have the same as urProgramBuild, where the program is built for all the devices in the context.

ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles
const char *pOptions ///< [in][optional] pointer to build options null-terminated string.
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Produces an executable program from one or more programs.
///
/// @details
/// - The application may call this function from simultaneous threads.
/// - Following a successful call to this entry point `hProgram` will
/// contain a binary of the ::UR_PROGRAM_BINARY_TYPE_COMPILED_OBJECT type
/// for each device in `phDevices`.
///
/// @remarks
/// _Analogues_
/// - **clCompileProgram**
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hProgram`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevices`
/// - ::UR_RESULT_ERROR_INVALID_PROGRAM
/// + If `hProgram` isn't a valid program object.
/// - ::UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE
/// + If an error occurred while compiling `hProgram`.
UR_APIEXPORT ur_result_t UR_APICALL
urProgramCompileExp(
ur_program_handle_t hProgram, ///< [in][out] handle of the program to compile.
uint32_t numDevices, ///< [in] number of devices
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles
const char *pOptions ///< [in][optional] pointer to build options null-terminated string.
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Produces an executable program from one or more programs.
///
/// @details
/// - The application may call this function from simultaneous threads.
/// - Following a successful call to this entry point the program returned
/// in `phProgram` will contain a binary of the
/// ::UR_PROGRAM_BINARY_TYPE_EXECUTABLE type for each device in
/// `phDevices`.
///
/// @remarks
/// _Analogues_
/// - **clLinkProgram**
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hContext`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevices`
/// + `NULL == phPrograms`
/// + `NULL == phProgram`
/// - ::UR_RESULT_ERROR_INVALID_PROGRAM
/// + If one of the programs in `phPrograms` isn't a valid program object.
/// - ::UR_RESULT_ERROR_INVALID_SIZE
/// + `count == 0`
/// - ::UR_RESULT_ERROR_PROGRAM_LINK_FAILURE
/// + If an error occurred while linking `phPrograms`.
UR_APIEXPORT ur_result_t UR_APICALL
urProgramLinkExp(
ur_context_handle_t hContext, ///< [in] handle of the context instance.
uint32_t numDevices, ///< [in] number of devices
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles
uint32_t count, ///< [in] number of program handles in `phPrograms`.
const ur_program_handle_t *phPrograms, ///< [in][range(0, count)] pointer to array of program handles.
const char *pOptions, ///< [in][optional] pointer to linker options null-terminated string.
ur_program_handle_t *phProgram ///< [out] pointer to handle of program object created.
);

#if !defined(__GNUC__)
#pragma endregion
#endif
Expand Down Expand Up @@ -8919,6 +9047,17 @@ typedef struct ur_program_build_params_t {
const char **ppOptions;
} ur_program_build_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramBuildExp
/// @details Each entry is a pointer to the parameter passed to the function;
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_program_build_exp_params_t {
ur_program_handle_t *phProgram;
uint32_t *pnumDevices;
ur_device_handle_t **pphDevices;
const char **ppOptions;
} ur_program_build_exp_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramCompile
/// @details Each entry is a pointer to the parameter passed to the function;
Expand All @@ -8929,6 +9068,17 @@ typedef struct ur_program_compile_params_t {
const char **ppOptions;
} ur_program_compile_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramCompileExp
/// @details Each entry is a pointer to the parameter passed to the function;
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_program_compile_exp_params_t {
ur_program_handle_t *phProgram;
uint32_t *pnumDevices;
ur_device_handle_t **pphDevices;
const char **ppOptions;
} ur_program_compile_exp_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramLink
/// @details Each entry is a pointer to the parameter passed to the function;
Expand All @@ -8941,6 +9091,20 @@ typedef struct ur_program_link_params_t {
ur_program_handle_t **pphProgram;
} ur_program_link_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramLinkExp
/// @details Each entry is a pointer to the parameter passed to the function;
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_program_link_exp_params_t {
ur_context_handle_t *phContext;
uint32_t *pnumDevices;
ur_device_handle_t **pphDevices;
uint32_t *pcount;
const ur_program_handle_t **pphPrograms;
const char **ppOptions;
ur_program_handle_t **pphProgram;
} ur_program_link_exp_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urProgramRetain
/// @details Each entry is a pointer to the parameter passed to the function;
Expand Down
57 changes: 57 additions & 0 deletions include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,62 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetProgramProcAddrTable_t)(
ur_api_version_t,
ur_program_dditable_t *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urProgramBuildExp
typedef ur_result_t(UR_APICALL *ur_pfnProgramBuildExp_t)(
ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urProgramCompileExp
typedef ur_result_t(UR_APICALL *ur_pfnProgramCompileExp_t)(
ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urProgramLinkExp
typedef ur_result_t(UR_APICALL *ur_pfnProgramLinkExp_t)(
ur_context_handle_t,
uint32_t,
ur_device_handle_t *,
uint32_t,
const ur_program_handle_t *,
const char *,
ur_program_handle_t *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Table of ProgramExp functions pointers
typedef struct ur_program_exp_dditable_t {
ur_pfnProgramBuildExp_t pfnBuildExp;
ur_pfnProgramCompileExp_t pfnCompileExp;
ur_pfnProgramLinkExp_t pfnLinkExp;
} ur_program_exp_dditable_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Exported function for filling application's ProgramExp table
/// with current process' addresses
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION
UR_DLLEXPORT ur_result_t UR_APICALL
urGetProgramExpProcAddrTable(
ur_api_version_t version, ///< [in] API version requested
ur_program_exp_dditable_t *pDdiTable ///< [in,out] pointer to table of DDI function pointers
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urGetProgramExpProcAddrTable
typedef ur_result_t(UR_APICALL *ur_pfnGetProgramExpProcAddrTable_t)(
ur_api_version_t,
ur_program_exp_dditable_t *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urKernelCreate
typedef ur_result_t(UR_APICALL *ur_pfnKernelCreate_t)(
Expand Down Expand Up @@ -2250,6 +2306,7 @@ typedef struct ur_dditable_t {
ur_context_dditable_t Context;
ur_event_dditable_t Event;
ur_program_dditable_t Program;
ur_program_exp_dditable_t ProgramExp;
ur_kernel_dditable_t Kernel;
ur_kernel_exp_dditable_t KernelExp;
ur_sampler_dditable_t Sampler;
Expand Down
Loading