-
Notifications
You must be signed in to change notification settings - Fork 14.7k
Closed
Labels
Description
The following MLIR file:
func.func @f(%input : memref<10xf32>, %output : memref<10xf32>, %reduc : memref<10xf32>) {
%zero = arith.constant 0. : f32
%one = arith.constant 1. : f32
affine.for %i = 0 to 10 {
%0 = affine.load %input[%i] : memref<10xf32>
%2 = arith.addf %0, %one : f32
affine.store %2, %output[%i] : memref<10xf32>
}
affine.for %i = 0 to 10 {
%0 = affine.load %input[%i] : memref<10xf32>
%1 = affine.load %reduc[0] : memref<10xf32>
%2 = arith.addf %0, %1 : f32
affine.store %2, %reduc[0] : memref<10xf32>
}
return
}
run with
bin/mlir-opt -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=sibling fusion-maximal}))'
produces:
module {
func.func @f(%arg0: memref<10xf32>, %arg1: memref<10xf32>, %arg2: memref<10xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
affine.for %arg3 = 0 to 10 {
%0 = affine.load %arg0[%arg3] : memref<10xf32>
%1 = arith.addf %0, %cst_0 : f32
affine.store %1, %arg1[%arg3] : memref<10xf32>
affine.for %arg4 = 0 to 10 {
%2 = affine.load %arg0[%arg4] : memref<10xf32>
%3 = affine.load %arg2[0] : memref<10xf32>
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg2[0] : memref<10xf32>
}
}
return
}
}
This looks incorrect to me, as the reduction into %arg2[0]
will occur 100 times instead of 10 times now. Without fusion-maximal
, the correct fusion gets applied:
module {
func.func @f(%arg0: memref<10xf32>, %arg1: memref<10xf32>, %arg2: memref<10xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
affine.for %arg3 = 0 to 10 {
%0 = affine.load %arg0[%arg3] : memref<10xf32>
%1 = arith.addf %0, %cst_0 : f32
affine.store %1, %arg1[%arg3] : memref<10xf32>
%2 = affine.load %arg0[%arg3] : memref<10xf32>
%3 = affine.load %arg2[0] : memref<10xf32>
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg2[0] : memref<10xf32>
}
return
}
}