Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,23 @@ def test_TFDisillBertModel(self):
outputs = ["start_logits", "end_logits"]
self.run_test(model, input_dict, input_signature=spec, outputs=outputs, rtol=1e-5)

## FUNNEL

def _test_TFFunnelSquad(self, size, large=False):
from transformers import FunnelTokenizer, TFFunnelForQuestionAnswering
tokenizer = FunnelTokenizer.from_pretrained(size)
model = TFFunnelForQuestionAnswering.from_pretrained(size)
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_dict = tokenizer(question, text, return_tensors='tf')
spec = (tf.TensorSpec((None, 14), tf.int32, name="input_ids"),
tf.TensorSpec((None, 14), tf.int32, name="token_type_ids"),
tf.TensorSpec((None, 14), tf.int32, name="attention_mask"))
outputs = ["start_logits", "end_logits"]
self.run_test(model, input_dict, input_signature=spec, outputs=outputs, rtol=1e-5)

def test_TFFunnelSquadSmall(self):
self._test_TFFunnelSquad("funnel-transformer/small")

## T5

def _test_TFT5Model(self, size, large=False):
Expand Down
9 changes: 9 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ def func(x):
self.logger.debug(str(p))
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_tf_min_version("1.15", "required for max_pool args")
def test_maxpool_int(self):
x_shape = [8, 16, 16, 3]
x_val = make_xval(x_shape).astype("int32")
def func(x):
mp = tf.nn.max_pool(x, ksize=[2], strides=[1, 2, 2, 1], padding="SAME")
return tf.identity(mp, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@skip_tf_cpu("only tf_gpu can run maxpool with NCHW format")
def test_maxpool_gpu(self):
# make sure converter behaves well when data format is NCHW
Expand Down
12 changes: 12 additions & 0 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,18 @@ def _convert(cls, ctx, node, **kwargs):
else:
spatial = 2

origin_dtype = ctx.get_dtype(node.output[0])
if origin_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
# the onnx spec doesn't allow int types for pool ops
input_shapes = [ctx.get_shape(node.input[0])]
output_shapes = [ctx.get_shape(node.output[0])]
cast_node = ctx.make_node("Cast", [node.input[0]], dtypes=[onnx_pb.TensorProto.FLOAT], shapes=input_shapes,
name=node.name + "_cast", attr={"to": onnx_pb.TensorProto.FLOAT})
_ = ctx.insert_node_on_output(cast_node, node.inputs[0].output[0])
cast_back_node = ctx.make_node("Cast", [node.output[0]], dtypes=[origin_dtype], shapes=output_shapes,
name=node.name + "_castback", attr={"to": origin_dtype})
_ = ctx.insert_node_on_output(cast_back_node, node.output[0])

if len(node.input) < 3:
kernel_shape_tf = node.get_attr("ksize").ints
strides_tf = node.get_attr("strides").ints
Expand Down