Skip to content

Commit 8bc8b84

Browse files
[SPIR-V] Fix inconsistency between previously deduced element type of a pointer and function's return type (#109660)
This PR improves type inference and fixes inconsistency between previously deduced element type of a pointer and function's return type. It fixes #109401 by ensuring that OpPhi is consistent with respect to operand types.
1 parent 3ba4092 commit 8bc8b84

File tree

4 files changed

+159
-6
lines changed

4 files changed

+159
-6
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ class SPIRVEmitIntrinsics
144144
Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
145145
Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
146146
std::unordered_set<Function *> &FVisited);
147+
void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
148+
CallInst *AssignCI);
147149

148150
public:
149151
static char ID;
@@ -502,10 +504,11 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
502504
if (DemangledName.length() > 0)
503505
DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
504506
auto AsArgIt = ResTypeByArg.find(DemangledName);
505-
if (AsArgIt != ResTypeByArg.end()) {
507+
if (AsArgIt != ResTypeByArg.end())
506508
Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
507509
Visited, UnknownElemTypeI8);
508-
}
510+
else if (Type *KnownRetTy = GR->findDeducedElementType(CalledF))
511+
Ty = KnownRetTy;
509512
}
510513
}
511514

@@ -835,6 +838,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
835838
CallInst *PtrCastI =
836839
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
837840
I->setOperand(OpIt.second, PtrCastI);
841+
buildAssignPtr(B, KnownElemTy, PtrCastI);
838842
}
839843
}
840844
}
@@ -1736,6 +1740,26 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
17361740
return true;
17371741
}
17381742

