Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion gptqmodel/nn_modules/qlinear/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ def ext_q4_matmul(self, x, q4, q4_width):
return output.view(outshape)


def forward(self, x):
def forward(self, x: torch.Tensor):
# TODO FIXME: parent should never call us if there is no data to process
# check: https://github.com/ModelCloud/GPTQModel/issues/1361
if x.shape[0] == 0:
return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)

x_dtype = x.dtype
if x_dtype != torch.float16:
logger.warning_once(
Expand Down
7 changes: 6 additions & 1 deletion gptqmodel/nn_modules/qlinear/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,12 @@ def post_init(self, temp_dq):

super().post_init()

def forward(self, x, force_cuda=False):
def forward(self, x: torch.Tensor, force_cuda=False):
# TODO FIXME: parent should never call us if there is no data to process
# check: https://github.com/ModelCloud/GPTQModel/issues/1361
if x.shape[0] == 0:
return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)

x_dtype = x.dtype
if x_dtype != torch.float16:
logger.warning_once(
Expand Down
15 changes: 10 additions & 5 deletions gptqmodel/nn_modules/qlinear/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,17 @@ def post_init(self):

super().post_init()

def forward(self, A: torch.Tensor):
if A.dtype != torch.float16:
A = A.to(torch.float16)
def forward(self, x: torch.Tensor):
# TODO FIXME: parent should never call us if there is no data to process
# check: https://github.com/ModelCloud/GPTQModel/issues/1361
if x.shape[0] == 0:
return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)

if x.dtype != torch.float16:
x = x.to(torch.float16)

out = apply_gptq_marlin_linear(
input=A.contiguous() if self.is_lm_head else A,
input=x.contiguous() if self.is_lm_head else x,
weight=self.qweight,
weight_scale=self.scales,
weight_zp=self.zp,
Expand All @@ -421,7 +426,7 @@ def forward(self, A: torch.Tensor):
)

if self.adapter:
out = self.adapter.apply(x=A, out=out)
out = self.adapter.apply(x=x, out=out)

return out

Expand Down
29 changes: 0 additions & 29 deletions gptqmodel_ext/marlin/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1927,27 +1927,6 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)

#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)

template <typename scalar_t>
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
Expand Down Expand Up @@ -2089,14 +2068,6 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
GPTQ_CALL_IF(8, 8, 4, 128)
GPTQ_CALL_IF(8, 4, 8, 128)

AWQ_CALL_IF(4, 16, 4, 256)
AWQ_CALL_IF(4, 8, 8, 256)
AWQ_CALL_IF(4, 8, 4, 128)
AWQ_CALL_IF(4, 4, 8, 128)
AWQ_CALL_IF(8, 16, 4, 256)
AWQ_CALL_IF(8, 8, 8, 256)
AWQ_CALL_IF(8, 8, 4, 128)
AWQ_CALL_IF(8, 4, 8, 128)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order,
Expand Down
39 changes: 39 additions & 0 deletions tests/models/test_qwen_15_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from gptqmodel import GPTQModel, BACKEND
import torch

import unittest

from datasets import load_dataset
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig


class TestQwen15Moe(unittest.TestCase):
def test_inference(self):
model = GPTQModel.load("Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
device=torch.device("cuda:0"),
backend=BACKEND.MARLIN)

tokenizer = model.tokenizer

prompt = "Give me a short introduction to large language model."
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=128
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(f"Response: `{response}`")