Skip to content

Commit 1941996

Browse files
kovdan01asl
andauthored
[AutoDiff] Fix assert on missing struct decl on cross-file derivative search (#77183)
Consider: 1. File struct.swift defining `struct Struct` with `static func max` member 2. File derivatives.swift defining `extension Struct` with custom derivative of the `max` function 3. File error.swift defining a differentiable function which uses `Struct.max`. Previously, when passing error.swift as primary file and derivatives.swift as a secondary file to swift-frontend (and forgetting to pass struct.swift as a secondary file as well), an assertion failure was triggered. This patch fixes the issue by adding a check against `ErrorType` in `findAutoDiffOriginalFunctionDecl` before calling `lookupMember`. Co-authored-by: Anton Korobeynikov <[email protected]>
1 parent 45657fe commit 1941996

File tree

6 files changed

+79
-16
lines changed

6 files changed

+79
-16
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5460,7 +5460,8 @@ static AbstractFunctionDecl *findAutoDiffOriginalFunctionDecl(
54605460
if (!baseType && lookupContext->isTypeContext())
54615461
baseType = lookupContext->getSelfTypeInContext();
54625462
if (baseType) {
5463-
results = TypeChecker::lookupMember(lookupContext, baseType, funcName);
5463+
if (!baseType->hasError())
5464+
results = TypeChecker::lookupMember(lookupContext, baseType, funcName);
54645465
} else {
54655466
results = TypeChecker::lookupUnqualified(
54665467
lookupContext, funcName, funcNameLoc.getBaseNameLoc(), lookupOptions);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import _Differentiation
2+
3+
@inlinable
4+
@derivative(of: min)
5+
func minVJP<T: Comparable & Differentiable>(
6+
_ x: T,
7+
_ y: T
8+
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
9+
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
10+
if x <= y {
11+
return (v, .zero)
12+
}
13+
else {
14+
return (.zero, v)
15+
}
16+
}
17+
return (value: min(x, y), pullback: pullback)
18+
}
19+
20+
extension Struct {
21+
@inlinable
22+
@derivative(of: max) // expected-error {{cannot find 'max' in scope}}
23+
static func maxVJP<T: Comparable & Differentiable>(
24+
_ x: T,
25+
_ y: T
26+
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
27+
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
28+
if x < y {
29+
return (.zero, v)
30+
}
31+
else {
32+
return (v, .zero)
33+
}
34+
}
35+
return (value: max(x, y), pullback: pullback)
36+
}
37+
}

test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives.swift

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,21 @@ func minVJP<T: Comparable & Differentiable>(
1717
return (value: min(x, y), pullback: pullback)
1818
}
1919

20-
@inlinable
21-
@derivative(of: max)
22-
func maxVJP<T: Comparable & Differentiable>(
23-
_ x: T,
24-
_ y: T
25-
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
26-
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
27-
if x < y {
28-
return (.zero, v)
29-
}
30-
else {
31-
return (v, .zero)
20+
extension Struct {
21+
@inlinable
22+
@derivative(of: max)
23+
static func maxVJP<T: Comparable & Differentiable>(
24+
_ x: T,
25+
_ y: T
26+
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
27+
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
28+
if x < y {
29+
return (.zero, v)
30+
}
31+
else {
32+
return (v, .zero)
33+
}
3234
}
35+
return (value: max(x, y), pullback: pullback)
3336
}
34-
return (value: max(x, y), pullback: pullback)
3537
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import _Differentiation
2+
3+
struct Struct {
4+
@inlinable
5+
static func max<T: Comparable>(
6+
_ x: T,
7+
_ y: T
8+
) -> T {
9+
if x > y
10+
return y
11+
else
12+
return x
13+
}
14+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/derivatives-error.swift -module-name main -o /dev/null
2+
3+
import _Differentiation
4+
5+
@differentiable(reverse)
6+
func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double {
7+
return Struct.max(min(value, upperBound), lowerBound) // expected-error {{cannot find 'Struct' in scope}}
8+
}
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/derivatives.swift -module-name main -o /dev/null
1+
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s \
2+
// RUN: %S/Inputs/derivatives.swift %S/Inputs/struct.swift -module-name main -o /dev/null
23

34
import _Differentiation
45

56
@differentiable(reverse)
67
func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double {
78
// No error expected
8-
return max(min(value, upperBound), lowerBound)
9+
return Struct.max(min(value, upperBound), lowerBound)
910
}

0 commit comments

Comments
 (0)