|
| 1 | +/* @internal */ |
| 2 | +namespace ts.codefix { |
| 3 | + type ContextualTrackChangesFunction = (cb: (changeTracker: textChanges.ChangeTracker) => void) => FileTextChanges[]; |
| 4 | + const fixId = "addMissingAwait"; |
| 5 | + const propertyAccessCode = Diagnostics.Property_0_does_not_exist_on_type_1.code; |
| 6 | + const callableConstructableErrorCodes = [ |
| 7 | + Diagnostics.This_expression_is_not_callable.code, |
| 8 | + Diagnostics.This_expression_is_not_constructable.code, |
| 9 | + ]; |
| 10 | + const errorCodes = [ |
| 11 | + Diagnostics.An_arithmetic_operand_must_be_of_type_any_number_bigint_or_an_enum_type.code, |
| 12 | + Diagnostics.The_left_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code, |
| 13 | + Diagnostics.The_right_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code, |
| 14 | + Diagnostics.Operator_0_cannot_be_applied_to_type_1.code, |
| 15 | + Diagnostics.Operator_0_cannot_be_applied_to_types_1_and_2.code, |
| 16 | + Diagnostics.This_condition_will_always_return_0_since_the_types_1_and_2_have_no_overlap.code, |
| 17 | + Diagnostics.Type_0_is_not_an_array_type.code, |
| 18 | + Diagnostics.Type_0_is_not_an_array_type_or_a_string_type.code, |
| 19 | + Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_Use_compiler_option_downlevelIteration_to_allow_iterating_of_iterators.code, |
| 20 | + Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code, |
| 21 | + Diagnostics.Type_0_is_not_an_array_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code, |
| 22 | + Diagnostics.Type_0_must_have_a_Symbol_iterator_method_that_returns_an_iterator.code, |
| 23 | + Diagnostics.Type_0_must_have_a_Symbol_asyncIterator_method_that_returns_an_async_iterator.code, |
| 24 | + Diagnostics.Argument_of_type_0_is_not_assignable_to_parameter_of_type_1.code, |
| 25 | + propertyAccessCode, |
| 26 | + ...callableConstructableErrorCodes, |
| 27 | + ]; |
| 28 | + |
| 29 | + registerCodeFix({ |
| 30 | + fixIds: [fixId], |
| 31 | + errorCodes, |
| 32 | + getCodeActions: context => { |
| 33 | + const { sourceFile, errorCode, span, cancellationToken, program } = context; |
| 34 | + const expression = getAwaitableExpression(sourceFile, errorCode, span, cancellationToken, program); |
| 35 | + if (!expression) { |
| 36 | + return; |
| 37 | + } |
| 38 | + |
| 39 | + const checker = context.program.getTypeChecker(); |
| 40 | + const trackChanges: ContextualTrackChangesFunction = cb => textChanges.ChangeTracker.with(context, cb); |
| 41 | + return compact([ |
| 42 | + getDeclarationSiteFix(context, expression, errorCode, checker, trackChanges), |
| 43 | + getUseSiteFix(context, expression, errorCode, checker, trackChanges)]); |
| 44 | + }, |
| 45 | + getAllCodeActions: context => { |
| 46 | + const { sourceFile, program, cancellationToken } = context; |
| 47 | + const checker = context.program.getTypeChecker(); |
| 48 | + return codeFixAll(context, errorCodes, (t, diagnostic) => { |
| 49 | + const expression = getAwaitableExpression(sourceFile, diagnostic.code, diagnostic, cancellationToken, program); |
| 50 | + if (!expression) { |
| 51 | + return; |
| 52 | + } |
| 53 | + const trackChanges: ContextualTrackChangesFunction = cb => (cb(t), []); |
| 54 | + return getDeclarationSiteFix(context, expression, diagnostic.code, checker, trackChanges) |
| 55 | + || getUseSiteFix(context, expression, diagnostic.code, checker, trackChanges); |
| 56 | + }); |
| 57 | + }, |
| 58 | + }); |
| 59 | + |
| 60 | + function getDeclarationSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction) { |
| 61 | + const { sourceFile } = context; |
| 62 | + const awaitableInitializer = findAwaitableInitializer(expression, sourceFile, checker); |
| 63 | + if (awaitableInitializer) { |
| 64 | + const initializerChanges = trackChanges(t => makeChange(t, errorCode, sourceFile, checker, awaitableInitializer)); |
| 65 | + return createCodeFixActionNoFixId( |
| 66 | + "addMissingAwaitToInitializer", |
| 67 | + initializerChanges, |
| 68 | + [Diagnostics.Add_await_to_initializer_for_0, expression.getText(sourceFile)]); |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + function getUseSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction) { |
| 73 | + const changes = trackChanges(t => makeChange(t, errorCode, context.sourceFile, checker, expression)); |
| 74 | + return createCodeFixAction(fixId, changes, Diagnostics.Add_await, fixId, Diagnostics.Fix_all_expressions_possibly_missing_await); |
| 75 | + } |
| 76 | + |
| 77 | + function isMissingAwaitError(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program) { |
| 78 | + const checker = program.getDiagnosticsProducingTypeChecker(); |
| 79 | + const diagnostics = checker.getDiagnostics(sourceFile, cancellationToken); |
| 80 | + return some(diagnostics, ({ start, length, relatedInformation, code }) => |
| 81 | + isNumber(start) && isNumber(length) && textSpansEqual({ start, length }, span) && |
| 82 | + code === errorCode && |
| 83 | + !!relatedInformation && |
| 84 | + some(relatedInformation, related => related.code === Diagnostics.Did_you_forget_to_use_await.code)); |
| 85 | + } |
| 86 | + |
| 87 | + function getAwaitableExpression(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program): Expression | undefined { |
| 88 | + const token = getTokenAtPosition(sourceFile, span.start); |
| 89 | + // Checker has already done work to determine that await might be possible, and has attached |
| 90 | + // related info to the node, so start by finding the expression that exactly matches up |
| 91 | + // with the diagnostic range. |
| 92 | + const expression = findAncestor(token, node => { |
| 93 | + if (node.getStart(sourceFile) < span.start || node.getEnd() > textSpanEnd(span)) { |
| 94 | + return "quit"; |
| 95 | + } |
| 96 | + return isExpression(node) && textSpansEqual(span, createTextSpanFromNode(node, sourceFile)); |
| 97 | + }) as Expression | undefined; |
| 98 | + |
| 99 | + return expression |
| 100 | + && isMissingAwaitError(sourceFile, errorCode, span, cancellationToken, program) |
| 101 | + && isInsideAwaitableBody(expression) |
| 102 | + ? expression |
| 103 | + : undefined; |
| 104 | + } |
| 105 | + |
| 106 | + function findAwaitableInitializer(expression: Node, sourceFile: SourceFile, checker: TypeChecker): Expression | undefined { |
| 107 | + if (!isIdentifier(expression)) { |
| 108 | + return; |
| 109 | + } |
| 110 | + |
| 111 | + const symbol = checker.getSymbolAtLocation(expression); |
| 112 | + if (!symbol) { |
| 113 | + return; |
| 114 | + } |
| 115 | + |
| 116 | + const declaration = tryCast(symbol.valueDeclaration, isVariableDeclaration); |
| 117 | + const variableName = tryCast(declaration && declaration.name, isIdentifier); |
| 118 | + const variableStatement = getAncestor(declaration, SyntaxKind.VariableStatement); |
| 119 | + if (!declaration || !variableStatement || |
| 120 | + declaration.type || |
| 121 | + !declaration.initializer || |
| 122 | + variableStatement.getSourceFile() !== sourceFile || |
| 123 | + hasModifier(variableStatement, ModifierFlags.Export) || |
| 124 | + !variableName || |
| 125 | + !isInsideAwaitableBody(declaration.initializer)) { |
| 126 | + return; |
| 127 | + } |
| 128 | + |
| 129 | + const isUsedElsewhere = FindAllReferences.Core.eachSymbolReferenceInFile(variableName, checker, sourceFile, identifier => { |
| 130 | + return identifier !== expression; |
| 131 | + }); |
| 132 | + |
| 133 | + if (isUsedElsewhere) { |
| 134 | + return; |
| 135 | + } |
| 136 | + |
| 137 | + return declaration.initializer; |
| 138 | + } |
| 139 | + |
| 140 | + function isInsideAwaitableBody(node: Node) { |
| 141 | + return node.kind & NodeFlags.AwaitContext || !!findAncestor(node, ancestor => |
| 142 | + ancestor.parent && isArrowFunction(ancestor.parent) && ancestor.parent.body === ancestor || |
| 143 | + isBlock(ancestor) && ( |
| 144 | + ancestor.parent.kind === SyntaxKind.FunctionDeclaration || |
| 145 | + ancestor.parent.kind === SyntaxKind.FunctionExpression || |
| 146 | + ancestor.parent.kind === SyntaxKind.ArrowFunction || |
| 147 | + ancestor.parent.kind === SyntaxKind.MethodDeclaration)); |
| 148 | + } |
| 149 | + |
| 150 | + function makeChange(changeTracker: textChanges.ChangeTracker, errorCode: number, sourceFile: SourceFile, checker: TypeChecker, insertionSite: Expression) { |
| 151 | + if (isBinaryExpression(insertionSite)) { |
| 152 | + const { left, right } = insertionSite; |
| 153 | + const leftType = checker.getTypeAtLocation(left); |
| 154 | + const rightType = checker.getTypeAtLocation(right); |
| 155 | + const newLeft = checker.getPromisedTypeOfPromise(leftType) ? createAwait(left) : left; |
| 156 | + const newRight = checker.getPromisedTypeOfPromise(rightType) ? createAwait(right) : right; |
| 157 | + changeTracker.replaceNode(sourceFile, left, newLeft); |
| 158 | + changeTracker.replaceNode(sourceFile, right, newRight); |
| 159 | + } |
| 160 | + else if (errorCode === propertyAccessCode && isPropertyAccessExpression(insertionSite.parent)) { |
| 161 | + changeTracker.replaceNode( |
| 162 | + sourceFile, |
| 163 | + insertionSite.parent.expression, |
| 164 | + createParen(createAwait(insertionSite.parent.expression))); |
| 165 | + } |
| 166 | + else if (contains(callableConstructableErrorCodes, errorCode) && isCallOrNewExpression(insertionSite.parent)) { |
| 167 | + changeTracker.replaceNode(sourceFile, insertionSite, createParen(createAwait(insertionSite))); |
| 168 | + } |
| 169 | + else { |
| 170 | + changeTracker.replaceNode(sourceFile, insertionSite, createAwait(insertionSite)); |
| 171 | + } |
| 172 | + } |
| 173 | +} |
0 commit comments