Skip to content

Commit 21f3a60

Browse files
authored
[Offload] Only initialize a plugin if it is needed (#92765)
Summary: Initializing the plugins requires initializing the runtime like CUDA or HSA. This has a considerable overhead on most platforms, so we should only actually initialize a plugin if it is needed by any image that is loaded.
1 parent 7e476eb commit 21f3a60

File tree

5 files changed

+63
-37
lines changed

5 files changed

+63
-37
lines changed

offload/plugins-nextgen/common/include/JIT.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,6 @@ struct JITEngine {
5555
process(const __tgt_device_image &Image,
5656
target::plugin::GenericDeviceTy &Device);
5757

58-
/// Return true if \p Image is a bitcode image that can be JITed for the given
59-
/// architecture.
60-
Expected<bool> checkBitcodeImage(StringRef Buffer) const;
61-
6258
private:
6359
/// Compile the bitcode image \p Image and generate the binary image that can
6460
/// be loaded to the target device of the triple \p Triple architecture \p

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,10 @@ struct GenericPluginTy {
10521052
/// given target. Returns true if the \p Image is compatible with the plugin.
10531053
Expected<bool> checkELFImage(StringRef Image) const;
10541054

1055+
/// Return true if the \p Image can be compiled to run on the platform's
1056+
/// target architecture.
1057+
Expected<bool> checkBitcodeImage(StringRef Image) const;
1058+
10551059
/// Indicate if an image is compatible with the plugin devices. Notice that
10561060
/// this function may be called before actually initializing the devices. So
10571061
/// we could not move this function into GenericDeviceTy.
@@ -1066,8 +1070,11 @@ struct GenericPluginTy {
10661070
public:
10671071
// TODO: This plugin interface needs to be cleaned up.
10681072

1073+
/// Returns true if the plugin has been initialized.
1074+
int32_t is_initialized() const;
1075+
10691076
/// Returns non-zero if the provided \p Image can be executed by the runtime.
1070-
int32_t is_valid_binary(__tgt_device_image *Image);
1077+
int32_t is_valid_binary(__tgt_device_image *Image, bool Initialized = true);
10711078

10721079
/// Initialize the device inside of the plugin.
10731080
int32_t init_device(int32_t DeviceId);
@@ -1187,6 +1194,9 @@ struct GenericPluginTy {
11871194
void **KernelPtr);
11881195

11891196
private:
1197+
/// Indicates if the platform runtime has been fully initialized.
1198+
bool Initialized = false;
1199+
11901200
/// Number of devices available for the plugin.
11911201
int32_t NumDevices = 0;
11921202

offload/plugins-nextgen/common/src/JIT.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -323,19 +323,3 @@ JITEngine::process(const __tgt_device_image &Image,
323323

324324
return &Image;
325325
}
326-
327-
Expected<bool> JITEngine::checkBitcodeImage(StringRef Buffer) const {
328-
TimeTraceScope TimeScope("Check bitcode image");
329-
330-
assert(identify_magic(Buffer) == file_magic::bitcode &&
331-
"Input is not bitcode");
332-
333-
LLVMContext Context;
334-
auto ModuleOrErr = getLazyBitcodeModule(MemoryBufferRef(Buffer, ""), Context,
335-
/*ShouldLazyLoadMetadata=*/true);
336-
if (!ModuleOrErr)
337-
return ModuleOrErr.takeError();
338-
Module &M = **ModuleOrErr;
339-
340-
return Triple(M.getTargetTriple()).getArch() == TT.getArch();
341-
}

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "omp-tools.h"
2525
#endif
2626

27+
#include "llvm/Bitcode/BitcodeReader.h"
2728
#include "llvm/Frontend/OpenMP/OMPConstants.h"
2829
#include "llvm/Support/Error.h"
2930
#include "llvm/Support/JSON.h"
@@ -1495,6 +1496,7 @@ Error GenericPluginTy::init() {
14951496
if (!NumDevicesOrErr)
14961497
return NumDevicesOrErr.takeError();
14971498

1499+
Initialized = true;
14981500
NumDevices = *NumDevicesOrErr;
14991501
if (NumDevices == 0)
15001502
return Plugin::success();
@@ -1578,14 +1580,27 @@ Expected<bool> GenericPluginTy::checkELFImage(StringRef Image) const {
15781580
if (!MachineOrErr)
15791581
return MachineOrErr.takeError();
15801582

1581-
if (!*MachineOrErr)
1583+
return MachineOrErr;
1584+
}
1585+
1586+
Expected<bool> GenericPluginTy::checkBitcodeImage(StringRef Image) const {
1587+
if (identify_magic(Image) != file_magic::bitcode)
15821588
return false;
15831589

1584-
// Perform plugin-dependent checks for the specific architecture if needed.
1585-
return isELFCompatible(Image);
1590+
LLVMContext Context;
1591+
auto ModuleOrErr = getLazyBitcodeModule(MemoryBufferRef(Image, ""), Context,
1592+
/*ShouldLazyLoadMetadata=*/true);
1593+
if (!ModuleOrErr)
1594+
return ModuleOrErr.takeError();
1595+
Module &M = **ModuleOrErr;
1596+
1597+
return Triple(M.getTargetTriple()).getArch() == getTripleArch();
15861598
}
15871599

1588-
int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image) {
1600+
int32_t GenericPluginTy::is_initialized() const { return Initialized; }
1601+
1602+
int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
1603+
bool Initialized) {
15891604
StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
15901605
target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
15911606

@@ -1603,10 +1618,17 @@ int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image) {
16031618
auto MatchOrErr = checkELFImage(Buffer);
16041619
if (Error Err = MatchOrErr.takeError())
16051620
return HandleError(std::move(Err));
1606-
return *MatchOrErr;
1621+
if (!Initialized || !*MatchOrErr)
1622+
return *MatchOrErr;
1623+
1624+
// Perform plugin-dependent checks for the specific architecture if needed.
1625+
auto CompatibleOrErr = isELFCompatible(Buffer);
1626+
if (Error Err = CompatibleOrErr.takeError())
1627+
return HandleError(std::move(Err));
1628+
return *CompatibleOrErr;
16071629
}
16081630
case file_magic::bitcode: {
1609-
auto MatchOrErr = getJIT().checkBitcodeImage(Buffer);
1631+
auto MatchOrErr = checkBitcodeImage(Buffer);
16101632
if (Error Err = MatchOrErr.takeError())
16111633
return HandleError(std::move(Err));
16121634
return *MatchOrErr;

offload/src/PluginManager.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,8 @@ void PluginManager::init() {
3434
// Attempt to create an instance of each supported plugin.
3535
#define PLUGIN_TARGET(Name) \
3636
do { \
37-
auto Plugin = std::unique_ptr<GenericPluginTy>(createPlugin_##Name()); \
38-
if (auto Err = Plugin->init()) { \
39-
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); \
40-
DP("Failed to init plugin: %s\n", InfoMsg.c_str()); \
41-
} else { \
42-
DP("Registered plugin %s with %d visible device(s)\n", \
43-
Plugin->getName(), Plugin->number_of_devices()); \
44-
Plugins.emplace_back(std::move(Plugin)); \
45-
} \
37+
Plugins.emplace_back( \
38+
std::unique_ptr<GenericPluginTy>(createPlugin_##Name())); \
4639
} while (false);
4740
#include "Shared/Targets.def"
4841

@@ -160,6 +153,27 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
160153
if (Entry.flags == OMP_REGISTER_REQUIRES)
161154
PM->addRequirements(Entry.data);
162155

156+
// Initialize all the plugins that have associated images.
157+
for (auto &Plugin : Plugins) {
158+
if (Plugin->is_initialized())
159+
continue;
160+
161+
// Extract the exectuable image and extra information if availible.
162+
for (int32_t i = 0; i < Desc->NumDeviceImages; ++i) {
163+
if (!Plugin->is_valid_binary(&Desc->DeviceImages[i],
164+
/*Initialized=*/false))
165+
continue;
166+
167+
if (auto Err = Plugin->init()) {
168+
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
169+
DP("Failed to init plugin: %s\n", InfoMsg.c_str());
170+
} else {
171+
DP("Registered plugin %s with %d visible device(s)\n",
172+
Plugin->getName(), Plugin->number_of_devices());
173+
}
174+
}
175+
}
176+
163177
// Extract the exectuable image and extra information if availible.
164178
for (int32_t i = 0; i < Desc->NumDeviceImages; ++i)
165179
PM->addDeviceImage(*Desc, Desc->DeviceImages[i]);
@@ -177,7 +191,7 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
177191
if (!R.number_of_devices())
178192
continue;
179193

180-
if (!R.is_valid_binary(Img)) {
194+
if (!R.is_valid_binary(Img, /*Initialized=*/true)) {
181195
DP("Image " DPxMOD " is NOT compatible with RTL %s!\n",
182196
DPxPTR(Img->ImageStart), R.getName());
183197
continue;

0 commit comments

Comments
 (0)