Skip to content

Commit 2e20635

Browse files
committed
Remove duplication in the code creating optimizer variables
1 parent 78251c4 commit 2e20635

File tree

9 files changed

+40
-155
lines changed

9 files changed

+40
-155
lines changed

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/AdaDelta.kt

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
package org.jetbrains.kotlinx.dl.api.core.optimizer
77

88
import org.jetbrains.kotlinx.dl.api.core.KGraph
9-
import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
109
import org.jetbrains.kotlinx.dl.api.core.util.getDType
1110
import org.tensorflow.Operand
12-
import org.tensorflow.Output
1311
import org.tensorflow.op.Ops
1412
import org.tensorflow.op.core.Constant
1513
import org.tensorflow.op.core.Gradients
@@ -74,7 +72,9 @@ public class AdaDelta(
7472
epsilonConstant = tf.constant(epsilon, getDType())
7573

7674
for ((i, variable) in weights.withIndex()) {
77-
val (accumSlot, accumUpdateSlot) = createAdaDeltaSlot(graph, tf, variable.asOutput())
75+
val output = variable.asOutput()
76+
val accumSlot = createSlot(ACCUMULATOR, output, tf, graph)
77+
val accumUpdateSlot = createSlot(ACCUMULATOR_UPDATE, output, tf, graph)
7878

7979
targets.add(
8080
tf.train.applyAdadelta(
@@ -91,19 +91,6 @@ public class AdaDelta(
9191
return targets
9292
}
9393

94-
private fun createAdaDeltaSlot(graph: KGraph, tf: Ops, v: Output<Float>): Pair<Variable<Float>, Variable<Float>> {
95-
val accumInitializerName = defaultInitializerOpName(createName(v, ACCUMULATOR))
96-
val accumulatorInitializer = tf.withName(accumInitializerName)
97-
.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), getDType()))
98-
val accumulator = createSlot(graph, tf, v.asOutput(), ACCUMULATOR, accumulatorInitializer)
99-
100-
val accumUpdateInitializerName = defaultInitializerOpName(createName(v, ACCUMULATOR_UPDATE))
101-
val updateInitializer: Operand<Float> = tf.withName(accumUpdateInitializerName)
102-
.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), getDType()))
103-
val accumulatorUpdate = createSlot(graph, tf, v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer)
104-
return accumulator to accumulatorUpdate
105-
}
106-
10794
override val optimizerName: String get() = "Adadelta"
10895

10996
override val isRunningOnGPU: Boolean get() = true

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/AdaGrad.kt

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
package org.jetbrains.kotlinx.dl.api.core.optimizer
77

