|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import json |
| 4 | +import numpy as np |
| 5 | +import functools |
| 6 | +import importlib.util |
| 7 | + |
| 8 | +from .._quant_common.helper_modules import * |
| 9 | +from .._quant_common.quant_config import get_hqt_config |
| 10 | +from ..utils.logger import logger |
| 11 | + |
| 12 | +deepspeed_exists = False |
| 13 | +if importlib.util.find_spec("deepspeed"): # check if deepspeed is installed |
| 14 | + deepspeed_exists = True |
| 15 | + |
| 16 | +UNMEASURED_MODELS = "UnmeasuredModels" |
| 17 | + |
| 18 | + |
| 19 | +class ModuleInfo: |
| 20 | + def __init__(self, type, patched_module): |
| 21 | + self.type = type |
| 22 | + self.patched_module = patched_module |
| 23 | + |
| 24 | + |
| 25 | +class ModuleConfig: |
| 26 | + def __init__(self, inputs=(None,), outputs=(None,), params=None): |
| 27 | + self.inputs = inputs |
| 28 | + self.outputs = outputs |
| 29 | + self.params = params if params is not None else {} |
| 30 | + |
| 31 | + |
| 32 | +class ModuleExtraConfig: |
| 33 | + def __init__(self, inputs=(None,), outputs=(None,), params=None, scale=None, config_params=None): |
| 34 | + self.inputs = inputs |
| 35 | + self.outputs = outputs |
| 36 | + self.params = params if params is not None else {} |
| 37 | + self.scale = scale |
| 38 | + self.config_params = config_params if config_params is not None else {} |
| 39 | + |
| 40 | + |
| 41 | +class ModuleType: |
| 42 | + def __init__(self, num_inputs, param_names, num_outputs, required_output): |
| 43 | + self.num_inputs = num_inputs |
| 44 | + self.param_names = param_names |
| 45 | + self.num_outputs = num_outputs |
| 46 | + self.required_output = required_output |
| 47 | + |
| 48 | + |
| 49 | +mod_types = { |
| 50 | + "linear": ModuleType(1, ["weight"], 1, False), |
| 51 | + "matmul": ModuleType(2, [], 1, False), |
| 52 | + "kv_cache": ModuleType(1, [], 1, False), |
| 53 | + "softmax": ModuleType(1, [], 1, True), |
| 54 | + "fused_sdpa": ModuleType(3, [], 2, True), |
| 55 | +} |
| 56 | +descale_fcn = lambda x, scale: torch.mul(x, scale) |
| 57 | +scale_fcn = lambda x, scale: torch.div(x, scale) |
| 58 | +mat_scale_fcn = lambda x, scale_col, scale_row: torch.div(torch.div(x, scale_col), scale_row) |
| 59 | +cast_fcn = lambda x, dtype: x.to(dtype=dtype) |
| 60 | +cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0] |
| 61 | +cast_from_fp8_fcn = lambda x, dtype, scale=None: torch.ops.hpu.cast_from_fp8(x, scale, dtype) |
| 62 | + |
| 63 | + |
| 64 | +class ShapeList: |
| 65 | + data = None |
| 66 | + |
| 67 | + |
| 68 | +def rec_fn(x, fn): |
| 69 | + if isinstance(x, dict): |
| 70 | + return {k: rec_fn(x[k], fn) for k in x} |
| 71 | + elif isinstance(x, list): |
| 72 | + return [rec_fn(k, fn) for k in x] |
| 73 | + elif isinstance(x, tuple): |
| 74 | + return tuple([rec_fn(k, fn) for k in x]) |
| 75 | + else: |
| 76 | + return fn(x) |
| 77 | + |
| 78 | + |
| 79 | +def np_to_pt(x): |
| 80 | + return rec_fn(x, lambda x: torch.tensor(x) if isinstance(x, np.ndarray) else x) |
| 81 | + |
| 82 | + |
| 83 | +def pt_to_np(x): |
| 84 | + return rec_fn( |
| 85 | + x, |
| 86 | + lambda x: (x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x), |
| 87 | + ) |
| 88 | + |
| 89 | + |
| 90 | +def np_to_list(x): |
| 91 | + return rec_fn(x, lambda x: x.tolist() if isinstance(x, np.ndarray) else x) |
| 92 | + |
| 93 | + |
| 94 | +def list_to_np(x): |
| 95 | + return rec_fn(x, lambda x: np.array(x) if isinstance(x, list) else x) |
| 96 | + |
| 97 | + |
| 98 | +def save_json(d, fname): |
| 99 | + with open(fname, "w") as f: |
| 100 | + json.dump(d, f, indent=4) |
| 101 | + |
| 102 | + |
| 103 | +def load_json(fname): |
| 104 | + with open(fname, "r") as f: |
| 105 | + d = json.load(f) |
| 106 | + return d |
| 107 | + |
| 108 | + |
| 109 | +def save_npz(d, fname): |
| 110 | + np.savez(fname, d) |
| 111 | + |
| 112 | + |
| 113 | +def load_npz(fname): |
| 114 | + d = np.load(fname, allow_pickle=True) |
| 115 | + return d["arr_0"].item() |
| 116 | + |
| 117 | + |
| 118 | +def save_file(model, d, source_format, fname, mode): |
| 119 | + config = get_hqt_config(model) |
| 120 | + logger.debug("Saving %s file: %s", mode, fname) |
| 121 | + ext = os.path.splitext(fname)[1] |
| 122 | + target_format = file_functions[ext][0] |
| 123 | + dc = rec_fn(d, format_functions[(source_format, target_format)]) |
| 124 | + df = { |
| 125 | + "GlobalRank": config.cfg["global_rank"], |
| 126 | + "LocalRank": config.cfg["local_rank"], |
| 127 | + "Mode": mode, |
| 128 | + "Nodes": dc, |
| 129 | + } |
| 130 | + try: |
| 131 | + file_functions[ext][1](df, fname) |
| 132 | + except: |
| 133 | + pass |
| 134 | + |
| 135 | + |
| 136 | +# convert module config data to other format |
| 137 | +def module_convert(m, fcn): |
| 138 | + mt = ModuleConfig( |
| 139 | + tuple([fcn(x) for x in m.inputs]), |
| 140 | + tuple([fcn(m.outputs)],) if type(m.outputs) == np.ndarray else tuple([fcn(y) for y in m.outputs]), |
| 141 | + {k: fcn(m.params[k]) for k in m.params}, |
| 142 | + ) |
| 143 | + return mt |
| 144 | + |
| 145 | + |
| 146 | +def fix_fields(d): |
| 147 | + if "input" in d: |
| 148 | + d["inputs"] = d.pop("input") |
| 149 | + if "output" in d: |
| 150 | + d["outputs"] = d.pop("output") |
| 151 | + return d |
| 152 | + |
| 153 | + |
| 154 | +def load_file(fname, target_format, fail_on_file_not_exist): |
| 155 | + logger.debug("Loading file: %s", fname) |
| 156 | + ext = os.path.splitext(fname)[1] |
| 157 | + source_format = file_functions[ext][0] |
| 158 | + d = {} |
| 159 | + if os.path.isfile(fname): |
| 160 | + d = file_functions[ext][2](fname) |
| 161 | + elif fail_on_file_not_exist: |
| 162 | + raise FileNotFoundError(f"Failed to load file {fname}") |
| 163 | + if "Nodes" in d: |
| 164 | + dc = {k: ModuleConfig(**fix_fields(d["Nodes"][k])) for k in d["Nodes"]} |
| 165 | + dc = {k: module_convert(dc[k], format_functions[(source_format, target_format)]) for k in dc} |
| 166 | + else: |
| 167 | + dc = {} |
| 168 | + return dc |
| 169 | + |
| 170 | + |
| 171 | +def save_scales(model, d, source_format, fname): |
| 172 | + dc = {k: d[k].__dict__ for k in d} |
| 173 | + save_file(model, dc, source_format, fname, "Scale") |
| 174 | + |
| 175 | + |
| 176 | +def load_scales(fname, target_format): |
| 177 | + logger.debug("Loading scales file %s", fname) |
| 178 | + d = load_file(fname, target_format, False) |
| 179 | + return d |
| 180 | + |
| 181 | + |
| 182 | +def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype): |
| 183 | + scales_temp = {k: scales_obj[k].__dict__ for k in scales_obj} |
| 184 | + scales_temp = format_functions_rec((scales_file_format, torch.Tensor))(scales_temp) |
| 185 | + scales_temp = rec_fn(scales_temp, lambda x: x.to(dtype=hp_dtype, device="hpu")) |
| 186 | + scales = {k: ModuleConfig(**scales_temp[k]) for k in scales_temp} |
| 187 | + return scales |
| 188 | + |
| 189 | + |
| 190 | +file_functions = { |
| 191 | + ".json": (list, save_json, load_json), |
| 192 | + ".npz": (np.ndarray, save_npz, load_npz), |
| 193 | +} |
| 194 | + |
| 195 | +format_functions = { |
| 196 | + (torch.Tensor, torch.Tensor): lambda x: x, |
| 197 | + (np.ndarray, np.ndarray): lambda x: x, |
| 198 | + (list, list): lambda x: x, |
| 199 | + (torch.Tensor, np.ndarray): lambda x: x.detach().cpu().float().numpy(), |
| 200 | + (torch.Tensor, list): lambda x: x.detach().cpu().float().numpy().tolist(), |
| 201 | + (np.ndarray, torch.Tensor): torch.tensor, |
| 202 | + (np.ndarray, list): lambda x: x.tolist(), |
| 203 | + (list, torch.Tensor): torch.tensor, |
| 204 | + (list, np.ndarray): lambda x: np.array(x), |
| 205 | + (list, ShapeList): lambda x: [int(s) for s in x[0]], |
| 206 | +} |
| 207 | + |
| 208 | + |
| 209 | +format_functions_rec = lambda k: functools.partial(rec_fn, fn=format_functions[k]) |
| 210 | + |
| 211 | +mod_default_dict = { |
| 212 | + "Matmul": ModuleInfo("matmul", PatchedMatmul), |
| 213 | + "Linear": ModuleInfo("linear", PatchedLinear), |
| 214 | + "RowParallelLinear": ModuleInfo("linear", PatchedRowParallelLinear), |
| 215 | + "ColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear), |
| 216 | + "MergedColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear), |
| 217 | + "QKVParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear), |
| 218 | + "FalconLinear": ModuleInfo("linear", PatchedLinear), |
| 219 | + "KVCache": ModuleInfo("kv_cache", PatchedKVCache), |
| 220 | + "VLLMKVCache": ModuleInfo("kv_cache", PatchedVLLMKVCache), |
| 221 | + "Conv2d": ModuleInfo("linear", PatchedConv2d), |
| 222 | + "LoRACompatibleLinear": ModuleInfo("linear", PatchedLoRACompatibleLinear), |
| 223 | + "LoRACompatibleConv": ModuleInfo("linear", PatchedLoRACompatibleConv), |
| 224 | + "Softmax": ModuleInfo("softmax", PatchedSoftmax), |
| 225 | + "ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA), |
| 226 | +} |
| 227 | + |
| 228 | + |
| 229 | +if deepspeed_exists: |
| 230 | + mod_default_dict.update( |
| 231 | + { |
| 232 | + "LinearLayer": ModuleInfo("linear", PatchedLinear), |
| 233 | + "LinearAllreduce": ModuleInfo("linear", PatchedLinearAllReduce), |
| 234 | + "ScopedLinearAllReduce": ModuleInfo("linear", PatchedLinearAllReduce), |
| 235 | + "LmHeadLinearAllreduce": ModuleInfo("linear", PatchedLmHeadLinearAllreduce), |
| 236 | + } |
| 237 | + ) |
| 238 | + |
| 239 | + |
| 240 | +class ModInstInfo: |
| 241 | + def __init__(self, name, parent): |
| 242 | + self.name = name |
| 243 | + self.parent = parent |
| 244 | + |
| 245 | + |
| 246 | +parent_child_mod_dict = {} |
| 247 | + |
| 248 | + |
| 249 | +def generate_model_info(model): |
| 250 | + def create_mod_info_recursion(parent): |
| 251 | + for name, mod in parent.named_children(): |
| 252 | + parent_child_mod_dict[mod] = ModInstInfo(name, parent) |
| 253 | + create_mod_info_recursion(mod) |
| 254 | + |
| 255 | + create_mod_info_recursion(model) |
0 commit comments