diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h index 6c4643da18849..c258513ed4878 100644 --- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h +++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h @@ -30,11 +30,21 @@ class MMAMatrixType; void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); +/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, +/// using the KHR Cooperative Matrix extension. +void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( + SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); + /// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, /// using the NV Cooperative Matrix extension. void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns( SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); +/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType +/// `type`. +spirv::CooperativeMatrixType +convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type); + /// Returns an NV cooperative matrix type corresponding to the MMAMatrixType /// `type`. spirv::CooperativeMatrixNVType diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 3218760931b8c..11008baa0160e 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -567,7 +567,11 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> { let options = [ Option<"use64bitIndex", "use-64bit-index", "bool", /*default=*/"false", - "Use 64-bit integers to convert index types"> + "Use 64-bit integers to convert index types">, + Option<"useCoopMatrixNV", "use-coop-matrix-nv", + "bool", /*default=*/"true", + "Use the NV cooperative matrix extension insted of the KHR extension" + " to lower GPU WMMA ops">, ]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td index b5ea0774f589d..34c76c5e93823 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -146,6 +146,15 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad" let results = (outs SPIRV_AnyCooperativeMatrix:$result ); + + let builders = [ + OpBuilder<(ins "Type":$result, "Value":$pointer, + "spirv::ConstantOp":$stride, + "spirv::CooperativeMatrixLayoutKHR":$layout), [{ + build($_builder, $_state, result, pointer, layout, stride, + spirv::MemoryAccessAttr{}); + }]> + ]; } // ----- @@ -226,6 +235,15 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor ); let results = (outs); + + let builders = [ + OpBuilder<(ins "Value":$pointer, "Value":$object, + "spirv::ConstantOp":$stride, + "spirv::CooperativeMatrixLayoutKHR":$layout), [{ + build($_builder, $_state, pointer, object, layout, stride, + spirv::MemoryAccessAttr{}); + }]> + ]; } // ----- @@ -332,6 +350,13 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul let results = (outs SPIRV_AnyCooperativeMatrix:$result ); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{ + build($_builder, $_state, a, b, c, + spirv::CooperativeMatrixOperandsKHRAttr{}); + }]> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index d0ce58597f980..5b05c45bf6025 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -86,13 +86,25 @@ void GPUToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.use64bitIndex = this->use64bitIndex; SPIRVTypeConverter typeConverter(targetAttr, options); - typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type { - return convertMMAToSPIRVCoopMatrixNVType(type); + + typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()]( + gpu::MMAMatrixType type) -> Type { + if (useNV) + return convertMMAToSPIRVCoopMatrixNVType(type); + + return convertMMAToSPIRVCoopMatrixType(type); }); + RewritePatternSet patterns(context); populateGPUToSPIRVPatterns(typeConverter, patterns); - populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter, - patterns); + if (this->useCoopMatrixNV) { + populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter, + patterns); + } else { + populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter, + patterns); + } + // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index bf3fff027fe38..d73cd5686d66e 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -18,22 +18,28 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/StringSwitch.h" -namespace mlir::nv { -namespace { +#include +namespace mlir { /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op /// when the elementwise op directly supports with cooperative matrix type. /// Returns false if cannot. /// /// See SPV_NV_cooperative_matrix for supported elementwise ops. static bool createElementwiseOp(ConversionPatternRewriter &builder, - gpu::SubgroupMmaElementwiseOp op, - spirv::CooperativeMatrixNVType coopType, + gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands) { + assert((isa( + coopType))); + switch (op.getOpType()) { case gpu::MMAElementwiseOp::ADDF: builder.replaceOpWithNewOp(op, coopType, operands); @@ -71,6 +77,110 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder, return false; } +//===----------------------------------------------------------------------===// +// SPV_KHR_cooperative_matrix +//===----------------------------------------------------------------------===// + +namespace khr { +namespace { + +/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV +/// dialect. +struct WmmaLoadOpToSPIRVLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto &typeConverter = *getTypeConverter(); + Location loc = op->getLoc(); + + auto retType = cast(op.getRes().getType()); + MemRefType memrefType = op.getSrcMemref().getType(); + Value bufferPtr = + spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(), + adaptor.getIndices(), loc, rewriter); + + auto coopType = + typeConverter.convertType(retType); + if (!coopType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + int64_t stride = op.getLeadDimension().getSExtValue(); + IntegerType i32Type = rewriter.getI32Type(); + auto strideValue = rewriter.create( + loc, i32Type, IntegerAttr::get(i32Type, stride)); + + bool isColMajor = op.getTranspose().value_or(false); + auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor + : spirv::CooperativeMatrixLayoutKHR::RowMajor; + + rewriter.replaceOpWithNewOp( + op, coopType, bufferPtr, strideValue, layout); + return success(); + } +}; + +/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV +/// dialect. +struct WmmaStoreOpToSPIRVLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto &typeConverter = *getTypeConverter(); + Location loc = op->getLoc(); + + auto memrefType = cast(op.getDstMemref().getType()); + Value bufferPtr = + spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(), + adaptor.getIndices(), loc, rewriter); + + int64_t stride = op.getLeadDimension().getSExtValue(); + IntegerType i32Type = rewriter.getI32Type(); + auto strideValue = rewriter.create( + loc, i32Type, IntegerAttr::get(i32Type, stride)); + + bool isColMajor = op.getTranspose().value_or(false); + auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor + : spirv::CooperativeMatrixLayoutKHR::RowMajor; + + rewriter.replaceOpWithNewOp( + op, bufferPtr, adaptor.getSrc(), strideValue, layout); + return success(); + } +}; + +/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV +/// dialect. +struct WmmaMmaOpToSPIRVLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(), + adaptor.getOpC()); + return success(); + } +}; + +} // namespace +} // namespace khr + +//===----------------------------------------------------------------------===// +// SPV_NV_cooperative_matrix +//===----------------------------------------------------------------------===// + +namespace nv { +namespace { + /// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV /// dialect. struct WmmaLoadOpToSPIRVLowering final @@ -247,7 +357,8 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final }; } // namespace -} // namespace mlir::nv +} // namespace nv +} // namespace mlir mlir::spirv::CooperativeMatrixNVType mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) { @@ -257,6 +368,30 @@ mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) { elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]); } +mlir::spirv::CooperativeMatrixType +mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) { + ArrayRef retTypeShape = type.getShape(); + Type elementType = type.getElementType(); + + auto use = + llvm::StringSwitch(type.getOperand()) + .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA) + .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB) + .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc); + + return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0], + retTypeShape[1], + spirv::Scope::Subgroup, use); +} + +void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( + SPIRVTypeConverter &converter, RewritePatternSet &patterns) { + using namespace mlir; + MLIRContext *context = patterns.getContext(); + patterns.add(converter, context); +} + void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns( SPIRVTypeConverter &converter, RewritePatternSet &patterns) { using namespace mlir; diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir new file mode 100644 index 0000000000000..0818791b98471 --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir @@ -0,0 +1,80 @@ +// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \ +// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, + #spirv.resource_limits<>>} { + + gpu.module @kernels { + // CHECK-LABEL: spirv.func @gpu_wmma_load_op + // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> + gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %i = arith.constant 16 : index + %j = arith.constant 16 : index + // CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32 + // CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], : + // CHECK-SAME: !spirv.ptr, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : + memref<32x32xf16, #spirv.storage_class> -> !gpu.mma_matrix<16x16xf16, "COp"> + + // CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], : + // CHECK-SAME: !spirv.ptr, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %1 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : + memref<32x32xf16, #spirv.storage_class> -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: spirv.Return + gpu.return + } + + // CHECK-LABEL: spirv.func @gpu_wmma_store_op + // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + gpu.func @gpu_wmma_store_op(%arg0: memref<32x32xf16, #spirv.storage_class>, + %arg1: !gpu.mma_matrix<16x16xf16, "COp">) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %i = arith.constant 16 : index + %j = arith.constant 16 : index + // CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32 + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], : + // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index} : + !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class> + + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], : + // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index, transpose} : + !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class> + // CHECK: spirv.Return + gpu.return + } + + // CHECK-LABEL: spirv.func @gpu_wmma_mma_op + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + gpu.func @gpu_wmma_mma_op(%A: !gpu.mma_matrix<16x16xf16, "AOp">, + %B: !gpu.mma_matrix<16x16xf16, "BOp">, + %C: !gpu.mma_matrix<16x16xf16, "COp">, + %ptr: memref<16x16xf16, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: %[[MAD:.*]] = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB> + // CHECK-SAME: -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, + !gpu.mma_matrix<16x16xf16, "BOp"> + -> !gpu.mma_matrix<16x16xf16, "COp"> + + %i = arith.constant 0 : index + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, + gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} : + !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class> + // CHECK: spirv.Return + gpu.return + } + + } +} diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir index 5811c791f308d..ec7da92704c07 100644 --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \ +// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s module attributes { gpu.container_module,