Skip to content

Commit 3a228fb

Browse files
committed
Fix edge case for _validate_softmax_output
1 parent afc156a commit 3a228fb

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

keras/engine/training.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4389,14 +4389,22 @@ def _validate_softmax_output(model_instance):
43894389
43904390
"""
43914391
outputs = model_instance.outputs
4392+
4393+
# `outputs` can be None in the case of subclassed models.
43924394
if outputs is not None:
43934395
for output in outputs:
4394-
if (
4395-
"softmax" in str(output.name.lower())
4396-
and output.__class__.__name__ == "KerasTensor"
4397-
):
4398-
check_output_activation(output)
43994396

4397+
# if an output layer ends with a native tf_ops the name can be None,
4398+
# i.e using output layer: tf.cast(outputs, tf.float32)
4399+
output_name = str(output.name)
4400+
if output_name is not None:
4401+
if (
4402+
"softmax" in str(output_name.lower())
4403+
and output.__class__.__name__ == "KerasTensor"
4404+
):
4405+
check_output_activation(output)
4406+
else:
4407+
continue
44004408
else: # model is a subclassed/custom model, so we don't apply any checks
44014409
return
44024410

0 commit comments

Comments
 (0)