Skip to content

Commit b72c135

Browse files
committed
Fix the keras.sparse_categorical_crossentropy. (#985)
1 parent 5821275 commit b72c135

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs

+7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ public override Tensor Apply(Tensor target, Tensor output, bool from_logits = fa
1414
{
1515
target = tf.cast(target, dtype: TF_DataType.TF_INT64);
1616

17+
if (!from_logits)
18+
{
19+
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype);
20+
output = tf.clip_by_value(output, epsilon, 1 - epsilon);
21+
output = tf.log(output);
22+
}
23+
1724
// Try to adjust the shape so that rank of labels = rank of logits - 1.
1825
var output_shape = array_ops.shape_v2(output);
1926
var output_rank = output.shape.ndim;

0 commit comments

Comments
 (0)