Skip to content

[ASTScope] Allow try in unfolded sequence to cover following elements #80461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions include/swift/AST/ASTScope.h
Original file line number Diff line number Diff line change
Expand Up @@ -1871,16 +1871,21 @@ class BraceStmtScope final : public AbstractStmtScope {
class TryScope final : public ASTScopeImpl {
public:
AnyTryExpr *const expr;
TryScope(AnyTryExpr *e)
: ASTScopeImpl(ScopeKind::Try), expr(e) {}

/// The end location of the scope. This may be past the TryExpr for
/// cases where the `try` is at the top-level of an unfolded SequenceExpr. In
/// such cases, the `try` covers all elements to the right.
SourceLoc endLoc;

TryScope(AnyTryExpr *e, SourceLoc endLoc)
: ASTScopeImpl(ScopeKind::Try), expr(e), endLoc(endLoc) {
ASSERT(endLoc.isValid());
}
virtual ~TryScope() {}

protected:
ASTScopeImpl *expandSpecifically(ScopeCreator &scopeCreator) override;

private:
void expandAScopeThatDoesNotCreateANewInsertionPoint(ScopeCreator &);

public:
SourceRange
getSourceRangeOfThisASTNode(bool omitAssertions = false) const override;
Expand Down
16 changes: 16 additions & 0 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,22 @@ class alignas(8) Expr : public ASTAllocated<Expr> {
return const_cast<Expr *>(this)->getValueProvidingExpr();
}

/// Checks whether this expression is always hoisted above a binary operator
/// when it appears on the left-hand side, e.g 'try x + y' becomes
/// 'try (x + y)'. If so, returns the sub-expression, \c nullptr otherwise.
///
/// Note that such expressions may not appear on the right-hand side of a
/// binary operator, except for assignment and ternaries.
Expr *getAlwaysLeftFoldedSubExpr() const;

/// Checks whether this expression is always hoisted above a binary operator
/// when it appears on the left-hand side, e.g 'try x + y' becomes
/// 'try (x + y)'.
///
/// Note that such expressions may not appear on the right-hand side of a
/// binary operator, except for assignment and ternaries.
bool isAlwaysLeftFolded() const { return bool(getAlwaysLeftFoldedSubExpr()); }

/// Find the original expression value, looking through various
/// implicit conversions.
const Expr *findOriginalValue() const;
Expand Down
38 changes: 31 additions & 7 deletions lib/AST/ASTScopeCreation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,39 @@ class ScopeCreator final : public ASTAllocated<ScopeCreator> {

// If we have a try/try!/try?, we need to add a scope for it
if (auto anyTry = dyn_cast<AnyTryExpr>(E)) {
scopeCreator.constructExpandAndInsert<TryScope>(parent, anyTry);
auto *scope = scopeCreator.constructExpandAndInsert<TryScope>(
parent, anyTry, anyTry->getEndLoc());
scopeCreator.addExprToScopeTree(anyTry->getSubExpr(), scope);
return Action::SkipNode(E);
}

// If we have an unfolded SequenceExpr, make sure any `try` covers all
// the following elements in the sequence. It's possible it doesn't
// end up covering some of the following elements in the folded tree,
// e.g `0 * try foo() + bar()` and `_ = try foo() ^ bar()` where `^` is
// lower precedence than `=`, but those cases are invalid and will be
// diagnosed during sequence folding.
if (auto *seqExpr = dyn_cast<SequenceExpr>(E)) {
if (!seqExpr->getFoldedExpr()) {
auto *scope = parent;
for (auto *elt : seqExpr->getElements()) {
// Make sure we look through any always-left-folded expr,
// including e.g `await` and `unsafe`.
while (auto *subExpr = elt->getAlwaysLeftFoldedSubExpr()) {
// Only `try` current receives a scope.
if (auto *ATE = dyn_cast<AnyTryExpr>(elt)) {
scope = scopeCreator.constructExpandAndInsert<TryScope>(
scope, ATE, seqExpr->getEndLoc());
}
elt = subExpr;
}
scopeCreator.addExprToScopeTree(elt, scope);
}
// Already walked.
return Action::SkipNode(E);
}
}

return Action::Continue(E);
}
PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
Expand Down Expand Up @@ -802,11 +831,11 @@ NO_NEW_INSERTION_POINT(MacroDefinitionScope)
NO_NEW_INSERTION_POINT(MacroExpansionDeclScope)
NO_NEW_INSERTION_POINT(SwitchStmtScope)
NO_NEW_INSERTION_POINT(WhileStmtScope)
NO_NEW_INSERTION_POINT(TryScope)

NO_EXPANSION(GenericParamScope)
NO_EXPANSION(SpecializeAttributeScope)
NO_EXPANSION(DifferentiableAttributeScope)
NO_EXPANSION(TryScope)

#undef CREATES_NEW_INSERTION_POINT
#undef NO_NEW_INSERTION_POINT
Expand Down Expand Up @@ -1433,11 +1462,6 @@ IterableTypeBodyPortion::insertionPointForDeferredExpansion(
return s->getParent().get();
}

void TryScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
ScopeCreator &scopeCreator) {
scopeCreator.addToScopeTree(expr->getSubExpr(), this);
}

#pragma mark verification

void ast_scope::simple_display(llvm::raw_ostream &out,
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/ASTScopeSourceRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,5 +410,5 @@ SourceLoc ast_scope::extractNearestSourceLoc(

SourceRange TryScope::getSourceRangeOfThisASTNode(
const bool omitAssertions) const {
return expr->getSourceRange();
return SourceRange(expr->getStartLoc(), endLoc);
}
11 changes: 11 additions & 0 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3017,6 +3017,17 @@ FrontendStatsTracer::getTraceFormatter<const Expr *>() {
return &TF;
}

Expr *Expr::getAlwaysLeftFoldedSubExpr() const {
if (auto *ATE = dyn_cast<AnyTryExpr>(this))
return ATE->getSubExpr();
if (auto *AE = dyn_cast<AwaitExpr>(this))
return AE->getSubExpr();
if (auto *UE = dyn_cast<UnsafeExpr>(this))
return UE->getSubExpr();

return nullptr;
}

const Expr *Expr::findOriginalValue() const {
auto *expr = this;
do {
Expand Down
26 changes: 2 additions & 24 deletions lib/ASTGen/Sources/ASTGen/Exprs.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1126,18 +1126,6 @@ extension ASTGenVisitor {
var elements: [BridgedExpr] = []
elements.reserveCapacity(node.elements.count)

// If the left-most sequence expr is a 'try', hoist it up to turn
// '(try x) + y' into 'try (x + y)'. This is necessary to do in the
// ASTGen because 'try' nodes are represented in the ASTScope tree
// to look up catch nodes. The scope tree must be syntactic because
// it's constructed before sequence folding happens during preCheckExpr.
// Otherwise, catch node lookup would find the incorrect catch node for
// 'try x + y' at the source location for 'y'.
//
// 'try' has restrictions for where it can appear within a sequence
// expr. This is still diagnosed in TypeChecker::foldSequence.
let firstTryExprSyntax = node.elements.first?.as(TryExprSyntax.self)

var iter = node.elements.makeIterator()
while let node = iter.next() {
switch node.as(ExprSyntaxEnum.self) {
Expand All @@ -1164,24 +1152,14 @@ extension ASTGenVisitor {
case .unresolvedTernaryExpr(let node):
elements.append(self.generate(unresolvedTernaryExpr: node).asExpr)
default:
if let firstTryExprSyntax, node.id == firstTryExprSyntax.id {
elements.append(self.generate(expr: firstTryExprSyntax.expression))
} else {
elements.append(self.generate(expr: node))
}
elements.append(self.generate(expr: node))
}
}

let seqExpr = BridgedSequenceExpr.createParsed(
return BridgedSequenceExpr.createParsed(
self.ctx,
exprs: elements.lazy.bridgedArray(in: self)
).asExpr

if let firstTryExprSyntax {
return self.generate(tryExpr: firstTryExprSyntax, overridingSubExpr: seqExpr)
} else {
return seqExpr
}
}

func generate(subscriptCallExpr node: SubscriptCallExprSyntax, postfixIfConfigBaseExpr: BridgedExpr? = nil) -> BridgedSubscriptExpr {
Expand Down
17 changes: 0 additions & 17 deletions lib/Parse/ParseExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,23 +370,6 @@ ParserResult<Expr> Parser::parseExprSequence(Diag<> Message,
if (SequencedExprs.size() == 1)
return makeParserResult(SequenceStatus, SequencedExprs[0]);

// If the left-most sequence expr is a 'try', hoist it up to turn
// '(try x) + y' into 'try (x + y)'. This is necessary to do in the
// parser because 'try' nodes are represented in the ASTScope tree
// to look up catch nodes. The scope tree must be syntactic because
// it's constructed before sequence folding happens during preCheckExpr.
// Otherwise, catch node lookup would find the incorrect catch node for
// 'try x + y' at the source location for 'y'.
//
// 'try' has restrictions for where it can appear within a sequence
// expr. This is still diagnosed in TypeChecker::foldSequence.
if (auto *tryEval = dyn_cast<AnyTryExpr>(SequencedExprs[0])) {
SequencedExprs[0] = tryEval->getSubExpr();
auto *sequence = SequenceExpr::create(Context, SequencedExprs);
tryEval->setSubExpr(sequence);
return makeParserResult(SequenceStatus, tryEval);
}

return makeParserResult(SequenceStatus,
SequenceExpr::create(Context, SequencedExprs));
}
Expand Down
41 changes: 22 additions & 19 deletions lib/Sema/TypeCheckExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,25 +198,28 @@ static Expr *makeBinOp(ASTContext &Ctx, Expr *Op, Expr *LHS, Expr *RHS,

// If the left-hand-side is a 'try', 'await', or 'unsafe', hoist it up
// turning "(try x) + y" into try (x + y).
if (auto *tryEval = dyn_cast<AnyTryExpr>(LHS)) {
auto sub = makeBinOp(Ctx, Op, tryEval->getSubExpr(), RHS,
opPrecedence, isEndOfSequence);
tryEval->setSubExpr(sub);
return tryEval;
}

if (auto *await = dyn_cast<AwaitExpr>(LHS)) {
auto sub = makeBinOp(Ctx, Op, await->getSubExpr(), RHS,
opPrecedence, isEndOfSequence);
await->setSubExpr(sub);
return await;
}
if (LHS->isAlwaysLeftFolded()) {
if (auto *tryEval = dyn_cast<AnyTryExpr>(LHS)) {
auto sub = makeBinOp(Ctx, Op, tryEval->getSubExpr(), RHS, opPrecedence,
isEndOfSequence);
tryEval->setSubExpr(sub);
return tryEval;
}

if (auto *unsafe = dyn_cast<UnsafeExpr>(LHS)) {
auto sub = makeBinOp(Ctx, Op, unsafe->getSubExpr(), RHS,
opPrecedence, isEndOfSequence);
unsafe->setSubExpr(sub);
return unsafe;
if (auto *await = dyn_cast<AwaitExpr>(LHS)) {
auto sub = makeBinOp(Ctx, Op, await->getSubExpr(), RHS, opPrecedence,
isEndOfSequence);
await->setSubExpr(sub);
return await;
}

if (auto *unsafe = dyn_cast<UnsafeExpr>(LHS)) {
auto sub = makeBinOp(Ctx, Op, unsafe->getSubExpr(), RHS, opPrecedence,
isEndOfSequence);
unsafe->setSubExpr(sub);
return unsafe;
}
llvm_unreachable("Unhandled left-folded case!");
}

// If the right operand is a try, await, or unsafe, it's an error unless
Expand All @@ -235,7 +238,7 @@ static Expr *makeBinOp(ASTContext &Ctx, Expr *Op, Expr *LHS, Expr *RHS,
// x ? try foo() : try bar() $#! 1
// assuming $#! is some crazy operator with lower precedence
// than the conditional operator.
if (isa<AnyTryExpr>(RHS) || isa<AwaitExpr>(RHS) || isa<UnsafeExpr>(RHS)) {
if (RHS->isAlwaysLeftFolded()) {
// If you change this, also change TRY_KIND_SELECT in diagnostics.
enum class TryKindForDiagnostics : unsigned {
Try,
Expand Down
81 changes: 81 additions & 0 deletions test/stmt/typed_throws.swift
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ func takesThrowingAutoclosure(_: @autoclosure () throws(MyError) -> Int) {}
func takesNonThrowingAutoclosure(_: @autoclosure () throws(Never) -> Int) {}

func getInt() throws -> Int { 0 }
func getIntAsync() async throws -> Int { 0 }
func getBool() throws -> Bool { true }

func throwingAutoclosures() {
takesThrowingAutoclosure(try getInt())
Expand All @@ -337,3 +339,82 @@ func noThrow() throws(Never) {
// expected-error@-1 {{thrown expression type 'MyError' cannot be converted to error type 'Never'}}
} catch {}
}

precedencegroup LowerThanAssignment {
lowerThan: AssignmentPrecedence
}
infix operator ~~~ : LowerThanAssignment
func ~~~ <T, U> (lhs: T, rhs: U) {}

func testSequenceExpr() async throws(Never) {
// Make sure the `try` here covers both calls in the ASTScope tree.
try! getInt() + getInt() // expected-warning {{result of operator '+' is unused}}
try! _ = getInt() + getInt()
_ = try! getInt() + getInt()
_ = try! getInt() + (getInt(), 0).0

_ = try try! getInt() + getInt()
// expected-warning@-1 {{no calls to throwing functions occur within 'try' expression}}

_ = await try! getIntAsync() + getIntAsync()
// expected-warning@-1 {{'try' must precede 'await'}}

_ = unsafe await try! getIntAsync() + getIntAsync()
// expected-warning@-1 {{'try' must precede 'await'}}

_ = try unsafe await try! getIntAsync() + getIntAsync()
// expected-warning@-1 {{'try' must precede 'await'}}
// expected-warning@-2 {{no calls to throwing functions occur within 'try' expression}}

try _ = (try! getInt()) + getInt()
// expected-error@-1:29 {{thrown expression type 'any Error' cannot be converted to error type 'Never'}}

// `try` on the condition covers both branches.
_ = try! getBool() ? getInt() : getInt()

// `try` on "then" branch doesn't cover else.
try _ = getBool() ? try! getInt() : getInt()
// expected-error@-1:11 {{thrown expression type 'any Error' cannot be converted to error type 'Never'}}
// expected-error@-2:39 {{thrown expression type 'any Error' cannot be converted to error type 'Never'}}

// The `try` here covers everything, even if unassignable.
try! getBool() ? getInt() : getInt() = getInt()
// expected-error@-1 {{result of conditional operator '? :' is never mutable}}

// Same here.
try! getBool() ? getInt() : getInt() ~~~ getInt()

try _ = getInt() + try! getInt()
// expected-error@-1 {{thrown expression type 'any Error' cannot be converted to error type 'Never'}}
// expected-error@-2 {{'try!' cannot appear to the right of a non-assignment operator}}

// The ASTScope for `try` here covers both, but isn't covered in the folded
// expression. This is illegal anyway.
_ = 0 + try! getInt() + getInt()
// expected-error@-1 {{'try!' cannot appear to the right of a non-assignment operator}}
// expected-error@-2:27 {{call can throw but is not marked with 'try'}}
// expected-note@-3:27 3{{did you mean}}

try _ = 0 + try! getInt() + getInt()
// expected-error@-1 {{'try!' cannot appear to the right of a non-assignment operator}}

// The ASTScope for `try` here covers both, but isn't covered in the folded
// expression due `~~~` having a lower precedence than assignment. This is
// illegal anyway.
_ = try! getInt() ~~~ getInt()
// expected-error@-1 {{'try!' following assignment operator does not cover everything to its right}}
// expected-error@-2:25 {{call can throw but is not marked with 'try'}}
// expected-note@-3:25 3{{did you mean}}

try _ = try! getInt() ~~~ getInt()
// expected-error@-1 {{'try!' following assignment operator does not cover everything to its right}}

// Same here.
true ? 0 : try! getInt() ~~~ getInt()
// expected-error@-1 {{'try!' following conditional operator does not cover everything to its right}}
// expected-error@-2:32 {{call can throw but is not marked with 'try'}}
// expected-note@-3:32 3{{did you mean}}

try true ? 0 : try! getInt() ~~~ getInt()
// expected-error@-1 {{'try!' following conditional operator does not cover everything to its right}}
}