Skip to content

[AutoDiff] First cut of coroutines differentiation #71461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/ASTDemangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ class ASTBuilder {

Type createImplFunctionType(
Demangle::ImplParameterConvention calleeConvention,
Demangle::ImplCoroutineKind coroutineKind,
ArrayRef<Demangle::ImplFunctionParam<Type>> params,
ArrayRef<Demangle::ImplFunctionYield<Type>> yields,
ArrayRef<Demangle::ImplFunctionResult<Type>> results,
std::optional<Demangle::ImplFunctionResult<Type>> errorResult,
ImplFunctionTypeFlags flags);
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
"cannot differentiate through multiple results", ())
NOTE(autodiff_cannot_differentiate_through_inout_arguments,none,
"cannot differentiate through 'inout' arguments", ())
NOTE(autodiff_cannot_differentiate_through_direct_yield,none,
"cannot differentiate through a direct yield result", ())
NOTE(autodiff_enums_unsupported,none,
"differentiating enum values is not yet supported", ())
NOTE(autodiff_stored_property_parent_not_differentiable,none,
Expand Down
4 changes: 3 additions & 1 deletion include/swift/AST/IndexSubset.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ class IndexSubset : public llvm::FoldingSetNode {
static IndexSubset *get(ASTContext &ctx, unsigned capacity,
ArrayRef<unsigned> indices) {
SmallBitVector indicesBitVec(capacity, false);
for (auto index : indices)
for (auto index : indices) {
assert(index < capacity);
indicesBitVec.set(index);
}
return IndexSubset::get(ctx, indicesBitVec);
}

Expand Down
5 changes: 4 additions & 1 deletion include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -5174,8 +5174,11 @@ class SILFunctionType final
/// Returns the number of function potential semantic results:
/// * Usual results
/// * Inout parameters
/// * yields
unsigned getNumAutoDiffSemanticResults() const {
return getNumResults() + getNumAutoDiffSemanticResultsParameters();
return getNumResults() +
getNumAutoDiffSemanticResultsParameters() +
getNumYields();
}

/// Get the generic signature that the component types are specified
Expand Down
1 change: 1 addition & 0 deletions include/swift/Demangling/Demangle.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ struct [[nodiscard]] ManglingError {
UnknownEncoding,
InvalidImplCalleeConvention,
InvalidImplDifferentiability,
InvalidImplCoroutineKind,
InvalidImplFunctionAttribute,
InvalidImplParameterConvention,
InvalidImplParameterTransferring,
Expand Down
1 change: 1 addition & 0 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ NODE(ImplFunctionAttribute)
NODE(ImplFunctionConvention)
NODE(ImplFunctionConventionName)
NODE(ImplFunctionType)
NODE(ImplCoroutineKind)
NODE(ImplInvocationSubstitutions)
CONTEXT_NODE(ImplicitClosure)
NODE(ImplParameter)
Expand Down
31 changes: 27 additions & 4 deletions include/swift/Demangling/TypeDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ enum class ImplMetatypeRepresentation {
ObjC,
};

enum class ImplCoroutineKind {
None,
YieldOnce,
YieldMany,
};

/// Describe a function parameter, parameterized on the type
/// representation.
template <typename BuiltType>
Expand Down Expand Up @@ -188,6 +194,9 @@ class ImplFunctionParam {
BuiltType getType() const { return Type; }
};

template<typename Type>
using ImplFunctionYield = ImplFunctionParam<Type>;

enum class ImplResultConvention {
Indirect,
Owned,
Expand Down Expand Up @@ -1023,9 +1032,11 @@ class TypeDecoder {
case NodeKind::ImplFunctionType: {
auto calleeConvention = ImplParameterConvention::Direct_Unowned;
llvm::SmallVector<ImplFunctionParam<BuiltType>, 8> parameters;
llvm::SmallVector<ImplFunctionYield<BuiltType>, 8> yields;
llvm::SmallVector<ImplFunctionResult<BuiltType>, 8> results;
llvm::SmallVector<ImplFunctionResult<BuiltType>, 8> errorResults;
ImplFunctionTypeFlags flags;
ImplCoroutineKind coroutineKind = ImplCoroutineKind::None;

for (unsigned i = 0; i < Node->getNumChildren(); i++) {
auto child = Node->getChild(i);
Expand Down Expand Up @@ -1066,6 +1077,15 @@ class TypeDecoder {
} else if (child->getText() == "@async") {
flags = flags.withAsync();
}
} else if (child->getKind() == NodeKind::ImplCoroutineKind) {
if (!child->hasText())
return MAKE_NODE_TYPE_ERROR0(child, "expected text");
if (child->getText() == "yield_once") {
coroutineKind = ImplCoroutineKind::YieldOnce;
} else if (child->getText() == "yield_many") {
coroutineKind = ImplCoroutineKind::YieldMany;
} else
return MAKE_NODE_TYPE_ERROR0(child, "failed to decode coroutine kind");
} else if (child->getKind() == NodeKind::ImplDifferentiabilityKind) {
ImplFunctionDifferentiabilityKind implDiffKind;
switch ((MangledDifferentiabilityKind)child->getIndex()) {
Expand All @@ -1088,10 +1108,14 @@ class TypeDecoder {
if (decodeImplFunctionParam(child, depth + 1, parameters))
return MAKE_NODE_TYPE_ERROR0(child,
"failed to decode function parameter");
} else if (child->getKind() == NodeKind::ImplYield) {
if (decodeImplFunctionParam(child, depth + 1, yields))
return MAKE_NODE_TYPE_ERROR0(child,
"failed to decode function yields");
} else if (child->getKind() == NodeKind::ImplResult) {
if (decodeImplFunctionParam(child, depth + 1, results))
return MAKE_NODE_TYPE_ERROR0(child,
"failed to decode function parameter");
"failed to decode function results");
} else if (child->getKind() == NodeKind::ImplErrorResult) {
if (decodeImplFunctionPart(child, depth + 1, errorResults))
return MAKE_NODE_TYPE_ERROR0(child,
Expand All @@ -1115,11 +1139,10 @@ class TypeDecoder {

// TODO: Some cases not handled above, but *probably* they cannot
// appear as the types of values in SIL (yet?):
// - functions with yield returns
// - functions with generic signatures
// - foreign error conventions
return Builder.createImplFunctionType(calleeConvention,
parameters, results,
return Builder.createImplFunctionType(calleeConvention, coroutineKind,
parameters, yields, results,
errorResult, flags);
}

Expand Down
2 changes: 2 additions & 0 deletions include/swift/RemoteInspection/TypeRefBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,9 @@ class TypeRefBuilder {

const FunctionTypeRef *createImplFunctionType(
Demangle::ImplParameterConvention calleeConvention,
Demangle::ImplCoroutineKind coroutineKind,
llvm::ArrayRef<Demangle::ImplFunctionParam<const TypeRef *>> params,
llvm::ArrayRef<Demangle::ImplFunctionYield<const TypeRef *>> yields,
llvm::ArrayRef<Demangle::ImplFunctionResult<const TypeRef *>> results,
std::optional<Demangle::ImplFunctionResult<const TypeRef *>> errorResult,
ImplFunctionTypeFlags flags) {
Expand Down
8 changes: 8 additions & 0 deletions include/swift/SIL/SILFunctionConventions.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,14 @@ class SILFunctionConventions {
idx < indirectResults + getNumIndirectSILErrorResults();
}

unsigned getNumAutoDiffSemanticResults() const {
return funcTy->getNumAutoDiffSemanticResults();
}

unsigned getNumAutoDiffSemanticResultParameters() const {
return funcTy->getNumAutoDiffSemanticResultsParameters();
}

/// Are any SIL results passed as address-typed arguments?
bool hasIndirectSILResults() const { return getNumIndirectSILResults() != 0; }
bool hasIndirectSILErrorResults() const { return getNumIndirectSILErrorResults() != 0; }
Expand Down
11 changes: 9 additions & 2 deletions include/swift/SILOptimizer/Differentiation/ADContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H

#include "swift/SIL/ApplySite.h"
#include "swift/SILOptimizer/Differentiation/Common.h"
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"

Expand Down Expand Up @@ -51,6 +52,12 @@ struct NestedApplyInfo {
/// The original pullback type before reabstraction. `None` if the pullback
/// type is not reabstracted.
std::optional<CanSILFunctionType> originalPullbackType;
/// Index of `apply` pullback in nested pullback call
unsigned pullbackIdx = -1U;
/// Pullback value itself that is memoized in some cases (e.g. pullback is
/// called by `begin_apply`, but should be destroyed after `end_apply`).
SILValue pullback = SILValue();
SILValue beginApplyToken = SILValue();
};

/// Per-module contextual information for the Differentiation pass.
Expand Down Expand Up @@ -97,7 +104,7 @@ class ADContext {

/// Mapping from original `apply` instructions to their corresponding
/// `NestedApplyInfo`s.
llvm::DenseMap<ApplyInst *, NestedApplyInfo> nestedApplyInfo;
llvm::DenseMap<FullApplySite, NestedApplyInfo> nestedApplyInfo;

/// List of generated functions (JVPs, VJPs, pullbacks, and thunks).
/// Saved for deletion during cleanup.
Expand Down Expand Up @@ -185,7 +192,7 @@ class ADContext {
invokers.insert({witness, DifferentiationInvoker(witness)});
}

llvm::DenseMap<ApplyInst *, NestedApplyInfo> &getNestedApplyInfo() {
llvm::DenseMap<FullApplySite, NestedApplyInfo> &getNestedApplyInfo() {
return nestedApplyInfo;
}

Expand Down
5 changes: 3 additions & 2 deletions include/swift/SILOptimizer/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/Expr.h"
#include "swift/AST/SemanticAttrs.h"
#include "swift/SIL/ApplySite.h"
#include "swift/SIL/SILDifferentiabilityWitness.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/Projection.h"
Expand Down Expand Up @@ -112,15 +113,15 @@ void collectAllDirectResultsInTypeOrder(SILFunction &function,
/// Given a function call site, gathers all of its actual results (both direct
/// and indirect) in an order defined by its result type.
void collectAllActualResultsInTypeOrder(
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
FullApplySite fai, ArrayRef<SILValue> extractedDirectResults,
SmallVectorImpl<SILValue> &results);

/// For an `apply` instruction with active results, compute:
/// - The results of the `apply` instruction, in type order.
/// - The set of minimal parameter and result indices for differentiating the
/// `apply` instruction.
void collectMinimalIndicesForFunctionCall(
ApplyInst *ai, const AutoDiffConfig &parentConfig,
FullApplySite fai, const AutoDiffConfig &parentConfig,
const DifferentiableActivityInfo &activityInfo,
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
SmallVectorImpl<unsigned> &resultIndices);
Expand Down
24 changes: 12 additions & 12 deletions include/swift/SILOptimizer/Differentiation/LinearMapInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ class LinearMapInfo {
/// For differentials: these are successor enums.
llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;

/// Mapping from `apply` instructions in the original function to the
/// Mapping from `apply` / `begin_apply` instructions in the original function to the
/// corresponding linear map tuple type index.
llvm::DenseMap<ApplyInst *, unsigned> linearMapIndexMap;
llvm::DenseMap<FullApplySite, unsigned> linearMapIndexMap;

/// Mapping from predecessor-successor basic block pairs in the original
/// function to the corresponding branching trace enum case.
Expand Down Expand Up @@ -112,9 +112,9 @@ class LinearMapInfo {
void populateBranchingTraceDecl(SILBasicBlock *originalBB,
SILLoopInfo *loopInfo);

/// Given an `apply` instruction, conditionally gets a linear map tuple field
/// AST type for its linear map function if it is active.
Type getLinearMapType(ADContext &context, ApplyInst *ai);
/// Given an `apply` / `begin_apply` instruction, conditionally gets a linear
/// map tuple field AST type for its linear map function if it is active.
Type getLinearMapType(ADContext &context, FullApplySite fai);

/// Generates linear map struct and branching enum declarations for the given
/// function. Linear map structs are populated with linear map fields and a
Expand Down Expand Up @@ -180,18 +180,18 @@ class LinearMapInfo {
}

/// Finds the linear map index in the pullback tuple for the given
/// `apply` instruction in the original function.
unsigned lookUpLinearMapIndex(ApplyInst *ai) const {
assert(ai->getFunction() == original);
auto lookup = linearMapIndexMap.find(ai);
/// `apply` / `begin_apply` instruction in the original function.
unsigned lookUpLinearMapIndex(FullApplySite fas) const {
assert(fas->getFunction() == original);
auto lookup = linearMapIndexMap.find(fas);
assert(lookup != linearMapIndexMap.end() &&
"No linear map field corresponding to the given `apply`");
return lookup->getSecond();
}

Type lookUpLinearMapType(ApplyInst *ai) const {
unsigned idx = lookUpLinearMapIndex(ai);
return getLinearMapTupleType(ai->getParentBlock())->getElement(idx).getType();
Type lookUpLinearMapType(FullApplySite fas) const {
unsigned idx = lookUpLinearMapIndex(fas);
return getLinearMapTupleType(fas->getParent())->getElement(idx).getType();
}

bool hasHeapAllocatedContext() const {
Expand Down
5 changes: 5 additions & 0 deletions include/swift/SILOptimizer/Differentiation/Thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
CanSILFunctionType fromType,
CanSILFunctionType toType);

SILValue reabstractCoroutine(
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
SILValue fn, CanSILFunctionType toType,
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions);

/// Reabstracts the given function-typed value `fn` to the target type `toType`.
/// Remaps substitutions using `remapSubstitutions`.
SILValue reabstractFunction(
Expand Down
25 changes: 24 additions & 1 deletion lib/AST/ASTDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,17 +571,33 @@ getResultOptions(ImplResultInfoOptions implOptions) {
return result;
}

static SILCoroutineKind
getCoroutineKind(ImplCoroutineKind kind) {
switch (kind) {
case ImplCoroutineKind::None:
return SILCoroutineKind::None;
case ImplCoroutineKind::YieldOnce:
return SILCoroutineKind::YieldOnce;
case ImplCoroutineKind::YieldMany:
return SILCoroutineKind::YieldMany;
}
llvm_unreachable("unknown coroutine kind");
}

Type ASTBuilder::createImplFunctionType(
Demangle::ImplParameterConvention calleeConvention,
Demangle::ImplCoroutineKind coroutineKind,
ArrayRef<Demangle::ImplFunctionParam<Type>> params,
ArrayRef<Demangle::ImplFunctionYield<Type>> yields,
ArrayRef<Demangle::ImplFunctionResult<Type>> results,
std::optional<Demangle::ImplFunctionResult<Type>> errorResult,
ImplFunctionTypeFlags flags) {
GenericSignature genericSig;

SILCoroutineKind funcCoroutineKind = SILCoroutineKind::None;
ParameterConvention funcCalleeConvention =
getParameterConvention(calleeConvention);
SILCoroutineKind funcCoroutineKind =
getCoroutineKind(coroutineKind);

SILFunctionTypeRepresentation representation;
switch (flags.getRepresentation()) {
Expand Down Expand Up @@ -644,6 +660,13 @@ Type ASTBuilder::createImplFunctionType(
funcParams.emplace_back(type, conv, options);
}

for (const auto &yield : yields) {
auto type = yield.getType()->getCanonicalType();
auto conv = getParameterConvention(yield.getConvention());
auto options = *getParameterOptions(yield.getOptions());
funcParams.emplace_back(type, conv, options);
}

for (const auto &result : results) {
auto type = result.getType()->getCanonicalType();
auto conv = getResultConvention(result.getConvention());
Expand Down
34 changes: 23 additions & 11 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
llvm_unreachable("invalid derivative kind");
}

void AutoDiffConfig::dump() const {
print(llvm::errs());
}

void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << "(parameters=";
parameterIndices->print(s);
Expand Down Expand Up @@ -354,22 +358,30 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
// Require differentiability results to conform to `Differentiable`.
SmallVector<SILResultInfo, 2> originalResults;
getSemanticResults(originalFnTy, diffParamIndices, originalResults);
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
originalFnTy->getNumAutoDiffSemanticResultsParameters();
for (unsigned resultIdx : diffResultIndices->getIndices()) {
// Handle formal original result.
if (resultIdx < originalFnTy->getNumResults()) {
if (resultIdx < firstSemanticParamResultIdx) {
auto resultType = originalResults[resultIdx].getInterfaceType();
addRequirement(resultType);
continue;
} else if (resultIdx < firstYieldResultIndex) {
// Handle original semantic result parameters.
auto resultParamIndex = resultIdx - originalFnTy->getNumResults();
auto resultParamIt = std::next(
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
resultParamIndex);
auto paramIndex =
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType());
} else {
// Handle formal original yields.
assert(originalFnTy->isCoroutine());
assert(originalFnTy->getCoroutineKind() == SILCoroutineKind::YieldOnce);
auto yieldResultIndex = resultIdx - firstYieldResultIndex;
addRequirement(originalFnTy->getYields()[yieldResultIndex].getInterfaceType());
}
// Handle original semantic result parameters.
// FIXME: Constraint generic yields when we will start supporting them
auto resultParamIndex = resultIdx - originalFnTy->getNumResults();
auto resultParamIt = std::next(
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
resultParamIndex);
auto paramIndex =
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType());
}

return buildGenericSignature(ctx, derivativeGenSig,
Expand Down
Loading