Skip to content

[Concurrency] Fix checking of captures in concurrent closures. #35798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 66 additions & 22 deletions lib/Sema/TypeCheckConcurrency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,11 +1001,13 @@ namespace {

ConcurrentExecutionChecker concurrentExecutionChecker;

using MutableVarSource = llvm::PointerUnion<DeclRefExpr *, InOutExpr *>;
using MutableVarParent = llvm::PointerUnion<InOutExpr *, LoadExpr *>;

/// Mapping from mutable local variables to the parent expression, when
/// that parent is either a load or a inout expression.
llvm::SmallDenseMap<DeclRefExpr *, MutableVarParent, 4> mutableLocalVarParent;
/// Mapping from mutable local variables or inout expressions to the
/// parent expression, when that parent is either a load or a inout expression.
llvm::SmallDenseMap<MutableVarSource, MutableVarParent, 4>
mutableLocalVarParent;

const DeclContext *getDeclContext() const {
return contextStack.back();
Expand All @@ -1022,28 +1024,66 @@ namespace {
/// If the subexpression is a reference to a mutable local variable from a
/// different context, record its parent. We'll query this as part of
/// capture semantics in concurrent functions.
void recordMutableVarParent(MutableVarParent parent, Expr *subExpr) {
auto declRef = dyn_cast<DeclRefExpr>(subExpr);
if (!declRef)
return;
///
/// \returns true if we recorded anything, false otherwise.
bool recordMutableVarParent(MutableVarParent parent, Expr *subExpr) {
subExpr = subExpr->getValueProvidingExpr();

if (auto declRef = dyn_cast<DeclRefExpr>(subExpr)) {
auto var = dyn_cast_or_null<VarDecl>(declRef->getDecl());
if (!var)
return false;

auto var = dyn_cast_or_null<VarDecl>(declRef->getDecl());
if (!var)
return;
// Only mutable variables matter.
if (!var->supportsMutation())
return false;

// Only mutable variables matter.
if (!var->supportsMutation())
return;
// Only mutable variables outside of the current context. This is an
// optimization, because the parent map won't be queried in this case, and
// it is the most common case for variables to be referenced in their
// own context.
if (var->getDeclContext() == getDeclContext())
return false;

// Only mutable variables outside of the current context. This is an
// optimization, because the parent map won't be queried in this case, and
// it is the most common case for variables to be referenced in their
// own context.
if (var->getDeclContext() == getDeclContext())
return;
assert(mutableLocalVarParent[declRef].isNull());
mutableLocalVarParent[declRef] = parent;
return true;
}

// For a member reference, try to record a parent for the base
// expression.
if (auto memberRef = dyn_cast<MemberRefExpr>(subExpr)) {
return recordMutableVarParent(parent, memberRef->getBase());
}

// For a subscript, try to record a parent for the base expression.
if (auto subscript = dyn_cast<SubscriptExpr>(subExpr)) {
return recordMutableVarParent(parent, subscript->getBase());
}

assert(mutableLocalVarParent[declRef].isNull());
mutableLocalVarParent[declRef] = parent;
// Look through postfix '!'.
if (auto force = dyn_cast<ForceValueExpr>(subExpr)) {
return recordMutableVarParent(parent, force->getSubExpr());
}

// Look through postfix '?'.
if (auto bindOpt = dyn_cast<BindOptionalExpr>(subExpr)) {
return recordMutableVarParent(parent, bindOpt->getSubExpr());
}

if (auto optEval = dyn_cast<OptionalEvaluationExpr>(subExpr)) {
return recordMutableVarParent(parent, optEval->getSubExpr());
}

// & expressions can be embedded for references to mutable variables
// or subscribes inside a struct/enum.
if (auto inout = dyn_cast<InOutExpr>(subExpr)) {
// Record the parent of the inout so we don't look at it again later.
mutableLocalVarParent[inout] = parent;
return recordMutableVarParent(parent, inout->getSubExpr());
}

return false;
}

public:
Expand Down Expand Up @@ -1127,7 +1167,8 @@ namespace {
if (!applyStack.empty())
diagnoseInOutArg(applyStack.back(), inout, false);

recordMutableVarParent(inout, inout->getSubExpr());
if (mutableLocalVarParent.count(inout) == 0)
recordMutableVarParent(inout, inout->getSubExpr());
}

if (auto load = dyn_cast<LoadExpr>(expr)) {
Expand Down Expand Up @@ -1225,6 +1266,9 @@ namespace {
if (auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
mutableLocalVarParent.erase(declRefExpr);
}
if (auto *inoutExpr = dyn_cast<InOutExpr>(expr)) {
mutableLocalVarParent.erase(inoutExpr);
}

return expr;
}
Expand Down
44 changes: 20 additions & 24 deletions test/Concurrency/concurrentfunction_capturediagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,9 @@ struct NonTrivialValueType {
func testCaseNonTrivialValue() {
var i = NonTrivialValueType(17, Klass())
f {
// Currently emits a typechecker level error due to some sort of bug in the type checker.
// print(i.i + 17)
// print(i.i + 18)
// print(i.i + 19)
print(i.i + 17)
print(i.i + 18)
print(i.i + 19)
}

i.i = 20
Expand All @@ -155,28 +154,27 @@ func testCaseNonTrivialValue() {
// We only emit a warning here since we use the last write.
//
// TODO: Should we emit for all writes?
i.i.addOne() // xpected-warning {{'i' mutated after capture by concurrent closure}}
// xpected-note @-14 {{variable defined here}}
// xpected-note @-14 {{variable captured by concurrent closure}}
// xpected-note @-14 {{capturing use}}
// xpected-note @-14 {{capturing use}}
// xpected-note @-14 {{capturing use}}
i.i.addOne() // expected-warning {{'i' mutated after capture by concurrent closure}}
// expected-note @-14 {{variable defined here}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DougGregor Thanks for updating this test!

// expected-note @-14 {{variable captured by concurrent closure}}
// expected-note @-14 {{capturing use}}
// expected-note @-14 {{capturing use}}
// expected-note @-14 {{capturing use}}
}

func testCaseNonTrivialValueInout() {
var i = NonTrivialValueType(17, Klass())
f {
// Currently emits a typechecker level error due to some sort of bug in the type checker.
// print(i.i + 17)
// print(i.k ?? "none")
print(i.i + 17)
print(i.k ?? "none")
}

// We only emit a warning here since we use the last write.
inoutUserOptKlass(&i.k) // xpected-warning {{'i' mutated after capture by concurrent closure}}
// xpected-note @-8 {{variable defined here}}
// xpected-note @-8 {{variable captured by concurrent closure}}
// xpected-note @-8 {{capturing use}}
// xpected-note @-8 {{capturing use}}
inoutUserOptKlass(&i.k) // expected-warning {{'i' mutated after capture by concurrent closure}}
// expected-note @-8 {{variable defined here}}
// expected-note @-8 {{variable captured by concurrent closure}}
// expected-note @-8 {{capturing use}}
// expected-note @-8 {{capturing use}}
}

protocol MyProt {
Expand All @@ -187,9 +185,8 @@ protocol MyProt {
func testCaseAddressOnlyAllocBoxToStackable<T : MyProt>(i : T) {
var i2 = i
f {
// Currently emits an error due to some sort of bug in the type checker.
// print(i2.i + 17)
// print(i2.k ?? "none")
print(i2.i + 17)
print(i2.k ?? "none")
}

// TODO: Make sure we emit these once we support address only types!
Expand All @@ -206,9 +203,8 @@ func testCaseAddressOnlyNoAllocBoxToStackable<T : MyProt>(i : T) {
let f2 = F()
var i2 = i
f2.useConcurrent {
// Currently emits a typechecker level error due to some sort of bug in the type checker.
// print(i2.i + 17)
// print(i2.k ?? "none")
print(i2.i + 17)
print(i2.k ?? "none")
}

// TODO: Make sure we emit these once we support address only types!
Expand Down
25 changes: 25 additions & 0 deletions test/attr/attr_concurrent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,28 @@ func mutationOfLocal() {

localInt = 20
}

struct NonTrivialValueType {
var int: Int = 0
var array: [Int] = []
var optArray: [Int]? = nil
}

func testCaseNonTrivialValue() {
var i = NonTrivialValueType()
var j = 0
acceptsConcurrent { value in
print(i.int)
print(i.array[0])
print(i.array[j])
print(i.optArray?[j] ?? 0)
print(i.optArray![j])

i.int = 5 // expected-error{{mutation of captured var 'i' in concurrently-executing code}}
i.array[0] = 5 // expected-error{{mutation of captured var 'i' in concurrently-executing code}}

return value
}

j = 17
}