Skip to content

Commit 3052221

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into findex2
2 parents b81b2f8 + c867e52 commit 3052221

File tree

6 files changed

+429
-265
lines changed

6 files changed

+429
-265
lines changed

tests/test_backend.py

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
if is_tf2():
5656
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
57+
conv3d_transpose = tf.compat.v1.nn.conv3d_transpose
5758
multinomial = tf.compat.v1.random.multinomial
5859
space_to_batch_nd = tf.compat.v1.space_to_batch_nd
5960
batch_to_space_nd = tf.compat.v1.batch_to_space_nd
@@ -73,6 +74,7 @@
7374
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
7475
elif LooseVersion(tf.__version__) >= "1.13":
7576
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
77+
conv3d_transpose = tf.compat.v1.nn.conv3d_transpose
7678
multinomial = tf.compat.v1.random.multinomial
7779
space_to_batch_nd = tf.compat.v1.space_to_batch_nd
7880
batch_to_space_nd = tf.compat.v1.batch_to_space_nd
@@ -93,6 +95,7 @@
9395
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
9496
else:
9597
conv2d_backprop_input = tf.nn.conv2d_backprop_input
98+
conv3d_transpose = tf.nn.conv3d_transpose
9699
multinomial = tf.multinomial
97100
space_to_batch_nd = tf.space_to_batch_nd
98101
batch_to_space_nd = tf.batch_to_space_nd
@@ -395,6 +398,24 @@ def test_conv2d_6(self):
395398
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
396399
self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", rtol=1e-05)
397400

401+
def test_conv2d_dilation_same(self):
402+
x_shape = [1, 35, 35, 288] # NHWC
403+
kernel_shape = [3, 3, 288, 384] # [filter_height, filter_width, in_channels, out_channels]
404+
strides = [1, 1, 1, 1] # NHWC
405+
dilations = [1, 3, 1, 1] # NHWC
406+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
407+
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
408+
self._conv_test(x_val, kernel_val, strides=strides, padding="SAME", dilations=dilations, rtol=1e-05)
409+
410+
def test_conv2d_dilation_strides_same(self):
411+
x_shape = [1, 35, 35, 288] # NHWC
412+
kernel_shape = [3, 3, 288, 384] # [filter_height, filter_width, in_channels, out_channels]
413+
strides = [1, 2, 4, 1] # NHWC
414+
dilations = [1, 3, 1, 1] # NHWC
415+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
416+
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
417+
self._conv_test(x_val, kernel_val, strides=strides, padding="SAME", dilations=dilations, rtol=1e-05)
418+
398419
def test_conv3d_1(self):
399420
strides = [1, 1, 1, 1, 1]
400421
dilations = [1, 1, 1, 1, 1]
@@ -3136,45 +3157,38 @@ def func(x):
31363157
@check_opset_min_version(10, "Conv2DBackpropInput")
31373158
def test_Conv2DBackpropInput_const(self):
31383159
input_sizes_val_ = np.array([1, 10, 10, 3], dtype=np.int32)
3139-
filter_val_ = np.random.randint(low=0, high=256, size=[3, 3, 3, 5])
3140-
out_backprop_val_ = np.random.randint(low=0, high=256, size=[1, 10, 10, 5])
3141-
def func():
3160+
def func(filter_val, out_backprop_val):
31423161
input_sizes_val = tf.constant(input_sizes_val_, dtype=tf.int32)
3143-
filter_val = tf.constant(filter_val_, dtype=tf.float32)
3144-
out_backprop_val = tf.constant(out_backprop_val_, dtype=tf.float32)
31453162
return conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val,
31463163
out_backprop=out_backprop_val, strides=[1, 1, 1, 1],
31473164
padding='SAME', name=_TFOUTPUT)
3148-
self._run_test_case(func, [_OUTPUT], {})
3165+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
3166+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
3167+
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val})
31493168

