Skip to content

Commit c7a2160

Browse files
aslasavonicrxwei
authored
[AutoDiff] First cut of coroutines differentiation (#71461)
This PR implements first set of changes required to support autodiff for coroutines. It mostly targeted to `_modify` accessors in standard library (and beyond), but overall implementation is quite generic. There are some specifics of implementation and known limitations: - Only `@yield_once` coroutines are naturally supported - VJP is a coroutine itself: it yields the results *and* returns a pullback closure as a normal return. This allows us to capture values produced in resume part of a coroutine (this is required for defers and other cleanups / commits) - Pullback is a coroutine, we assume that coroutine cannot abort and therefore we execute the original coroutine in reverse from return via yield and then back to the entry - It seems there is no semantically sane way to support `_read` coroutines (as we will need to "accept" adjoints via yields), therefore only coroutines with inout yields are supported (`_modify` accessors). Pullbacks of such coroutines take adjoint buffer as input argument, yield this buffer (to accumulate adjoint values in the caller) and finally return the adjoints indirectly. - Coroutines (as opposed to normal functions) are not first-class values: there is no AST type for them, one cannot e.g. store them into tuples, etc. So, everywhere where AST type is required, we have to hack around. - As there is no AST type for coroutines, there is no way one could register custom derivative for coroutines. So far only compiler-produced derivatives are supported - There are lots of common things wrt normal function apply's, but still there are subtle but important differences. I tried to organize the code to enable code reuse, still it was not always possible, so some code duplication could be seen - The order of how pullback closures are produced in VJP is a bit different: for normal apply's VJP produces both value and pullback closure via a single nested VJP apply. This is not so anymore with coroutine VJP's: yielded values are produced at `begin_apply` site and pullback closure is available only from `end_apply`, so we need to track the order in which pullbacks are produced (and arrange consumption of the values accordingly – effectively delay them) - On the way some complementary changes were required in e.g. mangler / demangler This patch covers the generation of derivatives up to SIL level, however, it is not enough as codegen of `partial_apply` of a coroutine is completely broken. The fix for this will be submitted separately as it is not directly autodiff-related. --------- Co-authored-by: Andrew Savonichev <[email protected]> Co-authored-by: Richard Wei <[email protected]>
1 parent b3b2f37 commit c7a2160

30 files changed

+1044
-302
lines changed

include/swift/AST/ASTDemangler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ class ASTBuilder {
150150

151151
Type createImplFunctionType(
152152
Demangle::ImplParameterConvention calleeConvention,
153+
Demangle::ImplCoroutineKind coroutineKind,
153154
ArrayRef<Demangle::ImplFunctionParam<Type>> params,
155+
ArrayRef<Demangle::ImplFunctionYield<Type>> yields,
154156
ArrayRef<Demangle::ImplFunctionResult<Type>> results,
155157
std::optional<Demangle::ImplFunctionResult<Type>> errorResult,
156158
ImplFunctionTypeFlags flags);

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,8 @@ NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
607607
"cannot differentiate through multiple results", ())
608608
NOTE(autodiff_cannot_differentiate_through_inout_arguments,none,
609609
"cannot differentiate through 'inout' arguments", ())
610+
NOTE(autodiff_cannot_differentiate_through_direct_yield,none,
611+
"cannot differentiate through a '_read' accessor", ())
610612
NOTE(autodiff_enums_unsupported,none,
611613
"differentiating enum values is not yet supported", ())
612614
NOTE(autodiff_stored_property_parent_not_differentiable,none,

include/swift/AST/IndexSubset.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ class IndexSubset : public llvm::FoldingSetNode {
108108
static IndexSubset *get(ASTContext &ctx, unsigned capacity,
109109
ArrayRef<unsigned> indices) {
110110
SmallBitVector indicesBitVec(capacity, false);
111-
for (auto index : indices)
111+
for (auto index : indices) {
112+
assert(index < capacity);
112113
indicesBitVec.set(index);
114+
}
113115
return IndexSubset::get(ctx, indicesBitVec);
114116
}
115117

include/swift/AST/Types.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5174,8 +5174,11 @@ class SILFunctionType final
51745174
/// Returns the number of function potential semantic results:
51755175
/// * Usual results
51765176
/// * Inout parameters
5177+
/// * yields
51775178
unsigned getNumAutoDiffSemanticResults() const {
5178-
return getNumResults() + getNumAutoDiffSemanticResultsParameters();
5179+
return getNumResults() +
5180+
getNumAutoDiffSemanticResultsParameters() +
5181+
getNumYields();
51795182
}
51805183

51815184
/// Get the generic signature that the component types are specified

include/swift/Demangling/Demangle.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ struct [[nodiscard]] ManglingError {
553553
UnknownEncoding,
554554
InvalidImplCalleeConvention,
555555
InvalidImplDifferentiability,
556+
InvalidImplCoroutineKind,
556557
InvalidImplFunctionAttribute,
557558
InvalidImplParameterConvention,
558559
InvalidImplParameterTransferring,

include/swift/Demangling/DemangleNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ NODE(ImplFunctionAttribute)
139139
NODE(ImplFunctionConvention)
140140
NODE(ImplFunctionConventionName)
141141
NODE(ImplFunctionType)
142+
NODE(ImplCoroutineKind)
142143
NODE(ImplInvocationSubstitutions)
143144
CONTEXT_NODE(ImplicitClosure)
144145
NODE(ImplParameter)

include/swift/Demangling/TypeDecoder.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ enum class ImplMetatypeRepresentation {
4848
ObjC,
4949
};
5050

51+
enum class ImplCoroutineKind {
52+
None,
53+
YieldOnce,
54+
YieldMany,
55+
};
56+
5157
/// Describe a function parameter, parameterized on the type
5258
/// representation.
5359
template <typename BuiltType>
@@ -188,6 +194,9 @@ class ImplFunctionParam {
188194
BuiltType getType() const { return Type; }
189195
};
190196

197+
template<typename Type>
198+
using ImplFunctionYield = ImplFunctionParam<Type>;
199+
191200
enum class ImplResultConvention {
192201
Indirect,
193202
Owned,
@@ -1023,9 +1032,11 @@ class TypeDecoder {
10231032
case NodeKind::ImplFunctionType: {
10241033
auto calleeConvention = ImplParameterConvention::Direct_Unowned;
10251034
llvm::SmallVector<ImplFunctionParam<BuiltType>, 8> parameters;
1035+
llvm::SmallVector<ImplFunctionYield<BuiltType>, 8> yields;
10261036
llvm::SmallVector<ImplFunctionResult<BuiltType>, 8> results;
10271037
llvm::SmallVector<ImplFunctionResult<BuiltType>, 8> errorResults;
10281038
ImplFunctionTypeFlags flags;
1039+
ImplCoroutineKind coroutineKind = ImplCoroutineKind::None;
10291040

10301041
for (unsigned i = 0; i < Node->getNumChildren(); i++) {
10311042
auto child = Node->getChild(i);
@@ -1066,6 +1077,15 @@ class TypeDecoder {
10661077
} else if (child->getText() == "@async") {
10671078
flags = flags.withAsync();
10681079
}
1080+
} else if (child->getKind() == NodeKind::ImplCoroutineKind) {
1081+
if (!child->hasText())
1082+
return MAKE_NODE_TYPE_ERROR0(child, "expected text");
1083+
if (child->getText() == "yield_once") {
1084+
coroutineKind = ImplCoroutineKind::YieldOnce;
1085+
} else if (child->getText() == "yield_many") {
1086+
coroutineKind = ImplCoroutineKind::YieldMany;
1087+
} else
1088+
return MAKE_NODE_TYPE_ERROR0(child, "failed to decode coroutine kind");
10691089
} else if (child->getKind() == NodeKind::ImplDifferentiabilityKind) {
10701090
ImplFunctionDifferentiabilityKind implDiffKind;
10711091
switch ((MangledDifferentiabilityKind)child->getIndex()) {
@@ -1088,10 +1108,14 @@ class TypeDecoder {
10881108
if (decodeImplFunctionParam(child, depth + 1, parameters))
10891109
return MAKE_NODE_TYPE_ERROR0(child,
10901110
"failed to decode function parameter");
1111+
} else if (child->getKind() == NodeKind::ImplYield) {
1112+
if (decodeImplFunctionParam(child, depth + 1, yields))
1113+
return MAKE_NODE_TYPE_ERROR0(child,
1114+
"failed to decode function yields");
10911115
} else if (child->getKind() == NodeKind::ImplResult) {
10921116
if (decodeImplFunctionParam(child, depth + 1, results))
10931117
return MAKE_NODE_TYPE_ERROR0(child,
1094-
"failed to decode function parameter");
1118+
"failed to decode function results");
10951119
} else if (child->getKind() == NodeKind::ImplErrorResult) {
10961120
if (decodeImplFunctionPart(child, depth + 1, errorResults))
10971121
return MAKE_NODE_TYPE_ERROR0(child,
@@ -1115,11 +1139,10 @@ class TypeDecoder {
11151139

11161140
// TODO: Some cases not handled above, but *probably* they cannot
11171141
// appear as the types of values in SIL (yet?):
1118-
// - functions with yield returns
11191142
// - functions with generic signatures
11201143
// - foreign error conventions
1121-
return Builder.createImplFunctionType(calleeConvention,
1122-
parameters, results,
1144+
return Builder.createImplFunctionType(calleeConvention, coroutineKind,
1145+
parameters, yields, results,
11231146
errorResult, flags);
11241147
}
11251148

include/swift/RemoteInspection/TypeRefBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,9 @@ class TypeRefBuilder {
11341134

11351135
const FunctionTypeRef *createImplFunctionType(
11361136
Demangle::ImplParameterConvention calleeConvention,
1137+
Demangle::ImplCoroutineKind coroutineKind,
11371138
llvm::ArrayRef<Demangle::ImplFunctionParam<const TypeRef *>> params,
1139+
llvm::ArrayRef<Demangle::ImplFunctionYield<const TypeRef *>> yields,
11381140
llvm::ArrayRef<Demangle::ImplFunctionResult<const TypeRef *>> results,
11391141
std::optional<Demangle::ImplFunctionResult<const TypeRef *>> errorResult,
11401142
ImplFunctionTypeFlags flags) {

include/swift/SIL/SILFunctionConventions.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,14 @@ class SILFunctionConventions {
248248
idx < indirectResults + getNumIndirectSILErrorResults();
249249
}
250250

251+
unsigned getNumAutoDiffSemanticResults() const {
252+
return funcTy->getNumAutoDiffSemanticResults();
253+
}
254+
255+
unsigned getNumAutoDiffSemanticResultParameters() const {
256+
return funcTy->getNumAutoDiffSemanticResultsParameters();
257+
}
258+
251259
/// Are any SIL results passed as address-typed arguments?
252260
bool hasIndirectSILResults() const { return getNumIndirectSILResults() != 0; }
253261
bool hasIndirectSILErrorResults() const { return getNumIndirectSILErrorResults() != 0; }

include/swift/SILOptimizer/Differentiation/ADContext.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
1818
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
1919

20+
#include "swift/SIL/ApplySite.h"
2021
#include "swift/SILOptimizer/Differentiation/Common.h"
2122
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2223

@@ -51,6 +52,12 @@ struct NestedApplyInfo {
5152
/// The original pullback type before reabstraction. `None` if the pullback
5253
/// type is not reabstracted.
5354
std::optional<CanSILFunctionType> originalPullbackType;
55+
/// Index of `apply` pullback in nested pullback call
56+
unsigned pullbackIdx = -1U;
57+
/// Pullback value itself that is memoized in some cases (e.g. pullback is
58+
/// called by `begin_apply`, but should be destroyed after `end_apply`).
59+
SILValue pullback = SILValue();
60+
SILValue beginApplyToken = SILValue();
5461
};
5562

5663
/// Per-module contextual information for the Differentiation pass.
@@ -97,7 +104,7 @@ class ADContext {
97104

98105
/// Mapping from original `apply` instructions to their corresponding
99106
/// `NestedApplyInfo`s.
100-
llvm::DenseMap<ApplyInst *, NestedApplyInfo> nestedApplyInfo;
107+
llvm::DenseMap<FullApplySite, NestedApplyInfo> nestedApplyInfo;
101108

102109
/// List of generated functions (JVPs, VJPs, pullbacks, and thunks).
103110
/// Saved for deletion during cleanup.
@@ -185,7 +192,7 @@ class ADContext {
185192
invokers.insert({witness, DifferentiationInvoker(witness)});
186193
}
187194

188-
llvm::DenseMap<ApplyInst *, NestedApplyInfo> &getNestedApplyInfo() {
195+
llvm::DenseMap<FullApplySite, NestedApplyInfo> &getNestedApplyInfo() {
189196
return nestedApplyInfo;
190197
}
191198

0 commit comments

Comments
 (0)