Skip to content

Need a way to pass trainable Variables to Optimizer #307

Open
@JimClarke5

Description

@JimClarke5

In woking with Model training, an issue on Optimizer has shown its head.

Currently, when calling minimize(loss) on the Optimizer instance, the Optimizer code walks the entire Graph and pulls out all the defined Variables in the graph. The idea is when you call minimize(loss), the Optimizer builds gradients based on all the variables. However, when working with Model, this "all variables approach" breaks down, because some variables are not referenced in the loss operand execution path. This produces the following error:

org.tensorflow.exceptions.TFInvalidArgumentException: Cannot compute the partial derivative for node 'model/mse_total' as it's unreachable from the output node(s).

This specific error is because the MSE metric's internal variables are not within the loss execution path. This pattern of "non-trainable variables (weights)" is in most Metric classes, and in the Model itself, so it is wide spread. What we need is a way to distinguish between trainable and non-trainable variables. Trainable variables would then be used to calculate the gradient values in the Optimizer.

In Python tensorflow, the Keras Layers track the trainable variables as an attribute list, the Model then passes the collected lists to the Optimizer's minimize method.

There are a couple of options here:

  1. Mimic TF Keras, and have each Layer identify its trainable variables, Then, pass the trainable variables as a List<Variable<?> list using a call like, Optimizer.minimize(loss, trainableVariables), then have the Optimizer minimize routine call addGradients with this variable list, rather than walk the whole Graph, to compute the gradients.
  2. Within Optimzier.minimize(loss), walk the loss operand execution path to locate any variables contributing to the loss calculation, then pass these to addGradients. A solution based on this option may be facilitated using Add graph walking functions to Graph and GraphOperation #232, "Add graph walking functions to Graph and GraphOperation".

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions