Skip to content

Commit 8dc9e2b

Browse files
Full conversion of MatrixBandPart
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent ad4c792 commit 8dc9e2b

File tree

2 files changed

+102
-10
lines changed

2 files changed

+102
-10
lines changed

tests/test_backend.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,6 +2979,35 @@ def func(input_x):
29792979
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
29802980
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
29812981

2982+
@check_opset_min_version(11, "CumSum")
2983+
def test_matrix_band_part_3(self):
2984+
for low, high in [(-1, 3), (2, 3), (4, 3), (0, -1), (0, 0)]:
2985+
input_val = np.random.randint(0, 666, (10, 15)).astype(np.int32)
2986+
def func(input_x):
2987+
res = tf.linalg.band_part(input_x, low, high)
2988+
return tf.identity(res, name=_TFOUTPUT)
2989+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
2990+
2991+
@check_opset_min_version(11, "CumSum")
2992+
def test_matrix_band_part_4(self):
2993+
for low, high in [(-1, 3), (2, 3), (4, 3), (0, -1), (0, 0)]:
2994+
input_val = np.random.randint(0, 666, (2, 3, 10, 15)).astype(np.int32)
2995+
def func(input_x):
2996+
res = tf.linalg.band_part(input_x, low, high)
2997+
return tf.identity(res, name=_TFOUTPUT)
2998+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
2999+
3000+
@check_opset_min_version(11, "CumSum")
3001+
def test_matrix_band_part_5(self):
3002+
for low_val, high_val in [(2, 3), (4, 3), (0, 0), (2, 0)]:
3003+
low_val = np.array(low_val, np.int32)
3004+
high_val = np.array(high_val, np.int32)
3005+
input_val = np.random.randint(0, 666, (2, 3, 10, 15)).astype(np.int32)
3006+
def func(input_x, low, high):
3007+
res = tf.linalg.band_part(input_x, low, high)
3008+
return tf.identity(res, name=_TFOUTPUT)
3009+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val, _INPUT1: low_val, _INPUT2: high_val})
3010+
29823011
def test_floordiv(self):
29833012
input_val_1 = np.random.random_sample(100).astype(np.int32)
29843013
input_val_2 = (np.random.random_sample(100) + 1).astype(np.int32)

tf2onnx/onnx_opset/nn.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import logging
1313

1414
import numpy as np
15-
from onnx import onnx_pb
15+
from onnx import onnx_pb, helper
1616
from onnx.onnx_pb import TensorProto
1717
from tf2onnx import constants, utils
1818
from tf2onnx.graph_builder import GraphBuilder
@@ -1188,13 +1188,13 @@ def version_11(cls, ctx, node, **kwargs):
11881188
@tf_op("MatrixBandPart")
11891189
class MatrixBandPart:
11901190
@classmethod
1191-
def any_version_after7(cls, opset, ctx, node, **kwargs):
1191+
def version_7(cls, opset, ctx, node, **kwargs):
11921192
# T output = MatrixBandPart(T input, int num_lower, int num_upper)
11931193
# data-flow: first generate mask matrix and then use element-wise mul op
11941194
input_rank = len(ctx.get_shape(node.input[0]))
11951195
utils.make_sure(input_rank == 2, error_msg="MatrixBandPart op: only rank 2 is supported")
11961196
bandpart = [node.inputs[ind].get_tensor_value() for ind in [1, 2]]
1197-
utils.make_sure(bandpart in [[-1, 0], [0, -1]], "only support Lower/Upper triangular for now")
1197+
utils.make_sure(bandpart in [[-1, 0], [0, -1]], "only support Lower/Upper triangular for opset < 11")
11981198
# methods to generate mask matrix: if lower triangular is needed, then generate column one by one
11991199
# otherwise row is generated one by one.
12001200
axis, counter_axis, squeeze_axis = (1, 0, 2) if bandpart == [-1, 0] else (0, 1, 1)
@@ -1267,13 +1267,76 @@ def any_version_after7(cls, opset, ctx, node, **kwargs):
12671267
dtypes=dtypes)
12681268

