From 36cdb76ab07c791b2383cfabec32dabe0bc53988 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 15:05:03 -0500 Subject: [PATCH 01/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index c77430fe4..4415fe8b5 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -105,7 +105,8 @@ public struct BatchNorm: Layer { precondition( input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") - var (offset, scale) = {x in (x.offset, x.scale) }(self) + var offset = self.offset + var scale = self.scale if positiveAxis != input.rank - 1 { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] From cea4777af5cfe7585e9ddeb6120194306687c220 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 15:25:55 -0500 Subject: [PATCH 02/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 4415fe8b5..9c0c68728 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -105,15 +105,29 @@ public struct BatchNorm: Layer { precondition( input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") - var offset = self.offset - var scale = self.scale - if positiveAxis != input.rank - 1 { - var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) - broadcastShape[positiveAxis] = input.shape[positiveAxis] - offset = offset.reshaped(to: broadcastShape) - - scale = scale.reshaped(to: broadcastShape) + + // Will document the SR name shortly - try @inline(never) if it doesn't work? + func srNameWorkaround( + params: (offset: Tensor, scale: Tensor) + ) -> (offset: Tensor, scale: Tensor) { + if positiveAxis == input.rank - 1 { + return params + } else { + var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) + broadcastShape[positiveAxis] = input.shape[positiveAxis] + return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) + } } + let (offset, scale) = srNameWorkaround(params: (self.offset, self.scale)) +// var offset = self.offset +// var scale = self.scale +// if positiveAxis != input.rank - 1 { +// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) +// broadcastShape[positiveAxis] = input.shape[positiveAxis] +// offset = offset.reshaped(to: broadcastShape) + +// scale = scale.reshaped(to: broadcastShape) +// } switch Context.local.learningPhase { case .training: return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) From fef6edd3cdec96ca0b5c79be717f3dba2bb94cd8 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 15:28:37 -0500 Subject: [PATCH 03/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 9c0c68728..58fef3a44 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -115,7 +115,7 @@ public struct BatchNorm: Layer { } else { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] - return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) + return (params.offset.reshaped(to: broadcastShape), params.scale.reshaped(to: broadcastShape)) } } let (offset, scale) = srNameWorkaround(params: (self.offset, self.scale)) From 6e7d8bded3fc7eaea82c693ea9c933e73f14f7ec Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 15:32:43 -0500 Subject: [PATCH 04/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 58fef3a44..ec78198f3 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -110,13 +110,13 @@ public struct BatchNorm: Layer { func srNameWorkaround( params: (offset: Tensor, scale: Tensor) ) -> (offset: Tensor, scale: Tensor) { - if positiveAxis == input.rank - 1 { - return params - } else { +// if positiveAxis == input.rank - 1 { +// return params +// } else { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] return (params.offset.reshaped(to: broadcastShape), params.scale.reshaped(to: broadcastShape)) - } +// } } let (offset, scale) = srNameWorkaround(params: (self.offset, self.scale)) // var offset = self.offset From d7dc92e81f01c2c777002f1591a9f32be83fdfd0 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 15:46:43 -0500 Subject: [PATCH 05/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 65 +++++++++++++------ 1 file changed, 46 insertions(+), 19 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index ec78198f3..2ca08bfd7 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -106,33 +106,60 @@ public struct BatchNorm: Layer { input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") - // Will document the SR name shortly - try @inline(never) if it doesn't work? - func srNameWorkaround( - params: (offset: Tensor, scale: Tensor) - ) -> (offset: Tensor, scale: Tensor) { -// if positiveAxis == input.rank - 1 { -// return params -// } else { - var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) - broadcastShape[positiveAxis] = input.shape[positiveAxis] - return (params.offset.reshaped(to: broadcastShape), params.scale.reshaped(to: broadcastShape)) -// } - } - let (offset, scale) = srNameWorkaround(params: (self.offset, self.scale)) +// // Will document the SR name shortly - try @inline(never) if it doesn't work? +// func srNameWorkaround( +// params: (offset: Tensor, scale: Tensor) +// ) -> (offset: Tensor, scale: Tensor) { +// // if positiveAxis == input.rank - 1 { +// // return params +// // } else { +// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) +// broadcastShape[positiveAxis] = input.shape[positiveAxis] +// return (params.offset.reshaped(to: broadcastShape), params.scale.reshaped(to: broadcastShape)) +// // } +// } +// let (offset, scale) = srNameWorkaround(params: (self.offset, self.scale)) + // var offset = self.offset // var scale = self.scale // if positiveAxis != input.rank - 1 { // var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) // broadcastShape[positiveAxis] = input.shape[positiveAxis] // offset = offset.reshaped(to: broadcastShape) - +// // scale = scale.reshaped(to: broadcastShape) // } - switch Context.local.learningPhase { - case .training: - return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) - case .inference: - return doInference(input, offset: offset, scale: scale) +// switch Context.local.learningPhase { +// case .training: +// return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) +// case .inference: +// return doInference(input, offset: offset, scale: scale) +// } + + // Remove this workaround ASAP allow inlining of `doTraining` and `doInference` + if positiveAxis == input.rank - 1 { + let offset = self.offset + let scale = self.scale + + switch Context.local.learningPhase { + case .training: + return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) + case .inference: + return doInference(input, offset: offset, scale: scale) + } + } else { + // Might need to extract this into a function + var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) + broadcastShape[positiveAxis] = input.shape[positiveAxis] + let offset = self.offset.reshaped(to: broadcastShape) + let scale = self.scale.reshaped(to: broadcastShape) + + switch Context.local.learningPhase { + case .training: + return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) + case .inference: + return doInference(input, offset: offset, scale: scale) + } } } From 3fc82a8902f827393f3cf5f31feaa4c5a6988ef4 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 15:58:11 -0500 Subject: [PATCH 06/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 2ca08bfd7..f6dc281d4 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -106,20 +106,6 @@ public struct BatchNorm: Layer { input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") -// // Will document the SR name shortly - try @inline(never) if it doesn't work? -// func srNameWorkaround( -// params: (offset: Tensor, scale: Tensor) -// ) -> (offset: Tensor, scale: Tensor) { -// // if positiveAxis == input.rank - 1 { -// // return params -// // } else { -// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) -// broadcastShape[positiveAxis] = input.shape[positiveAxis] -// return (params.offset.reshaped(to: broadcastShape), params.scale.reshaped(to: broadcastShape)) -// // } -// } -// let (offset, scale) = srNameWorkaround(params: (self.offset, self.scale)) - // var offset = self.offset // var scale = self.scale // if positiveAxis != input.rank - 1 { @@ -138,7 +124,15 @@ public struct BatchNorm: Layer { // Remove this workaround ASAP allow inlining of `doTraining` and `doInference` if positiveAxis == input.rank - 1 { - let offset = self.offset + callAsFunction1(input) + } else { + callAsFunction2(input) + } + } + + @inline(never) + private func callAsFunction1(_ input: Tensor) -> Tensor { + let offset = self.offset let scale = self.scale switch Context.local.learningPhase { @@ -147,9 +141,11 @@ public struct BatchNorm: Layer { case .inference: return doInference(input, offset: offset, scale: scale) } - } else { - // Might need to extract this into a function - var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) + } + + @inline(never) + private func callAsFunction2(_ input: Tensor) -> Tensor { + var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] let offset = self.offset.reshaped(to: broadcastShape) let scale = self.scale.reshaped(to: broadcastShape) @@ -160,7 +156,6 @@ public struct BatchNorm: Layer { case .inference: return doInference(input, offset: offset, scale: scale) } - } } private func doTraining( From 9e3b130b8717a6612eb9b18ab4a4d7af78956270 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:00:42 -0500 Subject: [PATCH 07/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index f6dc281d4..f4d985162 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -124,14 +124,14 @@ public struct BatchNorm: Layer { // Remove this workaround ASAP allow inlining of `doTraining` and `doInference` if positiveAxis == input.rank - 1 { - callAsFunction1(input) + callAsFunction1(input, positiveAxis: positiveAxis) } else { - callAsFunction2(input) + callAsFunction2(input, positiveAxis: positiveAxis) } } @inline(never) - private func callAsFunction1(_ input: Tensor) -> Tensor { + private func callAsFunction1(_ input: Tensor, positiveAxis: Int) -> Tensor { let offset = self.offset let scale = self.scale @@ -144,7 +144,7 @@ public struct BatchNorm: Layer { } @inline(never) - private func callAsFunction2(_ input: Tensor) -> Tensor { + private func callAsFunction2(_ input: Tensor, positiveAxis: Int) -> Tensor { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] let offset = self.offset.reshaped(to: broadcastShape) From da6804d1b8a1e1d29c75a8b93c69dd427afa46dd Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:01:32 -0500 Subject: [PATCH 08/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index f4d985162..c6cfbfc9a 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -133,29 +133,29 @@ public struct BatchNorm: Layer { @inline(never) private func callAsFunction1(_ input: Tensor, positiveAxis: Int) -> Tensor { let offset = self.offset - let scale = self.scale - - switch Context.local.learningPhase { - case .training: - return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) - case .inference: - return doInference(input, offset: offset, scale: scale) - } + let scale = self.scale + + switch Context.local.learningPhase { + case .training: + return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) + case .inference: + return doInference(input, offset: offset, scale: scale) + } } @inline(never) private func callAsFunction2(_ input: Tensor, positiveAxis: Int) -> Tensor { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) - broadcastShape[positiveAxis] = input.shape[positiveAxis] - let offset = self.offset.reshaped(to: broadcastShape) - let scale = self.scale.reshaped(to: broadcastShape) - - switch Context.local.learningPhase { - case .training: - return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) - case .inference: - return doInference(input, offset: offset, scale: scale) - } + broadcastShape[positiveAxis] = input.shape[positiveAxis] + let offset = self.offset.reshaped(to: broadcastShape) + let scale = self.scale.reshaped(to: broadcastShape) + + switch Context.local.learningPhase { + case .training: + return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) + case .inference: + return doInference(input, offset: offset, scale: scale) + } } private func doTraining( From f79833210bf8a6ace0805c726bd3bf5c248c8919 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:02:50 -0500 Subject: [PATCH 09/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index c6cfbfc9a..de505c628 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -131,6 +131,7 @@ public struct BatchNorm: Layer { } @inline(never) + @differentiable(reverse, wrt: input) private func callAsFunction1(_ input: Tensor, positiveAxis: Int) -> Tensor { let offset = self.offset let scale = self.scale @@ -144,6 +145,7 @@ public struct BatchNorm: Layer { } @inline(never) + @differentiable(reverse, wrt: input) private func callAsFunction2(_ input: Tensor, positiveAxis: Int) -> Tensor { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] From 1c4947fd9e8eda4f6d3085701beaa743dabfd681 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:05:10 -0500 Subject: [PATCH 10/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index de505c628..9eb0b6d5c 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -124,9 +124,9 @@ public struct BatchNorm: Layer { // Remove this workaround ASAP allow inlining of `doTraining` and `doInference` if positiveAxis == input.rank - 1 { - callAsFunction1(input, positiveAxis: positiveAxis) + return callAsFunction1(input, positiveAxis: positiveAxis) } else { - callAsFunction2(input, positiveAxis: positiveAxis) + return callAsFunction2(input, positiveAxis: positiveAxis) } } From 51dfcb647e20ba6358ddcd7df4356bba9ae2e7e1 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:21:03 -0500 Subject: [PATCH 11/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 84 +++++++++++-------- 1 file changed, 48 insertions(+), 36 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 9eb0b6d5c..fbaabe109 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -105,9 +105,8 @@ public struct BatchNorm: Layer { precondition( input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") - -// var offset = self.offset -// var scale = self.scale + var offset = self.offset + var scale = self.scale // if positiveAxis != input.rank - 1 { // var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) // broadcastShape[positiveAxis] = input.shape[positiveAxis] @@ -115,27 +114,10 @@ public struct BatchNorm: Layer { // // scale = scale.reshaped(to: broadcastShape) // } -// switch Context.local.learningPhase { -// case .training: -// return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) -// case .inference: -// return doInference(input, offset: offset, scale: scale) -// } - - // Remove this workaround ASAP allow inlining of `doTraining` and `doInference` - if positiveAxis == input.rank - 1 { - return callAsFunction1(input, positiveAxis: positiveAxis) - } else { - return callAsFunction2(input, positiveAxis: positiveAxis) - } - } - - @inline(never) - @differentiable(reverse, wrt: input) - private func callAsFunction1(_ input: Tensor, positiveAxis: Int) -> Tensor { - let offset = self.offset - let scale = self.scale - + Self.srNameWorkaround(offset: &offset, + scale: &scale, + input: input, + positiveAxis: positiveAxis) switch Context.local.learningPhase { case .training: return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) @@ -145,20 +127,50 @@ public struct BatchNorm: Layer { } @inline(never) - @differentiable(reverse, wrt: input) - private func callAsFunction2(_ input: Tensor, positiveAxis: Int) -> Tensor { - var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) - broadcastShape[positiveAxis] = input.shape[positiveAxis] - let offset = self.offset.reshaped(to: broadcastShape) - let scale = self.scale.reshaped(to: broadcastShape) - - switch Context.local.learningPhase { - case .training: - return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) - case .inference: - return doInference(input, offset: offset, scale: scale) + @differentiable(wrt: offset, scale) + private static func srNameWorkaround( + offset: inout Tensor, + scale: inout Tensor, + input: Tensor, + positiveAxis: Int + ) { + if positiveAxis != input.rank - 1 { + var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) + broadcastShape[positiveAxis] = input.shape[positiveAxis] + offset = offset.reshaped(to: broadcastShape) + scale = scale.reshaped(to: broadcastShape) } } + +// @inline(never) +// @differentiable(reverse, wrt: input) +// private func callAsFunction1(_ input: Tensor, positiveAxis: Int) -> Tensor { +// let offset = self.offset +// let scale = self.scale + +// switch Context.local.learningPhase { +// case .training: +// return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) +// case .inference: +// return doInference(input, offset: offset, scale: scale) +// } +// } + +// @inline(never) +// @differentiable(reverse, wrt: input) +// private func callAsFunction2(_ input: Tensor, positiveAxis: Int) -> Tensor { +// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) +// broadcastShape[positiveAxis] = input.shape[positiveAxis] +// let offset = self.offset.reshaped(to: broadcastShape) +// let scale = self.scale.reshaped(to: broadcastShape) + +// switch Context.local.learningPhase { +// case .training: +// return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) +// case .inference: +// return doInference(input, offset: offset, scale: scale) +// } +// } private func doTraining( _ input: Tensor, offset: Tensor, scale: Tensor, axis: Int From f88c2c9732b41565a654f8f1dab86144d71f5147 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:22:39 -0500 Subject: [PATCH 12/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index fbaabe109..5b1def4b9 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -127,7 +127,7 @@ public struct BatchNorm: Layer { } @inline(never) - @differentiable(wrt: offset, scale) + @differentiable(wrt: (offset, scale)) private static func srNameWorkaround( offset: inout Tensor, scale: inout Tensor, From 8b87c39e1992921873c3b307f48f08e5ff42959b Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:26:49 -0500 Subject: [PATCH 13/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 5b1def4b9..3cc36834a 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -105,8 +105,8 @@ public struct BatchNorm: Layer { precondition( input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") - var offset = self.offset - var scale = self.scale + let offsetOriginal = self.offset + let scaleOriginal = self.scale // if positiveAxis != input.rank - 1 { // var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) // broadcastShape[positiveAxis] = input.shape[positiveAxis] @@ -114,10 +114,10 @@ public struct BatchNorm: Layer { // // scale = scale.reshaped(to: broadcastShape) // } - Self.srNameWorkaround(offset: &offset, - scale: &scale, - input: input, - positiveAxis: positiveAxis) + let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, + scale: scaleOriginal, + input: input, + positiveAxis: positiveAxis) switch Context.local.learningPhase { case .training: return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) @@ -127,19 +127,24 @@ public struct BatchNorm: Layer { } @inline(never) - @differentiable(wrt: (offset, scale)) + @differentiable(reverse, wrt: (offset, scale)) private static func srNameWorkaround( - offset: inout Tensor, - scale: inout Tensor, + offset: Tensor, + scale: Tensor, input: Tensor, positiveAxis: Int - ) { + ) -> (Tensor, Tensor) { + var offsetCopy = offset + var scaleCopy = offset + if positiveAxis != input.rank - 1 { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] - offset = offset.reshaped(to: broadcastShape) - scale = scale.reshaped(to: broadcastShape) + offsetCopy = offsetCopy.reshaped(to: broadcastShape) + scaleCopy = scaleCopy.reshaped(to: broadcastShape) } + + return (offsetCopy, scaleCopy) } // @inline(never) From c85e7617592ac5ca8085993a05c266f305cd0097 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:32:12 -0500 Subject: [PATCH 14/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 3cc36834a..14c6e2e74 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -127,7 +127,7 @@ public struct BatchNorm: Layer { } @inline(never) - @differentiable(reverse, wrt: (offset, scale)) + @differentiable(reverse) private static func srNameWorkaround( offset: Tensor, scale: Tensor, From 0a6c6abf410cacd2b5516ab39df38e269fbbdc8e Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:37:13 -0500 Subject: [PATCH 15/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 14c6e2e74..14ec49295 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -105,19 +105,21 @@ public struct BatchNorm: Layer { precondition( input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") - let offsetOriginal = self.offset - let scaleOriginal = self.scale -// if positiveAxis != input.rank - 1 { -// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) -// broadcastShape[positiveAxis] = input.shape[positiveAxis] -// offset = offset.reshaped(to: broadcastShape) -// -// scale = scale.reshaped(to: broadcastShape) -// } - let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, - scale: scaleOriginal, - input: input, - positiveAxis: positiveAxis) +// let offsetOriginal = self.offset +// let scaleOriginal = self.scale + var offset = self.offset + var scale = self.scale + if positiveAxis != input.rank - 1 { + var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) + broadcastShape[positiveAxis] = input.shape[positiveAxis] + offset = offset.reshaped(to: broadcastShape) + + scale = scale.reshaped(to: broadcastShape) + } +// let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, +// scale: scaleOriginal, +// input: input, +// positiveAxis: positiveAxis) switch Context.local.learningPhase { case .training: return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) From 917623727c9a4af40152def6656df1682a610c11 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:37:32 -0500 Subject: [PATCH 16/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 1 - 1 file changed, 1 deletion(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 14ec49295..c9d45a462 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -113,7 +113,6 @@ public struct BatchNorm: Layer { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] offset = offset.reshaped(to: broadcastShape) - scale = scale.reshaped(to: broadcastShape) } // let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, From fe062ac66a2056ebd9302d3c502430759494075a Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:43:58 -0500 Subject: [PATCH 17/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index c9d45a462..72150d19d 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -109,11 +109,17 @@ public struct BatchNorm: Layer { // let scaleOriginal = self.scale var offset = self.offset var scale = self.scale +// if positiveAxis != input.rank - 1 { +// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) +// broadcastShape[positiveAxis] = input.shape[positiveAxis] +// offset = offset.reshaped(to: broadcastShape) +// scale = scale.reshaped(to: broadcastShape) +// } if positiveAxis != input.rank - 1 { - var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) - broadcastShape[positiveAxis] = input.shape[positiveAxis] - offset = offset.reshaped(to: broadcastShape) - scale = scale.reshaped(to: broadcastShape) + (offset, scale) = Self.srNameWorkaround(offset: offset, + scale: scale, + input: input, + positiveAxis: positiveAxis) } // let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, // scale: scaleOriginal, @@ -135,17 +141,18 @@ public struct BatchNorm: Layer { input: Tensor, positiveAxis: Int ) -> (Tensor, Tensor) { - var offsetCopy = offset - var scaleCopy = offset +// var offsetCopy = offset +// var scaleCopy = offset - if positiveAxis != input.rank - 1 { +// if positiveAxis != input.rank - 1 { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] - offsetCopy = offsetCopy.reshaped(to: broadcastShape) - scaleCopy = scaleCopy.reshaped(to: broadcastShape) - } + return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) +// offsetCopy = offsetCopy.reshaped(to: broadcastShape) +// scaleCopy = scaleCopy.reshaped(to: broadcastShape) +// } - return (offsetCopy, scaleCopy) +// return (offsetCopy, scaleCopy) } // @inline(never) From 9545583525a417af5168372dfd8c9d8af8a22a5d Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:49:13 -0500 Subject: [PATCH 18/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 72150d19d..429d09964 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -116,8 +116,8 @@ public struct BatchNorm: Layer { // scale = scale.reshaped(to: broadcastShape) // } if positiveAxis != input.rank - 1 { - (offset, scale) = Self.srNameWorkaround(offset: offset, - scale: scale, + Self.srNameWorkaround(offset: &offset, + scale: &scale, input: input, positiveAxis: positiveAxis) } @@ -136,18 +136,19 @@ public struct BatchNorm: Layer { @inline(never) @differentiable(reverse) private static func srNameWorkaround( - offset: Tensor, - scale: Tensor, + offset: inout Tensor, + scale: inout Tensor, input: Tensor, positiveAxis: Int - ) -> (Tensor, Tensor) { + ) { // var offsetCopy = offset // var scaleCopy = offset // if positiveAxis != input.rank - 1 { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] - return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) + offset = offset.reshaped(to: broadcastShape) + shape = scale.reshaped(to: broadcastShape) // offsetCopy = offsetCopy.reshaped(to: broadcastShape) // scaleCopy = scaleCopy.reshaped(to: broadcastShape) // } From 449b1c5c3207858f0f0c6f012d49b5f87cb0170b Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:53:15 -0500 Subject: [PATCH 19/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 429d09964..f93171d8c 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -115,12 +115,12 @@ public struct BatchNorm: Layer { // offset = offset.reshaped(to: broadcastShape) // scale = scale.reshaped(to: broadcastShape) // } - if positiveAxis != input.rank - 1 { - Self.srNameWorkaround(offset: &offset, - scale: &scale, +// if positiveAxis != input.rank - 1 { + (offset, scale) = Self.srNameWorkaround(offset: offset, + scale: scale, input: input, positiveAxis: positiveAxis) - } +// } // let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, // scale: scaleOriginal, // input: input, @@ -135,23 +135,24 @@ public struct BatchNorm: Layer { @inline(never) @differentiable(reverse) - private static func srNameWorkaround( - offset: inout Tensor, - scale: inout Tensor, + private static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct + offset: Tensor, + scale: Tensor, input: Tensor, positiveAxis: Int - ) { + ) -> (Tensor, Tensor) { // var offsetCopy = offset // var scaleCopy = offset -// if positiveAxis != input.rank - 1 { + if positiveAxis != input.rank - 1 { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] - offset = offset.reshaped(to: broadcastShape) - shape = scale.reshaped(to: broadcastShape) + return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) // offsetCopy = offsetCopy.reshaped(to: broadcastShape) // scaleCopy = scaleCopy.reshaped(to: broadcastShape) -// } + } else { + return (offset, scale) + } // return (offsetCopy, scaleCopy) } From 718b98862476460e458be79a4683e43898648fc3 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 16:57:33 -0500 Subject: [PATCH 20/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index f93171d8c..513d14957 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -107,8 +107,8 @@ public struct BatchNorm: Layer { "The number of features of the input and the offset doesn't match.") // let offsetOriginal = self.offset // let scaleOriginal = self.scale - var offset = self.offset - var scale = self.scale + let offsetOriginal = self.offset + let scaleOriginal = self.scale // if positiveAxis != input.rank - 1 { // var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) // broadcastShape[positiveAxis] = input.shape[positiveAxis] @@ -116,10 +116,10 @@ public struct BatchNorm: Layer { // scale = scale.reshaped(to: broadcastShape) // } // if positiveAxis != input.rank - 1 { - (offset, scale) = Self.srNameWorkaround(offset: offset, - scale: scale, - input: input, - positiveAxis: positiveAxis) + let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, + scale: scaleOriginal, + input: input, + positiveAxis: positiveAxis) // } // let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, // scale: scaleOriginal, From b1e147583c87030a11e6bc7577b7dd61f16c0aa3 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 17:04:19 -0500 Subject: [PATCH 21/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 513d14957..326fb69aa 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -141,20 +141,20 @@ public struct BatchNorm: Layer { input: Tensor, positiveAxis: Int ) -> (Tensor, Tensor) { -// var offsetCopy = offset -// var scaleCopy = offset + var offsetCopy = offset + var scaleCopy = offset if positiveAxis != input.rank - 1 { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] - return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) -// offsetCopy = offsetCopy.reshaped(to: broadcastShape) -// scaleCopy = scaleCopy.reshaped(to: broadcastShape) +// return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) + offsetCopy = offsetCopy.reshaped(to: broadcastShape) + scaleCopy = scaleCopy.reshaped(to: broadcastShape) } else { - return (offset, scale) +// return (offset, scale) } -// return (offsetCopy, scaleCopy) + return (offsetCopy, scaleCopy) } // @inline(never) From ec5c664280fa364144f8b3b9dede5095e3949f82 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 17:09:22 -0500 Subject: [PATCH 22/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 63 +++---------------- 1 file changed, 10 insertions(+), 53 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 326fb69aa..d6b54f943 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -105,26 +105,20 @@ public struct BatchNorm: Layer { precondition( input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") -// let offsetOriginal = self.offset -// let scaleOriginal = self.scale - let offsetOriginal = self.offset - let scaleOriginal = self.scale +// var offset = self.offset +// var scale = self.scale // if positiveAxis != input.rank - 1 { // var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) // broadcastShape[positiveAxis] = input.shape[positiveAxis] // offset = offset.reshaped(to: broadcastShape) // scale = scale.reshaped(to: broadcastShape) // } -// if positiveAxis != input.rank - 1 { - let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, - scale: scaleOriginal, - input: input, - positiveAxis: positiveAxis) -// } -// let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, -// scale: scaleOriginal, -// input: input, -// positiveAxis: positiveAxis) + let offsetOriginal = self.offset + let scaleOriginal = self.scale + let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, + scale: scaleOriginal, + input: input, + positiveAxis: positiveAxis) switch Context.local.learningPhase { case .training: return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) @@ -141,51 +135,14 @@ public struct BatchNorm: Layer { input: Tensor, positiveAxis: Int ) -> (Tensor, Tensor) { - var offsetCopy = offset - var scaleCopy = offset - if positiveAxis != input.rank - 1 { var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) broadcastShape[positiveAxis] = input.shape[positiveAxis] -// return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) - offsetCopy = offsetCopy.reshaped(to: broadcastShape) - scaleCopy = scaleCopy.reshaped(to: broadcastShape) + return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape)) } else { -// return (offset, scale) + return (offset, scale) } - - return (offsetCopy, scaleCopy) } - -// @inline(never) -// @differentiable(reverse, wrt: input) -// private func callAsFunction1(_ input: Tensor, positiveAxis: Int) -> Tensor { -// let offset = self.offset -// let scale = self.scale - -// switch Context.local.learningPhase { -// case .training: -// return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) -// case .inference: -// return doInference(input, offset: offset, scale: scale) -// } -// } - -// @inline(never) -// @differentiable(reverse, wrt: input) -// private func callAsFunction2(_ input: Tensor, positiveAxis: Int) -> Tensor { -// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) -// broadcastShape[positiveAxis] = input.shape[positiveAxis] -// let offset = self.offset.reshaped(to: broadcastShape) -// let scale = self.scale.reshaped(to: broadcastShape) - -// switch Context.local.learningPhase { -// case .training: -// return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) -// case .inference: -// return doInference(input, offset: offset, scale: scale) -// } -// } private func doTraining( _ input: Tensor, offset: Tensor, scale: Tensor, axis: Int From c34889ca593f06ed32ebceaaa423f6df3934f506 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 17:13:36 -0500 Subject: [PATCH 23/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index d6b54f943..c3dfa5db8 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -129,7 +129,7 @@ public struct BatchNorm: Layer { @inline(never) @differentiable(reverse) - private static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct + public static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct --- remove `public` after debugging offset: Tensor, scale: Tensor, input: Tensor, From 4cb76764c2342025c01e4e05bf5da1128545d365 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 17:17:17 -0500 Subject: [PATCH 24/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index c3dfa5db8..65bb060f1 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -129,7 +129,7 @@ public struct BatchNorm: Layer { @inline(never) @differentiable(reverse) - public static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct --- remove `public` after debugging + private static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct --- remove `public` after debugging offset: Tensor, scale: Tensor, input: Tensor, From b7719364c67bf3d91c475cc5f173b2fbde364dd2 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 17:21:14 -0500 Subject: [PATCH 25/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 65bb060f1..eaed8f3b1 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -129,7 +129,7 @@ public struct BatchNorm: Layer { @inline(never) @differentiable(reverse) - private static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct --- remove `public` after debugging + internal static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct --- remove `public` after debugging offset: Tensor, scale: Tensor, input: Tensor, From 012697016fd310596e4debaeafb39eda60392d5d Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 17:23:54 -0500 Subject: [PATCH 26/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index eaed8f3b1..91fa54e31 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -128,8 +128,8 @@ public struct BatchNorm: Layer { } @inline(never) - @differentiable(reverse) - internal static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct --- remove `public` after debugging + @differentiable(reverse) // if the function is `public` or `internal`, the compiler crashes + private static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct offset: Tensor, scale: Tensor, input: Tensor, From da1722e34350efb0f782ff6524d126dced8b6967 Mon Sep 17 00:00:00 2001 From: Philip Turner Date: Sun, 9 Jan 2022 17:47:22 -0500 Subject: [PATCH 27/27] Update Normalization.swift --- Sources/TensorFlow/Layers/Normalization.swift | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/Sources/TensorFlow/Layers/Normalization.swift b/Sources/TensorFlow/Layers/Normalization.swift index 91fa54e31..e46e0ed22 100644 --- a/Sources/TensorFlow/Layers/Normalization.swift +++ b/Sources/TensorFlow/Layers/Normalization.swift @@ -105,8 +105,7 @@ public struct BatchNorm: Layer { precondition( input.shape[positiveAxis] == offset.shape[0], "The number of features of the input and the offset doesn't match.") -// var offset = self.offset -// var scale = self.scale +// var (offset, scale) = {x in (x.offset, x.scale) }(self) // if positiveAxis != input.rank - 1 { // var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank)) // broadcastShape[positiveAxis] = input.shape[positiveAxis] @@ -115,10 +114,10 @@ public struct BatchNorm: Layer { // } let offsetOriginal = self.offset let scaleOriginal = self.scale - let (offset, scale) = Self.srNameWorkaround(offset: offsetOriginal, - scale: scaleOriginal, - input: input, - positiveAxis: positiveAxis) + let (offset, scale) = Self._sr13263workaround(offset: offsetOriginal, + scale: scaleOriginal, + input: input, + positiveAxis: positiveAxis) switch Context.local.learningPhase { case .training: return doTraining(input, offset: offset, scale: scale, axis: positiveAxis) @@ -129,7 +128,7 @@ public struct BatchNorm: Layer { @inline(never) @differentiable(reverse) // if the function is `public` or `internal`, the compiler crashes - private static func srNameWorkaround( // if this doesn't work, try a fileprivate generic struct + private static func _sr13263workaround( offset: Tensor, scale: Tensor, input: Tensor,