Skip to content

fix: optimize some APIs #1129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions src/TensorFlowNET.Core/APIs/tf.nn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,8 @@ public Tensor batch_normalization(Tensor x,
Tensor offset,
Tensor scale,
float variance_epsilon,
string name = null)
{
var inv = math_ops.rsqrt(variance + variance_epsilon);
tf_with(ops.name_scope(name, "batchnorm", (x, mean, variance, scale, offset)), scope =>
{
if (scale != null) inv *= scale;
});
if (offset != null) return x * math_ops.cast(inv, x.dtype) + math_ops.cast(offset - mean * inv, dtype: x.dtype);
else return x * math_ops.cast(inv, x.dtype) + math_ops.cast(-mean * inv, dtype: x.dtype);
}
string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name);


public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);
Expand Down
1 change: 0 additions & 1 deletion src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,6 @@ public static Tensor stop_gradient(Tensor input, string name = null)
var tape = tf.GradientTape().stop_recording();
var result = gen_array_ops.stop_gradient(input, name);
tape.StartRecord();
tf.GradientTape().PushTape(tape);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this removal looks highly suspicious. Does anyone even know why this statement is here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks fine when I run though all the examples.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

none of the examples test this specific method though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not deleting this line of code will result in NULL gradient when using the stop_gradient API.

Copy link
Contributor Author

@Beacontownfc Beacontownfc Jul 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LayerNorm uses tf.nn.moments and tf.nn.moments uses the stop_gradient API. If this line of code is not deleted, an error will be reported. Deleting this line of code will enable normal training.

return result;
}

Expand Down