Skip to content

Commit 56a9871

Browse files
committed
Add inference based for 'Promise' based on call to 'resolve'
1 parent 6aec7f4 commit 56a9871

12 files changed

+507
-95
lines changed

src/compiler/checker.ts

+94-20
Original file line numberDiff line numberDiff line change
@@ -19458,7 +19458,7 @@ namespace ts {
1945819458
source = getUnionType(sources);
1945919459
}
1946019460
else if (target.flags & TypeFlags.Intersection && some((<IntersectionType>target).types,
19461-
t => !!getInferenceInfoForType(t) || (isGenericMappedType(t) && !!getInferenceInfoForType(getHomomorphicTypeVariable(t) || neverType)))) {
19461+
t => !!getInferenceInfoForType(inferences, t) || (isGenericMappedType(t) && !!getInferenceInfoForType(inferences, getHomomorphicTypeVariable(t) || neverType)))) {
1946219462
// We reduce intersection types only when they contain naked type parameters. For example, when
1946319463
// inferring from 'string[] & { extra: any }' to 'string[] & T' we want to remove string[] and
1946419464
// infer { extra: any } for T. But when inferring to 'string[] & Iterable<T>' we want to keep the
@@ -19490,7 +19490,7 @@ namespace ts {
1949019490
(priority & InferencePriority.ReturnType && (source === autoType || source === autoArrayType)) || isFromInferenceBlockedSource(source)) {
1949119491
return;
1949219492
}
19493-
const inference = getInferenceInfoForType(target);
19493+
const inference = getInferenceInfoForType(inferences, target);
1949419494
if (inference) {
1949519495
if (!inference.isFixed) {
1949619496
if (inference.priority === undefined || priority < inference.priority) {
@@ -19687,21 +19687,10 @@ namespace ts {
1968719687
}
1968819688
}
1968919689

19690-
function getInferenceInfoForType(type: Type) {
19691-
if (type.flags & TypeFlags.TypeVariable) {
19692-
for (const inference of inferences) {
19693-
if (type === inference.typeParameter) {
19694-
return inference;
19695-
}
19696-
}
19697-
}
19698-
return undefined;
19699-
}
19700-
1970119690
function getSingleTypeVariableFromIntersectionTypes(types: Type[]) {
1970219691
let typeVariable: Type | undefined;
1970319692
for (const type of types) {
19704-
const t = type.flags & TypeFlags.Intersection && find((<IntersectionType>type).types, t => !!getInferenceInfoForType(t));
19693+
const t = type.flags & TypeFlags.Intersection && find((<IntersectionType>type).types, t => !!getInferenceInfoForType(inferences, t));
1970519694
if (!t || typeVariable && t !== typeVariable) {
1970619695
return undefined;
1970719696
}
@@ -19722,7 +19711,7 @@ namespace ts {
1972219711
// equal priority (i.e. of equal quality) to what we would infer for a naked type
1972319712
// parameter.
1972419713
for (const t of targets) {
19725-
if (getInferenceInfoForType(t)) {
19714+
if (getInferenceInfoForType(inferences, t)) {
1972619715
nakedTypeVariable = t;
1972719716
typeVariableCount++;
1972819717
}
@@ -19764,7 +19753,7 @@ namespace ts {
1976419753
// make from nested naked type variables and given slightly higher priority by virtue
1976519754
// of being first in the candidates array.
1976619755
for (const t of targets) {
19767-
if (getInferenceInfoForType(t)) {
19756+
if (getInferenceInfoForType(inferences, t)) {
1976819757
typeVariableCount++;
1976919758
}
1977019759
else {
@@ -19778,7 +19767,7 @@ namespace ts {
1977819767
// we only infer to single naked type variables.
1977919768
if (targetFlags & TypeFlags.Intersection ? typeVariableCount === 1 : typeVariableCount > 0) {
1978019769
for (const t of targets) {
19781-
if (getInferenceInfoForType(t)) {
19770+
if (getInferenceInfoForType(inferences, t)) {
1978219771
inferWithPriority(source, t, InferencePriority.NakedTypeVariable);
1978319772
}
1978419773
}
@@ -19798,7 +19787,7 @@ namespace ts {
1979819787
// where T is a type variable. Use inferTypeForHomomorphicMappedType to infer a suitable source
1979919788
// type and then make a secondary inference from that type to T. We make a secondary inference
1980019789
// such that direct inferences to T get priority over inferences to Partial<T>, for example.
19801-
const inference = getInferenceInfoForType((<IndexType>constraintType).type);
19790+
const inference = getInferenceInfoForType(inferences, (<IndexType>constraintType).type);
1980219791
if (inference && !inference.isFixed && !isFromInferenceBlockedSource(source)) {
1980319792
const inferredType = inferTypeForHomomorphicMappedType(source, target, <IndexType>constraintType);
1980419793
if (inferredType) {
@@ -19909,7 +19898,7 @@ namespace ts {
1990919898
const middleLength = targetArity - startLength - endLength;
1991019899
if (middleLength === 2 && elementFlags[startLength] & elementFlags[startLength + 1] & ElementFlags.Variadic && isTupleType(source)) {
1991119900
// Middle of target is [...T, ...U] and source is tuple type
19912-
const targetInfo = getInferenceInfoForType(elementTypes[startLength]);
19901+
const targetInfo = getInferenceInfoForType(inferences, elementTypes[startLength]);
1991319902
if (targetInfo && targetInfo.impliedArity !== undefined) {
1991419903
// Infer slices from source based on implied arity of T.
1991519904
inferFromTypes(sliceTupleType(source, startLength, sourceEndLength + sourceArity - targetInfo.impliedArity), elementTypes[startLength]);
@@ -20004,6 +19993,22 @@ namespace ts {
2000419993
}
2000519994
}
2000619995

19996+
function getInferenceInfoForType(inferences: InferenceInfo[], type: Type) {
19997+
if (type.flags & TypeFlags.TypeVariable) {
19998+
for (const inference of inferences) {
19999+
if (type === inference.typeParameter) {
20000+
return inference;
20001+
}
20002+
}
20003+
}
20004+
return undefined;
20005+
}
20006+
20007+
function hasHigherPriorityInference(inferences: InferenceInfo[], type: Type, priority: InferencePriority) {
20008+
const inference = getInferenceInfoForType(inferences, type);
20009+
return !!inference && (inference.isFixed || inference.priority !== undefined && inference.priority < priority);
20010+
}
20011+
2000720012
function isTypeOrBaseIdenticalTo(s: Type, t: Type) {
2000820013
return isTypeIdenticalTo(s, t) || !!(t.flags & TypeFlags.String && s.flags & TypeFlags.StringLiteral || t.flags & TypeFlags.Number && s.flags & TypeFlags.NumberLiteral);
2000920014
}
@@ -20661,7 +20666,7 @@ namespace ts {
2066120666
}
2066220667

2066320668
function isTypeSubsetOf(source: Type, target: Type) {
20664-
return source === target || target.flags & TypeFlags.Union && isTypeSubsetOfUnion(source, <UnionType>target);
20669+
return source === target || !!(target.flags & TypeFlags.Union) && isTypeSubsetOfUnion(source, <UnionType>target);
2066520670
}
2066620671

2066720672
function isTypeSubsetOfUnion(source: Type, target: UnionType) {
@@ -26020,6 +26025,75 @@ namespace ts {
2602026025
inferTypes(context.inferences, spreadType, restType);
2602126026
}
2602226027

26028+
// Attempt to solve for `T` in `new Promise<T>(resolve => resolve(t))` (also known as the "revealing constructor" pattern).
26029+
// To avoid too much complexity, we use a very restrictive heuristic:
26030+
// - Restrict to NewExpression to reduce overhead.
26031+
// - `signature` has a single parameter (`callbackType`)
26032+
// - `callbackType` has a single call signature (`callbackSignature`) (i.e., `executor: (resolve: (value: T | PromiseLike<T>) => void) => void`)
26033+
// - `callbackSignature` has at least one parameter (`innerCallbackType`)
26034+
// - `innerCallbackType` has a single call signature (`innerCallbackSignature`) (i.e., `resolve: (value: T | PromiseLike<T>) => void`)
26035+
// - `innerCallbackSignature` has a single parameter (`innerCallbackValueType`)
26036+
// - `innerCallbackValueType` contains type variable for which we are gathering inferences (i.e. `value: T | PromiseLike<T>`)
26037+
// - The function (`callbackFunc`) passed as the argument to the parameter `callbackType` must be inline (i.e., an arrow function or function expression)
26038+
// - `callbackFunc` must have one parameter (`innerCallbackParam`) that is untyped (and thus would be contextually typed by `innerCallbackType`)
26039+
// If the above conditions are met then:
26040+
// - Determine the name in function `callbackFunc` given to the parameter `innerCallbackParam`
26041+
// - Find all references to that name in the body of the function `callbackFunc`
26042+
// - If `innerCallbackParam` is called directly, collect inferences for the type of the argument passed to the parameter (`innerCallbackValueType`) each call to `innerCallbackParam`
26043+
// - If `innerCallbackParam` is passed as the argument to another function, we can attempt to use the contextual type of that parameter for inference.
26044+
if (isNewExpression(node) && argCount === 1) {
26045+
const callbackType = getTypeAtPosition(signature, 0); // executor: ...
26046+
const callbackSignature = getSingleCallSignature(callbackType); // (resolve: (...) => ...) => ...
26047+
const callbackFunc = skipParentheses(args[0]);
26048+
if (callbackSignature && isFunctionExpressionOrArrowFunction(callbackFunc)) {
26049+
const sourceFile = getSourceFileOfNode(callbackFunc);
26050+
for (let callbackParamIndex = 0; callbackParamIndex < callbackFunc.parameters.length; callbackParamIndex++) {
26051+
const innerCallbackType = tryGetTypeAtPosition(callbackSignature, callbackParamIndex); // resolve: ...
26052+
const innerCallbackSignature = innerCallbackType && getSingleCallSignature(innerCallbackType); // (value: T | PromiseLike<T>) => ...
26053+
const innerCallbackParam = callbackFunc.parameters[callbackParamIndex];
26054+
if (innerCallbackSignature && getParameterCount(innerCallbackSignature) === 1 && isIdentifier(innerCallbackParam.name) && !getEffectiveTypeAnnotationNode(innerCallbackParam)) {
26055+
const innerCallbackValueType = getTypeAtPosition(innerCallbackSignature, 0); // value: ...
26056+
// Don't do the work if we already have a higher-priority inference.
26057+
if (some(signature.typeParameters, typeParam => isTypeSubsetOf(typeParam, innerCallbackValueType) && !hasHigherPriorityInference(context.inferences, typeParam, InferencePriority.RevealingConstructor))) {
26058+
const innerCallbackSymbol = getSymbolOfNode(innerCallbackParam);
26059+
const positions = getPossibleSymbolReferencePositions(sourceFile, idText(innerCallbackParam.name), callbackFunc);
26060+
if (positions.length) {
26061+
const candidateReferences = findNodesAtPositions(callbackFunc, positions, sourceFile);
26062+
if (candidateReferences.length) {
26063+
// The callback will not have a type associated with it, so we temporarily assign it `anyFunctionType` so that
26064+
// we do not trigger implicit `any` errors and so that we do not create inferences from it.
26065+
const links = getSymbolLinks(innerCallbackSymbol);
26066+
const savedType = links.type;
26067+
links.type = anyFunctionType;
26068+
// collect types for inferences to ppB
26069+
for (const candidateReference of candidateReferences) {
26070+
if (!isIdentifier(candidateReference) || candidateReference === innerCallbackParam.name) continue;
26071+
const candidateReferenceSymbol = resolveName(candidateReference, candidateReference.escapedText, SymbolFlags.Value, /*nameNotFoundMessage*/ undefined, /*nameArg*/ undefined, /*isUse*/ false);
26072+
if (candidateReferenceSymbol !== innerCallbackSymbol) continue;
26073+
if (isCallExpression(candidateReference.parent) && candidateReference === candidateReference.parent.expression) {
26074+
const argType =
26075+
candidateReference.parent.arguments.length >= 1 ? checkExpression(candidateReference.parent.arguments[0]) :
26076+
voidType;
26077+
inferTypes(context.inferences, argType, innerCallbackValueType, InferencePriority.RevealingConstructor);
26078+
}
26079+
else if (isCallOrNewExpression(candidateReference.parent) && contains(candidateReference.parent.arguments, candidateReference)) {
26080+
const callbackType = getContextualType(candidateReference);
26081+
const callbackSignature = callbackType && getSingleCallSignature(callbackType);
26082+
const callbackParamType = callbackSignature && tryGetTypeAtPosition(callbackSignature, 0);
26083+
if (callbackParamType) {
26084+
inferTypes(context.inferences, callbackParamType, innerCallbackValueType, InferencePriority.RevealingConstructor);
26085+
}
26086+
}
26087+
}
26088+
links.type = savedType;
26089+
}
26090+
}
26091+
}
26092+
}
26093+
}
26094+
}
26095+
}
26096+
2602326097
return getInferredTypes(context);
2602426098
}
2602526099

src/compiler/types.ts

+4-3
Original file line numberDiff line numberDiff line change
@@ -5437,10 +5437,11 @@ namespace ts {
54375437
ReturnType = 1 << 6, // Inference made from return type of generic function
54385438
LiteralKeyof = 1 << 7, // Inference made from a string literal to a keyof T
54395439
NoConstraints = 1 << 8, // Don't infer from constraints of instantiable types
5440-
AlwaysStrict = 1 << 9, // Always use strict rules for contravariant inferences
5441-
MaxValue = 1 << 10, // Seed for inference priority tracking
5440+
RevealingConstructor = 1 << 9, // Inference made to a callback in a "revealing constructor" (i.e., `new Promise(resolve => resolve(1))`)
5441+
AlwaysStrict = 1 << 10, // Always use strict rules for contravariant inferences
5442+
MaxValue = 1 << 11, // Seed for inference priority tracking
54425443

5443-
PriorityImpliesCombination = ReturnType | MappedTypeConstraint | LiteralKeyof, // These priorities imply that the resulting type should be a combination of all candidates
5444+
PriorityImpliesCombination = ReturnType | MappedTypeConstraint | LiteralKeyof | RevealingConstructor, // These priorities imply that the resulting type should be a combination of all candidates
54445445
Circularity = -1, // Inference circularity (value less than all other priorities)
54455446
}
54465447

src/compiler/utilities.ts

+73
Original file line numberDiff line numberDiff line change
@@ -6903,4 +6903,77 @@ namespace ts {
69036903
return bindParentToChildIgnoringJSDoc(child, parent) || bindJSDoc(child);
69046904
}
69056905
}
6906+
6907+
export function getPossibleSymbolReferencePositions(sourceFile: SourceFile, symbolName: string, container: Node = sourceFile) {
6908+
const positions: number[] = [];
6909+
6910+
/// TODO: Cache symbol existence for files to save text search
6911+
// Also, need to make this work for unicode escapes.
6912+
6913+
// Be resilient in the face of a symbol with no name or zero length name
6914+
if (!symbolName || !symbolName.length) {
6915+
return positions as readonly number[] as SortedReadonlyArray<number>;
6916+
}
6917+
6918+
const text = sourceFile.text;
6919+
const sourceLength = text.length;
6920+
const symbolNameLength = symbolName.length;
6921+
6922+
let position = text.indexOf(symbolName, container.pos);
6923+
while (position >= 0) {
6924+
// If we are past the end, stop looking
6925+
if (position > container.end) break;
6926+
6927+
// We found a match. Make sure it's not part of a larger word (i.e. the char
6928+
// before and after it have to be a non-identifier char).
6929+
const endPosition = position + symbolNameLength;
6930+
6931+
if ((position === 0 || !isIdentifierPart(text.charCodeAt(position - 1), ScriptTarget.Latest)) &&
6932+
(endPosition === sourceLength || !isIdentifierPart(text.charCodeAt(endPosition), ScriptTarget.Latest))) {
6933+
// Found a real match. Keep searching.
6934+
positions.push(position);
6935+
}
6936+
position = text.indexOf(symbolName, position + symbolNameLength + 1);
6937+
}
6938+
6939+
return positions as readonly number[] as SortedReadonlyArray<number>;
6940+
}
6941+
6942+
export function findNodesAtPositions(container: Node, positions: SortedReadonlyArray<number>, sourceFile = getSourceFileOfNode(container)) {
6943+
let i = 0;
6944+
const results: Node[] = [];
6945+
visit(container);
6946+
return results;
6947+
function visit(node: Node) {
6948+
const startPos = skipTrivia(sourceFile.text, node.pos);
6949+
while (i < positions.length) {
6950+
const pos = positions[i];
6951+
const startOffset = i;
6952+
if (pos >= node.pos && pos < node.end) {
6953+
if (pos < startPos) {
6954+
// The position exists in the node's trivia, so we should skip it and
6955+
// move on to the next position
6956+
i++;
6957+
}
6958+
else {
6959+
const length = results.length;
6960+
forEachChild(node, visit);
6961+
if (length === results.length) {
6962+
// no children were added, so add this node
6963+
results.push(node);
6964+
// advance to the next position
6965+
i++;
6966+
}
6967+
}
6968+
}
6969+
else {
6970+
// If we've advanced past the end of our parent we should break out of
6971+
// the containing `forEachChild`. Otherwise, the position is not contained
6972+
// within this node so we should skip to the next node
6973+
return !!node.parent && pos > node.parent.end;
6974+
}
6975+
Debug.assert(i !== startOffset, "Position did not advance");
6976+
}
6977+
}
6978+
}
69066979
}

0 commit comments

Comments
 (0)