Skip to content

Commit 46c2073

Browse files
authored
[AutoDiff] Start supporting loadable types with address-only tangents. (#32540)
Previously, PullbackEmitter assumed that original values' value category matches their `TangentVector` types' value category. This was problematic for loadable types with address-only `TangentVector` types. Now, PullbackEmitter starts to support differentiation of loadable types with address-only `TangentVector` types. This patch focuses on supporting and testing class types, more support can be added incrementally. Resolves TF-1149.
1 parent c106b2d commit 46c2073

File tree

8 files changed

+315
-117
lines changed

8 files changed

+315
-117
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -496,12 +496,6 @@ NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
496496
"cannot differentiate through multiple results", ())
497497
NOTE(autodiff_cannot_differentiate_through_inout_arguments,none,
498498
"cannot differentiate through 'inout' arguments", ())
499-
// TODO(TF-1149): Remove this diagnostic.
500-
NOTE(autodiff_loadable_value_addressonly_tangent_unsupported,none,
501-
"cannot yet differentiate value whose type %0 has a compile-time known "
502-
"size, but whose 'TangentVector' contains stored properties of unknown "
503-
"size; consider modifying %1 to use fewer generic parameters in stored "
504-
"properties", (Type, Type))
505499
NOTE(autodiff_enums_unsupported,none,
506500
"differentiating enum values is not yet supported", ())
507501
NOTE(autodiff_stored_property_parent_not_differentiable,none,

include/swift/SILOptimizer/Differentiation/PullbackEmitter.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
218218

219219
Optional<TangentSpace> getTangentSpace(CanType type);
220220

221+
/// Returns the tangent value category of the given value.
222+
SILValueCategory getTangentValueCategory(SILValue v);
223+
221224
/// Assuming the given type conforms to `Differentiable` after remapping,
222225
/// returns the associated tangent space type.
223226
SILType getRemappedTangentType(SILType type);
@@ -264,6 +267,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
264267
SILValue getAdjointProjection(SILBasicBlock *origBB,
265268
SILValue originalProjection);
266269

270+
SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer);
271+
267272
SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint();
268273

269274
/// Creates and returns a local allocation with the given type.
@@ -273,8 +278,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
273278
/// `destroyedLocalAllocations` are also destroyed in the pullback exit.
274279
AllocStackInst *createFunctionLocalAllocation(SILType type, SILLocation loc);
275280

276-
SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer);
277-
278281
/// Accumulates `rhsBufferAccess` into the adjoint buffer corresponding to
279282
/// `originalBuffer`.
280283
void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer,

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 201 additions & 98 deletions
Large diffs are not rendered by default.

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,23 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
380380
for (unsigned resIdx : range(toType->getNumResults())) {
381381
auto fromRes = fromConv.getResults()[resIdx];
382382
auto toRes = toConv.getResults()[resIdx];
383+
// Check function-typed results.
384+
if (isa<SILFunctionType>(fromRes.getInterfaceType()) &&
385+
isa<SILFunctionType>(toRes.getInterfaceType())) {
386+
auto fromFnType = cast<SILFunctionType>(fromRes.getInterfaceType());
387+
auto toFnType = cast<SILFunctionType>(toRes.getInterfaceType());
388+
auto fromUnsubstFnType = fromFnType->getUnsubstitutedType(module);
389+
auto toUnsubstFnType = toFnType->getUnsubstitutedType(module);
390+
// If unsubstituted function types are not equal, perform reabstraction.
391+
if (fromUnsubstFnType != toUnsubstFnType) {
392+
auto fromFn = *fromDirResultsIter++;
393+
auto newFromFn = reabstractFunction(
394+
builder, fb, loc, fromFn, toFnType,
395+
[](SubstitutionMap substMap) { return substMap; });
396+
results.push_back(newFromFn);
397+
continue;
398+
}
399+
}
383400
// No abstraction mismatch.
384401
if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
385402
// If result types are direct, add call result as direct thunk result.

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,23 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
11701170
loc, derivativeFn,
11711171
SILType::getPrimitiveObjectType(expectedDerivativeFnTy));
11721172
}
1173+
// If derivative function value's type is not ABI-compatible with the
1174+
// expected derivative function type (i.e. parameter and result conventions
1175+
// do not match), perform reabstraction.
1176+
auto abiCompatibility = expectedDerivativeFnTy->isABICompatibleWith(
1177+
derivativeFn->getType().castTo<SILFunctionType>(), *dfi->getFunction());
1178+
if (!abiCompatibility.isCompatible()) {
1179+
SILOptFunctionBuilder fb(context.getTransform());
1180+
auto newDerivativeFn = reabstractFunction(
1181+
builder, fb, loc, derivativeFn, expectedDerivativeFnTy,
1182+
[](SubstitutionMap substMap) { return substMap; });
1183+
derivativeFn = newDerivativeFn;
1184+
assert(expectedDerivativeFnTy
1185+
->isABICompatibleWith(
1186+
derivativeFn->getType().castTo<SILFunctionType>(),
1187+
*dfi->getFunction())
1188+
.isCompatible());
1189+
}
11731190

