Skip to content

Commit 90c95b1

Browse files
committed
Add last layer activation check for softmax
1 parent 9b55a1c commit 90c95b1

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

keras/engine/training.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,9 @@ def compile(
723723
Defaults to `0`.
724724
**kwargs: Arguments supported for backwards compatibility only.
725725
"""
726+
727+
_check_output_activations(self.outputs)
728+
726729
if jit_compile and not tf_utils.can_jit_compile(warn=True):
727730
jit_compile = False
728731
base_layer.keras_api_gauge.get_cell("compile").set(True)
@@ -789,6 +792,7 @@ def compile(
789792
else:
790793
self._jit_compile = jit_compile
791794

795+
792796
def _get_optimizer(self, optimizer):
793797
"""Wraps `optimizer` in `LossScaleOptimizer` if necessary."""
794798

@@ -4373,3 +4377,36 @@ def is_functional_model_init_params(args, kwargs):
43734377
if "inputs" in kwargs and "outputs" in kwargs:
43744378
return True
43754379
return False
4380+
4381+
4382+
def _check_output_activations(outputs):
4383+
"""
4384+
Checks if the output activation is softmax and the output shape is 1.
4385+
4386+
Args:
4387+
outputs: List of outputs of the model, instance of KerasTensor.
4388+
4389+
Raises:
4390+
Warning: If the last axis of the output shape is 1 and the activation
4391+
is softmax.
4392+
"""
4393+
for output in outputs:
4394+
# Outputs are instance of KerasTensor. The activation is stored in
4395+
# the name of the tensor. Ex: dense_12/Softmax:0
4396+
layer_name_and_act = output.name.split("/")
4397+
output_act = layer_name_and_act[1].lower()
4398+
4399+
# Softmax is applied on the last axis of logits.
4400+
output_shape_last_axis = output.shape[-1]
4401+
4402+
if "softmax" in output_act and output_shape_last_axis == 1:
4403+
warnings.warn(
4404+
"Found a layer with softmax activation and single unit output. "
4405+
"This is most likely an error as this will produce a model "
4406+
"which outputs ones (1) all the time. Ensure you are using "
4407+
"the correct activation function. "
4408+
f"Found activation: {output_act} at "
4409+
f"{layer_name_and_act[0]} with output shape: {output.shape}."
4410+
"If you don't apply softmax on the last axis, you can ignore "
4411+
"this warning."
4412+
)

0 commit comments

Comments
 (0)