@@ -723,6 +723,9 @@ def compile(
723
723
Defaults to `0`.
724
724
**kwargs: Arguments supported for backwards compatibility only.
725
725
"""
726
+
727
+ _check_output_activations (self .outputs )
728
+
726
729
if jit_compile and not tf_utils .can_jit_compile (warn = True ):
727
730
jit_compile = False
728
731
base_layer .keras_api_gauge .get_cell ("compile" ).set (True )
@@ -789,6 +792,7 @@ def compile(
789
792
else :
790
793
self ._jit_compile = jit_compile
791
794
795
+
792
796
def _get_optimizer (self , optimizer ):
793
797
"""Wraps `optimizer` in `LossScaleOptimizer` if necessary."""
794
798
@@ -4373,3 +4377,36 @@ def is_functional_model_init_params(args, kwargs):
4373
4377
if "inputs" in kwargs and "outputs" in kwargs :
4374
4378
return True
4375
4379
return False
4380
+
4381
+
4382
+ def _check_output_activations (outputs ):
4383
+ """
4384
+ Checks if the output activation is softmax and the output shape is 1.
4385
+
4386
+ Args:
4387
+ outputs: List of outputs of the model, instance of KerasTensor.
4388
+
4389
+ Raises:
4390
+ Warning: If the last axis of the output shape is 1 and the activation
4391
+ is softmax.
4392
+ """
4393
+ for output in outputs :
4394
+ # Outputs are instance of KerasTensor. The activation is stored in
4395
+ # the name of the tensor. Ex: dense_12/Softmax:0
4396
+ layer_name_and_act = output .name .split ("/" )
4397
+ output_act = layer_name_and_act [1 ].lower ()
4398
+
4399
+ # Softmax is applied on the last axis of logits.
4400
+ output_shape_last_axis = output .shape [- 1 ]
4401
+
4402
+ if "softmax" in output_act and output_shape_last_axis == 1 :
4403
+ warnings .warn (
4404
+ "Found a layer with softmax activation and single unit output. "
4405
+ "This is most likely an error as this will produce a model "
4406
+ "which outputs ones (1) all the time. Ensure you are using "
4407
+ "the correct activation function. "
4408
+ f"Found activation: { output_act } at "
4409
+ f"{ layer_name_and_act [0 ]} with output shape: { output .shape } ."
4410
+ "If you don't apply softmax on the last axis, you can ignore "
4411
+ "this warning."
4412
+ )
0 commit comments