From d5c41ed566118bdc174f03bb816a199a056a1048 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Wed, 28 Aug 2024 14:10:55 +0100 Subject: [PATCH 1/3] [MLIR][AMDGPU] Add support for fp8 ops on gfx12 --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 +- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 7 +++- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 37 +++++++++++-------- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 4 +- .../Conversion/AMDGPUToROCDL/wmma-gfx12.mlir | 9 +++++ mlir/test/Target/LLVMIR/rocdl.mlir | 10 +++++ 6 files changed, 50 insertions(+), 19 deletions(-) create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index aa2b4543927a7..35789984c9221 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -504,7 +504,7 @@ def MFMAOutTypes : AnyTypeOf<[F64, VectorOfLengthAndType<[4, 16, 32], [I32]>, VectorOfLengthAndType<[4], [F64]>]>; // wmma -def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>; +def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>; def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>, VectorOfLengthAndType<[8, 16], [F16, BF16]>]>; diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 868208ff74a52..bbb6e666d8295 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -165,7 +165,7 @@ def ROCDL_BallotOp : let summary = "Vote across thread group"; let description = [{ - Ballot provides a bit mask containing the 1-bit predicate value from each lane. + Ballot provides a bit mask containing the 1-bit predicate value from each lane. The nth bit of the result contains the 1 bit contributed by the nth warp lane. }]; @@ -328,13 +328,16 @@ class ROCDL_Wmma_IntrOp overloadedOperands, "$args attr-dict `:` functional-type($args, $res)"; } -// Available on RDNA3 +// Available from gfx11 def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>; def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>; def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>; def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>; def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>; def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>; +// Available from gfx12 +def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>; +def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; //===---------------------------------------------------------------------===// // Operations on raw buffer resources (stride of 0, bounds checks either off or in diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index b808738804030..45c5070333b52 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, + Value mlirInput, SmallVector &operands) { Type inputType = llvmInput.getType(); auto vectorType = dyn_cast(inputType); @@ -398,23 +399,25 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, return; } + auto mlirInputType = dyn_cast(mlirInput.getType()); + if (mlirInputType.getElementType().isInteger(8)) { + // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag + bool localIsUnsigned = isUnsigned; + if (elemType.isUnsignedInteger(8)) { + localIsUnsigned = true; + } else if (elemType.isSignedInteger(8)) { + localIsUnsigned = false; + } + Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); + operands.push_back(sign); + } + int64_t numBytes = vectorType.getNumElements(); Type i32 = rewriter.getI32Type(); VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32); auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits); - Value result = rewriter.createOrFold( loc, llvmVectorType32bits, llvmInput); - - // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag - bool localIsUnsigned = isUnsigned; - if (elemType.isUnsignedInteger(8)) { - localIsUnsigned = true; - } else if (elemType.isSignedInteger(8)) { - localIsUnsigned = false; - } - Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); - operands.push_back(sign); operands.push_back(result); } @@ -601,6 +604,10 @@ static std::optional wmmaOpToIntrinsic(WMMAOp wmma, return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) { return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); + } else if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) { + return ROCDL::wmma_f32_16x16x16_fp8::getOperationName(); + } else if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) { + return ROCDL::wmma_f32_16x16x16_bf8::getOperationName(); } return std::nullopt; } @@ -662,8 +669,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { Location loc = op.getLoc(); Type outType = typeConverter->convertType(op.getDestD().getType()); - if (chipset.majorVersion != 11) - return op->emitOpError("WMMA only supported on gfx11"); + if (chipset.majorVersion != 11 && chipset.majorVersion != 12) + return op->emitOpError("WMMA only supported on gfx11 and gfx12"); std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); @@ -675,9 +682,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { SmallVector operands; wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), - adaptor.getSourceA(), operands); + adaptor.getSourceA(), op.getSourceA(), operands); wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), - adaptor.getSourceB(), operands); + adaptor.getSourceB(), op.getSourceB(), operands); wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), op.getSubwordOffset(), op.getClamp(), operands); diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index e3beceaa3bbb5..a8d6ccdc1a471 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -235,7 +235,9 @@ LogicalResult WMMAOp::verify() { bool isDestFloat = (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16()); - bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16()); + bool isSrcFloat = + (sourceAElemType.isF16() || sourceAElemType.isBF16() || + sourceAElemType.isFloat8E4M3FN() || sourceAElemType.isFloat8E5M2()); if (isDestFloat && !isSrcFloat) { return emitOpError("Expected float sources with float destination"); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir new file mode 100644 index 0000000000000..7b2b524d4af42 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s +func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) { + // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32> + func.return +} diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 78c3987fab648..79f5c133503d4 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -363,6 +363,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, llvm.return %rsrc : !llvm.ptr<8> } +llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}}) + %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}}) + %r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + + llvm.return %r0 : vector<8 x f32> +} + llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>, %offset : i32, %soffset : i32, %vdata1 : i32, From dbfc608b39b0e34fd7ca2b92e8c81fcb275488d5 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 29 Aug 2024 10:48:07 +0100 Subject: [PATCH 2/3] Address review feeback --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 45c5070333b52..8c739de0ab151 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -399,8 +399,12 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, return; } - auto mlirInputType = dyn_cast(mlirInput.getType()); - if (mlirInputType.getElementType().isInteger(8)) { + // We need to check the type of the input before conversion to properly test + // for int8. This is because, in LLVM, fp8 type is converted to int8, so the + // fp8/int8 information is lost during the conversion process. + auto mlirInputType = cast(mlirInput.getType()); + bool isInputInt8 = mlirInputType.getElementType().isInteger(8); + if (isInputInt8) { // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag bool localIsUnsigned = isUnsigned; if (elemType.isUnsignedInteger(8)) { @@ -593,22 +597,20 @@ static std::optional wmmaOpToIntrinsic(WMMAOp wmma, auto elemSourceType = sourceVectorType.getElementType(); auto elemDestType = destVectorType.getElementType(); - if (elemSourceType.isF16() && elemDestType.isF32()) { + if (elemSourceType.isF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); - } - if (elemSourceType.isBF16() && elemDestType.isF32()) { + if (elemSourceType.isBF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); - } else if (elemSourceType.isF16() && elemDestType.isF16()) { + if (elemSourceType.isF16() && elemDestType.isF16()) return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); - } else if (elemSourceType.isBF16() && elemDestType.isBF16()) { + if (elemSourceType.isBF16() && elemDestType.isBF16()) return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); - } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) { + if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - } else if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) { + if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8::getOperationName(); - } else if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) { + if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8::getOperationName(); - } return std::nullopt; } From b0ceff8150303b7d6e1f72e1b40b1dc290abbc98 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Mon, 2 Sep 2024 20:01:53 +0100 Subject: [PATCH 3/3] Address review feedback - 2 --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index a8d6ccdc1a471..1bc41ba9c8cf5 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -233,11 +233,10 @@ LogicalResult WMMAOp::verify() { Type sourceAElemType = sourceVectorAType.getElementType(); Type destElemType = destVectorType.getElementType(); - bool isDestFloat = - (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16()); + bool isDestFloat = isa(destElemType); bool isSrcFloat = - (sourceAElemType.isF16() || sourceAElemType.isBF16() || - sourceAElemType.isFloat8E4M3FN() || sourceAElemType.isFloat8E5M2()); + isa( + sourceAElemType); if (isDestFloat && !isSrcFloat) { return emitOpError("Expected float sources with float destination");