Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5013c57

Browse files
committed
TensorFlow: annotate functions using _forEachFieldWithKeyPath
Because the standard toolchain support requires the newest runtime in order to have access to `_forEachFieldWithKeyPath` and the runtime is bundled into the OS on macOS, we need to annotate the functions with availability.
1 parent fd04e95 commit 5013c57

6 files changed

+25
-2
lines changed

Sources/TensorFlow/Core/ElementaryFunctions.swift

+4
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
// limitations under the License.
1414

1515
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
16+
1617
import Numerics
1718
@_spi(Reflection) import Swift
19+
1820
extension ElementaryFunctions {
21+
@available(macOS 9999, *)
1922
internal static func visitChildren(
2023
_ body: (PartialKeyPath<Self>, ElementaryFunctionsVisit.Type) -> Void
2124
) {
@@ -164,4 +167,5 @@ extension ElementaryFunctions {
164167
public static func root(_ x: Self, _ n: Int) -> Self { .init(mapped: Functor_root(n: n), x) }
165168
public static func pow(_ x: Self, _ y: Self) -> Self { .init(mapped: Functor_pow2(), x, y) }
166169
}
170+
167171
#endif

Sources/TensorFlow/Core/EuclideanDifferentiable.swift

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import _Differentiation
1616

1717
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
18+
1819
@_spi(Reflection) import Swift
1920

21+
@available(macOS 9999, *)
2022
func listFields<Root>(of type: Root.Type) -> [(String, PartialKeyPath<Root>)] {
2123
var out = [(String, PartialKeyPath<Root>)]()
2224
_forEachFieldWithKeyPath(of: type, options: .ignoreUnknown) { name, kp in
@@ -27,8 +29,8 @@ func listFields<Root>(of type: Root.Type) -> [(String, PartialKeyPath<Root>)] {
2729
}
2830

2931
extension Differentiable {
30-
static var differentiableFields: [(String, PartialKeyPath<Self>, PartialKeyPath<TangentVector>)]
31-
{
32+
@available(macOS 9999, *)
33+
static var differentiableFields: [(String, PartialKeyPath<Self>, PartialKeyPath<TangentVector>)] {
3234
let tangentFields = listFields(of: TangentVector.self)
3335
var i = 0
3436
var out = [(String, PartialKeyPath<Self>, PartialKeyPath<TangentVector>)]()

Sources/TensorFlow/Core/KeyPathIterable.swift

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import _Differentiation
2222

2323
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
24+
2425
@_spi(Reflection) import Swift
2526

2627
/// An implementation detail of `KeyPathIterable`; do not use this protocol
@@ -41,6 +42,7 @@ public protocol KeyPathIterable: _KeyPathIterableBase {
4142
}
4243

4344
public extension KeyPathIterable {
45+
@available(macOS 9999, *)
4446
var allKeyPaths: [PartialKeyPath<Self>] {
4547
var out = [PartialKeyPath<Self>]()
4648
_forEachFieldWithKeyPath(of: Self.self, options: .ignoreUnknown) { name, kp in
@@ -171,4 +173,5 @@ extension Optional.TangentVector: KeyPathIterable {
171173
return []
172174
}
173175
}
176+
174177
#endif

Sources/TensorFlow/Core/PointwiseMultiplicative.swift

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import _Differentiation
1616

1717
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
18+
1819
@_spi(Reflection) import Swift
1920

2021
infix operator .*: MultiplicationPrecedence
@@ -111,6 +112,7 @@ extension PointwiseMultiplicative {
111112
}
112113

113114
extension PointwiseMultiplicative {
115+
@available(macOS 9999, *)
114116
internal static func visitChildren(
115117
_ body: (PartialKeyPath<Self>, _PointwiseMultiplicative.Type) -> Void
116118
) {
@@ -134,4 +136,5 @@ extension PointwiseMultiplicative {
134136
extension Array.DifferentiableView: _PointwiseMultiplicative
135137
where Element: Differentiable & PointwiseMultiplicative {}
136138
extension Tensor: _PointwiseMultiplicative where Scalar: Numeric {}
139+
137140
#endif

Sources/TensorFlow/Core/TensorGroup.swift

+8
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,10 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
340340
}
341341

342342
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
343+
343344
@_spi(Reflection) import Swift
344345

346+
@available(macOS 9999, *)
345347
func reflectionInit<T>(type: T.Type, body: (inout T, PartialKeyPath<T>) -> Void) -> T {
346348
let x = UnsafeMutablePointer<T>.allocate(capacity: 1)
347349
defer { x.deallocate() }
@@ -355,6 +357,7 @@ func reflectionInit<T>(type: T.Type, body: (inout T, PartialKeyPath<T>) -> Void)
355357
}
356358

357359
extension TensorGroup {
360+
@available(macOS 9999, *)
358361
public static var _typeList: [TensorDataType] {
359362
var out = [TensorDataType]()
360363
if !(_forEachFieldWithKeyPath(of: Self.self) { name, kp in
@@ -366,6 +369,7 @@ extension TensorGroup {
366369
}
367370
return out
368371
}
372+
369373
public static func initialize<Root>(
370374
_ base: inout Root, _ kp: PartialKeyPath<Root>,
371375
_owning tensorHandles: UnsafePointer<CTensorHandle>?
@@ -377,6 +381,7 @@ extension TensorGroup {
377381
v.initialize(to: .init(_owning: tensorHandles))
378382
}
379383
}
384+
380385
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
381386
var i = 0
382387
self = reflectionInit(type: Self.self) { base, kp in
@@ -387,6 +392,8 @@ extension TensorGroup {
387392
i += Int(valueType._tensorHandleCount)
388393
}
389394
}
395+
396+
@available(macOS 9999, *)
390397
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
391398
var i = 0
392399
if !_forEachFieldWithKeyPath(of: Self.self, body: { name, kp in
@@ -399,4 +406,5 @@ extension TensorGroup {
399406
}
400407
}
401408
}
409+
402410
#endif

Sources/TensorFlow/Core/VectorProtocol.swift

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import _Differentiation
1616

1717
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
18+
1819
@_spi(Reflection) import Swift
1920

2021
/// Implementation detail for reflection.
@@ -35,6 +36,7 @@ public protocol _VectorProtocol {
3536
}
3637

3738
extension VectorProtocol {
39+
@available(macOS 9999, *)
3840
internal static func visitChildren(
3941
_ body: (PartialKeyPath<Self>, _VectorProtocol.Type) -> Void
4042
) {
@@ -127,4 +129,5 @@ extension VectorProtocol {
127129
extension Tensor: _VectorProtocol where Scalar: TensorFlowFloatingPoint {}
128130
extension Array.DifferentiableView: _VectorProtocol
129131
where Element: Differentiable & VectorProtocol {}
132+
130133
#endif

0 commit comments

Comments
 (0)