Skip to content

Commit 950e5d0

Browse files
myutwo150fchollet
authored andcommitted
Support constants in StackedRNNCells (#9089)
* Support constants in StackedRNNCells * Add test * Add missing test wrappers
1 parent a6542e8 commit 950e5d0

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

keras/layers/recurrent.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def state_size(self):
7373
state_size.append(cell.state_size)
7474
return tuple(state_size)
7575

76-
def call(self, inputs, states, **kwargs):
76+
def call(self, inputs, states, constants=None, **kwargs):
7777
# Recover per-cell states.
7878
nested_states = []
7979
for cell in self.cells[::-1]:
@@ -88,7 +88,12 @@ def call(self, inputs, states, **kwargs):
8888
# Call the cells in order and store the returned states.
8989
new_nested_states = []
9090
for cell, states in zip(self.cells, nested_states):
91-
inputs, states = cell.call(inputs, states, **kwargs)
91+
if has_arg(cell.call, 'constants'):
92+
inputs, states = cell.call(inputs, states,
93+
constants=constants,
94+
**kwargs)
95+
else:
96+
inputs, states = cell.call(inputs, states, **kwargs)
9297
new_nested_states.append(states)
9398

9499
# Format the new states as a flat list
@@ -99,9 +104,15 @@ def call(self, inputs, states, **kwargs):
99104
return inputs, states
100105

101106
def build(self, input_shape):
107+
if isinstance(input_shape, list):
108+
constants_shape = input_shape[1:]
109+
input_shape = input_shape[0]
102110
for cell in self.cells:
103111
if isinstance(cell, Layer):
104-
cell.build(input_shape)
112+
if has_arg(cell.call, 'constants'):
113+
cell.build([input_shape] + constants_shape)
114+
else:
115+
cell.build(input_shape)
105116
if hasattr(cell.state_size, '__len__'):
106117
output_dim = cell.state_size[0]
107118
else:

tests/keras/layers/recurrent_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,7 @@ def test_batch_size_equal_one(layer_class):
698698
model.train_on_batch(x, y)
699699

700700

701+
@keras_test
701702
def test_rnn_cell_with_constants_layer():
702703

703704
class RNNCellWithConstants(keras.layers.Layer):
@@ -778,7 +779,35 @@ def get_config(self):
778779
y_np_3 = model.predict([x_np, c_np])
779780
assert_allclose(y_np, y_np_3, atol=1e-4)
780781

782+
# Test stacking.
783+
cells = [recurrent.GRUCell(8),
784+
RNNCellWithConstants(12),
785+
RNNCellWithConstants(32)]
786+
layer = recurrent.RNN(cells)
787+
y = layer(x, constants=c)
788+
model = keras.models.Model([x, c], y)
789+
model.compile(optimizer='rmsprop', loss='mse')
790+
model.train_on_batch(
791+
[np.zeros((6, 5, 5)), np.zeros((6, 3))],
792+
np.zeros((6, 32))
793+
)
794+
795+
# Test stacked RNN serialization.
796+
x_np = np.random.random((6, 5, 5))
797+
c_np = np.random.random((6, 3))
798+
y_np = model.predict([x_np, c_np])
799+
weights = model.get_weights()
800+
config = layer.get_config()
801+
with keras.utils.CustomObjectScope(custom_objects):
802+
layer = recurrent.RNN.from_config(config.copy())
803+
y = layer(x, constants=c)
804+
model = keras.models.Model([x, c], y)
805+
model.set_weights(weights)
806+
y_np_2 = model.predict([x_np, c_np])
807+
assert_allclose(y_np, y_np_2, atol=1e-4)
808+
781809

810+
@keras_test
782811
def test_rnn_cell_with_constants_layer_passing_initial_state():
783812

784813
class RNNCellWithConstants(keras.layers.Layer):

0 commit comments

Comments
 (0)