Skip to content

Commit 0a0cfb0

Browse files
authored
Fix upstream transformers modeling inference code is passing impossible input shape where shape[0]==0 to module (#1365)
Signed-off-by: Qubitium <[email protected]>
1 parent cb786e5 commit 0a0cfb0

File tree

5 files changed

+61
-36
lines changed

5 files changed

+61
-36
lines changed

gptqmodel/nn_modules/qlinear/exllama.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,12 @@ def ext_q4_matmul(self, x, q4, q4_width):
161161
return output.view(outshape)
162162

163163

164-
def forward(self, x):
164+
def forward(self, x: torch.Tensor):
165+
# TODO FIXME: parent should never call us if there is no data to process
166+
# check: https://github.com/ModelCloud/GPTQModel/issues/1361
167+
if x.shape[0] == 0:
168+
return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)
169+
165170
x_dtype = x.dtype
166171
if x_dtype != torch.float16:
167172
logger.warning_once(

gptqmodel/nn_modules/qlinear/exllamav2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,12 @@ def post_init(self, temp_dq):
217217

218218
super().post_init()
219219

220-
def forward(self, x, force_cuda=False):
220+
def forward(self, x: torch.Tensor, force_cuda=False):
221+
# TODO FIXME: parent should never call us if there is no data to process
222+
# check: https://github.com/ModelCloud/GPTQModel/issues/1361
223+
if x.shape[0] == 0:
224+
return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)
225+
221226
x_dtype = x.dtype
222227
if x_dtype != torch.float16:
223228
logger.warning_once(

gptqmodel/nn_modules/qlinear/marlin.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,17 @@ def post_init(self):
400400

401401
super().post_init()
402402

403-
def forward(self, A: torch.Tensor):
404-
if A.dtype != torch.float16:
405-
A = A.to(torch.float16)
403+
def forward(self, x: torch.Tensor):
404+
# TODO FIXME: parent should never call us if there is no data to process
405+
# check: https://github.com/ModelCloud/GPTQModel/issues/1361
406+
if x.shape[0] == 0:
407+
return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)
408+
409+
if x.dtype != torch.float16:
410+
x = x.to(torch.float16)
406411

407412
out = apply_gptq_marlin_linear(
408-
input=A.contiguous() if self.is_lm_head else A,
413+
input=x.contiguous() if self.is_lm_head else x,
409414
weight=self.qweight,
410415
weight_scale=self.scales,
411416
weight_zp=self.zp,
@@ -421,7 +426,7 @@ def forward(self, A: torch.Tensor):
421426
)
422427

423428
if self.adapter:
424-
out = self.adapter.apply(x=A, out=out)
429+
out = self.adapter.apply(x=x, out=out)
425430

426431
return out
427432

gptqmodel_ext/marlin/marlin_cuda_kernel.cu

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,27 +1927,6 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
19271927
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
19281928
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
19291929

1930-
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
1931-
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
1932-
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
1933-
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
1934-
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
1935-
\
1936-
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
1937-
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
1938-
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
1939-
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
1940-
\
1941-
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
1942-
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
1943-
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
1944-
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
1945-
\
1946-
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
1947-
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
1948-
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
1949-
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
1950-
19511930
template <typename scalar_t>
19521931
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
19531932
void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
@@ -2089,14 +2068,6 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
20892068
GPTQ_CALL_IF(8, 8, 4, 128)
20902069
GPTQ_CALL_IF(8, 4, 8, 128)
20912070

2092-
AWQ_CALL_IF(4, 16, 4, 256)
2093-
AWQ_CALL_IF(4, 8, 8, 256)
2094-
AWQ_CALL_IF(4, 8, 4, 128)
2095-
AWQ_CALL_IF(4, 4, 8, 128)
2096-
AWQ_CALL_IF(8, 16, 4, 256)
2097-
AWQ_CALL_IF(8, 8, 8, 256)
2098-
AWQ_CALL_IF(8, 8, 4, 128)
2099-
AWQ_CALL_IF(8, 4, 8, 128)
21002071
else {
21012072
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
21022073
", ", prob_k, "]", ", has_act_order = ", has_act_order,

tests/models/test_qwen_15_moe.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from gptqmodel import GPTQModel, BACKEND
2+
import torch
3+
4+
import unittest
5+
6+
from datasets import load_dataset
7+
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
8+
9+
10+
class TestQwen15Moe(unittest.TestCase):
11+
def test_inference(self):
12+
model = GPTQModel.load("Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
13+
device=torch.device("cuda:0"),
14+
backend=BACKEND.MARLIN)
15+
16+
tokenizer = model.tokenizer
17+
18+
prompt = "Give me a short introduction to large language model."
19+
messages = [
20+
{"role": "system", "content": "You are a helpful assistant."},
21+
{"role": "user", "content": prompt}
22+
]
23+
text = tokenizer.apply_chat_template(
24+
messages,
25+
tokenize=False,
26+
add_generation_prompt=True
27+
)
28+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
29+
30+
generated_ids = model.generate(
31+
model_inputs.input_ids,
32+
max_new_tokens=128
33+
)
34+
generated_ids = [
35+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
36+
]
37+
38+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
39+
print(f"Response: `{response}`")

0 commit comments

Comments
 (0)