Skip to content

Commit ff3b627

Browse files
committed
Less widening of literal types in type inference
1 parent 694705f commit ff3b627

File tree

2 files changed

+73
-30
lines changed

2 files changed

+73
-30
lines changed

src/compiler/checker.ts

+71-30
Original file line numberDiff line numberDiff line change
@@ -7131,15 +7131,36 @@ namespace ts {
71317131
return true;
71327132
}
71337133

7134+
function literalTypesWithSameBaseType(types: Type[]): boolean {
7135+
let commonBaseType: Type;
7136+
for (const t of types) {
7137+
const baseType = getBaseTypeOfLiteralType(t);
7138+
if (!commonBaseType) {
7139+
commonBaseType = baseType;
7140+
}
7141+
if (baseType === t || baseType !== commonBaseType) {
7142+
return false;
7143+
}
7144+
}
7145+
return true;
7146+
}
7147+
7148+
// When the candidate types are all literal types with the same base type, the common
7149+
// supertype is a union of those literal types. Otherwise, the common supertype is the
7150+
// first type that is a supertype of each of the other types.
7151+
function getSupertypeOrUnion(types: Type[]): Type {
7152+
return literalTypesWithSameBaseType(types) ? getUnionType(types) : forEach(types, t => isSupertypeOfEach(t, types) ? t : undefined);
7153+
}
7154+
71347155
function getCommonSupertype(types: Type[]): Type {
71357156
if (!strictNullChecks) {
7136-
return forEach(types, t => isSupertypeOfEach(t, types) ? t : undefined);
7157+
return getSupertypeOrUnion(types);
71377158
}
71387159
const primaryTypes = filter(types, t => !(t.flags & TypeFlags.Nullable));
71397160
if (!primaryTypes.length) {
71407161
return getUnionType(types, /*subtypeReduction*/ true);
71417162
}
7142-
const supertype = forEach(primaryTypes, t => isSupertypeOfEach(t, primaryTypes) ? t : undefined);
7163+
const supertype = getSupertypeOrUnion(primaryTypes);
71437164
return supertype && includeFalsyTypes(supertype, getFalsyFlagsOfTypes(types) & TypeFlags.Nullable);
71447165
}
71457166

@@ -7468,11 +7489,13 @@ namespace ts {
74687489
}
74697490
}
74707491

