Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit fc55d91

Browse files
authored
add old optimizer initializers back (#36)
1 parent 6ef51b3 commit fc55d91

File tree

1 file changed

+42
-9
lines changed

1 file changed

+42
-9
lines changed

Sources/DeepLearning/Optimizer.swift

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,11 @@ public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
3535
public let decay: Scalar
3636

3737
public init(
38-
for _: __shared Model,
3938
learningRate: Scalar = 1e-3,
4039
beta1: Scalar = 0.9,
4140
beta2: Scalar = 0.999,
4241
epsilon: Scalar = 1e-8,
43-
decay: Scalar = 0,
44-
scalarType: Scalar.Type
42+
decay: Scalar = 0
4543
) {
4644
precondition(learningRate >= 0, "Learning rate must be non-negative")
4745
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1")
@@ -55,6 +53,23 @@ public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
5553
self.decay = decay
5654
}
5755

56+
public convenience init(
57+
for _: __shared Model,
58+
learningRate: Scalar = 1e-3,
59+
beta1: Scalar = 0.9,
60+
beta2: Scalar = 0.999,
61+
epsilon: Scalar = 1e-8,
62+
decay: Scalar = 0,
63+
scalarType: Scalar.Type
64+
) {
65+
self.init(
66+
learningRate: learningRate,
67+
beta1: beta1,
68+
beta2: beta2,
69+
epsilon: epsilon,
70+
decay: decay)
71+
}
72+
5873
private var step: Scalar = 0
5974
private var firstMoments = Model.AllDifferentiableVariables.zero
6075
private var secondMoments = Model.AllDifferentiableVariables.zero
@@ -84,12 +99,10 @@ public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
8499
public let decay: Scalar
85100

86101
public init(
87-
for _: __shared Model,
88102
learningRate: Scalar = 0.001,
89103
rho: Scalar = 0.9,
90104
epsilon: Scalar = 1e-8,
91-
decay: Scalar = 0,
92-
scalarType: Scalar.Type
105+
decay: Scalar = 0
93106
) {
94107
precondition(learningRate >= 0, "Learning rate must be non-negative")
95108
precondition(rho >= 0, "Rho must be non-negative")
@@ -101,6 +114,17 @@ public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
101114
self.decay = decay
102115
}
103116

117+
public convenience init(
118+
for _: __shared Model,
119+
learningRate: Scalar = 0.001,
120+
rho: Scalar = 0.9,
121+
epsilon: Scalar = 1e-8,
122+
decay: Scalar = 0,
123+
scalarType: Scalar.Type
124+
) {
125+
self.init(learningRate: learningRate, rho: rho, epsilon: epsilon, decay: decay)
126+
}
127+
104128
private var step: Scalar = 0
105129
private var alpha = Model.AllDifferentiableVariables.zero
106130

@@ -125,12 +149,10 @@ public class SGD<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
125149
public let nesterov: Bool
126150

127151
public init(
128-
for _: __shared Model,
129152
learningRate: Scalar = 0.01,
130153
momentum: Scalar = 0,
131154
decay: Scalar = 0,
132-
nesterov: Bool = false,
133-
scalarType: Scalar.Type
155+
nesterov: Bool = false
134156
) {
135157
precondition(learningRate >= 0, "Learning rate must be non-negative")
136158
precondition(momentum >= 0, "Momentum must be non-negative")
@@ -142,6 +164,17 @@ public class SGD<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
142164
self.nesterov = nesterov
143165
}
144166

167+
public convenience init(
168+
for _: __shared Model,
169+
learningRate: Scalar = 0.01,
170+
momentum: Scalar = 0,
171+
decay: Scalar = 0,
172+
nesterov: Bool = false,
173+
scalarType: Scalar.Type
174+
) {
175+
self.init(learningRate: learningRate, momentum: momentum, decay: decay, nesterov: nesterov)
176+
}
177+
145178
private var step: Scalar = 0
146179
private var velocity = Model.AllDifferentiableVariables.zero
147180

0 commit comments

Comments
 (0)