|
29 | 29 | numpy_helper = LazyImport("onnx.numpy_helper")
|
30 | 30 | onnx_proto = LazyImport("onnx.onnx_pb")
|
31 | 31 | torch = LazyImport("torch")
|
32 |
| -onnxruntime = LazyImport("onnxruntime") |
33 | 32 | symbolic_shape_infer = LazyImport("onnxruntime.tools.symbolic_shape_infer")
|
34 | 33 | onnx = LazyImport("onnx")
|
35 | 34 |
|
@@ -599,40 +598,36 @@ def to_numpy(data):
|
599 | 598 | else:
|
600 | 599 | return data
|
601 | 600 |
|
| 601 | +def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, base_dir=""): |
| 602 | + """Symbolic shape inference.""" |
| 603 | + |
| 604 | + class SymbolicShapeInference(symbolic_shape_infer.SymbolicShapeInference): |
| 605 | + def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix="", base_dir=""): |
| 606 | + super().__init__(int_max, auto_merge, guess_output_rank, verbose, prefix) |
| 607 | + self.base_dir = base_dir |
| 608 | + |
| 609 | + def _get_value(self, node, idx): |
| 610 | + name = node.input[idx] |
| 611 | + assert name in self.sympy_data_ or name in self.initializers_ |
| 612 | + return ( |
| 613 | + self.sympy_data_[name] |
| 614 | + if name in self.sympy_data_ |
| 615 | + else numpy_helper.to_array(self.initializers_[name], base_dir=self.base_dir) |
| 616 | + ) |
602 | 617 |
|
603 |
| -class SymbolicShapeInference(symbolic_shape_infer.SymbolicShapeInference): |
604 |
| - """Shape inference for ONNX model.""" |
605 |
| - |
606 |
| - def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix="", base_dir=""): |
607 |
| - """Initialize Shape inference class.""" |
608 |
| - super().__init__(int_max, auto_merge, guess_output_rank, verbose, prefix) |
609 |
| - self.base_dir = base_dir |
610 |
| - |
611 |
| - def _get_value(self, node, idx): |
612 |
| - name = node.input[idx] |
613 |
| - assert name in self.sympy_data_ or name in self.initializers_ |
614 |
| - return ( |
615 |
| - self.sympy_data_[name] |
616 |
| - if name in self.sympy_data_ |
617 |
| - else numpy_helper.to_array(self.initializers_[name], base_dir=self.base_dir) |
618 |
| - ) |
619 |
| - |
620 |
| - @staticmethod |
621 |
| - def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, base_dir=""): |
622 |
| - """Symbolic shape inference.""" |
623 |
| - onnx_opset = symbolic_shape_infer.get_opset(in_mp) |
624 |
| - if (not onnx_opset) or onnx_opset < 7: |
625 |
| - logger.warning("Only support models of onnx opset 7 and above.") |
626 |
| - return None |
627 |
| - symbolic_shape_inference = SymbolicShapeInference( |
628 |
| - int_max, auto_merge, guess_output_rank, verbose, base_dir=base_dir |
629 |
| - ) |
630 |
| - all_shapes_inferred = False |
631 |
| - symbolic_shape_inference._preprocess(in_mp) |
632 |
| - while symbolic_shape_inference.run_: |
633 |
| - all_shapes_inferred = symbolic_shape_inference._infer_impl() |
634 |
| - symbolic_shape_inference._update_output_from_vi() |
635 |
| - if not all_shapes_inferred: |
636 |
| - onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True) |
637 |
| - raise Exception("Incomplete symbolic shape inference") |
638 |
| - return symbolic_shape_inference.out_mp_ |
| 618 | + onnx_opset = symbolic_shape_infer.get_opset(in_mp) |
| 619 | + if (not onnx_opset) or onnx_opset < 7: |
| 620 | + logger.warning("Only support models of onnx opset 7 and above.") |
| 621 | + return None |
| 622 | + symbolic_shape_inference = SymbolicShapeInference( |
| 623 | + int_max, auto_merge, guess_output_rank, verbose, base_dir=base_dir |
| 624 | + ) |
| 625 | + all_shapes_inferred = False |
| 626 | + symbolic_shape_inference._preprocess(in_mp) |
| 627 | + while symbolic_shape_inference.run_: |
| 628 | + all_shapes_inferred = symbolic_shape_inference._infer_impl() |
| 629 | + symbolic_shape_inference._update_output_from_vi() |
| 630 | + if not all_shapes_inferred: |
| 631 | + onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True) |
| 632 | + raise Exception("Incomplete symbolic shape inference") |
| 633 | + return symbolic_shape_inference.out_mp_ |
0 commit comments