diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 4bfa51e2cccdd..e236d646e66fc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -174,7 +174,7 @@ using namespace InstructionSet; namespace SPIRV { /// Parses the name part of the demangled builtin call. std::string lookupBuiltinNameHelper(StringRef DemangledCall, - std::string *Postfix) { + FPDecorationId *DecorationId) { const static std::string PassPrefix = "(anonymous namespace)::"; std::string BuiltinName; // Itanium Demangler result may have "(anonymous namespace)::" prefix @@ -232,12 +232,16 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|" "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|" "Convert|" - "UConvert|SConvert|FConvert|SatConvert).*)_R(.*)"); + "UConvert|SConvert|FConvert|SatConvert).*)_R[^_]*_?(\\w+)?.*"); std::smatch Match; - if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 3) { - BuiltinName = Match[1].str(); - if (Postfix) - *Postfix = Match[3].str(); + if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) { + std::ssub_match SubMatch; + if (DecorationId && Match.size() > 3) { + SubMatch = Match[3]; + *DecorationId = demangledPostfixToDecorationId(SubMatch.str()); + } + SubMatch = Match[1]; + BuiltinName = SubMatch.str(); } return BuiltinName; diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h index 0182d9652d18c..1a8641a8328dd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h @@ -21,7 +21,7 @@ namespace llvm { namespace SPIRV { /// Parses the name part of the demangled builtin call. std::string lookupBuiltinNameHelper(StringRef DemangledCall, - std::string *Postfix = nullptr); + FPDecorationId *DecorationId = nullptr); /// Lowers a builtin function call using the provided \p DemangledCall skeleton /// and external instruction \p Set. /// diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 433956f44917f..77b54219a9acc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -1876,18 +1876,6 @@ bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, return true; } -static unsigned roundingModeMDToDecorationConst(StringRef S) { - if (S == "rte") - return SPIRV::FPRoundingMode::FPRoundingMode::RTE; - if (S == "rtz") - return SPIRV::FPRoundingMode::FPRoundingMode::RTZ; - if (S == "rtp") - return SPIRV::FPRoundingMode::FPRoundingMode::RTP; - if (S == "rtn") - return SPIRV::FPRoundingMode::FPRoundingMode::RTN; - return std::numeric_limits::max(); -} - void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, IRBuilder<> &B) { // TODO: extend the list of functions with known result types @@ -1905,9 +1893,10 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, Function *CalledF = CI->getCalledFunction(); std::string DemangledName = getOclOrSpirvBuiltinDemangledName(CalledF->getName()); - std::string Postfix; + FPDecorationId DecorationId = FPDecorationId::NONE; if (DemangledName.length() > 0) - DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName, &Postfix); + DemangledName = + SPIRV::lookupBuiltinNameHelper(DemangledName, &DecorationId); auto ResIt = ResTypeWellKnown.find(DemangledName); if (ResIt != ResTypeWellKnown.end()) { IsKnown = true; @@ -1919,18 +1908,29 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, break; } } - // check if a floating rounding mode info is present - StringRef S = Postfix; - SmallVector Parts; - S.split(Parts, "_", -1, false); - if (Parts.size() > 1) { - // Convert the info about rounding mode into a decoration record. - unsigned RoundingModeDeco = roundingModeMDToDecorationConst(Parts[1]); - if (RoundingModeDeco != std::numeric_limits::max()) - createRoundingModeDecoration(CI, RoundingModeDeco, B); - // Check if the SaturatedConversion info is present. - if (Parts[1] == "sat") - createSaturatedConversionDecoration(CI, B); + // check if a floating rounding mode or saturation info is present + switch (DecorationId) { + default: + break; + case FPDecorationId::SAT: + createSaturatedConversionDecoration(CI, B); + break; + case FPDecorationId::RTE: + createRoundingModeDecoration( + CI, SPIRV::FPRoundingMode::FPRoundingMode::RTE, B); + break; + case FPDecorationId::RTZ: + createRoundingModeDecoration( + CI, SPIRV::FPRoundingMode::FPRoundingMode::RTZ, B); + break; + case FPDecorationId::RTP: + createRoundingModeDecoration( + CI, SPIRV::FPRoundingMode::FPRoundingMode::RTP, B); + break; + case FPDecorationId::RTN: + createRoundingModeDecoration( + CI, SPIRV::FPRoundingMode::FPRoundingMode::RTN, B); + break; } } } diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 3e913646d57c8..0c42447700106 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -157,28 +157,52 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { }); } +void SPIRVGlobalRegistry::invalidateMachineInstr(MachineInstr *MI) { + // TODO: + // - take into account duplicate tracker case which is a known issue, + // - review other data structure wrt. possible issues related to removal + // of a machine instruction during instruction selection. + const MachineFunction *MF = MI->getParent()->getParent(); + auto It = LastInsertedTypeMap.find(MF); + if (It == LastInsertedTypeMap.end()) + return; + if (It->second == MI) + LastInsertedTypeMap.erase(MF); +} + SPIRVType *SPIRVGlobalRegistry::createOpType( MachineIRBuilder &MIRBuilder, std::function Op) { auto oldInsertPoint = MIRBuilder.getInsertPt(); MachineBasicBlock *OldMBB = &MIRBuilder.getMBB(); + MachineBasicBlock *NewMBB = &*MIRBuilder.getMF().begin(); auto LastInsertedType = LastInsertedTypeMap.find(CurMF); if (LastInsertedType != LastInsertedTypeMap.end()) { auto It = LastInsertedType->second->getIterator(); - auto NewMBB = MIRBuilder.getMF().begin(); - MIRBuilder.setInsertPt(*NewMBB, It->getNextNode() - ? It->getNextNode()->getIterator() - : NewMBB->end()); + // It might happen that this instruction was removed from the first MBB, + // hence the Parent's check. + MachineBasicBlock::iterator InsertAt; + if (It->getParent() != NewMBB) + InsertAt = oldInsertPoint->getParent() == NewMBB + ? oldInsertPoint + : getInsertPtValidEnd(NewMBB); + else if (It->getNextNode()) + InsertAt = It->getNextNode()->getIterator(); + else + InsertAt = getInsertPtValidEnd(NewMBB); + MIRBuilder.setInsertPt(*NewMBB, InsertAt); } else { - MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(), - MIRBuilder.getMF().begin()->begin()); + MIRBuilder.setInsertPt(*NewMBB, NewMBB->begin()); auto Result = LastInsertedTypeMap.try_emplace(CurMF, nullptr); assert(Result.second); LastInsertedType = Result.first; } MachineInstr *Type = Op(MIRBuilder); + // We expect all users of this function to insert definitions at the insertion + // point set above that is always the first MBB. + assert(Type->getParent() == NewMBB); LastInsertedType->second = Type; MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index df92325ed1980..ec2386fa1e56e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -444,6 +444,10 @@ class SPIRVGlobalRegistry { bool isBitcastCompatible(const SPIRVType *Type1, const SPIRVType *Type2) const; + // Informs about removal of the machine instruction and invalidates data + // structures referring this instruction. + void invalidateMachineInstr(MachineInstr *MI); + private: SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder); diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index ff8d1f7485e16..eef7cefdeed4c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -431,6 +431,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) { } MRI->setRegClass(SrcReg, MRI->getRegClass(DstReg)); MRI->replaceRegWith(SrcReg, DstReg); + GR.invalidateMachineInstr(&I); I.removeFromParent(); return true; } else if (I.getNumDefs() == 1) { @@ -445,6 +446,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) { // erase it LLVM_DEBUG(dbgs() << "Instruction is folded and dead.\n"); salvageDebugInfo(*MRI, I); + GR.invalidateMachineInstr(&I); I.eraseFromParent(); return true; } @@ -464,6 +466,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) { if (HasDefs) // Make all vregs 64 bits (for SPIR-V IDs). for (unsigned i = 0; i < I.getNumDefs(); ++i) MRI->setType(I.getOperand(i).getReg(), LLT::scalar(64)); + GR.invalidateMachineInstr(&I); I.removeFromParent(); return true; } @@ -2253,8 +2256,10 @@ bool SPIRVInstructionSelector::selectDiscard(Register ResVReg, } else { Opcode = SPIRV::OpKill; // OpKill must be the last operation of any basic block. - MachineInstr *NextI = I.getNextNode(); - NextI->removeFromParent(); + if (MachineInstr *NextI = I.getNextNode()) { + GR.invalidateMachineInstr(NextI); + NextI->removeFromParent(); + } } MachineBasicBlock &BB = *I.getParent(); diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index b22027cd2cb93..fa5e0a80576d0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -319,7 +319,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { // tighten these requirements. Many of these math functions are only legal on // specific bitwidths, so they are not selectable for // allFloatScalarsAndVectors. - getActionDefinitionsBuilder({G_FPOW, + getActionDefinitionsBuilder({G_STRICT_FSQRT, + G_FPOW, G_FEXP, G_FEXP2, G_FLOG, diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index ce90e335fe404..ddc66f98829a9 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -194,6 +194,19 @@ MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) { return It; } +MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB) { + MachineBasicBlock::iterator I = MBB->end(); + if (I == MBB->begin()) + return I; + --I; + while (I->isTerminator() || I->isDebugValue()) { + if (I == MBB->begin()) + break; + --I; + } + return I; +} + SPIRV::StorageClass::StorageClass addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) { switch (AddrSpace) { diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index cc77e0afa275a..da2e24c0c9abe 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -150,6 +150,10 @@ void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, // i.e., at the beginning of the first block of the function. MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I); +// Return a valid position for the instruction at the end of the block before +// terminators and debug instructions. +MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB); + // Convert a SPIR-V storage class to the corresponding LLVM IR address space. // TODO: maybe the following two functions should be handled in the subtarget // to allow for different OpenCL vs Vulkan handling. @@ -396,5 +400,18 @@ Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR, // Return true if there is an opaque pointer type nested in the argument. bool isNestedPointer(const Type *Ty); +enum FPDecorationId { NONE, RTE, RTZ, RTP, RTN, SAT }; + +inline FPDecorationId demangledPostfixToDecorationId(const std::string &S) { + static std::unordered_map Mapping = { + {"rte", FPDecorationId::RTE}, + {"rtz", FPDecorationId::RTZ}, + {"rtp", FPDecorationId::RTP}, + {"rtn", FPDecorationId::RTN}, + {"sat", FPDecorationId::SAT}}; + auto It = Mapping.find(S); + return It == Mapping.end() ? FPDecorationId::NONE : It->second; +} + } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H