diff --git a/sycl/include/CL/sycl/detail/pi.h b/sycl/include/CL/sycl/detail/pi.h index 4a8b0fbf70185..f4fe05bdf5bcf 100644 --- a/sycl/include/CL/sycl/detail/pi.h +++ b/sycl/include/CL/sycl/detail/pi.h @@ -114,6 +114,7 @@ typedef enum { PI_INVALID_IMAGE_FORMAT_DESCRIPTOR = CL_INVALID_IMAGE_FORMAT_DESCRIPTOR, PI_IMAGE_FORMAT_NOT_SUPPORTED = CL_IMAGE_FORMAT_NOT_SUPPORTED, PI_MEM_OBJECT_ALLOCATION_FAILURE = CL_MEM_OBJECT_ALLOCATION_FAILURE, + PI_LINK_PROGRAM_FAILURE = CL_LINK_PROGRAM_FAILURE, PI_FUNCTION_ADDRESS_IS_NOT_AVAILABLE = -998, ///< PI_FUNCTION_ADDRESS_IS_NOT_AVAILABLE indicates a fallback ///< method determines the function exists but its address cannot be diff --git a/sycl/plugins/level_zero/pi_level_zero.cpp b/sycl/plugins/level_zero/pi_level_zero.cpp index 9a7c09418d6b7..cfb615b3f2292 100644 --- a/sycl/plugins/level_zero/pi_level_zero.cpp +++ b/sycl/plugins/level_zero/pi_level_zero.cpp @@ -374,8 +374,9 @@ static sycl::detail::SpinLock *PiPlatformsCacheMutex = new sycl::detail::SpinLock; static bool PiPlatformCachePopulated = false; -// Keeps track if the global offset extension is found +// Flags which tell whether various Level Zero extensions are available. static bool PiDriverGlobalOffsetExtensionFound = false; +static bool PiDriverModuleProgramExtensionFound = false; // TODO:: In the following 4 methods we may want to distinguish read access vs. // write (as it is OK for multiple threads to read the map without locking it). @@ -556,6 +557,7 @@ inline void zeParseError(ze_result_t ZeError, const char *&ErrorString) { ZE_ERRCASE(ZE_RESULT_ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE) ZE_ERRCASE(ZE_RESULT_ERROR_INVALID_COMMAND_LIST_TYPE) ZE_ERRCASE(ZE_RESULT_ERROR_OVERLAPPING_REGIONS) + ZE_ERRCASE(ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED) ZE_ERRCASE(ZE_RESULT_ERROR_UNKNOWN) #undef ZE_ERRCASE @@ -1582,31 +1584,11 @@ extern "C" { // Forward declarations decltype(piEventCreate) piEventCreate; -static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices, - const pi_device *DeviceList, - const char *Options); -static pi_result copyModule(ze_context_handle_t ZeContext, - ze_device_handle_t ZeDevice, - ze_module_handle_t SrcMod, - ze_module_handle_t *DestMod); - -static bool setEnvVar(const char *var, const char *value); - -// Forward declarations for mock implementations of Level Zero APIs that -// do not yet work in the driver. -// TODO: Remove these mock definitions when they work in the driver. static ze_result_t -zeModuleDynamicLinkMock(uint32_t numModules, ze_module_handle_t *phModules, - ze_module_build_log_handle_t *phLinkLog); - -static ze_result_t -zeModuleGetPropertiesMock(ze_module_handle_t hModule, - ze_module_properties_t *pModuleProperties); - -static bool isOnlineLinkEnabled(); -// End forward declarations for mock Level Zero APIs +checkUnresolvedSymbols(ze_module_handle_t ZeModule, + ze_module_build_log_handle_t *ZeBuildLog); -// This function will ensure compatibility with both Linux and Windowns for +// This function will ensure compatibility with both Linux and Windows for // setting environment variables. static bool setEnvVar(const char *name, const char *value) { #ifdef _WIN32 @@ -1657,6 +1639,14 @@ pi_result _pi_platform::initialize() { PiDriverGlobalOffsetExtensionFound = true; } } + // Check if extension is available for "static linking" (compiling multiple + // SPIR-V modules together into one Level Zero module). + if (strncmp(extension.name, ZE_MODULE_PROGRAM_EXP_NAME, + strlen(ZE_MODULE_PROGRAM_EXP_NAME) + 1) == 0) { + if (extension.version == ZE_MODULE_PROGRAM_EXP_VERSION_1_0) { + PiDriverModuleProgramExtensionFound = true; + } + } zeDriverExtensionMap[extension.name] = extension.version; } @@ -3603,11 +3593,10 @@ pi_result piProgramCreate(pi_context Context, const void *ILBytes, PI_ASSERT(Program, PI_INVALID_PROGRAM); // NOTE: the Level Zero module creation is also building the program, so we - // are deferring it until the program is ready to be built in piProgramBuild - // and piProgramCompile. Also it is only then we know the build options. + // are deferring it until the program is ready to be built. try { - *Program = new _pi_program(Context, ILBytes, Length, _pi_program::IL); + *Program = new _pi_program(_pi_program::IL, Context, ILBytes, Length); } catch (const std::bad_alloc &) { return PI_OUT_OF_HOST_MEMORY; } catch (...) { @@ -3655,7 +3644,7 @@ pi_result piProgramCreateWithBinary( // information to distinguish the cases. try { - *Program = new _pi_program(Context, Binary, Length, _pi_program::Native); + *Program = new _pi_program(_pi_program::Native, Context, Binary, Length); } catch (const std::bad_alloc &) { return PI_OUT_OF_HOST_MEMORY; } catch (...) { @@ -3698,32 +3687,16 @@ pi_result piProgramGetInfo(pi_program Program, pi_program_info ParamName, // TODO: return all devices this program exists for. return ReturnValue(Program->Context->Devices[0]); case PI_PROGRAM_INFO_BINARY_SIZES: { + std::shared_lock Guard(Program->Mutex); size_t SzBinary; if (Program->State == _pi_program::IL || - Program->State == _pi_program::Native) { + Program->State == _pi_program::Native || + Program->State == _pi_program::Object) { SzBinary = Program->CodeLength; + } else if (Program->State == _pi_program::Exe) { + ZE_CALL(zeModuleGetNativeBinary, (Program->ZeModule, &SzBinary, nullptr)); } else { - PI_ASSERT(Program->State == _pi_program::Object || - Program->State == _pi_program::Exe || - Program->State == _pi_program::LinkedExe, - PI_INVALID_OPERATION); - - // If the program is in LinkedExe state it may contain several modules. - // We cannot handle this case because each module's contents is in its - // own address range, discontiguous from the others. The - // PI_PROGRAM_INFO_BINARY_SIZES API assume the entire linked program is - // one contiguous region, which is not the case for LinkedExe program - // in Level Zero. Therefore, this API is unimplemented when the Program - // has more than one module. - _pi_program::ModuleIterator ModIt(Program); - - PI_ASSERT(!ModIt.Done(), PI_INVALID_VALUE); - - if (ModIt.Count() > 1) { - die("piProgramGetInfo: PI_PROGRAM_INFO_BINARY_SIZES not implemented " - "for linked programs"); - } - ZE_CALL(zeModuleGetNativeBinary, (*ModIt, &SzBinary, nullptr)); + return PI_INVALID_PROGRAM; } // This is an array of 1 element, initialized as if it were scalar. return ReturnValue(size_t{SzBinary}); @@ -3736,76 +3709,57 @@ pi_result piProgramGetInfo(pi_program Program, pi_program_info ParamName, uint8_t **PBinary = pi_cast(ParamValue); if (!PBinary[0]) break; + + std::shared_lock Guard(Program->Mutex); if (Program->State == _pi_program::IL || - Program->State == _pi_program::Native) { + Program->State == _pi_program::Native || + Program->State == _pi_program::Object) { std::memcpy(PBinary[0], Program->Code.get(), Program->CodeLength); - } else { - PI_ASSERT(Program->State == _pi_program::Object || - Program->State == _pi_program::Exe || - Program->State == _pi_program::LinkedExe, - PI_INVALID_OPERATION); - - _pi_program::ModuleIterator ModIt(Program); - - PI_ASSERT(!ModIt.Done(), PI_INVALID_VALUE); - - if (ModIt.Count() > 1) { - die("piProgramGetInfo: PI_PROGRAM_INFO_BINARIES not implemented for " - "linked programs"); - } + } else if (Program->State == _pi_program::Exe) { size_t SzBinary = 0; - ZE_CALL(zeModuleGetNativeBinary, (*ModIt, &SzBinary, PBinary[0])); + ZE_CALL(zeModuleGetNativeBinary, + (Program->ZeModule, &SzBinary, PBinary[0])); + } else { + return PI_INVALID_PROGRAM; } break; } case PI_PROGRAM_INFO_NUM_KERNELS: { + std::shared_lock Guard(Program->Mutex); uint32_t NumKernels; if (Program->State == _pi_program::IL || Program->State == _pi_program::Native || Program->State == _pi_program::Object) { return PI_INVALID_PROGRAM_EXECUTABLE; - } else { - PI_ASSERT(Program->State == _pi_program::Exe || - Program->State == _pi_program::LinkedExe, - PI_INVALID_OPERATION); - + } else if (Program->State == _pi_program::Exe) { NumKernels = 0; - _pi_program::ModuleIterator ModIt(Program); - while (!ModIt.Done()) { - uint32_t Num; - ZE_CALL(zeModuleGetKernelNames, (*ModIt, &Num, nullptr)); - NumKernels += Num; - ModIt++; - } + ZE_CALL(zeModuleGetKernelNames, + (Program->ZeModule, &NumKernels, nullptr)); + } else { + return PI_INVALID_PROGRAM; } return ReturnValue(size_t{NumKernels}); } case PI_PROGRAM_INFO_KERNEL_NAMES: try { + std::shared_lock Guard(Program->Mutex); std::string PINames{""}; if (Program->State == _pi_program::IL || Program->State == _pi_program::Native || Program->State == _pi_program::Object) { return PI_INVALID_PROGRAM_EXECUTABLE; - } else { - PI_ASSERT(Program->State == _pi_program::Exe || - Program->State == _pi_program::LinkedExe, - PI_INVALID_PROGRAM_EXECUTABLE); - - bool First = true; - _pi_program::ModuleIterator ModIt(Program); - while (!ModIt.Done()) { - uint32_t Count = 0; - ZE_CALL(zeModuleGetKernelNames, (*ModIt, &Count, nullptr)); - std::unique_ptr PNames(new const char *[Count]); - ZE_CALL(zeModuleGetKernelNames, (*ModIt, &Count, PNames.get())); - for (uint32_t I = 0; I < Count; ++I) { - PINames += (!First ? ";" : ""); - PINames += PNames[I]; - First = false; - } - ModIt++; + } else if (Program->State == _pi_program::Exe) { + uint32_t Count = 0; + ZE_CALL(zeModuleGetKernelNames, (Program->ZeModule, &Count, nullptr)); + std::unique_ptr PNames(new const char *[Count]); + ZE_CALL(zeModuleGetKernelNames, + (Program->ZeModule, &Count, PNames.get())); + for (uint32_t I = 0; I < Count; ++I) { + PINames += (I > 0 ? ";" : ""); + PINames += PNames[I]; } + } else { + return PI_INVALID_PROGRAM; } return ReturnValue(PINames.c_str()); } catch (const std::bad_alloc &) { @@ -3826,8 +3780,6 @@ pi_result piProgramLink(pi_context Context, pi_uint32 NumDevices, const pi_program *InputPrograms, void (*PFnNotify)(pi_program Program, void *UserData), void *UserData, pi_program *RetProgram) { - (void)Options; - // We only support one device with Level Zero currently. pi_device Device = Context->Devices[0]; if (NumDevices != 1) { @@ -3835,109 +3787,142 @@ pi_result piProgramLink(pi_context Context, pi_uint32 NumDevices, return PI_INVALID_VALUE; } - PI_ASSERT(DeviceList && DeviceList[0] == Device, PI_INVALID_DEVICE); - PI_ASSERT(!PFnNotify && !UserData, PI_INVALID_VALUE); + // We do not support any link flags at this time because the Level Zero API + // does not have any way to pass flags that are specific to linking. + if (Options && *Options != '\0') { + std::string ErrorMessage( + "Level Zero does not support kernel link flags: \""); + ErrorMessage.append(Options); + ErrorMessage.push_back('\"'); + pi_program Program = + new _pi_program(_pi_program::Invalid, Context, Options, ErrorMessage); + *RetProgram = Program; + return PI_LINK_PROGRAM_FAILURE; + } // Validate input parameters. + PI_ASSERT(DeviceList && DeviceList[0] == Device, PI_INVALID_DEVICE); + PI_ASSERT(!PFnNotify && !UserData, PI_INVALID_VALUE); if (NumInputPrograms == 0 || InputPrograms == nullptr) return PI_INVALID_VALUE; - for (pi_uint32 I = 0; I < NumInputPrograms; I++) { - if (InputPrograms[I]->State != _pi_program::Object) { - return PI_INVALID_OPERATION; - } - PI_ASSERT(InputPrograms[I]->ZeModule, PI_INVALID_VALUE); - } - // Linking modules on Level Zero is different from OpenCL. With Level Zero, - // each input object module already has native code loaded onto the device. - // Linking two modules together causes the importing module to be changed - // such that its native code points to an address in the exporting module. - // As a result, a module that imports symbols can only be linked into one - // executable at a time. By contrast, modules that export symbols are not - // changed, so they can be safely linked into multiple executables - // simultaneously. - // - // Level Zero linking also differs from OpenCL because a link operation does - // not create a new module that represents the linked executable. Instead, - // we must keep track of all the input modules and refer to the entire list - // whenever we want to know something about the executable. - - // This vector hold the Level Zero modules that we will actually link - // together. This may be different from "InputPrograms" because some of - // those modules may import symbols and already be linked into other - // executables. In such a case, we must make a copy of the module before we - // can link it again. - std::vector<_pi_program::LinkedReleaser> Inputs; + pi_result PiResult = PI_SUCCESS; try { - Inputs.reserve(NumInputPrograms); + // Acquire a "shared" lock on each of the input programs, and also validate + // that they are all in Object state. + // + // There is no danger of deadlock here even if two threads call + // piProgramLink simultaneously with the same input programs in a different + // order. If we were acquiring these with "exclusive" access, this could + // lead to a classic lock ordering deadlock. However, there is no such + // deadlock potential with "shared" access. There could also be a deadlock + // potential if there was some other code that holds more than one of these + // locks simultaneously with "exclusive" access. However, there is no such + // code like that, so this is also not a danger. + std::vector> Guards(NumInputPrograms); + for (pi_uint32 I = 0; I < NumInputPrograms; I++) { + std::shared_lock Guard(InputPrograms[I]->Mutex); + Guards[I].swap(Guard); + if (InputPrograms[I]->State != _pi_program::Object) { + return PI_INVALID_OPERATION; + } + } - // We do several things in this loop. + // Previous calls to piProgramCompile did not actually compile the SPIR-V. + // Instead, we postpone compilation until this point, when all the modules + // are linked together. By doing compilation and linking together, the JIT + // compiler is able see all modules and do cross-module optimizations. // - // 1. We identify any modules that need to be copied because they import - // symbols and are already linked into some other program. - // 2. For any module that does not need to be copied, we bump its reference - // count because we will hold a reference to it. - // 3. We create a vector of Level Zero modules, which we can pass to the - // zeModuleDynamicLink() API. - std::vector ZeHandles; - ZeHandles.reserve(NumInputPrograms); + // Construct a ze_module_program_exp_desc_t which contains information about + // all of the modules that will be linked together. + ZeStruct ZeExtModuleDesc; + std::vector CodeSizes(NumInputPrograms); + std::vector CodeBufs(NumInputPrograms); + std::vector BuildFlagPtrs(NumInputPrograms); + std::vector SpecConstPtrs(NumInputPrograms); + std::vector<_pi_program::SpecConstantShim> SpecConstShims; + SpecConstShims.reserve(NumInputPrograms); + for (pi_uint32 I = 0; I < NumInputPrograms; I++) { - pi_program Input = InputPrograms[I]; - if (Input->HasImports) { - std::unique_lock Guard(Input->MutexHasImportsAndIsLinked); - if (!Input->HasImportsAndIsLinked) { - // This module imports symbols, but it isn't currently linked with - // any other module. Grab the flag to indicate that it is now - // linked. - PI_CALL(piProgramRetain(Input)); - Input->HasImportsAndIsLinked = true; - } else { - // This module imports symbols and is also linked with another module - // already, so it needs to be copied. We expect this to be quite - // rare since linking is mostly used to link against libraries which - // only export symbols. - Guard.unlock(); - ze_module_handle_t ZeModule; - pi_result res = copyModule(Context->ZeContext, Device->ZeDevice, - Input->ZeModule, &ZeModule); - if (res != PI_SUCCESS) { - return res; - } - Input = - new _pi_program(Input->Context, ZeModule, true /*own ZeModule*/, - _pi_program::Object, Input->HasImports); - Input->HasImportsAndIsLinked = true; - } + pi_program Program = InputPrograms[I]; + CodeSizes[I] = Program->CodeLength; + CodeBufs[I] = Program->Code.get(); + BuildFlagPtrs[I] = Program->BuildFlags.c_str(); + SpecConstShims.emplace_back(Program); + SpecConstPtrs[I] = SpecConstShims[I].ze(); + } + + ZeExtModuleDesc.count = NumInputPrograms; + ZeExtModuleDesc.inputSizes = CodeSizes.data(); + ZeExtModuleDesc.pInputModules = CodeBufs.data(); + ZeExtModuleDesc.pBuildFlags = BuildFlagPtrs.data(); + ZeExtModuleDesc.pConstants = SpecConstPtrs.data(); + + ZeStruct ZeModuleDesc; + ZeModuleDesc.pNext = &ZeExtModuleDesc; + ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV; + + // We need a Level Zero extension to compile multiple programs together into + // a single Level Zero module. However, we don't need that extension if + // there happens to be only one input program. + if (!PiDriverModuleProgramExtensionFound) { + if (NumInputPrograms == 1) { + ZeModuleDesc.pNext = nullptr; + ZeModuleDesc.inputSize = ZeExtModuleDesc.inputSizes[0]; + ZeModuleDesc.pInputModule = ZeExtModuleDesc.pInputModules[0]; + ZeModuleDesc.pBuildFlags = ZeExtModuleDesc.pBuildFlags[0]; + ZeModuleDesc.pConstants = ZeExtModuleDesc.pConstants[0]; } else { - PI_CALL(piProgramRetain(Input)); + zePrint("piProgramLink: level_zero driver does not have static linking " + "support."); + return PI_INVALID_VALUE; } - Inputs.emplace_back(Input); - ZeHandles.push_back(Input->ZeModule); } - // Link all the modules together. - ze_module_build_log_handle_t ZeBuildLog; + // Call the Level Zero API to compile, link, and create the module. + ze_device_handle_t ZeDevice = DeviceList[0]->ZeDevice; + ze_context_handle_t ZeContext = Context->ZeContext; + ze_module_handle_t ZeModule = nullptr; + ze_module_build_log_handle_t ZeBuildLog = nullptr; ze_result_t ZeResult = - ZE_CALL_NOCHECK(zeModuleDynamicLinkMock, - (ZeHandles.size(), ZeHandles.data(), &ZeBuildLog)); - - // Construct a new program object to represent the linked executable. This - // new object holds a reference to all the input programs. Note that we - // create this program object even if the link fails with "link failure" - // because we need the new program object to hold the buid log (which has - // the description of the failure). - if (ZeResult == ZE_RESULT_SUCCESS || - ZeResult == ZE_RESULT_ERROR_MODULE_LINK_FAILURE) { - *RetProgram = new _pi_program(Context, std::move(Inputs), ZeBuildLog); + ZE_CALL_NOCHECK(zeModuleCreate, (ZeContext, ZeDevice, &ZeModuleDesc, + &ZeModule, &ZeBuildLog)); + + // We still create a _pi_program object even if there is a BUILD_FAILURE + // because we need the object to hold the ZeBuildLog. There is no build + // log created for other errors, so we don't create an object. + PiResult = mapError(ZeResult); + if (ZeResult != ZE_RESULT_SUCCESS && + ZeResult != ZE_RESULT_ERROR_MODULE_BUILD_FAILURE) { + return PiResult; + } + + // The call to zeModuleCreate does not report an error if there are + // unresolved symbols because it thinks these could be resolved later via a + // call to zeModuleDynamicLink. However, modules created with piProgramLink + // are supposed to be fully linked and ready to use. Therefore, do an extra + // check now for unresolved symbols. Note that we still create a + // _pi_program if there are unresolved symbols because the ZeBuildLog tells + // which symbols are unresolved. + if (ZeResult == ZE_RESULT_SUCCESS) { + ZeResult = checkUnresolvedSymbols(ZeModule, &ZeBuildLog); + if (ZeResult == ZE_RESULT_ERROR_MODULE_LINK_FAILURE) { + PiResult = PI_LINK_PROGRAM_FAILURE; + } else if (ZeResult != ZE_RESULT_SUCCESS) { + return mapError(ZeResult); + } } - if (ZeResult != ZE_RESULT_SUCCESS) - return mapError(ZeResult); + + _pi_program::state State = + (PiResult == PI_SUCCESS) ? _pi_program::Exe : _pi_program::Invalid; + *RetProgram = + new _pi_program(State, Context, ZeModule, ZeBuildLog, Options); } catch (const std::bad_alloc &) { return PI_OUT_OF_HOST_MEMORY; } catch (...) { return PI_ERROR_UNKNOWN; } - return PI_SUCCESS; + return PiResult; } pi_result piProgramCompile( @@ -3949,10 +3934,15 @@ pi_result piProgramCompile( (void)InputHeaders; (void)HeaderIncludeNames; - // The OpenCL spec says this should return CL_INVALID_PROGRAM, but there is - // no corresponding PI error code. - if (!Program) - return PI_INVALID_OPERATION; + PI_ASSERT(Program, PI_INVALID_PROGRAM); + + if ((NumDevices && !DeviceList) || (!NumDevices && DeviceList)) + return PI_INVALID_VALUE; + + // These aren't supported. + PI_ASSERT(!PFnNotify && !UserData, PI_INVALID_VALUE); + + std::scoped_lock Guard(Program->Mutex); // It's only valid to compile a program created from IL (we don't support // programs created from source code). @@ -3962,14 +3952,15 @@ pi_result piProgramCompile( if (Program->State != _pi_program::IL) return PI_INVALID_OPERATION; - // These aren't supported. - PI_ASSERT(!PFnNotify && !UserData, PI_INVALID_VALUE); - - pi_result res = compileOrBuild(Program, NumDevices, DeviceList, Options); - if (res != PI_SUCCESS) - return res; - + // We don't compile anything now. Instead, we delay compilation until + // piProgramLink, where we do both compilation and linking as a single step. + // This produces better code because the driver can do cross-module + // optimizations. Therefore, we just remember the compilation flags, so we + // can use them later. + if (Options) + Program->BuildFlags = Options; Program->State = _pi_program::Object; + return PI_SUCCESS; } @@ -3978,32 +3969,7 @@ pi_result piProgramBuild(pi_program Program, pi_uint32 NumDevices, void (*PFnNotify)(pi_program Program, void *UserData), void *UserData) { - // The OpenCL spec says this should return CL_INVALID_PROGRAM, but there is - // no corresponding PI error code. PI_ASSERT(Program, PI_INVALID_PROGRAM); - - // It is legal to build a program created from either IL or from native - // device code. - if (Program->State != _pi_program::IL && - Program->State != _pi_program::Native) - return PI_INVALID_OPERATION; - - // These aren't supported. - PI_ASSERT(!PFnNotify && !UserData, PI_INVALID_VALUE); - - pi_result res = compileOrBuild(Program, NumDevices, DeviceList, Options); - if (res != PI_SUCCESS) - return res; - - Program->State = _pi_program::Exe; - return PI_SUCCESS; -} - -// Perform common operations for compiling or building a program. -static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices, - const pi_device *DeviceList, - const char *Options) { - if ((NumDevices && !DeviceList) || (!NumDevices && DeviceList)) return PI_INVALID_VALUE; @@ -4011,63 +3977,61 @@ static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices, // TODO: we should eventually build to the possibly multiple root // devices in the context. if (NumDevices != 1) { - zePrint("compileOrBuild: level_zero supports only one device."); + zePrint("piProgramBuild: level_zero supports only one device."); return PI_INVALID_VALUE; } - PI_ASSERT(DeviceList, PI_INVALID_DEVICE); + // These aren't supported. + PI_ASSERT(!PFnNotify && !UserData, PI_INVALID_VALUE); + + std::scoped_lock Guard(Program->Mutex); + + // It is legal to build a program created from either IL or from native + // device code. + if (Program->State != _pi_program::IL && + Program->State != _pi_program::Native) + return PI_INVALID_OPERATION; // We should have either IL or native device code. PI_ASSERT(Program->Code, PI_INVALID_PROGRAM); - // Specialization constants are used only if the program was created from - // IL. Translate them to the Level Zero format. - ze_module_constants_t ZeSpecConstants = {}; - std::vector ZeSpecContantsIds; - std::vector ZeSpecContantsValues; - if (Program->State == _pi_program::IL) { - std::lock_guard Guard(Program->MutexZeSpecConstants); - - ZeSpecConstants.numConstants = Program->ZeSpecConstants.size(); - ZeSpecContantsIds.reserve(ZeSpecConstants.numConstants); - ZeSpecContantsValues.reserve(ZeSpecConstants.numConstants); - - for (auto &SpecConstant : Program->ZeSpecConstants) { - ZeSpecContantsIds.push_back(SpecConstant.first); - ZeSpecContantsValues.push_back(SpecConstant.second); - } - ZeSpecConstants.pConstantIds = ZeSpecContantsIds.data(); - ZeSpecConstants.pConstantValues = const_cast( - reinterpret_cast(ZeSpecContantsValues.data())); - } - // Ask Level Zero to build and load the native code onto the device. ZeStruct ZeModuleDesc; + _pi_program::SpecConstantShim Shim(Program); ZeModuleDesc.format = (Program->State == _pi_program::IL) ? ZE_MODULE_FORMAT_IL_SPIRV : ZE_MODULE_FORMAT_NATIVE; ZeModuleDesc.inputSize = Program->CodeLength; ZeModuleDesc.pInputModule = Program->Code.get(); ZeModuleDesc.pBuildFlags = Options; - ZeModuleDesc.pConstants = &ZeSpecConstants; + ZeModuleDesc.pConstants = Shim.ze(); ze_device_handle_t ZeDevice = DeviceList[0]->ZeDevice; ze_context_handle_t ZeContext = Program->Context->ZeContext; - ze_module_handle_t ZeModule; + ze_module_handle_t ZeModule = nullptr; ZE_CALL(zeModuleCreate, (ZeContext, ZeDevice, &ZeModuleDesc, &ZeModule, &Program->ZeBuildLog)); - // Check if this module imports any symbols, which we need to know if we - // end up linking this module later. See comments in piProgramLink() for - // details. - ZeStruct ZeModuleProps; - ZE_CALL(zeModuleGetPropertiesMock, (ZeModule, &ZeModuleProps)); - Program->HasImports = (ZeModuleProps.flags & ZE_MODULE_PROPERTY_FLAG_IMPORTS); + // The call to zeModuleCreate does not report an error if there are + // unresolved symbols because it thinks these could be resolved later via a + // call to zeModuleDynamicLink. However, modules created with piProgramBuild + // are supposed to be fully linked and ready to use. Therefore, do an extra + // check now for unresolved symbols. + ze_result_t ZeResult = checkUnresolvedSymbols(ZeModule, &Program->ZeBuildLog); + if (ZeResult == ZE_RESULT_ERROR_MODULE_LINK_FAILURE) { + return PI_BUILD_PROGRAM_FAILURE; + } else if (ZeResult != ZE_RESULT_SUCCESS) { + return mapError(ZeResult); + } // We no longer need the IL / native code. - // The caller must set the State to Object or Exe as appropriate. Program->Code.reset(); + + if (Options) + Program->BuildFlags = Options; Program->ZeModule = ZeModule; + Program->State = _pi_program::Exe; + return PI_SUCCESS; } @@ -4077,35 +4041,40 @@ pi_result piProgramGetBuildInfo(pi_program Program, pi_device Device, size_t *ParamValueSizeRet) { (void)Device; + std::shared_lock Guard(Program->Mutex); ReturnHelper ReturnValue(ParamValueSize, ParamValue, ParamValueSizeRet); if (ParamName == CL_PROGRAM_BINARY_TYPE) { cl_program_binary_type Type = CL_PROGRAM_BINARY_TYPE_NONE; if (Program->State == _pi_program::Object) { Type = CL_PROGRAM_BINARY_TYPE_COMPILED_OBJECT; - } else if (Program->State == _pi_program::Exe || - Program->State == _pi_program::LinkedExe) { + } else if (Program->State == _pi_program::Exe) { Type = CL_PROGRAM_BINARY_TYPE_EXECUTABLE; } return ReturnValue(cl_program_binary_type{Type}); } if (ParamName == CL_PROGRAM_BUILD_OPTIONS) { - // TODO: how to get module build options out of Level Zero? - // For the programs that we compiled we can remember the options - // passed with piProgramCompile/piProgramBuild, but what can we - // return for programs that were built outside and registered - // with piProgramRegister? - return ReturnValue(""); + return ReturnValue(Program->BuildFlags.c_str()); } else if (ParamName == CL_PROGRAM_BUILD_LOG) { - // The OpenCL spec says an empty string is returned if there was no - // previous Compile, Build, or Link. - if (!Program->ZeBuildLog) - return ReturnValue(""); - size_t LogSize = ParamValueSize; - ZE_CALL(zeModuleBuildLogGetString, - (Program->ZeBuildLog, &LogSize, pi_cast(ParamValue))); - if (ParamValueSizeRet) { - *ParamValueSizeRet = LogSize; + // Check first to see if the plugin code recorded an error message. + if (!Program->ErrorMessage.empty()) { + return ReturnValue(Program->ErrorMessage.c_str()); + } + + // Next check if there is a Level Zero build log. + if (Program->ZeBuildLog) { + size_t LogSize = ParamValueSize; + ZE_CALL(zeModuleBuildLogGetString, + (Program->ZeBuildLog, &LogSize, pi_cast(ParamValue))); + if (ParamValueSizeRet) { + *ParamValueSizeRet = LogSize; + } + return PI_SUCCESS; } + + // Otherwise, there is no error. The OpenCL spec says to return an empty + // string if there ws no previous attempt to compile, build, or link the + // program. + return ReturnValue(""); } else { zePrint("piProgramGetBuildInfo: unsupported ParamName\n"); return PI_INVALID_VALUE; @@ -4136,20 +4105,10 @@ pi_result piextProgramGetNativeHandle(pi_program Program, auto ZeModule = pi_cast(NativeHandle); + std::shared_lock Guard(Program->Mutex); switch (Program->State) { - case _pi_program::Object: - case _pi_program::Exe: - case _pi_program::LinkedExe: { - _pi_program::ModuleIterator ModIt(Program); - PI_ASSERT(!ModIt.Done(), PI_INVALID_VALUE); - if (ModIt.Count() > 1) { - // Programs in LinkedExe state could have several corresponding - // Level Zero modules, so there is no right answer in this case. - // - // TODO: Maybe we should return PI_INVALID_OPERATION instead here? - die("piextProgramGetNativeHandle: Not implemented for linked programs"); - } - *ZeModule = *ModIt; + case _pi_program::Exe: { + *ZeModule = Program->ZeModule; break; } @@ -4176,7 +4135,7 @@ pi_result piextProgramCreateWithNativeHandle(pi_native_handle NativeHandle, try { *Program = - new _pi_program(Context, ZeModule, ownNativeHandle, _pi_program::Exe); + new _pi_program(_pi_program::Exe, Context, ZeModule, ownNativeHandle); } catch (const std::bad_alloc &) { return PI_OUT_OF_HOST_MEMORY; } catch (...) { @@ -4198,102 +4157,40 @@ _pi_program::~_pi_program() { } } -_pi_program::LinkedReleaser::~LinkedReleaser() { - if (Prog->HasImports) { - std::lock_guard Guard(Prog->MutexHasImportsAndIsLinked); - if (Prog->HasImportsAndIsLinked) - Prog->HasImportsAndIsLinked = false; - } - piProgramRelease(Prog); -} - -// Create a copy of a Level Zero module by extracting the native code and -// creating a new module from that native code. -static pi_result copyModule(ze_context_handle_t ZeContext, - ze_device_handle_t ZeDevice, - ze_module_handle_t SrcMod, - ze_module_handle_t *DestMod) { - size_t Length; - ZE_CALL(zeModuleGetNativeBinary, (SrcMod, &Length, nullptr)); - - std::unique_ptr Code(new uint8_t[Length]); - ZE_CALL(zeModuleGetNativeBinary, (SrcMod, &Length, Code.get())); - - ZeStruct ZeModuleDesc; - ZeModuleDesc.format = ZE_MODULE_FORMAT_NATIVE; - ZeModuleDesc.inputSize = Length; - ZeModuleDesc.pInputModule = Code.get(); - ZeModuleDesc.pBuildFlags = nullptr; - ZeModuleDesc.pConstants = nullptr; - - ze_module_handle_t ZeModule; - ZE_CALL(zeModuleCreate, - (ZeContext, ZeDevice, &ZeModuleDesc, &ZeModule, nullptr)); - *DestMod = ZeModule; - return PI_SUCCESS; -} - -// TODO: Remove this mock implementation once the Level Zero driver -// implementation works. +// Check to see if a Level Zero module has any unresolved symbols. +// +// @param ZeModule The module handle to check. +// @param ZeBuildLog If there are unresolved symbols, this build log handle is +// modified to receive information telling which symbols +// are unresolved. +// +// @return ZE_RESULT_ERROR_MODULE_LINK_FAILURE indicates there are unresolved +// symbols. ZE_RESULT_SUCCESS indicates all symbols are resolved. Any other +// value indicates there was an error and we cannot tell if symbols are +// resolved. static ze_result_t -zeModuleDynamicLinkMock(uint32_t numModules, ze_module_handle_t *phModules, - ze_module_build_log_handle_t *phLinkLog) { - - // If enabled, try calling the real driver API instead. At the time this - // code was written, the "phLinkLog" parameter to zeModuleDynamicLink() - // doesn't work, so hard code it to NULL. - if (isOnlineLinkEnabled()) { - if (phLinkLog) - *phLinkLog = nullptr; - return ZE_CALL_NOCHECK(zeModuleDynamicLink, - (numModules, phModules, nullptr)); - } - - // The mock implementation can only handle the degenerate case where there - // is only a single module that is "linked" to itself. There is nothing to - // do in this degenerate case. - if (numModules > 1) { - die("piProgramLink: Program Linking is not supported yet in Level0"); - } +checkUnresolvedSymbols(ze_module_handle_t ZeModule, + ze_module_build_log_handle_t *ZeBuildLog) { - // The mock does not support the link log. - if (phLinkLog) - *phLinkLog = nullptr; - return ZE_RESULT_SUCCESS; -} - -// TODO: Remove this mock implementation once the Level Zero driver -// implementation works. -static ze_result_t -zeModuleGetPropertiesMock(ze_module_handle_t hModule, - ze_module_properties_t *pModuleProperties) { + // First check to see if the module has any imported symbols. If there are + // no imported symbols, it's not possible to have any unresolved symbols. We + // do this check first because we assume it's faster than the call to + // zeModuleDynamicLink below. + ZeStruct ZeModuleProps; + ze_result_t ZeResult = + ZE_CALL_NOCHECK(zeModuleGetProperties, (ZeModule, &ZeModuleProps)); + if (ZeResult != ZE_RESULT_SUCCESS) + return ZeResult; - // If enabled, try calling the real driver API first. At the time this code - // was written it always returns ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, so we - // fall back to the mock in this case. - if (isOnlineLinkEnabled()) { - ze_result_t ZeResult = - ZE_CALL_NOCHECK(zeModuleGetProperties, (hModule, pModuleProperties)); - if (ZeResult != ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) { - return ZeResult; - } + // If there are imported symbols, attempt to "link" the module with itself. + // As a side effect, this will return the error + // ZE_RESULT_ERROR_MODULE_LINK_FAILURE if there are any unresolved symbols. + if (ZeModuleProps.flags & ZE_MODULE_PROPERTY_FLAG_IMPORTS) { + return ZE_CALL_NOCHECK(zeModuleDynamicLink, (1, &ZeModule, ZeBuildLog)); } - - // The mock implementation assumes that the module has imported symbols. - // This is a conservative guess which may result in unnecessary calls to - // copyModule(), but it is always correct. - pModuleProperties->flags = ZE_MODULE_PROPERTY_FLAG_IMPORTS; return ZE_RESULT_SUCCESS; } -// Returns true if we should use the Level Zero driver online linking APIs. -// At the time this code was written, these APIs exist but do not work. We -// think that support in the DPC++ runtime is ready once the driver bugs are -// fixed, so runtime support can be enabled by setting an environment variable. -static bool isOnlineLinkEnabled() { - static bool IsEnabled = std::getenv("SYCL_ENABLE_LEVEL_ZERO_LINK"); - return IsEnabled; -} pi_result piKernelCreate(pi_program Program, const char *KernelName, pi_kernel *RetKernel) { @@ -4301,8 +4198,8 @@ pi_result piKernelCreate(pi_program Program, const char *KernelName, PI_ASSERT(RetKernel, PI_INVALID_VALUE); PI_ASSERT(KernelName, PI_INVALID_VALUE); - if (Program->State != _pi_program::Exe && - Program->State != _pi_program::LinkedExe) { + std::shared_lock Guard(Program->Mutex); + if (Program->State != _pi_program::Exe) { return PI_INVALID_PROGRAM_EXECUTABLE; } @@ -4310,29 +4207,8 @@ pi_result piKernelCreate(pi_program Program, const char *KernelName, ZeKernelDesc.flags = 0; ZeKernelDesc.pKernelName = KernelName; - // Search for the kernel name in each module. ze_kernel_handle_t ZeKernel; - ze_result_t ZeResult = ZE_RESULT_ERROR_INVALID_KERNEL_NAME; - _pi_program::ModuleIterator ModIt(Program); - while (!ModIt.Done()) { - // For a module with valid sycl kernel inside, zeKernelCreate API - // should return ZE_RESULT_SUCCESS if target kernel is found and - // ZE_RESULT_ERROR_INVALID_KERNEL_NAME otherwise. However, some module - // may not include any sycl kernel such as device library modules. For such - // modules, zeKernelCreate will return ZE_RESULT_ERROR_INVALID_ARGUMENT and - // we should skip them. - uint32_t KernelNum = 0; - ZE_CALL(zeModuleGetKernelNames, (*ModIt, &KernelNum, nullptr)); - if (KernelNum != 0) { - ZeResult = - ZE_CALL_NOCHECK(zeKernelCreate, (*ModIt, &ZeKernelDesc, &ZeKernel)); - if (ZeResult != ZE_RESULT_ERROR_INVALID_KERNEL_NAME) - break; - } - ModIt++; - } - if (ZeResult != ZE_RESULT_SUCCESS) - return mapError(ZeResult); + ZE_CALL(zeKernelCreate, (Program->ZeModule, &ZeKernelDesc, &ZeKernel)); try { *RetKernel = new _pi_kernel(ZeKernel, true, Program); @@ -6546,22 +6422,15 @@ pi_result piextGetDeviceFunctionPointer(pi_device Device, pi_program Program, (void)Device; PI_ASSERT(Program, PI_INVALID_PROGRAM); - if (Program->State != _pi_program::Exe && - Program->State != _pi_program::LinkedExe) { + std::shared_lock Guard(Program->Mutex); + if (Program->State != _pi_program::Exe) { return PI_INVALID_PROGRAM_EXECUTABLE; } - // Search for the function name in each module. - ze_result_t ZeResult = ZE_RESULT_ERROR_INVALID_FUNCTION_NAME; - _pi_program::ModuleIterator ModIt(Program); - while (!ModIt.Done()) { - ZeResult = ZE_CALL_NOCHECK( - zeModuleGetFunctionPointer, - (*ModIt, FunctionName, reinterpret_cast(FunctionPointerRet))); - if (ZeResult != ZE_RESULT_ERROR_INVALID_FUNCTION_NAME) - break; - ModIt++; - } + ze_result_t ZeResult = + ZE_CALL_NOCHECK(zeModuleGetFunctionPointer, + (Program->ZeModule, FunctionName, + reinterpret_cast(FunctionPointerRet))); // zeModuleGetFunctionPointer currently fails for all // kernels regardless of if the kernel exist or not @@ -7378,16 +7247,15 @@ pi_result piKernelSetExecInfo(pi_kernel Kernel, pi_kernel_exec_info ParamName, pi_result piextProgramSetSpecializationConstant(pi_program Prog, pi_uint32 SpecID, size_t, const void *SpecValue) { - // Level Zero sets spec constants when creating modules, - // so save them for when program is built. - std::lock_guard Guard(Prog->MutexZeSpecConstants); + std::scoped_lock Guard(Prog->Mutex); - // Pass SpecValue pointer. Spec constant value is retrieved - // by Level Zero when creating the module. + // Remember the value of this specialization constant until the program is + // built. Note that we only save the pointer to the buffer that contains the + // value. The caller is responsible for maintaining storage for this buffer. // // NOTE: SpecSize is unused in Level Zero, the size is known from SPIR-V by // SpecID. - Prog->ZeSpecConstants[SpecID] = reinterpret_cast(SpecValue); + Prog->SpecConstants[SpecID] = SpecValue; return PI_SUCCESS; } diff --git a/sycl/plugins/level_zero/pi_level_zero.hpp b/sycl/plugins/level_zero/pi_level_zero.hpp index 689c47ee14a10..4b2f722368ca6 100644 --- a/sycl/plugins/level_zero/pi_level_zero.hpp +++ b/sycl/plugins/level_zero/pi_level_zero.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -95,6 +96,10 @@ template <> ze_structure_type_t getZeStructureType() { template <> ze_structure_type_t getZeStructureType() { return ZE_STRUCTURE_TYPE_MODULE_DESC; } +template <> +ze_structure_type_t getZeStructureType() { + return ZE_STRUCTURE_TYPE_MODULE_PROGRAM_EXP_DESC; +} template <> ze_structure_type_t getZeStructureType() { return ZE_STRUCTURE_TYPE_KERNEL_DESC; } @@ -993,7 +998,7 @@ struct _pi_event : _pi_object { struct _pi_program : _pi_object { // Possible states of a program. typedef enum { - // The program has been created from intermediate language (SPIR-v), but it + // The program has been created from intermediate language (SPIR-V), but it // is not yet compiled. IL, @@ -1002,179 +1007,115 @@ struct _pi_program : _pi_object { // is loaded via clCreateProgramWithBinary(). Native, - // The program consists of native code (typically compiled from SPIR-v), - // but it has unresolved external dependencies which need to be resolved - // by linking with other Object state program(s). Programs in this state - // have a single Level Zero module. + // The program was notionally compiled from SPIR-V form. However, since we + // postpone compilation until the module is linked, the internal state + // still represents the module as SPIR-V. Object, - // The program consists of native code with no external dependencies. - // Programs in this state have a single Level Zero module, but no linking - // is needed in order to run kernels. + // The program has been built or linked, and it is represented as a Level + // Zero module. Exe, - // The program consists of several Level Zero modules, each of which - // contains native code. Some modules may import external symbols and - // other modules may export definitions of those external symbols. All of - // the modules have been linked together, so the imported references are - // resolved by the exported definitions. - // - // Module linking in Level Zero is quite different from program linking in - // OpenCL. OpenCL statically links several program objects together to - // form a new program that contains the linked result. Level Zero is more - // similar to shared libraries. When several Level Zero modules are linked - // together, each module is modified "in place" such that external - // references from one are linked to external definitions in another. - // Linking in Level Zero does not produce a new Level Zero module that - // represents the linked result, therefore a program in LinkedExe state - // holds a list of all the pi_programs that were linked together. Queries - // about the linked program need to query all the pi_programs in this list. - LinkedExe + // An error occurred during piProgramLink, but we created a _pi_program + // object anyways in order to hold the ZeBuildLog. Note that the ZeModule + // may or may not be nullptr in this state, depending on the error. + Invalid } state; - // This is a wrapper class used for programs in LinkedExe state. Such a - // program contains a list of pi_programs in Object state that have been - // linked together. The program in LinkedExe state increments the reference - // counter for each of the Object state programs, thus "retaining" a - // reference to them, and it may also set the "HasImportsAndIsLinked" flag - // in these Object state programs. The purpose of this wrapper is to - // decrement the reference count and clear the flag when the LinkedExe - // program is destroyed, so all the interesting code is in the wrapper's - // destructor. - // - // In order to ensure that the reference count is never decremented more - // than once, the wrapper has no copy constructor or copy assignment - // operator. Instead, we only allow move semantics for the wrapper. - class LinkedReleaser { - public: - LinkedReleaser(pi_program Prog) : Prog(Prog) {} - LinkedReleaser(LinkedReleaser &&Other) { - Prog = Other.Prog; - Other.Prog = nullptr; - } - LinkedReleaser(const LinkedReleaser &Other) = delete; - LinkedReleaser &operator=(LinkedReleaser &&Other) { - std::swap(Prog, Other.Prog); - return *this; - } - LinkedReleaser &operator=(const LinkedReleaser &Other) = delete; - ~LinkedReleaser(); - - pi_program operator->() const { return Prog; } - - private: - pi_program Prog; - }; - - // A utility class that iterates over the Level Zero modules contained by - // the program. This helps hide the difference between programs in Object - // or Exe state (which have one module) and programs in LinkedExe state - // (which have several modules). - class ModuleIterator { + // A utility class that converts specialization constants into the form + // required by the Level Zero driver. + class SpecConstantShim { public: - ModuleIterator(pi_program Prog) - : Prog(Prog), It(Prog->LinkedPrograms.begin()) { - if (Prog->State == LinkedExe) { - NumMods = Prog->LinkedPrograms.size(); - IsDone = (It == Prog->LinkedPrograms.end()); - Mod = IsDone ? nullptr : (*It)->ZeModule; - } else if (Prog->State == IL || Prog->State == Native) { - NumMods = 0; - IsDone = true; - Mod = nullptr; - } else { - NumMods = 1; - IsDone = false; - Mod = Prog->ZeModule; + SpecConstantShim(pi_program Program) { + ZeSpecConstants.numConstants = Program->SpecConstants.size(); + ZeSpecContantsIds.reserve(ZeSpecConstants.numConstants); + ZeSpecContantsValues.reserve(ZeSpecConstants.numConstants); + + for (auto &SpecConstant : Program->SpecConstants) { + ZeSpecContantsIds.push_back(SpecConstant.first); + ZeSpecContantsValues.push_back(SpecConstant.second); } + ZeSpecConstants.pConstantIds = ZeSpecContantsIds.data(); + ZeSpecConstants.pConstantValues = ZeSpecContantsValues.data(); } - bool Done() const { return IsDone; } - size_t Count() const { return NumMods; } - ze_module_handle_t operator*() const { return Mod; } - - void operator++(int) { - if (!IsDone && (Prog->State == LinkedExe) && - (++It != Prog->LinkedPrograms.end())) { - Mod = (*It)->ZeModule; - } else { - Mod = nullptr; - IsDone = true; - } - } + const ze_module_constants_t *ze() { return &ZeSpecConstants; } private: - pi_program Prog; - ze_module_handle_t Mod; - size_t NumMods; - bool IsDone; - std::vector::iterator It; + std::vector ZeSpecContantsIds; + std::vector ZeSpecContantsValues; + ze_module_constants_t ZeSpecConstants; }; // Construct a program in IL or Native state. - _pi_program(pi_context Context, const void *Input, size_t Length, state St) - : State(St), Context(Context), Code(new uint8_t[Length]), - CodeLength(Length), ZeModule(nullptr), OwnZeModule{true}, - HasImports(false), HasImportsAndIsLinked(false), ZeBuildLog(nullptr) { - + _pi_program(state St, pi_context Context, const void *Input, size_t Length) + : Context{Context}, + OwnZeModule{true}, State{St}, Code{new uint8_t[Length]}, + CodeLength{Length}, ZeModule{nullptr}, ZeBuildLog{nullptr} { std::memcpy(Code.get(), Input, Length); } - // Construct a program in either Object or Exe state. - _pi_program(pi_context Context, ze_module_handle_t ZeModule, bool OwnZeModule, - state St, bool HasImports = false) - : State(St), Context(Context), - ZeModule(ZeModule), OwnZeModule{OwnZeModule}, HasImports(HasImports), - HasImportsAndIsLinked(false), ZeBuildLog(nullptr) {} + // Construct a program in Exe or Invalid state. + _pi_program(state St, pi_context Context, ze_module_handle_t ZeModule, + ze_module_build_log_handle_t ZeBuildLog, const char *Options) + : Context{Context}, OwnZeModule{true}, State{St}, ZeModule{ZeModule}, + ZeBuildLog{ZeBuildLog} { + if (Options) + BuildFlags = Options; + } - // Construct a program in LinkedExe state. - _pi_program(pi_context Context, std::vector &&Inputs, - ze_module_build_log_handle_t ZeLog) - : State(LinkedExe), Context(Context), ZeModule(nullptr), - OwnZeModule(true), HasImports(false), HasImportsAndIsLinked(false), - LinkedPrograms(std::move(Inputs)), ZeBuildLog(ZeLog) {} + // Construct a program in Exe state (interop). + _pi_program(state St, pi_context Context, ze_module_handle_t ZeModule, + bool OwnZeModule) + : Context{Context}, OwnZeModule{OwnZeModule}, State{St}, + ZeModule{ZeModule}, ZeBuildLog{nullptr} {} + + // Construct a program in Invalid state with a custom error message. + _pi_program(state St, pi_context Context, const char *Options, + const std::string &ErrorMessage) + : Context{Context}, OwnZeModule{true}, ErrorMessage{ErrorMessage}, + State{St}, ZeModule{nullptr}, ZeBuildLog{nullptr} { + if (Options) + BuildFlags = Options; + } ~_pi_program(); - // Used for programs in all states. - state State; - pi_context Context; // Context of the program. + const pi_context Context; // Context of the program. - // Used for programs in IL or Native states. - std::unique_ptr Code; // Array containing raw IL / native code. - size_t CodeLength; // Size (bytes) of the array. + // Indicates if we own the ZeModule or it came from interop that + // asked to not transfer the ownership to SYCL RT. + const bool OwnZeModule; - // Level Zero specialization constants, used for programs in IL state. - std::unordered_map ZeSpecConstants; - std::mutex MutexZeSpecConstants; // Protects access to this field. + // This error message is used only in Invalid state to hold a custom error + // message from a call to piProgramLink. + const std::string ErrorMessage; - // Used for programs in Object or Exe state. - ze_module_handle_t ZeModule; // Level Zero module handle. + // Protects accesses to all the non-const member variables. Exclusive access + // is required to modify any of these members. + std::shared_mutex Mutex; - // Indicates if we own the ZeModule or it came from interop that - // asked to not transfer the ownership to SYCL RT. - bool OwnZeModule; + state State; - // Tells if module imports any symbols. - bool HasImports; + // In IL and Object states, this contains the SPIR-V representation of the + // module. In Native state, it contains the native code. + std::unique_ptr Code; // Array containing raw IL / native code. + size_t CodeLength; // Size (bytes) of the array. - // Used for programs in Object state. Tells if this module imports any - // symbols AND it is linked into some other program that has state LinkedExe. - // Such an Object is linked into exactly one other LinkedExe program. Access - // to this field needs to be locked in case there are two threads that try to - // simultaneously link with this module. - bool HasImportsAndIsLinked; - std::mutex MutexHasImportsAndIsLinked; // Protects access to this field. + // Used only in IL and Object states. Contains the SPIR-V specialization + // constants as a map from the SPIR-V "SpecID" to a buffer that contains the + // associated value. The caller of the PI layer is responsible for + // maintaining the storage of this buffer. + std::unordered_map SpecConstants; - // Used for programs in LinkedExe state. This is the set of Object programs - // that are linked together. - // - // Note that the Object programs in this vector might also be linked into - // other LinkedExe programs! - std::vector LinkedPrograms; + // Used only in Object and Exe states. Contains the build flags from the + // last call to piProgramCompile() or piProgramLink(). + std::string BuildFlags; + + // The Level Zero module handle. Used primarily in Exe state. + ze_module_handle_t ZeModule; - // Level Zero build or link log, used for programs in Obj, Exe, or LinkedExe - // state. + // The Level Zero build log from the last call to zeModuleCreate(). ze_module_build_log_handle_t ZeBuildLog; }; diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index aff8dc2f9f74b..757ef9a892528 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -963,8 +963,11 @@ ProgramManager::ProgramPtr ProgramManager::build( DeviceLibReqMask); } + static const char *ForceLinkEnv = std::getenv("SYCL_FORCE_LINK"); + static bool ForceLink = ForceLinkEnv && (*ForceLinkEnv == '1'); + const detail::plugin &Plugin = Context->getPlugin(); - if (LinkPrograms.empty()) { + if (LinkPrograms.empty() && !ForceLink) { RT::PiResult Error = Plugin.call_nocheck( Program.get(), /*num devices =*/1, &Device, CompileOptions.c_str(), nullptr, nullptr);