-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Closed as not planned
Closed as not planned
Copy link
Labels
Description
Previous ID | SR-12157 |
Radar | None |
Original Reporter | @dan-zheng |
Type | Sub-task |
Additional Detail from JIRA
Votes | 0 |
Component/s | |
Labels | Sub-task |
Assignee | None |
Priority | Medium |
md5: 24831c89fd1b09f4ffb802214eb06a7a
Issue Description:
Deprecate/remove the workaround Parameter
class when possible.
Directly convert user types (e.g. BatchNorm
to classes).
import TensorFlow
// @frozen
// public class BatchNorm<Scalar: TensorFlowFloatingPoint>: Differentiable {
public class BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
@noDerivative public let axis: Int
@noDerivative public let momentum: Tensor<Scalar>
public var offset: Tensor<Scalar>
public var scale: Tensor<Scalar>
@noDerivative public let epsilon: Tensor<Scalar>
@noDerivative public var runningMean: Tensor<Scalar>
@noDerivative public var runningVariance: Tensor<Scalar>
public var allKeyPaths: [PartialKeyPath<BatchNorm>] { [] }
/// Creates a batch normalization layer.
///
/// - Parameters:
/// - axis: The axis that should not be normalized (typically the feature axis).
/// - momentum: The momentum for the moving average.
/// - offset: The offset to be added to the normalized tensor.
/// - scale: The scale to multiply the normalized tensor by.
/// - epsilon: A small scalar added to the denominator to improve numerical stability.
/// - runningMean: The running mean.
/// - runningVariance: The running variance.
public init(
axis: Int,
momentum: Tensor<Scalar>,
offset: Tensor<Scalar>,
scale: Tensor<Scalar>,
epsilon: Tensor<Scalar>,
runningMean: Tensor<Scalar>,
runningVariance: Tensor<Scalar>
) {
self.axis = axis
self.momentum = momentum
self.offset = offset
self.scale = scale
self.epsilon = epsilon
self.runningMean = runningMean
self.runningVariance = runningVariance
}
@differentiable
private func applyingTraining(to input: Tensor<Scalar>) -> Tensor<Scalar> {
let positiveAxis = (input.rank + axis) % input.rank
var normalizedAxes = Array(0..<input.rank)
normalizedAxes.remove(at: positiveAxis)
let mean = input.mean(alongAxes: normalizedAxes)
let variance = input.variance(alongAxes: normalizedAxes)
runningMean += (mean - runningMean) * (1 - momentum)
runningVariance += (variance - runningVariance) * (1 - momentum)
let inv = rsqrt(variance + epsilon) * scale
return (input - mean) * inv + offset
}
@differentiable
private func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
let inv = rsqrt(runningVariance + epsilon) * scale
return (input - runningMean) * inv + offset
}
/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable(vjp: _vjpApplied(to:))
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
switch Context.local.learningPhase {
case .training:
return applyingTraining(to: input)
case .inference:
return applyingInference(to: input)
}
}
@usableFromInline
func _vjpApplied(to input: Tensor<Scalar>) ->
(Tensor<Scalar>, (Tensor<Scalar>) ->
(BatchNorm<Scalar>.TangentVector, Tensor<Scalar>)) {
switch Context.local.learningPhase {
case .training:
return valueWithPullback(at: input) {
$0.applyingTraining(to: $1)
}
case .inference:
return valueWithPullback(at: input) {
$0.applyingInference(to: $1)
}
}
}
/// Creates a batch normalization layer.
///
/// - Parameters:
/// - featureCount: The number of features.
/// - axis: The axis that should be normalized (typically the features axis).
/// - momentum: The momentum for the moving average.
/// - epsilon: A small scalar added to the denominator to improve numerical stability.
public init(featureCount: Int,
axis: Int = -1,
momentum: Tensor<Scalar> = Tensor(0.99),
epsilon: Tensor<Scalar> = Tensor(0.001)) {
self.axis = axis
self.momentum = momentum
self.scale = Tensor<Scalar>(ones: [featureCount])
self.offset = Tensor<Scalar>(zeros: [featureCount])
self.epsilon = epsilon
self.runningMean = Tensor(0)
self.runningVariance = Tensor(1)
}
}