From 4e2096852f26edb361c7caadfbf56ac771011cd9 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Fri, 8 Sep 2023 07:57:12 -0700 Subject: [PATCH] [RISCV] Match gather(splat(ptr)) as zero strided load We were already handling the case where the broadcast was being done via a GEP, but we hadn't handled the case of a broadcast via a shuffle. --- .../RISCV/RISCVGatherScatterLowering.cpp | 30 ++++++++---- .../RISCV/rvv/fixed-vectors-masked-gather.ll | 47 ++----------------- 2 files changed, 23 insertions(+), 54 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp index fac3526c43148..0e9244d0aefa8 100644 --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -67,7 +67,7 @@ class RISCVGatherScatterLowering : public FunctionPass { bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, Value *AlignOp); - std::pair determineBaseAndStride(GetElementPtrInst *GEP, + std::pair determineBaseAndStride(Instruction *Ptr, IRBuilderBase &Builder); bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, @@ -321,9 +321,19 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, } std::pair -RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP, +RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, IRBuilderBase &Builder) { + // A gather/scatter of a splat is a zero strided load/store. + if (auto *BasePtr = getSplatValue(Ptr)) { + Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); + return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0)); + } + + auto *GEP = dyn_cast(Ptr); + if (!GEP) + return std::make_pair(nullptr, nullptr); + auto I = StridedAddrs.find(GEP); if (I != StridedAddrs.end()) return I->second; @@ -452,17 +462,17 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, if (!TLI->isTypeLegal(DataTypeVT)) return false; - // Pointer should be a GEP. - auto *GEP = dyn_cast(Ptr); - if (!GEP) + // Pointer should be an instruction. + auto *PtrI = dyn_cast(Ptr); + if (!PtrI) return false; - LLVMContext &Ctx = GEP->getContext(); + LLVMContext &Ctx = PtrI->getContext(); IRBuilder Builder(Ctx, *DL); - Builder.SetInsertPoint(GEP); + Builder.SetInsertPoint(PtrI); Value *BasePtr, *Stride; - std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder); + std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder); if (!BasePtr) return false; assert(Stride != nullptr); @@ -485,8 +495,8 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, II->replaceAllUsesWith(Call); II->eraseFromParent(); - if (GEP->use_empty()) - RecursivelyDeleteTriviallyDeadInstructions(GEP); + if (PtrI->use_empty()) + RecursivelyDeleteTriviallyDeadInstructions(PtrI); return true; } diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll index cb3ee899dde7d..25ef59e111fae 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll @@ -12918,60 +12918,19 @@ define <4 x i32> @mgather_broadcast_load_unmasked2(ptr %base) { ; RV32-LABEL: mgather_broadcast_load_unmasked2: ; RV32: # %bb.0: ; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma -; RV32-NEXT: vmv.v.x v8, a0 -; RV32-NEXT: vluxei32.v v8, (zero), v8 +; RV32-NEXT: vlse32.v v8, (a0), zero ; RV32-NEXT: ret ; ; RV64V-LABEL: mgather_broadcast_load_unmasked2: ; RV64V: # %bb.0: -; RV64V-NEXT: vsetivli zero, 4, e64, m2, ta, ma -; RV64V-NEXT: vmv.v.x v10, a0 -; RV64V-NEXT: vsetvli zero, zero, e32, m1, ta, ma -; RV64V-NEXT: vluxei64.v v8, (zero), v10 +; RV64V-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; RV64V-NEXT: vlse32.v v8, (a0), zero ; RV64V-NEXT: ret ; ; RV64ZVE32F-LABEL: mgather_broadcast_load_unmasked2: ; RV64ZVE32F: # %bb.0: -; RV64ZVE32F-NEXT: vsetivli zero, 8, e8, mf2, ta, ma -; RV64ZVE32F-NEXT: vmset.m v8 -; RV64ZVE32F-NEXT: vmv.x.s a1, v8 -; RV64ZVE32F-NEXT: # implicit-def: $v8 -; RV64ZVE32F-NEXT: beqz zero, .LBB100_5 -; RV64ZVE32F-NEXT: # %bb.1: # %else -; RV64ZVE32F-NEXT: andi a2, a1, 2 -; RV64ZVE32F-NEXT: bnez a2, .LBB100_6 -; RV64ZVE32F-NEXT: .LBB100_2: # %else2 -; RV64ZVE32F-NEXT: andi a2, a1, 4 -; RV64ZVE32F-NEXT: bnez a2, .LBB100_7 -; RV64ZVE32F-NEXT: .LBB100_3: # %else5 -; RV64ZVE32F-NEXT: andi a1, a1, 8 -; RV64ZVE32F-NEXT: bnez a1, .LBB100_8 -; RV64ZVE32F-NEXT: .LBB100_4: # %else8 -; RV64ZVE32F-NEXT: ret -; RV64ZVE32F-NEXT: .LBB100_5: # %cond.load ; RV64ZVE32F-NEXT: vsetivli zero, 4, e32, m1, ta, ma ; RV64ZVE32F-NEXT: vlse32.v v8, (a0), zero -; RV64ZVE32F-NEXT: andi a2, a1, 2 -; RV64ZVE32F-NEXT: beqz a2, .LBB100_2 -; RV64ZVE32F-NEXT: .LBB100_6: # %cond.load1 -; RV64ZVE32F-NEXT: lw a2, 0(a0) -; RV64ZVE32F-NEXT: vsetivli zero, 2, e32, m1, tu, ma -; RV64ZVE32F-NEXT: vmv.s.x v9, a2 -; RV64ZVE32F-NEXT: vslideup.vi v8, v9, 1 -; RV64ZVE32F-NEXT: andi a2, a1, 4 -; RV64ZVE32F-NEXT: beqz a2, .LBB100_3 -; RV64ZVE32F-NEXT: .LBB100_7: # %cond.load4 -; RV64ZVE32F-NEXT: lw a2, 0(a0) -; RV64ZVE32F-NEXT: vsetivli zero, 3, e32, m1, tu, ma -; RV64ZVE32F-NEXT: vmv.s.x v9, a2 -; RV64ZVE32F-NEXT: vslideup.vi v8, v9, 2 -; RV64ZVE32F-NEXT: andi a1, a1, 8 -; RV64ZVE32F-NEXT: beqz a1, .LBB100_4 -; RV64ZVE32F-NEXT: .LBB100_8: # %cond.load7 -; RV64ZVE32F-NEXT: lw a0, 0(a0) -; RV64ZVE32F-NEXT: vsetivli zero, 4, e32, m1, ta, ma -; RV64ZVE32F-NEXT: vmv.s.x v9, a0 -; RV64ZVE32F-NEXT: vslideup.vi v8, v9, 3 ; RV64ZVE32F-NEXT: ret %head = insertelement <4 x i1> poison, i1 true, i32 0 %allones = shufflevector <4 x i1> %head, <4 x i1> poison, <4 x i32> zeroinitializer