Skip to content

[SR-12157] Use Differentiable-conforming class types in swift-apis #53444

@dan-zheng

Description

@dan-zheng
Previous ID SR-12157
Radar None
Original Reporter @dan-zheng
Type Sub-task
Additional Detail from JIRA
Votes 0
Component/s
Labels Sub-task
Assignee None
Priority Medium

md5: 24831c89fd1b09f4ffb802214eb06a7a

Issue Description:

Deprecate/remove the workaround Parameter class when possible.
Directly convert user types (e.g. BatchNorm to classes).

import TensorFlow

// @frozen
// public class BatchNorm<Scalar: TensorFlowFloatingPoint>: Differentiable {
public class BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
    @noDerivative public let axis: Int
    @noDerivative public let momentum: Tensor<Scalar>
    public var offset: Tensor<Scalar>
    public var scale: Tensor<Scalar>
    @noDerivative public let epsilon: Tensor<Scalar>
    @noDerivative public var runningMean: Tensor<Scalar>
    @noDerivative public var runningVariance: Tensor<Scalar>

    public var allKeyPaths: [PartialKeyPath<BatchNorm>] { [] }

    /// Creates a batch normalization layer.
    ///
    /// - Parameters:
    ///   - axis: The axis that should not be normalized (typically the feature axis).
    ///   - momentum: The momentum for the moving average.
    ///   - offset: The offset to be added to the normalized tensor.
    ///   - scale: The scale to multiply the normalized tensor by.
    ///   - epsilon: A small scalar added to the denominator to improve numerical stability.
    ///   - runningMean: The running mean.
    ///   - runningVariance: The running variance.
    public init(
        axis: Int,
        momentum: Tensor<Scalar>,
        offset: Tensor<Scalar>,
        scale: Tensor<Scalar>,
        epsilon: Tensor<Scalar>,
        runningMean: Tensor<Scalar>,
        runningVariance: Tensor<Scalar>
    ) {
        self.axis = axis
        self.momentum = momentum
        self.offset = offset
        self.scale = scale
        self.epsilon = epsilon
        self.runningMean = runningMean
        self.runningVariance = runningVariance
    }

    @differentiable
    private func applyingTraining(to input: Tensor<Scalar>) -> Tensor<Scalar> {
        let positiveAxis = (input.rank + axis) % input.rank
        var normalizedAxes = Array(0..<input.rank)
        normalizedAxes.remove(at: positiveAxis)
        let mean = input.mean(alongAxes: normalizedAxes)
        let variance = input.variance(alongAxes: normalizedAxes)
        runningMean += (mean - runningMean) * (1 - momentum)
        runningVariance += (variance - runningVariance) * (1 - momentum)
        let inv = rsqrt(variance + epsilon) * scale
        return (input - mean) * inv + offset
    }

    @differentiable
    private func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
        let inv = rsqrt(runningVariance + epsilon) * scale
        return (input - runningMean) * inv + offset
    }

    /// Returns the output obtained from applying the layer to the given input.
    ///
    /// - Parameter input: The input to the layer.
    /// - Returns: The output.
    @differentiable(vjp: _vjpApplied(to:))
    public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
        switch Context.local.learningPhase {
        case .training:
            return applyingTraining(to: input)
        case .inference:
            return applyingInference(to: input)
        }
    }

    @usableFromInline
    func _vjpApplied(to input: Tensor<Scalar>) ->
        (Tensor<Scalar>, (Tensor<Scalar>) ->
            (BatchNorm<Scalar>.TangentVector, Tensor<Scalar>)) {
        switch Context.local.learningPhase {
        case .training:
            return valueWithPullback(at: input) {
                $0.applyingTraining(to: $1)
            }
        case .inference:
            return valueWithPullback(at: input) {
                $0.applyingInference(to: $1)
            }
        }
    }

    /// Creates a batch normalization layer.
    ///
    /// - Parameters:
    ///   - featureCount: The number of features.
    ///   - axis: The axis that should be normalized (typically the features axis).
    ///   - momentum: The momentum for the moving average.
    ///   - epsilon: A small scalar added to the denominator to improve numerical stability.
    public init(featureCount: Int,
                axis: Int = -1,
                momentum: Tensor<Scalar> = Tensor(0.99),
                epsilon: Tensor<Scalar> = Tensor(0.001)) {
        self.axis = axis
        self.momentum = momentum
        self.scale = Tensor<Scalar>(ones: [featureCount])
        self.offset = Tensor<Scalar>(zeros: [featureCount])
        self.epsilon = epsilon
        self.runningMean = Tensor(0)
        self.runningVariance = Tensor(1)
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions