Skip to content

Commit 6355b23

Browse files
committed
Update softmax check tests
1 parent e9c950e commit 6355b23

File tree

1 file changed

+16
-33
lines changed

1 file changed

+16
-33
lines changed

keras/engine/training_test.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import io
1919
import sys
2020
import tempfile
21-
import warnings
2221

2322
import numpy as np
2423
import tensorflow.compat.v2 as tf
@@ -5010,31 +5009,23 @@ def test_sequential_model_output(self):
50105009
layers_module.Dense(1, activation=activation),
50115010
]
50125011
)
5013-
with warnings.catch_warnings(record=True) as w:
5014-
warnings.simplefilter("always")
5012+
with self.assertRaisesRegex(
5013+
ValueError,
5014+
"has a single unit output, but the activation is softmax.*",
5015+
):
50155016
model.compile()
5016-
self.assertIs(w[-1].category, SyntaxWarning)
5017-
self.assertIn(
5018-
"Found a layer with softmax activation and single unit "
5019-
"output",
5020-
str(w[-1].message),
5021-
)
50225017
del model
50235018

50245019
def test_functional_model_output(self):
50255020
inputs = input_layer.Input(shape=(10,))
50265021
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
50275022
x = layers_module.Dense(1, activation=activation)(inputs)
50285023
model = training_module.Model(inputs, x)
5029-
with warnings.catch_warnings(record=True) as w:
5030-
warnings.simplefilter("always")
5024+
with self.assertRaisesRegex(
5025+
ValueError,
5026+
"has a single unit output, but the activation is softmax.*",
5027+
):
50315028
model.compile()
5032-
self.assertIs(w[-1].category, SyntaxWarning)
5033-
self.assertIn(
5034-
"Found a layer with softmax activation and single unit "
5035-
"output",
5036-
str(w[-1].message),
5037-
)
50385029
del model
50395030

50405031
def test_multi_output_model(self):
@@ -5043,15 +5034,11 @@ def test_multi_output_model(self):
50435034
x = layers_module.Dense(1, activation=activation)(inputs)
50445035
y = layers_module.Dense(1, activation=activation)(inputs)
50455036
model = training_module.Model(inputs, [x, y])
5046-
with warnings.catch_warnings(record=True) as w:
5047-
warnings.simplefilter("always")
5037+
with self.assertRaisesRegex(
5038+
ValueError,
5039+
"has a single unit output, but the activation is softmax.*",
5040+
):
50485041
model.compile()
5049-
self.assertIs(w[-1].category, SyntaxWarning)
5050-
self.assertIn(
5051-
"Found a layer with softmax activation and single unit "
5052-
"output",
5053-
str(w[-1].message),
5054-
)
50555042
del model
50565043

50575044
def test_multi_input_output_model(self):
@@ -5063,15 +5050,11 @@ def test_multi_input_output_model(self):
50635050
x = layers_module.Dense(1, activation=activation)(inputs[0])
50645051
y = layers_module.Dense(1, activation=activation)(inputs[1])
50655052
model = training_module.Model(inputs, [x, y])
5066-
with warnings.catch_warnings(record=True) as w:
5067-
warnings.simplefilter("always")
5053+
with self.assertRaisesRegex(
5054+
ValueError,
5055+
"has a single unit output, but the activation is softmax.*",
5056+
):
50685057
model.compile()
5069-
self.assertIs(w[-1].category, SyntaxWarning)
5070-
self.assertIn(
5071-
"Found a layer with softmax activation and single unit "
5072-
"output",
5073-
str(w[-1].message),
5074-
)
50755058
del model
50765059

50775060

0 commit comments

Comments
 (0)