Skip to content

Commit 3e4e078

Browse files
authored
Merge pull request #29563 from DougGregor/generalize-solve
[Constraint system] Make solver less expression-centric, part N/M
2 parents a0fbeb9 + ce35731 commit 3e4e078

File tree

6 files changed

+250
-300
lines changed

6 files changed

+250
-300
lines changed

lib/Sema/BuilderTransform.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -1107,8 +1107,14 @@ Optional<BraceStmt *> TypeChecker::applyFunctionBuilderBodyTransform(
11071107
}
11081108

11091109
// Apply the solution to the function body.
1110-
return cast_or_null<BraceStmt>(
1111-
cs.applySolutionToBody(solutions.front(), func));
1110+
if (auto result = cs.applySolution(
1111+
solutions.front(),
1112+
SolutionApplicationTarget(func),
1113+
/*performingDiagnostics=*/false)) {
1114+
return result->getFunctionBody();
1115+
}
1116+
1117+
return nullptr;
11121118
}
11131119

11141120
ConstraintSystem::TypeMatchResult ConstraintSystem::matchFunctionBuilder(

lib/Sema/CSApply.cpp

+41-30
Original file line numberDiff line numberDiff line change
@@ -7231,14 +7231,14 @@ bool ConstraintSystem::applySolutionFixes(const Solution &solution) {
72317231

72327232
/// Apply a given solution to the expression, producing a fully
72337233
/// type-checked expression.
7234-
llvm::PointerUnion<Expr *, Stmt *> ConstraintSystem::applySolutionImpl(
7235-
Solution &solution, SolutionApplicationTarget target, Type convertType,
7236-
bool discardedExpr, bool performingDiagnostics) {
7234+
Optional<SolutionApplicationTarget> ConstraintSystem::applySolution(
7235+
Solution &solution, SolutionApplicationTarget target,
7236+
bool performingDiagnostics) {
72377237
// If any fixes needed to be applied to arrive at this solution, resolve
72387238
// them to specific expressions.
72397239
if (!solution.Fixes.empty()) {
72407240
if (shouldSuppressDiagnostics())
7241-
return nullptr;
7241+
return None;
72427242

72437243
bool diagnosedErrorsViaFixes = applySolutionFixes(solution);
72447244
// If all of the available fixes would result in a warning,
@@ -7248,22 +7248,26 @@ llvm::PointerUnion<Expr *, Stmt *> ConstraintSystem::applySolutionImpl(
72487248
})) {
72497249
// If we already diagnosed any errors via fixes, that's it.
72507250
if (diagnosedErrorsViaFixes)
7251-
return nullptr;
7251+
return None;
72527252

72537253
// If we didn't manage to diagnose anything well, so fall back to
72547254
// diagnosing mining the system to construct a reasonable error message.
72557255
diagnoseFailureFor(target);
7256-
return nullptr;
7256+
return None;
72577257
}
72587258
}
72597259

72607260
ExprRewriter rewriter(*this, solution, shouldSuppressDiagnostics());
72617261
ExprWalker walker(rewriter);
72627262

72637263
// Apply the solution to the target.
7264-
llvm::PointerUnion<Expr *, Stmt *> result;
7264+
SolutionApplicationTarget result = target;
72657265
if (auto expr = target.getAsExpr()) {
7266-
result = expr->walk(walker);
7266+
Expr *rewrittenExpr = expr->walk(walker);
7267+
if (!rewrittenExpr)
7268+
return None;
7269+
7270+
result.setExpr(rewrittenExpr);
72677271
} else {
72687272
auto fn = *target.getAsFunction();
72697273

@@ -7284,14 +7288,11 @@ llvm::PointerUnion<Expr *, Stmt *> ConstraintSystem::applySolutionImpl(
72847288
});
72857289

72867290
if (!newBody)
7287-
return result;
7291+
return None;
72887292

7289-
result = newBody;
7293+
result.setFunctionBody(newBody);
72907294
}
72917295

7292-
if (result.isNull())
7293-
return result;
7294-
72957296
// If we're re-typechecking an expression for diagnostics, don't
72967297
// visit closures that have non-single expression bodies.
72977298
if (!performingDiagnostics) {
@@ -7309,15 +7310,17 @@ llvm::PointerUnion<Expr *, Stmt *> ConstraintSystem::applySolutionImpl(
73097310

73107311
// If any of them failed to type check, bail.
73117312
if (hadError)
7312-
return nullptr;
7313+
return None;
73137314
}
73147315

7315-
if (auto resultExpr = result.dyn_cast<Expr *>()) {
7316+
if (auto resultExpr = result.getAsExpr()) {
73167317
Expr *expr = target.getAsExpr();
73177318
assert(expr && "Can't have expression result without expression target");
7319+
73187320
// We are supposed to use contextual type only if it is present and
73197321
// this expression doesn't represent the implicit return of the single
73207322
// expression function which got deduced to be `Never`.
7323+
Type convertType = target.getExprConversionType();
73217324
auto shouldCoerceToContextualType = [&]() {
73227325
return convertType &&
73237326
!(getType(resultExpr)->isUninhabited() &&
@@ -7328,19 +7331,22 @@ llvm::PointerUnion<Expr *, Stmt *> ConstraintSystem::applySolutionImpl(
73287331
// If we're supposed to convert the expression to some particular type,
73297332
// do so now.
73307333
if (shouldCoerceToContextualType()) {
7331-
result = rewriter.coerceToType(resultExpr, convertType,
7332-
getConstraintLocator(expr));
7333-
if (!result)
7334-
return nullptr;
7335-
} else if (getType(resultExpr)->hasLValueType() && !discardedExpr) {
7334+
resultExpr = rewriter.coerceToType(resultExpr,
7335+
simplifyType(convertType),
7336+
getConstraintLocator(expr));
7337+
} else if (getType(resultExpr)->hasLValueType() &&
7338+
!target.isDiscardedExpr()) {
73367339
// We referenced an lvalue. Load it.
7337-
result = rewriter.coerceToType(resultExpr,
7338-
getType(resultExpr)->getRValueType(),
7339-
getConstraintLocator(expr));
7340+
resultExpr = rewriter.coerceToType(resultExpr,
7341+
getType(resultExpr)->getRValueType(),
7342+
getConstraintLocator(expr));
73407343
}
73417344

7342-
if (resultExpr)
7343-
solution.setExprTypes(resultExpr);
7345+
if (!resultExpr)
7346+
return None;
7347+
7348+
solution.setExprTypes(resultExpr);
7349+
result.setExpr(resultExpr);
73447350
}
73457351

73467352
rewriter.finalize();
@@ -7509,16 +7515,21 @@ ArrayRef<Solution> SolutionResult::getAmbiguousSolutions() const {
75097515

75107516
MutableArrayRef<Solution> SolutionResult::takeAmbiguousSolutions() && {
75117517
assert(getKind() == Ambiguous);
7518+
markAsDiagnosed();
75127519
return MutableArrayRef<Solution>(solutions, numSolutions);
75137520
}
75147521

7515-
llvm::PointerUnion<Expr *, Stmt *> SolutionApplicationTarget::walk(
7516-
ASTWalker &walker) {
7522+
SolutionApplicationTarget SolutionApplicationTarget::walk(ASTWalker &walker) {
75177523
switch (kind) {
7518-
case Kind::expression:
7519-
return getAsExpr()->walk(walker);
7524+
case Kind::expression: {
7525+
SolutionApplicationTarget result = *this;
7526+
result.setExpr(getAsExpr()->walk(walker));
7527+
return result;
7528+
}
75207529

75217530
case Kind::function:
7522-
return getAsFunction()->getBody()->walk(walker);
7531+
return SolutionApplicationTarget(
7532+
*getAsFunction(),
7533+
cast_or_null<BraceStmt>(getFunctionBody()->walk(walker)));
75237534
}
75247535
}

lib/Sema/CSSolver.cpp

+68-46
Original file line numberDiff line numberDiff line change
@@ -1074,22 +1074,22 @@ void ConstraintSystem::shrink(Expr *expr) {
10741074
}
10751075
}
10761076

1077-
static bool debugConstraintSolverForExpr(ASTContext &C, Expr *expr) {
1077+
static bool debugConstraintSolverForTarget(
1078+
ASTContext &C, SolutionApplicationTarget target) {
10781079
if (C.TypeCheckerOpts.DebugConstraintSolver)
10791080
return true;
10801081

10811082
if (C.TypeCheckerOpts.DebugConstraintSolverOnLines.empty())
10821083
// No need to compute the line number to find out it's not present.
10831084
return false;
10841085

1085-
// Get the lines on which the expression starts and ends.
1086+
// Get the lines on which the target starts and ends.
10861087
unsigned startLine = 0, endLine = 0;
1087-
if (expr->getSourceRange().isValid()) {
1088-
auto range =
1089-
Lexer::getCharSourceRangeFromSourceRange(C.SourceMgr,
1090-
expr->getSourceRange());
1091-
startLine = C.SourceMgr.getLineNumber(range.getStart());
1092-
endLine = C.SourceMgr.getLineNumber(range.getEnd());
1088+
SourceRange range = target.getSourceRange();
1089+
if (range.isValid()) {
1090+
auto charRange = Lexer::getCharSourceRangeFromSourceRange(C.SourceMgr, range);
1091+
startLine = C.SourceMgr.getLineNumber(charRange.getStart());
1092+
endLine = C.SourceMgr.getLineNumber(charRange.getEnd());
10931093
}
10941094

10951095
assert(startLine <= endLine && "expr ends before it starts?");
@@ -1107,25 +1107,44 @@ static bool debugConstraintSolverForExpr(ASTContext &C, Expr *expr) {
11071107
return startBound != endBound;
11081108
}
11091109

1110-
bool ConstraintSystem::solve(Expr *&expr,
1111-
Type convertType,
1112-
ExprTypeCheckListener *listener,
1113-
SmallVectorImpl<Solution> &solutions,
1114-
FreeTypeVariableBinding allowFreeTypeVariables) {
1110+
/// If we aren't certain that we've emitted a diagnostic, emit a fallback
1111+
/// diagnostic.
1112+
static void maybeProduceFallbackDiagnostic(
1113+
ConstraintSystem &cs, SolutionApplicationTarget target) {
1114+
if (cs.Options.contains(ConstraintSystemFlags::SubExpressionDiagnostics) ||
1115+
cs.Options.contains(ConstraintSystemFlags::SuppressDiagnostics))
1116+
return;
1117+
1118+
// Before producing fatal error here, let's check if there are any "error"
1119+
// diagnostics already emitted or waiting to be emitted. Because they are
1120+
// a better indication of the problem.
1121+
ASTContext &ctx = cs.getASTContext();
1122+
if (ctx.Diags.hadAnyError() || ctx.hasDelayedConformanceErrors())
1123+
return;
1124+
1125+
ctx.Diags.diagnose(target.getLoc(), diag::failed_to_produce_diagnostic);
1126+
}
1127+
1128+
Optional<std::vector<Solution>> ConstraintSystem::solve(
1129+
SolutionApplicationTarget &target,
1130+
ExprTypeCheckListener *listener,
1131+
FreeTypeVariableBinding allowFreeTypeVariables
1132+
) {
11151133
llvm::SaveAndRestore<bool> debugForExpr(
11161134
getASTContext().TypeCheckerOpts.DebugConstraintSolver,
1117-
debugConstraintSolverForExpr(getASTContext(), expr));
1135+
debugConstraintSolverForTarget(getASTContext(), target));
11181136

11191137
/// Dump solutions for debugging purposes.
1120-
auto dumpSolutions = [&] {
1138+
auto dumpSolutions = [&](const SolutionResult &result) {
11211139
// Debug-print the set of solutions.
11221140
if (getASTContext().TypeCheckerOpts.DebugConstraintSolver) {
11231141
auto &log = getASTContext().TypeCheckerDebug->getStream();
1124-
if (solutions.size() == 1) {
1142+
if (result.getKind() == SolutionResult::Success) {
11251143
log << "---Solution---\n";
1126-
solutions[0].dump(log);
1127-
} else {
1128-
for (unsigned i = 0, e = solutions.size(); i != e; ++i) {
1144+
result.getSolution().dump(log);
1145+
} else if (result.getKind() == SolutionResult::Ambiguous) {
1146+
auto solutions = result.getAmbiguousSolutions();
1147+
for (unsigned i : indices(solutions)) {
11291148
log << "--- Solution #" << i << " ---\n";
11301149
solutions[i].dump(log);
11311150
}
@@ -1135,59 +1154,62 @@ bool ConstraintSystem::solve(Expr *&expr,
11351154

11361155
// Take up to two attempts at solving the system. The first attempts to
11371156
// solve a system that is expected to be well-formed, the second kicks in
1138-
// when there is an error and attempts to salvage an ill-formed expression.
1157+
// when there is an error and attempts to salvage an ill-formed program.
11391158
for (unsigned stage = 0; stage != 2; ++stage) {
11401159
auto solution = (stage == 0)
1141-
? solveImpl(expr, convertType, listener, allowFreeTypeVariables)
1160+
? solveImpl(target, listener, allowFreeTypeVariables)
11421161
: salvage();
11431162

11441163
switch (solution.getKind()) {
1145-
case SolutionResult::Success:
1164+
case SolutionResult::Success: {
11461165
// Return the successful solution.
1147-
solutions.clear();
1148-
solutions.push_back(std::move(solution).takeSolution());
1149-
dumpSolutions();
1150-
return false;
1166+
dumpSolutions(solution);
1167+
std::vector<Solution> result;
1168+
result.push_back(std::move(solution).takeSolution());
1169+
return std::move(result);
1170+
}
11511171

11521172
case SolutionResult::Error:
1153-
return true;
1173+
maybeProduceFallbackDiagnostic(*this, target);
1174+
return None;
11541175

11551176
case SolutionResult::TooComplex:
1156-
getASTContext().Diags.diagnose(expr->getLoc(), diag::expression_too_complex)
1157-
.highlight(expr->getSourceRange());
1177+
getASTContext().Diags.diagnose(
1178+
target.getLoc(), diag::expression_too_complex)
1179+
.highlight(target.getSourceRange());
11581180
solution.markAsDiagnosed();
1159-
return true;
1181+
return None;
11601182

11611183
case SolutionResult::Ambiguous:
11621184
// If salvaging produced an ambiguous result, it has already been
11631185
// diagnosed.
11641186
if (stage == 1) {
11651187
solution.markAsDiagnosed();
1166-
return true;
1188+
return None;
11671189
}
11681190

11691191
if (Options.contains(
11701192
ConstraintSystemFlags::AllowUnresolvedTypeVariables)) {
1193+
dumpSolutions(solution);
11711194
auto ambiguousSolutions = std::move(solution).takeAmbiguousSolutions();
1172-
solutions.assign(std::make_move_iterator(ambiguousSolutions.begin()),
1173-
std::make_move_iterator(ambiguousSolutions.end()));
1174-
dumpSolutions();
1175-
solution.markAsDiagnosed();
1176-
return false;
1195+
std::vector<Solution> result(
1196+
std::make_move_iterator(ambiguousSolutions.begin()),
1197+
std::make_move_iterator(ambiguousSolutions.end()));
1198+
return std::move(result);
11771199
}
11781200

11791201
LLVM_FALLTHROUGH;
11801202

11811203
case SolutionResult::UndiagnosedError:
11821204
if (shouldSuppressDiagnostics()) {
11831205
solution.markAsDiagnosed();
1184-
return true;
1206+
return None;
11851207
}
11861208

11871209
if (stage == 1) {
1188-
diagnoseFailureFor(expr);
1210+
diagnoseFailureFor(target);
11891211
solution.markAsDiagnosed();
1190-
return true;
1212+
return None;
11911213
}
11921214

11931215
// Loop again to try to salvage.
@@ -1200,14 +1222,13 @@ bool ConstraintSystem::solve(Expr *&expr,
12001222
}
12011223

12021224
SolutionResult
1203-
ConstraintSystem::solveImpl(Expr *&expr,
1204-
Type convertType,
1225+
ConstraintSystem::solveImpl(SolutionApplicationTarget &target,
12051226
ExprTypeCheckListener *listener,
12061227
FreeTypeVariableBinding allowFreeTypeVariables) {
12071228
if (getASTContext().TypeCheckerOpts.DebugConstraintSolver) {
12081229
auto &log = getASTContext().TypeCheckerDebug->getStream();
1209-
log << "---Constraint solving for the expression at ";
1210-
auto R = expr->getSourceRange();
1230+
log << "---Constraint solving at ";
1231+
auto R = target.getSourceRange();
12111232
if (R.isValid()) {
12121233
R.print(log, getASTContext().SourceMgr, /*PrintText=*/ false);
12131234
} else {
@@ -1219,6 +1240,7 @@ ConstraintSystem::solveImpl(Expr *&expr,
12191240
assert(!solverState && "cannot be used directly");
12201241

12211242
// Set up the expression type checker timer.
1243+
Expr *expr = target.getAsExpr();
12221244
Timer.emplace(expr, *this);
12231245

12241246
Expr *origExpr = expr;
@@ -1232,14 +1254,12 @@ ConstraintSystem::solveImpl(Expr *&expr,
12321254
if (auto generatedExpr = generateConstraints(expr, DC))
12331255
expr = generatedExpr;
12341256
else {
1235-
if (listener)
1236-
listener->constraintGenerationFailed(expr);
12371257
return SolutionResult::forError();
12381258
}
12391259

12401260
// If there is a type that we're expected to convert to, add the conversion
12411261
// constraint.
1242-
if (convertType) {
1262+
if (Type convertType = target.getExprConversionType()) {
12431263
// Determine whether we know more about the contextual type.
12441264
ContextualTypePurpose ctp = CTP_Unused;
12451265
bool isOpaqueReturnType = false;
@@ -1286,6 +1306,8 @@ ConstraintSystem::solveImpl(Expr *&expr,
12861306
if (getExpressionTooComplex(solutions))
12871307
return SolutionResult::forTooComplex();
12881308

1309+
target.setExpr(expr);
1310+
12891311
switch (solutions.size()) {
12901312
case 0:
12911313
return SolutionResult::forUndiagnosedError();

0 commit comments

Comments
 (0)