Skip to content

Support constants in StackedRNNCells #9089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 19, 2018

Conversation

yuyang-huang
Copy link
Contributor

This PR is probably useful for implementing (some variants of) attention RNNs.

  1. Enable the constants argument in StackedRNNCells.call. The constants will be routed to an underlying cell if it also supports constants.
  2. Handle the shape of the constants properly in StackedRNNCells.build so that an RNN cell with constants can be stacked on top of another cell. Before this PR,
  • If the bottom cell does not support constants:
constants = Input((32,))
x = Input((None, 5))
x = RNN([GRUCell(16), RNNCellWithConstants(32)])(x, constants=constants)

The GRUCell fails to build because input_shape, which is now a list [step_input_shape] + constants_shape, is passed to GRUCell.build directly. We should only pass step_input_shape to GRUCell.build in this case.

  • If the bottom cell also supports constants:
constants = Input((32,))
x = Input((None, 5))
x = RNN([RNNCellWithConstants(16), RNNCellWithConstants(32)])(x, constants=constants)

The top cell fails to build because the input_shape passed to it is incorrect.

Also modified the test to cover the two cases above.

@fchollet
Copy link
Collaborator

Re-running tests

@fchollet fchollet closed this Jan 18, 2018
@fchollet fchollet reopened this Jan 18, 2018
@fchollet
Copy link
Collaborator

That's a reasonable behavior change. Code LGTM.

@yuyang-huang
Copy link
Contributor Author

Thanks for reviewing! Seems like the tests are passing now.

@fchollet fchollet merged commit 950e5d0 into keras-team:master Jan 19, 2018
@yuyang-huang yuyang-huang deleted the stacked-rnn-cells-constants branch January 19, 2018 01:44
@LeZhengThu
Copy link

Hi, can anyone help explain the argument "constants"? I'm a bit lost about the argument. Thanks.

@fchollet
Copy link
Collaborator

fchollet commented Apr 6, 2018

They're constant tensors passed to the underlying cell at every call.

@LeZhengThu
Copy link

@fchollet Does these tensors have timesteps? Are they something like the static features fed in the network?

@fchollet
Copy link
Collaborator

fchollet commented Apr 6, 2018

No, they don't have timesteps. They're passed to each cell call. They have no concept of time, that's why they're called "constants".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants