Skip to content

[AutoDiff] Initial support for differentiation of throwing functions #82653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/swift/SILOptimizer/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ inline void createEntryArguments(SILFunction *f) {
indResTy = indResTy.mapTypeOutOfContext();
createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType());
}
if (auto indErrorResTy = conv.getIndirectErrorResultType(f->getTypeExpansionContext())) {
if (indErrorResTy.hasArchetype())
indErrorResTy = indErrorResTy.mapTypeOutOfContext();
createFunctionArgument(f->mapTypeIntoContext(indErrorResTy).getAddressType());
}

for (auto paramTy : conv.getParameterSILTypes(f->getTypeExpansionContext())) {
if (paramTy.hasArchetype())
paramTy = paramTy.mapTypeOutOfContext();
Expand Down
2 changes: 2 additions & 0 deletions lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2892,6 +2892,7 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
if (!autodiff::getBuiltinApplyDerivativeConfig(
OperationName, kind, arity, throws))
return nullptr;
// TODO: Support somehow typed throws
return getAutoDiffApplyDerivativeFunction(Context, Id, kind, arity,
throws, /*thrownType=*/Type());
}
Expand All @@ -2901,6 +2902,7 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
if (!autodiff::getBuiltinApplyTransposeConfig(
OperationName, arity, throws))
return nullptr;
// TODO: Support somehow typed throws
return getAutoDiffApplyTransposeFunction(Context, Id, arity, throws,
/*thrownType=*/Type());
}
Expand Down
24 changes: 13 additions & 11 deletions lib/SILGen/SILGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1191,11 +1191,8 @@ static ManagedValue emitBuiltinTypeTrait(SILGenFunction &SGF,

static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
AutoDiffDerivativeFunctionKind kind, unsigned arity,
bool throws, SILGenFunction &SGF, SILLocation loc,
SubstitutionMap substitutions, ArrayRef<ManagedValue> args, SGFContext C) {
// FIXME(https://github.com/apple/swift/issues/54259): Support throwing functions.
assert(!throws && "Throwing functions are not yet supported");

SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions,
ArrayRef<ManagedValue> args, SGFContext C) {
auto origFnVal = args[0];
SmallVector<SILValue, 2> origFnArgVals;
for (auto& arg : args.drop_front(1))
Expand All @@ -1213,7 +1210,8 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
origFnVal = SGF.B.createBeginBorrow(loc, origFnVal);
SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract(
loc, kind, origFnVal.getValue());
auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();
SILType derivativeType = derivativeFn->getType();
auto derivativeFnType = derivativeType.castTo<SILFunctionType>();
assert(derivativeFnType->getNumResults() == 2);
assert(derivativeFnType->getNumParameters() == origFnArgVals.size());

Expand All @@ -1240,8 +1238,10 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
applyArgs.push_back(SGF.B.createTupleElementAddr(loc, indResBuffer, 0));
for (auto origFnArgVal : origFnArgVals)
applyArgs.push_back(origFnArgVal);
auto differential = SGF.B.createApply(loc, derivativeFn, SubstitutionMap(),
applyArgs);
auto differential =
SGF.emitApplyWithRethrow(loc,
derivativeFn, derivativeType,
SubstitutionMap(), applyArgs);

derivativeFn = SILValue();

Expand All @@ -1253,8 +1253,10 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
}

// Do the apply for the direct result case.
auto resultTuple = SGF.B.createApply(
loc, derivativeFn, SubstitutionMap(), origFnArgVals);
auto resultTuple =
SGF.emitApplyWithRethrow(loc,
derivativeFn, derivativeType,
SubstitutionMap(), origFnArgVals);

derivativeFn = SILValue();

Expand Down Expand Up @@ -1324,7 +1326,7 @@ static ManagedValue emitBuiltinApplyDerivative(
builtinName, kind, arity, throws);
assert(successfullyParsed);
return emitBuiltinAutoDiffApplyDerivativeFunction(
kind, arity, throws, SGF, loc, substitutions, args, C);
kind, arity, SGF, loc, substitutions, args, C);
}

static ManagedValue emitBuiltinApplyTranspose(
Expand Down
19 changes: 11 additions & 8 deletions lib/SILOptimizer/Differentiation/LinearMapInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,15 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
heapAllocatedContext = true;
decl->setInterfaceType(astCtx.TheRawPointerType);
} else { // Otherwise the payload is the linear map tuple.
auto *linearMapStructTy = getLinearMapTupleType(predBB);
auto *linearMapTupleTy = getLinearMapTupleType(predBB);
// Do not create entries for unreachable predecessors
if (!linearMapStructTy)
if (!linearMapTupleTy)
continue;
auto canLinearMapStructTy = linearMapStructTy->getCanonicalType();

auto canLinearMapTupleTy = linearMapTupleTy->getCanonicalType();
decl->setInterfaceType(
canLinearMapStructTy->hasArchetype()
? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy);
canLinearMapTupleTy->hasArchetype()
? canLinearMapTupleTy->mapTypeOutOfContext() : canLinearMapTupleTy);
}
// Create enum element and enum case declarations.
auto *paramList = ParameterList::create(astCtx, {decl});
Expand Down Expand Up @@ -183,6 +184,7 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) {
auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) {
return activityInfo.isActive(res, config);
});

bool hasActiveSemanticResultArgument = false;
bool hasActiveArguments = false;
auto numIndirectResults = fai.getNumIndirectSILResults();
Expand Down Expand Up @@ -311,10 +313,11 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) {
params, silFnTy->getAllResultsInterfaceType().getASTType(), info);
}

if (astFnTy->hasArchetype())
return astFnTy->mapTypeOutOfContext();
Type resultType = astFnTy->hasArchetype() ? astFnTy->mapTypeOutOfContext() : astFnTy;
if (fai.getKind() == FullApplySiteKind::TryApplyInst)
resultType = resultType->wrapInOptionalType();

return astFnTy;
return resultType;
}

void LinearMapInfo::generateDifferentiationDataStructures(
Expand Down
Loading