From b9c971d76bbee4e861aeccd1fd5581f3c601e16a Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 8 Sep 2023 00:10:59 -0400 Subject: [PATCH 1/3] [mlir][spirv] Fix coop matrix store - Fix operand/attribute order - Use ODS for parsing/printing - Allow for stride to be any integer type --- .../SPIRV/IR/SPIRVCooperativeMatrixOps.td | 15 ++++--- .../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 43 ------------------- .../SPIRV/IR/cooperative-matrix-ops.mlir | 35 +++++++++------ 3 files changed, 33 insertions(+), 60 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td index 7060aa80dc113..9da120d132227 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -171,10 +171,10 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor ``` {.ebnf} coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore ` - ssa-use `, ` ssa-use `, ` - ssa-use `, ` cooperative-matrix-layout `, ` - (`[` memory-operand `]`)? `:` - pointer-type `,` coop-matrix-type + ssa-use `,` ssa-use `,` + ssa-use `,` `<` cooperative-matrix-layout `> + (`,` `<` memory-operand `>`)? `:` + pointer-type `,` coop-matrix-type, stride-type ``` #### Example: @@ -185,6 +185,11 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor ``` }]; + let assemblyFormat = [{ + $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:` + type($pointer) `,` type($object) `,` type($stride) + }]; + let availability = [ MinVersion, MaxVersion, @@ -195,8 +200,8 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor let arguments = (ins SPIRV_AnyPtr:$pointer, SPIRV_AnyCooperativeMatrix:$object, - SPIRV_Integer:$stride, SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout, + SPIRV_Integer:$stride, OptionalAttr:$memory_operand ); diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp index bc1d30f555183..36aea151a87e3 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -106,49 +106,6 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() { // spirv.KHR.CooperativeMatrixStore //===----------------------------------------------------------------------===// -ParseResult KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser, - OperationState &result) { - std::array operandInfo = {}; - for (auto &op : operandInfo) { - if (parser.parseOperand(op) || parser.parseComma()) - return failure(); - } - - CooperativeMatrixLayoutKHR layout; - if (parseEnumKeywordAttr( - layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) { - return failure(); - } - - if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName)) - return failure(); - - Type ptrType; - Type objectType; - if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() || - parser.parseType(objectType)) { - return failure(); - } - - Type strideType = parser.getBuilder().getIntegerType(32); - if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType}, - parser.getNameLoc(), result.operands)) { - return failure(); - } - - return success(); -} - -void KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { - printer << " " << getPointer() << ", " << getObject() << ", " << getStride() - << ", " << getMatrixLayout(); - - // Print optional memory operand attribute. - if (auto memOperand = getMemoryOperand()) - printer << " [\"" << *memOperand << "\"]"; - printer << " : " << getPointer().getType() << ", " << getObject().getType(); -} - LogicalResult KHRCooperativeMatrixStoreOp::verify() { return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), getObject().getType()); diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir index aa6e072b03c5d..73be42aeeab90 100644 --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir @@ -60,10 +60,10 @@ spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr, % // CHECK-LABEL: @cooperative_matrix_store spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { - // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor : - // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA> - spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor : - !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA> + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, : + // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32 + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32 spirv.Return } @@ -71,10 +71,21 @@ spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %str spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr, %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, %stride : i32) "None" { - // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] : - // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> - spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] : - !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB> + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, , : + // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32 + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, , : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32 + spirv.Return +} + +// CHECK-LABEL: @cooperative_matrix_store_stride_i16 +spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr, + %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, + %stride : i16) "None" { + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, : + // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16 + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16 spirv.Return } @@ -128,9 +139,9 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { - // expected-error @+1 {{expected valid keyword}} + // expected-error @+1 {{expected '<'}} spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, : - !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA> + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32 spirv.Return } @@ -139,8 +150,8 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr, %stride : i32) "None" { // expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}} - spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor : - !spirv.ptr, i32 + spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, : + !spirv.ptr, i32, i32 spirv.Return } From c6d879d92346115dceb0e367f97787642c7a66ab Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 8 Sep 2023 00:45:39 -0400 Subject: [PATCH 2/3] Update examples --- .../mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td index 9da120d132227..2bf0eb6fd74bc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -180,8 +180,11 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor #### Example: ``` - spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride : - !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, : + !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32 + + spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, , : + !spirv.ptr, !spirv.coopmatrix<8x8xf32, Subgroup, MatrixAcc>, i64 ``` }]; From 0ad9037782809b6dd67faa3b743dd2b6018a7e2d Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 8 Sep 2023 11:50:59 -0400 Subject: [PATCH 3/3] Add a TODO for stride optionality --- .../Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td index 2bf0eb6fd74bc..628d15fc70fae 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -170,13 +170,16 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor inactive. ``` {.ebnf} - coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore ` + coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore` ssa-use `,` ssa-use `,` - ssa-use `,` `<` cooperative-matrix-layout `> + ssa-use `,` `<` cooperative-matrix-layout `>` (`,` `<` memory-operand `>`)? `:` - pointer-type `,` coop-matrix-type, stride-type + pointer-type `,` coop-matrix-type `,` stride-type ``` + TODO: In the SPIR-V spec, `stride` is an optional argument. We should also + support this optionality in the SPIR-V dialect. + #### Example: ``` @@ -190,7 +193,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor let assemblyFormat = [{ $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:` - type($pointer) `,` type($object) `,` type($stride) + type(operands) }]; let availability = [