@@ -7131,15 +7131,36 @@ namespace ts {
7131
7131
return true;
7132
7132
}
7133
7133
7134
+ function literalTypesWithSameBaseType(types: Type[]): boolean {
7135
+ let commonBaseType: Type;
7136
+ for (const t of types) {
7137
+ const baseType = getBaseTypeOfLiteralType(t);
7138
+ if (!commonBaseType) {
7139
+ commonBaseType = baseType;
7140
+ }
7141
+ if (baseType === t || baseType !== commonBaseType) {
7142
+ return false;
7143
+ }
7144
+ }
7145
+ return true;
7146
+ }
7147
+
7148
+ // When the candidate types are all literal types with the same base type, the common
7149
+ // supertype is a union of those literal types. Otherwise, the common supertype is the
7150
+ // first type that is a supertype of each of the other types.
7151
+ function getSupertypeOrUnion(types: Type[]): Type {
7152
+ return literalTypesWithSameBaseType(types) ? getUnionType(types) : forEach(types, t => isSupertypeOfEach(t, types) ? t : undefined);
7153
+ }
7154
+
7134
7155
function getCommonSupertype(types: Type[]): Type {
7135
7156
if (!strictNullChecks) {
7136
- return forEach (types, t => isSupertypeOfEach(t, types) ? t : undefined );
7157
+ return getSupertypeOrUnion (types);
7137
7158
}
7138
7159
const primaryTypes = filter(types, t => !(t.flags & TypeFlags.Nullable));
7139
7160
if (!primaryTypes.length) {
7140
7161
return getUnionType(types, /*subtypeReduction*/ true);
7141
7162
}
7142
- const supertype = forEach (primaryTypes, t => isSupertypeOfEach(t, primaryTypes) ? t : undefined );
7163
+ const supertype = getSupertypeOrUnion (primaryTypes);
7143
7164
return supertype && includeFalsyTypes(supertype, getFalsyFlagsOfTypes(types) & TypeFlags.Nullable);
7144
7165
}
7145
7166
@@ -7468,11 +7489,13 @@ namespace ts {
7468
7489
}
7469
7490
}
7470
7491
7471
- function createInferenceContext(typeParameters: TypeParameter[], inferUnionTypes: boolean): InferenceContext {
7472
- const inferences = map(typeParameters, createTypeInferencesObject);
7473
-
7492
+ function createInferenceContext(signature: Signature, inferUnionTypes: boolean): InferenceContext {
7493
+ const typeParameters = signature.typeParameters;
7494
+ const returnType = getReturnTypeOfSignature(signature);
7495
+ const inferences = map(signature.typeParameters, createTypeInferencesObject);
7474
7496
return {
7475
7497
typeParameters,
7498
+ returnType,
7476
7499
inferUnionTypes,
7477
7500
inferences,
7478
7501
inferredTypes: new Array(typeParameters.length),
@@ -7483,6 +7506,7 @@ namespace ts {
7483
7506
return {
7484
7507
primary: undefined,
7485
7508
secondary: undefined,
7509
+ shallow: true,
7486
7510
isFixed: false,
7487
7511
};
7488
7512
}
@@ -7504,21 +7528,13 @@ namespace ts {
7504
7528
return type.couldContainTypeParameters;
7505
7529
}
7506
7530
7507
- function hasPrimitiveConstraint(type: TypeParameter): boolean {
7508
- const constraint = getConstraintOfTypeParameter(type);
7509
- return constraint && maybeTypeOfKind(constraint, TypeFlags.Primitive);
7510
- }
7511
-
7512
7531
function inferTypes(context: InferenceContext, source: Type, target: Type) {
7513
7532
let sourceStack: Type[];
7514
7533
let targetStack: Type[];
7515
7534
let depth = 0;
7516
7535
let inferiority = 0;
7517
7536
const visited = createMap<boolean>();
7518
- // We widen a literal source type only if we're inferring directly to a type parameter
7519
- // that has no primitive or literal constraint.
7520
- const shouldWiden = isLiteralType(source) && target.flags & TypeFlags.TypeParameter && !hasPrimitiveConstraint(<TypeParameter>target);
7521
- inferFromTypes(shouldWiden ? getBaseTypeOfLiteralType(source) : source, target);
7537
+ inferFromTypes(source, target, /*nested*/ false);
7522
7538
7523
7539
function isInProcess(source: Type, target: Type) {
7524
7540
for (let i = 0; i < depth; i++) {
@@ -7529,7 +7545,7 @@ namespace ts {
7529
7545
return false;
7530
7546
}
7531
7547
7532
- function inferFromTypes(source: Type, target: Type) {
7548
+ function inferFromTypes(source: Type, target: Type, nested: boolean ) {
7533
7549
if (!couldContainTypeParameters(target)) {
7534
7550
return;
7535
7551
}
@@ -7539,7 +7555,7 @@ namespace ts {
7539
7555
// are the same type, just relate each constituent type to itself.
7540
7556
if (source === target) {
7541
7557
for (const t of (<UnionOrIntersectionType>source).types) {
7542
- inferFromTypes(t, t);
7558
+ inferFromTypes(t, t, /*nested*/ false );
7543
7559
}
7544
7560
return;
7545
7561
}
@@ -7551,7 +7567,7 @@ namespace ts {
7551
7567
for (const t of (<UnionOrIntersectionType>target).types) {
7552
7568
if (typeIdenticalToSomeType(t, (<UnionOrIntersectionType>source).types)) {
7553
7569
(matchingTypes || (matchingTypes = [])).push(t);
7554
- inferFromTypes(t, t);
7570
+ inferFromTypes(t, t, /*nested*/ false );
7555
7571
}
7556
7572
}
7557
7573
// Next, to improve the quality of inferences, reduce the source and target types by
@@ -7589,6 +7605,9 @@ namespace ts {
7589
7605
if (!contains(candidates, source)) {
7590
7606
candidates.push(source);
7591
7607
}
7608
+ if (nested) {
7609
+ inferences.shallow = false;
7610
+ }
7592
7611
}
7593
7612
return;
7594
7613
}
@@ -7600,7 +7619,7 @@ namespace ts {
7600
7619
const targetTypes = (<TypeReference>target).typeArguments || emptyArray;
7601
7620
const count = sourceTypes.length < targetTypes.length ? sourceTypes.length : targetTypes.length;
7602
7621
for (let i = 0; i < count; i++) {
7603
- inferFromTypes(sourceTypes[i], targetTypes[i]);
7622
+ inferFromTypes(sourceTypes[i], targetTypes[i], /*nested*/ true );
7604
7623
}
7605
7624
}
7606
7625
else if (target.flags & TypeFlags.UnionOrIntersection) {
@@ -7614,23 +7633,23 @@ namespace ts {
7614
7633
typeParameterCount++;
7615
7634
}
7616
7635
else {
7617
- inferFromTypes(source, t);
7636
+ inferFromTypes(source, t, /*nested*/ false );
7618
7637
}
7619
7638
}
7620
7639
// Next, if target containings a single naked type parameter, make a secondary inference to that type
7621
7640
// parameter. This gives meaningful results for union types in co-variant positions and intersection
7622
7641
// types in contra-variant positions (such as callback parameters).
7623
7642
if (typeParameterCount === 1) {
7624
7643
inferiority++;
7625
- inferFromTypes(source, typeParameter);
7644
+ inferFromTypes(source, typeParameter, /*nested*/ false );
7626
7645
inferiority--;
7627
7646
}
7628
7647
}
7629
7648
else if (source.flags & TypeFlags.UnionOrIntersection) {
7630
7649
// Source is a union or intersection type, infer from each constituent type
7631
7650
const sourceTypes = (<UnionOrIntersectionType>source).types;
7632
7651
for (const sourceType of sourceTypes) {
7633
- inferFromTypes(sourceType, target);
7652
+ inferFromTypes(sourceType, target, /*nested*/ false );
7634
7653
}
7635
7654
}
7636
7655
else {
@@ -7668,7 +7687,7 @@ namespace ts {
7668
7687
for (const targetProp of properties) {
7669
7688
const sourceProp = getPropertyOfObjectType(source, targetProp.name);
7670
7689
if (sourceProp) {
7671
- inferFromTypes(getTypeOfSymbol(sourceProp), getTypeOfSymbol(targetProp));
7690
+ inferFromTypes(getTypeOfSymbol(sourceProp), getTypeOfSymbol(targetProp), /*nested*/ true );
7672
7691
}
7673
7692
}
7674
7693
}
@@ -7684,14 +7703,18 @@ namespace ts {
7684
7703
}
7685
7704
}
7686
7705
7706
+ function inferFromParameterTypes(source: Type, target: Type) {
7707
+ return inferFromTypes(source, target, /*nested*/ true);
7708
+ }
7709
+
7687
7710
function inferFromSignature(source: Signature, target: Signature) {
7688
- forEachMatchingParameterType(source, target, inferFromTypes );
7711
+ forEachMatchingParameterType(source, target, inferFromParameterTypes );
7689
7712
7690
7713
if (source.typePredicate && target.typePredicate && source.typePredicate.kind === target.typePredicate.kind) {
7691
- inferFromTypes(source.typePredicate.type, target.typePredicate.type);
7714
+ inferFromTypes(source.typePredicate.type, target.typePredicate.type, /*nested*/ true );
7692
7715
}
7693
7716
else {
7694
- inferFromTypes(getReturnTypeOfSignature(source), getReturnTypeOfSignature(target));
7717
+ inferFromTypes(getReturnTypeOfSignature(source), getReturnTypeOfSignature(target), /*nested*/ true );
7695
7718
}
7696
7719
}
7697
7720
@@ -7701,7 +7724,7 @@ namespace ts {
7701
7724
const sourceIndexType = getIndexTypeOfType(source, IndexKind.String) ||
7702
7725
getImplicitIndexTypeOfType(source, IndexKind.String);
7703
7726
if (sourceIndexType) {
7704
- inferFromTypes(sourceIndexType, targetStringIndexType);
7727
+ inferFromTypes(sourceIndexType, targetStringIndexType, /*nested*/ true );
7705
7728
}
7706
7729
}
7707
7730
const targetNumberIndexType = getIndexTypeOfType(target, IndexKind.Number);
@@ -7710,7 +7733,7 @@ namespace ts {
7710
7733
getIndexTypeOfType(source, IndexKind.String) ||
7711
7734
getImplicitIndexTypeOfType(source, IndexKind.Number);
7712
7735
if (sourceIndexType) {
7713
- inferFromTypes(sourceIndexType, targetNumberIndexType);
7736
+ inferFromTypes(sourceIndexType, targetNumberIndexType, /*nested*/ true );
7714
7737
}
7715
7738
}
7716
7739
}
@@ -7744,14 +7767,32 @@ namespace ts {
7744
7767
return inferences.primary || inferences.secondary || emptyArray;
7745
7768
}
7746
7769
7770
+ function hasPrimitiveConstraint(type: TypeParameter): boolean {
7771
+ const constraint = getConstraintOfTypeParameter(type);
7772
+ return constraint && maybeTypeOfKind(constraint, TypeFlags.Primitive);
7773
+ }
7774
+
7775
+ function hasTypeParameterAtTopLevel(type: Type, typeParameter: TypeParameter): boolean {
7776
+ return type === typeParameter || type.flags & TypeFlags.UnionOrIntersection && forEach((<UnionOrIntersectionType>type).types, t => hasTypeParameterAtTopLevel(t, typeParameter));
7777
+ }
7778
+
7779
+
7747
7780
function getInferredType(context: InferenceContext, index: number): Type {
7748
7781
let inferredType = context.inferredTypes[index];
7749
7782
let inferenceSucceeded: boolean;
7750
7783
if (!inferredType) {
7751
7784
const inferences = getInferenceCandidates(context, index);
7752
7785
if (inferences.length) {
7786
+ // We keep inferences of literal types if
7787
+ // we made at least one inference that wasn't shallow, or
7788
+ // the type parameter has a primitive type constraint, or
7789
+ // the type parameter wasn't fixed and is referenced at top level in the return type.
7790
+ const keepLiteralTypes = !context.inferences[index].shallow ||
7791
+ hasPrimitiveConstraint(context.typeParameters[index]) ||
7792
+ !context.inferences[index].isFixed && hasTypeParameterAtTopLevel(context.returnType, context.typeParameters[index]);
7793
+ const baseInferences = keepLiteralTypes ? inferences : map(inferences, getBaseTypeOfLiteralType);
7753
7794
// Infer widened union or supertype, or the unknown type for no common supertype
7754
- const unionOrSuperType = context.inferUnionTypes ? getUnionType(inferences , /*subtypeReduction*/ true) : getCommonSupertype(inferences );
7795
+ const unionOrSuperType = context.inferUnionTypes ? getUnionType(baseInferences , /*subtypeReduction*/ true) : getCommonSupertype(baseInferences );
7755
7796
inferredType = unionOrSuperType ? getWidenedType(unionOrSuperType) : unknownType;
7756
7797
inferenceSucceeded = !!unionOrSuperType;
7757
7798
}
@@ -11203,7 +11244,7 @@ namespace ts {
11203
11244
11204
11245
// Instantiate a generic signature in the context of a non-generic signature (section 3.8.5 in TypeScript spec)
11205
11246
function instantiateSignatureInContextOf(signature: Signature, contextualSignature: Signature, contextualMapper: TypeMapper): Signature {
11206
- const context = createInferenceContext(signature.typeParameters , /*inferUnionTypes*/ true);
11247
+ const context = createInferenceContext(signature, /*inferUnionTypes*/ true);
11207
11248
forEachMatchingParameterType(contextualSignature, signature, (source, target) => {
11208
11249
// Type parameters from outer context referenced by source type are fixed by instantiation of the source type
11209
11250
inferTypes(context, instantiateType(source, contextualMapper), target);
@@ -11859,7 +11900,7 @@ namespace ts {
11859
11900
let candidate: Signature;
11860
11901
let typeArgumentsAreValid: boolean;
11861
11902
const inferenceContext = originalCandidate.typeParameters
11862
- ? createInferenceContext(originalCandidate.typeParameters , /*inferUnionTypes*/ false)
11903
+ ? createInferenceContext(originalCandidate, /*inferUnionTypes*/ false)
11863
11904
: undefined;
11864
11905
11865
11906
while (true) {
0 commit comments