31503169
@check_opset_min_version(10, "Conv2DBackpropInput")
31513170
def test_Conv2DBackpropInput_const_strided(self):
31523171
input_sizes_val_ = np.array([1, 10, 10, 3], dtype=np.int32)
3153-
filter_val_ = np.random.randint(low=0, high=256, size=[3, 3, 3, 5])
3154-
out_backprop_val_ = np.random.randint(low=0, high=256, size=[1, 5, 5, 5])
3155-
3156-
def func():
3172+
def func(filter_val, out_backprop_val):
31573173
input_sizes_val = tf.constant(input_sizes_val_, dtype=tf.int32)
3158-
filter_val = tf.constant(filter_val_, dtype=tf.float32)
3159-
out_backprop_val = tf.constant(out_backprop_val_, dtype=tf.float32)
31603174
return conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val,
31613175
out_backprop=out_backprop_val, strides=[1, 2, 2, 1],
31623176
padding='SAME', name=_TFOUTPUT)
3163-
self._run_test_case(func, [_OUTPUT], {})
3177+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
3178+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
3179+
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val})
31643180

31653181
@check_opset_min_version(10, "Conv2DBackpropInput")
31663182
def test_Conv2DBackpropInput_const_valid(self):
31673183
input_sizes_val_ = np.array([1, 12, 12, 3], dtype=np.int32)
3168-
filter_val_ = np.random.randint(low=0, high=256, size=[3, 3, 3, 5])
3169-
out_backprop_val_ = np.random.randint(low=0, high=256, size=[1, 10, 10, 5])
3170-
def func():
3184+
def func(filter_val, out_backprop_val):
31713185
input_sizes_val = tf.constant(input_sizes_val_, dtype=tf.int32)
3172-
filter_val = tf.constant(filter_val_, dtype=tf.float32)
3173-
out_backprop_val = tf.constant(out_backprop_val_, dtype=tf.float32)
31743186
return conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val,
31753187
out_backprop=out_backprop_val, strides=[1, 1, 1, 1],
31763188
padding='VALID', name=_TFOUTPUT)
3177-
self._run_test_case(func, [_OUTPUT], {})
3189+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
3190+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
3191+
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val})
31783192

31793193
@check_opset_min_version(10, "Conv2DBackpropInput")
31803194
def test_Conv2DBackpropInput(self):
@@ -3206,6 +3220,72 @@ def func(input_sizes, filters, out_backprop):
32063220
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
32073221
self._run_test_case(func, [_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filters_val, _INPUT2: out_backprop_val})
32083222

3223+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3224+
def test_Conv3DBackpropInputV2_const(self):
3225+
output_shape_val_ = np.array([1, 10, 10, 10, 3], dtype=np.int32)
3226+
def func(value, filters):
3227+
output_shape_val = tf.constant(output_shape_val_, dtype=tf.int32)
3228+
return conv3d_transpose(value, filters, output_shape_val, strides=[1, 1, 1, 1, 1],
3229+
padding='SAME', data_format="NDHWC", name=_TFOUTPUT)
3230+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3231+
value_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 10, 5]).astype(np.float32)
3232+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val}, rtol=1e-6)
3233+
3234+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3235+
def test_Conv3DBackpropInputV2_const_strided(self):
3236+
output_shape_val_ = np.array([1, 10, 10, 10, 3], dtype=np.int32)
3237+
def func(value, filters):
3238+
output_shape_val = tf.constant(output_shape_val_, dtype=tf.int32)
3239+
return conv3d_transpose(value, filters, output_shape_val, strides=[1, 2, 2, 2, 1],
3240+
padding='SAME', data_format="NDHWC", name=_TFOUTPUT)
3241+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3242+
value_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5, 5]).astype(np.float32)
3243+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val}, rtol=1e-6)
3244+
3245+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3246+
def test_Conv3DBackpropInputV2_const_valid(self):
3247+
output_shape_val_ = np.array([1, 12, 12, 12, 3], dtype=np.int32)
3248+
def func(value, filters):
3249+
output_shape_val = tf.constant(output_shape_val_, dtype=tf.int32)
3250+
return conv3d_transpose(value, filters, output_shape_val, strides=[1, 1, 1, 1, 1],
3251+
padding='VALID', data_format="NDHWC", name=_TFOUTPUT)
3252+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3253+
value_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 10, 5]).astype(np.float32)
3254+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val}, rtol=1e-6)
3255+
3256+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3257+
def test_Conv3DBackpropInputV2(self):
3258+
def func(value, filters, output_shape):
3259+
return conv3d_transpose(value, filters, output_shape, strides=[1, 1, 1, 1, 1],
3260+
padding='SAME', data_format="NDHWC", name=_TFOUTPUT)
3261+
filters_val = np.random.randint(low=0, high=256, size=[2, 3, 4, 4, 5]).astype(np.float32)
3262+
value_val = np.random.randint(low=0, high=256, size=[2, 7, 8, 9, 5]).astype(np.float32)
3263+
output_shape_val = np.array([2, 7, 8, 9, 4], dtype=np.int32)
3264+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val, _INPUT2: output_shape_val},
3265+
rtol=1e-6)
3266+
3267+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3268+
def test_Conv3DBackpropInputV2_strided(self):
3269+
def func(value, filters, output_shape):
3270+
return conv3d_transpose(value, filters, output_shape, strides=[1, 2, 2, 2, 1],
3271+
padding='SAME', data_format="NDHWC", name=_TFOUTPUT)
3272+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3273+
value_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5, 5]).astype(np.float32)
3274+
output_shape_val = np.array([1, 10, 10, 10, 3], dtype=np.int32)
3275+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val, _INPUT2: output_shape_val},
3276+
rtol=1e-6)
3277+
3278+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3279+
def test_Conv3DBackpropInputV2_valid(self):
3280+
def func(value, filters, output_shape):
3281+
return conv3d_transpose(value, filters, output_shape, strides=[1, 1, 1, 1, 1],
3282+
padding='VALID', data_format="NDHWC", name=_TFOUTPUT)
3283+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3284+
value_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 10, 5]).astype(np.float32)
3285+
output_shape_val = np.array([1, 12, 12, 12, 3], dtype=np.int32)
3286+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val, _INPUT2: output_shape_val},
3287+
rtol=1e-6)
3288+
32093289
@check_opset_min_version(8, "CategoryMapper")
32103290
@skip_tf2()
32113291
def test_hashtable_lookup(self):

