|
31 | 31 | from keras.engine import base_layer_utils
|
32 | 32 | from keras.engine import compile_utils
|
33 | 33 | from keras.engine import data_adapter
|
| 34 | +from keras.engine import functional |
34 | 35 | from keras.engine import input_layer as input_layer_module
|
35 | 36 | from keras.engine import training_utils
|
36 | 37 | from keras.metrics import base_metric
|
@@ -191,8 +192,6 @@ def __new__(cls, *args, **kwargs):
|
191 | 192 | # Signature detection
|
192 | 193 | if is_functional_model_init_params(args, kwargs) and cls == Model:
|
193 | 194 | # Functional model
|
194 |
| - from keras.engine import functional |
195 |
| - |
196 | 195 | return functional.Functional(skip_init=True, *args, **kwargs)
|
197 | 196 | else:
|
198 | 197 | return super(Model, cls).__new__(cls, *args, **kwargs)
|
@@ -724,7 +723,7 @@ def compile(
|
724 | 723 | **kwargs: Arguments supported for backwards compatibility only.
|
725 | 724 | """
|
726 | 725 |
|
727 |
| - _check_output_activations(self.outputs) |
| 726 | + _validate_softmax_output(self) |
728 | 727 |
|
729 | 728 | if jit_compile and not tf_utils.can_jit_compile(warn=True):
|
730 | 729 | jit_compile = False
|
@@ -4379,16 +4378,74 @@ def is_functional_model_init_params(args, kwargs):
|
4379 | 4378 | return False
|
4380 | 4379 |
|
4381 | 4380 |
|
4382 |
| -def _check_output_activations(outputs): |
| 4381 | +def _validate_softmax_output(model_instance): |
| 4382 | + """ |
| 4383 | + Calls the related function for checking the output activations |
| 4384 | +
|
| 4385 | + Args: |
| 4386 | + model_instance: A `Model` instance, either functional or sequential. |
| 4387 | +
|
| 4388 | + """ |
| 4389 | + |
| 4390 | + if isinstance(model_instance, tf.keras.Sequential): |
| 4391 | + output = model_instance.layers[-1] |
| 4392 | + check_sequential_output_activation(output) |
| 4393 | + |
| 4394 | + elif isinstance(model_instance, functional.Functional): |
| 4395 | + outputs = model_instance.outputs |
| 4396 | + check_functional_output_activation(outputs) |
| 4397 | + |
| 4398 | + else: # model is a subclassed/custom model, so we don't apply any checks |
| 4399 | + return |
| 4400 | + |
| 4401 | + |
| 4402 | +def check_sequential_output_activation(last_layer): |
| 4403 | + """ |
| 4404 | + Checks if the last layer of a sequential model has a softmax activation |
| 4405 | + and a single unit output. |
| 4406 | +
|
| 4407 | + Args: |
| 4408 | + last_layer: The last layer of a sequential model. Instance of `Layer`. |
| 4409 | +
|
| 4410 | + Raises: |
| 4411 | + Warning: If the last layer has a softmax activation and a |
| 4412 | + single unit output. |
| 4413 | + """ |
| 4414 | + try: |
| 4415 | + # Check if model had an input layer, if not we can't check the output |
| 4416 | + # shape of the layer, and it is difficult determine. So we just skip |
| 4417 | + # this check if this try-catch fails. |
| 4418 | + output_shape_last_axis = last_layer.output_shape[-1] |
| 4419 | + except AttributeError: |
| 4420 | + return |
| 4421 | + |
| 4422 | + activation = last_layer.activation |
| 4423 | + |
| 4424 | + if "softmax" in str(activation) and output_shape_last_axis == 1: |
| 4425 | + warnings.warn( |
| 4426 | + "Found a layer with softmax activation and single unit output. " |
| 4427 | + "This is most likely an error as this will produce a model " |
| 4428 | + "which outputs ones (1) all the time. Ensure you are using " |
| 4429 | + "the correct activation function. " |
| 4430 | + f"Found activation: {activation} at " |
| 4431 | + f"{last_layer.name} with output shape: {last_layer.output_shape}. " |
| 4432 | + "If you don't apply softmax on the last axis, you can ignore " |
| 4433 | + "this warning." |
| 4434 | + ) |
| 4435 | + |
| 4436 | + |
| 4437 | +def check_functional_output_activation(outputs): |
4383 | 4438 | """
|
4384 |
| - Checks if the output activation is softmax and the output shape is 1. |
| 4439 | + Checks if the last layer(s) of a functional model has a softmax activation |
| 4440 | + and a single unit output. |
4385 | 4441 |
|
4386 | 4442 | Args:
|
4387 |
| - outputs: List of outputs of the model, instance of KerasTensor. |
| 4443 | + outputs: The output(s) of a functional model. List of `KerasTensor`. |
4388 | 4444 |
|
4389 | 4445 | Raises:
|
4390 |
| - Warning: If the last axis of the output shape is 1 and the activation |
4391 |
| - is softmax. |
| 4446 | + Warning: If the last layer has a softmax activation and a |
| 4447 | + single unit output. |
| 4448 | +
|
4392 | 4449 | """
|
4393 | 4450 | for output in outputs:
|
4394 | 4451 | # Outputs are instance of KerasTensor. The activation is stored in
|
|
0 commit comments