diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 7cd998eed2e08..a3826c56e1f62 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -169,6 +169,12 @@ XeGPUBlockingPass::getTileShape(Operation *op) const { if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) return getTileShape(op->getOpResult(0)); + if (isa(op)) + return getTileShape(op->getOpOperand(0)); + + if (isa(op)) + return getTileShape(op->getOpResult(0)); + return std::nullopt; } diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir index f9114988686c8..8e3673d04eacb 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir @@ -246,3 +246,113 @@ gpu.module @test_kernel { gpu.return } } + +// ----- +#l = #xegpu.layout +#r = #xegpu.layout + +gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %acc = arith.constant dense<0.0> : vector<64xf32> + %c64 = arith.constant 64 : index + %block_id_x = gpu.block_id x + %m = arith.muli %block_id_x, %c64 : index + %0 = xegpu.create_nd_tdesc %a[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l> + %1 = xegpu.load_nd %0: !xegpu.tensor_desc<16x64xf32, #l> -> vector<16x64xf32> + // CHECK: vector.multi_reduction , {{.*}}, [[ACC:%[0-9A-Za-z]+]] [0] : vector<16x16xf32> to vector<16xf32> + // CHECK-COUNT-3: vector.multi_reduction , {{.*}}, [[ACC]] [0] : vector<16x16xf32> to vector<16xf32> + %2 = vector.multi_reduction , %1, %acc {layout_result_0 = #r} [0]: vector<16x64xf32> to vector<64xf32> + %3 = xegpu.create_nd_tdesc %b[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r> + xegpu.store_nd %2, %3: vector<64xf32>, !xegpu.tensor_desc<64xf32, #r> + gpu.return + } +} + +// ----- +#l = #xegpu.layout +#r = #xegpu.layout + +gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %acc = arith.constant dense<0.0> : vector<32xf32> + + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + + %m = arith.muli %block_id_x, %c32 : index + %n = arith.muli %block_id_y, %c32 : index + %0 = xegpu.create_nd_tdesc %a[%m, %n] : memref<512x32xf32> -> !xegpu.tensor_desc<32x128xf32, #l> + %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x128xf32, #l> -> vector<32x128xf32> + + // CHECK: vector.multi_reduction , {{.*}}, [[INIT:%[0-9A-Za-z]+]] [1] : vector<16x16xf32> to vector<16xf32> + // CHECK-COUNT-1: vector.multi_reduction , {{.*}}, [[INIT]] [1] : vector<16x16xf32> to vector<16xf32> + + %2 = vector.multi_reduction , %1, %acc {layout_result_0 = #r} [1]: vector<32x128xf32> to vector<32xf32> + %3 = xegpu.create_nd_tdesc %b[%n] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r> + xegpu.store_nd %2, %3: vector<32xf32>, !xegpu.tensor_desc<32xf32, #r> + gpu.return + } +} + +// ----- +#r = #xegpu.layout +#l = #xegpu.layout + +gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + + %c64 = arith.constant 64 : index + %block_id_x = gpu.block_id x + %m = arith.muli %block_id_x, %c64 : index + %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r> + %1 = xegpu.load_nd %0: !xegpu.tensor_desc<64xf32, #r> -> vector<64xf32> + // CHECK-COUNT-4: vector.broadcast {{.*}} : vector<16xf32> to vector<16x16xf32> + %2 = vector.broadcast %1 {layout_result_0 = #l} : vector<64xf32> to vector<16x64xf32> + %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l> + xegpu.store_nd %2, %3: vector<16x64xf32>, !xegpu.tensor_desc<16x64xf32, #l> + gpu.return + } +} + +// ----- +#r = #xegpu.layout +#l = #xegpu.layout + +gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + + %c32 = arith.constant 32 : index + %block_id_x = gpu.block_id x + %m = arith.muli %block_id_x, %c32 : index + %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r> + %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32xf32, #r> -> vector<32xf32> + %11 = vector.shape_cast %1 : vector<32xf32> to vector<32x1xf32> + // CHECK-COUNT-8: vector.broadcast {{.*}}: vector<16x1xf32> to vector<16x16xf32> + %2 = vector.broadcast %11 {layout_result_0 = #l} : vector<32x1xf32> to vector<32x64xf32> + %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<32x64xf32, #l> + xegpu.store_nd %2, %3: vector<32x64xf32>, !xegpu.tensor_desc<32x64xf32, #l> + gpu.return + } +} + +// ----- +#l = #xegpu.layout +#t = #xegpu.layout + +gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + + %c32 = arith.constant 32 : index + %block_id_x = gpu.block_id x + %m = arith.muli %block_id_x, %c32 : index + %0 = xegpu.create_nd_tdesc %a[%m, 0] : memref<512x8xf32> -> !xegpu.tensor_desc<32x8xf32, #l> + %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x8xf32, #l> -> vector<32x8xf32> + // CHECK-COUNT-2: vector.transpose {{.*}} [1, 0] : vector<16x8xf32> to vector<8x16xf32> + %2 = vector.transpose %1, [1, 0] {layout_result_0 = #t} : vector<32x8xf32> to vector<8x32xf32> + %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<8x512xf32> -> !xegpu.tensor_desc<8x32xf32, #t> + xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t> + gpu.return + } +} \ No newline at end of file