diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp index f0bd25f167d80..cff46e15251bc 100644 --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -349,8 +349,27 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, SmallVector Ops(GEP->operands()); + // If the base pointer is a vector, check if it's strided. + Value *Base = GEP->getPointerOperand(); + if (auto *BaseInst = dyn_cast(Base); + BaseInst && BaseInst->getType()->isVectorTy()) { + // If GEP's offset is scalar then we can add it to the base pointer's base. + auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); }; + if (all_of(GEP->indices(), IsScalar)) { + auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder); + if (BaseBase) { + Builder.SetInsertPoint(GEP); + SmallVector Indices(GEP->indices()); + Value *OffsetBase = + Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices, + GEP->getName() + "offset", GEP->isInBounds()); + return {OffsetBase, Stride}; + } + } + } + // Base pointer needs to be a scalar. - Value *ScalarBase = Ops[0]; + Value *ScalarBase = Base; if (ScalarBase->getType()->isVectorTy()) { ScalarBase = getSplatValue(ScalarBase); if (!ScalarBase) diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll index 8733c5dc83d65..70412de1d0e91 100644 --- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll +++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll @@ -301,10 +301,8 @@ define void @constant_stride( %x, ptr %p, i64 %stride) { define @vector_base_scalar_offset(ptr %p, i64 %offset) { ; CHECK-LABEL: @vector_base_scalar_offset( -; CHECK-NEXT: [[STEP:%.*]] = call @llvm.experimental.stepvector.nxv1i64() -; CHECK-NEXT: [[PTRS1:%.*]] = getelementptr i64, ptr [[P:%.*]], [[STEP]] -; CHECK-NEXT: [[PTRS2:%.*]] = getelementptr i64, [[PTRS1]], i64 [[OFFSET:%.*]] -; CHECK-NEXT: [[X:%.*]] = call @llvm.masked.gather.nxv1i64.nxv1p0( [[PTRS2]], i32 8, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), poison) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i64, ptr [[P:%.*]], i64 [[OFFSET:%.*]] +; CHECK-NEXT: [[X:%.*]] = call @llvm.riscv.masked.strided.load.nxv1i64.p0.i64( poison, ptr [[TMP1]], i64 8, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) ; CHECK-NEXT: ret [[X]] ; %step = call @llvm.experimental.stepvector.nxv1i64() @@ -321,10 +319,8 @@ define @vector_base_scalar_offset(ptr %p, i64 %offset) { define @splat_base_scalar_offset(ptr %p, i64 %offset) { ; CHECK-LABEL: @splat_base_scalar_offset( -; CHECK-NEXT: [[HEAD:%.*]] = insertelement poison, ptr [[P:%.*]], i32 0 -; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[HEAD]], poison, zeroinitializer -; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i64, [[SPLAT]], i64 [[OFFSET:%.*]] -; CHECK-NEXT: [[X:%.*]] = call @llvm.masked.gather.nxv1i64.nxv1p0( [[PTRS]], i32 8, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), poison) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i64, ptr [[P:%.*]], i64 [[OFFSET:%.*]] +; CHECK-NEXT: [[X:%.*]] = call @llvm.riscv.masked.strided.load.nxv1i64.p0.i64( poison, ptr [[TMP1]], i64 0, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) ; CHECK-NEXT: ret [[X]] ; %head = insertelement poison, ptr %p, i32 0