88
import org.jetbrains.kotlinx.dl.api.core.KGraph
9-
import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
109
import org.jetbrains.kotlinx.dl.api.core.util.getDType
1110
import org.tensorflow.Operand
12-
import org.tensorflow.Output
1311
import org.tensorflow.op.Ops
1412
import org.tensorflow.op.core.Constant
1513
import org.tensorflow.op.core.Gradients
@@ -64,7 +62,7 @@ public class AdaGrad(
6462
learningRateConst = tf.constant(learningRate, getDType())
6563

6664
for ((i, variable) in weights.withIndex()) {
67-
val slot = createAdaGradSlot(graph, tf, variable.asOutput())
65+
val slot = createSlot(ACCUMULATOR, variable.asOutput(), tf, graph, initialValue = initialAccumulatorValue)
6866

6967
targets.add(
7068
tf.train.applyAdagrad(
@@ -80,14 +78,6 @@ public class AdaGrad(
8078
return targets
8179
}
8280

83-
private fun createAdaGradSlot(graph: KGraph, tf: Ops, v: Output<Float>): Variable<Float> {
84-
val accumInitializerName = defaultInitializerOpName(createName(v, ACCUMULATOR))
85-
86-
val initializer: Operand<Float> = tf.withName(accumInitializerName)
87-
.fill(tf.shape(v), tf.constant(initialAccumulatorValue))
88-
return createSlot(graph, tf, v.asOutput(), ACCUMULATOR, initializer)
89-
}
90-
9181
override val optimizerName: String get() = "Adagrad"
9282

9383
override val isRunningOnGPU: Boolean get() = true

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/AdaGradDA.kt

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
1111
import org.jetbrains.kotlinx.dl.api.core.util.defaultOptimizerVariableName
1212
import org.jetbrains.kotlinx.dl.api.core.util.getDType
1313
import org.tensorflow.Operand
14-
import org.tensorflow.Output
1514
import org.tensorflow.Shape
1615
import org.tensorflow.op.Ops
1716
import org.tensorflow.op.core.Assign
@@ -83,7 +82,9 @@ public class AdaGradDA(
8382
graph.addOptimizerVariableInitializer(globalStepInit)
8483

8584
for ((i, variable) in weights.withIndex()) {
86-
val (gradSlot, gradSquaredSlot) = createAdaGradDASlot(graph, tf, variable.asOutput())
85+
val output = variable.asOutput()
86+
val gradSlot = createSlot(ACCUMULATOR, output, tf, graph)
87+
val gradSquaredSlot = createSlot(SQUARED_ACCUMULATOR, output, tf, graph)
8788
targets.add(
8889
tf.train.applyAdagradDa(
8990
variable,
@@ -105,20 +106,6 @@ public class AdaGradDA(
105106
return targets
106107
}
107108

108-
private fun createAdaGradDASlot(graph: KGraph, tf: Ops, v: Output<Float>): Pair<Variable<Float>, Variable<Float>> {
109-
val accumulatorInitializerName = defaultInitializerOpName(createName(v, ACCUMULATOR))
110-
val accumInitializer: Operand<Float> = tf.withName(accumulatorInitializerName)
111-
.fill(tf.shape(v), tf.constant(0.0f))
112-
val accumulator = createSlot(graph, tf, v.asOutput(), ACCUMULATOR, accumInitializer)
113-
114-
val squareAccumInitializerName = defaultInitializerOpName(createName(v, SQUARED_ACCUMULATOR))
115-
val sqInitializer: Operand<Float> = tf.withName(squareAccumInitializerName)
116-
.fill(tf.shape(v), tf.constant(initialAccumulatorValue))
117-
118-
val squaredAccumulator = createSlot(graph, tf, v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer)
119-
return accumulator to squaredAccumulator
120-
}
121-
122109
override val optimizerName: String get() = "AdaGradDA"
123110

124111
override val isRunningOnGPU: Boolean get() = true

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Adam.kt

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
1111
import org.jetbrains.kotlinx.dl.api.core.util.defaultOptimizerVariableName
1212
import org.jetbrains.kotlinx.dl.api.core.util.getDType
1313
import org.tensorflow.Operand
14-
import org.tensorflow.Output
1514
import org.tensorflow.Shape
1615
import org.tensorflow.op.Ops
1716
import org.tensorflow.op.core.Assign
@@ -99,7 +98,9 @@ public class Adam(
9998
graph.addOptimizerVariableInitializer(betaTwoPowerInit)
10099

101100
for ((i, variable) in weights.withIndex()) {
102-
val (firstMomentSlot, secondMomentSlot) = createAdamSlot(graph, tf, variable.asOutput())
101+
val output = variable.asOutput()
102+
val firstMomentSlot = createSlot(FIRST_MOMENT, output, tf, graph)
103+
val secondMomentSlot = createSlot(SECOND_MOMENT, output, tf, graph)
103104
targets.add(
104105
tf.train.applyAdam(
105106
variable,
@@ -133,19 +134,6 @@ public class Adam(
133134
return targets
134135
}
135136

136-
private fun createAdamSlot(graph: KGraph, tf: Ops, v: Output<Float>): Pair<Variable<Float>, Variable<Float>> {
137-
val firstMomentInitializerName = defaultInitializerOpName(createName(v, FIRST_MOMENT))
138-
val firstMomentInitializer =
139-
tf.withName(firstMomentInitializerName).fill(tf.shape(v), tf.constant(0.0f, getDType()))
140-
val firstMoment = createSlot(graph, tf, v.asOutput(), FIRST_MOMENT, firstMomentInitializer)
141-
142-
val secondMomentInitializerName = defaultInitializerOpName(createName(v, SECOND_MOMENT))
143-
val secondMomentInitializer =
144-
tf.withName(secondMomentInitializerName).fill(tf.shape(v), tf.constant(0.0f, getDType()))
145-
val secondMoment = createSlot(graph, tf, v.asOutput(), SECOND_MOMENT, secondMomentInitializer)
146-
return firstMoment to secondMoment
147-
}
148-
149137
override val optimizerName: String get() = "Adam"
150138

151139
override val isRunningOnGPU: Boolean get() = true

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Adamax.kt

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
1111
import org.jetbrains.kotlinx.dl.api.core.util.defaultOptimizerVariableName
1212
import org.jetbrains.kotlinx.dl.api.core.util.getDType
1313
import org.tensorflow.Operand
14-
import org.tensorflow.Output
1514
import org.tensorflow.Shape
1615
import org.tensorflow.op.Ops
1716
import org.tensorflow.op.Scope
@@ -91,7 +90,9 @@ public class Adamax(
9190
val scope = Scope(graph.tfGraph)
9291

9392
for ((i, variable) in weights.withIndex()) {
94-
val (firstMomentSlot, secondMomentSlot) = createAdamaxSlot(graph, tf, variable.asOutput())
93+
val output = variable.asOutput()
94+
val firstMomentSlot = createSlot(FIRST_MOMENT, output, tf, graph)
95+
val secondMomentSlot = createSlot(SECOND_MOMENT, output, tf, graph)
9596
targets.add(
9697
ApplyAdaMax.create(
9798
scope,
@@ -117,20 +118,6 @@ public class Adamax(
117118
return targets
118119
}
119120

120-
private fun createAdamaxSlot(graph: KGraph, tf: Ops, v: Output<Float>): Pair<Variable<Float>, Variable<Float>> {
121-
val firstMomentInitializerName = defaultInitializerOpName(createName(v, FIRST_MOMENT))
122-
val firstMomentInitializer =
123-
tf.withName(firstMomentInitializerName).fill(tf.shape(v), tf.constant(0.0f, getDType()))
124-
val firstMoment = createSlot(graph, tf, v.asOutput(), FIRST_MOMENT, firstMomentInitializer)
125-
126-
val secondMomentInitializerName = defaultInitializerOpName(createName(v, SECOND_MOMENT))
127-
val secondMomentInitializer = tf.withName(secondMomentInitializerName)
128-
.fill(tf.shape(v), tf.constant(0.0f, getDType()))
129-
val secondMoment = createSlot(graph, tf, v.asOutput(), SECOND_MOMENT, secondMomentInitializer)
130-
131-
return firstMoment to secondMoment
132-
}
133-
134121
override val optimizerName: String get() = "Adamax"
135122

136123
override val isRunningOnGPU: Boolean get() = false

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Ftrl.kt

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
package org.jetbrains.kotlinx.dl.api.core.optimizer
77

88
import org.jetbrains.kotlinx.dl.api.core.KGraph
9-
import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
109
import org.jetbrains.kotlinx.dl.api.core.util.getDType
1110
import org.tensorflow.Operand
12-
import org.tensorflow.Output
1311
import org.tensorflow.op.Ops
1412
import org.tensorflow.op.core.Constant
1513
import org.tensorflow.op.core.Gradients
@@ -91,7 +89,9 @@ public class Ftrl(
9189
learningRatePowerConst = tf.constant(learningRatePower, getDType())
9290

9391
for ((i, variable) in weights.withIndex()) {
94-
val (accumSlot, linearSlot) = createFtrlSlot(graph, tf, variable.asOutput())
92+
val output = variable.asOutput()
93+
val accumSlot = createSlot(ACCUMULATOR, output, tf, graph)
94+
val linearSlot = createSlot(LINEAR_ACCUMULATOR, output, tf, graph)
9595

9696
val options = ApplyFtrl.useLocking(true)
9797

@@ -114,20 +114,6 @@ public class Ftrl(
114114
return targets
115115
}
116116

117-
private fun createFtrlSlot(graph: KGraph, tf: Ops, v: Output<Float>): Pair<Variable<Float>, Variable<Float>> {
118-
val accumInitializerName = defaultInitializerOpName(createName(v, ACCUMULATOR))
119-
val accumInitializer = tf.withName(accumInitializerName)
120-
.fill(tf.shape(v), tf.constant(initialAccumulatorValue))
121-
val accumulator = createSlot(graph, tf, v.asOutput(), ACCUMULATOR, accumInitializer)
122-
123-
val linearAccumInitializerName = defaultInitializerOpName(createName(v, LINEAR_ACCUMULATOR))
124-
val linearAccumInitializer = tf.withName(linearAccumInitializerName)
125-
.fill(tf.shape(v), tf.constant(0.0f))
126-
val linearAccumulator = createSlot(graph, tf, v.asOutput(), LINEAR_ACCUMULATOR, linearAccumInitializer)
127-
128-
return accumulator to linearAccumulator
129-
}
130-
131117
override val optimizerName: String get() = "Ftrl"
132118

133119
override val isRunningOnGPU: Boolean get() = false

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Momentum.kt

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
package org.jetbrains.kotlinx.dl.api.core.optimizer
77

88
import org.jetbrains.kotlinx.dl.api.core.KGraph
9-
import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
109
import org.tensorflow.Operand
11-
import org.tensorflow.Output
1210
import org.tensorflow.op.Ops
1311
import org.tensorflow.op.core.Constant
1412
import org.tensorflow.op.core.Gradients
@@ -50,7 +48,7 @@ public class Momentum(
5048
momentumConst = tf.constant(momentum)
5149

5250
for ((i, variable) in weights.withIndex()) {
53-
val slot = createMomentumSlot(graph, tf, variable.asOutput())
51+
val slot = createSlot(MOMENTUM, variable.asOutput(), tf, graph)
5452

5553
targets.add(
5654
tf.train.applyMomentum(
@@ -67,13 +65,6 @@ public class Momentum(
6765
return targets
6866
}
6967

70-
private fun createMomentumSlot(graph: KGraph, tf: Ops, v: Output<Float>): Variable<Float> {
71-
val momentumInitializerName = defaultInitializerOpName(createName(v, MOMENTUM))
72-
val initializer: Operand<Float> = tf.withName(momentumInitializerName)
73-
.fill(tf.shape(v), tf.constant(0.0f))
74-
return createSlot(graph, tf, v.asOutput(), MOMENTUM, initializer)
75-
}
76-
7768
override val optimizerName: String get() = "Momentum"
7869

7970
override val isRunningOnGPU: Boolean get() = true

api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer.kt

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ package org.jetbrains.kotlinx.dl.api.core.optimizer
77

88
import org.jetbrains.kotlinx.dl.api.core.KGraph
99
import org.jetbrains.kotlinx.dl.api.core.util.defaultAssignOpName
10+
import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
1011
import org.jetbrains.kotlinx.dl.api.core.util.defaultOptimizerVariableName
1112
import org.jetbrains.kotlinx.dl.api.core.util.getDType
1213
import org.tensorflow.Operand
1314
import org.tensorflow.Output
1415
import org.tensorflow.op.Ops
15-
import org.tensorflow.op.core.Assign
1616
import org.tensorflow.op.core.Gradients
1717
import org.tensorflow.op.core.Variable
1818

@@ -74,38 +74,34 @@ public abstract class Optimizer(public val clipGradient: ClipGradientAction) {
7474
* Creates a slot in the graph for the specified variable with the specified name. Adds the slot's
7575
* initializer to the graph's initializers.
7676
*
77-
* @param [graph] KGraph to be updated.
78-
* @param [tf] TensorFlow graph API for building operations.
79-
* @param [variable] The variable to create the slot for.
8077
* @param [slotName] The name of the slot.
81-
* @param [initializer] The initializer for the slot.
78+
* @param [variable] The variable to create the slot for.
79+
* @param [tf] TensorFlow graph API for building operations.
80+
* @param [graph] KGraph to be updated.
81+
* @param [initialValue] The initial value to use.
8282
*/
83-
protected open fun createSlot(
84-
graph: KGraph,
85-
tf: Ops,
86-
variable: Output<Float>,
87-
slotName: String,
88-
initializer: Operand<Float>
83+
protected fun createSlot(slotName: String,
84+
variable: Output<Float>,
85+
tf: Ops,
86+
graph: KGraph,
87+
initialValue: Float = 0.0f
8988
): Variable<Float> {
90-
val createName: String = createName(variable, slotName)
91-
val slot: Variable<Float> = tf.withName(createName).variable(variable.shape(), getDType())
89+
val slotVariableName = defaultOptimizerVariableName(variable.op().name() + "-" + slotName)
90+
val slot = tf.withName(slotVariableName).variable(variable.shape(), getDType())
91+
92+
val initializerOpName = defaultInitializerOpName(slotVariableName)
93+
val initializerOp = tf.withName(initializerOpName)
94+
.fill(tf.shape(variable), tf.dtypes.cast(tf.constant(initialValue), getDType()))
9295

93-
val assignName = defaultAssignOpName(createName(variable, slotName))
94-
val slotInit: Assign<Float> = tf.withName(assignName).assign(slot, initializer)
96+
val assignOpName = defaultAssignOpName(slotVariableName)
97+
val assignOp = tf.withName(assignOpName).assign(slot, initializerOp)
9598

96-
graph.addOptimizerVariableInitializer(slotInit)
99+
graph.addOptimizerVariableInitializer(assignOp)
97100
graph.addOptimizerVariable(slot)
98101

99102
return slot
100103
}
101104

102-
/**
103-
* Creates name for [variable] used in slot with name [slotName].
104-
*/
105-
internal open fun createName(variable: Output<Float>, slotName: String): String {
106-
return defaultOptimizerVariableName(variable.op().name() + "-" + slotName)
107-
}
108-
109105
/** True, if optimizer is implemented for GPU. */
110106
internal abstract val isRunningOnGPU: Boolean
111107
}

0 commit comments

Comments
 (0)