12691269
@classmethod
1270-
def version_7(cls, ctx, node, **kwargs):
1271-
cls.any_version_after7(7, ctx, node, **kwargs)
1272-
1273-
@classmethod
1274-
def version_13(cls, ctx, node, **kwargs):
1275-
# Signature of operator Squeeze changed.
1276-
cls.any_version_after7(13, ctx, node, **kwargs)
1270+
def version_11(cls, ctx, node, **kwargs):
1271+
num_lower_const = node.inputs[1].get_tensor_value() if node.inputs[1].is_const() else None
1272+
num_upper_const = node.inputs[2].get_tensor_value() if node.inputs[2].is_const() else None
1273+
data, num_lower, num_upper = node.input
1274+
rank = ctx.get_rank(data)
1275+
int_max_val = utils.get_max_value(np.int64)
1276+
dtype = ctx.get_dtype(data)
1277+
np_dtype = utils.map_onnx_to_numpy_type(dtype)
1278+
if rank == 2:
1279+
shape = ctx.make_node("Shape", [data]).output[0]
1280+
else:
1281+
whole_shape = ctx.make_node("Shape", [data]).output[0]
1282+
shape = GraphBuilder(ctx).make_slice(
1283+
{'data': whole_shape, 'starts': [-2], 'ends': [int_max_val], 'axes': [0]})
1284+
if num_lower_const == 0 and num_upper_const == 0:
1285+
if rank == 2:
1286+
identity_node = ctx.make_node("EyeLike", [data]).output[0]
1287+
else:
1288+
zero_tensor = helper.make_tensor("value", dtype, dims=[1], vals=[0])
1289+
const_of_shape = ctx.make_node("ConstantOfShape", [shape], attr={'value': zero_tensor}).output[0]
1290+
identity_node = ctx.make_node("EyeLike", [const_of_shape]).output[0]
1291+
one_const = ctx.make_const(utils.make_name("one"), np.array(1, np_dtype)).output[0]
1292+
mask = ctx.make_node("Sub", [one_const, identity_node]).output[0]
1293+
shapes = node.output_shapes
1294+
dtypes = node.output_dtypes
1295+
ctx.remove_node(node.name)
1296+
ctx.make_node(op_type="Mul", inputs=[identity_node, data],
1297+
name=node.name, outputs=node.output, shapes=shapes,
1298+
dtypes=dtypes)
1299+
return
1300+
zero_const = ctx.make_const(utils.make_name("zero"), np.array(0, np.int64)).output[0]
1301+
one_const = ctx.make_const(utils.make_name("one"), np.array(1, np.int64)).output[0]
1302+
conditions = []
1303+
row_cnt = GraphBuilder(ctx).make_slice({'data': shape, 'axes': [0], 'starts': [0], 'ends': [1]})
1304+
col_cnt = GraphBuilder(ctx).make_slice({'data': shape, 'axes': [0], 'starts': [1], 'ends': [2]})
1305+
limit = ctx.make_node("Mul", [row_cnt, col_cnt]).output[0]
1306+
# idx_cnt = ctx.make_node("Range", [zero_const, limit, one_const]).output[0]
1307+
1308+
ones_of_shape = ctx.make_node("Expand", [one_const, limit]).output[0]
1309+
idx_cnt = ctx.make_node("CumSum", [ones_of_shape, zero_const], attr={'exclusive': True}).output[0]
1310+
1311+
idx_reshape = ctx.make_node("Reshape", [idx_cnt, shape]).output[0]
1312+
row_idx = ctx.make_node("Div", [idx_reshape, col_cnt]).output[0]
1313+
col_idx = ctx.make_node("Mod", [idx_reshape, col_cnt]).output[0]
1314+
idx_diff = ctx.make_node("Sub", [col_idx, row_idx]).output[0]
1315+
1316+
if num_upper_const is None or num_upper_const >= 0:
1317+
if ctx.get_dtype(num_upper) != TensorProto.INT64:
1318+
num_upper = ctx.make_node("Cast", [num_upper], attr={'to': TensorProto.INT64}).output[0]
1319+
conditions.append(ctx.make_node("LessOrEqual", [idx_diff, num_upper]).output[0])
1320+
if num_lower_const is None or num_lower_const >= 0:
1321+
if ctx.get_dtype(num_lower) != TensorProto.INT64:
1322+
num_lower = ctx.make_node("Cast", [num_lower], attr={'to': TensorProto.INT64}).output[0]
1323+
num_lower_neg = ctx.make_node("Neg", [num_lower]).output[0]
1324+
conditions.append(ctx.make_node("LessOrEqual", [num_lower_neg, idx_diff]).output[0])
1325+
if len(conditions) == 0:
1326+
node.type = "Identity"
1327+
ctx.replace_inputs(node, [data])
1328+
return
1329+
if len(conditions) == 1:
1330+
cond = conditions[0]
1331+
if len(conditions) == 2:
1332+
cond = ctx.make_node("And", conditions).output[0]
1333+
mask = ctx.make_node("Cast", [cond], attr={'to': ctx.get_dtype(data)}).output[0]
1334+
shapes = node.output_shapes
1335+
dtypes = node.output_dtypes
1336+
ctx.remove_node(node.name)
1337+
ctx.make_node(op_type="Mul", inputs=[mask, data],
1338+
name=node.name, outputs=node.output, shapes=shapes,
1339+
dtypes=dtypes)
12771340

12781341

12791342
def _make_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node):

0 commit comments

Comments
 (0)