Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def _backend_update_step(self, grads, trainable_variables, learning_rate):

@torch_utils.no_grad
def _backend_reset_gradient_accumulators(self):
acc_list = [v.value for v in self._accumulated_gradients]
acc_list = [
v.value for v in self._accumulated_gradients if v is not None
]
torch._foreach_mul_(acc_list, 0.0)

@torch_utils.no_grad
Expand Down
15 changes: 6 additions & 9 deletions keras/src/optimizers/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,12 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._accumulated_grads = []
self._accumulated_delta_vars = []
for var in var_list:
self._accumulated_grads.append(
self.add_variable_from_reference(var, "accumulated_grad")
)
self._accumulated_delta_vars.append(
self.add_variable_from_reference(var, "accumulated_delta_var")
)
self._accumulated_grads = self.add_optimizer_variables(
var_list, "accumulated_grad"
)
self._accumulated_delta_vars = self.add_optimizer_variables(
var_list, "accumulated_delta_var"
)

def update_step(self, grad, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
30 changes: 15 additions & 15 deletions keras/src/optimizers/adafactor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.optimizers import optimizer
Expand Down Expand Up @@ -97,16 +96,13 @@ def build(self, var_list):
self._c = []
self._v = []
for var in var_list:
if len(var.shape) < 2:
# Don't factor if variable is of dimension < 2, but we still
# need to create dummy variables as placeholder.
with backend.name_scope(self.name, caller=self):
self._r.append(
backend.Variable(0, name=var.name, trainable=False)
)
self._c.append(
backend.Variable(0, name=var.name, trainable=False)
)
if (
self._overwrite_variable_with_gradient(var)
or len(var.shape) < 2
):
# Don't factor if variable is of dimension < 2.
self._r.append(None)
self._c.append(None)
else:
# Always factor the last 2 dimensions.
r_shape = var.shape[:-1]
Expand All @@ -125,11 +121,15 @@ def build(self, var_list):
name=var.name,
)
)
self._v.append(
self.add_variable_from_reference(
reference_variable=var, name="velocity"

if self._overwrite_variable_with_gradient(var):
self._v.append(None)
else:
self._v.append(
self.add_variable_from_reference(
reference_variable=var, name="velocity"
)
)
)

def _rms(self, x):
return ops.sqrt(ops.mean(ops.square(x)))
Expand Down
13 changes: 3 additions & 10 deletions keras/src/optimizers/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,10 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._accumulators = []
initializer = initializers.Constant(self.initial_accumulator_value)
for var in var_list:
self._accumulators.append(
self.add_variable(
shape=var.shape,
initializer=initializer,
dtype=var.dtype,
name="accumulator",
)
)
self._accumulators = self.add_optimizer_variables(
var_list, "accumulator", initializer=initializer
)

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
26 changes: 6 additions & 20 deletions keras/src/optimizers/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,27 +90,13 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._momentums = []
self._velocities = []
for var in var_list:
self._momentums.append(
self.add_variable_from_reference(
reference_variable=var, name="momentum"
)
)
self._velocities.append(
self.add_variable_from_reference(
reference_variable=var, name="velocity"
)
)
self._momentums = self.add_optimizer_variables(var_list, "momentum")
self._velocities = self.add_optimizer_variables(var_list, "velocity")

if self.amsgrad:
self._velocity_hats = []
for var in var_list:
self._velocity_hats.append(
self.add_variable_from_reference(
reference_variable=var, name="velocity_hat"
)
)
self._velocity_hats = self.add_optimizer_variables(
var_list, "velocity_hat"
)

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
15 changes: 2 additions & 13 deletions keras/src/optimizers/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,8 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._m = []
self._u = []
for var in var_list:
self._m.append(
self.add_variable_from_reference(
reference_variable=var, name="momentum"
)
)
self._u.append(
self.add_variable_from_reference(
reference_variable=var, name="norm"
)
)
self._m = self.add_optimizer_variables(var_list, "momentum")
self._u = self.add_optimizer_variables(var_list, "norm")

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
80 changes: 62 additions & 18 deletions keras/src/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,19 @@ def iterations(self):
def _track_variable(self, variable):
self._tracker.add_to_store("variables", variable)

def _overwrite_variable_with_gradient(self, variable):
return getattr(variable, "overwrite_with_gradient", False)

@tracking.no_automatic_dependency_tracking
def build(self, variables):
if self.use_ema:
self._model_variables_moving_average = []
self._model_variables_moving_average = self.add_optimizer_variables(
variables, "average"
)
if self.gradient_accumulation_steps:
self._accumulated_gradients = []
for i, variable in enumerate(variables):
self._trainable_variables_indices[self._var_key(variable)] = i
if self.use_ema:
self._model_variables_moving_average.append(
self.add_variable_from_reference(
variable,
name="average",
)
)
if self.gradient_accumulation_steps:
self._accumulated_gradients.append(
self.add_variable_from_reference(
Expand Down Expand Up @@ -323,6 +321,49 @@ def add_variable_from_reference(
name=name,
)

def add_optimizer_variables(
self, trainable_variables, name, initializer="zeros"
):
"""Add optimizer variables from the list of trainable model variables.

Create an optimizer variable based on the information of the supplied
model variables. For example, in SGD optimizer momemtum, for each model
variable, a corresponding momemtum variable is created of the same shape
and dtype.

Note that trainable variables with `v.overwrite_with_gradient == True`
will insert `None`, into the output list, since the optimizer variable
will not be used anyways, and could be wasteful.

Args:
trainable_variables: `keras.Variable`, the corresponding model
variable to the optimizer variable to be created.
name: The name prefix of the optimizer variable to be created. The
variable name will follow the pattern
`{variable_name}_{trainable_variable.name}`, e.g.,
`momemtum/dense_1`. Defaults to `None`.
initializer: Initializer object to use to populate the initial
variable value, or string name of a built-in initializer (e.g.
`"random_normal"`). If unspecified, defaults to `"zeros"`.

Returns:
A list of optimizer variables, in the format of `keras.Variable`s.
"""
optimizer_variables = []
for variable in trainable_variables:
if not self._overwrite_variable_with_gradient(variable):
optimizer_variables.append(
self.add_variable_from_reference(
variable,
name=name,
initializer=initializer,
)
)
else:
optimizer_variables.append(None)

return optimizer_variables

def _check_variables_are_known(self, variables):
for v in variables:
if self._var_key(v) not in self._trainable_variables_indices:
Expand Down Expand Up @@ -544,7 +585,8 @@ def _backend_update_step(self, grads, trainable_variables, learning_rate):

def _backend_reset_gradient_accumulators(self):
for g_acc in self._accumulated_gradients:
g_acc.assign(ops.zeros(g_acc.shape, dtype=g_acc.dtype))
if g_acc is not None:
g_acc.assign(ops.zeros(g_acc.shape, dtype=g_acc.dtype))

def _backend_increment_gradient_accumulators(self, grads, acc_grads):
new_g_accs = [(g + acc_g) for g, acc_g in zip(grads, acc_grads)]
Expand Down Expand Up @@ -711,8 +753,8 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars):
After the update, the processed pairs will be filtered out.
"""
# Shortcut for `tf.Variable` because it doesn't have a
# `overwrite_with_gradient` attr
if any(not hasattr(v, "overwrite_with_gradient") for v in vars):
# `overwrite_with_gradient` attr.
if not any(self._overwrite_variable_with_gradient(v) for v in vars):
return grads, vars

# Shallow copies
Expand All @@ -722,7 +764,7 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars):
# Iterate from right to left for safe popping
for i in range(len(filtered_grads) - 1, -1, -1):
g, v = filtered_grads[i], filtered_vars[i]
if v.overwrite_with_gradient:
if self._overwrite_variable_with_gradient(v):
if self.gradient_accumulation_steps:
# Utilize a stateless manner for JAX compatibility
steps = self.gradient_accumulation_steps
Expand Down Expand Up @@ -886,11 +928,12 @@ def _update_model_variables_moving_average(self, trainable_variables):
for var, average in zip(
trainable_variables, self._model_variables_moving_average
):
not_first_step = ops.not_equal(self.iterations, 0)
momentum = (
ops.cast(not_first_step, var.dtype) * self.ema_momentum
)
average.assign(momentum * average + (1 - momentum) * var)
if average is not None:
not_first_step = ops.not_equal(self.iterations, 0)
momentum = (
ops.cast(not_first_step, var.dtype) * self.ema_momentum
)
average.assign(momentum * average + (1 - momentum) * var)

def _overwrite_model_variables_with_average_value(
self, trainable_variables
Expand All @@ -909,7 +952,8 @@ def _overwrite_model_variables_with_average_value(
for var, average_var in zip(
trainable_variables, self._model_variables_moving_average
):
var.assign(average_var)
if average_var is not None:
var.assign(average_var)

def finalize_variable_values(self, var_list):
"""Set the final value of model's trainable variables.
Expand Down
25 changes: 7 additions & 18 deletions keras/src/optimizers/ftrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,24 +159,13 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._accumulators = []
self._linears = []
for var in var_list:
self._accumulators.append(
self.add_variable(
shape=var.shape,
dtype=var.dtype,
name="accumulator",
initializer=initializers.Constant(
self.initial_accumulator_value,
),
)
)
self._linears.append(
self.add_variable_from_reference(
reference_variable=var, name="linear"
)
)
accumulator_initializer = initializers.Constant(
self.initial_accumulator_value,
)
self._accumulators = self.add_optimizer_variables(
var_list, "accumulator", initializer=accumulator_initializer
)
self._linears = self.add_optimizer_variables(var_list, "linear")

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
15 changes: 2 additions & 13 deletions keras/src/optimizers/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,8 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._momentums = []
self._velocities = []
for var in var_list:
self._momentums.append(
self.add_variable_from_reference(
reference_variable=var, name="momentum"
)
)
self._velocities.append(
self.add_variable_from_reference(
reference_variable=var, name="velocity"
)
)
self._momentums = self.add_optimizer_variables(var_list, "momentum")
self._velocities = self.add_optimizer_variables(var_list, "velocity")

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
8 changes: 1 addition & 7 deletions keras/src/optimizers/lion.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,7 @@ def build(self, var_list):
if self.built:
return
super().build(var_list)
self._momentums = []
for var in var_list:
self._momentums.append(
self.add_variable_from_reference(
reference_variable=var, name="momentum"
)
)
self._momentums = self.add_optimizer_variables(var_list, "momentum")

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
Expand Down
6 changes: 0 additions & 6 deletions keras/src/optimizers/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,6 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
),
)

def _overwrite_variable_with_gradient(self, variable):
return (
hasattr(variable, "overwrite_with_gradient")
and variable.overwrite_with_gradient
)

def _stateless_handle_finite_grads(
self, optimizer_variables, grads, trainable_variables
):
Expand Down
Loading