-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. #145395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
lialan
commented
Jun 23, 2025
- 1-to-1 mapping wrapper op.
- Direct lowering from AMDGPU wrapper to ROCDL intrinsics.
@llvm/pr-subscribers-backend-amdgpu @llvm/pr-subscribers-mlir-amdgpu Author: Alan Li (lialan) Changes
Full diff: https://github.com/llvm/llvm-project/pull/145395.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d58558ac32884..003aff6d38da0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -898,6 +898,27 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}
+def AMDGPU_TransposeLoadOp :
+ AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
+ Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
+ Results<(outs MFMAInTypes:$dst)> {
+ let summary = "MLIR wrapper for CDNA Transpose Load instructions";
+ let description = [{
+ The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
+
+ Operands:
+ * `$src`: LDS memref to read from.
+ * `$srcIndices`: indices into `$src` to read from for this thread.
+ * `$dst`: target register this transpose load instruction will write to.
+
+ Note: Lowering is only supported on gfx950 and up.
+ }];
+ let assemblyFormat = [{
+ $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_ScaledMFMAOp :
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 700563460f525..62ed1d871bcfd 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1100,6 +1100,49 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct TransposeLoadOpLowering
+ : public ConvertOpToLLVMPattern<TransposeLoadOp> {
+ TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx950)
+ return op.emitOpError("Non-gfx950 chipset not supported");
+
+ Location loc = op.getLoc();
+ auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+ Value srcPtr =
+ getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+ (adaptor.getSrcIndices()));
+ auto elementTypeSize = cast<VectorType>(op.getDst().getType())
+ .getElementType()
+ .getIntOrFloatBitWidth();
+
+ // TODO: support ds_read_tr16_b64 intrinsic.
+ switch (elementTypeSize) {
+ case 4:
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
+ op, op.getDst().getType(), srcPtr);
+ break;
+ case 8:
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
+ op, op.getDst().getType(), srcPtr);
+ break;
+ case 16:
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
+ op, op.getDst().getType(), srcPtr);
+ break;
+ default:
+ return op.emitOpError("Unsupported element size for transpose load");
+ }
+ return success();
+ }
+};
+
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1792,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
- chipset);
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+ TransposeLoadOpLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 0d0add3094666..00e9019b79647 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -524,6 +524,24 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
+LogicalResult TransposeLoadOp::verify() {
+ MemRefType srcType = cast<MemRefType>(getSrc().getType());
+
+ if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
+ return emitOpError("source memory address space must be Workgroup");
+
+ // TODO: support 6-bit element type vectors.
+ auto transferType = dyn_cast<VectorType>(getDst().getType());
+ if (!transferType)
+ return emitOpError("destination type must be a vector type");
+ size_t transferSize =
+ transferType.getNumElements() * transferType.getElementTypeBitWidth();
+ if (transferSize != 64)
+ return emitOpError("Transfering type size must be 64 bits");
+
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
|
@llvm/pr-subscribers-mlir Author: Alan Li (lialan) Changes
Full diff: https://github.com/llvm/llvm-project/pull/145395.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d58558ac32884..003aff6d38da0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -898,6 +898,27 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}
+def AMDGPU_TransposeLoadOp :
+ AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
+ Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
+ Results<(outs MFMAInTypes:$dst)> {
+ let summary = "MLIR wrapper for CDNA Transpose Load instructions";
+ let description = [{
+ The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
+
+ Operands:
+ * `$src`: LDS memref to read from.
+ * `$srcIndices`: indices into `$src` to read from for this thread.
+ * `$dst`: target register this transpose load instruction will write to.
+
+ Note: Lowering is only supported on gfx950 and up.
+ }];
+ let assemblyFormat = [{
+ $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_ScaledMFMAOp :
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 700563460f525..62ed1d871bcfd 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1100,6 +1100,49 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct TransposeLoadOpLowering
+ : public ConvertOpToLLVMPattern<TransposeLoadOp> {
+ TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx950)
+ return op.emitOpError("Non-gfx950 chipset not supported");
+
+ Location loc = op.getLoc();
+ auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+ Value srcPtr =
+ getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+ (adaptor.getSrcIndices()));
+ auto elementTypeSize = cast<VectorType>(op.getDst().getType())
+ .getElementType()
+ .getIntOrFloatBitWidth();
+
+ // TODO: support ds_read_tr16_b64 intrinsic.
+ switch (elementTypeSize) {
+ case 4:
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
+ op, op.getDst().getType(), srcPtr);
+ break;
+ case 8:
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
+ op, op.getDst().getType(), srcPtr);
+ break;
+ case 16:
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
+ op, op.getDst().getType(), srcPtr);
+ break;
+ default:
+ return op.emitOpError("Unsupported element size for transpose load");
+ }
+ return success();
+ }
+};
+
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1792,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
- chipset);
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+ TransposeLoadOpLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 0d0add3094666..00e9019b79647 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -524,6 +524,24 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
+LogicalResult TransposeLoadOp::verify() {
+ MemRefType srcType = cast<MemRefType>(getSrc().getType());
+
+ if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
+ return emitOpError("source memory address space must be Workgroup");
+
+ // TODO: support 6-bit element type vectors.
+ auto transferType = dyn_cast<VectorType>(getDst().getType());
+ if (!transferType)
+ return emitOpError("destination type must be a vector type");
+ size_t transferSize =
+ transferType.getNumElements() * transferType.getElementTypeBitWidth();
+ if (transferSize != 64)
+ return emitOpError("Transfering type size must be 64 bits");
+
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new amdgpu.transpose_load
wrapper operation with verification, TableGen definition, and direct lowering to ROCDL intrinsics.
- Added TableGen op definition for
TransposeLoadOp
in the AMDGPU dialect. - Implemented
TransposeLoadOp::verify()
to enforce memory space and type constraints. - Created a conversion pattern to lower
TransposeLoadOp
to ROCDL ds_read_tr intrinsics.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | Added TransposeLoadOp::verify() |
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | Defined AMDGPU_TransposeLoadOp in TableGen |
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | Added TransposeLoadOpLowering and registered it |
Comments suppressed due to low confidence (1)
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp:1796
- There are no corresponding tests for the new TransposeLoadOp and its lowering. Consider adding unit tests to cover verification and lowering paths for different element sizes and unsupported cases.
TransposeLoadOpLowering>(converter, chipset);
* 1-to-1 mapping wrapper op. * Direct lowering from AMDGPU wrapper to ROCDL intrinsics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing AMDGPU dialect tests to show the op
Missing tests for the lowering
Maybe missing a narrow type emulation pattern
Value srcPtr = | ||
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), | ||
(adaptor.getSrcIndices())); | ||
auto elementTypeSize = cast<VectorType>(op.getDst().getType()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, re small types and the like, I'm going to observe that we want to be able to load from a memref<...xi8>
or the like and then have the result be a vector of 4-bit or 6-bit types.
We'll probably want an EmulateNarrowTypes pattern - probably in the AMDGPU dialect's Transforms/ - that just substitutes in the byte memref and the relevant indices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I added tests for i8 memrefs.
def AMDGPU_TransposeLoadOp : | ||
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>, | ||
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>, | ||
Results<(outs MFMAInTypes:$dst)> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to restrict to the MFMAInTypes - that'll give the wrong impression. We just load a vector, and then impose the condition that the total returned size of the vector is 64-bits and that your element type is {4, 6, 8, 16} bits long.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually think the idea is good. I updated the code accordingly.
@krzysz00 My bad, forgot to include the test file in the PR. updated. But what do we need for emulating narrow types? |
We'll want to make a pattern on this op that's analogous to the ones in In short, this pass turns |
|
||
Note: Lowering is only supported on gfx950 and up. | ||
}]; | ||
let assemblyFormat = [{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know other ops here don't provide examples, but I think it would be worth adding going forward -- I rely on these all the time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like your idea. So I tried to add a very simple example to show the format of the op. In terms of the semantics of the instruction, it is too hard to explain in a few sentences so I wrote that "please refer to the actual document for detailed explanation".
F8E3M4 // 3 exponent, 4 mantissa | ||
]>; | ||
def F6Types : AnyTypeOf<[F6E2M3FN, F6E3M2FN]>; | ||
def TrLoadTypes : AnyTypeOf<[VectorOfLengthAndType<[4], [F16, AnyI<16>]>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BF16 exists ... and also, we can probably leave this open and rely on a getIntOrFloatBitWidth()
check in the verifier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, now it accepts any vectors and the verifier will serve as the checker.
// CHECK: rocdl.ds.read.tr6.b96 | ||
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi32, 3> -> vector<3xi32> | ||
return %0 : vector<3xi32> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll want tests for loading i4 and i6 from memrefs of bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
✅ With the latest revision this PR passed the C/C++ code formatter. |
// ElementSize -> LoadSize | ||
const std::map<size_t, size_t> KValidLoadSizeMap = { | ||
{4, 64}, | ||
{32, 96}, // 6-bit element loads use casted vector<3xi32> | ||
{8, 64}, | ||
{16, 64}, | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome suggestion. Since this is a map instead of set, I use SmallDenseMap instead, which also avoids heap allocation!