Skip to content

[AutoDiff upstream] Add @noDerivative flag to SILParameterInfo. #29405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3637,13 +3637,35 @@ inline bool isGuaranteedParameter(ParameterConvention conv) {
llvm_unreachable("bad convention kind");
}

/// The differentiability of a SIL function type parameter.
enum class SILParameterDifferentiability : unsigned {
/// Either differentiable or not applicable.
///
/// - If the function type is not `@differentiable`, parameter
/// differentiability is not applicable. This case is the default value.
/// - If the function type is `@differentiable`, the function is
/// differentiable with respect to this parameter.
DifferentiableOrNotApplicable,

/// Not differentiable: a `@noDerivative` parameter.
///
/// May be applied only to parameters of `@differentiable` function types.
/// The function type is not differentiable with respect to this parameter.
NotDifferentiable,
};

/// A parameter type and the rules for passing it.
class SILParameterInfo {
llvm::PointerIntPair<CanType, 3, ParameterConvention> TypeAndConvention;
SILParameterDifferentiability Differentiability : 1;

public:
SILParameterInfo() = default;//: Ty(), Convention((ParameterConvention)0) {}
SILParameterInfo(CanType type, ParameterConvention conv)
: TypeAndConvention(type, conv) {
SILParameterInfo(
CanType type, ParameterConvention conv,
SILParameterDifferentiability differentiability =
SILParameterDifferentiability::DifferentiableOrNotApplicable)
: TypeAndConvention(type, conv), Differentiability(differentiability) {
assert(type->isLegalSILType() && "SILParameterInfo has illegal SIL type");
}

Expand Down Expand Up @@ -3698,6 +3720,16 @@ class SILParameterInfo {
return isGuaranteedParameter(getConvention());
}

SILParameterDifferentiability getDifferentiability() const {
return Differentiability;
}

SILParameterInfo getWithDifferentiability(
SILParameterDifferentiability differentiability) const {
return SILParameterInfo(getInterfaceType(), getConvention(),
differentiability);
}

/// The SIL storage type determines the ABI for arguments based purely on the
/// formal parameter conventions. The actual SIL type for the argument values
/// may differ in canonical SIL. In particular, opaque values require indirect
Expand Down Expand Up @@ -3726,6 +3758,7 @@ class SILParameterInfo {
void profile(llvm::FoldingSetNodeID &id) {
id.AddPointer(getInterfaceType().getPointer());
id.AddInteger((unsigned)getConvention());
id.AddInteger((unsigned)getDifferentiability());
}

SWIFT_DEBUG_DUMP;
Expand All @@ -3739,8 +3772,9 @@ class SILParameterInfo {
}

bool operator==(SILParameterInfo rhs) const {
return getInterfaceType() == rhs.getInterfaceType()
&& getConvention() == rhs.getConvention();
return getInterfaceType() == rhs.getInterfaceType() &&
getConvention() == rhs.getConvention() &&
getDifferentiability() == rhs.getDifferentiability();
}
bool operator!=(SILParameterInfo rhs) const {
return !(*this == rhs);
Expand Down Expand Up @@ -4093,6 +4127,13 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
return ExtInfo(NoEscape ? (Bits | NoEscapeMask) : (Bits & ~NoEscapeMask),
Other);
}
ExtInfo
withDifferentiabilityKind(DifferentiabilityKind differentiability) const {
return ExtInfo(
(Bits & ~DifferentiabilityMask) |
((unsigned)differentiability << DifferentiabilityMaskOffset),
Other);
}

std::pair<unsigned, const void *> getFuncAttrKey() const {
return std::make_pair(Bits, Other.ClangFunctionType);
Expand Down
9 changes: 9 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3327,6 +3327,15 @@ SILFunctionType::SILFunctionType(
"Cannot return an @noescape function type");
}
}

// Check that `@noDerivative` parameters only exist on `@differentiable`
// functions.
if (!ext.isDifferentiable())
for (auto param : getParameters())
assert(param.getDifferentiability() ==
SILParameterDifferentiability::DifferentiableOrNotApplicable &&
"non-`@differentiable` function should not have NotDifferentiable "
"parameter");
#endif
}

Expand Down
7 changes: 7 additions & 0 deletions lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4649,6 +4649,13 @@ void SILParameterInfo::print(raw_ostream &OS, const PrintOptions &Opts) const {
}
void SILParameterInfo::print(ASTPrinter &Printer,
const PrintOptions &Opts) const {
switch (getDifferentiability()) {
case SILParameterDifferentiability::NotDifferentiable:
Printer << "@noDerivative ";
break;
default:
break;
}
Printer << getStringForParameterConvention(getConvention());
getInterfaceType().print(Printer, Opts);
}
Expand Down
21 changes: 13 additions & 8 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,8 +754,8 @@ class DestructureInputs {
auto eltPattern = origType.getFunctionParamType(i);
auto flags = params[i].getParameterFlags();

visit(flags.getValueOwnership(), /*forSelf=*/false,
eltPattern, ty, silRepresentation);
visit(flags.getValueOwnership(), /*forSelf=*/false, eltPattern, ty,
silRepresentation, flags.isNoDerivative());
}

// Process the self parameter. Note that we implicitly drop self
Expand All @@ -776,7 +776,8 @@ class DestructureInputs {

void visit(ValueOwnership ownership, bool forSelf,
AbstractionPattern origType, CanType substType,
SILFunctionTypeRepresentation rep) {
SILFunctionTypeRepresentation rep,
bool isNonDifferentiable = false) {
assert(!isa<InOutType>(substType));

// Tuples get handled specially, in some cases:
Expand Down Expand Up @@ -829,9 +830,12 @@ class DestructureInputs {
substTLConv);
assert(!isIndirectFormalParameter(convention));
}

Inputs.push_back(SILParameterInfo(
substTL.getLoweredType().getASTType(), convention));

SILParameterInfo param(substTL.getLoweredType().getASTType(), convention);
if (isNonDifferentiable)
param = param.getWithDifferentiability(
SILParameterDifferentiability::NotDifferentiable);
Inputs.push_back(param);

maybeAddForeignParameters();
}
Expand Down Expand Up @@ -1269,7 +1273,8 @@ static CanSILFunctionType getSILFunctionType(
auto silExtInfo = SILFunctionType::ExtInfo()
.withRepresentation(extInfo.getSILRepresentation())
.withIsPseudogeneric(pseudogeneric)
.withNoEscape(extInfo.isNoEscape());
.withNoEscape(extInfo.isNoEscape())
.withDifferentiabilityKind(extInfo.getDifferentiabilityKind());

// Build the substituted generic signature we extracted.
bool impliedSignature = false;
Expand Down Expand Up @@ -2734,7 +2739,7 @@ class SILTypeSubstituter :

SILParameterInfo substInterface(SILParameterInfo orig) {
return SILParameterInfo(visit(orig.getInterfaceType()),
orig.getConvention());
orig.getConvention(), orig.getDifferentiability());
}

/// Tuples need to have their component types substituted by these
Expand Down
9 changes: 8 additions & 1 deletion lib/Sema/TypeCheckType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2948,6 +2948,8 @@ SILParameterInfo TypeResolver::resolveSILParameter(
auto convention = DefaultParameterConvention;
Type type;
bool hadError = false;
auto differentiability =
SILParameterDifferentiability::DifferentiableOrNotApplicable;

if (auto attrRepr = dyn_cast<AttributedTypeRepr>(repr)) {
auto attrs = attrRepr->getAttrs();
Expand All @@ -2973,6 +2975,10 @@ SILParameterInfo TypeResolver::resolveSILParameter(
checkFor(TypeAttrKind::TAK_owned, ParameterConvention::Direct_Owned);
checkFor(TypeAttrKind::TAK_guaranteed,
ParameterConvention::Direct_Guaranteed);
if (attrs.has(TAK_noDerivative)) {
attrs.clearAttribute(TAK_noDerivative);
differentiability = SILParameterDifferentiability::NotDifferentiable;
}

type = resolveAttributedType(attrs, attrRepr->getTypeRepr(), options);
} else {
Expand All @@ -2989,7 +2995,8 @@ SILParameterInfo TypeResolver::resolveSILParameter(
}

if (hadError) type = ErrorType::get(Context);
return SILParameterInfo(type->getCanonicalType(), convention);
return SILParameterInfo(type->getCanonicalType(), convention,
differentiability);
}

bool TypeResolver::resolveSingleSILResult(TypeRepr *repr,
Expand Down
37 changes: 33 additions & 4 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4503,6 +4503,21 @@ Optional<swift::ParameterConvention> getActualParameterConvention(uint8_t raw) {
return None;
}

/// Translate from the serialization SILParameterDifferentiability enumerators,
/// which are guaranteed to be stable, to the AST ones.
static Optional<swift::SILParameterDifferentiability>
getActualSILParameterDifferentiability(uint8_t raw) {
switch (serialization::SILParameterDifferentiability(raw)) {
#define CASE(ID) \
case serialization::SILParameterDifferentiability::ID: \
return swift::SILParameterDifferentiability::ID;
CASE(DifferentiableOrNotApplicable)
CASE(NotDifferentiable)
#undef CASE
}
return None;
}

/// Translate from the serialization ResultConvention enumerators,
/// which are guaranteed to be stable, to the AST ones.
static
Expand Down Expand Up @@ -5144,15 +5159,26 @@ class TypeDeserializer {
if (!calleeConvention.hasValue())
MF.fatal();

auto processParameter = [&](TypeID typeID, uint64_t rawConvention)
-> llvm::Expected<SILParameterInfo> {
auto processParameter =
[&](TypeID typeID, uint64_t rawConvention,
uint64_t ramDifferentiability) -> llvm::Expected<SILParameterInfo> {
auto convention = getActualParameterConvention(rawConvention);
if (!convention)
MF.fatal();
auto type = MF.getTypeChecked(typeID);
if (!type)
return type.takeError();
return SILParameterInfo(type.get()->getCanonicalType(), *convention);
auto differentiability =
swift::SILParameterDifferentiability::DifferentiableOrNotApplicable;
if (diffKind != DifferentiabilityKind::NonDifferentiable) {
auto differentiabilityOpt =
getActualSILParameterDifferentiability(ramDifferentiability);
if (!differentiabilityOpt)
MF.fatal();
differentiability = *differentiabilityOpt;
}
return SILParameterInfo(type.get()->getCanonicalType(), *convention,
differentiability);
};

auto processYield = [&](TypeID typeID, uint64_t rawConvention)
Expand Down Expand Up @@ -5191,7 +5217,10 @@ class TypeDeserializer {
for (unsigned i = 0; i != numParams; ++i) {
auto typeID = variableData[nextVariableDataIndex++];
auto rawConvention = variableData[nextVariableDataIndex++];
auto param = processParameter(typeID, rawConvention);
uint64_t differentiability = 0;
if (diffKind != DifferentiabilityKind::NonDifferentiable)
differentiability = variableData[nextVariableDataIndex++];
auto param = processParameter(typeID, rawConvention, differentiability);
if (!param)
return param.takeError();
allParams.push_back(param.get());
Expand Down
9 changes: 8 additions & 1 deletion lib/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
const uint16_t SWIFTMODULE_VERSION_MINOR = 533; // removed @_implicitly_synthesizes_nested_requirement
const uint16_t SWIFTMODULE_VERSION_MINOR = 534; // add SIL parameter differentiability

/// A standard hash seed used for all string hashes in a serialized module.
///
Expand Down Expand Up @@ -347,6 +347,13 @@ enum class ParameterConvention : uint8_t {
};
using ParameterConventionField = BCFixed<4>;

// These IDs must \em not be renumbered or reordered without incrementing
// the module version.
enum class SILParameterDifferentiability : uint8_t {
DifferentiableOrNotApplicable,
NotDifferentiable,
};

// These IDs must \em not be renumbered or reordered without incrementing
// the module version.
enum class ResultConvention : uint8_t {
Expand Down
14 changes: 14 additions & 0 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3766,6 +3766,17 @@ static uint8_t getRawStableParameterConvention(swift::ParameterConvention pc) {
llvm_unreachable("bad parameter convention kind");
}

/// Translate from AST SILParameterDifferentiability enum to the Serialization
/// enum values, which are guaranteed to be stable.
static uint8_t
getRawSILParameterDifferentiability(swift::SILParameterDifferentiability pd) {
switch (pd) {
SIMPLE_CASE(SILParameterDifferentiability, DifferentiableOrNotApplicable)
SIMPLE_CASE(SILParameterDifferentiability, NotDifferentiable)
}
llvm_unreachable("bad parameter differentiability kind");
}

/// Translate from the AST ResultConvention enum to the
/// Serialization enum values, which are guaranteed to be stable.
static uint8_t getRawStableResultConvention(swift::ResultConvention rc) {
Expand Down Expand Up @@ -4075,6 +4086,9 @@ class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
variableData.push_back(S.addTypeRef(param.getInterfaceType()));
unsigned conv = getRawStableParameterConvention(param.getConvention());
variableData.push_back(TypeID(conv));
if (fnTy->isDifferentiable())
variableData.push_back(TypeID(
getRawSILParameterDifferentiability(param.getDifferentiability())));
}
for (auto yield : fnTy->getYields()) {
variableData.push_back(S.addTypeRef(yield.getInterfaceType()));
Expand Down
20 changes: 20 additions & 0 deletions test/AutoDiff/SIL/Serialization/differentiation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,23 @@ bb0(%0 : $@differentiable(linear) (Float) -> Float):
// CHECK: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float):
// CHECK: return [[ARG]] : $@differentiable(linear) (Float) -> Float
// CHECK: }

sil @c : $@convention(thin) (@differentiable (Float, @noDerivative Float) -> Float) -> @differentiable (Float, @noDerivative Float) -> Float {
bb0(%0 : $@differentiable (Float, @noDerivative Float) -> Float):
return %0 : $@differentiable (Float, @noDerivative Float) -> Float
}

// CHECK-LABEL: sil @c : $@convention(thin) (@differentiable (Float, @noDerivative Float) -> Float) -> @differentiable (Float, @noDerivative Float) -> Float {
// CHECK: bb0(%0 : $@differentiable (Float, @noDerivative Float) -> Float):
// CHECK: return %0 : $@differentiable (Float, @noDerivative Float) -> Float
// CHECK: }

sil @d : $@convention(thin) (@differentiable(linear) (Float, @noDerivative Float) -> Float) -> @differentiable(linear) (Float, @noDerivative Float) -> Float {
bb0(%0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float):
return %0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float
}

// CHECK-LABEL: sil @d : $@convention(thin) (@differentiable(linear) (Float, @noDerivative Float) -> Float) -> @differentiable(linear) (Float, @noDerivative Float) -> Float {
// CHECK: bb0(%0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float):
// CHECK: return %0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float
// CHECK: }
55 changes: 55 additions & 0 deletions test/AutoDiff/SILGen/differentiable_function.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: %target-swift-frontend -emit-silgen -enable-experimental-differentiable-programming %s | %FileCheck %s

// Test SILGen for `@differentiable` function typed values.

import _Differentiation

@_silgen_name("differentiable")
func differentiable(_ fn: @escaping @differentiable (Float) -> Float)
-> @differentiable (Float) -> Float {
return fn
}

@_silgen_name("linear")
func linear(_ fn: @escaping @differentiable(linear) (Float) -> Float)
-> @differentiable(linear) (Float) -> Float {
return fn
}

@_silgen_name("differentiable_noDerivative")
func differentiable_noDerivative(
_ fn: @escaping @differentiable (Float, @noDerivative Float) -> Float
) -> @differentiable (Float, @noDerivative Float) -> Float {
return fn
}

@_silgen_name("linear_noDerivative")
func linear_noDerivative(
_ fn: @escaping @differentiable(linear) (Float, @noDerivative Float) -> Float
) -> @differentiable(linear) (Float, @noDerivative Float) -> Float {
return fn
}

// CHECK-LABEL: sil hidden [ossa] @differentiable : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> @owned @differentiable @callee_guaranteed (Float) -> Float {
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float) -> Float):
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable @callee_guaranteed (Float) -> Float
// CHECK: return [[COPIED_FN]] : $@differentiable @callee_guaranteed (Float) -> Float
// CHECK: }

// CHECK-LABEL: sil hidden [ossa] @linear : $@convention(thin) (@guaranteed @differentiable(linear) @callee_guaranteed (Float) -> Float) -> @owned @differentiable(linear) @callee_guaranteed (Float) -> Float {
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float) -> Float):
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
// CHECK: return [[COPIED_FN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
// CHECK: }

// CHECK-LABEL: sil hidden [ossa] @differentiable_noDerivative : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float {
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float):
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float
// CHECK: return [[COPIED_FN]] : $@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float
// CHECK: }

// CHECK-LABEL: sil hidden [ossa] @linear_noDerivative : $@convention(thin) (@guaranteed @differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float {
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float):
// CHECK: [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float
// CHECK: return [[COPIED_FN]] : $@differentiable(linear) @callee_guaranteed (Float, @noDerivative Float) -> Float
// CHECK: }