Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 67 additions & 4 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from keras.engine import base_layer_utils
from keras.engine import compile_utils
from keras.engine import data_adapter
from keras.engine import functional
from keras.engine import input_layer as input_layer_module
from keras.engine import training_utils
from keras.metrics import base_metric
Expand Down Expand Up @@ -191,8 +192,6 @@ def __new__(cls, *args, **kwargs):
# Signature detection
if is_functional_model_init_params(args, kwargs) and cls == Model:
# Functional model
from keras.engine import functional

return functional.Functional(skip_init=True, *args, **kwargs)
else:
return super(Model, cls).__new__(cls, *args, **kwargs)
Expand All @@ -206,8 +205,6 @@ def __init__(self, *args, **kwargs):
# Special case for Subclassed Functional Model, which we couldn't detect
# when __new__ is called. We only realize it is a functional model when
# it calls super.__init__ with input and output tensor.
from keras.engine import functional

if is_functional_model_init_params(args, kwargs) and not isinstance(
self, functional.Functional
):
Expand Down Expand Up @@ -723,6 +720,13 @@ def compile(
Defaults to `0`.
**kwargs: Arguments supported for backwards compatibility only.
"""

validate_softmax_activation = kwargs.pop(
"experimental_validate_softmax_activation", True
)
if validate_softmax_activation:
_validate_softmax_output(self)

if jit_compile and not tf_utils.can_jit_compile(warn=True):
jit_compile = False
base_layer.keras_api_gauge.get_cell("compile").set(True)
Expand Down Expand Up @@ -3795,6 +3799,7 @@ def _validate_compile(self, optimizer, metrics, **kwargs):

kwargs.pop("cloning", None) # Legacy DistStrat argument, never used.
kwargs.pop("experimental_run_tf_function", None) # Always `True`.
kwargs.pop("experimental_validate_softmax_activation", None)
distribute_arg = kwargs.pop("distribute", None)
if distribute_arg is not None:
raise ValueError(
Expand Down Expand Up @@ -4399,3 +4404,61 @@ def is_functional_model_init_params(args, kwargs):
if "inputs" in kwargs and "outputs" in kwargs:
return True
return False


def _validate_softmax_output(model_instance):
"""
Calls the related function for checking the output activations

Args:
model_instance: A `Model` instance, either functional or sequential.

"""
outputs = model_instance.outputs
if outputs is not None:
for output in outputs:
if (
"softmax" in str(output.name.lower())
and output.__class__.__name__ == "KerasTensor"
):
check_output_activation(output)

else: # model is a subclassed/custom model, so we don't apply any checks
return


def check_output_activation(output):
"""
Checks if the last layer(s) of a functional model has a softmax activation
and a single unit output.

Args:
output: The output of a Keras (either Functional or Sequential) model.
List of `KerasTensor`.

Raises:
Warning: If the last layer has a softmax activation and a
single unit output.

"""
# Outputs are instance of KerasTensor. The activation is stored in
# the name of the tensor. Ex: dense_12/Softmax:0
layer_name_and_act = output.name.split("/")
output_act = layer_name_and_act[-1].lower()

# Softmax is applied on the last axis of logits.
output_shape_last_axis = output.shape[-1]

if "softmax" in output_act and output_shape_last_axis == 1:
warnings.warn(
"Found a layer with softmax activation and single unit output. "
"This is most likely an error as this will produce a model "
"which outputs ones (1) all the time. Ensure you are using "
"the correct activation function. "
f"Found activation: {output_act} at "
f"{layer_name_and_act[0]} with output shape: {output.shape}. "
"If you don't apply softmax on the last axis, you can ignore "
"this warning.",
SyntaxWarning,
stacklevel=2,
)
77 changes: 76 additions & 1 deletion keras/engine/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# ==============================================================================
"""Tests for training routines."""


import collections
import io
import sys
import tempfile
import warnings

import numpy as np
import tensorflow.compat.v2 as tf
Expand Down Expand Up @@ -5000,6 +5000,81 @@ def test_sequential_model_get_weight_paths(self):
)


class TestCheckLastLayerActivation(test_combinations.TestCase):
def test_sequential_model_output(self):

for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
model = sequential.Sequential(
[
layers_module.InputLayer(input_shape=(10,)),
layers_module.Dense(1, activation=activation),
]
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.compile()
self.assertIs(w[-1].category, SyntaxWarning)
self.assertIn(
"Found a layer with softmax activation and single unit "
"output",
str(w[-1].message),
)
del model

def test_functional_model_output(self):
inputs = input_layer.Input(shape=(10,))
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs)
model = training_module.Model(inputs, x)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.compile()
self.assertIs(w[-1].category, SyntaxWarning)
self.assertIn(
"Found a layer with softmax activation and single unit "
"output",
str(w[-1].message),
)
del model

def test_multi_output_model(self):
inputs = input_layer.Input(shape=(10,))
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs)
y = layers_module.Dense(1, activation=activation)(inputs)
model = training_module.Model(inputs, [x, y])
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.compile()
self.assertIs(w[-1].category, SyntaxWarning)
self.assertIn(
"Found a layer with softmax activation and single unit "
"output",
str(w[-1].message),
)
del model

def test_multi_input_output_model(self):
inputs = [
input_layer.Input(shape=(10,)),
input_layer.Input(shape=(10,)),
]
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs[0])
y = layers_module.Dense(1, activation=activation)(inputs[1])
model = training_module.Model(inputs, [x, y])
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.compile()
self.assertIs(w[-1].category, SyntaxWarning)
self.assertIn(
"Found a layer with softmax activation and single unit "
"output",
str(w[-1].message),
)
del model


def _is_oss():
"""Returns whether the test is run under OSS."""
return len(sys.argv) >= 1 and "bazel" in sys.argv[0]
Expand Down