Skip to content

Commit 1f10a92

Browse files
committed
fix import bug
Signed-off-by: yuwenzho <[email protected]>
1 parent 966aa9b commit 1f10a92

File tree

2 files changed

+34
-39
lines changed

2 files changed

+34
-39
lines changed

neural_compressor/adaptor/ox_utils/util.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
numpy_helper = LazyImport("onnx.numpy_helper")
3030
onnx_proto = LazyImport("onnx.onnx_pb")
3131
torch = LazyImport("torch")
32-
onnxruntime = LazyImport("onnxruntime")
3332
symbolic_shape_infer = LazyImport("onnxruntime.tools.symbolic_shape_infer")
3433
onnx = LazyImport("onnx")
3534

@@ -599,40 +598,36 @@ def to_numpy(data):
599598
else:
600599
return data
601600

601+
def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, base_dir=""):
602+
"""Symbolic shape inference."""
603+
604+
class SymbolicShapeInference(symbolic_shape_infer.SymbolicShapeInference):
605+
def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix="", base_dir=""):
606+
super().__init__(int_max, auto_merge, guess_output_rank, verbose, prefix)
607+
self.base_dir = base_dir
608+
609+
def _get_value(self, node, idx):
610+
name = node.input[idx]
611+
assert name in self.sympy_data_ or name in self.initializers_
612+
return (
613+
self.sympy_data_[name]
614+
if name in self.sympy_data_
615+
else numpy_helper.to_array(self.initializers_[name], base_dir=self.base_dir)
616+
)
602617

603-
class SymbolicShapeInference(symbolic_shape_infer.SymbolicShapeInference):
604-
"""Shape inference for ONNX model."""
605-
606-
def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix="", base_dir=""):
607-
"""Initialize Shape inference class."""
608-
super().__init__(int_max, auto_merge, guess_output_rank, verbose, prefix)
609-
self.base_dir = base_dir
610-
611-
def _get_value(self, node, idx):
612-
name = node.input[idx]
613-
assert name in self.sympy_data_ or name in self.initializers_
614-
return (
615-
self.sympy_data_[name]
616-
if name in self.sympy_data_
617-
else numpy_helper.to_array(self.initializers_[name], base_dir=self.base_dir)
618-
)
619-
620-
@staticmethod
621-
def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, base_dir=""):
622-
"""Symbolic shape inference."""
623-
onnx_opset = symbolic_shape_infer.get_opset(in_mp)
624-
if (not onnx_opset) or onnx_opset < 7:
625-
logger.warning("Only support models of onnx opset 7 and above.")
626-
return None
627-
symbolic_shape_inference = SymbolicShapeInference(
628-
int_max, auto_merge, guess_output_rank, verbose, base_dir=base_dir
629-
)
630-
all_shapes_inferred = False
631-
symbolic_shape_inference._preprocess(in_mp)
632-
while symbolic_shape_inference.run_:
633-
all_shapes_inferred = symbolic_shape_inference._infer_impl()
634-
symbolic_shape_inference._update_output_from_vi()
635-
if not all_shapes_inferred:
636-
onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
637-
raise Exception("Incomplete symbolic shape inference")
638-
return symbolic_shape_inference.out_mp_
618+
onnx_opset = symbolic_shape_infer.get_opset(in_mp)
619+
if (not onnx_opset) or onnx_opset < 7:
620+
logger.warning("Only support models of onnx opset 7 and above.")
621+
return None
622+
symbolic_shape_inference = SymbolicShapeInference(
623+
int_max, auto_merge, guess_output_rank, verbose, base_dir=base_dir
624+
)
625+
all_shapes_inferred = False
626+
symbolic_shape_inference._preprocess(in_mp)
627+
while symbolic_shape_inference.run_:
628+
all_shapes_inferred = symbolic_shape_inference._infer_impl()
629+
symbolic_shape_inference._update_output_from_vi()
630+
if not all_shapes_inferred:
631+
onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
632+
raise Exception("Incomplete symbolic shape inference")
633+
return symbolic_shape_inference.out_mp_

neural_compressor/model/onnx_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,9 +1045,9 @@ def split_model_with_node(
10451045
if shape_infer:
10461046
try:
10471047
# need ort.GraphOptimizationLevel <= ORT_ENABLE_BASIC
1048-
from neural_compressor.adaptor.ox_utils.util import SymbolicShapeInference
1048+
from neural_compressor.adaptor.ox_utils.util import infer_shapes
10491049

1050-
self._model = SymbolicShapeInference.infer_shapes(
1050+
self._model = infer_shapes(
10511051
self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path)
10521052
)
10531053
except Exception as e: # pragma: no cover

0 commit comments

Comments
 (0)