Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit a6b432f

Browse files
waytrue17Wei Chu
andauthored
[v1.x] ONNX support rewrite norm (#20195)
* fix norm * fix sanity Co-authored-by: Wei Chu <weichu@amazon.com>
1 parent 8c15875 commit a6b432f

File tree

3 files changed

+98
-18
lines changed

3 files changed

+98
-18
lines changed

python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,7 @@ def convert_norm(node, **kwargs):
23932393
"""Map MXNet's norm operator attributes to onnx's ReduceL1 and ReduceL2 operators
23942394
and return the created node.
23952395
"""
2396+
from onnx.helper import make_node
23962397
name, input_nodes, attrs = get_inputs(node, kwargs)
23972398

23982399
mx_axis = attrs.get("axis", None)
@@ -2407,24 +2408,32 @@ def convert_norm(node, **kwargs):
24072408
onnx_op_name = "ReduceL1" if ord == 1 else "ReduceL2"
24082409

24092410
if axes:
2410-
reduce_node = onnx.helper.make_node(
2411-
onnx_op_name,
2412-
input_nodes,
2413-
[name],
2414-
axes=axes,
2415-
keepdims=keepdims,
2416-
name=name
2417-
)
2418-
return [reduce_node]
2411+
if keepdims:
2412+
reduce_node = make_node(onnx_op_name, input_nodes, [name], axes=axes, keepdims=keepdims)
2413+
return [reduce_node]
2414+
else:
2415+
create_tensor([1], name+'_1', kwargs['initializer'])
2416+
nodes = [
2417+
make_node(onnx_op_name, input_nodes, [name+'_norm'], axes=axes, keepdims=keepdims),
2418+
make_node('Shape', [name+'_norm'], [name+'_norm_shape']),
2419+
make_node('Concat', [name+'_1', name+'_norm_shape'], [name+'_concat'], axis=0),
2420+
make_node('Reshape', [name+'_norm', name+'_concat'], [name+'_reshape']),
2421+
make_node('Squeeze', [name+'_reshape'], [name], axes=[0]),
2422+
]
2423+
return nodes
24192424
else:
2420-
reduce_node = onnx.helper.make_node(
2421-
onnx_op_name,
2422-
input_nodes,
2423-
[name],
2424-
keepdims=keepdims,
2425-
name=name
2426-
)
2427-
return [reduce_node]
2425+
2426+
if keepdims:
2427+
reduce_node = make_node(onnx_op_name, input_nodes, [name], keepdims=keepdims)
2428+
return [reduce_node]
2429+
else:
2430+
create_tensor([1], name+'_1', kwargs['initializer'])
2431+
nodes = [
2432+
make_node(onnx_op_name, input_nodes, [name+'_norm'], keepdims=keepdims),
2433+
make_node('Reshape', [name+'_norm', name+'_1'], [name])
2434+
]
2435+
return nodes
2436+
24282437

24292438
@mx_op.register("_sample_multinomial")
24302439
def convert_multinomial(node, **kwargs):

python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,3 +1544,57 @@ def convert_squeeze(node, **kwargs):
15441544
name=name,
15451545
)
15461546
return [node]
1547+
1548+
1549+
@mx_op.register("norm", OPSET_VERSION)
1550+
def convert_norm(node, **kwargs):
1551+
"""Map MXNet's norm operator attributes to onnx's ReduceL1 and ReduceL2 operators
1552+
and return the created node.
1553+
"""
1554+
from onnx.helper import make_node
1555+
name, input_nodes, attrs = get_inputs(node, kwargs)
1556+
1557+
mx_axis = attrs.get("axis", None)
1558+
axes = convert_string_to_list(str(mx_axis)) if mx_axis else None
1559+
1560+
keepdims = get_boolean_attribute_value(attrs, "keepdims")
1561+
ord = int(attrs.get("ord", 2))
1562+
1563+
onnx_op_name = "ReduceL1" if ord == 1 else "ReduceL2"
1564+
1565+
if axes:
1566+
if keepdims:
1567+
reduce_node = make_node(onnx_op_name, input_nodes, [name], axes=axes, keepdims=keepdims)
1568+
return [reduce_node]
1569+
else:
1570+
create_tensor([1], name+'_1', kwargs['initializer'])
1571+
create_tensor([0], name+'_0', kwargs['initializer'])
1572+
create_tensor([len(axes)], name+'_axes_dim', kwargs['initializer'])
1573+
nodes = [
1574+
make_node(onnx_op_name, input_nodes, [name+'_reduce'], axes=axes, keepdims=keepdims),
1575+
make_node('Shape', [name+'_reduce'], [name+'_reduce_shape']),
1576+
make_node('Shape', [name+'_reduce_shape'], [name+'_reduce_dim']),
1577+
make_node('Shape', [input_nodes[0]], [name+'_in_shape']),
1578+
make_node('Shape', [name+'_in_shape'], [name+'_in_dim']),
1579+
make_node('Equal', [name+'_axes_dim', name+'_in_dim'], [name+'_equal']),
1580+
make_node('Where', [name+'_equal', name+'_1', name+'_reduce_dim'], [name+'_where0']),
1581+
make_node('Tile', [name+'_0', name+'_where0'], [name+'_tile']),
1582+
make_node('Unsqueeze', [name+'_0', name+'_0'], [name+'_unsqueeze']),
1583+
make_node('Where', [name+'_equal', name+'_1', name+'_0'], [name+'_where1']),
1584+
make_node('ScatterND', [name+'_tile', name+'_unsqueeze', name+'_where1'], [name+'_SND']),
1585+
make_node('Reshape', [name+'_reduce', name+'_SND'], [name]),
1586+
]
1587+
return nodes
1588+
else:
1589+
1590+
if keepdims:
1591+
reduce_node = make_node(onnx_op_name, input_nodes, [name], keepdims=keepdims)
1592+
return [reduce_node]
1593+
else:
1594+
create_tensor([1], name+'_1', kwargs['initializer'])
1595+
nodes = [
1596+
make_node(onnx_op_name, input_nodes, [name+'_norm'], keepdims=keepdims),
1597+
make_node('Reshape', [name+'_norm', name+'_1'], [name])
1598+
]
1599+
return nodes
1600+

