diff --git a/llvm/lib/SYCLLowerIR/LowerESIMD.cpp b/llvm/lib/SYCLLowerIR/LowerESIMD.cpp index 029263c0dc35c..00679c6fd79fb 100644 --- a/llvm/lib/SYCLLowerIR/LowerESIMD.cpp +++ b/llvm/lib/SYCLLowerIR/LowerESIMD.cpp @@ -43,6 +43,8 @@ namespace id = itanium_demangle; #define SLM_BTI 254 +#define MAX_DIMS 3 + namespace { SmallPtrSet collectGenXVolatileTypes(Module &); void generateKernelMetadata(Module &); @@ -846,145 +848,131 @@ static Instruction *addCastInstIfNeeded(Instruction *OldI, Instruction *NewI) { auto CastOpcode = CastInst::getCastOpcode(NewI, false, OITy, false); NewI = CastInst::Create(CastOpcode, NewI, OITy, NewI->getName() + ".cast.ty", OldI); + NewI->setDebugLoc(OldI->getDebugLoc()); } return NewI; } -static int getIndexForSuffix(StringRef Suff) { - return llvm::StringSwitch(Suff) - .Case("x", 0) - .Case("y", 1) - .Case("z", 2) - .Default(-1); +/// Returns the index from the given extract element instruction \p EEI. +/// It is checked here that the index is either 0, 1, or 2. +static uint64_t getIndexFromExtract(ExtractElementInst *EEI) { + Value *IndexV = EEI->getIndexOperand(); + uint64_t IndexValue = cast(IndexV)->getZExtValue(); + assert(IndexValue < MAX_DIMS && + "Extract element index should be either 0, 1, or 2"); + return IndexValue; } -// Helper function to convert extractelement instruction associated with the -// load from SPIRV builtin global, into the GenX intrinsic that returns vector -// of coordinates. It also generates required extractelement and cast -// instructions. Example: -// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast -// (<3 x i64> addrspace(1)* @__spirv_BuiltInLocalInvocationId -// to <3 x i64> addrspace(4)*), align 32 -// %1 = extractelement <3 x i64> %0, i64 0 -// -// => -// -// %.esimd = call <3 x i32> @llvm.genx.local.id.v3i32() -// %local_id.x = extractelement <3 x i32> %.esimd, i32 0 -// %local_id.x.cast.ty = zext i32 %local_id.x to i64 -static Instruction *generateVectorGenXForSpirv(ExtractElementInst *EEI, - StringRef Suff, - const std::string &IntrinName, - StringRef ValueName) { - std::string IntrName = - std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + IntrinName; - auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName); - LLVMContext &Ctx = EEI->getModule()->getContext(); - Type *I32Ty = Type::getInt32Ty(Ctx); - Function *NewFDecl = GenXIntrinsic::getGenXDeclaration( - EEI->getModule(), ID, {FixedVectorType::get(I32Ty, 3)}); - Instruction *IntrI = - IntrinsicInst::Create(NewFDecl, {}, EEI->getName() + ".esimd", EEI); - int ExtractIndex = getIndexForSuffix(Suff); - assert(ExtractIndex != -1 && "Extract index is invalid."); - Twine ExtractName = ValueName + Suff; - - Instruction *ExtrI = ExtractElementInst::Create( - IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, EEI); - Instruction *CastI = addCastInstIfNeeded(EEI, ExtrI); - if (EEI->getDebugLoc()) { - IntrI->setDebugLoc(EEI->getDebugLoc()); - ExtrI->setDebugLoc(EEI->getDebugLoc()); - // It's OK if ExtrI and CastI is the same instruction - CastI->setDebugLoc(EEI->getDebugLoc()); +/// Generates the call of GenX intrinsic \p IntrinName and inserts it +/// right before the given extract element instruction \p EEI using the result +/// of vector load. The parameter \p IsVectorCall tells what version of GenX +/// intrinsic (scalar or vector) to use to lower the load from SPIRV global. +static Instruction *generateGenXCall(ExtractElementInst *EEI, + StringRef IntrinName, bool IsVectorCall) { + uint64_t IndexValue = getIndexFromExtract(EEI); + std::string Suffix = + IsVectorCall + ? ".v3i32" + : (Twine(".") + Twine(static_cast('x' + IndexValue))).str(); + std::string FullIntrinName = (Twine(GenXIntrinsic::getGenXIntrinsicPrefix()) + + Twine(IntrinName) + Suffix) + .str(); + auto ID = GenXIntrinsic::lookupGenXIntrinsicID(FullIntrinName); + Type *I32Ty = Type::getInt32Ty(EEI->getModule()->getContext()); + Function *NewFDecl = + IsVectorCall + ? GenXIntrinsic::getGenXDeclaration( + EEI->getModule(), ID, FixedVectorType::get(I32Ty, MAX_DIMS)) + : GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID); + + std::string ResultName = + (Twine(EEI->getNameOrAsOperand()) + "." + FullIntrinName).str(); + Instruction *Inst = IntrinsicInst::Create(NewFDecl, {}, ResultName, EEI); + Inst->setDebugLoc(EEI->getDebugLoc()); + + if (IsVectorCall) { + Type *I32Ty = Type::getInt32Ty(EEI->getModule()->getContext()); + std::string ExtractName = + (Twine(Inst->getNameOrAsOperand()) + ".ext." + Twine(IndexValue)).str(); + Inst = ExtractElementInst::Create(Inst, ConstantInt::get(I32Ty, IndexValue), + ExtractName, EEI); + Inst->setDebugLoc(EEI->getDebugLoc()); } - return CastI; + Inst = addCastInstIfNeeded(EEI, Inst); + return Inst; } -// Helper function to convert extractelement instruction associated with the -// load from SPIRV builtin global, into the GenX intrinsic. It also generates -// required cast instructions. Example: -// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64> -// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align -// 32 %1 = extractelement <3 x i64> %0, i64 0 -// => -// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64> -// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align -// 32 %group.id.x = call i32 @llvm.genx.group.id.x() %group.id.x.cast.ty = zext -// i32 %group.id.x to i64 -static Instruction *generateGenXForSpirv(ExtractElementInst *EEI, - StringRef Suff, - const std::string &IntrinName) { - std::string IntrName = std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + - IntrinName + Suff.str(); - auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName); - Function *NewFDecl = - GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID, {}); - - Instruction *IntrI = - IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), EEI); - Instruction *CastI = addCastInstIfNeeded(EEI, IntrI); - if (EEI->getDebugLoc()) { - IntrI->setDebugLoc(EEI->getDebugLoc()); - // It's OK if IntrI and CastI is the same instruction - CastI->setDebugLoc(EEI->getDebugLoc()); +/// Replaces the load \p LI of SPIRV global with corresponding call(s) of GenX +/// intrinsic(s). The users of \p LI may also be transformed if needed for +/// def/use type correctness. +/// The replaced instructions are stored into the given container +/// \p InstsToErase. +static void +translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName, + SmallVectorImpl &InstsToErase) { + // TODO: Implement support for the following intrinsics: + // uint32_t __spirv_BuiltIn NumSubgroups; + // uint32_t __spirv_BuiltIn SubgroupId; + + // Translate those loads from _scalar_ SPIRV globals that can be replaced with + // a const value here. + // The loads from other scalar SPIRV globals may require insertion of GenX + // calls before each user, which is done in the loop by users of 'LI' below. + Value *NewInst = nullptr; + if (SpirvGlobalName == "SubgroupLocalInvocationId") { + NewInst = llvm::Constant::getNullValue(LI->getType()); + } else if (SpirvGlobalName == "SubgroupSize" || + SpirvGlobalName == "SubgroupMaxSize") { + NewInst = llvm::Constant::getIntegerValue(LI->getType(), + llvm::APInt(32, 1, true)); + } + if (NewInst) { + LI->replaceAllUsesWith(NewInst); + InstsToErase.push_back(LI); + return; } - return CastI; -} -// This function translates one occurence of SPIRV builtin use into GenX -// intrinsic. -static Value *translateSpirvGlobalUse(ExtractElementInst *EEI, - StringRef SpirvGlobalName) { - Value *IndexV = EEI->getIndexOperand(); - assert(isa(IndexV) && - "Extract element index should be a constant"); + // Only loads from _vector_ SPIRV globals reach here now. Their users are + // expected to be ExtractElementInst only, and they are replaced in this loop. + // When loads from _scalar_ SPIRV globals are handled here as well, the users + // will not be replaced by new instructions, but the GenX call replacing the + // original load 'LI' should be inserted before each user. + for (User *LU : LI->users()) { + ExtractElementInst *EEI = cast(LU); + NewInst = nullptr; + + if (SpirvGlobalName == "WorkgroupSize") { + NewInst = generateGenXCall(EEI, "local.size", true); + } else if (SpirvGlobalName == "LocalInvocationId") { + NewInst = generateGenXCall(EEI, "local.id", true); + } else if (SpirvGlobalName == "WorkgroupId") { + NewInst = generateGenXCall(EEI, "group.id", false); + } else if (SpirvGlobalName == "GlobalInvocationId") { + // GlobalId = LocalId + WorkGroupSize * GroupId + Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true); + Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true); + Instruction *GroupIdI = generateGenXCall(EEI, "group.id", false); + Instruction *MulI = + BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI); + NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI); + } else if (SpirvGlobalName == "GlobalSize") { + // GlobalSize = WorkGroupSize * NumWorkGroups + Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true); + Instruction *NumWGI = generateGenXCall(EEI, "group.count", true); + NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI); + } else if (SpirvGlobalName == "GlobalOffset") { + // TODO: Support GlobalOffset SPIRV intrinsics + // Currently all users of load of GlobalOffset are replaced with 0. + NewInst = llvm::Constant::getNullValue(EEI->getType()); + } else if (SpirvGlobalName == "NumWorkgroups") { + NewInst = generateGenXCall(EEI, "group.count", true); + } - // Get the suffix based on the index of extractelement instruction - ConstantInt *IndexC = cast(IndexV); - std::string Suff; - if (IndexC->equalsInt(0)) - Suff = 'x'; - else if (IndexC->equalsInt(1)) - Suff = 'y'; - else if (IndexC->equalsInt(2)) - Suff = 'z'; - else - assert(false && "Extract element index should be either 0, 1, or 2"); - - // Translate SPIRV into GenX intrinsic. - if (SpirvGlobalName == "WorkgroupSize") { - return generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize."); - } else if (SpirvGlobalName == "LocalInvocationId") { - return generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id."); - } else if (SpirvGlobalName == "WorkgroupId") { - return generateGenXForSpirv(EEI, Suff, "group.id."); - } else if (SpirvGlobalName == "GlobalInvocationId") { - // GlobalId = LocalId + WorkGroupSize * GroupId - Instruction *LocalIdI = - generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id."); - Instruction *WGSizeI = - generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize."); - Instruction *GroupIdI = generateGenXForSpirv(EEI, Suff, "group.id."); - Instruction *MulI = - BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI); - return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI); - } else if (SpirvGlobalName == "GlobalSize") { - // GlobalSize = WorkGroupSize * NumWorkGroups - Instruction *WGSizeI = - generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize."); - Instruction *NumWGI = generateVectorGenXForSpirv( - EEI, Suff, "group.count.v3i32", "group_count."); - return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI); - } else if (SpirvGlobalName == "GlobalOffset") { - // TODO: Support GlobalOffset SPIRV intrinsics - return llvm::Constant::getNullValue(EEI->getType()); - } else if (SpirvGlobalName == "NumWorkgroups") { - return generateVectorGenXForSpirv(EEI, Suff, "group.count.v3i32", - "group_count."); + assert(NewInst && "Load from global SPIRV builtin was not translated"); + EEI->replaceAllUsesWith(NewInst); + InstsToErase.push_back(EEI); } - - return nullptr; + InstsToErase.push_back(LI); } static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc, @@ -1370,8 +1358,7 @@ SmallPtrSet collectGenXVolatileTypes(Module &M) { } // namespace -PreservedAnalyses SYCLLowerESIMDPass::run(Module &M, - ModuleAnalysisManager &) { +PreservedAnalyses SYCLLowerESIMDPass::run(Module &M, ModuleAnalysisManager &) { generateKernelMetadata(M); SmallPtrSet GVTS = collectGenXVolatileTypes(M); @@ -1507,23 +1494,11 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F, auto PrefLen = StringRef(SPIRV_INTRIN_PREF).size(); - // Go through all the uses of the load instruction from SPIRV builtin - // globals, which are required to be extractelement instructions. - // Translate each of them. - for (auto *LU : LI->users()) { - auto *EEI = dyn_cast(LU); - assert(EEI && "User of load from global SPIRV builtin is not an " - "extractelement instruction"); - Value *TranslatedVal = translateSpirvGlobalUse( - EEI, SpirvGlobal->getName().drop_front(PrefLen)); - assert(TranslatedVal && - "Load from global SPIRV builtin was not translated"); - EEI->replaceAllUsesWith(TranslatedVal); - ESIMDToErases.push_back(EEI); - } - // After all users of load were translated, we get rid of the load - // itself. - ESIMDToErases.push_back(LI); + // Translate all uses of the load instruction from SPIRV builtin global. + // Replaces the original global load and it is uses and stores the old + // instructions to ESIMDToErases. + translateSpirvGlobalUses(LI, SpirvGlobal->getName().drop_front(PrefLen), + ESIMDToErases); } } // Now demangle and translate found ESIMD intrinsic calls diff --git a/sycl/test/esimd/spirv_intrins_trans.cpp b/sycl/test/esimd/spirv_intrins_trans.cpp index 5384e5dd31cfc..7f7da671e01fb 100644 --- a/sycl/test/esimd/spirv_intrins_trans.cpp +++ b/sycl/test/esimd/spirv_intrins_trans.cpp @@ -18,12 +18,15 @@ size_t caller() { size_t DoNotOpt; cl::sycl::buffer buf(&DoNotOpt, 1); + uint32_t DoNotOpt32; + cl::sycl::buffer buf32(&DoNotOpt32, 1); size_t DoNotOptXYZ[3]; cl::sycl::buffer bufXYZ(&DoNotOptXYZ[0], sycl::range<1>(3)); cl::sycl::queue().submit([&](cl::sycl::handler &cgh) { auto DoNotOptimize = buf.get_access(cgh); + auto DoNotOptimize32 = buf32.get_access(cgh); kernel([=]() SYCL_ESIMD_KERNEL { *DoNotOptimize.get_pointer() = __spirv_GlobalInvocationId_x(); @@ -213,6 +216,33 @@ size_t caller() { // CHECK: {{.*}} call i32 @llvm.genx.group.id.x() // CHECK: {{.*}} call i32 @llvm.genx.group.id.y() // CHECK: {{.*}} call i32 @llvm.genx.group.id.z() + + kernel([=]() SYCL_ESIMD_KERNEL { + *DoNotOptimize.get_pointer() = __spirv_SubgroupLocalInvocationId(); + *DoNotOptimize32.get_pointer() = __spirv_SubgroupLocalInvocationId() + 3; + }); + // CHECK-LABEL: @{{.*}}kernel_SubgroupLocalInvocationId + // CHECK: [[ZEXT0:%.*]] = zext i32 0 to i64 + // CHECK: store i64 [[ZEXT0]] + // CHECK: add i32 0, 3 + + kernel([=]() SYCL_ESIMD_KERNEL { + *DoNotOptimize.get_pointer() = __spirv_SubgroupSize(); + *DoNotOptimize32.get_pointer() = __spirv_SubgroupSize() + 7; + }); + // CHECK-LABEL: @{{.*}}kernel_SubgroupSize + // CHECK: [[ZEXT0:%.*]] = zext i32 1 to i64 + // CHECK: store i64 [[ZEXT0]] + // CHECK: add i32 1, 7 + + kernel([=]() SYCL_ESIMD_KERNEL { + *DoNotOptimize.get_pointer() = __spirv_SubgroupMaxSize(); + *DoNotOptimize32.get_pointer() = __spirv_SubgroupMaxSize() + 9; + }); + // CHECK-LABEL: @{{.*}}kernel_SubgroupMaxSize + // CHECK: [[ZEXT0:%.*]] = zext i32 1 to i64 + // CHECK: store i64 [[ZEXT0]] + // CHECK: add i32 1, 9 }); return DoNotOpt; }