7
7
import typing
8
8
from dataclasses import dataclass
9
9
10
- import habana_frameworks as htcore
11
10
import torch
12
- from habana_quantization_toolkit ._core .common import mod_default_dict
13
- from habana_quantization_toolkit ._quant_common .quant_config import Fp8cfg , QuantMode , ScaleMethod
11
+ from neural_compressor . torch . algorithms . fp8_quant ._core .common import mod_default_dict
12
+ from neural_compressor . torch . algorithms . fp8_quant ._quant_common .quant_config import Fp8cfg , QuantMode , ScaleMethod
14
13
15
14
16
15
@dataclass
@@ -60,8 +59,6 @@ def run_accuracy_test(
60
59
This test also makes asserts the quantization actually happened.
61
60
This may be moved to another tests in the future.
62
61
63
- You can use the generate_test_vectors.py script to generate input test vectors.
64
-
65
62
Args:
66
63
module_class: The reference module class to test.
67
64
This should be the direct module to test, e.g. Matmul, Linear, etc.
@@ -82,7 +79,7 @@ def run_accuracy_test(
82
79
measure_vectors , test_vectors = itertools .tee (test_vectors )
83
80
84
81
for mode in [QuantMode .MEASURE , QuantMode .QUANTIZE ]:
85
- import habana_quantization_toolkit . prepare_quant .prepare_model as hqt
82
+ import neural_compressor . torch . algorithms . fp8_quant . prepare_quant .prepare_model as prepare_model
86
83
87
84
reference_model = WrapModel (module_class , seed , * module_args , ** module_kwargs )
88
85
quantized_model = WrapModel (module_class , seed , * module_args , ** module_kwargs )
@@ -92,7 +89,7 @@ def run_accuracy_test(
92
89
lp_dtype = lp_dtype ,
93
90
scale_method = scale_method ,
94
91
)
95
- hqt ._prep_model_with_predefined_config (quantized_model , config = config )
92
+ prepare_model ._prep_model_with_predefined_config (quantized_model , config = config )
96
93
97
94
_assert_quantized_correctly (reference_model = reference_model , quantized_model = quantized_model )
98
95
@@ -120,7 +117,7 @@ def run_accuracy_test(
120
117
f"\n { scale_method .name = } "
121
118
)
122
119
123
- hqt .finish_measurements (quantized_model )
120
+ prepare_model .finish_measurements (quantized_model )
124
121
125
122
126
123
def _set_optional_seed (* , module_class : typing .Type [M ], seed : typing .Optional [int ]):
0 commit comments