Skip to content

Commit dc4b5f5

Browse files
zyuwen-habanaxinhe3
authored andcommitted
[SW-192809] fix json_file bug when instantiating FP8Config class
Change-Id: I4a715d0a706efe20ccdb49033755cabbc729ccdc Signed-off-by: Zhou Yuwen <[email protected]>
1 parent cfe135f commit dc4b5f5

File tree

5 files changed

+137
-12
lines changed

5 files changed

+137
-12
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
from __future__ import annotations
216

317
import json
418
import os
5-
import torch
6-
from enum import Enum, Flag, auto
719
from dataclasses import dataclass
20+
from enum import Enum, Flag, auto
821
from json.decoder import JSONDecodeError
922
from typing import Any, Mapping
23+
1024
import habana_frameworks.torch.utils.experimental as htexp
25+
import torch
1126

1227
from ..utils.logger import logger
1328

@@ -121,6 +136,16 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
121136
else:
122137
raise ValueError("invalid fp8_config in custom config. Enter E4M3 or E5M2")
123138

139+
if keys == "hp_dtype":
140+
if custom_config[keys].lower() == "bf16":
141+
custom_config[keys] = torch.bfloat16
142+
elif custom_config[keys].lower() == "fp16":
143+
custom_config[keys] = torch.float16
144+
elif custom_config[keys].lower() == "fp32":
145+
custom_config[keys] = torch.float32
146+
else:
147+
raise ValueError("invalid hp_dtype in custom config. Enter bf16, fp16 or fp32")
148+
124149
if keys == "scale_method":
125150
if custom_config[keys].lower() == "unit_scale":
126151
custom_config[keys] = ScaleMethod.UNIT_SCALE
@@ -176,7 +201,7 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
176201
# If seperate_measure_files is True (default value), then it is assumed that there are multiple distinct measure and scale files
177202
# and they are stored in / loaded from paths with the correct index as a suffix. Else, only one is searched for.
178203
measured_global_config["local_rank"] = (
179-
local_rank if local_rank >= 0 and (custom_config.get("seperate_measure_files", True) == True) else None
204+
local_rank if local_rank >= 0 and custom_config.get("seperate_measure_files", True) else None
180205
)
181206

182207
base_name = measured_global_config["dump_stats_path"].split("/")[-1]
@@ -185,7 +210,7 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
185210
os.makedirs(folder_name, exist_ok=True)
186211
worker_st = (
187212
""
188-
if measured_global_config["local_rank"] == None
213+
if measured_global_config["local_rank"] is None
189214
else "_" + str(measured_global_config["local_rank"]) + "_" + str(measured_global_config["world_size"])
190215
)
191216
measured_global_config["shape_file"] = measured_global_config["dump_stats_path"] + "_hooks_shape" + worker_st

