From 401ca1aa26658a44210774e225e0df626b9f2d48 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 12 Jan 2021 15:52:36 -0800 Subject: [PATCH] updated supported ops doc Signed-off-by: Guenther Schmuelling --- support_status.md | 423 ++++++++++++++++++++---------------- tools/gen_doc.py | 5 +- tools/quantitize_weights.py | 184 ---------------- 3 files changed, 239 insertions(+), 373 deletions(-) delete mode 100644 tools/quantitize_weights.py diff --git a/support_status.md b/support_status.md index 4eac09e49..6a196fb36 100644 --- a/support_status.md +++ b/support_status.md @@ -2,193 +2,232 @@ ### Domain: "" (default domain) | Tensorflow Op | Convertible to ONNX Op Versions | | ------------- | ------------------------------- | -| Abs | 1 ~ 12 | -| Acos | 7 ~ 12 | -| Acosh | 9 ~ 12 | -| Add | 1 ~ 12 | -| AddN | 6 ~ 12 | -| AddV2 | 1 ~ 12 | -| All | 6 ~ 12 | -| Any | 6 ~ 12 | -| ArgMax | 1 ~ 12 | -| ArgMin | 1 ~ 12 | -| Asin | 7 ~ 12 | -| Asinh | 9 ~ 12 | -| Atan | 7 ~ 12 | -| Atanh | 9 ~ 12 | -| AvgPool | 1 ~ 12 | -| AvgPool3D | 1 ~ 12 | -| BatchMatMul | 1 ~ 12 | -| BatchMatMulV2 | 1 ~ 12 | -| BatchToSpaceND | 1 ~ 12 | -| BiasAdd | 1 ~ 12 | -| BiasAddV1 | 1 ~ 12 | -| BroadcastTo | 8 ~ 12 | -| Cast | 1 ~ 12 | -| Ceil | 1 ~ 12 | -| CheckNumerics | 1 ~ 12 | -| ClipByValue | 8 ~ 12 | -| Concat | 1 ~ 12 | -| ConcatV2 | 1 ~ 12 | -| Const | 1 ~ 12 | -| ConstV2 | 1 ~ 12 | -| Conv1D | 1 ~ 12 | -| Conv2D | 1 ~ 12 | -| Conv2DBackpropInput | 1 ~ 12 | -| Conv3D | 1 ~ 12 | -| Cos | 7 ~ 12 | -| Cosh | 9 ~ 12 | -| CropAndResize | 10 ~ 12 | -| CudnnRNN | 10 ~ 12 | -| Cumsum | 11 ~ 12 | -| DepthToSpace | 1 ~ 12 | -| DepthwiseConv2d | 1 ~ 12 | -| DepthwiseConv2dNative | 1 ~ 12 | -| Div | 1 ~ 12 | -| Dropout | 1 ~ 12 | -| Einsum | 12 | -| Elu | 1 ~ 12 | -| Equal | 1 ~ 12 | -| Erf | 1 ~ 12 | -| Exp | 1 ~ 12 | -| ExpandDims | 1 ~ 12 | -| FIFOQueueV2 | 8 ~ 12 | -| Fill | 7 ~ 12 | -| Flatten | 1 ~ 12 | -| Floor | 1 ~ 12 | -| FloorDiv | 6 ~ 12 | -| FloorMod | 7 ~ 12 | -| FusedBatchNorm | 6 ~ 12 | -| FusedBatchNormV2 | 6 ~ 12 | -| FusedBatchNormV3 | 6 ~ 12 | -| Gather | 1 ~ 12 | -| GatherNd | 1 ~ 12 | -| GatherV2 | 1 ~ 12 | -| Greater | 1 ~ 12 | -| GreaterEqual | 7 ~ 12 | -| HashTableV2 | 8 ~ 12 | -| Identity | 1 ~ 12 | -| IdentityN | 1 ~ 12 | -| If | 1 ~ 12 | -| IsFinite | 10 ~ 12 | -| IsInf | 10 ~ 12 | -| IsNan | 9 ~ 12 | -| IteratorGetNext | 8 ~ 12 | -| IteratorV2 | 8 ~ 12 | -| LRN | 1 ~ 12 | -| LSTMBlockCell | 1 ~ 12 | -| LeakyRelu | 1 ~ 12 | -| LeftShift | 11 ~ 12 | -| Less | 1 ~ 12 | -| LessEqual | 7 ~ 12 | -| Log | 1 ~ 12 | -| LogSoftmax | 1 ~ 12 | -| LogicalAnd | 1 ~ 12 | -| LogicalNot | 1 ~ 12 | -| LogicalOr | 1 ~ 12 | -| LookupTableFindV2 | 8 ~ 12 | -| Loop | 7 ~ 12 | -| MatMul | 1 ~ 12 | -| MatrixBandPart | 7 ~ 12 | -| MatrixDeterminant | 11 ~ 12 | -| MatrixDiagPart | 11 ~ 12 | -| MatrixDiagPartV2 | 11 ~ 12 | -| MatrixDiagPartV3 | 11 ~ 12 | -| Max | 1 ~ 12 | -| MaxPool | 1 ~ 12 | -| MaxPoolV2 | 1 ~ 12 | -| MaxPoolWithArgmax | 8 ~ 12 | -| Maximum | 1 ~ 12 | -| Mean | 1 ~ 12 | -| Min | 1 ~ 12 | -| Minimum | 1 ~ 12 | -| MirrorPad | 1 ~ 12 | -| Mul | 1 ~ 12 | -| Multinomial | 7 ~ 12 | -| Neg | 1 ~ 12 | -| NoOp | 1 ~ 12 | -| NonMaxSuppressionV2 | 10 ~ 12 | -| NonMaxSuppressionV3 | 10 ~ 12 | -| NonMaxSuppressionV4 | 10 ~ 12 | -| NonMaxSuppressionV5 | 10 ~ 12 | -| NotEqual | 1 ~ 12 | -| OneHot | 1 ~ 12 | -| Pack | 1 ~ 12 | -| Pad | 1 ~ 12 | -| PadV2 | 1 ~ 12 | -| Placeholder | 1 ~ 12 | -| PlaceholderV2 | 1 ~ 12 | -| PlaceholderWithDefault | 1 ~ 12 | -| Pow | 1 ~ 12 | -| Prod | 1 ~ 12 | -| QueueDequeueV2 | 8 ~ 12 | -| RandomNormal | 1 ~ 12 | -| RandomNormalLike | 1 ~ 12 | -| RandomUniform | 1 ~ 12 | -| RandomUniformLike | 1 ~ 12 | -| Range | 7 ~ 12 | -| RealDiv | 1 ~ 12 | -| Reciprocal | 1 ~ 12 | -| Relu | 1 ~ 12 | -| Relu6 | 1 ~ 12 | -| Reshape | 1 ~ 12 | -| ResizeBilinear | 7 ~ 12 | -| ResizeNearestNeighbor | 7 ~ 12 | -| ReverseSequence | 8 ~ 12 (Except 9) | -| ReverseV2 | 10 ~ 12 | -| RightShift | 11 ~ 12 | -| Round | 11 ~ 12 | -| Rsqrt | 1 ~ 12 | -| Scan | 7 ~ 12 | -| ScatterNd | 11 ~ 12 | -| Select | 7 ~ 12 | -| SelectV2 | 7 ~ 12 | -| Selu | 1 ~ 12 | -| Shape | 1 ~ 12 | -| Sigmoid | 1 ~ 12 | -| Sign | 1 ~ 12 | -| Sin | 7 ~ 12 | -| Sinh | 9 ~ 12 | -| Size | 1 ~ 12 | -| Slice | 1 ~ 12 | -| Softmax | 1 ~ 12 | -| SoftmaxCrossEntropyWithLogits | 7 ~ 12 | -| Softplus | 1 ~ 12 | -| Softsign | 1 ~ 12 | -| SpaceToBatchND | 1 ~ 12 | -| SpaceToDepth | 1 ~ 12 | -| SparseSoftmaxCrossEntropyWithLogits | 7 ~ 12 | -| Split | 1 ~ 12 | -| SplitV | 1 ~ 12 | -| Sqrt | 1 ~ 12 | -| Square | 1 ~ 12 | -| SquaredDifference | 1 ~ 12 | -| SquaredDistance | 12 | -| Squeeze | 1 ~ 12 | -| StatelessIf | 1 ~ 12 | -| StatelessWhile | 7 ~ 12 | -| StopGradient | 1 ~ 12 | -| StridedSlice | 1 ~ 12 | -| Sub | 1 ~ 12 | -| Sum | 1 ~ 12 | -| Tan | 7 ~ 12 | -| Tanh | 1 ~ 12 | -| TensorListFromTensor | 7 ~ 12 | -| TensorListGetItem | 7 ~ 12 | -| TensorListLength | 7 ~ 12 | -| TensorListReserve | 7 ~ 12 | -| TensorListResize | 7 ~ 12 | -| TensorListSetItem | 7 ~ 12 | -| TensorListStack | 7 ~ 12 | -| Tile | 1 ~ 12 | -| TopKV2 | 1 ~ 12 | -| Transpose | 1 ~ 12 | -| TruncateDiv | 1 ~ 12 | -| Unique | 11 ~ 12 | -| Unpack | 1 ~ 12 | -| Where | 9 ~ 12 | -| While | 7 ~ 12 | -| ZerosLike | 1 ~ 12 | +| Abs | 1 ~ 13 | +| Acos | 7 ~ 13 | +| Acosh | 9 ~ 13 | +| Add | 1 ~ 13 | +| AddN | 6 ~ 13 | +| AddV2 | 1 ~ 13 | +| All | 6 ~ 13 | +| Any | 6 ~ 13 | +| ArgMax | 1 ~ 13 | +| ArgMin | 1 ~ 13 | +| Asin | 7 ~ 13 | +| Asinh | 9 ~ 13 | +| Atan | 7 ~ 13 | +| Atan2 | 9 ~ 13 | +| Atanh | 9 ~ 13 | +| AvgPool | 1 ~ 13 | +| AvgPool3D | 1 ~ 13 | +| BatchMatMul | 1 ~ 13 | +| BatchMatMulV2 | 1 ~ 13 | +| BatchToSpaceND | 1 ~ 13 | +| BiasAdd | 1 ~ 13 | +| BiasAddV1 | 1 ~ 13 | +| Bincount | 11 ~ 13 | +| BroadcastTo | 8 ~ 13 | +| Cast | 1 ~ 13 | +| Ceil | 1 ~ 13 | +| CheckNumerics | 1 ~ 13 | +| ClipByValue | 8 ~ 13 | +| ComplexAbs | 1 ~ 13 | +| Concat | 1 ~ 13 | +| ConcatV2 | 1 ~ 13 | +| Const | 1 ~ 13 | +| ConstV2 | 1 ~ 13 | +| Conv1D | 1 ~ 13 | +| Conv2D | 1 ~ 13 | +| Conv2DBackpropInput | 1 ~ 13 | +| Conv3D | 1 ~ 13 | +| Conv3DBackpropInputV2 | 1 ~ 13 | +| Cos | 7 ~ 13 | +| Cosh | 9 ~ 13 | +| CropAndResize | 10 ~ 13 | +| CudnnRNN | 10 ~ 13 | +| Cumsum | 11 ~ 13 | +| DepthToSpace | 1 ~ 13 | +| DepthwiseConv2d | 1 ~ 13 | +| DepthwiseConv2dNative | 1 ~ 13 | +| Div | 1 ~ 13 | +| Dropout | 1 ~ 13 | +| DynamicPartition | 9 ~ 13 | +| DynamicStitch | 10 ~ 13 | +| Einsum | 12 ~ 13 | +| Elu | 1 ~ 13 | +| Equal | 1 ~ 13 | +| Erf | 1 ~ 13 | +| Exp | 1 ~ 13 | +| ExpandDims | 1 ~ 13 | +| FIFOQueueV2 | 8 ~ 13 | +| FakeQuantWithMinMaxArgs | 10 ~ 13 | +| Fill | 7 ~ 13 | +| Flatten | 1 ~ 13 | +| Floor | 1 ~ 13 | +| FloorDiv | 6 ~ 13 | +| FloorMod | 7 ~ 13 | +| FusedBatchNorm | 6 ~ 13 | +| FusedBatchNormV2 | 6 ~ 13 | +| FusedBatchNormV3 | 6 ~ 13 | +| Gather | 1 ~ 13 | +| GatherNd | 1 ~ 13 | +| GatherV2 | 1 ~ 13 | +| Greater | 1 ~ 13 | +| GreaterEqual | 7 ~ 13 | +| HashTableV2 | 8 ~ 13 | +| Identity | 1 ~ 13 | +| IdentityN | 1 ~ 13 | +| If | 1 ~ 13 | +| InvertPermutation | 11 ~ 13 | +| IsFinite | 10 ~ 13 | +| IsInf | 10 ~ 13 | +| IsNan | 9 ~ 13 | +| IteratorGetNext | 8 ~ 13 | +| IteratorV2 | 8 ~ 13 | +| LRN | 1 ~ 13 | +| LSTMBlockCell | 1 ~ 13 | +| LeakyRelu | 1 ~ 13 | +| LeftShift | 11 ~ 13 | +| Less | 1 ~ 13 | +| LessEqual | 7 ~ 13 | +| Log | 1 ~ 13 | +| LogSoftmax | 1 ~ 13 | +| LogicalAnd | 1 ~ 13 | +| LogicalNot | 1 ~ 13 | +| LogicalOr | 1 ~ 13 | +| LookupTableFindV2 | 8 ~ 13 | +| LookupTableSizeV2 | 1 ~ 13 | +| Loop | 7 ~ 13 | +| MatMul | 1 ~ 13 | +| MatrixBandPart | 7 ~ 13 | +| MatrixDeterminant | 11 ~ 13 | +| MatrixDiag | 12 ~ 13 | +| MatrixDiagPart | 11 ~ 13 | +| MatrixDiagPartV2 | 11 ~ 13 | +| MatrixDiagPartV3 | 11 ~ 13 | +| MatrixDiagV2 | 12 ~ 13 | +| MatrixDiagV3 | 12 ~ 13 | +| MatrixSetDiagV3 | 12 ~ 13 | +| Max | 1 ~ 13 | +| MaxPool | 1 ~ 13 | +| MaxPool3D | 1 ~ 13 | +| MaxPoolV2 | 1 ~ 13 | +| MaxPoolWithArgmax | 8 ~ 13 | +| Maximum | 1 ~ 13 | +| Mean | 1 ~ 13 | +| Min | 1 ~ 13 | +| Minimum | 1 ~ 13 | +| MirrorPad | 1 ~ 13 | +| Mul | 1 ~ 13 | +| Multinomial | 7 ~ 13 | +| Neg | 1 ~ 13 | +| NoOp | 1 ~ 13 | +| NonMaxSuppressionV2 | 10 ~ 13 | +| NonMaxSuppressionV3 | 10 ~ 13 | +| NonMaxSuppressionV4 | 10 ~ 13 | +| NonMaxSuppressionV5 | 10 ~ 13 | +| NotEqual | 1 ~ 13 | +| OneHot | 1 ~ 13 | +| Pack | 1 ~ 13 | +| Pad | 1 ~ 13 | +| PadV2 | 1 ~ 13 | +| ParallelDynamicStitch | 10 ~ 13 | +| Placeholder | 1 ~ 13 | +| PlaceholderV2 | 1 ~ 13 | +| PlaceholderWithDefault | 1 ~ 13 | +| Pow | 1 ~ 13 | +| Prod | 1 ~ 13 | +| QueueDequeueManyV2 | 8 ~ 13 | +| QueueDequeueUpToV2 | 8 ~ 13 | +| QueueDequeueV2 | 8 ~ 13 | +| RFFT | 1 ~ 13 | +| RaggedRange | 11 ~ 13 | +| RandomNormal | 1 ~ 13 | +| RandomNormalLike | 1 ~ 13 | +| RandomUniform | 1 ~ 13 | +| RandomUniformLike | 1 ~ 13 | +| Range | 7 ~ 13 | +| RealDiv | 1 ~ 13 | +| Reciprocal | 1 ~ 13 | +| Relu | 1 ~ 13 | +| Relu6 | 1 ~ 13 | +| Reshape | 1 ~ 13 | +| ResizeBilinear | 7 ~ 13 | +| ResizeNearestNeighbor | 7 ~ 13 | +| ReverseSequence | 8 ~ 13 (Except 9) | +| ReverseV2 | 10 ~ 13 | +| RightShift | 11 ~ 13 | +| Roll | 10 ~ 13 | +| Round | 11 ~ 13 | +| Rsqrt | 1 ~ 13 | +| Scan | 7 ~ 13 | +| ScatterNd | 11 ~ 13 | +| SegmentMax | 9 ~ 13 | +| SegmentMean | 9 ~ 13 | +| SegmentMin | 9 ~ 13 | +| SegmentProd | 9 ~ 13 | +| SegmentSum | 9 ~ 13 | +| Select | 7 ~ 13 | +| SelectV2 | 7 ~ 13 | +| Selu | 1 ~ 13 | +| Shape | 1 ~ 13 | +| Sigmoid | 1 ~ 13 | +| Sign | 1 ~ 13 | +| Sin | 7 ~ 13 | +| Sinh | 9 ~ 13 | +| Size | 1 ~ 13 | +| Slice | 1 ~ 13 | +| Softmax | 1 ~ 13 | +| SoftmaxCrossEntropyWithLogits | 7 ~ 13 | +| Softplus | 1 ~ 13 | +| Softsign | 1 ~ 13 | +| SpaceToBatchND | 1 ~ 13 | +| SpaceToDepth | 1 ~ 13 | +| SparseFillEmptyRows | 11 ~ 13 | +| SparseReshape | 11 ~ 13 | +| SparseSegmentMean | 9 ~ 13 | +| SparseSegmentMeanWithNumSegments | 9 ~ 13 | +| SparseSegmentSqrtN | 9 ~ 13 | +| SparseSegmentSqrtNWithNumSegments | 9 ~ 13 | +| SparseSegmentSum | 9 ~ 13 | +| SparseSegmentSumWithNumSegments | 9 ~ 13 | +| SparseSoftmaxCrossEntropyWithLogits | 7 ~ 13 | +| SparseToDense | 11 ~ 13 | +| Split | 1 ~ 13 | +| SplitV | 1 ~ 13 | +| Sqrt | 1 ~ 13 | +| Square | 1 ~ 13 | +| SquaredDifference | 1 ~ 13 | +| SquaredDistance | 12 ~ 13 | +| Squeeze | 1 ~ 13 | +| StatelessIf | 1 ~ 13 | +| StatelessWhile | 7 ~ 13 | +| StopGradient | 1 ~ 13 | +| StridedSlice | 1 ~ 13 | +| Sub | 1 ~ 13 | +| Sum | 1 ~ 13 | +| Tan | 7 ~ 13 | +| Tanh | 1 ~ 13 | +| TensorListFromTensor | 7 ~ 13 | +| TensorListGetItem | 7 ~ 13 | +| TensorListLength | 7 ~ 13 | +| TensorListReserve | 7 ~ 13 | +| TensorListResize | 7 ~ 13 | +| TensorListSetItem | 7 ~ 13 | +| TensorListStack | 7 ~ 13 | +| TensorScatterUpdate | 11 ~ 13 | +| Tile | 1 ~ 13 | +| TopKV2 | 1 ~ 13 | +| Transpose | 1 ~ 13 | +| TruncateDiv | 1 ~ 13 | +| Unique | 11 ~ 13 | +| Unpack | 1 ~ 13 | +| UnsortedSegmentMax | 9 ~ 13 | +| UnsortedSegmentMin | 9 ~ 13 | +| UnsortedSegmentProd | 9 ~ 13 | +| UnsortedSegmentSum | 9 ~ 13 | +| Where | 9 ~ 13 | +| While | 7 ~ 13 | +| ZerosLike | 1 ~ 13 | ### Domain: "com.microsoft" | Tensorflow Op | Convertible to ONNX Op Versions | | ------------- | ------------------------------- | @@ -196,3 +235,13 @@ | CropAndResize | 1 | | MatrixInverse | 1 | | Range | 1 | +### Domain: "ai.onnx.contrib" +| Tensorflow Op | Convertible to ONNX Op Versions | +| ------------- | ------------------------------- | +| Equal | 1 | +| NotEqual | 1 | +| StaticRegexReplace | 1 | +| StringJoin | 1 | +| StringSplit | 1 | +| StringSplitV2 | 1 | +| StringToHashBucketFast | 1 | diff --git a/tools/gen_doc.py b/tools/gen_doc.py index 2ada2afcd..e8e428181 100644 --- a/tools/gen_doc.py +++ b/tools/gen_doc.py @@ -18,8 +18,9 @@ LATEST_OPSET = { - "": 12, # default domain - "com.microsoft": 1 # microsoft domain + "": 13, # default domain + "com.microsoft": 1, # microsoft domain + "ai.onnx.contrib": 1, # contrib ops } diff --git a/tools/quantitize_weights.py b/tools/quantitize_weights.py deleted file mode 100644 index ec1e184a9..000000000 --- a/tools/quantitize_weights.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -""" -quantitize_weights.py - simple script to quantitize weights (not the model) to 8 bits. -""" - -from __future__ import division -from __future__ import print_function - -import argparse -import logging -import numpy as np -from onnx import ModelProto, helper, onnx_pb, numpy_helper - - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("quantitize_weights") - - -def _get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--input", required=True, help="input model") - parser.add_argument("--output", required=True, help="output model") - parser.add_argument("--verbose", help="verbose", action="store_true") - args = parser.parse_args() - return args - - -def eight_bit_dequantitize(w_in, zp, scale): - w = w_in * scale + zp - w = w.astype("float32") - return w - - -def eight_bit_quantitize(w_in): - """quantitize to 8 bit as scale and zeropoint""" - low = np.min(w_in) - high = np.max(w_in) - scale = (high - low) / 256. - w = (w_in - low) / scale - w_out = w.astype("uint8") - return w_out, low, scale - - -def _port_name(name): - return name + "__out" - - -def _make_node(nodes, op, name, inputs, **kwargs): - node = helper.make_node(op, inputs, [_port_name(name)], name=name, **kwargs) - nodes.append(node) - - -def _compose_quantitize(nodes, weights, zp, scale, name): - """ - Compose Dequantitize(input, zeropoint, scale). - """ - name_zp = name + "_zeropoint" - name_scale = name + "_scale" - name_add = name + "_add" - name_mul = name + "_mul" - name_cast = name + "_cast" - - # add zeropoint and scale as initializers - weights.append(numpy_helper.from_array(np.array(zp, dtype=np.float32), name_zp)) - weights.append(numpy_helper.from_array(np.array(scale, dtype=np.float32), name_scale)) - - # insert ops to dequantitize - _make_node(nodes, "Cast", name_cast, [name], to=onnx_pb.TensorProto.FLOAT) - _make_node(nodes, "Mul", name_mul, [_port_name(name_cast), name_scale]) - _make_node(nodes, "Add", name_add, [_port_name(name_mul), name_zp]) - - return _port_name(name_add) - - -def stats(a): - return {"mean": a.mean(), "std": a.std(), "max": a.max(), "min": a.min()} - - -def quantitize_graph(g, verbose=False): - """Quantitize graph.""" - new_weights = [] - quantitized_weights = [] - nodes = [] - remap = {} - remove = [] - - for i, w in enumerate(g.initializer): - # only quantitize float32 - if w.data_type != onnx_pb.TensorProto.FLOAT: - continue - w_np = numpy_helper.to_array(w) - # only look at sizes >= 32 elements - if w_np.size < 32: - continue - - # weights we want to quantitize - remove.append(i) - name = w.name - if verbose: - logger.info("quantitizing %s", name) - w_quant, zp, scale = eight_bit_quantitize(w_np) - nw = numpy_helper.from_array(w_quant, name=name) - if verbose: - w_dequant = eight_bit_dequantitize(w_quant, zp, scale) - rtol = np.abs(w_dequant - w_np) - s = {} - for j in [1.0, 5.0, 10.0, 20.0]: - above_rtol = np.sum(rtol > np.abs(j * w_np / 100.)) / w_np.size - s["> " + str(j) + "%"] = "{:.2f}".format(100. * above_rtol) - logger.info("above_rtol: %s", str(s)) - logger.info("raw: %s", stats(w_np)) - logger.info("quant: %s", stats(w_dequant)) - output_name = _compose_quantitize(nodes, new_weights, zp, scale, name) - remap[name] = output_name - quantitized_weights.append(nw) - - # few things to do to initializers and graph inputs: - - # 1. remove initializers that got quantitized - for i in reversed(remove): - del g.initializer[i] - - # 2. add quantitized to initializers - g.initializer.extend(new_weights) - g.initializer.extend(quantitized_weights) - - # 3. modify the type of weights that we quantitized - modified = {w.name: w for w in quantitized_weights} - new_inputs = [] - remove = [] - for i, inp in enumerate(g.input): - w = modified.get(inp.name) - if w is not None: - new_inputs.append(helper.make_tensor_value_info(w.name, w.data_type, w.dims)) - remove.append(i) - for i in reversed(remove): - del g.input[i] - - # 4. add new weights as inputs - for w in new_weights: - tv = helper.make_tensor_value_info(w.name, w.data_type, w.dims) - new_inputs.append(tv) - g.input.extend(new_inputs) - - # 5. rewrite consumers of the quantitized weights - for node in g.node: - for i, name in enumerate(node.input): - new_name = remap.get(name) - if new_name is not None: - node.input[i] = new_name - - # 6. add composed nodes to graph, new nodes in the front - nodes.extend(g.node) - del g.node[:] - g.node.extend(nodes) - return g - - -def main(): - args = _get_args() - - # read onnx graph - with open(args.input, "rb") as f: - data = f.read() - model_proto = ModelProto() - model_proto.ParseFromString(data) - - # quantitize weights - g = quantitize_graph(model_proto.graph, args.verbose) - - # write quantitized graph - with open(args.output, "wb") as f: - # create model proto - model_proto_out = helper.make_model(g, - producer_name="quantized {}".format(model_proto.producer_name), - producer_version=model_proto.producer_version, - opset_imports=model_proto.opset_import) - f.write(model_proto_out.SerializeToString()) - - -if __name__ == "__main__": - main()