Skip to content

Workaround for sr-13263 #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jan 13, 2022
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
36cdb76
Update Normalization.swift
philipturner Jan 9, 2022
cea4777
Update Normalization.swift
philipturner Jan 9, 2022
fef6edd
Update Normalization.swift
philipturner Jan 9, 2022
6e7d8bd
Update Normalization.swift
philipturner Jan 9, 2022
d7dc92e
Update Normalization.swift
philipturner Jan 9, 2022
3fc82a8
Update Normalization.swift
philipturner Jan 9, 2022
9e3b130
Update Normalization.swift
philipturner Jan 9, 2022
da6804d
Update Normalization.swift
philipturner Jan 9, 2022
f798332
Update Normalization.swift
philipturner Jan 9, 2022
1c4947f
Update Normalization.swift
philipturner Jan 9, 2022
51dfcb6
Update Normalization.swift
philipturner Jan 9, 2022
f88c2c9
Update Normalization.swift
philipturner Jan 9, 2022
8b87c39
Update Normalization.swift
philipturner Jan 9, 2022
c85e761
Update Normalization.swift
philipturner Jan 9, 2022
0a6c6ab
Update Normalization.swift
philipturner Jan 9, 2022
9176237
Update Normalization.swift
philipturner Jan 9, 2022
fe062ac
Update Normalization.swift
philipturner Jan 9, 2022
9545583
Update Normalization.swift
philipturner Jan 9, 2022
449b1c5
Update Normalization.swift
philipturner Jan 9, 2022
718b988
Update Normalization.swift
philipturner Jan 9, 2022
b1e1475
Update Normalization.swift
philipturner Jan 9, 2022
ec5c664
Update Normalization.swift
philipturner Jan 9, 2022
c34889c
Update Normalization.swift
philipturner Jan 9, 2022
4cb7676
Update Normalization.swift
philipturner Jan 9, 2022
b771936
Update Normalization.swift
philipturner Jan 9, 2022
0126970
Update Normalization.swift
philipturner Jan 9, 2022
da1722e
Update Normalization.swift
philipturner Jan 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions Sources/TensorFlow/Layers/Normalization.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,43 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
precondition(
input.shape[positiveAxis] == offset.shape[0],
"The number of features of the input and the offset doesn't match.")
var (offset, scale) = {x in (x.offset, x.scale) }(self)
if positiveAxis != input.rank - 1 {
var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
broadcastShape[positiveAxis] = input.shape[positiveAxis]
offset = offset.reshaped(to: broadcastShape)

scale = scale.reshaped(to: broadcastShape)
}
// var (offset, scale) = {x in (x.offset, x.scale) }(self)
// if positiveAxis != input.rank - 1 {
// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
// broadcastShape[positiveAxis] = input.shape[positiveAxis]
// offset = offset.reshaped(to: broadcastShape)
// scale = scale.reshaped(to: broadcastShape)
// }
let offsetOriginal = self.offset
let scaleOriginal = self.scale
let (offset, scale) = Self._sr13263workaround(offset: offsetOriginal,
scale: scaleOriginal,
input: input,
positiveAxis: positiveAxis)
switch Context.local.learningPhase {
case .training:
return doTraining(input, offset: offset, scale: scale, axis: positiveAxis)
case .inference:
return doInference(input, offset: offset, scale: scale)
}
}

@inline(never)
@differentiable(reverse) // if the function is `public` or `internal`, the compiler crashes
private static func _sr13263workaround(
offset: Tensor<Scalar>,
scale: Tensor<Scalar>,
input: Tensor<Scalar>,
positiveAxis: Int
) -> (Tensor<Scalar>, Tensor<Scalar>) {
if positiveAxis != input.rank - 1 {
var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
broadcastShape[positiveAxis] = input.shape[positiveAxis]
return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape))
} else {
return (offset, scale)
}
}

private func doTraining(
_ input: Tensor<Scalar>, offset: Tensor<Scalar>, scale: Tensor<Scalar>, axis: Int
Expand Down