1743+
void SPIRVEmitIntrinsics::replaceWithPtrcasted(Instruction *CI, Type *NewElemTy,
1744+
Type *KnownElemTy,
1745+
CallInst *AssignCI) {
1746+
updateAssignType(AssignCI, CI, PoisonValue::get(NewElemTy));
1747+
IRBuilder<> B(CI->getContext());
1748+
B.SetInsertPoint(*CI->getInsertionPointAfterDef());
1749+
B.SetCurrentDebugLocation(CI->getDebugLoc());
1750+
Type *OpTy = CI->getType();
1751+
SmallVector<Type *, 2> Types = {OpTy, OpTy};
1752+
SmallVector<Value *, 2> Args = {CI, buildMD(PoisonValue::get(KnownElemTy)),
1753+
B.getInt32(getPointerAddressSpace(OpTy))};
1754+
CallInst *PtrCasted =
1755+
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
1756+
SmallVector<User *> Users(CI->users());
1757+
for (auto *U : Users)
1758+
if (U != AssignCI && U != PtrCasted)
1759+
U->replaceUsesOfWith(CI, PtrCasted);
1760+
buildAssignPtr(B, KnownElemTy, PtrCasted);
1761+
}
1762+
17391763
// Try to deduce a better type for pointers to untyped ptr.
17401764
bool SPIRVEmitIntrinsics::postprocessTypes() {
17411765
bool Changed = false;
@@ -1747,6 +1771,18 @@ bool SPIRVEmitIntrinsics::postprocessTypes() {
17471771
Type *KnownTy = GR->findDeducedElementType(*IB);
17481772
if (!KnownTy || !AssignCI || !isa<Instruction>(AssignCI->getArgOperand(0)))
17491773
continue;
1774+
// Try to improve the type deduced after all Functions are processed.
1775+
if (auto *CI = dyn_cast<CallInst>(*IB)) {
1776+
if (Function *CalledF = CI->getCalledFunction()) {
1777+
Type *RetElemTy = GR->findDeducedElementType(CalledF);
1778+
// Fix inconsistency between known type and function's return type.
1779+
if (RetElemTy && RetElemTy != KnownTy) {
1780+
replaceWithPtrcasted(CI, RetElemTy, KnownTy, AssignCI);
1781+
Changed = true;
1782+
continue;
1783+
}
1784+
}
1785+
}
17501786
Instruction *I = cast<Instruction>(AssignCI->getArgOperand(0));
17511787
for (User *U : I->users()) {
17521788
Instruction *Inst = dyn_cast<Instruction>(U);

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,17 @@ createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
341341
return {Reg, GetIdOp};
342342
}
343343

344+
static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
345+
MachineBasicBlock &MBB = *Def->getParent();
346+
MachineBasicBlock::iterator DefIt =
347+
Def->getNextNode() ? Def->getNextNode()->getIterator() : MBB.end();
348+
// Skip all the PHI and debug instructions.
349+
while (DefIt != MBB.end() &&
350+
(DefIt->isPHI() || DefIt->isDebugOrPseudoInstr()))
351+
DefIt = std::next(DefIt);
352+
MIB.setInsertPt(MBB, DefIt);
353+
}
354+
344355
// Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
345356
// a dst of the definition, assign SPIRVType to both registers. If SpvType is
346357
// provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
@@ -350,11 +361,9 @@ namespace llvm {
350361
Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
351362
SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
352363
MachineRegisterInfo &MRI) {
353-
MachineInstr *Def = MRI.getVRegDef(Reg);
354364
assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
355-
MIB.setInsertPt(*Def->getParent(),
356-
(Def->getNextNode() ? Def->getNextNode()->getIterator()
357-
: Def->getParent()->end()));
365+
MachineInstr *Def = MRI.getVRegDef(Reg);
366+
setInsertPtAfterDef(MIB, Def);
358367
SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
359368
Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
360369
if (auto *RC = MRI.getRegClassOrNull(Reg)) {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
2+
; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.
3+
4+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
6+
7+
; CHECK: %[[#Char:]] = OpTypeInt 8 0
8+
; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
9+
; CHECK: %[[#Int:]] = OpTypeInt 32 0
10+
; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
11+
; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
12+
; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
13+
; CHECK-DAG: %[[#Casted1:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
14+
; CHECK-DAG: %[[#Casted2:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
15+
; CHECK: OpBranchConditional
16+
; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted1]] %[[#]]
17+
; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted2]] %[[#]]
18+
19+
define void @f0(ptr %arg) {
20+
entry:
21+
ret void
22+
}
23+
24+
define ptr @f1() {
25+
entry:
26+
%p = alloca i8
27+
store i8 8, ptr %p
28+
ret ptr %p
29+
}
30+
31+
define ptr @f2() {
32+
entry:
33+
%p = alloca i32
34+
store i32 32, ptr %p
35+
ret ptr %p
36+
}
37+
38+
define ptr @foo(i1 %arg) {
39+
entry:
40+
%r1 = tail call ptr @f1()
41+
%r2 = tail call ptr @f2()
42+
br i1 %arg, label %l1, label %l2
43+
44+
l1:
45+
br label %exit
46+
47+
l2:
48+
br label %exit
49+
50+
exit:
51+
%ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
52+
%ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
53+
tail call void @f0(ptr %ret)
54+
ret ptr %ret2
55+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
2+
; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.
3+
4+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
6+
7+
; CHECK: %[[#Char:]] = OpTypeInt 8 0
8+
; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
9+
; CHECK: %[[#Int:]] = OpTypeInt 32 0
10+
; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
11+
; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
12+
; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
13+
; CHECK: %[[#Casted:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
14+
; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]
15+
; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]
16+
17+
define ptr @foo(i1 %arg) {
18+
entry:
19+
%r1 = tail call ptr @f1()
20+
%r2 = tail call ptr @f2()
21+
br i1 %arg, label %l1, label %l2
22+
23+
l1:
24+
br label %exit
25+
26+
l2:
27+
br label %exit
28+
29+
exit:
30+
%ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
31+
%ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
32+
tail call void @f0(ptr %ret)
33+
ret ptr %ret2
34+
}
35+
36+
define void @f0(ptr %arg) {
37+
entry:
38+
ret void
39+
}
40+
41+
define ptr @f1() {
42+
entry:
43+
%p = alloca i8
44+
store i8 8, ptr %p
45+
ret ptr %p
46+
}
47+
48+
define ptr @f2() {
49+
entry:
50+
%p = alloca i32
51+
store i32 32, ptr %p
52+
ret ptr %p
53+
}

0 commit comments

Comments
 (0)