Skip to content

Commit f3685f9

Browse files
authored
[AutoDiff] Conform Optional to Differentiable. (#32948)
Make `Optional` conditionally conform to `Differentiable` when the `Wrapped` type does. `Optional.TangentVector` is a wrapper around `Wrapped.TangentVector?`. Also, fix `Array.TangentVector.zeroTangentVectorInitializer`. Resolves TF-1301.
1 parent a6b9815 commit f3685f9

File tree

5 files changed

+201
-1
lines changed

5 files changed

+201
-1
lines changed

docs/DifferentiableProgramming.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,14 @@ extension Optional: Differentiable where Wrapped: Differentiable {
11921192

11931193
@noDerivative
11941194
public var zeroTangentVectorInitializer: () -> TangentVector {
1195-
{ TangentVector(.zero) }
1195+
switch self {
1196+
case nil:
1197+
return { TangentVector(nil) }
1198+
case let x?:
1199+
return { [zeroTanInit = x.zeroTangentVectorInitializer] in
1200+
TangentVector(zeroTanInit())
1201+
}
1202+
}
11961203
}
11971204
}
11981205
```

stdlib/public/Differentiation/ArrayDifferentiation.swift

+6
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ where Element: Differentiable {
8282
base[i].move(along: direction.base[i])
8383
}
8484
}
85+
86+
/// A closure that produces a `TangentVector` of zeros with the same
87+
/// `count` as `self`.
88+
public var zeroTangentVectorInitializer: () -> TangentVector {
89+
return base.zeroTangentVectorInitializer
90+
}
8591
}
8692

8793
extension Array.DifferentiableView: Equatable

stdlib/public/Differentiation/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
1616
DifferentiationUtilities.swift
1717
AnyDifferentiable.swift
1818
ArrayDifferentiation.swift
19+
OptionalDifferentiation.swift
1920

2021
GYB_SOURCES
2122
FloatingPointDifferentiation.swift.gyb
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//===--- OptionalDifferentiation.swift ------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import Swift
14+
15+
extension Optional: Differentiable where Wrapped: Differentiable {
16+
public struct TangentVector: Differentiable, AdditiveArithmetic {
17+
public typealias TangentVector = Self
18+
19+
public var value: Wrapped.TangentVector?
20+
21+
public init(_ value: Wrapped.TangentVector?) {
22+
self.value = value
23+
}
24+
25+
public static var zero: Self {
26+
return Self(.zero)
27+
}
28+
29+
public static func + (lhs: Self, rhs: Self) -> Self {
30+
switch (lhs.value, rhs.value) {
31+
case (nil, nil): return Self(nil)
32+
case let (x?, nil): return Self(x)
33+
case let (nil, y?): return Self(y)
34+
case let (x?, y?): return Self(x + y)
35+
}
36+
}
37+
38+
public static func - (lhs: Self, rhs: Self) -> Self {
39+
switch (lhs.value, rhs.value) {
40+
case (nil, nil): return Self(nil)
41+
case let (x?, nil): return Self(x)
42+
case let (nil, y?): return Self(.zero - y)
43+
case let (x?, y?): return Self(x - y)
44+
}
45+
}
46+
47+
public mutating func move(along direction: TangentVector) {
48+
if let value = direction.value {
49+
self.value?.move(along: value)
50+
}
51+
}
52+
53+
@noDerivative
54+
public var zeroTangentVectorInitializer: () -> TangentVector {
55+
switch value {
56+
case nil:
57+
return { Self(nil) }
58+
case let x?:
59+
return { [zeroTanInit = x.zeroTangentVectorInitializer] in
60+
Self(zeroTanInit())
61+
}
62+
}
63+
}
64+
}
65+
66+
public mutating func move(along direction: TangentVector) {
67+
if let value = direction.value {
68+
self?.move(along: value)
69+
}
70+
}
71+
72+
@noDerivative
73+
public var zeroTangentVectorInitializer: () -> TangentVector {
74+
switch self {
75+
case nil:
76+
return { TangentVector(nil) }
77+
case let x?:
78+
return { [zeroTanInit = x.zeroTangentVectorInitializer] in
79+
TangentVector(zeroTanInit())
80+
}
81+
}
82+
}
83+
}

test/AutoDiff/stdlib/optional.swift

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import _Differentiation
5+
import StdlibUnittest
6+
7+
var OptionalDifferentiationTests = TestSuite("OptionalDifferentiation")
8+
9+
OptionalDifferentiationTests.test("Optional operations") {
10+
// Differentiable.move(along:)
11+
do {
12+
var some: Float? = 2
13+
some.move(along: .init(3))
14+
expectEqual(5, some)
15+
16+
var none: Float? = nil
17+
none.move(along: .init(3))
18+
expectEqual(nil, none)
19+
}
20+
21+
// Differentiable.zeroTangentVectorInitializer
22+
do {
23+
let some: [Float]? = [1, 2, 3]
24+
expectEqual(.init([0, 0, 0]), some.zeroTangentVectorInitializer())
25+
26+
let none: [Float]? = nil
27+
expectEqual(.init(nil), none.zeroTangentVectorInitializer())
28+
}
29+
}
30+
31+
OptionalDifferentiationTests.test("Optional.TangentVector operations") {
32+
// Differentiable.move(along:)
33+
do {
34+
var some: Optional<Float>.TangentVector = .init(2)
35+
some.move(along: .init(3))
36+
expectEqual(5, some.value)
37+
38+
var none: Optional<Float>.TangentVector = .init(nil)
39+
none.move(along: .init(3))
40+
expectEqual(nil, none.value)
41+
42+
var nestedSome: Optional<Optional<Float>>.TangentVector = .init(.init(2))
43+
nestedSome.move(along: .init(.init(3)))
44+
expectEqual(.init(5), nestedSome.value)
45+
46+
var nestedNone: Optional<Optional<Float>>.TangentVector = .init(.init(nil))
47+
nestedNone.move(along: .init(.init(3)))
48+
expectEqual(.init(nil), nestedNone.value)
49+
}
50+
51+
// Differentiable.zeroTangentVectorInitializer
52+
do {
53+
let some: [Float]? = [1, 2, 3]
54+
expectEqual(.init([0, 0, 0]), some.zeroTangentVectorInitializer())
55+
56+
let none: [Float]? = nil
57+
expectEqual(.init(nil), none.zeroTangentVectorInitializer())
58+
59+
let nestedSome: [Float]?? = [1, 2, 3]
60+
expectEqual(.init(.init([0, 0, 0])), nestedSome.zeroTangentVectorInitializer())
61+
62+
let nestedNone: [Float]?? = nil
63+
expectEqual(.init(nil), nestedNone.zeroTangentVectorInitializer())
64+
}
65+
66+
// AdditiveArithmetic.zero
67+
expectEqual(.init(Float.zero), Float?.TangentVector.zero)
68+
expectEqual(.init([Float].TangentVector.zero), [Float]?.TangentVector.zero)
69+
70+
expectEqual(.init(.init(Float.zero)), Float??.TangentVector.zero)
71+
expectEqual(.init(.init([Float].TangentVector.zero)), [Float]??.TangentVector.zero)
72+
73+
// AdditiveArithmetic.+, AdditiveArithmetic.-
74+
do {
75+
let some: Optional<Float>.TangentVector = .init(2)
76+
let none: Optional<Float>.TangentVector = .init(nil)
77+
78+
expectEqual(.init(4), some + some)
79+
expectEqual(.init(2), some + none)
80+
expectEqual(.init(2), none + some)
81+
expectEqual(.init(nil), none + none)
82+
83+
expectEqual(.init(0), some - some)
84+
expectEqual(.init(2), some - none)
85+
expectEqual(.init(-2), none - some)
86+
expectEqual(.init(nil), none - none)
87+
88+
let nestedSome: Optional<Optional<Float>>.TangentVector = .init(.init(2))
89+
let nestedNone: Optional<Optional<Float>>.TangentVector = .init(.init(nil))
90+
91+
expectEqual(.init(.init(4)), nestedSome + nestedSome)
92+
expectEqual(.init(.init(2)), nestedSome + nestedNone)
93+
expectEqual(.init(.init(2)), nestedNone + nestedSome)
94+
expectEqual(.init(.init(nil)), nestedNone + nestedNone)
95+
96+
expectEqual(.init(.init(0)), nestedSome - nestedSome)
97+
expectEqual(.init(.init(2)), nestedSome - nestedNone)
98+
expectEqual(.init(.init(-2)), nestedNone - nestedSome)
99+
expectEqual(.init(.init(nil)), nestedNone - nestedNone)
100+
}
101+
}
102+
103+
runAllTests()

0 commit comments

Comments
 (0)