Skip to content

[LoopIdiom] Initial support for generating memset_pattern intrinsic (disabled by default) #98311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15230,6 +15230,62 @@ The behavior of '``llvm.memset.inline.*``' is equivalent to the behavior of
'``llvm.memset.*``', but the generated code is guaranteed not to call any
external functions.

.. _int_memset_pattern:

'``llvm.memset_pattern``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

This is an overloaded intrinsic. You can use ``llvm.memset_pattern`` on
any integer bit width and for different address spaces. Not all targets
support all bit widths however.

::

declare void @llvm.memset_pattern.p0.i64.i128(ptr <dest>, i128 <val>,
i64 <len>, i1 <isvolatile>)

Overview:
"""""""""

The '``llvm.memset_pattern.*``' intrinsics fill a block of memory with
a particular value. This may be expanded to an inline loop, a sequence of
stores, or a libcall depending on what is available for the target and the
expected performance and code size impact.

Arguments:
""""""""""

The first argument is a pointer to the destination to fill, the second
is the value with which to fill it, the third argument is an integer
argument specifying the number of bytes to fill, and the fourth is a boolean
indicating a volatile access.

The :ref:`align <attr_align>` parameter attribute can be provided
for the first argument.

If the ``isvolatile`` parameter is ``true``, the
``llvm.memset_pattern`` call is a :ref:`volatile operation <volatile>`. The
detailed access behavior is not very cleanly specified and it is unwise to
depend on it.

