Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,28 @@ def func(x):
return tf.identity(y[0], name="output"), tf.identity(y[1], name="output1")
self.run_test_case(func, {"input:0": x_val}, [], ["output:0", "output1:0"], rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_lstm_recurrent_activation_is_hard_sigmoid(self):
in_shape = [10, 3]
x_val = np.random.uniform(size=[2, 10, 3]).astype(np.float32)

model_in = tf.keras.layers.Input(tuple(in_shape), batch_size=2)
x = tf.keras.layers.LSTM(
units=5,
return_sequences=True,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_activation="hard_sigmoid"
)(model_in)
model = tf.keras.models.Model(inputs=model_in, outputs=x)

def func(x):
y = model(x)
return tf.identity(y[0], name="output"), tf.identity(y[1], name="output1")
self.run_test_case(func, {"input:0": x_val}, [], ["output:0", "output1:0"], rtol=1e-05, atol=1e-06)

if __name__ == '__main__':
unittest_main()
51 changes: 37 additions & 14 deletions tf2onnx/rewriter/lstm_tf2_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,52 @@

# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable

def _make_lstm_pattern_from_params(params):
return make_lstm_pattern(enter_or_id="Identity") if not params.get("from_keras", False) \
else make_lstm_pattern(
from_keras=True,
use_bias=params.get("use_bias", False),
activation=params.get("activation", ""),
recurrent_activation=params.get("recurrent_activation", "")
)

def rewriter_lstm_tf2(g, ops):

pattern1 = make_lstm_pattern(enter_or_id="Identity") # TF LSTM
pattern2 = make_lstm_pattern(from_keras=True, use_bias=False) # keras LSTM
pattern3 = make_lstm_pattern(from_keras=True, use_bias=True) # keras LSTM with bias

for pattern in [pattern1, pattern2, pattern3]:
lstm_params_variations = [
# default activations
{"enter_or_id": "Identity"}, # TF LSTM
{"from_keras": True, "use_bias": False}, # keras LSTM
{"from_keras": True, "use_bias": True}, # keras LSTM with bias
# hard sigmoid as recurrent activation
{"from_keras": True, "use_bias": False, "recurrent_activation": "hard_sigmoid"}, # keras LSTM
{"from_keras": True, "use_bias": True, "recurrent_activation": "hard_sigmoid"} # keras LSTM with bias
# Note: add other LSTM variations as needed
]
for params in lstm_params_variations:
pattern = _make_lstm_pattern_from_params(params)
matcher = GraphMatcher(pattern, allow_reorder=False)
match_results = list(matcher.match_ops(ops))

for match_result in match_results:
from_keras = pattern != pattern1
is_ft_hard_sigmoid = params.get("recurrent_activation", "") == "hard_sigmoid"
recurrent_activation_f = "HardSigmoid" if is_ft_hard_sigmoid else \
match_result.get_op("ft").type
activation_g = match_result.get_op("gt").type
activation_h = match_result.get_op("ct'").type

default_activations = ["Relu", "Sigmoid", "Tanh"]
if ((activation_g not in default_activations) or
(activation_h not in default_activations) or
(not is_ft_hard_sigmoid and recurrent_activation_f not in default_activations)):
continue

activations_fgh = [
match_result.get_op("ft").type,
match_result.get_op("gt").type,
match_result.get_op("ct'").type
recurrent_activation_f,
activation_g,
activation_h
]
supported_activations = ['Relu', 'Sigmoid', 'Tanh']
if any(f not in supported_activations for f in activations_fgh):
continue

# extract input x_t
from_keras = params.get("from_keras", False)
if from_keras:
get_item = match_result.get_op("xt")
else:
Expand Down Expand Up @@ -134,7 +157,7 @@ def has_tensor_list_consumer(n):

# Wb and Rb are concatenated
b_idx = None
if pattern is pattern3:
if from_keras and params.get("use_bias", False):
bias_add = match_result.get_op("bias_add")
if bias_add is not None and bias_add.data_format != "NHWC":
continue
Expand Down
52 changes: 35 additions & 17 deletions tf2onnx/rewriter/rnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@ class REWRITER_RESULT(Enum):

# TensorFlow LSTMCell/BasicLSTMCell and Keras LSTM computation graph matching

def insert_activation(activation, name="", inputs=None):
inputs = inputs if inputs else [] # to avoid empty list as default arg
if activation == "hard_sigmoid":
return OpTypePattern("Maximum", inputs=[
OpTypePattern("Minimum", inputs=[
OpTypePattern("Add|AddV2", inputs=[
OpTypePattern("Mul", inputs=[
*inputs,
OpTypePattern("*") # mul(x, 0.2)
]), OpTypePattern("*") # add(x, 0.5)
]), OpTypePattern("*") # minimum(x, 1)
]), OpTypePattern("*") # maximum(x, 0)
])
# Additional activation pattern can be added when needed:
# https://www.tensorflow.org/api_docs/python/tf/keras/activations
# otherwise, use default activations
return OpTypePattern("Tanh|Relu|Sigmoid", name=name, inputs=inputs)


def make_lstm_xc_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
if from_keras:
lstm_xh_pattern = OpTypePattern("Add|AddV2", allow_reorder=False, inputs=[
Expand Down Expand Up @@ -63,7 +82,8 @@ def make_lstm_xc_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
])


def make_lstm_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
def make_lstm_pattern(enter_or_id="Enter", from_keras=False, use_bias=False,
activation="", recurrent_activation=""):
# split (Xt*(W[ifco]^T) + Ht-1*(R[ifco]^T)) on 'Const' axis
lstm_xc_pattern = OpTypePattern('Split', inputs=[
OpTypePattern("Const"),
Expand All @@ -77,23 +97,21 @@ def make_lstm_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
OpTypePattern("*", name="ft_bias"),
])

activation = "Tanh|Relu|Sigmoid"
recurrent_activation = "Tanh|Relu|Sigmoid"

return OpTypePattern("Mul", name='ht', inputs=[
OpTypePattern(recurrent_activation, name="ot", inputs=[lstm_xc_pattern]),
OpTypePattern(activation, name="ct'", inputs=[
OpTypePattern("Add|AddV2", name="ct", inputs=[
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
OpTypePattern(recurrent_activation, name="ft", inputs=[lstm_fb_pattern]),
OpTypePattern("*", name="c"),
]),
OpTypePattern("Mul", inputs=[
OpTypePattern(recurrent_activation, name="it", inputs=[lstm_xc_pattern]),
OpTypePattern(activation, name="gt", inputs=[lstm_xc_pattern]),
]),
]),
# cell state
lstm_ct_pattern = OpTypePattern("Add|AddV2", name="ct", inputs=[
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
insert_activation(recurrent_activation, name="ft", inputs=[lstm_fb_pattern]),
OpTypePattern("*", name="c"),
]),
OpTypePattern("Mul", inputs=[
insert_activation(recurrent_activation, name="it", inputs=[lstm_xc_pattern]),
insert_activation(activation, name="gt", inputs=[lstm_xc_pattern]),
]),
])

return OpTypePattern("Mul", name="ht", inputs=[
insert_activation(recurrent_activation, name="ot", inputs=[lstm_xc_pattern]),
insert_activation(activation, name="ct'", inputs=[lstm_ct_pattern]),
])

lstmcell_pattern = make_lstm_pattern()
Expand Down