tf2onnx/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def data_format(self, val):
148148

149149
def is_nhwc(self):
150150
"""Return True if node is in NHWC format."""
151+
utils.make_sure('D' not in self.data_format, "is_nhwc called on %s with spatial=2 but data_format=%s",
152+
self.name, self.data_format)
151153
return self.data_format == "NHWC"
152154

153155
def is_const(self):

tf2onnx/onnx_opset/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def version_8(cls, ctx, node, **kwargs):
197197
ctx.add_graph_input(output_names[1], type_1, shape_1)
198198

199199

200-
@tf_op("QueueDequeueManyV2")
200+
@tf_op("QueueDequeueManyV2", "QueueDequeueUpToV2")
201201
class QueueDequeueManyV2:
202202
@classmethod
203203
def version_8(cls, ctx, node, **kwargs):

tf2onnx/onnx_opset/nn.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
240240
for i in range(spatial):
241241
pad = (
242242
(output_shape[i + 2] - 1) * strides[i]
243-
+ dilations[i] * kernel_shape[i]
243+
+ dilations[i] * (kernel_shape[i] - 1) + 1
244244
- input_shape[i + 2]
245245
)
246246
pad = max(pad, 0)
@@ -362,8 +362,7 @@ def version_11(cls, ctx, node, **kwargs):
362362
# No change.
363363
cls.version_1(ctx, node, **kwargs)
364364

365-
366-
@tf_op("Conv2DBackpropInput")
365+
@tf_op(["Conv2DBackpropInput", "Conv3DBackpropInputV2"])
367366
class ConvTranspose:
368367
@classmethod
369368
def version_1(cls, ctx, node, **kwargs):
@@ -372,24 +371,36 @@ def version_1(cls, ctx, node, **kwargs):
372371
# T Y = ConvTranspose(T X, T W, T B, @STRING auto_pad, @INTS dilations,
373372
# @INT group, @INTS kernel_shape, @INTS output_shape, @INTS pads, @INTS strides)
374373

374+
if node.type == "Conv3DBackpropInputV2":
375+
spatial = 3
376+
else:
377+
spatial = 2
375378
node.type = "ConvTranspose"
376379
# Note: inputs are reversed from what one would expect.
377-
conv_kernel_shape(ctx, node, 1)
380+
conv_kernel_shape(ctx, node, 1, spatial=spatial)
378381
input_shape = ctx.get_shape(node.input[2])
379382
output_shape_orig = node.output_shapes
380383