tests/python-pytest/onnx/test_operators.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1488,7 +1488,6 @@ def test_onnx_export_maximum_minimum(tmp_path, dtype, shape, op_name):
14881488
op_export_test(op_name, M, [lhs, rhs], tmp_path)
14891489

14901490

1491-
14921491
# onnx reduce ops do not support float64
14931492
@pytest.mark.parametrize('dtype', ['float16', 'float32','int32', 'int64'])
14941493
@pytest.mark.parametrize('shape', [(2, 3), (4, 5, 6)])
@@ -1557,6 +1556,24 @@ def test_onnx_export_squeeze(tmp_path, dtype, shape_axis):
15571556
op_export_test('squeeze', M, [x], tmp_path)
15581557

15591558

1559+
@pytest.mark.parametrize("dtype", ["float16", "float32"])
1560+
@pytest.mark.parametrize("order", [1, 2])
1561+
@pytest.mark.parametrize("keepdims", [0, 1])
1562+
@pytest.mark.parametrize("axis", [None, 0, 1, 2, -1, (0, 2), (0, 1, 2)])
1563+
@pytest.mark.parametrize("shape", [(4, 5, 6), (3, 4, 5, 6)])
1564+
def test_onnx_export_norm(tmp_path, dtype, order, axis, shape, keepdims):
1565+
kwargs = {}
1566+
if order is not None:
1567+
kwargs['ord'] = order
1568+
if axis is not None:
1569+
kwargs['axis'] = axis
1570+
if keepdims is not None:
1571+
kwargs['keepdims'] = keepdims
1572+
M = def_model('norm', **kwargs)
1573+
x = mx.random.normal(0, 10, shape).astype(dtype)
1574+
op_export_test('norm', M, [x], tmp_path)
1575+
1576+
15601577
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
15611578
@pytest.mark.parametrize("shape", [(10,), (2,3), (4,5,6)])
15621579
def test_onnx_export_logical_not(tmp_path, dtype, shape):

0 commit comments

Comments
 (0)