Skip to content

Commit da4bcd2

Browse files
committed
[SW-184714] Port HQT code into INC
HQT lib content was copied as is under fp8_quant Tests were copied to 3.x torch location Change-Id: Iec6e1fa7ac4bf1df1c95b429524c40e32bc13ac9
1 parent 768c2a4 commit da4bcd2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+4465
-7
lines changed

neural_compressor/torch/algorithms/fp8_quant/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
restore_patched_module,
1919
with_patched_module,
2020
)
21+
from neural_compressor.torch.algorithms.fp8_quant.prepare_quant.prepare_model import finish_measurements, prep_model
2122
from neural_compressor.torch.algorithms.fp8_quant.fp8_quant import FP8Quantizer

neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py

Whitespace-only changes.
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import os
2+
import torch
3+
import json
4+
import numpy as np
5+
import functools
6+
import importlib.util
7+
8+
from .._quant_common.helper_modules import *
9+
from .._quant_common.quant_config import get_hqt_config
10+
from ..utils.logger import logger
11+
12+
deepspeed_exists = False
13+
if importlib.util.find_spec("deepspeed"): # check if deepspeed is installed
14+
deepspeed_exists = True
15+
16+
UNMEASURED_MODELS = "UnmeasuredModels"
17+
18+
19+
class ModuleInfo:
20+
def __init__(self, type, patched_module):
21+
self.type = type
22+
self.patched_module = patched_module
23+
24+
25+
class ModuleConfig:
26+
def __init__(self, inputs=(None,), outputs=(None,), params=None):
27+
self.inputs = inputs
28+
self.outputs = outputs
29+
self.params = params if params is not None else {}
30+
31+
32+
class ModuleExtraConfig:
33+
def __init__(self, inputs=(None,), outputs=(None,), params=None, scale=None, config_params=None):
34+
self.inputs = inputs
35+
self.outputs = outputs
36+
self.params = params if params is not None else {}
37+
self.scale = scale
38+
self.config_params = config_params if config_params is not None else {}
39+
40+
41+
class ModuleType:
42+
def __init__(self, num_inputs, param_names, num_outputs, required_output):
43+
self.num_inputs = num_inputs
44+
self.param_names = param_names
45+
self.num_outputs = num_outputs
46+
self.required_output = required_output
47+
48+
49+
mod_types = {
50+
"linear": ModuleType(1, ["weight"], 1, False),
51+
"matmul": ModuleType(2, [], 1, False),
52+
"kv_cache": ModuleType(1, [], 1, False),
53+
"softmax": ModuleType(1, [], 1, True),
54+
"fused_sdpa": ModuleType(3, [], 2, True),
55+
}
56+
descale_fcn = lambda x, scale: torch.mul(x, scale)
57+
scale_fcn = lambda x, scale: torch.div(x, scale)
58+
mat_scale_fcn = lambda x, scale_col, scale_row: torch.div(torch.div(x, scale_col), scale_row)
59+
cast_fcn = lambda x, dtype: x.to(dtype=dtype)
60+
cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0]
61+
cast_from_fp8_fcn = lambda x, dtype, scale=None: torch.ops.hpu.cast_from_fp8(x, scale, dtype)
62+
63+
64+
class ShapeList:
65+
data = None
66+
67+
68+
def rec_fn(x, fn):
69+
if isinstance(x, dict):
70+
return {k: rec_fn(x[k], fn) for k in x}
71+
elif isinstance(x, list):
72+
return [rec_fn(k, fn) for k in x]
73+
elif isinstance(x, tuple):
74+
return tuple([rec_fn(k, fn) for k in x])
75+
else:
76+
return fn(x)
77+
78+
79+
def np_to_pt(x):
80+
return rec_fn(x, lambda x: torch.tensor(x) if isinstance(x, np.ndarray) else x)
81+
82+
83+
def pt_to_np(x):
84+
return rec_fn(
85+
x,
86+
lambda x: (x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x),
87+
)
88+
89+
90+
def np_to_list(x):
91+
return rec_fn(x, lambda x: x.tolist() if isinstance(x, np.ndarray) else x)
92+
93+
94+
def list_to_np(x):
95+
return rec_fn(x, lambda x: np.array(x) if isinstance(x, list) else x)
96+
97+
98+
def save_json(d, fname):
99+
with open(fname, "w") as f:
100+
json.dump(d, f, indent=4)
101+
102+
103+
def load_json(fname):
104+
with open(fname, "r") as f:
105+
d = json.load(f)
106+
return d
107+
108+
109+
def save_npz(d, fname):
110+
np.savez(fname, d)
111+
112+
113+
def load_npz(fname):
114+
d = np.load(fname, allow_pickle=True)
115+
return d["arr_0"].item()
116+
117+
118+
def save_file(model, d, source_format, fname, mode):
119+
config = get_hqt_config(model)
120+
logger.debug("Saving %s file: %s", mode, fname)
121+
ext = os.path.splitext(fname)[1]
122+
target_format = file_functions[ext][0]
123+
dc = rec_fn(d, format_functions[(source_format, target_format)])
124+
df = {
125+
"GlobalRank": config.cfg["global_rank"],
126+
"LocalRank": config.cfg["local_rank"],
127+
"Mode": mode,
128+
"Nodes": dc,
129+
}
130+
try:
131+
file_functions[ext][1](df, fname)
132+
except:
133+
pass
134+
135+
136+
# convert module config data to other format
137+
def module_convert(m, fcn):
138+
mt = ModuleConfig(
139+
tuple([fcn(x) for x in m.inputs]),
140+
tuple([fcn(m.outputs)],) if type(m.outputs) == np.ndarray else tuple([fcn(y) for y in m.outputs]),
141+
{k: fcn(m.params[k]) for k in m.params},
142+
)
143+
return mt
144+
145+
146+
def fix_fields(d):
147+
if "input" in d:
148+
d["inputs"] = d.pop("input")
149+
if "output" in d:
150+
d["outputs"] = d.pop("output")
151+
return d
152+
153+
154+
def load_file(fname, target_format, fail_on_file_not_exist):
155+
logger.debug("Loading file: %s", fname)
156+
ext = os.path.splitext(fname)[1]
157+
source_format = file_functions[ext][0]
158+
d = {}
159+
if os.path.isfile(fname):
160+
d = file_functions[ext][2](fname)
161+
elif fail_on_file_not_exist:
162+
raise FileNotFoundError(f"Failed to load file {fname}")
163+
if "Nodes" in d:
164+
dc = {k: ModuleConfig(**fix_fields(d["Nodes"][k])) for k in d["Nodes"]}
165+
dc = {k: module_convert(dc[k], format_functions[(source_format, target_format)]) for k in dc}
166+
else:
167+
dc = {}
168+
return dc
169+
170+
171+
def save_scales(model, d, source_format, fname):
172+
dc = {k: d[k].__dict__ for k in d}
173+
save_file(model, dc, source_format, fname, "Scale")
174+
175+
176+
def load_scales(fname, target_format):
177+
logger.debug("Loading scales file %s", fname)
178+
d = load_file(fname, target_format, False)
179+
return d
180+
181+
182+
def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype):
183+
scales_temp = {k: scales_obj[k].__dict__ for k in scales_obj}
184+
scales_temp = format_functions_rec((scales_file_format, torch.Tensor))(scales_temp)
185+
scales_temp = rec_fn(scales_temp, lambda x: x.to(dtype=hp_dtype, device="hpu"))
186+
scales = {k: ModuleConfig(**scales_temp[k]) for k in scales_temp}
187+
return scales
188+
189+
190+
file_functions = {
191+
".json": (list, save_json, load_json),
192+
".npz": (np.ndarray, save_npz, load_npz),
193+
}
194+
195+
format_functions = {
196+
(torch.Tensor, torch.Tensor): lambda x: x,
197+
(np.ndarray, np.ndarray): lambda x: x,
198+
(list, list): lambda x: x,
199+
(torch.Tensor, np.ndarray): lambda x: x.detach().cpu().float().numpy(),
200+
(torch.Tensor, list): lambda x: x.detach().cpu().float().numpy().tolist(),
201+
(np.ndarray, torch.Tensor): torch.tensor,
202+
(np.ndarray, list): lambda x: x.tolist(),
203+
(list, torch.Tensor): torch.tensor,
204+
(list, np.ndarray): lambda x: np.array(x),
205+
(list, ShapeList): lambda x: [int(s) for s in x[0]],
206+
}
207+
208+
209+
format_functions_rec = lambda k: functools.partial(rec_fn, fn=format_functions[k])
210+
211+
mod_default_dict = {
212+
"Matmul": ModuleInfo("matmul", PatchedMatmul),
213+
"Linear": ModuleInfo("linear", PatchedLinear),
214+
"RowParallelLinear": ModuleInfo("linear", PatchedRowParallelLinear),
215+
"ColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear),
216+
"MergedColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear),
217+
"QKVParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear),
218+
"FalconLinear": ModuleInfo("linear", PatchedLinear),
219+
"KVCache": ModuleInfo("kv_cache", PatchedKVCache),
220+
"VLLMKVCache": ModuleInfo("kv_cache", PatchedVLLMKVCache),
221+
"Conv2d": ModuleInfo("linear", PatchedConv2d),
222+
"LoRACompatibleLinear": ModuleInfo("linear", PatchedLoRACompatibleLinear),
223+
"LoRACompatibleConv": ModuleInfo("linear", PatchedLoRACompatibleConv),
224+
"Softmax": ModuleInfo("softmax", PatchedSoftmax),
225+
"ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA),
226+
}
227+
228+
229+
if deepspeed_exists:
230+
mod_default_dict.update(
231+
{
232+
"LinearLayer": ModuleInfo("linear", PatchedLinear),
233+
"LinearAllreduce": ModuleInfo("linear", PatchedLinearAllReduce),
234+
"ScopedLinearAllReduce": ModuleInfo("linear", PatchedLinearAllReduce),
235+
"LmHeadLinearAllreduce": ModuleInfo("linear", PatchedLmHeadLinearAllreduce),
236+
}
237+
)
238+
239+
240+
class ModInstInfo:
241+
def __init__(self, name, parent):
242+
self.name = name
243+
self.parent = parent
244+
245+
246+
parent_child_mod_dict = {}
247+
248+
249+
def generate_model_info(model):
250+
def create_mod_info_recursion(parent):
251+
for name, mod in parent.named_children():
252+
parent_child_mod_dict[mod] = ModInstInfo(name, parent)
253+
create_mod_info_recursion(mod)
254+
255+
create_mod_info_recursion(model)

0 commit comments

Comments
 (0)