7471-
function createInferenceContext(typeParameters: TypeParameter[], inferUnionTypes: boolean): InferenceContext {
7472-
const inferences = map(typeParameters, createTypeInferencesObject);
7473-
7492+
function createInferenceContext(signature: Signature, inferUnionTypes: boolean): InferenceContext {
7493+
const typeParameters = signature.typeParameters;
7494+
const returnType = getReturnTypeOfSignature(signature);
7495+
const inferences = map(signature.typeParameters, createTypeInferencesObject);
74747496
return {
74757497
typeParameters,
7498+
returnType,
74767499
inferUnionTypes,
74777500
inferences,
74787501
inferredTypes: new Array(typeParameters.length),
@@ -7483,6 +7506,7 @@ namespace ts {
74837506
return {
74847507
primary: undefined,
74857508
secondary: undefined,
7509+
shallow: true,
74867510
isFixed: false,
74877511
};
74887512
}
@@ -7504,21 +7528,13 @@ namespace ts {
75047528
return type.couldContainTypeParameters;
75057529
}
75067530

7507-
function hasPrimitiveConstraint(type: TypeParameter): boolean {
7508-
const constraint = getConstraintOfTypeParameter(type);
7509-
return constraint && maybeTypeOfKind(constraint, TypeFlags.Primitive);
7510-
}
7511-
75127531
function inferTypes(context: InferenceContext, source: Type, target: Type) {
75137532
let sourceStack: Type[];
75147533
let targetStack: Type[];
75157534
let depth = 0;
75167535
let inferiority = 0;
75177536
const visited = createMap<boolean>();
7518-
// We widen a literal source type only if we're inferring directly to a type parameter
7519-
// that has no primitive or literal constraint.
7520-
const shouldWiden = isLiteralType(source) && target.flags & TypeFlags.TypeParameter && !hasPrimitiveConstraint(<TypeParameter>target);
7521-
inferFromTypes(shouldWiden ? getBaseTypeOfLiteralType(source) : source, target);
7537+
inferFromTypes(source, target, /*nested*/ false);
75227538

75237539
function isInProcess(source: Type, target: Type) {
75247540
for (let i = 0; i < depth; i++) {
@@ -7529,7 +7545,7 @@ namespace ts {
75297545
return false;
75307546
}
75317547

7532-
function inferFromTypes(source: Type, target: Type) {
7548+
function inferFromTypes(source: Type, target: Type, nested: boolean) {
75337549
if (!couldContainTypeParameters(target)) {
75347550
return;
75357551
}
@@ -7539,7 +7555,7 @@ namespace ts {
75397555
// are the same type, just relate each constituent type to itself.
75407556
if (source === target) {
75417557
for (const t of (<UnionOrIntersectionType>source).types) {
7542-
inferFromTypes(t, t);
7558+
inferFromTypes(t, t, /*nested*/ false);
75437559
}
75447560
return;
75457561
}
@@ -7551,7 +7567,7 @@ namespace ts {
75517567
for (const t of (<UnionOrIntersectionType>target).types) {
75527568
if (typeIdenticalToSomeType(t, (<UnionOrIntersectionType>source).types)) {
75537569
(matchingTypes || (matchingTypes = [])).push(t);
7554-
inferFromTypes(t, t);
7570+
inferFromTypes(t, t, /*nested*/ false);
75557571
}
75567572
}
75577573
// Next, to improve the quality of inferences, reduce the source and target types by
@@ -7589,6 +7605,9 @@ namespace ts {
75897605
if (!contains(candidates, source)) {
75907606
candidates.push(source);
75917607
}
7608+
if (nested) {
7609+
inferences.shallow = false;
7610+
}
75927611
}
75937612
return;
75947613
}
@@ -7600,7 +7619,7 @@ namespace ts {
76007619
const targetTypes = (<TypeReference>target).typeArguments || emptyArray;
76017620
const count = sourceTypes.length < targetTypes.length ? sourceTypes.length : targetTypes.length;
76027621
for (let i = 0; i < count; i++) {
7603-
inferFromTypes(sourceTypes[i], targetTypes[i]);
7622+
inferFromTypes(sourceTypes[i], targetTypes[i], /*nested*/ true);
76047623
}
76057624
}
76067625
else if (target.flags & TypeFlags.UnionOrIntersection) {
@@ -7614,23 +7633,23 @@ namespace ts {
76147633
typeParameterCount++;
76157634
}
76167635
else {
7617-
inferFromTypes(source, t);
7636+
inferFromTypes(source, t, /*nested*/ false);
76187637
}
76197638
}
76207639
// Next, if target containings a single naked type parameter, make a secondary inference to that type
76217640
// parameter. This gives meaningful results for union types in co-variant positions and intersection
76227641
// types in contra-variant positions (such as callback parameters).
76237642
if (typeParameterCount === 1) {
76247643
inferiority++;
7625-
inferFromTypes(source, typeParameter);
7644+
inferFromTypes(source, typeParameter, /*nested*/ false);
76267645
inferiority--;
76277646
}
76287647
}
76297648
else if (source.flags & TypeFlags.UnionOrIntersection) {
76307649
// Source is a union or intersection type, infer from each constituent type
76317650
const sourceTypes = (<UnionOrIntersectionType>source).types;
76327651
for (const sourceType of sourceTypes) {
7633-
inferFromTypes(sourceType, target);
7652+
inferFromTypes(sourceType, target, /*nested*/ false);
76347653
}
76357654
}
76367655
else {
@@ -7668,7 +7687,7 @@ namespace ts {
76687687
for (const targetProp of properties) {
76697688
const sourceProp = getPropertyOfObjectType(source, targetProp.name);
76707689
if (sourceProp) {
7671-
inferFromTypes(getTypeOfSymbol(sourceProp), getTypeOfSymbol(targetProp));
7690+
inferFromTypes(getTypeOfSymbol(sourceProp), getTypeOfSymbol(targetProp), /*nested*/ true);
76727691
}
76737692
}
76747693
}
@@ -7684,14 +7703,18 @@ namespace ts {
76847703
}
76857704
}
76867705

7706+
function inferFromParameterTypes(source: Type, target: Type) {
7707+
return inferFromTypes(source, target, /*nested*/ true);
7708+
}
7709+
76877710
function inferFromSignature(source: Signature, target: Signature) {
7688-
forEachMatchingParameterType(source, target, inferFromTypes);
7711+
forEachMatchingParameterType(source, target, inferFromParameterTypes);
76897712

76907713
if (source.typePredicate && target.typePredicate && source.typePredicate.kind === target.typePredicate.kind) {
7691-
inferFromTypes(source.typePredicate.type, target.typePredicate.type);
7714+
inferFromTypes(source.typePredicate.type, target.typePredicate.type, /*nested*/ true);
76927715
}
76937716
else {
7694-
inferFromTypes(getReturnTypeOfSignature(source), getReturnTypeOfSignature(target));
7717+
inferFromTypes(getReturnTypeOfSignature(source), getReturnTypeOfSignature(target), /*nested*/ true);
76957718
}
76967719
}
76977720

@@ -7701,7 +7724,7 @@ namespace ts {
77017724
const sourceIndexType = getIndexTypeOfType(source, IndexKind.String) ||
77027725
getImplicitIndexTypeOfType(source, IndexKind.String);
77037726
if (sourceIndexType) {
7704-
inferFromTypes(sourceIndexType, targetStringIndexType);
7727+
inferFromTypes(sourceIndexType, targetStringIndexType, /*nested*/ true);
77057728
}
77067729
}
77077730
const targetNumberIndexType = getIndexTypeOfType(target, IndexKind.Number);
@@ -7710,7 +7733,7 @@ namespace ts {
77107733
getIndexTypeOfType(source, IndexKind.String) ||
77117734
getImplicitIndexTypeOfType(source, IndexKind.Number);
77127735
if (sourceIndexType) {
7713-
inferFromTypes(sourceIndexType, targetNumberIndexType);
7736+
inferFromTypes(sourceIndexType, targetNumberIndexType, /*nested*/ true);
77147737
}
77157738
}
77167739
}
@@ -7744,14 +7767,32 @@ namespace ts {
77447767
return inferences.primary || inferences.secondary || emptyArray;
77457768
}
77467769

7770+
function hasPrimitiveConstraint(type: TypeParameter): boolean {
7771+
const constraint = getConstraintOfTypeParameter(type);
7772+
return constraint && maybeTypeOfKind(constraint, TypeFlags.Primitive);
7773+
}
7774+
7775+
function hasTypeParameterAtTopLevel(type: Type, typeParameter: TypeParameter): boolean {
7776+
return type === typeParameter || type.flags & TypeFlags.UnionOrIntersection && forEach((<UnionOrIntersectionType>type).types, t => hasTypeParameterAtTopLevel(t, typeParameter));
7777+
}
7778+
7779+
77477780
function getInferredType(context: InferenceContext, index: number): Type {
77487781
let inferredType = context.inferredTypes[index];
77497782
let inferenceSucceeded: boolean;
77507783
if (!inferredType) {
77517784
const inferences = getInferenceCandidates(context, index);
77527785
if (inferences.length) {
7786+
// We keep inferences of literal types if
7787+
// we made at least one inference that wasn't shallow, or
7788+
// the type parameter has a primitive type constraint, or
7789+
// the type parameter wasn't fixed and is referenced at top level in the return type.
7790+
const keepLiteralTypes = !context.inferences[index].shallow ||
7791+
hasPrimitiveConstraint(context.typeParameters[index]) ||
7792+
!context.inferences[index].isFixed && hasTypeParameterAtTopLevel(context.returnType, context.typeParameters[index]);
7793+
const baseInferences = keepLiteralTypes ? inferences : map(inferences, getBaseTypeOfLiteralType);
77537794
// Infer widened union or supertype, or the unknown type for no common supertype
7754-
const unionOrSuperType = context.inferUnionTypes ? getUnionType(inferences, /*subtypeReduction*/ true) : getCommonSupertype(inferences);
7795+
const unionOrSuperType = context.inferUnionTypes ? getUnionType(baseInferences, /*subtypeReduction*/ true) : getCommonSupertype(baseInferences);
77557796
inferredType = unionOrSuperType ? getWidenedType(unionOrSuperType) : unknownType;
77567797
inferenceSucceeded = !!unionOrSuperType;
77577798
}
@@ -11203,7 +11244,7 @@ namespace ts {
1120311244

1120411245
// Instantiate a generic signature in the context of a non-generic signature (section 3.8.5 in TypeScript spec)
1120511246
function instantiateSignatureInContextOf(signature: Signature, contextualSignature: Signature, contextualMapper: TypeMapper): Signature {
11206-
const context = createInferenceContext(signature.typeParameters, /*inferUnionTypes*/ true);
11247+
const context = createInferenceContext(signature, /*inferUnionTypes*/ true);
1120711248
forEachMatchingParameterType(contextualSignature, signature, (source, target) => {
1120811249
// Type parameters from outer context referenced by source type are fixed by instantiation of the source type
1120911250
inferTypes(context, instantiateType(source, contextualMapper), target);
@@ -11859,7 +11900,7 @@ namespace ts {
1185911900
let candidate: Signature;
1186011901
let typeArgumentsAreValid: boolean;
1186111902
const inferenceContext = originalCandidate.typeParameters
11862-
? createInferenceContext(originalCandidate.typeParameters, /*inferUnionTypes*/ false)
11903+
? createInferenceContext(originalCandidate, /*inferUnionTypes*/ false)
1186311904
: undefined;
1186411905

1186511906
while (true) {

src/compiler/types.ts

+2
Original file line numberDiff line numberDiff line change
@@ -2503,13 +2503,15 @@ namespace ts {
25032503
export interface TypeInferences {
25042504
primary: Type[]; // Inferences made directly to a type parameter
25052505
secondary: Type[]; // Inferences made to a type parameter in a union type
2506+
shallow: boolean; // True if all inferences were made from shallow (not nested in object type) locations
25062507
isFixed: boolean; // Whether the type parameter is fixed, as defined in section 4.12.2 of the TypeScript spec
25072508
// If a type parameter is fixed, no more inferences can be made for the type parameter
25082509
}
25092510

25102511
/* @internal */
25112512
export interface InferenceContext {
25122513
typeParameters: TypeParameter[]; // Type parameters for which inferences are made
2514+
returnType: Type; // Return type used when determining whether to widen literal types
25132515
inferUnionTypes: boolean; // Infer union types for disjoint candidates (otherwise undefinedType)
25142516
inferences: TypeInferences[]; // Inferences made for each type parameter
25152517
inferredTypes: Type[]; // Inferred type for each type parameter

0 commit comments

Comments
 (0)