diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 0a9139e0062ba..a210a208d01c0 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1448,6 +1448,26 @@ defm int_nvvm_cp_async_ca_shared_global_8 : CP_ASYNC_SHARED_GLOBAL<"8", "ca">; defm int_nvvm_cp_async_ca_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "ca">; defm int_nvvm_cp_async_cg_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "cg">; +// TODO(apaszke): Multicast TMA loads +foreach dim = [1, 2, 3, 4, 5] in { + def int_nvvm_cp_async_bulk_tensor_ # dim # d_shared_cluster_global_tile_mbarrier_complete_tx_bytes : + Intrinsic< + [], + [llvm_shared_ptr_ty, llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_anyptr_ty], + [IntrArgMemOnly, IntrNoCallback, + NoAlias>, NoAlias>, NoAlias>, + WriteOnly>, ReadOnly>], + "llvm.nvvm.cp.async.bulk.tensor." # dim # "d.shared_cluster.global.tile.mbarrier_complete_tx_bytes">; + def int_nvvm_cp_async_bulk_tensor_ # dim # d_global_shared_cta_tile_bulk_group : + Intrinsic< + [], + [llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_shared_ptr_ty], + [IntrNoCallback, + NoAlias>, NoAlias>, + ReadOnly>, ReadOnly>], + "llvm.nvvm.cp.async.bulk.tensor." # dim # "d.global.shared_cta.tile.bulk_group">; +} + def int_nvvm_cp_async_commit_group : ClangBuiltin<"__nvvm_cp_async_commit_group">, Intrinsic<[],[],[]>; @@ -1595,6 +1615,10 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable, IntrNoCallback], "llvm.nvvm.ptr.gen.to.param">; +def int_nvvm_ptr_param_to_gen: Intrinsic<[llvm_anyptr_ty], + [llvm_anyptr_ty], + [IntrNoMem, IntrSpeculatable, IntrNoCallback], + "llvm.nvvm.ptr.param.to.gen">; // Move intrinsics, used in nvvm internally diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 440af085cb8e9..e2a565defb95b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -403,6 +403,33 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 : CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16, int_nvvm_cp_async_cg_shared_global_16_s>; +foreach dim = [1, 2, 3, 4, 5] in { + defvar idx_ptx = !interleave(!foreach(i, !range(dim), "$idx" # i), ", "); + defvar idx_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "idx" # i)); + defvar intrinsic_g2s = !cast("int_nvvm_cp_async_bulk_tensor_" # dim # "d_shared_cluster_global_tile_mbarrier_complete_tx_bytes"); + def CP_ASYNC_BULK_TENSOR_ # dim # D_SHARED_CLUSTER_GLOBAL_TILE_MBARRIER_COMPLETE_TX_BYTES_64 : + NVPTXInst< + (outs), + !con((ins Int64Regs:$dst, Int64Regs:$desc), idx_dag, (ins Int64Regs:$mbar)), + "cp.async.bulk.tensor." # dim # "d.shared::cluster.global.tile.mbarrier::complete_tx::bytes [$dst], [$desc, {{" # idx_ptx # "}}], [$mbar];", + [!con((intrinsic_g2s Int64Regs:$dst, Int64Regs:$desc), + !setdagop(idx_dag, intrinsic_g2s), + (intrinsic_g2s Int64Regs:$mbar))] + >, + Requires<[hasPTX<80>, hasSM<90>]>; + defvar intrinsic_s2g = !cast("int_nvvm_cp_async_bulk_tensor_" # dim # "d_global_shared_cta_tile_bulk_group"); + def CP_ASYNC_BULK_TENSOR_ # dim # D_GLOBAL_SHARED_CTA_TILE_BULK_GROUP_64 : + NVPTXInst< + (outs), + !con((ins Int64Regs:$desc), idx_dag, (ins Int64Regs:$dst)), + "cp.async.bulk.tensor." # dim # "d.global.shared::cta.tile.bulk_group [$desc, {{" # idx_ptx # "}}], [$dst];", + [!con((intrinsic_s2g Int64Regs:$desc), + !setdagop(idx_dag, intrinsic_s2g), + (intrinsic_s2g Int64Regs:$dst))] + >, + Requires<[hasPTX<80>, hasSM<90>]>; +} + def CP_ASYNC_COMMIT_GROUP : NVPTXInst<(outs), (ins), "cp.async.commit_group;", [(int_nvvm_cp_async_commit_group)]>, Requires<[hasPTX<70>, hasSM<80>]>; @@ -2475,6 +2502,7 @@ defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal> defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; +defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp index cde02c25c4834..06eb2ba848762 100644 --- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp @@ -94,12 +94,17 @@ #include "NVPTXUtilities.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include #include #include @@ -146,6 +151,28 @@ INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args", "Lower arguments (NVPTX)", false, false) +static std::optional tmaDescriptorOperandIndex(Instruction *I) { + if (auto *II = dyn_cast(I)) { + switch (II->getIntrinsicID()) { + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_shared_cluster_global_tile_mbarrier_complete_tx_bytes: + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_shared_cluster_global_tile_mbarrier_complete_tx_bytes: + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_shared_cluster_global_tile_mbarrier_complete_tx_bytes: + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_shared_cluster_global_tile_mbarrier_complete_tx_bytes: + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_shared_cluster_global_tile_mbarrier_complete_tx_bytes: + return 1; + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_global_shared_cta_tile_bulk_group: + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_global_shared_cta_tile_bulk_group: + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_global_shared_cta_tile_bulk_group: + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_global_shared_cta_tile_bulk_group: + case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_global_shared_cta_tile_bulk_group: + return 0; + default: + return std::nullopt; + } + } + return std::nullopt; +} + // ============================================================================= // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d), // and we can't guarantee that the only accesses are loads, @@ -166,14 +193,15 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args", // Replaces the \p OldUser instruction with the same in parameter AS. // Only Load and GEP are supported. -static void convertToParamAS(Value *OldUser, Value *Param) { +static void convertToParamAS(Value *OldUser, Value *OldParam, Value *NewParam) { Instruction *I = dyn_cast(OldUser); assert(I && "OldUser must be an instruction"); struct IP { Instruction *OldInstruction; + Value *OldParam; Value *NewParam; }; - SmallVector ItemsToConvert = {{I, Param}}; + SmallVector ItemsToConvert = {{I, OldParam, NewParam}}; SmallVector InstructionsToDelete; auto CloneInstInParamAS = [](const IP &I) -> Value * { @@ -200,6 +228,28 @@ static void convertToParamAS(Value *OldUser, Value *Param) { // Just pass through the argument, the old ASC is no longer needed. return I.NewParam; } + if (auto *II = dyn_cast(I.OldInstruction)) { + // Assert that this is a TMA intrinsic. + assert(tmaDescriptorOperandIndex(II).has_value()); + assert(I.OldInstruction->getOperand(*tmaDescriptorOperandIndex(II)) == + I.OldParam); + // TMA descriptors can remain in param memory space, but need to be passed + // in the generic address space. + Type *ParamPtr = PointerType::get(II->getContext(), ADDRESS_SPACE_PARAM); + Type *GenericPtr = + PointerType::get(II->getContext(), ADDRESS_SPACE_GENERIC); + FunctionType *cast_func_ty = + FunctionType::get(GenericPtr, {ParamPtr}, false); + Module *M = I.OldInstruction->getModule(); + FunctionCallee func = + M->getOrInsertFunction(getName(llvm::Intrinsic::nvvm_ptr_param_to_gen, + {GenericPtr, ParamPtr}, M), + cast_func_ty); + Instruction *NewInGeneric = + CallInst::Create(func, {I.NewParam}, "", II->getIterator()); + II->replaceUsesOfWith(I.OldParam, NewInGeneric); + return II; + } llvm_unreachable("Unsupported instruction"); }; @@ -212,7 +262,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) { // be converted and the instruction itself to be deleted. We can't delete // the old instruction yet, because it's still in use by a load somewhere. for (Value *V : I.OldInstruction->users()) - ItemsToConvert.push_back({cast(V), NewInst}); + ItemsToConvert.push_back( + {cast(V), I.OldInstruction, NewInst}); InstructionsToDelete.push_back(I.OldInstruction); } @@ -300,9 +351,13 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS, Worklist.push({I, Ctx.Offset + Offset}); continue; } + if (auto *II = dyn_cast(CurUser)) { + assert(tmaDescriptorOperandIndex(II).has_value()); + continue; + } llvm_unreachable("All users must be one of: load, " - "bitcast, getelementptr."); + "bitcast, getelementptr, TMA intrinsic."); } } @@ -321,8 +376,11 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM, assert(StructType && "Missing byval type"); auto IsALoadChain = [&](Value *Start) { - SmallVector ValuesToCheck = {Start}; - auto IsALoadChainInstr = [](Value *V) -> bool { + SmallVector UsesToCheck; + for (Use& u : Start->uses()) + UsesToCheck.push_back(&u); + auto IsSupportedUse = [](Use *U) -> bool { + Value *V = U->get(); if (isa(V) || isa(V) || isa(V)) return true; // ASC to param space are OK, too -- we'll just strip them. @@ -330,19 +388,26 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM, if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM) return true; } + // TMA descriptors passed to TMA intrinsics are OK, too. + if (auto *II = dyn_cast(V)) { + auto OI = tmaDescriptorOperandIndex(II); + return OI.has_value() && *OI == U->getOperandNo(); + } return false; }; - while (!ValuesToCheck.empty()) { - Value *V = ValuesToCheck.pop_back_val(); - if (!IsALoadChainInstr(V)) { - LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V + while (!UsesToCheck.empty()) { + Use* U = UsesToCheck.pop_back_val(); + if (!IsSupportedUse(U)) { + LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << U << "\n"); (void)Arg; return false; } - if (!isa(V)) - llvm::append_range(ValuesToCheck, V->users()); + if (!isa(U)) { + for (Use& u : U->getUser()->uses()) + UsesToCheck.push_back(&u); + } } return true; }; @@ -355,7 +420,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM, Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), FirstInst); for (Value *V : UsersToUpdate) - convertToParamAS(V, ArgInParamAS); + convertToParamAS(V, Arg, ArgInParamAS); LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n"); const auto *TLI =