Semantics:
""""""""""

The '``llvm.memset_pattern.*``' intrinsics fill "len" bytes of memory
starting at the destination location. If the argument is known to be aligned
to some boundary, this can be specified as an attribute on the argument.

If ``<len>`` is not an integer multiple of the pattern width in bytes, then any
remainder bytes will be copied from ``<val>``.
If ``<len>`` is 0, it is no-op modulo the behavior of attributes attached to
the arguments.
If ``<len>`` is not a well-defined value, the behavior is undefined.
If ``<len>`` is not zero, ``<dest>`` should be well-defined, otherwise the
behavior is undefined.

.. _int_sqrt:

'``llvm.sqrt.*``' Intrinsic
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/IR/InstVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class InstVisitor {
RetTy visitDbgInfoIntrinsic(DbgInfoIntrinsic &I){ DELEGATE(IntrinsicInst); }
RetTy visitMemSetInst(MemSetInst &I) { DELEGATE(MemIntrinsic); }
RetTy visitMemSetInlineInst(MemSetInlineInst &I){ DELEGATE(MemSetInst); }
RetTy visitMemSetPatternInst(MemSetPatternInst &I) { DELEGATE(MemSetInst); }
RetTy visitMemCpyInst(MemCpyInst &I) { DELEGATE(MemTransferInst); }
RetTy visitMemCpyInlineInst(MemCpyInlineInst &I){ DELEGATE(MemCpyInst); }
RetTy visitMemMoveInst(MemMoveInst &I) { DELEGATE(MemTransferInst); }
Expand Down Expand Up @@ -295,6 +296,8 @@ class InstVisitor {
case Intrinsic::memset: DELEGATE(MemSetInst);
case Intrinsic::memset_inline:
DELEGATE(MemSetInlineInst);
case Intrinsic::memset_pattern:
DELEGATE(MemSetPatternInst);
case Intrinsic::vastart: DELEGATE(VAStartInst);
case Intrinsic::vaend: DELEGATE(VAEndInst);
case Intrinsic::vacopy: DELEGATE(VACopyInst);
Expand Down
22 changes: 21 additions & 1 deletion llvm/include/llvm/IR/IntrinsicInst.h
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ class MemIntrinsic : public MemIntrinsicBase<MemIntrinsic> {
case Intrinsic::memmove:
case Intrinsic::memset:
case Intrinsic::memset_inline:
case Intrinsic::memset_pattern:
case Intrinsic::memcpy_inline:
return true;
default:
Expand All @@ -1219,14 +1220,16 @@ class MemIntrinsic : public MemIntrinsicBase<MemIntrinsic> {
}
};

/// This class wraps the llvm.memset and llvm.memset.inline intrinsics.
/// This class wraps the llvm.memset, llvm.memset.inline, and
/// llvm.memset_pattern intrinsics.
class MemSetInst : public MemSetBase<MemIntrinsic> {
public:
// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const IntrinsicInst *I) {
switch (I->getIntrinsicID()) {
case Intrinsic::memset:
case Intrinsic::memset_inline:
case Intrinsic::memset_pattern:
return true;
default:
return false;
Expand All @@ -1249,6 +1252,21 @@ class MemSetInlineInst : public MemSetInst {
}
};

/// This class wraps the llvm.memset.pattern intrinsic.
class MemSetPatternInst : public MemSetInst {
public:
ConstantInt *getLength() const {
return cast<ConstantInt>(MemSetInst::getLength());
}
// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::memset_pattern;
}
static bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};

/// This class wraps the llvm.memcpy/memmove intrinsics.
class MemTransferInst : public MemTransferBase<MemIntrinsic> {
public:
Expand Down Expand Up @@ -1328,6 +1346,7 @@ class AnyMemIntrinsic : public MemIntrinsicBase<AnyMemIntrinsic> {
case Intrinsic::memmove:
case Intrinsic::memset:
case Intrinsic::memset_inline:
case Intrinsic::memset_pattern:
case Intrinsic::memcpy_element_unordered_atomic:
case Intrinsic::memmove_element_unordered_atomic:
case Intrinsic::memset_element_unordered_atomic:
Expand All @@ -1350,6 +1369,7 @@ class AnyMemSetInst : public MemSetBase<AnyMemIntrinsic> {
switch (I->getIntrinsicID()) {
case Intrinsic::memset:
case Intrinsic::memset_inline:
case Intrinsic::memset_pattern:
case Intrinsic::memset_element_unordered_atomic:
return true;
default:
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,14 @@ def int_memset_inline
NoCapture<ArgIndex<0>>, WriteOnly<ArgIndex<0>>,
ImmArg<ArgIndex<3>>]>;

// Memset variant that writes a given pattern.
def int_memset_pattern
: Intrinsic<[],
[llvm_anyptr_ty, llvm_anyint_ty, llvm_anyint_ty, llvm_i1_ty],
[IntrWriteMem, IntrArgMemOnly, IntrWillReturn, IntrNoFree, IntrNoCallback,
NoCapture<ArgIndex<0>>, WriteOnly<ArgIndex<0>>,
ImmArg<ArgIndex<3>>], "llvm.memset_pattern">;

// FIXME: Add version of these floating point intrinsics which allow non-default
// rounding modes and FP exception handling.

Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses(Function &F) const {
Memset->eraseFromParent();
break;
}
case Intrinsic::memset_pattern: {
auto *Memset = cast<MemSetPatternInst>(Inst);
expandMemSetAsLoop(Memset);
Changed = true;
Memset->eraseFromParent();
break;
}
default:
llvm_unreachable("unhandled intrinsic");
}
Expand All @@ -294,6 +301,7 @@ bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const {
case Intrinsic::memmove:
case Intrinsic::memset:
case Intrinsic::memset_inline:
case Intrinsic::memset_pattern:
Changed |= expandMemIntrinsicUses(F);
break;
case Intrinsic::load_relative:
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5435,7 +5435,8 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
case Intrinsic::memcpy_inline:
case Intrinsic::memmove:
case Intrinsic::memset:
case Intrinsic::memset_inline: {
case Intrinsic::memset_inline:
case Intrinsic::memset_pattern: {
break;
}
case Intrinsic::memcpy_element_unordered_atomic:
Expand Down
110 changes: 89 additions & 21 deletions llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ static cl::opt<bool> UseLIRCodeSizeHeurs(
"with -Os/-Oz"),
cl::init(true), cl::Hidden);

static cl::opt<bool> EnableMemsetPatternIntrinsic(
"loop-idiom-enable-memset-pattern-intrinsic",
cl::desc("Enable use of the memset_pattern intrinsic."), cl::init(false),
cl::Hidden);

namespace {

class LoopIdiomRecognize {
Expand Down Expand Up @@ -300,7 +305,8 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {
HasMemsetPattern = TLI->has(LibFunc_memset_pattern16);
HasMemcpy = TLI->has(LibFunc_memcpy);

if (HasMemset || HasMemsetPattern || HasMemcpy)
if (HasMemset || HasMemsetPattern || EnableMemsetPatternIntrinsic ||
HasMemcpy)
if (SE->hasLoopInvariantBackedgeTakenCount(L))
return runOnCountableLoop();

Expand Down Expand Up @@ -457,7 +463,8 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
// It looks like we can use SplatValue.
return LegalStoreKind::Memset;
}
if (!UnorderedAtomic && HasMemsetPattern && !DisableLIRP::Memset &&
if (!UnorderedAtomic && (HasMemsetPattern || EnableMemsetPatternIntrinsic) &&
!DisableLIRP::Memset &&
// Don't create memset_pattern16s with address spaces.
StorePtr->getType()->getPointerAddressSpace() == 0 &&
getMemSetPatternValue(StoredVal, DL)) {
Expand Down Expand Up @@ -993,6 +1000,46 @@ static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
SCEV::FlagNUW);
}

ConstantInt *memSetPatternValueToI128ConstantInt(LLVMContext &Context,
Value *MemSetPatternValue) {
if (auto CIMemSetPatternValue = dyn_cast<ConstantInt>(MemSetPatternValue)) {
return CIMemSetPatternValue;
}

if (auto Array = dyn_cast<ConstantDataArray>(MemSetPatternValue)) {
Type *ElementType = Array->getElementType();
unsigned ElementSize = Array->getElementByteSize() * 8;

APInt Result(128, 0);
unsigned totalBits = 0;

for (unsigned i = 0; i < Array->getNumElements(); ++i) {
if (totalBits + ElementSize > 128) {
report_fatal_error("Pattern value unexpectedly greater than 128 bits");
}

APInt ElementBits;
if (ElementType->isIntegerTy()) {
ElementBits = Array->getElementAsAPInt(i);
} else if (ElementType->isFloatingPointTy()) {
APFloat APF = Array->getElementAsAPFloat(i);
ElementBits = APF.bitcastToAPInt();
} else {
llvm_unreachable("Unexpected element type");
}

// Shift the existing result left by the element's size and OR in the new
// value
Result = (Result << ElementSize) | ElementBits.zextOrTrunc(128);
totalBits += ElementSize;
}

// Create and return a ConstantInt with the resulting value
return ConstantInt::get(Context, Result);
}
report_fatal_error("Encountered unrecognised type");
}

/// processLoopStridedStore - We see a strided store of some value. If we can
/// transform this into a memset or memset_pattern in the loop preheader, do so.
bool LoopIdiomRecognize::processLoopStridedStore(
Expand Down Expand Up @@ -1070,7 +1117,8 @@ bool LoopIdiomRecognize::processLoopStridedStore(
Value *NumBytes =
Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator());

if (!SplatValue && !isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16))
if (!SplatValue && !(isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16) ||
EnableMemsetPatternIntrinsic))
return Changed;

AAMDNodes AATags = TheStore->getAAMetadata();
Expand All @@ -1087,24 +1135,44 @@ bool LoopIdiomRecognize::processLoopStridedStore(
BasePtr, SplatValue, NumBytes, MaybeAlign(StoreAlignment),
/*isVolatile=*/false, AATags.TBAA, AATags.Scope, AATags.NoAlias);
} else {
assert (isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16));
// Everything is emitted in default address space
Type *Int8PtrTy = DestInt8PtrTy;

StringRef FuncName = "memset_pattern16";
FunctionCallee MSP = getOrInsertLibFunc(M, *TLI, LibFunc_memset_pattern16,
Builder.getVoidTy(), Int8PtrTy, Int8PtrTy, IntIdxTy);
inferNonMandatoryLibFuncAttrs(M, FuncName, *TLI);

// Otherwise we should form a memset_pattern16. PatternValue is known to be
// an constant array of 16-bytes. Plop the value into a mergable global.
GlobalVariable *GV = new GlobalVariable(*M, PatternValue->getType(), true,
GlobalValue::PrivateLinkage,
PatternValue, ".memset_pattern");
GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these.
GV->setAlignment(Align(16));
Value *PatternPtr = GV;
NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes});
assert(isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16) ||
EnableMemsetPatternIntrinsic);
if (EnableMemsetPatternIntrinsic) {
// Everything is emitted in default address space

// Get or insert the intrinsic declaration
Function *MemsetPatternIntrinsic = Intrinsic::getDeclaration(
M, Intrinsic::memset_pattern,
{DestInt8PtrTy, Builder.getInt128Ty(), Builder.getInt64Ty()});

// Create the call to the intrinsic
NewCall = Builder.CreateCall(
MemsetPatternIntrinsic,
{BasePtr,
memSetPatternValueToI128ConstantInt(M->getContext(), PatternValue),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two high level suggestions to simplify this code:

  • Change getMemSetPatternValue to directly return the i128 result. There's no reason it should return and array, and then be converted again. The global variable can just be a i128 type, and the pointer can be passed. The Array is (should be?) an irrelevant implementation detail.
  • Move the libcall creation to lowering for the intrinsic, and always generate the intrinsic. The enable becomes not whether to use the intrinsic, but when to generate the intrinsic (i.e. are we limiting ourselves to the case where we know we have a libcall to expand to.)

NumBytes, ConstantInt::getFalse(M->getContext())});
} else {
// Everything is emitted in default address space
Type *Int8PtrTy = DestInt8PtrTy;

StringRef FuncName = "memset_pattern16";
FunctionCallee MSP = getOrInsertLibFunc(M, *TLI, LibFunc_memset_pattern16,
Builder.getVoidTy(), Int8PtrTy,
Int8PtrTy, IntIdxTy);
inferNonMandatoryLibFuncAttrs(M, FuncName, *TLI);

// Otherwise we should form a memset_pattern16. PatternValue is known to
// be an constant array of 16-bytes. Plop the value into a mergable
// global.
GlobalVariable *GV = new GlobalVariable(*M, PatternValue->getType(), true,
GlobalValue::PrivateLinkage,
PatternValue, ".memset_pattern");
GV->setUnnamedAddr(
GlobalValue::UnnamedAddr::Global); // Ok to merge these.
GV->setAlignment(Align(16));
Value *PatternPtr = GV;
NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes});
}

// Set the TBAA info if present.
if (AATags.TBAA)
Expand Down
Loading
Loading