Skip to content

[AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. #145395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,40 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}

def AMDGPU_TransposeLoadOp :
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> {
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
let description = [{
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
The transpose load op represents a subgroup load from LDS memory,
where the subgroup of threads collectively reads a matrix from the source
memref, with each thread reading a vector of the matrix, and gets a transposed matrix
in as the result. That is, each thread reads a vector of the col-major matrix at different
indices, and the thread's read result is a vector of the corresponding row of the transposed
matrix.

This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer
to the CDNA4 ISA documentation for more details about its exact semantics.

Format example:
```
%0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16>
```
Operands:
* `$src`: LDS memref to read from.
* `$srcIndices`: indices into `$src` to read from for this thread.
* `$result`: target register this transpose load instruction will write to.

Note: Lowering is only supported on gfx950 and up.
}];
let assemblyFormat = [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know other ops here don't provide examples, but I think it would be worth adding going forward -- I rely on these all the time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your idea. So I tried to add a very simple example to show the format of the op. In terms of the semantics of the instruction, it is too hard to explain in a few sentences so I wrote that "please refer to the actual document for detailed explanation".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably call out that you mean the CDNA4 ISA manual

$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
}];
let hasVerifier = 1;
}

def AMDGPU_ScaledMFMAOp :
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
Expand Down
79 changes: 77 additions & 2 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};

struct TransposeLoadOpLowering
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}

Chipset chipset;

LogicalResult
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset != kGfx950)
return op.emitOpError("Non-gfx950 chipset not supported");

Location loc = op.getLoc();
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());

// Elements in subbyte memrefs are stored non-contiguously,
// reject if source is sub-byte memref. Use emulated memrefs instead.
size_t srcElementSize =
srcMemRefType.getElementType().getIntOrFloatBitWidth();
if (srcElementSize < 8)
return op.emitOpError("Expect source memref to have at least 8 bits "
"element size, got ")
<< srcElementSize;

auto resultType = cast<VectorType>(op.getResult().getType());
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()));

size_t numElements = resultType.getNumElements();
size_t elementTypeSize =
resultType.getElementType().getIntOrFloatBitWidth();

// ROCDL transpose load intrinsics return vectors of 32-bit integers, if
// the element size is smaller than 16 bits.
Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
rewriter.getIntegerType(32));
Type llvmResultType = typeConverter->convertType(resultType);

switch (elementTypeSize) {
case 4: {
assert(numElements == 16);
auto rocdlOp =
rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 6: {
assert(numElements == 16);
auto rocdlOp =
rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 8: {
assert(numElements == 8);
auto rocdlOp =
rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 16: {
assert(numElements == 4);
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
srcPtr);
break;
}
default:
return op.emitOpError("Unsupported element size for transpose load");
}
return success();
}
};

struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
Expand Down Expand Up @@ -1749,7 +1824,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
chipset);
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

#include <limits>
Expand Down Expand Up @@ -524,6 +525,39 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}

LogicalResult TransposeLoadOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());

if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
return emitOpError("source memory address space must be Workgroup");

auto transferType = cast<VectorType>(getType());
size_t numElements = transferType.getNumElements();
size_t elementTypeSize =
transferType.getElementType().getIntOrFloatBitWidth();

// ElementSize -> NumElements
const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
{4, 16},
{6, 16},
{8, 8},
{16, 4},
};

auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
if (validNumElems == KValidLoadSizeMap.end()) {
return emitOpError("Unsupported element type size for transpose load: ")
<< elementTypeSize << " bits";
}
if (numElements != validNumElems->second) {
return emitOpError(
"Transferring type size mismatch: expected num of elements: ")
<< validNumElems->second;
}

return success();
}

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD

// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> {
// CHECK: rocdl.ds.read.tr16.b64
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, 3> -> vector<4xf16>
return %0 : vector<4xf16>
}

// -----

// CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> {
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr8.b64
// CHECK-SAME: -> vector<2xi32>
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<8xi8>
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8>
return %0 : vector<8xi8>
}

// -----

// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8
func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr4.b64
// CHECK-SAME: -> vector<2xi32>
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<16xi4>
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4>
return %0 : vector<16xi4>
}

// -----

// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi8
func.func @transpose_load_to_rocdl_i6_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi6> {
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr6.b96
// CHECK-SAME: -> vector<3xi32>
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<3xi32> to vector<16xi6>
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi6>
return %0 : vector<16xi6>
}

// -----

// CHECK-LABEL: func @transpose_load_to_rocdl_i16_memrefxi8
func.func @transpose_load_to_rocdl_i16_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<4xi16> {
// CHECK: rocdl.ds.read.tr16.b64
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<4xi16>
return %0 : vector<4xi16>
}
17 changes: 17 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 2>&1 | FileCheck %s

// -----

func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
// CHECK: memref to have at least 8 bits element size, got 4
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
return %0 : vector<16xi4>
}

// -----

func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> {
// CHECK: memref to have at least 8 bits element size, got 6
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6>
return %0 : vector<16xi6>
}
58 changes: 58 additions & 0 deletions mlir/test/Dialect/AMDGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,61 @@ func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32>
func.return %0 : vector<[4]xf32>
}

// -----

func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
// expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
func.return %0 : vector<4xf16>
}

// -----

func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
// expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
func.return %0 : vector<4xf16>
}

// -----

func.func @transpose_load_elem_f32(%idx1 : index, %idx2 : index, %mem : memref<128x32xf32, 3>) -> vector<4xf32> {
// expected-error@+1 {{'amdgpu.transpose_load' op Unsupported element type size for transpose load: 32 bits}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf32, 3> -> vector<4xf32>
func.return %0 : vector<4xf32>
}

// -----

func.func @transpose_load_vector_size_f16(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<2xf16> {
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 4}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<2xf16>
func.return %0 : vector<2xf16>
}

// -----

func.func @transpose_load_vector_size_i4(%idx1 : index, %idx2 : index, %mem : memref<128x32xi4, 3>) -> vector<20xi4> {
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi4, 3> -> vector<20xi4>
func.return %0 : vector<20xi4>
}

// -----

func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi8, 3>) -> vector<20xi8> {
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 8}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<20xi8>
func.return %0 : vector<20xi8>
}

// -----

func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi6, 3>) -> vector<8xi6> {
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size
// mismatch: expected num of elements: 16}}
% 0 = amdgpu.transpose_load %
mem[% idx1, % idx2] : memref<128x32xi6, 3>->vector<8xi6> func.return %
0 : vector<8xi6>
}
7 changes: 7 additions & 0 deletions mlir/test/Dialect/AMDGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,10 @@ func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : v
%0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
func.return %0 : vector<16xf32>
}

// CHECK-LABEL: func @transpose_load
func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<4xf16> {
// CHECK: amdgpu.transpose_load
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<4xf16>
func.return %0 : vector<4xf16>
}
Loading