From dcd769d31f05a4ca32b6cd233f1e2dbfb08ee355 Mon Sep 17 00:00:00 2001 From: Daniil Kovalev Date: Mon, 21 Oct 2024 18:13:26 +0300 Subject: [PATCH 1/2] [AutoDiff] Fix assert on missing struct decl on cross-file derivative search Consider: 1. File struct.swift defining `struct Struct` with `static func max` member 2. File derivatives.swift defining `extension Struct` with custom derivative of the `max` function 3. File error.swift defining a differentiable function which uses `Struct.max`. Previously, when passing error.swift as primary file and derivatives.swift as a secondary file to swift-frontend (and forgetting to pass struct.swift as a secondary file as well), an assertion failure was triggered in the following call stack: ``` assert(type->mayHaveMembers()); // while type is ErrorType TypeChecker::lookupMember findAutoDiffOriginalFunctionDecl typeCheckDerivativeAttr DerivativeAttrOriginalDeclRequest::evaluate ``` This patch fixes the issue by adding a check against `ErrorType` in `findAutoDiffOriginalFunctionDecl` before calling `lookupMember`. --- lib/Sema/TypeCheckAttr.cpp | 3 +- .../Inputs/derivatives-error.swift | 37 +++++++++++++++++++ .../Inputs/derivatives.swift | 28 +++++++------- .../Inputs/struct.swift | 14 +++++++ .../error.swift | 8 ++++ .../main.swift | 5 ++- 6 files changed, 79 insertions(+), 16 deletions(-) create mode 100644 test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives-error.swift create mode 100644 test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/struct.swift create mode 100644 test/AutoDiff/Sema/DerivativeRegistrationCrossFile/error.swift diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index e6b8cdd927bd2..2709b4264c63b 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -5459,7 +5459,8 @@ static AbstractFunctionDecl *findAutoDiffOriginalFunctionDecl( if (!baseType && lookupContext->isTypeContext()) baseType = lookupContext->getSelfTypeInContext(); if (baseType) { - results = TypeChecker::lookupMember(lookupContext, baseType, funcName); + if (!baseType.getPointer()->hasError()) + results = TypeChecker::lookupMember(lookupContext, baseType, funcName); } else { results = TypeChecker::lookupUnqualified( lookupContext, funcName, funcNameLoc.getBaseNameLoc(), lookupOptions); diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives-error.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives-error.swift new file mode 100644 index 0000000000000..1924e99efe641 --- /dev/null +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives-error.swift @@ -0,0 +1,37 @@ +import _Differentiation + +@inlinable +@derivative(of: min) +func minVJP( + _ x: T, + _ y: T +) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) { + func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) { + if x <= y { + return (v, .zero) + } + else { + return (.zero, v) + } + } + return (value: min(x, y), pullback: pullback) +} + +extension Struct { + @inlinable + @derivative(of: max) // expected-error {{cannot find 'max' in scope}} + static func maxVJP( + _ x: T, + _ y: T + ) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) { + func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) { + if x < y { + return (.zero, v) + } + else { + return (v, .zero) + } + } + return (value: max(x, y), pullback: pullback) + } +} diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives.swift index 71dffec74e834..cfc5d8f5018b5 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives.swift @@ -17,19 +17,21 @@ func minVJP( return (value: min(x, y), pullback: pullback) } -@inlinable -@derivative(of: max) -func maxVJP( - _ x: T, - _ y: T -) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) { - func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) { - if x < y { - return (.zero, v) - } - else { - return (v, .zero) +extension Struct { + @inlinable + @derivative(of: max) + static func maxVJP( + _ x: T, + _ y: T + ) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) { + func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) { + if x < y { + return (.zero, v) + } + else { + return (v, .zero) + } } + return (value: max(x, y), pullback: pullback) } - return (value: max(x, y), pullback: pullback) } diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/struct.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/struct.swift new file mode 100644 index 0000000000000..c43b87974fdb0 --- /dev/null +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/struct.swift @@ -0,0 +1,14 @@ +import _Differentiation + +struct Struct { + @inlinable + static func max( + _ x: T, + _ y: T + ) -> T { + if x > y + return y + else + return x + } +} diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/error.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/error.swift new file mode 100644 index 0000000000000..380ac70078a21 --- /dev/null +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/error.swift @@ -0,0 +1,8 @@ +// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/derivatives-error.swift -module-name main -o /dev/null + +import _Differentiation + +@differentiable(reverse) +func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double { + return Struct.max(min(value, upperBound), lowerBound) // expected-error {{cannot find 'Struct' in scope}} +} diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/main.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/main.swift index 071510d508373..7f4604176c0c6 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/main.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/main.swift @@ -1,9 +1,10 @@ -// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/derivatives.swift -module-name main -o /dev/null +// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s \ +// RUN: %S/Inputs/derivatives.swift %S/Inputs/struct.swift -module-name main -o /dev/null import _Differentiation @differentiable(reverse) func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double { // No error expected - return max(min(value, upperBound), lowerBound) + return Struct.max(min(value, upperBound), lowerBound) } From bcedb71f4daa10bc546b17c7f18a45ca27be59db Mon Sep 17 00:00:00 2001 From: Daniil Kovalev Date: Wed, 23 Oct 2024 19:54:20 +0300 Subject: [PATCH 2/2] Address review comment Co-authored-by: Anton Korobeynikov --- lib/Sema/TypeCheckAttr.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 2709b4264c63b..11905a8b39974 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -5459,7 +5459,7 @@ static AbstractFunctionDecl *findAutoDiffOriginalFunctionDecl( if (!baseType && lookupContext->isTypeContext()) baseType = lookupContext->getSelfTypeInContext(); if (baseType) { - if (!baseType.getPointer()->hasError()) + if (!baseType->hasError()) results = TypeChecker::lookupMember(lookupContext, baseType, funcName); } else { results = TypeChecker::lookupUnqualified(