diff --git a/src/services/refactors/inferFunctionReturnType.ts b/src/services/refactors/inferFunctionReturnType.ts
index 535e25c330b4d..919aaf5413dfa 100644
--- a/src/services/refactors/inferFunctionReturnType.ts
+++ b/src/services/refactors/inferFunctionReturnType.ts
@@ -17,8 +17,7 @@ namespace ts.refactor.inferFunctionReturnType {
function getEditsForAction(context: RefactorContext): RefactorEditInfo | undefined {
const info = getInfo(context);
if (info && !isRefactorErrorInfo(info)) {
- const edits = textChanges.ChangeTracker.with(context, t =>
- t.tryInsertTypeAnnotation(context.file, info.declaration, info.returnTypeNode));
+ const edits = textChanges.ChangeTracker.with(context, t => doChange(context.file, t, info.declaration, info.returnTypeNode));
return { renameFilename: undefined, renameLocation: undefined, edits };
}
return undefined;
@@ -55,6 +54,19 @@ namespace ts.refactor.inferFunctionReturnType {
returnTypeNode: TypeNode;
}
+ function doChange(sourceFile: SourceFile, changes: textChanges.ChangeTracker, declaration: ConvertibleDeclaration, typeNode: TypeNode) {
+ const closeParen = findChildOfKind(declaration, SyntaxKind.CloseParenToken, sourceFile);
+ const needParens = isArrowFunction(declaration) && closeParen === undefined;
+ const endNode = needParens ? first(declaration.parameters) : closeParen;
+ if (endNode) {
+ if (needParens) {
+ changes.insertNodeBefore(sourceFile, endNode, factory.createToken(SyntaxKind.OpenParenToken));
+ changes.insertNodeAfter(sourceFile, endNode, factory.createToken(SyntaxKind.CloseParenToken));
+ }
+ changes.insertNodeAt(sourceFile, endNode.end, typeNode, { prefix: ": " });
+ }
+ }
+
function getInfo(context: RefactorContext): FunctionInfo | RefactorErrorInfo | undefined {
if (isInJSFile(context.file) || !refactorKindBeginsWith(inferReturnTypeAction.kind, context.kind)) return;
diff --git a/tests/cases/fourslash/refactorInferFunctionReturnType22.ts b/tests/cases/fourslash/refactorInferFunctionReturnType22.ts
new file mode 100644
index 0000000000000..fc2b5314d7046
--- /dev/null
+++ b/tests/cases/fourslash/refactorInferFunctionReturnType22.ts
@@ -0,0 +1,16 @@
+///
+
+////const foo = async /*a*//*b*/a => {
+//// return 1;
+////}
+
+goTo.select("a", "b");
+edit.applyRefactor({
+ refactorName: "Infer function return type",
+ actionName: "Infer function return type",
+ actionDescription: "Infer function return type",
+ newContent:
+`const foo = async (a): Promise => {
+ return 1;
+}`
+});
diff --git a/tests/cases/fourslash/refactorInferFunctionReturnType23.ts b/tests/cases/fourslash/refactorInferFunctionReturnType23.ts
new file mode 100644
index 0000000000000..f6d3d3edb1d80
--- /dev/null
+++ b/tests/cases/fourslash/refactorInferFunctionReturnType23.ts
@@ -0,0 +1,16 @@
+///
+
+////const foo = async /*a*//*b*/(a) => {
+//// return 1;
+////}
+
+goTo.select("a", "b");
+edit.applyRefactor({
+ refactorName: "Infer function return type",
+ actionName: "Infer function return type",
+ actionDescription: "Infer function return type",
+ newContent:
+`const foo = async (a): Promise => {
+ return 1;
+}`
+});