381384
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
382385
if node.inputs[0].is_const():
383386
output_shape = ctx.get_shape(node.output[0])
384-
if node.is_nhwc():
387+
if is_channels_last(node):
385388
new_output_shape = [output_shape[1], output_shape[2]]
386-
input_hw = [input_shape[1], input_shape[2]]
389+
input_dims = [input_shape[1], input_shape[2]]
390+
if spatial == 3:
391+
new_output_shape.append(output_shape[3])
392+
input_dims.append(input_shape[3])
387393
else:
388394
new_output_shape = [output_shape[2], output_shape[3]]
389-
input_hw = [input_shape[2], input_shape[3]]
390-
utils.make_sure(new_output_shape.count(-1) <= 0, "output h and w need to be known")
391-
utils.make_sure(new_output_shape[0] >= input_hw[0] and new_output_shape[1] >= input_hw[1],
392-
"output h and w cannot be smaller than input h and w.")
395+
input_dims = [input_shape[2], input_shape[3]]
396+
if spatial == 3:
397+
new_output_shape.append(output_shape[4])
398+
input_dims.append(input_shape[4])
399+
400+
utils.make_sure(new_output_shape.count(-1) <= 0, "output dims need to be known")
401+
utils.make_sure(all(new_output_shape[i] >= input_dims[i] for i in range(spatial)),
402+
"output dims cannot be smaller than input dims.")
403+
393404
node.set_attr("output_shape", new_output_shape)
394405
else:
395406
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
@@ -409,20 +420,37 @@ def version_1(cls, ctx, node, **kwargs):
409420
start_w = ctx.make_node("Div", [diff_w.output[0], const_two.output[0]])
410421
end_h = ctx.make_node("Add", [start_h.output[0], expect_h])
411422
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
412-
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
413-
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0})
414-
const_one_two = ctx.make_const(utils.make_name(node.name + "_const_one_two"),
415-
np.array([1, 2], dtype=np.int64))
423+
if spatial == 3:
424+
output_d = GraphBuilder(ctx).make_slice(
425+
{"data": output_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
426+
expect_d = GraphBuilder(ctx).make_slice(
427+
{"data": input_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
428+
diff_d = ctx.make_node("Sub", [output_d, expect_d])
429+
start_d = ctx.make_node("Div", [diff_d.output[0], const_two.output[0]])
430+
end_d = ctx.make_node("Add", [start_d.output[0], expect_d])
431+
432+
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0], start_d.output[0]],
433+
attr={"axis": 0})
434+
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0], end_d.output[0]], attr={"axis": 0})
435+
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
436+
np.array([1, 2, 3], dtype=np.int64))
437+
else:
438+
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
439+
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0})
440+
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
441+
np.array([1, 2], dtype=np.int64))
442+
416443
slice_node = ctx.make_node("Slice",
417-
[node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]],
444+
[node.output[0], starts.output[0], ends.output[0], slice_axes.output[0]],
418445
shapes=output_shape_orig)
446+
419447
downstream_nodes = ctx.find_output_consumers(node.output[0])
420448
downstream_nodes.remove(output_shape)
421449
downstream_nodes.remove(slice_node)
422450
ctx.replace_all_inputs(downstream_nodes, node.output[0], slice_node.output[0])
423451

424-
conv_dims_attr(node, "strides")
425-
conv_dims_attr(node, "dilations")
452+
conv_dims_attr(node, "strides", spatial=spatial)
453+
conv_dims_attr(node, "dilations", spatial=spatial)
426454

427455
# remove output_shapes input
428456
ctx.remove_input(node, node.input[0], 0)
@@ -431,7 +459,7 @@ def version_1(cls, ctx, node, **kwargs):
431459
ctx.replace_input(node, node.input[0], node.input[1], 0)
432460
ctx.replace_input(node, node.input[1], t, 1)
433461

434-
conv_convert_inputs(ctx, node, with_kernel=True)
462+
conv_convert_inputs(ctx, node, with_kernel=True, spatial=spatial)
435463

436464
@classmethod
437465
def version_11(cls, ctx, node, **kwargs):

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ def _fold_transpose(node, graph) -> list:
110110
const_val_after_trans = const_val.transpose(perm)
111111
return [const_val_after_trans]
112112

113+
@staticmethod
114+
@_register_func("Reshape")
115+
def _fold_reshape(node, graph):
116+
const_val_data = node.inputs[0].get_tensor_value(as_list=False)
117+
const_val_shape = node.inputs[1].get_tensor_value(as_list=False)
118+
const_val_after_trans = const_val_data.reshape(const_val_shape)
119+
return [const_val_after_trans]
120+
113121
@staticmethod
114122
@_register_func("Unsqueeze")
115123
def _fold_unsqueeze(node, graph):

0 commit comments

Comments
 (0)