diff --git a/gptqmodel/nn_modules/qlinear/exllama.py b/gptqmodel/nn_modules/qlinear/exllama.py index 3c0a046cf..ee4beb18f 100644 --- a/gptqmodel/nn_modules/qlinear/exllama.py +++ b/gptqmodel/nn_modules/qlinear/exllama.py @@ -161,7 +161,12 @@ def ext_q4_matmul(self, x, q4, q4_width): return output.view(outshape) - def forward(self, x): + def forward(self, x: torch.Tensor): + # TODO FIXME: parent should never call us if there is no data to process + # check: https://github.com/ModelCloud/GPTQModel/issues/1361 + if x.shape[0] == 0: + return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device) + x_dtype = x.dtype if x_dtype != torch.float16: logger.warning_once( diff --git a/gptqmodel/nn_modules/qlinear/exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2.py index 7250eddeb..d08d2b266 100644 --- a/gptqmodel/nn_modules/qlinear/exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2.py @@ -217,7 +217,12 @@ def post_init(self, temp_dq): super().post_init() - def forward(self, x, force_cuda=False): + def forward(self, x: torch.Tensor, force_cuda=False): + # TODO FIXME: parent should never call us if there is no data to process + # check: https://github.com/ModelCloud/GPTQModel/issues/1361 + if x.shape[0] == 0: + return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device) + x_dtype = x.dtype if x_dtype != torch.float16: logger.warning_once( diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index 6b90427e6..69e084b35 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -400,12 +400,17 @@ def post_init(self): super().post_init() - def forward(self, A: torch.Tensor): - if A.dtype != torch.float16: - A = A.to(torch.float16) + def forward(self, x: torch.Tensor): + # TODO FIXME: parent should never call us if there is no data to process + # check: https://github.com/ModelCloud/GPTQModel/issues/1361 + if x.shape[0] == 0: + return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device) + + if x.dtype != torch.float16: + x = x.to(torch.float16) out = apply_gptq_marlin_linear( - input=A.contiguous() if self.is_lm_head else A, + input=x.contiguous() if self.is_lm_head else x, weight=self.qweight, weight_scale=self.scales, weight_zp=self.zp, @@ -421,7 +426,7 @@ def forward(self, A: torch.Tensor): ) if self.adapter: - out = self.adapter.apply(x=A, out=out) + out = self.adapter.apply(x=x, out=out) return out diff --git a/gptqmodel_ext/marlin/marlin_cuda_kernel.cu b/gptqmodel_ext/marlin/marlin_cuda_kernel.cu index 312414b4d..531a67a0a 100644 --- a/gptqmodel_ext/marlin/marlin_cuda_kernel.cu +++ b/gptqmodel_ext/marlin/marlin_cuda_kernel.cu @@ -1927,27 +1927,6 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - #define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - template void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp, void* s, void* zp, void* g_idx, void* perm, void* a_tmp, @@ -2089,14 +2068,6 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp, GPTQ_CALL_IF(8, 8, 4, 128) GPTQ_CALL_IF(8, 4, 8, 128) - AWQ_CALL_IF(4, 16, 4, 256) - AWQ_CALL_IF(4, 8, 8, 256) - AWQ_CALL_IF(4, 8, 4, 128) - AWQ_CALL_IF(4, 4, 8, 128) - AWQ_CALL_IF(8, 16, 4, 256) - AWQ_CALL_IF(8, 8, 8, 256) - AWQ_CALL_IF(8, 8, 4, 128) - AWQ_CALL_IF(8, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]", ", has_act_order = ", has_act_order, diff --git a/tests/models/test_qwen_15_moe.py b/tests/models/test_qwen_15_moe.py new file mode 100644 index 000000000..c4eb1b7e6 --- /dev/null +++ b/tests/models/test_qwen_15_moe.py @@ -0,0 +1,39 @@ +from gptqmodel import GPTQModel, BACKEND +import torch + +import unittest + +from datasets import load_dataset +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig + + +class TestQwen15Moe(unittest.TestCase): + def test_inference(self): + model = GPTQModel.load("Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", + device=torch.device("cuda:0"), + backend=BACKEND.MARLIN) + + tokenizer = model.tokenizer + + prompt = "Give me a short introduction to large language model." + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + + generated_ids = model.generate( + model_inputs.input_ids, + max_new_tokens=128 + ) + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + print(f"Response: `{response}`") \ No newline at end of file