|
12 | 12 | import logging
|
13 | 13 |
|
14 | 14 | import numpy as np
|
15 |
| -from onnx import onnx_pb |
| 15 | +from onnx import onnx_pb, helper |
16 | 16 | from onnx.onnx_pb import TensorProto
|
17 | 17 | from tf2onnx import constants, utils
|
18 | 18 | from tf2onnx.graph_builder import GraphBuilder
|
@@ -1188,13 +1188,13 @@ def version_11(cls, ctx, node, **kwargs):
|
1188 | 1188 | @tf_op("MatrixBandPart")
|
1189 | 1189 | class MatrixBandPart:
|
1190 | 1190 | @classmethod
|
1191 |
| - def any_version_after7(cls, opset, ctx, node, **kwargs): |
| 1191 | + def version_7(cls, opset, ctx, node, **kwargs): |
1192 | 1192 | # T output = MatrixBandPart(T input, int num_lower, int num_upper)
|
1193 | 1193 | # data-flow: first generate mask matrix and then use element-wise mul op
|
1194 | 1194 | input_rank = len(ctx.get_shape(node.input[0]))
|
1195 | 1195 | utils.make_sure(input_rank == 2, error_msg="MatrixBandPart op: only rank 2 is supported")
|
1196 | 1196 | bandpart = [node.inputs[ind].get_tensor_value() for ind in [1, 2]]
|
1197 |
| - utils.make_sure(bandpart in [[-1, 0], [0, -1]], "only support Lower/Upper triangular for now") |
| 1197 | + utils.make_sure(bandpart in [[-1, 0], [0, -1]], "only support Lower/Upper triangular for opset < 11") |
1198 | 1198 | # methods to generate mask matrix: if lower triangular is needed, then generate column one by one
|
1199 | 1199 | # otherwise row is generated one by one.
|
1200 | 1200 | axis, counter_axis, squeeze_axis = (1, 0, 2) if bandpart == [-1, 0] else (0, 1, 1)
|
@@ -1267,13 +1267,76 @@ def any_version_after7(cls, opset, ctx, node, **kwargs):
|
1267 | 1267 | dtypes=dtypes)
|
1268 | 1268 |
|
1269 | 1269 | @classmethod
|
1270 |
| - def version_7(cls, ctx, node, **kwargs): |
1271 |
| - cls.any_version_after7(7, ctx, node, **kwargs) |
1272 |
| - |
1273 |
| - @classmethod |
1274 |
| - def version_13(cls, ctx, node, **kwargs): |
1275 |
| - # Signature of operator Squeeze changed. |
1276 |
| - cls.any_version_after7(13, ctx, node, **kwargs) |
| 1270 | + def version_11(cls, ctx, node, **kwargs): |
| 1271 | + num_lower_const = node.inputs[1].get_tensor_value() if node.inputs[1].is_const() else None |
| 1272 | + num_upper_const = node.inputs[2].get_tensor_value() if node.inputs[2].is_const() else None |
| 1273 | + data, num_lower, num_upper = node.input |
| 1274 | + rank = ctx.get_rank(data) |
| 1275 | + int_max_val = utils.get_max_value(np.int64) |
| 1276 | + dtype = ctx.get_dtype(data) |
| 1277 | + np_dtype = utils.map_onnx_to_numpy_type(dtype) |
| 1278 | + if rank == 2: |
| 1279 | + shape = ctx.make_node("Shape", [data]).output[0] |
| 1280 | + else: |
| 1281 | + whole_shape = ctx.make_node("Shape", [data]).output[0] |
| 1282 | + shape = GraphBuilder(ctx).make_slice( |
| 1283 | + {'data': whole_shape, 'starts': [-2], 'ends': [int_max_val], 'axes': [0]}) |
| 1284 | + if num_lower_const == 0 and num_upper_const == 0: |
| 1285 | + if rank == 2: |
| 1286 | + identity_node = ctx.make_node("EyeLike", [data]).output[0] |
| 1287 | + else: |
| 1288 | + zero_tensor = helper.make_tensor("value", dtype, dims=[1], vals=[0]) |
| 1289 | + const_of_shape = ctx.make_node("ConstantOfShape", [shape], attr={'value': zero_tensor}).output[0] |
| 1290 | + identity_node = ctx.make_node("EyeLike", [const_of_shape]).output[0] |
| 1291 | + one_const = ctx.make_const(utils.make_name("one"), np.array(1, np_dtype)).output[0] |
| 1292 | + mask = ctx.make_node("Sub", [one_const, identity_node]).output[0] |
| 1293 | + shapes = node.output_shapes |
| 1294 | + dtypes = node.output_dtypes |
| 1295 | + ctx.remove_node(node.name) |
| 1296 | + ctx.make_node(op_type="Mul", inputs=[identity_node, data], |
| 1297 | + name=node.name, outputs=node.output, shapes=shapes, |
| 1298 | + dtypes=dtypes) |
| 1299 | + return |
| 1300 | + zero_const = ctx.make_const(utils.make_name("zero"), np.array(0, np.int64)).output[0] |
| 1301 | + one_const = ctx.make_const(utils.make_name("one"), np.array(1, np.int64)).output[0] |
| 1302 | + conditions = [] |
| 1303 | + row_cnt = GraphBuilder(ctx).make_slice({'data': shape, 'axes': [0], 'starts': [0], 'ends': [1]}) |
| 1304 | + col_cnt = GraphBuilder(ctx).make_slice({'data': shape, 'axes': [0], 'starts': [1], 'ends': [2]}) |
| 1305 | + limit = ctx.make_node("Mul", [row_cnt, col_cnt]).output[0] |
| 1306 | + # idx_cnt = ctx.make_node("Range", [zero_const, limit, one_const]).output[0] |
| 1307 | + |
| 1308 | + ones_of_shape = ctx.make_node("Expand", [one_const, limit]).output[0] |
| 1309 | + idx_cnt = ctx.make_node("CumSum", [ones_of_shape, zero_const], attr={'exclusive': True}).output[0] |
| 1310 | + |
| 1311 | + idx_reshape = ctx.make_node("Reshape", [idx_cnt, shape]).output[0] |
| 1312 | + row_idx = ctx.make_node("Div", [idx_reshape, col_cnt]).output[0] |
| 1313 | + col_idx = ctx.make_node("Mod", [idx_reshape, col_cnt]).output[0] |
| 1314 | + idx_diff = ctx.make_node("Sub", [col_idx, row_idx]).output[0] |
| 1315 | + |
| 1316 | + if num_upper_const is None or num_upper_const >= 0: |
| 1317 | + if ctx.get_dtype(num_upper) != TensorProto.INT64: |
| 1318 | + num_upper = ctx.make_node("Cast", [num_upper], attr={'to': TensorProto.INT64}).output[0] |
| 1319 | + conditions.append(ctx.make_node("LessOrEqual", [idx_diff, num_upper]).output[0]) |
| 1320 | + if num_lower_const is None or num_lower_const >= 0: |
| 1321 | + if ctx.get_dtype(num_lower) != TensorProto.INT64: |
| 1322 | + num_lower = ctx.make_node("Cast", [num_lower], attr={'to': TensorProto.INT64}).output[0] |
| 1323 | + num_lower_neg = ctx.make_node("Neg", [num_lower]).output[0] |
| 1324 | + conditions.append(ctx.make_node("LessOrEqual", [num_lower_neg, idx_diff]).output[0]) |
| 1325 | + if len(conditions) == 0: |
| 1326 | + node.type = "Identity" |
| 1327 | + ctx.replace_inputs(node, [data]) |
| 1328 | + return |
| 1329 | + if len(conditions) == 1: |
| 1330 | + cond = conditions[0] |
| 1331 | + if len(conditions) == 2: |
| 1332 | + cond = ctx.make_node("And", conditions).output[0] |
| 1333 | + mask = ctx.make_node("Cast", [cond], attr={'to': ctx.get_dtype(data)}).output[0] |
| 1334 | + shapes = node.output_shapes |
| 1335 | + dtypes = node.output_dtypes |
| 1336 | + ctx.remove_node(node.name) |
| 1337 | + ctx.make_node(op_type="Mul", inputs=[mask, data], |
| 1338 | + name=node.name, outputs=node.output, shapes=shapes, |
| 1339 | + dtypes=dtypes) |
1277 | 1340 |
|
1278 | 1341 |
|
1279 | 1342 | def _make_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node):
|
|
0 commit comments