diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 3c31d4a4cd378..c173e3dd7d0e5 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -177,7 +177,6 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, assert((!isa(Stride) || cast(Stride)->getZExtValue() >= NumElements) && "Stride must be >= the number of elements in the result vector."); - unsigned AS = cast(BasePtr->getType())->getAddressSpace(); // Compute the start of the vector with index VecIdx as VecIdx * Stride. Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); @@ -189,11 +188,7 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, else VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); - // Cast elementwise vector start pointer to a pointer to a vector - // (EltType x NumElements)*. - auto *VecType = FixedVectorType::get(EltType, NumElements); - Type *VecPtrType = PointerType::get(VecType, AS); - return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); + return VecStart; } /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. @@ -1060,13 +1055,6 @@ class LowerMatrixIntrinsics { return Changed; } - /// Turns \p BasePtr into an elementwise pointer to \p EltType. - Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { - unsigned AS = cast(BasePtr->getType())->getAddressSpace(); - Type *EltPtrType = PointerType::get(EltType, AS); - return Builder.CreatePointerCast(BasePtr, EltPtrType); - } - /// Replace intrinsic calls bool VisitCallInst(CallInst *Inst) { if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) @@ -1118,7 +1106,7 @@ class LowerMatrixIntrinsics { auto *VType = cast(Ty); Type *EltTy = VType->getElementType(); Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); - Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); + Value *EltPtr = Ptr; MatrixTy Result; for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { Value *GEP = computeVectorAddr( @@ -1144,17 +1132,11 @@ class LowerMatrixIntrinsics { Value *Offset = Builder.CreateAdd( Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); - unsigned AS = cast(MatrixPtr->getType())->getAddressSpace(); - Value *EltPtr = - Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); - Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns); - Type *TilePtrTy = PointerType::get(TileTy, AS); - Value *TilePtr = - Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); - return loadMatrix(TileTy, TilePtr, Align, + return loadMatrix(TileTy, TileStart, Align, Builder.getInt64(MatrixShape.getStride()), IsVolatile, ResultShape, Builder); } @@ -1190,17 +1172,11 @@ class LowerMatrixIntrinsics { Value *Offset = Builder.CreateAdd( Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); - unsigned AS = cast(MatrixPtr->getType())->getAddressSpace(); - Value *EltPtr = - Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); - Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * StoreVal.getNumColumns()); - Type *TilePtrTy = PointerType::get(TileTy, AS); - Value *TilePtr = - Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); - storeMatrix(TileTy, StoreVal, TilePtr, MAlign, + storeMatrix(TileTy, StoreVal, TileStart, MAlign, Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); } @@ -1210,7 +1186,7 @@ class LowerMatrixIntrinsics { MaybeAlign MAlign, Value *Stride, bool IsVolatile, IRBuilder<> &Builder) { auto VType = cast(Ty); - Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); + Value *EltPtr = Ptr; for (auto Vec : enumerate(StoreVal.vectors())) { Value *GEP = computeVectorAddr( EltPtr,