Skip to content

Add handling of HardSigmoid recurrent activation for Keras LSTM #2001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 22, 2022

Conversation

q-ycong-p
Copy link
Contributor

@q-ycong-p q-ycong-p commented Jul 15, 2022

This commit adds pattern matching and parsing where Keras LSTM uses
HardSigmoid as the recurrent activation. Refactored to allow addition
of other activations when needed.

Signed-off-by: Yu Cong [email protected]

This commit adds pattern matching and parsing where Keras LSTM uses
HardSigmoid as the recurrent activation. Refactored to allow addition
of other activations when needed. Non-default activation is not needed
for custum LSTMs (e.g. LSTMLN) hence not added in this CR.

Signed-off-by: Yu Cong <[email protected]>
@q-ycong-p
Copy link
Contributor Author

The existing tf2 lstm rewriter exclusively handles {sigmoid,tanh,relu} as {recurrent,}activations. There are other activations defined in tf.keras.activations. For example, a tf.keras.layers.LSTM layer's recurrent activation may be hard_sigmoid. tf2onnx frozen graph alters patterns to represent the hard sigmoid computation as low-level ops (see attached screenshot). This messes existing pattern matching and fails tf2 lstm rewriter. This PR extends support to handle this variation, and refactors pattern matching to allow easy plug-in of other activation patterns in the future (i.e. ad hoc as needed).

See below for minimal example to reproduce the original issue. Prior to this PR, loop op is produced from Keras LSTM layer. This PR extends rewriter pattern matching to convert LSTM node correctly.

import tensorflow as tf
import tf2onnx

input = tf.keras.Input(shape=(10, 8))
lstm_output = tf.keras.layers.LSTM(4, recurrent_activation="hard_sigmoid")(input)
model = tf.keras.Model(input, lstm_output)
onnx_model, _ = tf2onnx.from_keras(model) # onnx model contains loop instead of LSTM op

Screen Shot 2022-07-15 at 2 23 29 PM

@hwangdeyu hwangdeyu added the keras Issues related to Keras label Jul 18, 2022
Copy link
Contributor

@hwangdeyu hwangdeyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for your contribution!

The HardSigmoid as the recurrent activation would be very helpful. Could you also tell us which model you are using to come up with this issue?

@q-ycong-p
Copy link
Contributor Author

Hi @hwangdeyu , thank you for taking a look. The original model is internal. The general architecture is a simple LSTM encoder. If all is good, would you mind merging this change? Thanks!

@hwangdeyu hwangdeyu merged commit 71105c1 into onnx:main Jul 22, 2022
@q-ycong-p q-ycong-p deleted the lstm_hardsigmoid_ra branch August 3, 2022 17:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
keras Issues related to Keras
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants