Skip to content

Commit 1cedb20

Browse files
committed
Split logic for sequential and functional models
1 parent 90c95b1 commit 1cedb20

File tree

1 file changed

+65
-8
lines changed

1 file changed

+65
-8
lines changed

keras/engine/training.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from keras.engine import base_layer_utils
3232
from keras.engine import compile_utils
3333
from keras.engine import data_adapter
34+
from keras.engine import functional
3435
from keras.engine import input_layer as input_layer_module
3536
from keras.engine import training_utils
3637
from keras.metrics import base_metric
@@ -191,8 +192,6 @@ def __new__(cls, *args, **kwargs):
191192
# Signature detection
192193
if is_functional_model_init_params(args, kwargs) and cls == Model:
193194
# Functional model
194-
from keras.engine import functional
195-
196195
return functional.Functional(skip_init=True, *args, **kwargs)
197196
else:
198197
return super(Model, cls).__new__(cls, *args, **kwargs)
@@ -724,7 +723,7 @@ def compile(
724723
**kwargs: Arguments supported for backwards compatibility only.
725724
"""
726725

727-
_check_output_activations(self.outputs)
726+
_validate_softmax_output(self)
728727

729728
if jit_compile and not tf_utils.can_jit_compile(warn=True):
730729
jit_compile = False
@@ -4379,16 +4378,74 @@ def is_functional_model_init_params(args, kwargs):
43794378
return False
43804379

43814380

4382-
def _check_output_activations(outputs):
4381+
def _validate_softmax_output(model_instance):
4382+
"""
4383+
Calls the related function for checking the output activations
4384+
4385+
Args:
4386+
model_instance: A `Model` instance, either functional or sequential.
4387+
4388+
"""
4389+
4390+
if isinstance(model_instance, tf.keras.Sequential):
4391+
output = model_instance.layers[-1]
4392+
check_sequential_output_activation(output)
4393+
4394+
elif isinstance(model_instance, functional.Functional):
4395+
outputs = model_instance.outputs
4396+
check_functional_output_activation(outputs)
4397+
4398+
else: # model is a subclassed/custom model, so we don't apply any checks
4399+
return
4400+
4401+
4402+
def check_sequential_output_activation(last_layer):
4403+
"""
4404+
Checks if the last layer of a sequential model has a softmax activation
4405+
and a single unit output.
4406+
4407+
Args:
4408+
last_layer: The last layer of a sequential model. Instance of `Layer`.
4409+
4410+
Raises:
4411+
Warning: If the last layer has a softmax activation and a
4412+
single unit output.
4413+
"""
4414+
try:
4415+
# Check if model had an input layer, if not we can't check the output
4416+
# shape of the layer, and it is difficult determine. So we just skip
4417+
# this check if this try-catch fails.
4418+
output_shape_last_axis = last_layer.output_shape[-1]
4419+
except AttributeError:
4420+
return
4421+
4422+
activation = last_layer.activation
4423+
4424+
if "softmax" in str(activation) and output_shape_last_axis == 1:
4425+
warnings.warn(
4426+
"Found a layer with softmax activation and single unit output. "
4427+
"This is most likely an error as this will produce a model "
4428+
"which outputs ones (1) all the time. Ensure you are using "
4429+
"the correct activation function. "
4430+
f"Found activation: {activation} at "
4431+
f"{last_layer.name} with output shape: {last_layer.output_shape}. "
4432+
"If you don't apply softmax on the last axis, you can ignore "
4433+
"this warning."
4434+
)
4435+
4436+
4437+
def check_functional_output_activation(outputs):
43834438
"""
4384-
Checks if the output activation is softmax and the output shape is 1.
4439+
Checks if the last layer(s) of a functional model has a softmax activation
4440+
and a single unit output.
43854441
43864442
Args:
4387-
outputs: List of outputs of the model, instance of KerasTensor.
4443+
outputs: The output(s) of a functional model. List of `KerasTensor`.
43884444
43894445
Raises:
4390-
Warning: If the last axis of the output shape is 1 and the activation
4391-
is softmax.
4446+
Warning: If the last layer has a softmax activation and a
4447+
single unit output.
4448+
43924449
"""
43934450
for output in outputs:
43944451
# Outputs are instance of KerasTensor. The activation is stored in

0 commit comments

Comments
 (0)