Skip to content

Commit 9d543af

Browse files
Xu-KaiXu Kai
authored andcommitted
[inference] add reference and fix some bugs (hpcaitech#4937)
* add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <[email protected]>
1 parent 8633a87 commit 9d543af

File tree

7 files changed

+24
-10
lines changed

7 files changed

+24
-10
lines changed

colossalai/inference/quant/smoothquant/models/base_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samp
132132
mean_scale = np.mean([v["input"] for v in act_dict.values()])
133133
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
134134

135+
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
135136
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
136137
model.eval()
137138
device = next(model.parameters()).device
@@ -163,6 +164,7 @@ def stat_input_hook(m, x, y, name):
163164

164165
return act_scales
165166

167+
# Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
166168
@torch.no_grad()
167169
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
168170
if not isinstance(fcs, list):
@@ -189,6 +191,7 @@ def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
189191
def create_quantized_model(model):
190192
raise NotImplementedError("Not implement create_quantized_model method")
191193

194+
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
192195
def save_quantized(
193196
self,
194197
save_dir: str,
@@ -249,6 +252,7 @@ def save_quantized(
249252

250253
self.model.config.save_pretrained(save_dir)
251254

255+
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
252256
def save_pretrained(
253257
self,
254258
save_dir: str,
@@ -260,6 +264,7 @@ def save_pretrained(
260264
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
261265
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
262266

267+
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
263268
@classmethod
264269
def from_pretrained(
265270
cls,
@@ -354,6 +359,7 @@ def skip(*args, **kwargs):
354359

355360
return cls(model, False)
356361

362+
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
357363
@classmethod
358364
def from_quantized(
359365
cls,

colossalai/inference/quant/smoothquant/models/linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def from_float(module: torch.nn.Linear, input_scale):
6262
return int8_module
6363

6464

65+
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
6566
class W8A8B8O8Linear(torch.nn.Module):
6667
# For qkv_proj
6768
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
@@ -117,6 +118,7 @@ def from_float(module: torch.nn.Linear, input_scale, output_scale):
117118
return int8_module
118119

119120

121+
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
120122
class W8A8BFP32OFP32Linear(torch.nn.Module):
121123
# For fc2 and out_proj
122124
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):

colossalai/inference/quant/smoothquant/models/llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ def forward(self, x, cos, sin, position_ids):
419419
return x_embed
420420

421421

422+
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
422423
def llama_decoder_layer_forward(
423424
self,
424425
hidden_states: torch.Tensor,
@@ -559,6 +560,7 @@ def init_to_get_rotary(config, base=10000, use_elem=False):
559560
return _cos_cached, _sin_cached
560561

561562

563+
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
562564
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
563565
def llama_model_forward(
564566
self,
@@ -729,6 +731,7 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
729731
def __init__(self, model: PreTrainedModel, quantized: bool = False):
730732
super().__init__(model, quantized)
731733

734+
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
732735
def get_act_dict(
733736
self,
734737
tokenizer,

colossalai/inference/tensor_parallel/engine.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
"BloomForCausalLM",
2222
"ChatGLMModel",
2323
"ChatGLMForConditionalGeneration",
24+
"LlamaGPTQForCausalLM",
25+
"BloomGPTQForCausalLM",
2426
]
2527

2628

@@ -213,11 +215,14 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
213215
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
214216
model_name = model.__class__.__name__
215217
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
218+
219+
model = model.model if self.shard_config.inference_gptq else model
220+
216221
policy = get_autopolicy(model, inference_only=True)
217222
self.model, _ = shardformer.optimize(model, policy)
218223

219224
if self.shard_config.inference_gptq:
220-
self._post_init_gptq_buffer(model)
225+
self._post_init_gptq_buffer(self.model)
221226

222227
self.model = self.model.cuda()
223228

colossalai/kernel/triton/gptq_triton.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def cai_gptq_matmul_248_kernel(
267267
tl.store(c_ptrs, accumulator, mask=c_mask)
268268

269269

270+
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
270271
@autotune(
271272
configs=[
272273
triton.Config(

colossalai/kernel/triton/smooth_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313

1414
if HAS_TRITON:
1515
"""
16-
this function is modified from
17-
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
16+
this functions are modified from https://github.com/ModelTC/lightllm
1817
"""
1918

19+
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
2020
@triton.jit
2121
def _context_flash_attention_kernel(
2222
Q,
@@ -145,20 +145,16 @@ def _context_flash_attention_kernel(
145145
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
146146
return
147147

148-
149-
150148
@torch.no_grad()
151149
def smooth_llama_context_attn_fwd(
152150
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
153151
):
154-
155152
BLOCK = 128
156153
# shape constraints
157154
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
158155
assert Lq == Lk, "context process only supports equal query, key, value length"
159156
assert Lk == Lv, "context process only supports equal query, key, value length"
160157
assert Lk in {16, 32, 64, 128}
161-
BLOCK_N = 128
162158
sm_scale = 1.0 / math.sqrt(Lk)
163159
batch, head = b_seq_len.shape[0], q.shape[1]
164160
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
@@ -203,6 +199,7 @@ def smooth_llama_context_attn_fwd(
203199
)
204200
return
205201

202+
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
206203
@triton.jit
207204
def _token_attn_1_kernel(
208205
Q,
@@ -264,6 +261,7 @@ def _token_attn_1_kernel(
264261
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
265262
return
266263

264+
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
267265
@triton.jit
268266
def _token_attn_1_alibi_kernel(
269267
Q,
@@ -413,6 +411,7 @@ def token_attn_fwd_1(
413411
)
414412
return
415413

414+
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
416415
@triton.jit
417416
def _token_attn_softmax_fwd(
418417
softmax_logics,
@@ -479,6 +478,7 @@ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen,
479478
)
480479
return
481480

481+
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
482482
@triton.jit
483483
def _token_attn_2_kernel(
484484
Prob,

examples/inference/gptq_llama.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import colossalai
1010
from colossalai.inference.tensor_parallel.engine import TPInferEngine
11-
from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary
1211
from colossalai.logging import disable_existing_loggers
1312
from colossalai.shardformer import ShardConfig
1413
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
@@ -74,8 +73,6 @@ def run_llama_test(args):
7473
quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
7574
)
7675

77-
init_to_get_rotary(model.model.model, base=10000)
78-
7976
model_config = model.config
8077
shard_config = ShardConfig(
8178
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True

0 commit comments

Comments
 (0)