Skip to content

[NFC][PGO] Use constants rather than free strings for metadata labels #145721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/IR/ProfDataUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
#include "llvm/Support/Compiler.h"

namespace llvm {
struct MDProfLabels {
static const char *BranchWeights;
static const char *ValueProfile;
static const char *FunctionEntryCount;
static const char *SyntheticFunctionEntryCount;
static const char *ExpectedBranchWeights;
};

/// Checks if an Instruction has MD_prof Metadata
LLVM_ABI bool hasProfMD(const Instruction &I);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Bitcode/Reader/BitcodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7031,7 +7031,7 @@ Error BitcodeReader::materialize(GlobalValue *GV) {
MDString *MDS = cast<MDString>(MD->getOperand(0));
StringRef ProfName = MDS->getString();
// Check consistency of !prof branch_weights metadata.
if (ProfName != "branch_weights")
if (ProfName != MDProfLabels::BranchWeights)
continue;
unsigned ExpectedNumOperands = 0;
if (BranchInst *BI = dyn_cast<BranchInst>(&I))
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/IR/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/SymbolTableListTraits.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
Expand Down Expand Up @@ -1115,7 +1116,7 @@ std::optional<ProfileCount> Function::getEntryCount(bool AllowSynthetic) const {
MDNode *MD = getMetadata(LLVMContext::MD_prof);
if (MD && MD->getOperand(0))
if (MDString *MDS = dyn_cast<MDString>(MD->getOperand(0))) {
if (MDS->getString() == "function_entry_count") {
if (MDS->getString() == MDProfLabels::FunctionEntryCount) {
ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(1));
uint64_t Count = CI->getValue().getZExtValue();
// A value of -1 is used for SamplePGO when there were no samples.
Expand All @@ -1124,7 +1125,8 @@ std::optional<ProfileCount> Function::getEntryCount(bool AllowSynthetic) const {
return std::nullopt;
return ProfileCount(Count, PCT_Real);
} else if (AllowSynthetic &&
MDS->getString() == "synthetic_function_entry_count") {
MDS->getString() ==
MDProfLabels::SyntheticFunctionEntryCount) {
ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(1));
uint64_t Count = CI->getValue().getZExtValue();
return ProfileCount(Count, PCT_Synthetic);
Expand All @@ -1137,7 +1139,7 @@ DenseSet<GlobalValue::GUID> Function::getImportGUIDs() const {
DenseSet<GlobalValue::GUID> R;
if (MDNode *MD = getMetadata(LLVMContext::MD_prof))
if (MDString *MDS = dyn_cast<MDString>(MD->getOperand(0)))
if (MDS->getString() == "function_entry_count")
if (MDS->getString() == MDProfLabels::FunctionEntryCount)
for (unsigned i = 2; i < MD->getNumOperands(); i++)
R.insert(mdconst::extract<ConstantInt>(MD->getOperand(i))
->getValue()
Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/IR/MDBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/ProfDataUtils.h"
using namespace llvm;

MDString *MDBuilder::createString(StringRef Str) {
Expand Down Expand Up @@ -55,9 +56,9 @@ MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights,

unsigned int Offset = IsExpected ? 2 : 1;
SmallVector<Metadata *, 4> Vals(Weights.size() + Offset);
Vals[0] = createString("branch_weights");
Vals[0] = createString(MDProfLabels::BranchWeights);
if (IsExpected)
Vals[1] = createString("expected");
Vals[1] = createString(MDProfLabels::ExpectedBranchWeights);

Type *Int32Ty = Type::getInt32Ty(Context);
for (unsigned i = 0, e = Weights.size(); i != e; ++i)
Expand All @@ -74,9 +75,9 @@ MDNode *MDBuilder::createFunctionEntryCount(
Type *Int64Ty = Type::getInt64Ty(Context);
SmallVector<Metadata *, 8> Ops;
if (Synthetic)
Ops.push_back(createString("synthetic_function_entry_count"));
Ops.push_back(createString(MDProfLabels::SyntheticFunctionEntryCount));
else
Ops.push_back(createString("function_entry_count"));
Ops.push_back(createString(MDProfLabels::FunctionEntryCount));
Ops.push_back(createConstant(ConstantInt::get(Int64Ty, Count)));
if (Imports) {
SmallVector<GlobalValue::GUID, 2> OrderID(Imports->begin(), Imports->end());
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/IR/Metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1202,14 +1202,15 @@ MDNode *MDNode::mergeDirectCallProfMetadata(MDNode *A, MDNode *B,
"first operand should be a non-null MDString");
StringRef AProfName = AMDS->getString();
StringRef BProfName = BMDS->getString();
if (AProfName == "branch_weights" && BProfName == "branch_weights") {
if (AProfName == MDProfLabels::BranchWeights &&
BProfName == MDProfLabels::BranchWeights) {
ConstantInt *AInstrWeight = mdconst::dyn_extract<ConstantInt>(
A->getOperand(getBranchWeightOffset(A)));
ConstantInt *BInstrWeight = mdconst::dyn_extract<ConstantInt>(
B->getOperand(getBranchWeightOffset(B)));
assert(AInstrWeight && BInstrWeight && "verified by LLVM verifier");
return MDNode::get(Ctx,
{MDHelper.createString("branch_weights"),
{MDHelper.createString(MDProfLabels::BranchWeights),
MDHelper.createConstant(ConstantInt::get(
Type::getInt64Ty(Ctx),
SaturatingAdd(AInstrWeight->getZExtValue(),
Expand Down
28 changes: 19 additions & 9 deletions llvm/lib/IR/ProfDataUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,23 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,

namespace llvm {

const char *MDProfLabels::BranchWeights = "branch_weights";
const char *MDProfLabels::ExpectedBranchWeights = "expected";
const char *MDProfLabels::ValueProfile = "VP";
const char *MDProfLabels::FunctionEntryCount = "function_entry_count";
const char *MDProfLabels::SyntheticFunctionEntryCount =
"synthetic_function_entry_count";

bool hasProfMD(const Instruction &I) {
return I.hasMetadata(LLVMContext::MD_prof);
}

bool isBranchWeightMD(const MDNode *ProfileData) {
return isTargetMD(ProfileData, "branch_weights", MinBWOps);
return isTargetMD(ProfileData, MDProfLabels::BranchWeights, MinBWOps);
}

static bool isValueProfileMD(const MDNode *ProfileData) {
return isTargetMD(ProfileData, "VP", MinVPOps);
return isTargetMD(ProfileData, MDProfLabels::ValueProfile, MinVPOps);
}

bool hasBranchWeightMD(const Instruction &I) {
Expand Down Expand Up @@ -131,7 +138,8 @@ bool hasBranchWeightOrigin(const MDNode *ProfileData) {
// NOTE: if we ever have more types of branch weight provenance,
// we need to check the string value is "expected". For now, we
// supply a more generic API, and avoid the spurious comparisons.
assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
assert(ProfDataName == nullptr ||
ProfDataName->getString() == MDProfLabels::ExpectedBranchWeights);
return ProfDataName != nullptr;
}

Expand Down Expand Up @@ -210,7 +218,7 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
if (!ProfDataName)
return false;

if (ProfDataName->getString() == "branch_weights") {
if (ProfDataName->getString() == MDProfLabels::BranchWeights) {
unsigned Offset = getBranchWeightOffset(ProfileData);
for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
auto *V = mdconst::extract<ConstantInt>(ProfileData->getOperand(Idx));
Expand All @@ -219,7 +227,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
return true;
}

if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) {
if (ProfDataName->getString() == MDProfLabels::ValueProfile &&
ProfileData->getNumOperands() > 3) {
TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
->getValue()
.getZExtValue();
Expand All @@ -246,8 +255,9 @@ void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
return;

auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
if (!ProfDataName || (ProfDataName->getString() != "branch_weights" &&
ProfDataName->getString() != "VP"))
if (!ProfDataName ||
(ProfDataName->getString() != MDProfLabels::BranchWeights &&
ProfDataName->getString() != MDProfLabels::ValueProfile))
return;

if (!hasCountTypeMD(I))
Expand All @@ -259,7 +269,7 @@ void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
SmallVector<Metadata *, 3> Vals;
Vals.push_back(ProfileData->getOperand(0));
APInt APS(128, S), APT(128, T);
if (ProfDataName->getString() == "branch_weights" &&
if (ProfDataName->getString() == MDProfLabels::BranchWeights &&
ProfileData->getNumOperands() > 0) {
// Using APInt::div may be expensive, but most cases should fit 64 bits.
APInt Val(128,
Expand All @@ -270,7 +280,7 @@ void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
Val *= APS;
Vals.push_back(MDB.createConstant(ConstantInt::get(
Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
} else if (ProfDataName->getString() == "VP")
} else if (ProfDataName->getString() == MDProfLabels::ValueProfile)
for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) {
// The first value is the key of the value profile, which will not change.
Vals.push_back(ProfileData->getOperand(Idx));
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2536,8 +2536,8 @@ void Verifier::verifyFunctionMetadata(
"expected string with name of the !prof annotation", MD);
MDString *MDS = cast<MDString>(MD->getOperand(0));
StringRef ProfName = MDS->getString();
Check(ProfName == "function_entry_count" ||
ProfName == "synthetic_function_entry_count",
Check(ProfName == MDProfLabels::FunctionEntryCount ||
ProfName == MDProfLabels::SyntheticFunctionEntryCount,
"first operand should be 'function_entry_count'"
" or 'synthetic_function_entry_count'",
MD);
Expand Down Expand Up @@ -4993,7 +4993,7 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
StringRef ProfName = MDS->getString();

// Check consistency of !prof branch_weights metadata.
if (ProfName == "branch_weights") {
if (ProfName == MDProfLabels::BranchWeights) {
unsigned NumBranchWeights = getNumBranchWeights(*MD);
if (isa<InvokeInst>(&I)) {
Check(NumBranchWeights == 1 || NumBranchWeights == 2,
Expand Down Expand Up @@ -5027,8 +5027,8 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
"!prof brunch_weights operand is not a const int");
}
} else {
Check(ProfName == "VP", "expected either branch_weights or VP profile name",
MD);
Check(ProfName == MDProfLabels::ValueProfile,
"expected either branch_weights or VP profile name", MD);
}
}

Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/ProfileData/InstrProf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Type.h"
#include "llvm/ProfileData/InstrProfReader.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -1358,7 +1359,7 @@ void annotateValueSite(Module &M, Instruction &Inst,
MDBuilder MDHelper(Ctx);
SmallVector<Metadata *, 3> Vals;
// Tag
Vals.push_back(MDHelper.createString("VP"));
Vals.push_back(MDHelper.createString(MDProfLabels::ValueProfile));
// Value Kind
Vals.push_back(MDHelper.createConstant(
ConstantInt::get(Type::getInt32Ty(Ctx), ValueKind)));
Expand Down Expand Up @@ -1389,7 +1390,7 @@ MDNode *mayHaveValueProfileOfKind(const Instruction &Inst,
return nullptr;

MDString *Tag = cast<MDString>(MD->getOperand(0));
if (!Tag || Tag->getString() != "VP")
if (!Tag || Tag->getString() != MDProfLabels::ValueProfile)
return nullptr;

// Now check kind:
Expand Down