Skip to content

mlir/Affine: sibling loop fusion with fusion-maximal on leads to invalid optimization #61820

@rohany

Description

@rohany

The following MLIR file:

func.func @f(%input : memref<10xf32>, %output : memref<10xf32>, %reduc : memref<10xf32>) {
  %zero = arith.constant 0. : f32
  %one = arith.constant 1. : f32
  affine.for %i = 0 to 10 {
    %0 = affine.load %input[%i] : memref<10xf32>
    %2 = arith.addf %0, %one : f32
    affine.store %2, %output[%i] : memref<10xf32>
  }
  affine.for %i = 0 to 10 {
    %0 = affine.load %input[%i] : memref<10xf32>
    %1 = affine.load %reduc[0] : memref<10xf32>
    %2 = arith.addf %0, %1 : f32
    affine.store %2, %reduc[0] : memref<10xf32>
  }
  return
}

run with

bin/mlir-opt -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=sibling fusion-maximal}))'

produces:

module {
  func.func @f(%arg0: memref<10xf32>, %arg1: memref<10xf32>, %arg2: memref<10xf32>) {
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 1.000000e+00 : f32
    affine.for %arg3 = 0 to 10 {
      %0 = affine.load %arg0[%arg3] : memref<10xf32>
      %1 = arith.addf %0, %cst_0 : f32
      affine.store %1, %arg1[%arg3] : memref<10xf32>
      affine.for %arg4 = 0 to 10 {
        %2 = affine.load %arg0[%arg4] : memref<10xf32>
        %3 = affine.load %arg2[0] : memref<10xf32>
        %4 = arith.addf %2, %3 : f32
        affine.store %4, %arg2[0] : memref<10xf32>
      }
    }
    return
  }
}

This looks incorrect to me, as the reduction into %arg2[0] will occur 100 times instead of 10 times now. Without fusion-maximal, the correct fusion gets applied:

module {
  func.func @f(%arg0: memref<10xf32>, %arg1: memref<10xf32>, %arg2: memref<10xf32>) {
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 1.000000e+00 : f32
    affine.for %arg3 = 0 to 10 {
      %0 = affine.load %arg0[%arg3] : memref<10xf32>
      %1 = arith.addf %0, %cst_0 : f32
      affine.store %1, %arg1[%arg3] : memref<10xf32>
      %2 = affine.load %arg0[%arg3] : memref<10xf32>
      %3 = affine.load %arg2[0] : memref<10xf32>
      %4 = arith.addf %2, %3 : f32
      affine.store %4, %arg2[0] : memref<10xf32>
    }
    return
  }
}

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions