Skip to content

Commit fc640db

Browse files
authored
[AutoDiff] Emit a diagnostic for non-differentiable active values (#67697)
For some values we cannot compute types for differentiation (for example, tangent vector type), so it is better to diagnose them earlier. Otherwise we hit assertions when generating code for such invalid values. The LIT test is a reduced reproducer from the issue #66996. Before the patch the compiler crashed while trying to get a tangent vector type for the following value (partial_apply): %54 = function_ref @$s4null1o2ffSdAA1FV_tFSdyKXEfu0_ : $@convention(thin) @Substituted <τ_0_0> (@inout_aliasable Double) -> (@out τ_0_0, @error any Error) for <Double> %55 = partial_apply [callee_guaranteed] %54(%2) : $@convention(thin) @Substituted <τ_0_0> (@inout_aliasable Double) -> (@out τ_0_0, @error any Error) for <Double> Now we emit a diagnostic instead. The patch resolves issues #66996 and #63331
1 parent 53d2cf1 commit fc640db

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,6 +1864,16 @@ bool PullbackCloner::Implementation::run() {
18641864
// become projections into their adjoint base buffer.
18651865
if (Projection::isAddressProjection(v))
18661866
return false;
1867+
1868+
// Check that active values are differentiable. Otherwise we may crash
1869+
// later when tangent space is required, but not available.
1870+
if (!getTangentSpace(remapType(type).getASTType())) {
1871+
getContext().emitNondifferentiabilityError(
1872+
v, getInvoker(), diag::autodiff_expression_not_differentiable_note);
1873+
errorOccurred = true;
1874+
return true;
1875+
}
1876+
18671877
// Record active value.
18681878
bbActiveValues.push_back(v);
18691879
return false;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
import _Differentiation
4+
5+
// expected-error @+1 {{function is not differentiable}}
6+
@differentiable(reverse)
7+
// expected-note @+1 {{when differentiating this function definition}}
8+
func o(ff: F) -> Double {
9+
var y = ff.i?.first { $0 >= 0.0 } ?? 0.0
10+
while 0.0 < y {
11+
// expected-note @+1 {{expression is not differentiable}}
12+
y = ff.g() ?? y
13+
}
14+
return y
15+
}
16+
17+
public struct F: Differentiable {
18+
@noDerivative var i: [Double]? {return nil}
19+
func g() -> Double? {return nil}
20+
}

0 commit comments

Comments
 (0)