neural_compressor/torch/quantization/config.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ def __init__(
12591259
self,
12601260
dump_stats_path: str = "./hqt_output/measure",
12611261
fp8_config: str = "E4M3",
1262-
hp_dtype: torch.dtype = torch.bfloat16,
1262+
hp_dtype: str = "bf16",
12631263
blocklist: dict = {'names': [], 'types': ()},
12641264
allowlist: dict = {'names': [], 'types': FP8_WHITE_LIST},
12651265
mode: str = "AUTO",
@@ -1294,13 +1294,6 @@ def quantize(self):
12941294

12951295
@property
12961296
def json_file(self):
1297-
if self._json_file is None:
1298-
import tempfile
1299-
from pathlib import Path
1300-
1301-
json_file_tmp = tempfile.NamedTemporaryFile(suffix=".json")
1302-
self.to_json_file(json_file_tmp.name)
1303-
self.json_file(json_file_tmp.name)
13041297
return self._json_file
13051298

13061299
@json_file.setter
@@ -1315,6 +1308,14 @@ def from_json_file(cls, filename):
13151308
config.json_file = filename
13161309
return config
13171310

1311+
def save_temp_json_file(self):
1312+
import tempfile
1313+
from pathlib import Path
1314+
1315+
json_file_tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
1316+
self.to_json_file(json_file_tmp.name)
1317+
self._json_file = json_file_tmp.name
1318+
13181319
@classmethod
13191320
def get_config_set_for_tuning(cls) -> Union[None, "FP8Config", List["FP8Config"]]:
13201321
# just a simple example here
@@ -1361,6 +1362,8 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
13611362
def to_config_mapping(
13621363
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
13631364
):
1365+
if self.json_file is None:
1366+
self.save_temp_json_file()
13641367
config_mapping = OrderedDict()
13651368
if config_list is None:
13661369
config_list = [self]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"mode": "QUANTIZE",
3+
"observer": "maxabs",
4+
"scale_method": "maxabs_hw",
5+
"allowlist": {
6+
"types": [],
7+
"names": []
8+
},
9+
"blocklist": {
10+
"types": [],
11+
"names": [
12+
"lm_head"
13+
]
14+
},
15+
"dump_stats_path": "./test_outputs/unit_test"
16+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"mode": "MEASURE",
3+
"observer": "maxabs",
4+
"allowlist": {
5+
"types": [],
6+
"names": []
7+
},
8+
"blocklist": {
9+
"types": [],
10+
"names": [
11+
"lm_head"
12+
]
13+
},
14+
"dump_stats_path": "./test_outputs/unit_test"
15+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import copy
2+
import shutil
3+
4+
import pytest
5+
import torch
6+
import transformers
7+
8+
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedLinear
9+
from neural_compressor.torch.quantization import (
10+
FP8Config,
11+
convert,
12+
finalize_calibration,
13+
get_default_fp8_config,
14+
prepare,
15+
quantize,
16+
)
17+
from neural_compressor.torch.utils import is_hpex_available
18+
19+
20+
@torch.no_grad()
21+
def calib_func(model):
22+
example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long).to("hpu")
23+
for i in range(2):
24+
model(example_inputs)
25+
26+
27+
@pytest.mark.skipif(not is_hpex_available(), reason="HPU environment is required!")
28+
class TestFP8StaticQuant:
29+
def setup_class(self):
30+
self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
31+
"hf-internal-testing/tiny-random-GPTJForCausalLM",
32+
device_map="cpu",
33+
)
34+
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long)
35+
36+
def teardown_class(self):
37+
shutil.rmtree("test_ouputs", ignore_errors=True)
38+
39+
def test_one_step_quant(self):
40+
model = copy.deepcopy(self.tiny_gptj)
41+
qconfig = FP8Config(fp8_config="E4M3")
42+
model = prepare(model, qconfig)
43+
assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not prepared."
44+
calib_func(model)
45+
model = convert(model)
46+
assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not quantized."
47+
assert (
48+
model.transformer.h[0].attn.k_proj.quant_input.lp_dtype == torch.float8_e4m3fn
49+
), "k_proj input dtype is not torch.float8_e4m3fn."
50+
51+
def test_two_step_quant(self):
52+
# step 1: measurement
53+
model = copy.deepcopy(self.tiny_gptj)
54+
config = FP8Config.from_json_file("test_fp8_jsons/test_measure.json")
55+
model = prepare(model, config)
56+
calib_func(model)
57+
finalize_calibration(model)
58+
assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not observed."
59+
# step 2: quantize based on measurement
60+
model = copy.deepcopy(self.tiny_gptj)
61+
config = FP8Config.from_json_file("test_fp8_jsons/test_hw_quant.json")
62+
model = convert(model, config)
63+
assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not quantized."
64+
assert (
65+
model.transformer.h[0].attn.k_proj.quant_input.lp_dtype == torch.float8_e4m3fn
66+
), "k_proj input dtype is not torch.float8_e4m3fn."

0 commit comments

Comments
 (0)