Skip to content

Commit a21986b

Browse files
authored
[MLIR][PDL] Skip over all results in the PDL Bytecode if a Constraint/Rewrite failed (#139255)
Skipping only over the first results leads to the curCodeIt pointing to the wrong location in the bytecode, causing the execution to continue with a wrong instruction after the Constraint/Rewrite. Signed-off-by: Rickert, Jonas <[email protected]>
1 parent f87bcf1 commit a21986b

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

mlir/lib/Rewrite/ByteCode.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,22 +1496,24 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
14961496
void ByteCodeExecutor::processNativeFunResults(
14971497
ByteCodeRewriteResultList &results, unsigned numResults,
14981498
LogicalResult &rewriteResult) {
1499-
// Store the results in the bytecode memory or handle missing results on
1500-
// failure.
1501-
for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1502-
PDLValue::Kind resultKind = read<PDLValue::Kind>();
1503-
1499+
if (failed(rewriteResult)) {
15041500
// Skip the according number of values on the buffer on failure and exit
15051501
// early as there are no results to process.
1506-
if (failed(rewriteResult)) {
1502+
for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1503+
const PDLValue::Kind resultKind = read<PDLValue::Kind>();
15071504
if (resultKind == PDLValue::Kind::TypeRange ||
15081505
resultKind == PDLValue::Kind::ValueRange) {
15091506
skip(2);
15101507
} else {
15111508
skip(1);
15121509
}
1513-
return;
15141510
}
1511+
return;
1512+
}
1513+
1514+
// Store the results in the bytecode memory
1515+
for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1516+
PDLValue::Kind resultKind = read<PDLValue::Kind>();
15151517
PDLValue result = results.getResults()[resultIdx];
15161518
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
15171519
assert(result.getKind() == resultKind &&

mlir/test/Rewrite/pdl-bytecode.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,36 @@ module @ir attributes { test.apply_constraint_4 } {
143143

144144
// -----
145145

146+
// Test returning a type from a native constraint.
147+
module @patterns {
148+
pdl_interp.func @matcher(%root : !pdl.operation) {
149+
%new_type:2 = pdl_interp.apply_constraint "op_multiple_returns_failure"(%root : !pdl.operation) : !pdl.type, !pdl.type -> ^pat2, ^end
150+
151+
^pat2:
152+
pdl_interp.record_match @rewriters::@success(%root, %new_type#0 : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
153+
154+
^end:
155+
pdl_interp.finalize
156+
}
157+
158+
module @rewriters {
159+
pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
160+
%op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
161+
pdl_interp.erase %root
162+
pdl_interp.finalize
163+
}
164+
}
165+
}
166+
167+
// CHECK-LABEL: test.apply_constraint_multi_result_failure
168+
// CHECK-NOT: "test.replaced_by_pattern"
169+
// CHECK: "test.success_op"
170+
module @ir attributes { test.apply_constraint_multi_result_failure } {
171+
"test.success_op"() : () -> ()
172+
}
173+
174+
// -----
175+
146176
// Test success and failure cases of native constraints with pdl.range results.
147177
module @patterns {
148178
pdl_interp.func @matcher(%root : !pdl.operation) {

mlir/test/lib/Rewrite/TestPDLByteCode.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
5555
return failure();
5656
}
5757

58+
// Custom constraint that always returns failure
59+
static LogicalResult customConstraintFailure(PatternRewriter & /*rewriter*/,
60+
PDLResultList & /*results*/,
61+
ArrayRef<PDLValue> /*args*/) {
62+
return failure();
63+
}
64+
5865
// Custom constraint that returns a type range of variable length if the op is
5966
// named test.success_op
6067
static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
@@ -150,6 +157,8 @@ struct TestPDLByteCodePass
150157
customValueResultConstraint);
151158
pdlPattern.registerConstraintFunction("op_constr_return_type",
152159
customTypeResultConstraint);
160+
pdlPattern.registerConstraintFunction("op_multiple_returns_failure",
161+
customConstraintFailure);
153162
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
154163
customTypeRangeResultConstraint);
155164
pdlPattern.registerRewriteFunction("creator", customCreate);

0 commit comments

Comments
 (0)