Skip to content

Commit 369db67

Browse files
committed
[CSGen] Make collection subscript result type inference more principled
Infer result type of a subscript with Array or Dictionary base type if argument type matches the key type exactly or it's a supported literal type. This helps to maintain the existing behavior without having to resort to "favored type" computation.
1 parent 62ec3db commit 369db67

File tree

2 files changed

+71
-45
lines changed

2 files changed

+71
-45
lines changed

lib/Sema/CSGen.cpp

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,61 @@ namespace {
299299
return tv;
300300
}
301301

302+
/// Attempt to infer a result type of a subscript reference where
303+
/// the base type is either a stdlib Array or a Dictionary type.
304+
/// This is a more principled version of the old performance hack
305+
/// that used "favored" types deduced by the constraint optimizer
306+
/// and is important to maintain pre-existing solver behavior.
307+
Type inferCollectionSubscriptResultType(Type baseTy,
308+
ArgumentList *argumentList) {
309+
auto isLValueBase = false;
310+
auto baseObjTy = baseTy;
311+
if (baseObjTy->is<LValueType>()) {
312+
isLValueBase = true;
313+
baseObjTy = baseObjTy->getWithoutSpecifierType();
314+
}
315+
316+
auto subscriptResultType = [&isLValueBase](Type valueTy,
317+
bool isOptional) -> Type {
318+
Type outputTy = isOptional ? OptionalType::get(valueTy) : valueTy;
319+
return isLValueBase ? LValueType::get(outputTy) : outputTy;
320+
};
321+
322+
if (auto *argument = argumentList->getUnlabeledUnaryExpr()) {
323+
auto argumentTy = CS.getType(argument);
324+
325+
auto elementTy = baseObjTy->getArrayElementType();
326+
327+
if (!elementTy)
328+
elementTy = baseObjTy->getInlineArrayElementType();
329+
330+
if (elementTy) {
331+
if (auto arraySliceTy =
332+
dyn_cast<ArraySliceType>(baseObjTy.getPointer())) {
333+
baseObjTy = arraySliceTy->getDesugaredType();
334+
}
335+
336+
if (argumentTy->isInt() || isExpr<IntegerLiteralExpr>(argument))
337+
return subscriptResultType(elementTy, /*isOptional*/ false);
338+
} else if (auto dictTy = CS.isDictionaryType(baseObjTy)) {
339+
auto [keyTy, valueTy] = *dictTy;
340+
341+
if (keyTy->isString() &&
342+
(isExpr<StringLiteralExpr>(argument) ||
343+
isExpr<InterpolatedStringLiteralExpr>(argument)))
344+
return subscriptResultType(valueTy, /*isOptional*/ true);
345+
346+
if (keyTy->isInt() && isExpr<IntegerLiteralExpr>(argument))
347+
return subscriptResultType(valueTy, /*isOptional*/ true);
348+
349+
if (keyTy->isEqual(argumentTy))
350+
return subscriptResultType(valueTy, /*isOptional*/ true);
351+
}
352+
}
353+
354+
return Type();
355+
}
356+
302357
/// Add constraints for a subscript operation.
303358
Type addSubscriptConstraints(
304359
Expr *anchor, Type baseTy, ValueDecl *declOrNull, ArgumentList *argList,
@@ -322,52 +377,10 @@ namespace {
322377

323378
Type outputTy;
324379

325-
// For an integer subscript expression on an array slice type, instead of
326-
// introducing a new type variable we can easily obtain the element type.
327-
if (isa<SubscriptExpr>(anchor)) {
328-
329-
auto isLValueBase = false;
330-
auto baseObjTy = baseTy;
331-
if (baseObjTy->is<LValueType>()) {
332-
isLValueBase = true;
333-
baseObjTy = baseObjTy->getWithoutSpecifierType();
334-
}
335-
336-
auto elementTy = baseObjTy->getArrayElementType();
337-
338-
if (!elementTy)
339-
elementTy = baseObjTy->getInlineArrayElementType();
340-
341-
if (elementTy) {
342-
343-
if (auto arraySliceTy =
344-
dyn_cast<ArraySliceType>(baseObjTy.getPointer())) {
345-
baseObjTy = arraySliceTy->getDesugaredType();
346-
}
347-
348-
if (argList->isUnlabeledUnary() &&
349-
isa<IntegerLiteralExpr>(argList->getExpr(0))) {
380+
// Attempt to infer the result type of a stdlib collection subscript.
381+
if (isa<SubscriptExpr>(anchor))
382+
outputTy = inferCollectionSubscriptResultType(baseTy, argList);
350383

351-
outputTy = elementTy;
352-
353-
if (isLValueBase)
354-
outputTy = LValueType::get(outputTy);
355-
}
356-
} else if (auto dictTy = CS.isDictionaryType(baseObjTy)) {
357-
auto keyTy = dictTy->first;
358-
auto valueTy = dictTy->second;
359-
360-
if (argList->isUnlabeledUnary()) {
361-
auto argTy = CS.getType(argList->getExpr(0));
362-
if (keyTy->isEqual(argTy)) {
363-
outputTy = OptionalType::get(valueTy);
364-
if (isLValueBase)
365-
outputTy = LValueType::get(outputTy);
366-
}
367-
}
368-
}
369-
}
370-
371384
if (outputTy.isNull()) {
372385
outputTy = CS.createTypeVariable(resultLocator,
373386
TVO_CanBindToLValue | TVO_CanBindToNoEscape);
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: %scale-test --begin 1 --end 15 --step 1 --select NumLeafScopes %s --expected-exit-code 0
2+
// REQUIRES: asserts,no_asan
3+
4+
func test(carrierDict: [String : Double]) {
5+
var exhaustTemperature: Double
6+
exhaustTemperature = (
7+
(carrierDict[""] ?? 0.0) +
8+
%for i in range(N):
9+
(carrierDict[""] ?? 0.0) +
10+
%end
11+
(carrierDict[""] ?? 0.0)
12+
) / 4
13+
}

0 commit comments

Comments
 (0)