diff --git a/tests/huggingface.py b/tests/huggingface.py index 27deeaea6..849907121 100644 --- a/tests/huggingface.py +++ b/tests/huggingface.py @@ -145,6 +145,23 @@ def test_TFDisillBertModel(self): outputs = ["start_logits", "end_logits"] self.run_test(model, input_dict, input_signature=spec, outputs=outputs, rtol=1e-5) + ## FUNNEL + + def _test_TFFunnelSquad(self, size, large=False): + from transformers import FunnelTokenizer, TFFunnelForQuestionAnswering + tokenizer = FunnelTokenizer.from_pretrained(size) + model = TFFunnelForQuestionAnswering.from_pretrained(size) + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + input_dict = tokenizer(question, text, return_tensors='tf') + spec = (tf.TensorSpec((None, 14), tf.int32, name="input_ids"), + tf.TensorSpec((None, 14), tf.int32, name="token_type_ids"), + tf.TensorSpec((None, 14), tf.int32, name="attention_mask")) + outputs = ["start_logits", "end_logits"] + self.run_test(model, input_dict, input_signature=spec, outputs=outputs, rtol=1e-5) + + def test_TFFunnelSquadSmall(self): + self._test_TFFunnelSquad("funnel-transformer/small") + ## T5 def _test_TFT5Model(self, size, large=False): diff --git a/tests/test_backend.py b/tests/test_backend.py index 4be6db3a6..57f7ce518 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -319,6 +319,15 @@ def func(x): self.logger.debug(str(p)) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_tf_min_version("1.15", "required for max_pool args") + def test_maxpool_int(self): + x_shape = [8, 16, 16, 3] + x_val = make_xval(x_shape).astype("int32") + def func(x): + mp = tf.nn.max_pool(x, ksize=[2], strides=[1, 2, 2, 1], padding="SAME") + return tf.identity(mp, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_tf_cpu("only tf_gpu can run maxpool with NCHW format") def test_maxpool_gpu(self): # make sure converter behaves well when data format is NCHW diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 0b48d9019..9d0f9966f 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -623,6 +623,18 @@ def _convert(cls, ctx, node, **kwargs): else: spatial = 2 + origin_dtype = ctx.get_dtype(node.output[0]) + if origin_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]: + # the onnx spec doesn't allow int types for pool ops + input_shapes = [ctx.get_shape(node.input[0])] + output_shapes = [ctx.get_shape(node.output[0])] + cast_node = ctx.make_node("Cast", [node.input[0]], dtypes=[onnx_pb.TensorProto.FLOAT], shapes=input_shapes, + name=node.name + "_cast", attr={"to": onnx_pb.TensorProto.FLOAT}) + _ = ctx.insert_node_on_output(cast_node, node.inputs[0].output[0]) + cast_back_node = ctx.make_node("Cast", [node.output[0]], dtypes=[origin_dtype], shapes=output_shapes, + name=node.name + "_castback", attr={"to": origin_dtype}) + _ = ctx.insert_node_on_output(cast_back_node, node.output[0]) + if len(node.input) < 3: kernel_shape_tf = node.get_attr("ksize").ints strides_tf = node.get_attr("strides").ints