diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td index 3e356373cbd73..42a611ee8e422 100644 --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -363,7 +363,7 @@ class DefaultValuedParameter : class StringRefParameter : AttrOrTypeParameter<"::llvm::StringRef", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; - let printer = [{$_printer << '"' << $_self << '"';}]; + let printer = [{$_printer.printString($_self);}]; let cppStorageType = "std::string"; let defaultValue = value; } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index f894ee64a27b0..8864ef02cd3cb 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -184,6 +184,10 @@ class AsmPrinter { /// has any special or non-printable characters in it. virtual void printKeywordOrString(StringRef keyword); + /// Print the given string as a quoted string, escaping any special or + /// non-printable characters in it. + virtual void printString(StringRef string); + /// Print the given string as a symbol reference, i.e. a form representable by /// a SymbolRefAttr. A symbol reference is represented as a string prefixed /// with '@'. The reference is surrounded with ""'s and escaped if it has any diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index c662edd592036..7b0da30541b16 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -779,6 +779,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { os << "%"; } void printKeywordOrString(StringRef) override {} + void printString(StringRef) override {} void printResourceHandle(const AsmDialectResourceHandle &) override {} void printSymbolName(StringRef) override {} void printSuccessor(Block *) override {} @@ -919,6 +920,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { /// determining potential aliases. void printFloat(const APFloat &) override {} void printKeywordOrString(StringRef) override {} + void printString(StringRef) override {} void printSymbolName(StringRef) override {} void printResourceHandle(const AsmDialectResourceHandle &) override {} @@ -2767,6 +2769,13 @@ void AsmPrinter::printKeywordOrString(StringRef keyword) { ::printKeywordOrString(keyword, impl->getStream()); } +void AsmPrinter::printString(StringRef keyword) { + assert(impl && "expected AsmPrinter::printString to be overriden"); + *this << '"'; + printEscapedString(keyword, getStream()); + *this << '"'; +} + void AsmPrinter::printSymbolName(StringRef symbolRef) { assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); ::printSymbolReference(symbolRef, impl->getStream()); diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir index 12289b4d73259..160c388cedf75 100644 --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -70,6 +70,7 @@ attributes { // CHECK: !test.optional_type_string // CHECK: !test.optional_type_string // CHECK: !test.optional_type_string<"non default"> +// CHECK: !test.optional_type_string<"containing\0A \22escape\22 characters\0F"> func.func private @test_roundtrip_default_parsers_struct( !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> @@ -111,5 +112,6 @@ func.func private @test_roundtrip_default_parsers_struct( !test.custom_type_string<"bar" bar>, !test.optional_type_string, !test.optional_type_string<"default">, - !test.optional_type_string<"non default"> + !test.optional_type_string<"non default">, + !test.optional_type_string<"containing\n \"escape\" characters\0f"> )