11741191
derivativeFns.push_back(derivativeFn);
11751192
}

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,20 +136,14 @@ func testMultipleDiffAttrsClass<C: ClassMethodMultipleDifferentiableAttribute>(
136136

137137
// TF-1149: Test class with loadable type but address-only `TangentVector` type.
138138
class C<T: Differentiable>: Differentiable {
139-
// expected-error @+1 {{function is not differentiable}}
140139
@differentiable
141-
// expected-note @+2 {{when differentiating this function definition}}
142-
// expected-note @+1 {{cannot yet differentiate value whose type 'C<T>' has a compile-time known size, but whose 'TangentVector' contains stored properties of unknown size; consider modifying 'C<τ_0_0>.TangentVector' to use fewer generic parameters in stored properties}}
143140
var stored: T
144141

145142
init(_ stored: T) {
146143
self.stored = stored
147144
}
148145

149-
// expected-error @+1 {{function is not differentiable}}
150146
@differentiable
151-
// expected-note @+2 {{when differentiating this function definition}}
152-
// expected-note @+1 {{cannot yet differentiate value whose type 'C<T>' has a compile-time known size, but whose 'TangentVector' contains stored properties of unknown size; consider modifying 'C<τ_0_0>.TangentVector' to use fewer generic parameters in stored properties}}
153147
func method(_ x: T) -> T {
154148
stored
155149
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
import DifferentiationUnittest
6+
7+
var AddressOnlyTangentVectorTests = TestSuite("AddressOnlyTangentVector")
8+
9+
// TF-1149: Test loadable class type with an address-only `TangentVector` type.
10+
11+
AddressOnlyTangentVectorTests.test("LoadableClassAddressOnlyTangentVector") {
12+
final class LoadableClass<T: Differentiable>: Differentiable {
13+
@differentiable
14+
var stored: T
15+
16+
@differentiable
17+
init(_ stored: T) {
18+
self.stored = stored
19+
}
20+
21+
@differentiable
22+
func method(_ x: T) -> T {
23+
stored
24+
}
25+
}
26+
27+
@differentiable
28+
func projection<T: Differentiable>(_ s: LoadableClass<T>) -> T {
29+
var x = s.stored
30+
return x
31+
}
32+
expectEqual(.init(stored: 1), gradient(at: LoadableClass<Float>(10), in: projection))
33+
34+
@differentiable
35+
func tuple<T: Differentiable>(_ s: LoadableClass<T>) -> T {
36+
var tuple = (s, (s, s))
37+
return tuple.1.0.stored
38+
}
39+
expectEqual(.init(stored: 1), gradient(at: LoadableClass<Float>(10), in: tuple))
40+
41+
@differentiable
42+
func conditional<T: Differentiable>(_ s: LoadableClass<T>) -> T {
43+
var tuple = (s, (s, s))
44+
if false {}
45+
return tuple.1.0.stored
46+
}
47+
expectEqual(.init(stored: 1), gradient(at: LoadableClass<Float>(10), in: conditional))
48+
49+
@differentiable
50+
func loop<T: Differentiable>(_ array: [LoadableClass<T>]) -> T {
51+
var result: [LoadableClass<T>] = []
52+
for i in withoutDerivative(at: array.indices) {
53+
result.append(array[i])
54+
}
55+
return result[0].stored
56+
}
57+
expectEqual([.init(stored: 1)], gradient(at: [LoadableClass<Float>(10)], in: loop))
58+
59+
@differentiable
60+
func arrayLiteral<T: Differentiable>(_ s: LoadableClass<T>) -> T {
61+
var result: [[LoadableClass<T>]] = [[s, s]]
62+
return result[0][1].stored
63+
}
64+
expectEqual(.init(stored: 1), gradient(at: LoadableClass<Float>(10), in: arrayLiteral))
65+
}
66+
67+
runAllTests()

test/AutoDiff/validation-test/property_wrappers.swift

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,7 @@ PropertyWrapperTests.test("GenericStruct") {
9898
*/
9999
}
100100

101-
// FIXME(TF-1149): Cannot differentiate active value with loadable type but
102-
// address-only tangent type. Triggered by marking properties with
103-
// `@differentiable`, which triggers derivative vtable thunk entries.
104-
/*
101+
// TF-1149: Test class with loadable type but address-only `TangentVector` type.
105102
class Class: Differentiable {
106103
@differentiable
107104
@Wrapper @Wrapper var x: Tracked<Float> = 10
@@ -124,8 +121,15 @@ PropertyWrapperTests.test("SimpleClass") {
124121
c.x = c.x * x * c.z
125122
return c.x
126123
}
124+
// FIXME(TF-1175): Class operands should always be marked active.
125+
// This is relevant for `Class.x.setter`, which has type
126+
// `$@convention(method) (@in Tracked<Float>, @guaranteed Class) -> ()`.
127+
expectEqual((.init(x: 1, y: 0, z: 0), 0),
128+
gradient(at: Class(), 2, in: setter))
129+
/*
127130
expectEqual((.init(x: 60, y: 0, z: 20), 300),
128131
gradient(at: Class(), 2, in: setter))
132+
*/
129133

130134
// TODO(SR-12640): Support `modify` accessors.
131135
/*
@@ -138,7 +142,6 @@ PropertyWrapperTests.test("SimpleClass") {
138142
gradient(at: Class(), 2, in: modify))
139143
*/
140144
}
141-
*/
142145

143146
// From: https://github.com/apple/swift-evolution/blob/master/proposals/0258-property-wrappers.md#proposed-solution
144147
// Tests the following functionality:

0 commit comments

Comments
 (0)