diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index f8f3423e9b110..fc96f568d56b0 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -227,6 +227,12 @@ struct AutoDiffConfig { SWIFT_DEBUG_DUMP; }; +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s, + const SILAutoDiffIndices &indices) { + indices.print(s); + return s; +} + /// A semantic function result type: either a formal function result type or /// an `inout` parameter type. Used in derivative function type calculation. struct AutoDiffSemanticFunctionResultType { diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index 65f6aadf3248b..54f74eeaef6e1 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -445,6 +445,36 @@ ERROR(not_constant_evaluable, none, "not constant evaluable", ()) ERROR(constexpr_imported_func_not_onone, none, "imported constant evaluable " "function '%0' must be annotated '@_optimize(none)'", (StringRef)) +// Differentiation transform diagnostics +ERROR(autodiff_internal_swift_not_imported,none, + "Automatic differentiation internal error: the Swift module is not " + "imported", ()) +ERROR(autodiff_differentiation_module_not_imported,none, + "Automatic differentiation requires the '_Differentiation' module to be " + "imported", ()) +ERROR(autodiff_conversion_to_linear_function_not_supported,none, + "conversion to '@differentiable(linear)' function type is not yet " + "supported", ()) +ERROR(autodiff_function_not_differentiable_error,none, + "function is not differentiable", ()) +ERROR(autodiff_expression_not_differentiable_error,none, + "expression is not differentiable", ()) +NOTE(autodiff_expression_not_differentiable_note,none, + "expression is not differentiable", ()) +NOTE(autodiff_when_differentiating_function_call,none, + "when differentiating this function call", ()) +NOTE(autodiff_when_differentiating_function_definition,none, + "when differentiating this function definition", ()) +NOTE(autodiff_implicitly_inherited_differentiable_attr_here,none, + "differentiability required by the corresponding protocol requirement " + "here", ()) +NOTE(autodiff_jvp_control_flow_not_supported,none, + "forward-mode differentiation does not yet support control flow", ()) +NOTE(autodiff_control_flow_not_supported,none, + "cannot differentiate unsupported control flow", ()) +NOTE(autodiff_missing_return,none, + "missing return for differentiation", ()) + ERROR(non_physical_addressof,none, "addressof only works with purely physical lvalues; " "use 'withUnsafePointer' or 'withUnsafeBytes' unless you're implementing " diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h index fda662f028150..1548411deedd3 100644 --- a/include/swift/Basic/LangOptions.h +++ b/include/swift/Basic/LangOptions.h @@ -327,6 +327,9 @@ namespace swift { /// `@differentiable` declaration attribute, etc. bool EnableExperimentalDifferentiableProgramming = false; + /// Whether to enable forward mode differentiation. + bool EnableExperimentalForwardModeDifferentiation = false; + /// Whether to enable experimental `AdditiveArithmetic` derived /// conformances. bool EnableExperimentalAdditiveArithmeticDerivedConformances = false; diff --git a/include/swift/Option/Options.td b/include/swift/Option/Options.td index cd73ceabafe2d..b698629f55da0 100644 --- a/include/swift/Option/Options.td +++ b/include/swift/Option/Options.td @@ -498,6 +498,14 @@ def disable_bridging_pch : Flag<["-"], "disable-bridging-pch">, HelpText<"Disable automatic generation of bridging PCH files">; // Experimental feature options + +// Note: this flag will be removed when JVP/differential generation in the +// differentiation transform is robust. +def enable_experimental_forward_mode_differentiation : + Flag<["-"], "enable-experimental-forward-mode-differentiation">, + Flags<[FrontendOption]>, + HelpText<"Enable experimental forward mode differentiation">; + def enable_experimental_additive_arithmetic_derivation : Flag<["-"], "enable-experimental-additive-arithmetic-derivation">, Flags<[FrontendOption]>, diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index 791bc98e8206b..2d4010753f4ad 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -132,6 +132,11 @@ class SILDifferentiabilityWitness bool isSerialized() const { return IsSerialized; } const DeclAttribute *getAttribute() const { return Attribute; } + /// Returns the `SILAutoDiffIndices` corresponding to this config's indices. + // TODO(TF-893): This is a temporary shim for incremental removal of + // `SILAutoDiffIndices`. Eventually remove this. + SILAutoDiffIndices getSILAutoDiffIndices() const; + /// Verify that the differentiability witness is well-formed. void verify(const SILModule &module) const; diff --git a/include/swift/SILOptimizer/PassManager/Passes.def b/include/swift/SILOptimizer/PassManager/Passes.def index 4bf6b0439f718..83db39ea944b8 100644 --- a/include/swift/SILOptimizer/PassManager/Passes.def +++ b/include/swift/SILOptimizer/PassManager/Passes.def @@ -120,6 +120,8 @@ PASS(CopyForwarding, "copy-forwarding", "Copy Forwarding to Remove Redundant Copies") PASS(CopyPropagation, "copy-propagation", "Copy propagation to Remove Redundant SSA Copies") +PASS(Differentiation, "differentiation", + "Automatic Differentiation") PASS(EpilogueARCMatcherDumper, "sil-epilogue-arc-dumper", "Print Epilogue retains of Returned Values and Argument releases") PASS(EpilogueRetainReleaseMatcherDumper, "sil-epilogue-retain-release-dumper", diff --git a/include/swift/SILOptimizer/Utils/Differentiation/ADContext.h b/include/swift/SILOptimizer/Utils/Differentiation/ADContext.h new file mode 100644 index 0000000000000..eb2e42138e8c4 --- /dev/null +++ b/include/swift/SILOptimizer/Utils/Differentiation/ADContext.h @@ -0,0 +1,376 @@ +//===--- ADContext.h - Differentiation Context ----------------*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Per-module contextual information for the differentiation transform. +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H +#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H + +#include "swift/AST/DiagnosticsSIL.h" +#include "swift/AST/Expr.h" +#include "swift/SIL/SILBuilder.h" +#include "swift/SILOptimizer/Utils/Differentiation/Common.h" +#include "swift/SILOptimizer/Utils/Differentiation/DifferentiationInvoker.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" + +namespace swift { + +class ASTContext; +class DifferentiableFunctionExpr; +class DifferentiableFunctionInst; +class FuncDecl; +class SILDifferentiabilityWitness; +class SILFunction; +class SILModuleTransform; +class SILModule; +class SILPassManager; + +namespace autodiff { + +/// Stores `apply` instruction information calculated by VJP generation. +struct NestedApplyInfo { + /// The differentiation indices that are used to differentiate this `apply` + /// instruction. + SILAutoDiffIndices indices; + /// The original pullback type before reabstraction. `None` if the pullback + /// type is not reabstracted. + Optional originalPullbackType; +}; + +/// Per-module contextual information for the Differentiation pass. +class ADContext { +private: + /// Reference to the main transform. + SILModuleTransform &transform; + + /// The module where Differentiation is performed on. + SILModule &module; + + /// AST context. + ASTContext &astCtx = module.getASTContext(); + + /// Shared pass manager. + SILPassManager &passManager; + + /// The worklist (stack) of `differentiable_function` instructions to be + /// processed. + llvm::SmallVector + differentiableFunctionInsts; + + /// The set of `differentiable_function` instructions that have been + /// processed. Used to avoid reprocessing invalidated instructions. + /// NOTE(TF-784): if we use `CanonicalizeInstruction` subclass to replace + /// `ADContext::processDifferentiableFunctionInst`, this field may be removed. + llvm::SmallPtrSet + processedDifferentiableFunctionInsts; + + /// Mapping from witnesses to invokers. + /// `SmallMapVector` is used for deterministic insertion order iteration. + llvm::SmallMapVector + invokers; + + /// Mapping from `differentiable_function` instructions to result indices. + llvm::DenseMap resultIndices; + + /// Mapping from original `apply` instructions to their corresponding + /// `NestedApplyInfo`s. + llvm::DenseMap nestedApplyInfo; + + /// List of generated functions (JVPs, VJPs, pullbacks, and thunks). + /// Saved for deletion during cleanup. + llvm::SmallVector generatedFunctions; + + /// List of references to generated functions. + /// Saved for deletion during cleanup. + llvm::SmallVector generatedFunctionReferences; + + /// The AdditiveArithmetic protocol in the standard library. + ProtocolDecl *additiveArithmeticProtocol = + astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic); + + /// `AdditiveArithmetic.+` declaration. + mutable FuncDecl *cachedPlusFn = nullptr; + /// `AdditiveArithmetic.+=` declaration. + mutable FuncDecl *cachedPlusEqualFn = nullptr; + +public: + /// Construct an ADContext for the given module. + explicit ADContext(SILModuleTransform &transform); + + //--------------------------------------------------------------------------// + // General utilities + //--------------------------------------------------------------------------// + + SILModuleTransform &getTransform() const { return transform; } + SILModule &getModule() const { return module; } + ASTContext &getASTContext() const { return module.getASTContext(); } + SILPassManager &getPassManager() const { return passManager; } + Lowering::TypeConverter &getTypeConverter() { return module.Types; } + + /// Returns true if the `differentiable_function` instruction worklist is + /// empty. + bool isDifferentiableFunctionInstsWorklistEmpty() const { + return differentiableFunctionInsts.empty(); + } + + /// Pops and returns a `differentiable_function` instruction from the + /// worklist. Returns nullptr if the worklist is empty. + DifferentiableFunctionInst *popDifferentiableFunctionInstFromWorklist() { + if (differentiableFunctionInsts.empty()) + return nullptr; + return differentiableFunctionInsts.pop_back_val(); + } + + /// Adds the given `differentiable_function` instruction to the worklist. + void + addDifferentiableFunctionInstToWorklist(DifferentiableFunctionInst *dfi) { + differentiableFunctionInsts.push_back(dfi); + } + + /// Returns true if the given `differentiable_function` instruction has + /// already been processed. + bool + isDifferentiableFunctionInstProcessed(DifferentiableFunctionInst *dfi) const { + return processedDifferentiableFunctionInsts.count(dfi); + } + + /// Adds the given `differentiable_function` instruction to the worklist. + void + markDifferentiableFunctionInstAsProcessed(DifferentiableFunctionInst *dfi) { + processedDifferentiableFunctionInsts.insert(dfi); + } + + const llvm::SmallMapVector & + getInvokers() const { + return invokers; + } + + void addInvoker(SILDifferentiabilityWitness *witness) { + assert(!invokers.count(witness) && + "Differentiability witness already has an invoker"); + invokers.insert({witness, DifferentiationInvoker(witness)}); + } + + /// Returns the result index for `dfi` if found in this context. Otherwise, + /// sets the result index to zero and returns it. + unsigned getResultIndex(DifferentiableFunctionInst *dfi) { + return resultIndices[dfi]; + } + + /// Sets the result index for `dfi`. + void setResultIndex(DifferentiableFunctionInst *dfi, unsigned index) { + resultIndices[dfi] = index; + } + + llvm::DenseMap &getNestedApplyInfo() { + return nestedApplyInfo; + } + + void recordGeneratedFunction(SILFunction *function) { + generatedFunctions.push_back(function); + } + + void recordGeneratedFunctionReference(SILValue functionRef) { + generatedFunctionReferences.push_back(functionRef); + } + + ProtocolDecl *getAdditiveArithmeticProtocol() const { + return additiveArithmeticProtocol; + } + + FuncDecl *getPlusDecl() const; + FuncDecl *getPlusEqualDecl() const; + + /// Cleans up all the internal state. + void cleanUp(); + + /// Creates an `differentiable_function` instruction using the given builder + /// and arguments. Erase the newly created instruction from the processed set, + /// if it exists - it may exist in the processed set if it has the same + /// pointer value as a previously processed and deleted instruction. + /// TODO(TF-784): The pointer reuse is a real concern and the use of + /// `CanonicalizeInstruction` may get rid of the need for this workaround. + DifferentiableFunctionInst *createDifferentiableFunction( + SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices, + SILValue original, + Optional> derivativeFunctions = None); + + // Given an `differentiable_function` instruction, finds the corresponding + // differential operator used in the AST. If no differential operator is + // found, return nullptr. + DifferentiableFunctionExpr * + findDifferentialOperator(DifferentiableFunctionInst *inst); + + template + InFlightDiagnostic diagnose(SourceLoc loc, Diag diag, + U &&... args) const { + return getASTContext().Diags.diagnose(loc, diag, std::forward(args)...); + } + + /// Given an instruction and a differentiation task associated with the + /// parent function, emits a "not differentiable" error based on the task. If + /// the task is indirect, emits notes all the way up to the outermost task, + /// and emits an error at the outer task. Otherwise, emits an error directly. + template + InFlightDiagnostic + emitNondifferentiabilityError(SILInstruction *inst, + DifferentiationInvoker invoker, Diag diag, + U &&... args); + + /// Given a value and a differentiation task associated with the parent + /// function, emits a "not differentiable" error based on the task. If the + /// task is indirect, emits notes all the way up to the outermost task, and + /// emits an error at the outer task. Otherwise, emits an error directly. + template + InFlightDiagnostic + emitNondifferentiabilityError(SILValue value, DifferentiationInvoker invoker, + Diag diag, U &&... args); + + /// Emit a "not differentiable" error based on the given differentiation task + /// and diagnostic. + template + InFlightDiagnostic + emitNondifferentiabilityError(SourceLoc loc, DifferentiationInvoker invoker, + Diag diag, U &&... args); +}; + +template +InFlightDiagnostic +ADContext::emitNondifferentiabilityError(SILValue value, + DifferentiationInvoker invoker, + Diag diag, U &&... args) { + LLVM_DEBUG({ + getADDebugStream() << "Diagnosing non-differentiability.\n"; + getADDebugStream() << "For value:\n" << value; + getADDebugStream() << "With invoker:\n" << invoker << '\n'; + }); + auto valueLoc = value.getLoc().getSourceLoc(); + // If instruction does not have a valid location, use the function location + // as a fallback. Improves diagnostics in some cases. + if (valueLoc.isInvalid()) + valueLoc = value->getFunction()->getLocation().getSourceLoc(); + return emitNondifferentiabilityError(valueLoc, invoker, diag, + std::forward(args)...); +} + +template +InFlightDiagnostic +ADContext::emitNondifferentiabilityError(SILInstruction *inst, + DifferentiationInvoker invoker, + Diag diag, U &&... args) { + LLVM_DEBUG({ + getADDebugStream() << "Diagnosing non-differentiability.\n"; + getADDebugStream() << "For instruction:\n" << *inst; + getADDebugStream() << "With invoker:\n" << invoker << '\n'; + }); + auto instLoc = inst->getLoc().getSourceLoc(); + // If instruction does not have a valid location, use the function location + // as a fallback. Improves diagnostics for `ref_element_addr` generated in + // synthesized stored property getters. + if (instLoc.isInvalid()) + instLoc = inst->getFunction()->getLocation().getSourceLoc(); + return emitNondifferentiabilityError(instLoc, invoker, diag, + std::forward(args)...); +} + +template +InFlightDiagnostic +ADContext::emitNondifferentiabilityError(SourceLoc loc, + DifferentiationInvoker invoker, + Diag diag, U &&... args) { + switch (invoker.getKind()) { + // For `differentiable_function` instructions: if the + // `differentiable_function` instruction comes from a differential operator, + // emit an error on the expression and a note on the non-differentiable + // operation. Otherwise, emit both an error and note on the + // non-differentiation operation. + case DifferentiationInvoker::Kind::DifferentiableFunctionInst: { + auto *inst = invoker.getDifferentiableFunctionInst(); + if (auto *expr = findDifferentialOperator(inst)) { + diagnose(expr->getLoc(), diag::autodiff_function_not_differentiable_error) + .highlight(expr->getSubExpr()->getSourceRange()); + return diagnose(loc, diag, std::forward(args)...); + } + diagnose(loc, diag::autodiff_expression_not_differentiable_error); + return diagnose(loc, diag, std::forward(args)...); + } + + // For differentiability witnesses: try to find a `@differentiable` or + // `@derivative` attribute. If an attribute is found, emit an error on it; + // otherwise, emit an error on the original function. + case DifferentiationInvoker::Kind::SILDifferentiabilityWitnessInvoker: { + auto *witness = invoker.getSILDifferentiabilityWitnessInvoker(); + auto *original = witness->getOriginalFunction(); + // If the witness has an associated attribute, emit an error at its + // location. + if (auto *attr = witness->getAttribute()) { + diagnose(attr->getLocation(), + diag::autodiff_function_not_differentiable_error) + .highlight(attr->getRangeWithAt()); + // Emit informative note. + bool emittedNote = false; + // If the witness comes from an implicit `@differentiable` attribute + // inherited from a protocol requirement's `@differentiable` attribute, + // emit a note on the inherited attribute. + if (auto *diffAttr = dyn_cast(attr)) { + auto inheritedAttrLoc = + diffAttr->getImplicitlyInheritedDifferentiableAttrLocation(); + if (inheritedAttrLoc.isValid()) { + diagnose(inheritedAttrLoc, + diag::autodiff_implicitly_inherited_differentiable_attr_here) + .highlight(inheritedAttrLoc); + emittedNote = true; + } + } + // Otherwise, emit a note on the original function. + if (!emittedNote) { + diagnose(original->getLocation().getSourceLoc(), + diag::autodiff_when_differentiating_function_definition); + } + } + // Otherwise, emit an error on the original function. + else { + diagnose(original->getLocation().getSourceLoc(), + diag::autodiff_function_not_differentiable_error); + } + return diagnose(loc, diag, std::forward(args)...); + } + + // For indirect differentiation, emit a "not differentiable" note on the + // expression first. Then emit an error at the source invoker of + // differentiation, and a "when differentiating this" note at each indirect + // invoker. + case DifferentiationInvoker::Kind::IndirectDifferentiation: { + SILInstruction *inst; + SILDifferentiabilityWitness *witness; + std::tie(inst, witness) = invoker.getIndirectDifferentiation(); + auto invokerLookup = invokers.find(witness); + assert(invokerLookup != invokers.end() && "Expected parent invoker"); + emitNondifferentiabilityError( + inst, invokerLookup->second, + diag::autodiff_expression_not_differentiable_note); + return diagnose(loc, diag::autodiff_when_differentiating_function_call); + } + } +} + +} // end namespace autodiff +} // end namespace swift + +#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H diff --git a/include/swift/SILOptimizer/Utils/Differentiation/Common.h b/include/swift/SILOptimizer/Utils/Differentiation/Common.h new file mode 100644 index 0000000000000..8989f75dd2bd9 --- /dev/null +++ b/include/swift/SILOptimizer/Utils/Differentiation/Common.h @@ -0,0 +1,90 @@ +//===--- Common.h - Automatic differentiation common utils ----*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Automatic differentiation common utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H +#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H + +#include "swift/SIL/SILDifferentiabilityWitness.h" +#include "swift/SIL/SILFunction.h" +#include "swift/SIL/SILModule.h" + +namespace swift { + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +namespace autodiff { + +/// Prints an "[AD] " prefix to `llvm::dbgs()` and returns the debug stream. +/// This is being used to print short debug messages within the AD pass. +raw_ostream &getADDebugStream(); + +/// Returns the underlying instruction for the given SILValue, if it exists, +/// peering through function conversion instructions. +template Inst *peerThroughFunctionConversions(SILValue value) { + if (auto *inst = dyn_cast(value)) + return inst; + if (auto *cvi = dyn_cast(value)) + return peerThroughFunctionConversions(cvi->getOperand()); + if (auto *bbi = dyn_cast(value)) + return peerThroughFunctionConversions(bbi->getOperand()); + if (auto *tttfi = dyn_cast(value)) + return peerThroughFunctionConversions(tttfi->getOperand()); + if (auto *cfi = dyn_cast(value)) + return peerThroughFunctionConversions(cfi->getOperand()); + if (auto *pai = dyn_cast(value)) + return peerThroughFunctionConversions(pai->getCallee()); + return nullptr; +} + +} // end namespace autodiff + +/// Creates arguments in the entry block based on the function type. +inline void createEntryArguments(SILFunction *f) { + auto *entry = f->getEntryBlock(); + auto conv = f->getConventions(); + auto &ctx = f->getASTContext(); + auto moduleDecl = f->getModule().getSwiftModule(); + assert((entry->getNumArguments() == 0 || conv.getNumSILArguments() == 0) && + "Entry already has arguments?!"); + auto createFunctionArgument = [&](SILType type) { + // Create a dummy parameter declaration. + // Necessary to prevent crash during argument explosion optimization. + auto loc = f->getLocation().getSourceLoc(); + auto *decl = new (ctx) + ParamDecl(loc, loc, Identifier(), loc, Identifier(), moduleDecl); + decl->setSpecifier(ParamDecl::Specifier::Default); + entry->createFunctionArgument(type, decl); + }; + // f->getLoweredFunctionType()->remap + for (auto indResTy : conv.getIndirectSILResultTypes()) { + if (indResTy.hasArchetype()) + indResTy = indResTy.mapTypeOutOfContext(); + createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType()); + // createFunctionArgument(indResTy.getAddressType()); + } + for (auto paramTy : conv.getParameterSILTypes()) { + if (paramTy.hasArchetype()) + paramTy = paramTy.mapTypeOutOfContext(); + createFunctionArgument(f->mapTypeIntoContext(paramTy)); + // createFunctionArgument(paramTy); + } +} + +} // end namespace swift + +#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_COMMON_H diff --git a/include/swift/SILOptimizer/Utils/Differentiation/DifferentiationInvoker.h b/include/swift/SILOptimizer/Utils/Differentiation/DifferentiationInvoker.h new file mode 100644 index 0000000000000..405fcc15a97f7 --- /dev/null +++ b/include/swift/SILOptimizer/Utils/Differentiation/DifferentiationInvoker.h @@ -0,0 +1,120 @@ +//===--- DifferentiationInvoker.h -----------------------------*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Class that represents an invoker of differentiation. +// Used to track diagnostic source locations. +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_DIFFERENTIATIONINVOKER_H +#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_DIFFERENTIATIONINVOKER_H + +#include "swift/Basic/SourceLoc.h" +#include + +namespace swift { + +class ApplyInst; +class DifferentiableFunctionInst; +class SILDifferentiabilityWitness; + +namespace autodiff { + +/// The invoker of a differentiation task. It can be some user syntax, e.g. +/// an `differentiable_function` instruction lowered from an +/// `DifferentiableFunctionExpr` expression, the differentiation pass, or +/// nothing at all. This will be used to emit informative diagnostics. +struct DifferentiationInvoker { +public: + /// The kind of the invoker of a differentiation task. + enum class Kind { + // Invoked by an `differentiable_function` instruction, which may or may not + // be linked to a Swift AST node (e.g. an `DifferentiableFunctionExpr` + // expression). + DifferentiableFunctionInst, + + // Invoked by the indirect application of differentiation. This case has an + // associated original `apply` instruction and + // `SILDifferentiabilityWitness`. + IndirectDifferentiation, + + // Invoked by a `SILDifferentiabilityWitness` **without** being linked to a + // Swift AST attribute. This case has an associated + // `SILDifferentiabilityWitness`. + SILDifferentiabilityWitnessInvoker + }; + +private: + Kind kind; + union Value { + /// The instruction associated with the `DifferentiableFunctionInst` case. + DifferentiableFunctionInst *diffFuncInst; + Value(DifferentiableFunctionInst *inst) : diffFuncInst(inst) {} + + /// The parent `apply` instruction and the witness associated with the + /// `IndirectDifferentiation` case. + std::pair + indirectDifferentiation; + Value(ApplyInst *applyInst, SILDifferentiabilityWitness *witness) + : indirectDifferentiation({applyInst, witness}) {} + + /// The witness associated with the `SILDifferentiabilityWitnessInvoker` + /// case. + SILDifferentiabilityWitness *witness; + Value(SILDifferentiabilityWitness *witness) : witness(witness) {} + } value; + + /*implicit*/ + DifferentiationInvoker(Kind kind, Value value) : kind(kind), value(value) {} + +public: + DifferentiationInvoker(DifferentiableFunctionInst *inst) + : kind(Kind::DifferentiableFunctionInst), value(inst) {} + DifferentiationInvoker(ApplyInst *applyInst, + SILDifferentiabilityWitness *witness) + : kind(Kind::IndirectDifferentiation), value({applyInst, witness}) {} + DifferentiationInvoker(SILDifferentiabilityWitness *witness) + : kind(Kind::SILDifferentiabilityWitnessInvoker), value(witness) {} + + Kind getKind() const { return kind; } + + DifferentiableFunctionInst *getDifferentiableFunctionInst() const { + assert(kind == Kind::DifferentiableFunctionInst); + return value.diffFuncInst; + } + + std::pair + getIndirectDifferentiation() const { + assert(kind == Kind::IndirectDifferentiation); + return value.indirectDifferentiation; + } + + SILDifferentiabilityWitness *getSILDifferentiabilityWitnessInvoker() const { + assert(kind == Kind::SILDifferentiabilityWitnessInvoker); + return value.witness; + } + + SourceLoc getLocation() const; + + void print(llvm::raw_ostream &os) const; +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + DifferentiationInvoker invoker) { + invoker.print(os); + return os; +} + +} // end namespace autodiff +} // end namespace swift + +#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_DIFFERENTIATIONINVOKER_H diff --git a/lib/IRGen/GenDiffWitness.cpp b/lib/IRGen/GenDiffWitness.cpp index a13293dca59cf..d931647524fc3 100644 --- a/lib/IRGen/GenDiffWitness.cpp +++ b/lib/IRGen/GenDiffWitness.cpp @@ -27,23 +27,14 @@ void IRGenModule::emitSILDifferentiabilityWitness( SILDifferentiabilityWitness *dw) { PrettyStackTraceDifferentiabilityWitness _st( "emitting differentiability witness for", dw->getKey()); - // Don't emit declarations. if (dw->isDeclaration()) return; - // Don't emit `public_external` witnesses. if (dw->getLinkage() == SILLinkage::PublicExternal) return; - ConstantInitBuilder builder(*this); auto diffWitnessContents = builder.beginStruct(); - - // TODO(TF-1211): Uncomment assertions after upstreaming differentiation - // transform. - // The mandatory differentiation transform canonicalizes differentiability - // witnesses and ensures that JVPs/VJPs are populated. - /* assert(dw->getJVP() && "Differentiability witness definition should have JVP"); assert(dw->getVJP() && @@ -52,16 +43,6 @@ void IRGenModule::emitSILDifferentiabilityWitness( getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy); diffWitnessContents.addBitCast( getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy); - */ - llvm::Constant *jvpValue = llvm::UndefValue::get(Int8PtrTy); - llvm::Constant *vjpValue = llvm::UndefValue::get(Int8PtrTy); - if (auto *jvpFn = dw->getJVP()) - jvpValue = getAddrOfSILFunction(dw->getJVP(), NotForDefinition); - if (auto *vjpFn = dw->getJVP()) - vjpValue = getAddrOfSILFunction(dw->getVJP(), NotForDefinition); - diffWitnessContents.addBitCast(jvpValue, Int8PtrTy); - diffWitnessContents.addBitCast(vjpValue, Int8PtrTy); - getAddrOfDifferentiabilityWitness( dw, diffWitnessContents.finishAndCreateFuture()); } diff --git a/lib/SIL/IR/SILDifferentiabilityWitness.cpp b/lib/SIL/IR/SILDifferentiabilityWitness.cpp index c3f6df4ccb11d..e363102049777 100644 --- a/lib/SIL/IR/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/IR/SILDifferentiabilityWitness.cpp @@ -74,3 +74,7 @@ void SILDifferentiabilityWitness::convertToDefinition(SILFunction *jvp, SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const { return std::make_pair(getOriginalFunction()->getName(), getConfig()); } + +SILAutoDiffIndices SILDifferentiabilityWitness::getSILAutoDiffIndices() const { + return getConfig().getSILAutoDiffIndices(); +} diff --git a/lib/SILOptimizer/IPO/DeadFunctionElimination.cpp b/lib/SILOptimizer/IPO/DeadFunctionElimination.cpp index 9fa665e9def5e..b1b2a25960fc1 100644 --- a/lib/SILOptimizer/IPO/DeadFunctionElimination.cpp +++ b/lib/SILOptimizer/IPO/DeadFunctionElimination.cpp @@ -587,7 +587,14 @@ class DeadFunctionElimination : FunctionLivenessComputation { ensureKeyPathComponentIsAlive(*component); } } - + // Check differentiability witness entries. + for (auto &dw : Module->getDifferentiabilityWitnessList()) { + ensureAlive(dw.getOriginalFunction()); + if (dw.getJVP()) + ensureAlive(dw.getJVP()); + if (dw.getVJP()) + ensureAlive(dw.getVJP()); + } } /// Removes all dead methods from vtables and witness tables. diff --git a/lib/SILOptimizer/Mandatory/CMakeLists.txt b/lib/SILOptimizer/Mandatory/CMakeLists.txt index ac2d4a9044868..e8ee3a6084e5e 100644 --- a/lib/SILOptimizer/Mandatory/CMakeLists.txt +++ b/lib/SILOptimizer/Mandatory/CMakeLists.txt @@ -11,6 +11,7 @@ silopt_register_sources( DiagnoseInvalidEscapingCaptures.cpp DiagnoseStaticExclusivity.cpp DiagnoseUnreachable.cpp + Differentiation.cpp GuaranteedARCOpts.cpp IRGenPrepare.cpp MandatoryInlining.cpp diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp new file mode 100644 index 0000000000000..883a0281e491d --- /dev/null +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -0,0 +1,605 @@ +//===--- Differentiation.cpp - SIL Automatic Differentiation --*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2018 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file implements automatic differentiation. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "differentiation" + +#include "swift/AST/ASTMangler.h" +#include "swift/AST/ASTPrinter.h" +#include "swift/AST/AnyFunctionRef.h" +#include "swift/AST/AutoDiff.h" +#include "swift/AST/Builtins.h" +#include "swift/AST/DeclContext.h" +#include "swift/AST/DiagnosticsSIL.h" +#include "swift/AST/Expr.h" +#include "swift/AST/GenericEnvironment.h" +#include "swift/AST/GenericSignatureBuilder.h" +#include "swift/AST/LazyResolver.h" +#include "swift/AST/ParameterList.h" +#include "swift/AST/SourceFile.h" +#include "swift/AST/SubstitutionMap.h" +#include "swift/AST/TypeCheckRequests.h" +#include "swift/SIL/FormalLinkage.h" +#include "swift/SIL/PrettyStackTrace.h" +#include "swift/SIL/SILBuilder.h" +#include "swift/SIL/TypeSubstCloner.h" +#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h" +#include "swift/SILOptimizer/PassManager/Passes.h" +#include "swift/SILOptimizer/PassManager/Transforms.h" +#include "swift/SILOptimizer/Utils/Differentiation/ADContext.h" +#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/BreadthFirstIterator.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/CommandLine.h" + +using namespace swift; +using namespace swift::autodiff; +using llvm::DenseMap; +using llvm::SmallDenseMap; +using llvm::SmallDenseSet; +using llvm::SmallMapVector; +using llvm::SmallSet; + +/// This flag is used to disable `differentiable_function_extract` instruction +/// folding for SIL testing purposes. +static llvm::cl::opt SkipFoldingDifferentiableFunctionExtraction( + "differentiation-skip-folding-differentiable-function-extraction", + llvm::cl::init(true)); + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +/// Given a dumpable value, dumps it to `llvm::dbgs()`. +template static inline void debugDump(T &v) { + LLVM_DEBUG(llvm::dbgs() << "\n==== BEGIN DEBUG DUMP ====\n" + << v << "\n==== END DEBUG DUMP ====\n"); +} + +namespace { + +class DifferentiationTransformer { +private: + /// Reference to the main transform. + SILModuleTransform &transform; + + /// Context necessary for performing the transformations. + ADContext context; + + /// Promotes the given `differentiable_function` instruction to a valid + /// `@differentiable` function-typed value. + SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst, + SILBuilder &builder, SILLocation loc, + DifferentiationInvoker invoker); + +public: + /// Construct an `DifferentiationTransformer` for the given module. + explicit DifferentiationTransformer(SILModuleTransform &transform) + : transform(transform), context(transform) {} + + ADContext &getContext() { return context; } + + /// Canonicalize the given witness, filling in derivative functions if + /// missing. + /// + /// Generated derivative functions have the same linkage as the witness. + /// + /// \param serializeFunctions specifies whether generated functions should be + /// serialized. + bool canonicalizeDifferentiabilityWitness( + SILFunction *original, SILDifferentiabilityWitness *witness, + DifferentiationInvoker invoker, IsSerialized_t serializeFunctions); + + /// Process the given `differentiable_function` instruction, filling in + /// missing derivative functions if necessary. + bool processDifferentiableFunctionInst(DifferentiableFunctionInst *dfi); + + /// Fold `differentiable_function_extract` users of the given + /// `differentiable_function` instruction, directly replacing them with + /// `differentiable_function` instruction operands. If the + /// `differentiable_function` instruction has no remaining uses, delete the + /// instruction itself after folding. + /// + /// Folding can be disabled by the + /// `SkipFoldingDifferentiableFunctionExtraction` flag for SIL testing + /// purposes. + void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source); +}; + +} // end anonymous namespace + +/// If the original function doesn't have a return, it cannot be differentiated. +/// Returns true if error is emitted. +static bool diagnoseNoReturn(ADContext &context, SILFunction *original, + DifferentiationInvoker invoker) { + if (original->findReturnBB() != original->end()) + return false; + context.emitNondifferentiabilityError( + original->getLocation().getEndSourceLoc(), invoker, + diag::autodiff_missing_return); + return true; +} + +/// If the original function contains unsupported control flow, emit a "control +/// flow unsupported" error at appropriate source locations. Returns true if +/// error is emitted. +/// +/// Update as control flow support is added. Currently, branching terminators +/// other than `br`, `cond_br`, `switch_enum` are not supported. +static bool diagnoseUnsupportedControlFlow(ADContext &context, + SILFunction *original, + DifferentiationInvoker invoker) { + if (original->getBlocks().size() <= 1) + return false; + // Diagnose unsupported branching terminators. + for (auto &bb : *original) { + auto *term = bb.getTerminator(); + // Supported terminators are: `br`, `cond_br`, `switch_enum`, + // `switch_enum_addr`. + if (isa(term) || isa(term) || + isa(term) || isa(term)) + continue; + // If terminator is an unsupported branching terminator, emit an error. + if (term->isBranch()) { + context.emitNondifferentiabilityError( + term, invoker, diag::autodiff_control_flow_not_supported); + return true; + } + } + return false; +} + +//===----------------------------------------------------------------------===// +// `SILDifferentiabilityWitness` processing +//===----------------------------------------------------------------------===// + +static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, + IsSerialized_t isSerialized) { + LLVM_DEBUG({ + auto &s = getADDebugStream(); + s << "Creating VJP:\n\t"; + s << "Original type: " << original->getLoweredFunctionType() << "\n\t"; + }); + + auto &module = context.getModule(); + auto originalTy = original->getLoweredFunctionType(); + auto indices = witness->getSILAutoDiffIndices(); + + // === Create an empty VJP. === + Mangle::ASTMangler mangler; + auto vjpName = + original->getASTContext() + .getIdentifier(mangler.mangleAutoDiffDerivativeFunctionHelper( + original->getName(), AutoDiffDerivativeFunctionKind::VJP, + witness->getConfig())) + .str(); + CanGenericSignature vjpCanGenSig; + if (auto jvpGenSig = witness->getDerivativeGenericSignature()) + vjpCanGenSig = jvpGenSig->getCanonicalSignature(); + GenericEnvironment *vjpGenericEnv = nullptr; + if (vjpCanGenSig && !vjpCanGenSig->areAllParamsConcrete()) + vjpGenericEnv = vjpCanGenSig->getGenericEnvironment(); + auto vjpType = originalTy->getAutoDiffDerivativeFunctionType( + indices.parameters, indices.source, AutoDiffDerivativeFunctionKind::VJP, + module.Types, LookUpConformanceInModule(module.getSwiftModule()), + vjpCanGenSig, + /*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk); + + SILOptFunctionBuilder fb(context.getTransform()); + auto *vjp = fb.createFunction( + witness->getLinkage(), vjpName, vjpType, vjpGenericEnv, + original->getLocation(), original->isBare(), IsNotTransparent, + isSerialized, original->isDynamicallyReplaceable()); + vjp->setDebugScope(new (module) SILDebugScope(original->getLocation(), vjp)); + + LLVM_DEBUG(llvm::dbgs() << "VJP type: " << vjp->getLoweredFunctionType() + << "\n"); + return vjp; +} + +static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, + IsSerialized_t isSerialized) { + LLVM_DEBUG({ + auto &s = getADDebugStream(); + s << "Creating JVP:\n\t"; + s << "Original type: " << original->getLoweredFunctionType() << "\n\t"; + }); + + auto &module = context.getModule(); + auto originalTy = original->getLoweredFunctionType(); + auto indices = witness->getSILAutoDiffIndices(); + + // === Create an empty JVP. === + Mangle::ASTMangler mangler; + auto jvpName = + original->getASTContext() + .getIdentifier(mangler.mangleAutoDiffDerivativeFunctionHelper( + original->getName(), AutoDiffDerivativeFunctionKind::JVP, + witness->getConfig())) + .str(); + CanGenericSignature jvpCanGenSig; + if (auto jvpGenSig = witness->getDerivativeGenericSignature()) + jvpCanGenSig = jvpGenSig->getCanonicalSignature(); + GenericEnvironment *jvpGenericEnv = nullptr; + if (jvpCanGenSig && !jvpCanGenSig->areAllParamsConcrete()) + jvpGenericEnv = jvpCanGenSig->getGenericEnvironment(); + auto jvpType = originalTy->getAutoDiffDerivativeFunctionType( + indices.parameters, indices.source, AutoDiffDerivativeFunctionKind::JVP, + module.Types, LookUpConformanceInModule(module.getSwiftModule()), + jvpCanGenSig, + /*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk); + + SILOptFunctionBuilder fb(context.getTransform()); + auto *jvp = fb.createFunction( + witness->getLinkage(), jvpName, jvpType, jvpGenericEnv, + original->getLocation(), original->isBare(), IsNotTransparent, + isSerialized, original->isDynamicallyReplaceable()); + jvp->setDebugScope(new (module) SILDebugScope(original->getLocation(), jvp)); + + LLVM_DEBUG(llvm::dbgs() << "JVP type: " << jvp->getLoweredFunctionType() + << "\n"); + return jvp; +} + +/// Apply the fatal error function with the given name of type +/// `@convention(thin) () -> Never` in `f`. +static void emitFatalError(ADContext &context, SILFunction *f, + StringRef fatalErrorFuncName) { + auto *entry = f->createBasicBlock(); + createEntryArguments(f); + SILBuilder builder(entry); + auto loc = f->getLocation(); + // Destroy all owned arguments to pass ownership verification. + for (auto *arg : entry->getArguments()) + if (arg->getOwnershipKind() == ValueOwnershipKind::Owned) + builder.emitDestroyOperation(loc, arg); + // Fatal error with a nice message. + auto neverResultInfo = + SILResultInfo(context.getModule().getASTContext().getNeverType(), + ResultConvention::Unowned); + // Fatal error function must have type `@convention(thin) () -> Never`. + auto fatalErrorFnType = SILFunctionType::get( + /*genericSig*/ nullptr, + SILFunctionType::ExtInfo().withRepresentation( + SILFunctionTypeRepresentation::Thin), + SILCoroutineKind::None, ParameterConvention::Direct_Unowned, {}, + /*interfaceYields*/ {}, neverResultInfo, + /*interfaceErrorResults*/ None, {}, {}, context.getASTContext()); + auto fnBuilder = SILOptFunctionBuilder(context.getTransform()); + auto *fatalErrorFn = fnBuilder.getOrCreateFunction( + loc, fatalErrorFuncName, SILLinkage::PublicExternal, fatalErrorFnType, + IsNotBare, IsNotTransparent, IsNotSerialized, IsNotDynamic, + ProfileCounter(), IsNotThunk); + auto *fatalErrorFnRef = builder.createFunctionRef(loc, fatalErrorFn); + builder.createApply(loc, fatalErrorFnRef, SubstitutionMap(), {}); + builder.createUnreachable(loc); +} + +/// Returns true on error. +bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( + SILFunction *original, SILDifferentiabilityWitness *witness, + DifferentiationInvoker invoker, IsSerialized_t serializeFunctions) { + std::string traceMessage; + llvm::raw_string_ostream OS(traceMessage); + OS << "processing "; + witness->print(OS); + OS << " on"; + OS.flush(); + PrettyStackTraceSILFunction trace(traceMessage.c_str(), original); + + assert(witness->isDefinition()); + + // If the JVP doesn't exist, need to synthesize it. + if (!witness->getJVP()) { + // Diagnose: + // - Functions with no return. + // - Functions with unsupported control flow. + if (context.getASTContext() + .LangOpts.EnableExperimentalForwardModeDifferentiation && + (diagnoseNoReturn(context, original, invoker) || + diagnoseUnsupportedControlFlow(context, original, invoker))) + return true; + + witness->setJVP( + createEmptyJVP(context, original, witness, serializeFunctions)); + context.recordGeneratedFunction(witness->getJVP()); + + // For now, only do JVP generation if the flag is enabled and if custom VJP + // does not exist. If custom VJP exists but custom JVP does not, skip JVP + // generation because generated JVP may not match semantics of custom VJP. + // Instead, create an empty JVP. + if (context.getASTContext() + .LangOpts.EnableExperimentalForwardModeDifferentiation && + !witness->getVJP()) { + // JVP and differential generation do not currently support functions with + // multiple basic blocks. + if (original->getBlocks().size() > 1) { + context.emitNondifferentiabilityError( + original->getLocation().getSourceLoc(), invoker, + diag::autodiff_jvp_control_flow_not_supported); + return true; + } + // TODO(TF-1211): Upstream and use `JVPEmitter`. Fatal error with a nice + // message for now. + auto *jvp = witness->getJVP(); + emitFatalError(context, jvp, "_fatalErrorJVPNotGenerated"); + } else { + // If JVP generation is disabled or a user-defined custom VJP function + // exists, fatal error with a nice message. + emitFatalError(context, witness->getJVP(), + "_fatalErrorForwardModeDifferentiationDisabled"); + LLVM_DEBUG(getADDebugStream() + << "Generated empty JVP for " << original->getName() << ":\n" + << *witness->getJVP()); + } + } + + // If the VJP doesn't exist, need to synthesize it. + if (!witness->getVJP()) { + // Diagnose: + // - Functions with no return. + // - Functions with unsupported control flow. + if (diagnoseNoReturn(context, original, invoker) || + diagnoseUnsupportedControlFlow(context, original, invoker)) + return true; + + // Create empty VJP. + auto *vjp = createEmptyVJP(context, original, witness, serializeFunctions); + witness->setVJP(vjp); + context.recordGeneratedFunction(vjp); + // TODO(TF-1211): Upstream and use `VJPEmitter`. Fatal error with a nice + // message for now. + emitFatalError(context, vjp, "_fatalErrorVJPNotGenerated"); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Differentiation pass implementation +//===----------------------------------------------------------------------===// + +/// The automatic differentiation pass. +namespace { +class Differentiation : public SILModuleTransform { +public: + Differentiation() : SILModuleTransform() {} + void run() override; +}; +} // end anonymous namespace + +SILValue DifferentiationTransformer::promoteToDifferentiableFunction( + DifferentiableFunctionInst *dfi, SILBuilder &builder, SILLocation loc, + DifferentiationInvoker invoker) { + auto origFnOperand = dfi->getOriginalFunction(); + auto origFnTy = origFnOperand->getType().castTo(); + auto parameterIndices = dfi->getParameterIndices(); + unsigned resultIndex = context.getResultIndex(dfi); + + // TODO(TF-1211): Upstream full derivative function referen ce emission logic. + SmallVector derivativeFns; + for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP, + AutoDiffDerivativeFunctionKind::VJP}) { + auto expectedDerivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType( + parameterIndices, resultIndex, derivativeFnKind, + context.getTypeConverter(), + LookUpConformanceInModule(context.getModule().getSwiftModule())); + auto expectedDerivativeSilTy = + SILType::getPrimitiveObjectType(expectedDerivativeFnTy); + // TODO: Replace undef with actual derivative function reference. + auto derivativeFn = + SILUndef::get(expectedDerivativeSilTy, builder.getFunction()); + derivativeFns.push_back(derivativeFn); + } + auto origFnCopy = builder.emitCopyValueOperation(loc, origFnOperand); + auto *newDFI = context.createDifferentiableFunction( + builder, loc, parameterIndices, origFnCopy, + std::make_pair(derivativeFns[0], derivativeFns[1])); + context.setResultIndex(dfi, resultIndex); + context.addDifferentiableFunctionInstToWorklist(dfi); + + return newDFI; +} + +/// Fold `differentiable_function_extract` users of the given +/// `differentiable_function` instruction, directly replacing them with +/// `differentiable_function` instruction operands. If the +/// `differentiable_function` instruction has no remaining uses, delete the +/// instruction itself after folding. +/// +/// Folding can be disabled by the `SkipFoldingDifferentiableFunctionExtraction` +/// flag for SIL testing purposes. +// FIXME: This function is not correctly detecting the foldable pattern and +// needs to be rewritten. +void DifferentiationTransformer::foldDifferentiableFunctionExtraction( + DifferentiableFunctionInst *source) { + // Iterate through all `differentiable_function` instruction uses. + for (auto use : source->getUses()) { + auto *dfei = dyn_cast(use->getUser()); + // If user is not an `differentiable_function_extract` instruction, set flag + // to false. + if (!dfei) + continue; + // Fold original function extractors. + if (dfei->getExtractee() == + NormalDifferentiableFunctionTypeComponent::Original) { + auto originalFnValue = source->getOriginalFunction(); + dfei->replaceAllUsesWith(originalFnValue); + dfei->eraseFromParent(); + continue; + } + // Fold derivative function extractors. + auto derivativeFnValue = + source->getDerivativeFunction(dfei->getDerivativeFunctionKind()); + dfei->replaceAllUsesWith(derivativeFnValue); + dfei->eraseFromParent(); + } + // If the `differentiable_function` instruction has no remaining uses, erase + // it. + if (isInstructionTriviallyDead(source)) { + SILBuilder builder(source); + builder.emitDestroyAddrAndFold(source->getLoc(), source->getJVPFunction()); + builder.emitDestroyAddrAndFold(source->getLoc(), source->getVJPFunction()); + source->eraseFromParent(); + } + // Mark `source` as processed so that it won't be reprocessed after deletion. + context.markDifferentiableFunctionInstAsProcessed(source); +} + +bool DifferentiationTransformer::processDifferentiableFunctionInst( + DifferentiableFunctionInst *dfi) { + PrettyStackTraceSILNode dfiTrace("canonicalizing `differentiable_function`", + cast(dfi)); + PrettyStackTraceSILFunction fnTrace("...in", dfi->getFunction()); + LLVM_DEBUG({ + auto &s = getADDebugStream() << "Processing DifferentiableFunctionInst:\n"; + dfi->printInContext(s); + }); + + // If `dfi` already has derivative functions, do not process. + if (dfi->hasDerivativeFunctions()) + return false; + + SILFunction *parent = dfi->getFunction(); + auto loc = dfi->getLoc(); + SILBuilderWithScope builder(dfi); + auto differentiableFnValue = + promoteToDifferentiableFunction(dfi, builder, loc, dfi); + // Mark `dfi` as processed so that it won't be reprocessed after deletion. + context.markDifferentiableFunctionInstAsProcessed(dfi); + if (!differentiableFnValue) + return true; + // Replace all uses of `dfi`. + dfi->replaceAllUsesWith(differentiableFnValue); + // Destroy the original operand. + builder.emitDestroyValueOperation(loc, dfi->getOriginalFunction()); + dfi->eraseFromParent(); + // If the promoted `@differentiable` function-typed value is an + // `differentiable_function` instruction, fold + // `differentiable_function_extract` instructions. If + // `differentiable_function_extract` folding is disabled, return. + if (!SkipFoldingDifferentiableFunctionExtraction) + if (auto *newDFI = + dyn_cast(differentiableFnValue)) + foldDifferentiableFunctionExtraction(newDFI); + transform.invalidateAnalysis(parent, + SILAnalysis::InvalidationKind::FunctionBody); + return false; +} + +/// Automatic differentiation transform entry. +void Differentiation::run() { + auto &module = *getModule(); + auto &astCtx = module.getASTContext(); + debugDump(module); + + // A transformation helper. + DifferentiationTransformer transformer(*this); + ADContext &context = transformer.getContext(); + + bool errorOccurred = false; + + // Register all the SIL differentiability witnesses in the module that trigger + // differentiation. + for (auto &witness : module.getDifferentiabilityWitnesses()) { + if (witness.isDeclaration()) + continue; + context.addInvoker(&witness); + } + + // Register all the `differentiable_function` instructions in the module that + // trigger differentiation. + for (SILFunction &f : module) { + for (SILBasicBlock &bb : f) { + for (SILInstruction &i : bb) { + if (auto *dfi = dyn_cast(&i)) + context.addDifferentiableFunctionInstToWorklist(dfi); + // Reject uncanonical `linear_function` instructions. + // FIXME(SR-11850): Add support for linear map transposition. + else if (auto *lfi = dyn_cast(&i)) { + if (!lfi->hasTransposeFunction()) { + astCtx.Diags.diagnose( + lfi->getLoc().getSourceLoc(), + diag::autodiff_conversion_to_linear_function_not_supported); + errorOccurred = true; + } + } + } + } + } + + // If nothing has triggered differentiation, there's nothing to do. + if (context.getInvokers().empty() && + context.isDifferentiableFunctionInstsWorklistEmpty()) + return; + + // Differentiation relies on the stdlib (the Swift module). + // If it's not imported, it's an internal error. + if (!astCtx.getStdlibModule()) { + astCtx.Diags.diagnose(SourceLoc(), + diag::autodiff_internal_swift_not_imported); + return; + } + if (!astCtx.getLoadedModule(astCtx.Id_Differentiation)) { + SourceLoc loc; + if (!context.getInvokers().empty()) { + loc = context.getInvokers().front().second.getLocation(); + } else { + assert(!context.isDifferentiableFunctionInstsWorklistEmpty()); + loc = context.popDifferentiableFunctionInstFromWorklist() + ->getLoc() + .getSourceLoc(); + } + astCtx.Diags.diagnose(loc, + diag::autodiff_differentiation_module_not_imported); + return; + } + + // Process all invokers. + for (auto invokerPair : context.getInvokers()) { + auto *witness = invokerPair.first; + auto *original = witness->getOriginalFunction(); + auto invoker = invokerPair.second; + + if (transformer.canonicalizeDifferentiabilityWitness( + original, witness, invoker, original->isSerialized())) + errorOccurred = true; + } + + // Iteratively process `differentiable_function` instruction worklist. + while (auto *dfi = context.popDifferentiableFunctionInstFromWorklist()) { + // Skip instructions that have been already been processed. + if (context.isDifferentiableFunctionInstProcessed(dfi)) + continue; + errorOccurred |= transformer.processDifferentiableFunctionInst(dfi); + } + + // If any error occurred while processing witnesses or + // `differentiable_function` instructions, clean up. + if (errorOccurred) { + context.cleanUp(); + return; + } + + LLVM_DEBUG(getADDebugStream() << "All differentiation finished\n"); +} + +//===----------------------------------------------------------------------===// +// Pass creation +//===----------------------------------------------------------------------===// + +SILTransform *swift::createDifferentiation() { return new Differentiation; } diff --git a/lib/SILOptimizer/PassManager/PassPipeline.cpp b/lib/SILOptimizer/PassManager/PassPipeline.cpp index e66b4ff69b015..8966facd12e97 100644 --- a/lib/SILOptimizer/PassManager/PassPipeline.cpp +++ b/lib/SILOptimizer/PassManager/PassPipeline.cpp @@ -92,6 +92,8 @@ static void addMandatoryOptPipeline(SILPassPipelinePlan &P) { P.addAllocBoxToStack(); P.addNoReturnFolding(); addDefiniteInitialization(P); + P.addDifferentiation(); + // Only run semantic arc opts if we are optimizing and if mandatory semantic // arc opts is explicitly enabled. // diff --git a/lib/SILOptimizer/Utils/CMakeLists.txt b/lib/SILOptimizer/Utils/CMakeLists.txt index 72587675a5500..ee7fa1d0aa6fb 100644 --- a/lib/SILOptimizer/Utils/CMakeLists.txt +++ b/lib/SILOptimizer/Utils/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Differentiation) + silopt_register_sources( BasicBlockOptUtils.cpp CFGOptUtils.cpp diff --git a/lib/SILOptimizer/Utils/Differentiation/ADContext.cpp b/lib/SILOptimizer/Utils/Differentiation/ADContext.cpp new file mode 100644 index 0000000000000..545ebeafdecb2 --- /dev/null +++ b/lib/SILOptimizer/Utils/Differentiation/ADContext.cpp @@ -0,0 +1,114 @@ +//===--- ADContext.cpp - Differentiation Context --------------*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Per-module contextual information for the differentiation transform. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "differentiation" + +#include "swift/SILOptimizer/Utils/Differentiation/ADContext.h" +#include "swift/AST/DiagnosticsSIL.h" +#include "swift/SILOptimizer/PassManager/Transforms.h" + +using llvm::DenseMap; +using llvm::SmallPtrSet; +using llvm::SmallVector; + +namespace swift { +namespace autodiff { + +//===----------------------------------------------------------------------===// +// Local helpers +//===----------------------------------------------------------------------===// + +/// Given an operator name, such as '+', and a protocol, returns the '+' +/// operator. If the operator does not exist in the protocol, returns null. +static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName, + ProtocolDecl *protocol) { + assert(operatorName.isOperator()); + // Find the operator requirement in the given protocol declaration. + auto opLookup = protocol->lookupDirect(operatorName); + for (auto *decl : opLookup) { + if (!decl->isProtocolRequirement()) + continue; + auto *fd = dyn_cast(decl); + if (!fd || !fd->isStatic() || !fd->isOperator()) + continue; + return fd; + } + // Not found. + return nullptr; +} + +//===----------------------------------------------------------------------===// +// ADContext methods +//===----------------------------------------------------------------------===// + +ADContext::ADContext(SILModuleTransform &transform) + : transform(transform), module(*transform.getModule()), + passManager(*transform.getPassManager()) {} + +FuncDecl *ADContext::getPlusDecl() const { + if (!cachedPlusFn) { + cachedPlusFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+"), + additiveArithmeticProtocol); + assert(cachedPlusFn && "AdditiveArithmetic.+ not found"); + } + return cachedPlusFn; +} + +FuncDecl *ADContext::getPlusEqualDecl() const { + if (!cachedPlusEqualFn) { + cachedPlusEqualFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+="), + additiveArithmeticProtocol); + assert(cachedPlusEqualFn && "AdditiveArithmetic.+= not found"); + } + return cachedPlusEqualFn; +} + +void ADContext::cleanUp() { + // Delete all references to generated functions. + for (auto fnRef : generatedFunctionReferences) { + if (auto *fnRefInst = + peerThroughFunctionConversions(fnRef)) { + fnRefInst->replaceAllUsesWithUndef(); + fnRefInst->eraseFromParent(); + } + } + // Delete all generated functions. + for (auto *generatedFunction : generatedFunctions) { + LLVM_DEBUG(getADDebugStream() << "Deleting generated function " + << generatedFunction->getName() << '\n'); + generatedFunction->dropAllReferences(); + transform.notifyWillDeleteFunction(generatedFunction); + module.eraseFunction(generatedFunction); + } +} + +DifferentiableFunctionInst *ADContext::createDifferentiableFunction( + SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices, + SILValue original, + Optional> derivativeFunctions) { + auto *dfi = builder.createDifferentiableFunction( + loc, parameterIndices, original, derivativeFunctions); + processedDifferentiableFunctionInsts.erase(dfi); + return dfi; +} + +DifferentiableFunctionExpr * +ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) { + return inst->getLoc().getAsASTNode(); +} + +} // end namespace autodiff +} // end namespace swift diff --git a/lib/SILOptimizer/Utils/Differentiation/CMakeLists.txt b/lib/SILOptimizer/Utils/Differentiation/CMakeLists.txt new file mode 100644 index 0000000000000..d3affa98414a6 --- /dev/null +++ b/lib/SILOptimizer/Utils/Differentiation/CMakeLists.txt @@ -0,0 +1,5 @@ +silopt_register_sources( + ADContext.cpp + Common.cpp + DifferentiationInvoker.cpp +) diff --git a/lib/SILOptimizer/Utils/Differentiation/Common.cpp b/lib/SILOptimizer/Utils/Differentiation/Common.cpp new file mode 100644 index 0000000000000..3c584a3bb4ac6 --- /dev/null +++ b/lib/SILOptimizer/Utils/Differentiation/Common.cpp @@ -0,0 +1,27 @@ +//===--- Common.cpp - Automatic differentiation common utils --*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Automatic differentiation common utilities. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "differentiation" + +#include "swift/SILOptimizer/Utils/Differentiation/Common.h" + +namespace swift { +namespace autodiff { + +raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; } + +} // end namespace autodiff +} // end namespace swift diff --git a/lib/SILOptimizer/Utils/Differentiation/DifferentiationInvoker.cpp b/lib/SILOptimizer/Utils/Differentiation/DifferentiationInvoker.cpp new file mode 100644 index 0000000000000..94859c4c847f4 --- /dev/null +++ b/lib/SILOptimizer/Utils/Differentiation/DifferentiationInvoker.cpp @@ -0,0 +1,74 @@ +//===--- DifferentiationInvoker.cpp ---------------------------*- C++ -*---===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Class that represents an invoker of differentiation. +// Used to track diagnostic source locations. +// +//===----------------------------------------------------------------------===// + +#include "swift/SILOptimizer/Utils/Differentiation/DifferentiationInvoker.h" +#include "swift/SIL/SILDifferentiabilityWitness.h" +#include "swift/SIL/SILFunction.h" +#include "swift/SIL/SILInstruction.h" + +namespace swift { +namespace autodiff { + +SourceLoc DifferentiationInvoker::getLocation() const { + switch (kind) { + case Kind::DifferentiableFunctionInst: + return getDifferentiableFunctionInst()->getLoc().getSourceLoc(); + case Kind::IndirectDifferentiation: + return getIndirectDifferentiation().first->getLoc().getSourceLoc(); + case Kind::SILDifferentiabilityWitnessInvoker: + return getSILDifferentiabilityWitnessInvoker() + ->getOriginalFunction() + ->getLocation() + .getSourceLoc(); + } +} + +void DifferentiationInvoker::print(llvm::raw_ostream &os) const { + os << "(differentiation_invoker "; + switch (kind) { + case Kind::DifferentiableFunctionInst: + os << "differentiable_function_inst=(" << *getDifferentiableFunctionInst() + << ")"; + break; + case Kind::IndirectDifferentiation: { + auto indDiff = getIndirectDifferentiation(); + os << "indirect_differentiation=(" << *std::get<0>(indDiff) << ')'; + // TODO: Enable printing parent invokers. + // May require storing a `DifferentiableInvoker *` in the + // `IndirectDifferentiation` case. + /* + SILInstruction *inst; + SILDifferentiableAttr *attr; + std::tie(inst, attr) = getIndirectDifferentiation(); + auto invokerLookup = invokers.find(attr); // No access to ADContext? + assert(invokerLookup != invokers.end() && "Expected parent invoker"); + */ + break; + } + case Kind::SILDifferentiabilityWitnessInvoker: { + auto witness = getSILDifferentiabilityWitnessInvoker(); + os << "sil_differentiability_witness_invoker=(witness=("; + witness->print(os); + os << ") function=" << witness->getOriginalFunction()->getName(); + break; + } + } + os << ')'; +} + +} // end namespace autodiff +} // end namespace swift diff --git a/lib/TBDGen/TBDGen.cpp b/lib/TBDGen/TBDGen.cpp index 01687fe9ea5d4..88c36f6e1ed79 100644 --- a/lib/TBDGen/TBDGen.cpp +++ b/lib/TBDGen/TBDGen.cpp @@ -502,6 +502,95 @@ void TBDGenVisitor::addConformances(DeclContext *DC) { } } +void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original, + AutoDiffConfig config, + AutoDiffLinearMapKind kind) { + auto &ctx = original->getASTContext(); + auto declRef = + SILDeclRef(original).asForeign(requiresForeignEntryPoint(original)); + + if (!declRef.isSerialized()) + return; + // Linear maps are public only when the original function is serialized. + if (!declRef.isSerialized()) + return; + // Differential functions are emitted only when forward-mode is enabled. + if (kind == AutoDiffLinearMapKind::Differential && + !ctx.LangOpts.EnableExperimentalForwardModeDifferentiation) + return; + auto *loweredParamIndices = autodiff::getLoweredParameterIndices( + config.parameterIndices, + original->getInterfaceType()->castTo()); + Mangle::ASTMangler mangler; + AutoDiffConfig silConfig{loweredParamIndices, config.resultIndices, + config.derivativeGenericSignature}; + std::string linearMapName = + mangler.mangleAutoDiffLinearMapHelper(declRef.mangle(), kind, silConfig); + addSymbol(linearMapName); +} + +void TBDGenVisitor::addAutoDiffDerivativeFunction( + AbstractFunctionDecl *original, IndexSubset *parameterIndices, + GenericSignature derivativeGenericSignature, + AutoDiffDerivativeFunctionKind kind) { + auto *assocFnId = AutoDiffDerivativeFunctionIdentifier::get( + kind, parameterIndices, derivativeGenericSignature, + original->getASTContext()); + auto declRef = + SILDeclRef(original).asForeign(requiresForeignEntryPoint(original)); + addSymbol(declRef.asAutoDiffDerivativeFunction(assocFnId)); +} + +void TBDGenVisitor::addDifferentiabilityWitness( + AbstractFunctionDecl *original, IndexSubset *astParameterIndices, + IndexSubset *resultIndices, GenericSignature derivativeGenericSignature) { + bool foreign = requiresForeignEntryPoint(original); + auto declRef = SILDeclRef(original).asForeign(foreign); + + // Skip symbol emission for original functions that do not have public + // linkage. Exclude original functions that require a foreign entry point with + // `public_external` linkage. + auto originalLinkage = declRef.getLinkage(ForDefinition); + if (foreign) + originalLinkage = stripExternalFromLinkage(originalLinkage); + if (originalLinkage != SILLinkage::Public) + return; + + auto *silParamIndices = autodiff::getLoweredParameterIndices( + astParameterIndices, + original->getInterfaceType()->castTo()); + + auto originalMangledName = declRef.mangle(); + AutoDiffConfig config{silParamIndices, resultIndices, + derivativeGenericSignature}; + SILDifferentiabilityWitnessKey key(originalMangledName, config); + + Mangle::ASTMangler mangler; + auto mangledName = mangler.mangleSILDifferentiabilityWitnessKey(key); + addSymbol(mangledName); +} + +void TBDGenVisitor::addDerivativeConfiguration(AbstractFunctionDecl *original, + AutoDiffConfig config) { + auto inserted = AddedDerivatives.insert({original, config}); + if (!inserted.second) + return; + + addAutoDiffLinearMapFunction(original, config, + AutoDiffLinearMapKind::Differential); + addAutoDiffLinearMapFunction(original, config, + AutoDiffLinearMapKind::Pullback); + addAutoDiffDerivativeFunction(original, config.parameterIndices, + config.derivativeGenericSignature, + AutoDiffDerivativeFunctionKind::JVP); + addAutoDiffDerivativeFunction(original, config.parameterIndices, + config.derivativeGenericSignature, + AutoDiffDerivativeFunctionKind::VJP); + addDifferentiabilityWitness(original, config.parameterIndices, + config.resultIndices, + config.derivativeGenericSignature); +} + /// Determine whether dynamic replacement should be emitted for the allocator or /// the initializer given a decl. /// The rule is that structs and convenience init of classes emit a @@ -565,6 +654,22 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) { addSymbol(SILDeclRef(AFD).asForeign()); } + // Add derivative function symbols. + for (const auto *differentiableAttr : + AFD->getAttrs().getAttributes()) + addDerivativeConfiguration( + AFD, + AutoDiffConfig(differentiableAttr->getParameterIndices(), + IndexSubset::get(AFD->getASTContext(), 1, {0}), + differentiableAttr->getDerivativeGenericSignature())); + for (const auto *derivativeAttr : + AFD->getAttrs().getAttributes()) + addDerivativeConfiguration( + derivativeAttr->getOriginalFunction(), + AutoDiffConfig(derivativeAttr->getParameterIndices(), + IndexSubset::get(AFD->getASTContext(), 1, {0}), + AFD->getGenericSignature())); + visitDefaultArguments(AFD, AFD->getParameters()); } @@ -617,6 +722,15 @@ void TBDGenVisitor::visitAbstractStorageDecl(AbstractStorageDecl *ASD) { ASD->visitEmittedAccessors([&](AccessorDecl *accessor) { visitFuncDecl(accessor); }); + + // Add derivative function symbols. + for (const auto *differentiableAttr : + ASD->getAttrs().getAttributes()) + addDerivativeConfiguration( + ASD->getAccessor(AccessorKind::Get), + AutoDiffConfig(differentiableAttr->getParameterIndices(), + IndexSubset::get(ASD->getASTContext(), 1, {0}), + differentiableAttr->getDerivativeGenericSignature())); } void TBDGenVisitor::visitVarDecl(VarDecl *VD) { diff --git a/lib/TBDGen/TBDGenVisitor.h b/lib/TBDGen/TBDGenVisitor.h index 68512ef2c767a..ac09ddec94bca 100644 --- a/lib/TBDGen/TBDGenVisitor.h +++ b/lib/TBDGen/TBDGenVisitor.h @@ -70,6 +70,14 @@ class TBDGenVisitor : public ASTVisitor { ModuleDecl *SwiftModule; const TBDGenOptions &Opts; + /// A set of original function and derivative configuration pairs for which + /// derivative symbols have been emitted. + /// + /// Used to deduplicate derivative symbol emission for `@differentiable` and + /// `@derivative` attributes. + llvm::DenseSet> + AddedDerivatives; + private: std::vector DeclStack; std::unique_ptr> @@ -98,6 +106,34 @@ class TBDGenVisitor : public ASTVisitor { void addAssociatedConformanceDescriptor(AssociatedConformance conformance); void addBaseConformanceDescriptor(BaseConformance conformance); + /// Adds the symbol for the linear map function of the given kind associated + /// with the given original function and derivative function configuration. + void addAutoDiffLinearMapFunction(AbstractFunctionDecl *original, + AutoDiffConfig config, + AutoDiffLinearMapKind kind); + + /// Adds the symbol for the autodiff function of the given kind associated + /// with the given original function, parameter indices, and derivative + /// generic signature. + void + addAutoDiffDerivativeFunction(AbstractFunctionDecl *original, + IndexSubset *parameterIndices, + GenericSignature derivativeGenericSignature, + AutoDiffDerivativeFunctionKind kind); + + /// Adds the symbol for the differentiability witness associated with the + /// given original function, AST parameter indices, result indices, and + /// derivative generic signature. + void addDifferentiabilityWitness(AbstractFunctionDecl *original, + IndexSubset *astParameterIndices, + IndexSubset *resultIndices, + GenericSignature derivativeGenericSignature); + + /// Adds symbols associated with the given original function and + /// derivative function configuration. + void addDerivativeConfiguration(AbstractFunctionDecl *original, + AutoDiffConfig config); + public: TBDGenVisitor(llvm::MachO::InterfaceFile &symbols, llvm::MachO::TargetList targets, StringSet *stringSymbols, diff --git a/stdlib/public/Differentiation/DifferentiationUtilities.swift b/stdlib/public/Differentiation/DifferentiationUtilities.swift index 81afe8c7e6c5c..9afbf3ea30ff5 100644 --- a/stdlib/public/Differentiation/DifferentiationUtilities.swift +++ b/stdlib/public/Differentiation/DifferentiationUtilities.swift @@ -90,3 +90,34 @@ public func withoutDerivative(at x: T) -> T { public func withoutDerivative(at x: T, in body: (T) -> R) -> R { body(x) } + +//===----------------------------------------------------------------------===// +// Diagnostics +//===----------------------------------------------------------------------===// + +@_silgen_name("_fatalErrorForwardModeDifferentiationDisabled") +public func _fatalErrorForwardModeDifferentiationDisabled() -> Never { + fatalError(""" + JVP does not exist. Use \ + '-Xfrontend -enable-experimental-forward-mode-differentiation' to enable \ + differential-first differentiation APIs. + """) +} + +// TODO(TF-1211): Remove this diagnostic helper function. +@_silgen_name("_fatalErrorJVPNotGenerated") +public func _fatalErrorJVPNotGenerated() -> Never { + fatalError(""" + Forward-mode automatic differentiation has not yet been upstreamed from \ + tensorflow branch. Tracked by https://bugs.swift.org/browse/TF-1211. + """) +} + +// TODO(TF-1211): Remove this diagnostic helper function. +@_silgen_name("_fatalErrorVJPNotGenerated") +public func _fatalErrorVJPNotGenerated() -> Never { + fatalError(""" + Reverse-mode automatic differentiation has not yet been upstreamed from \ + tensorflow branch. Tracked by https://bugs.swift.org/browse/TF-1211. + """) +} diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift new file mode 100644 index 0000000000000..200e8a15904b0 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -0,0 +1,69 @@ +// RUN: %target-swift-frontend -emit-sil -verify %s + +// Test differentiation transform diagnostics. + +import _Differentiation + +//===----------------------------------------------------------------------===// +// Basic function +//===----------------------------------------------------------------------===// + +@differentiable +func basic(_ x: Float) -> Float { + return x + 2 +} + +//===----------------------------------------------------------------------===// +// Control flow +//===----------------------------------------------------------------------===// + +@differentiable +func conditional(_ x: Float, _ flag: Bool) -> Float { + let y: Float + if flag { + y = x + 1 + } else { + y = x + } + return y +} + +// TF-433: Test `try_apply` differentiation. + +func throwing() throws -> Void {} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+2 {{when differentiating this function definition}} +@differentiable +func try_apply(_ x: Float) -> Float { + // expected-note @+1 {{cannot differentiate unsupported control flow}} + try! throwing() + return x +} + +func rethrowing(_ x: () throws -> Void) rethrows -> Void {} + +// expected-error @+2 {{function is not differentiable}} +// expected-note @+2 {{when differentiating this function definition}} +@differentiable +func try_apply_rethrows(_ x: Float) -> Float { + // expected-note @+1 {{cannot differentiate unsupported control flow}} + rethrowing({}) + return x +} + +//===----------------------------------------------------------------------===// +// Unreachable +//===----------------------------------------------------------------------===// + +let _: @differentiable (Float) -> Float = { x in + let _ = x + 1 + // expected-error @+1 {{missing return in a closure expected to return 'Float'}} +} + +//===----------------------------------------------------------------------===// +// Conversion to `@differentiable(linear)` (not yet supported) +//===----------------------------------------------------------------------===// + +// expected-error @+1 {{conversion to '@differentiable(linear)' function type is not yet supported}} +let _: @differentiable(linear) (Float) -> Float = { x in x } diff --git a/test/AutoDiff/SILOptimizer/differentiation_sil.swift b/test/AutoDiff/SILOptimizer/differentiation_sil.swift new file mode 100644 index 0000000000000..8ce9b95c36574 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/differentiation_sil.swift @@ -0,0 +1,39 @@ +// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=CHECK-SILGEN +// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s --check-prefix=CHECK-SIL + +// Simple differentiation transform test: check SIL before and after the transform. + +import _Differentiation + +@_silgen_name("basic") +@differentiable +func basic(_ x: Float) -> Float { x } + +// Test differentiability witnesses. + +// CHECK-SILGEN-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @basic : $@convention(thin) (Float) -> Float { +// CHECK-SILGEN-NEXT: } + +// CHECK-SIL-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @basic : $@convention(thin) (Float) -> Float { +// CHECK-SIL-NEXT: jvp: @AD__basic__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-SIL-NEXT: vjp: @AD__basic__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-SIL-NEXT: } + +// Test `differentiable_function` instructions. + +@_silgen_name("test_differentiable_function") +func testDifferentiableFunction() { + let _: @differentiable (Float) -> Float = basic +} + +// CHECK-SILGEN-LABEL: sil hidden [ossa] @test_differentiable_function : $@convention(thin) () -> () { +// CHECK-SILGEN: [[ORIG_FN_REF:%.*]] = function_ref @basic : $@convention(thin) (Float) -> Float +// CHECK-SILGEN: [[ORIG_FN:%.*]] = thin_to_thick_function [[ORIG_FN_REF]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float +// CHECK-SILGEN: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]] : $@callee_guaranteed (Float) -> Float +// CHECK-SILGEN: } + +// CHECK-SIL-LABEL: sil hidden @test_differentiable_function : $@convention(thin) () -> () { +// CHECK-SIL: [[ORIG_FN_REF:%.*]] = function_ref @basic : $@convention(thin) (Float) -> Float +// CHECK-SIL: [[ORIG_FN:%.*]] = thin_to_thick_function [[ORIG_FN_REF]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float +// CHECK-SIL: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]] : $@callee_guaranteed (Float) -> Float with_derivative {undef : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), undef : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} +// CHECK-SIL: } diff --git a/test/AutoDiff/Serialization/differentiable_attr.swift b/test/AutoDiff/Serialization/differentiable_attr.swift index cc76ca718c878..d410f7a79ca6c 100644 --- a/test/AutoDiff/Serialization/differentiable_attr.swift +++ b/test/AutoDiff/Serialization/differentiable_attr.swift @@ -40,7 +40,7 @@ func vjpSimple(x: Float) -> (Float, (Float) -> Float) { // CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float @differentiable(wrt: x) func testWrtClause(x: Float, y: Float) -> Float { - return x + y + return x } struct InstanceMethod : Differentiable { @@ -48,7 +48,7 @@ struct InstanceMethod : Differentiable { // CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float @differentiable(wrt: (self, y)) func testWrtClause(x: Float, y: Float) -> Float { - return x + y + return x } struct TangentVector: Differentiable, AdditiveArithmetic { diff --git a/test/AutoDiff/TBD/derivative_symbols.swift b/test/AutoDiff/TBD/derivative_symbols.swift new file mode 100644 index 0000000000000..44afd8001c159 --- /dev/null +++ b/test/AutoDiff/TBD/derivative_symbols.swift @@ -0,0 +1,56 @@ +// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s +// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s -O +// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing +// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing -O + +import _Differentiation + +@differentiable +public func topLevelDifferentiable(_ x: Float, _ y: Float) -> Float { x } + +public func topLevelHasDerivative(_ x: T) -> T { + x +} + +@derivative(of: topLevelHasDerivative) +public func topLevelDerivative(_ x: T) -> ( + value: T, pullback: (T.TangentVector) -> T.TangentVector +) { + fatalError() +} + +struct Struct: Differentiable { + var stored: Float + + // Test property. + @differentiable + public var property: Float { + stored + } + + // Test initializer. + @differentiable + public init(_ x: Float) { + stored = x + } + + // Test method. + public func method(x: Float, y: Float) -> Float { x } + + @derivative(of: method) + public func jvpMethod(x: Float, y: Float) -> ( + value: Float, differential: (TangentVector, Float, Float) -> Float + ) { + fatalError() + } + + // Test subscript. + public subscript(x: Float) -> Float { x } + + @derivative(of: subscript) + public func vjpSubscript(x: Float) -> ( + value: Float, pullback: (Float) -> (TangentVector, Float) + ) { + fatalError() + } +}