Skip to content
This repository was archived by the owner on Mar 11, 2021. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions dual_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ def model_fn(features, labels, mode, params):
train_op = optimizer.minimize(combined_cost, global_step=global_step)

# Computations to be executed on CPU, outside of the main TPU queues.
def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cost,
value_cost, l2_cost, combined_cost, step,
def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
value_tensor, policy_cost, value_cost,
l2_cost, combined_cost, step,
est_mode=tf.estimator.ModeKeys.TRAIN):
policy_entropy = -tf.reduce_mean(tf.reduce_sum(
policy_output * tf.log(policy_output), axis=1))
Expand All @@ -299,6 +300,7 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos
tf.one_hot(policy_target_top_1, tf.shape(policy_output)[1]))

value_cost_normalized = value_cost / params['value_cost_weight']
avg_value_observed = tf.reduce_mean(value_tensor)

with tf.variable_scope('metrics'):
metric_ops = {
Expand All @@ -308,7 +310,7 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos
'l2_cost': tf.metrics.mean(l2_cost),
'policy_entropy': tf.metrics.mean(policy_entropy),
'combined_cost': tf.metrics.mean(combined_cost),

'avg_value_observed': tf.metrics.mean(avg_value_observed),
'policy_accuracy_top_1': tf.metrics.mean(policy_output_in_top1),
'policy_accuracy_top_3': tf.metrics.mean(policy_output_in_top3),
'policy_top_1_confidence': tf.metrics.mean(policy_top_1_confidence),
Expand Down Expand Up @@ -345,6 +347,7 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos
policy_output,
value_output,
labels['pi_tensor'],
labels['value_tensor'],
tf.reshape(policy_cost, [1]),
tf.reshape(value_cost, [1]),
tf.reshape(l2_cost, [1]),
Expand Down