@@ -105,21 +105,43 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
105
105
precondition (
106
106
input. shape [ positiveAxis] == offset. shape [ 0 ] ,
107
107
" The number of features of the input and the offset doesn't match. " )
108
- var ( offset, scale) = { x in ( x. offset, x. scale) } ( self )
109
- if positiveAxis != input. rank - 1 {
110
- var broadcastShape = TensorShape ( [ Int] ( repeating: 1 , count: input. rank) )
111
- broadcastShape [ positiveAxis] = input. shape [ positiveAxis]
112
- offset = offset. reshaped ( to: broadcastShape)
113
-
114
- scale = scale. reshaped ( to: broadcastShape)
115
- }
108
+ // var (offset, scale) = {x in (x.offset, x.scale) }(self)
109
+ // if positiveAxis != input.rank - 1 {
110
+ // var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
111
+ // broadcastShape[positiveAxis] = input.shape[positiveAxis]
112
+ // offset = offset.reshaped(to: broadcastShape)
113
+ // scale = scale.reshaped(to: broadcastShape)
114
+ // }
115
+ let offsetOriginal = self . offset
116
+ let scaleOriginal = self . scale
117
+ let ( offset, scale) = Self . _sr13263workaround ( offset: offsetOriginal,
118
+ scale: scaleOriginal,
119
+ input: input,
120
+ positiveAxis: positiveAxis)
116
121
switch Context . local. learningPhase {
117
122
case . training:
118
123
return doTraining ( input, offset: offset, scale: scale, axis: positiveAxis)
119
124
case . inference:
120
125
return doInference ( input, offset: offset, scale: scale)
121
126
}
122
127
}
128
+
129
+ @inline ( never)
130
+ @differentiable ( reverse) // if the function is `public` or `internal`, the compiler crashes
131
+ private static func _sr13263workaround(
132
+ offset: Tensor < Scalar > ,
133
+ scale: Tensor < Scalar > ,
134
+ input: Tensor < Scalar > ,
135
+ positiveAxis: Int
136
+ ) -> ( Tensor < Scalar > , Tensor < Scalar > ) {
137
+ if positiveAxis != input. rank - 1 {
138
+ var broadcastShape = TensorShape ( [ Int] ( repeating: 1 , count: input. rank) )
139
+ broadcastShape [ positiveAxis] = input. shape [ positiveAxis]
140
+ return ( offset. reshaped ( to: broadcastShape) , scale. reshaped ( to: broadcastShape) )
141
+ } else {
142
+ return ( offset, scale)
143
+ }
144
+ }
123
145
124
146
private func doTraining(
125
147
_ input: Tensor < Scalar > , offset: Tensor < Scalar > , scale: Tensor < Scalar > , axis: Int
0 commit comments