Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,86 @@ def test_transpose_leaky_relu(self):
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@check_opset_min_version(10, "QuantizeLinear")
def test_transpose_quantize(self):
scale = numpy_helper.from_array(np.array(0.75, dtype=np.float32), name='scale')
zero_point = numpy_helper.from_array(np.array(3, dtype=np.uint8), name='zero_point')
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node2 = helper.make_node("QuantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="quantize")
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node1, node2, node3],
"quantize-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
[helper.make_tensor_value_info("Z1", TensorProto.UINT8, (2, 3, 4, 5))],
[scale, zero_point]
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@check_opset_min_version(13, "QuantizeLinear with axis")
def test_transpose_quantize_with_axis(self):
scale = numpy_helper.from_array(np.array([0.75, 0.1, 2.3, 0.3, 0.42], dtype=np.float32), name='scale')
zero_point = numpy_helper.from_array(np.array([2, 4, 6, 8, 10], dtype=np.uint8), name='zero_point')
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node2 = helper.make_node("QuantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="quantize", axis=2)
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node1, node2, node3],
"quantize-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
[helper.make_tensor_value_info("Z1", TensorProto.UINT8, (2, 3, 4, 5))],
[scale, zero_point]
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@check_opset_min_version(10, "DequantizeLinear")
def test_transpose_dequantize(self):
scale = numpy_helper.from_array(np.array(0.75, dtype=np.float32), name='scale')
zero_point = numpy_helper.from_array(np.array(3, dtype=np.uint8), name='zero_point')
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node2 = helper.make_node("DequantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="dequantize")
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node1, node2, node3],
"dequantize-test",
[helper.make_tensor_value_info("X", TensorProto.UINT8, (2, 3, 4, 5))],
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, (2, 3, 4, 5))],
[scale, zero_point]
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(["Z1"], {"X": np.random.randint(0, 100, (2, 3, 4, 5), np.uint8)},
model_proto, remaining_transpose_num=0)

@check_opset_min_version(13, "DequantizeLinear with axis")
def test_transpose_dequantize_with_axis(self):
scale = numpy_helper.from_array(np.array([0.75, 0.1, 2.3, 0.3, 0.42], dtype=np.float32), name='scale')
zero_point = numpy_helper.from_array(np.array([2, 4, 6, 8, 10], dtype=np.uint8), name='zero_point')
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node2 = helper.make_node("DequantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="dequantize", axis=2)
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node1, node2, node3],
"dequantize-test",
[helper.make_tensor_value_info("X", TensorProto.UINT8, (2, 3, 4, 5))],
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, (2, 3, 4, 5))],
[scale, zero_point]
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(["Z1"], {"X": np.random.randint(0, 100, (2, 3, 4, 5), np.uint8)},
model_proto, remaining_transpose_num=0)

@check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'")
def test_transpose_slice(self):
starts = np.array([0, 0, 0, 0], dtype=np.int64)
Expand Down
13 changes: 13 additions & 0 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def _initialize_handlers(self):
"Sub": self._sub_handler,
"Tanh": self._simple_through_handler,
"Transpose": self._transpose_handler,
"DequantizeLinear": self._quantize_handler,
"QuantizeLinear": self._quantize_handler,
}

def _handle_node_having_branches(self, node):
Expand Down Expand Up @@ -698,6 +700,17 @@ def _slice_handler(self, trans, node):
return self._switch_transpose_and_node(node, trans)
return False

def _quantize_handler(self, trans, node):
# Used for QuantizeLinear and DequantizeLinear
if not self._switch_transpose_and_node(node, trans):
return False
if 'axis' in node.attr:
perm = trans.get_attr_value("perm")
axis = node.get_attr_value("axis")
new_axis = perm[axis]
node.set_attr("axis", new_axis)
return True

def _simple_through_handler(self, trans, node):
return self._switch_transpose_and_node(node, trans)

Expand Down