Skip to content

Commit 9bb74b5

Browse files
[mlir][Transforms] GreedyPatternRewriteDriver: Add flag to control constant CSE'ing
By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to `GreedyRewriteConfig` that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.
1 parent 0c0c5c4 commit 9bb74b5

File tree

6 files changed

+25
-3
lines changed

6 files changed

+25
-3
lines changed

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ class GreedyRewriteConfig {
4747
/// Note: Only applicable when simplifying entire regions.
4848
bool enableRegionSimplification = true;
4949

50+
/// If set to "true", constants are CSE'd (even across multiple regions that
51+
/// are in a parent-ancestor relationship).
52+
bool cseConstants = true;
53+
5054
/// This specifies the maximum number of times the rewriter will iterate
5155
/// between applying patterns and simplifying regions. Use `kNoLimit` to
5256
/// disable this iteration limit.

mlir/include/mlir/Transforms/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def Canonicalizer : Pass<"canonicalize"> {
3535
Option<"enableRegionSimplification", "region-simplify", "bool",
3636
/*default=*/"true",
3737
"Perform control flow optimizations to the region tree">,
38+
Option<"cseConstants", "cse-constants", "bool", /*default=*/"true",
39+
"CSE constant operations">,
3840
Option<"maxIterations", "max-iterations", "int64_t",
3941
/*default=*/"10",
4042
"Max. iterations between applying patterns / simplifying regions">,

mlir/lib/Transforms/Canonicalizer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
3333
: config(config) {
3434
this->topDownProcessingEnabled = config.useTopDownTraversal;
3535
this->enableRegionSimplification = config.enableRegionSimplification;
36+
this->cseConstants = config.cseConstants;
3637
this->maxIterations = config.maxIterations;
3738
this->maxNumRewrites = config.maxNumRewrites;
3839
this->disabledPatterns = disabledPatterns;
@@ -45,6 +46,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
4546
// Set the config from possible pass options set in the meantime.
4647
config.useTopDownTraversal = topDownProcessingEnabled;
4748
config.enableRegionSimplification = enableRegionSimplification;
49+
config.cseConstants = cseConstants;
4850
config.maxIterations = maxIterations;
4951
config.maxNumRewrites = maxNumRewrites;
5052

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -848,13 +848,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
848848
if (!config.useTopDownTraversal) {
849849
// Add operations to the worklist in postorder.
850850
region.walk([&](Operation *op) {
851-
if (!insertKnownConstant(op))
851+
if (!config.cseConstants || !insertKnownConstant(op))
852852
addToWorklist(op);
853853
});
854854
} else {
855855
// Add all nested operations to the worklist in preorder.
856856
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
857-
if (!insertKnownConstant(op)) {
857+
if (!config.cseConstants || !insertKnownConstant(op)) {
858858
addToWorklist(op);
859859
return WalkResult::advance();
860860
}

mlir/test/Pass/run-reproducer.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ func.func @bar() {
1414
external_resources: {
1515
mlir_reproducer: {
1616
verify_each: true,
17-
// CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
17+
// CHECK: builtin.module(func.func(cse,canonicalize{cse-constants=true max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
1818
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
1919
disable_threading: true
2020
}

mlir/test/Transforms/test-canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
22
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
3+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{cse-constants=false}))' | FileCheck %s --check-prefixes=NO-CSE
34

45
// CHECK-LABEL: func @remove_op_with_inner_ops_pattern
56
func.func @remove_op_with_inner_ops_pattern() {
@@ -89,3 +90,16 @@ func.func @test_region_simplify() {
8990
^bb1:
9091
return
9192
}
93+
94+
// CHECK-LABEL: do_not_cse_constant
95+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
96+
// CHECK: return %[[c0]], %[[c0]]
97+
// NO-CSE-LABEL: do_not_cse_constant
98+
// NO-CSE: %[[c0:.*]] = arith.constant 0 : index
99+
// NO-CSE: %[[c1:.*]] = arith.constant 0 : index
100+
// NO-CSE: return %[[c0]], %[[c1]]
101+
func.func @do_not_cse_constant() -> (index, index) {
102+
%0 = arith.constant 0 : index
103+
%1 = arith.constant 0 : index
104+
return %0, %1 : index, index
105+
}

0 commit comments

Comments
 (0)