Skip to content

Commit 335f94d

Browse files
authored
Workaround for sr-13263 (#1)
* Update Normalization.swift
1 parent 652f815 commit 335f94d

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

Sources/TensorFlow/Layers/Normalization.swift

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,43 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
105105
precondition(
106106
input.shape[positiveAxis] == offset.shape[0],
107107
"The number of features of the input and the offset doesn't match.")
108-
var (offset, scale) = {x in (x.offset, x.scale) }(self)
109-
if positiveAxis != input.rank - 1 {
110-
var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
111-
broadcastShape[positiveAxis] = input.shape[positiveAxis]
112-
offset = offset.reshaped(to: broadcastShape)
113-
114-
scale = scale.reshaped(to: broadcastShape)
115-
}
108+
// var (offset, scale) = {x in (x.offset, x.scale) }(self)
109+
// if positiveAxis != input.rank - 1 {
110+
// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
111+
// broadcastShape[positiveAxis] = input.shape[positiveAxis]
112+
// offset = offset.reshaped(to: broadcastShape)
113+
// scale = scale.reshaped(to: broadcastShape)
114+
// }
115+
let offsetOriginal = self.offset
116+
let scaleOriginal = self.scale
117+
let (offset, scale) = Self._sr13263workaround(offset: offsetOriginal,
118+
scale: scaleOriginal,
119+
input: input,
120+
positiveAxis: positiveAxis)
116121
switch Context.local.learningPhase {
117122
case .training:
118123
return doTraining(input, offset: offset, scale: scale, axis: positiveAxis)
119124
case .inference:
120125
return doInference(input, offset: offset, scale: scale)
121126
}
122127
}
128+
129+
@inline(never)
130+
@differentiable(reverse) // if the function is `public` or `internal`, the compiler crashes
131+
private static func _sr13263workaround(
132+
offset: Tensor<Scalar>,
133+
scale: Tensor<Scalar>,
134+
input: Tensor<Scalar>,
135+
positiveAxis: Int
136+
) -> (Tensor<Scalar>, Tensor<Scalar>) {
137+
if positiveAxis != input.rank - 1 {
138+
var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
139+
broadcastShape[positiveAxis] = input.shape[positiveAxis]
140+
return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape))
141+
} else {
142+
return (offset, scale)
143+
}
144+
}
123145

124146
private func doTraining(
125147
_ input: Tensor<Scalar>, offset: Tensor<Scalar>, scale: Tensor<Scalar>, axis: Int

0 commit comments

Comments
 (0)