From 1fc7cddad413f7a3e714fb2dea67b5d01e9e615f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Fri, 8 Sep 2023 23:38:38 +0200 Subject: [PATCH] [mlir] Make `StringRefParameter` roundtrippable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The current printer of `StringRefParameter` simply prints out the content of the string as is without escaping it any way. This leads to it generating invalid syntax, causing parser errors when read in again. This PR fixes that by adding ´printString` to `AsmPrinter`, allowing one to print a string that can be parsed with `parseString`, using the same escaping syntax as `StringAttr`. --- mlir/include/mlir/IR/AttrTypeBase.td | 2 +- mlir/include/mlir/IR/OpImplementation.h | 4 ++++ mlir/lib/IR/AsmPrinter.cpp | 9 +++++++++ mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir | 4 +++- 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td index 3e356373cbd73..42a611ee8e422 100644 --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -363,7 +363,7 @@ class DefaultValuedParameter : class StringRefParameter : AttrOrTypeParameter<"::llvm::StringRef", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; - let printer = [{$_printer << '"' << $_self << '"';}]; + let printer = [{$_printer.printString($_self);}]; let cppStorageType = "std::string"; let defaultValue = value; } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index f894ee64a27b0..8864ef02cd3cb 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -184,6 +184,10 @@ class AsmPrinter { /// has any special or non-printable characters in it. virtual void printKeywordOrString(StringRef keyword); + /// Print the given string as a quoted string, escaping any special or + /// non-printable characters in it. + virtual void printString(StringRef string); + /// Print the given string as a symbol reference, i.e. a form representable by /// a SymbolRefAttr. A symbol reference is represented as a string prefixed /// with '@'. The reference is surrounded with ""'s and escaped if it has any diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index c662edd592036..7b0da30541b16 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -779,6 +779,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { os << "%"; } void printKeywordOrString(StringRef) override {} + void printString(StringRef) override {} void printResourceHandle(const AsmDialectResourceHandle &) override {} void printSymbolName(StringRef) override {} void printSuccessor(Block *) override {} @@ -919,6 +920,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { /// determining potential aliases. void printFloat(const APFloat &) override {} void printKeywordOrString(StringRef) override {} + void printString(StringRef) override {} void printSymbolName(StringRef) override {} void printResourceHandle(const AsmDialectResourceHandle &) override {} @@ -2767,6 +2769,13 @@ void AsmPrinter::printKeywordOrString(StringRef keyword) { ::printKeywordOrString(keyword, impl->getStream()); } +void AsmPrinter::printString(StringRef keyword) { + assert(impl && "expected AsmPrinter::printString to be overriden"); + *this << '"'; + printEscapedString(keyword, getStream()); + *this << '"'; +} + void AsmPrinter::printSymbolName(StringRef symbolRef) { assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); ::printSymbolReference(symbolRef, impl->getStream()); diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir index 12289b4d73259..160c388cedf75 100644 --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -70,6 +70,7 @@ attributes { // CHECK: !test.optional_type_string // CHECK: !test.optional_type_string // CHECK: !test.optional_type_string<"non default"> +// CHECK: !test.optional_type_string<"containing\0A \22escape\22 characters\0F"> func.func private @test_roundtrip_default_parsers_struct( !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> @@ -111,5 +112,6 @@ func.func private @test_roundtrip_default_parsers_struct( !test.custom_type_string<"bar" bar>, !test.optional_type_string, !test.optional_type_string<"default">, - !test.optional_type_string<"non default"> + !test.optional_type_string<"non default">, + !test.optional_type_string<"containing\n \"escape\" characters\0f"> )