-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[llvm][IR] Extend BranchWeightMetadata to track provenance of weights #86609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
1e9dbac
ffb5bb8
ee503bf
7217003
7b541e1
e5bd278
af27efe
bc338aa
5a5ae01
6e16264
56c4658
319aaa6
37af7e2
21382a0
6729b47
4749bdc
7760282
947f9e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5199,7 +5199,11 @@ void SwitchInstProfUpdateWrapper::init() { | |
if (!ProfileData) | ||
return; | ||
|
||
if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) { | ||
// FIXME: This check belongs in ProfDataUtils. Its almost equivalent to | ||
// getValidBranchWeightMDNode(), but the need to use llvm_unreachable | ||
// makes them slightly different. | ||
if (ProfileData->getNumOperands() != | ||
SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) { | ||
Comment on lines
+5202
to
+5206
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems simple enough to do something about it instead of adding a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another good suggestion. Thank you. |
||
llvm_unreachable("number of prof branch_weights metadata operands does " | ||
"not correspond to number of succesors"); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,9 +40,6 @@ namespace { | |
// We maintain some constants here to ensure that we access the branch weights | ||
// correctly, and can change the behavior in the future if the layout changes | ||
|
||
// The index at which the weights vector starts | ||
constexpr unsigned WeightsIdx = 1; | ||
|
||
// the minimum number of operands for MD_prof nodes with branch weights | ||
constexpr unsigned MinBWOps = 3; | ||
|
||
|
@@ -75,15 +72,16 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData, | |
assert(isBranchWeightMD(ProfileData) && "wrong metadata"); | ||
|
||
unsigned NOps = ProfileData->getNumOperands(); | ||
unsigned WeightsIdx = getBranchWeightOffset(ProfileData); | ||
assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); | ||
Weights.resize(NOps - WeightsIdx); | ||
|
||
for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { | ||
ConstantInt *Weight = | ||
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); | ||
assert(Weight && "Malformed branch_weight in MD_prof node"); | ||
assert(Weight->getValue().getActiveBits() <= 32 && | ||
"Too many bits for uint32_t"); | ||
assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && | ||
"Too many bits for MD_prof branch_weight"); | ||
Weights[Idx - WeightsIdx] = Weight->getZExtValue(); | ||
} | ||
} | ||
|
@@ -123,6 +121,25 @@ bool hasValidBranchWeightMD(const Instruction &I) { | |
return getValidBranchWeightMDNode(I); | ||
} | ||
|
||
bool hasBranchWeightProvenance(const Instruction &I) { | ||
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); | ||
return hasBranchWeightProvenance(ProfileData); | ||
} | ||
|
||
bool hasBranchWeightProvenance(const MDNode *ProfileData) { | ||
if (!isBranchWeightMD(ProfileData)) | ||
return false; | ||
auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1)); | ||
// NOTE: if we ever have more types of branch weight provenance, | ||
// we need to check the string value is "expected". For now, we | ||
// supply a more generic API, and avoid the spurious comparisons. | ||
return ProfDataName; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a debug assert for now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
} | ||
|
||
unsigned getBranchWeightOffset(const MDNode *ProfileData) { | ||
return hasBranchWeightProvenance(ProfileData) ? 2 : 1; | ||
} | ||
|
||
MDNode *getBranchWeightMDNode(const Instruction &I) { | ||
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); | ||
if (!isBranchWeightMD(ProfileData)) | ||
|
@@ -132,7 +149,9 @@ MDNode *getBranchWeightMDNode(const Instruction &I) { | |
|
||
MDNode *getValidBranchWeightMDNode(const Instruction &I) { | ||
auto *ProfileData = getBranchWeightMDNode(I); | ||
if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) | ||
auto Offset = getBranchWeightOffset(ProfileData); | ||
if (ProfileData && | ||
ProfileData->getNumOperands() == Offset + I.getNumSuccessors()) | ||
return ProfileData; | ||
return nullptr; | ||
} | ||
|
@@ -190,8 +209,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { | |
if (!ProfDataName) | ||
return false; | ||
|
||
if (ProfDataName->getString() == "branch_weights") { | ||
for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { | ||
if (ProfDataName->getString().equals("branch_weights")) { | ||
unsigned Offset = getBranchWeightOffset(ProfileData); | ||
for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { | ||
auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); | ||
assert(V && "Malformed branch_weight in MD_prof node"); | ||
TotalVal += V->getValue().getZExtValue(); | ||
|
@@ -212,9 +232,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { | |
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); | ||
} | ||
|
||
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) { | ||
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, | ||
bool IsExpected) { | ||
MDBuilder MDB(I.getContext()); | ||
MDNode *BranchWeights = MDB.createBranchWeights(Weights); | ||
MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); | ||
I.setMetadata(LLVMContext::MD_prof, BranchWeights); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,6 +104,7 @@ | |
#include "llvm/IR/Module.h" | ||
#include "llvm/IR/ModuleSlotTracker.h" | ||
#include "llvm/IR/PassManager.h" | ||
#include "llvm/IR/ProfDataUtils.h" | ||
#include "llvm/IR/Statepoint.h" | ||
#include "llvm/IR/Type.h" | ||
#include "llvm/IR/Use.h" | ||
|
@@ -4807,9 +4808,11 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) { | |
StringRef ProfName = MDS->getString(); | ||
|
||
// Check consistency of !prof branch_weights metadata. | ||
if (ProfName == "branch_weights") { | ||
if (ProfName.equals("branch_weights")) { | ||
unsigned int Offset = getBranchWeightOffset(MD); | ||
if (isa<InvokeInst>(&I)) { | ||
Check(MD->getNumOperands() == 2 || MD->getNumOperands() == 3, | ||
Check(MD->getNumOperands() == (1 + Offset) || | ||
MD->getNumOperands() == (2 + Offset), | ||
Comment on lines
+4814
to
+4815
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More opportunities for a possible |
||
"Wrong number of InvokeInst branch_weights operands", MD); | ||
} else { | ||
unsigned ExpectedNumOperands = 0; | ||
|
@@ -4829,10 +4832,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) { | |
CheckFailed("!prof branch_weights are not allowed for this instruction", | ||
MD); | ||
|
||
Check(MD->getNumOperands() == 1 + ExpectedNumOperands, | ||
Check(MD->getNumOperands() == Offset + ExpectedNumOperands, | ||
"Wrong number of operands", MD); | ||
} | ||
for (unsigned i = 1; i < MD->getNumOperands(); ++i) { | ||
for (unsigned i = Offset; i < MD->getNumOperands(); ++i) { | ||
auto &MDO = MD->getOperand(i); | ||
Check(MDO, "second operand should not be null", MD); | ||
Check(mdconst::dyn_extract<ConstantInt>(MDO), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it clearer to name it 'IsBranchWeightUserExpected'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I see your point, given its current use, but I do think we'll want to track more things in the future. Some other options:
IsBranchWeightFromLlvmIntrinsic
,hasOptionalMetadataField
, orhasBranchWeightOrigin
? The last is basically the same as the current, but avoids the use ofProvenance
like @MatzeB brought up earlier.WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hasBranchWeightOrigin sounds good to me -- it is straightforward for the reader to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.