Skip to content

Commit 620fcf1

Browse files
check compute capability for marlin in validate_device() (#1095)
* check cuda v8 for marlin * check cuda 8 for installation * update msg * update skip marlin msg * check rocm first * check not ROCM_VERSION * check compute capability with validate_device * check rocm * check all devices' capability * use local model path * Update marlin.py --------- Co-authored-by: Qubitium-ModelCloud <[email protected]>
1 parent 0801e1a commit 620fcf1

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

gptqmodel/nn_modules/qlinear/marlin.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# Adapted from vllm at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/gptq_marlin.py
1717

18+
import os
1819
from typing import Any, Dict, List, Optional, Tuple
1920

2021
import numpy as np
@@ -306,14 +307,26 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
306307

307308
@classmethod
308309
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
309-
if IS_ROCM:
310-
return False, RuntimeError("marlin kernel is not supported by rocm.")
311-
if not any(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count())):
312-
return False, RuntimeError("marlin kernel requires Compute Capability >= 8.0.")
313310
if marlin_import_exception is not None:
314311
return False, marlin_import_exception
315312
return cls._validate(**args)
316313

314+
@classmethod
315+
def validate_device(cls, device: DEVICE):
316+
super().validate_device(device)
317+
CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES")
318+
if device == DEVICE.CUDA:
319+
if IS_ROCM:
320+
raise NotImplementedError("Marlin kernel is not supported on ROCm.")
321+
322+
if CUDA_VISIBLE_DEVICES is None:
323+
has_cuda_v8 = all(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count()))
324+
else:
325+
has_cuda_v8 = all(torch.cuda.get_device_capability(int(i))[0] >= 8 for i in CUDA_VISIBLE_DEVICES.split(","))
326+
327+
if not has_cuda_v8:
328+
raise NotImplementedError("Marlin kernel only supports compute capability >= 8.0.")
329+
317330
def post_init(self):
318331
device = self.qweight.device
319332
# Allocate marlin workspace
@@ -420,4 +433,4 @@ def dequantize_qzeros(layer):
420433

421434
return unpacked_qzeros
422435

423-
__all__ = ["MarlinQuantLinear"]
436+
__all__ = ["MarlinQuantLinear"]

tests/test_q4_marlin.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,39 +32,39 @@ class TestQ4Marlin(unittest.TestCase):
3232
@parameterized.expand(
3333
[
3434
# act_order==False, group_size=128
35-
("TheBloke/Llama-2-7B-GPTQ", "main",
35+
("/monster/data/model/Llama-2-7B-GPTQ", "main",
3636
"<s> I am in Paris and I am in love. everybody knows that.\n"
3737
"I am in Paris and I am in love.\n"
3838
"I am in Paris and I am in love. everybody knows that.\n"
3939
"I am in Paris and I am in love. everybody knows that.\n"
4040
"I am in Paris and I am in love"),
4141
4242
# act_order==True, group_size=128
43-
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main",
43+
("/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "main",
4444
"<s> I am in Paris and I am so excited to be here. I am here for the first time in my life and I am so grateful for this opportunity. I am here to learn and to grow and to meet new people and to experience new things. I am here to see the Eiffel Tower and to walk along"),
4545
# act_order==True, group_size=64
46-
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True",
46+
("/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq-4bit-64g-actorder_True",
4747
"<s> I am in Paris and I am so happy to be here. I have been here for 10 years and I have never been happier. I have been here for 10 years and I have never been happier. I have been here for 10 years and I have never been happier. I"),
4848
# act_order==True, group_size=32
49-
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True",
49+
("/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq-4bit-32g-actorder_True",
5050
"<s> I am in Paris and I am in love with you.\n"
5151
"\n"
5252
"Scene 2:\n"
5353
"\n"
5454
"(The stage is now dark, with only the sound of the rain falling on the windowpane. The lights come up on a young couple, JESSICA and JASON, sitting on a park ben"),
5555
5656
# # 8-bit, act_order==True, group_size=channelwise
57-
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True",
57+
("/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq-8bit--1g-actorder_True",
5858
"<s> I am in Paris and I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy"),
5959
# # 8-bit, act_order==True, group_size=128
60-
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True",
60+
("/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq-8bit-128g-actorder_True",
6161
"<s> I am in Paris and I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy to be here. I am so happy"),
6262
# # 8-bit, act_order==True, group_size=32
63-
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True",
63+
("/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq-8bit-32g-actorder_True",
6464
"<s> I am in Paris and I am looking for a good restaurant for a special occasion. Can you recommend any restaurants in Paris that are known for their specialties? I am looking for something unique and special. Please let me know if you have any recommendations."),
6565
6666
# # 4-bit, act_order==True, group_size=128
67-
("TechxGenus/gemma-1.1-2b-it-GPTQ", "main",
67+
("/monster/data/model/gemma-1.1-2b-it-GPTQ", "main",
6868
"<bos>I am in Paris and I am looking for a good bakery with fresh bread.\n"
6969
"\n"
7070
"**What are some good bakeries in Paris with fresh bread?**\n"
@@ -76,12 +76,12 @@ class TestQ4Marlin(unittest.TestCase):
7676
"* I am open to both traditional bakeries and newer, trendy")
7777
]
7878
)
79-
def test_generation(self, model_id, revision, reference_output):
79+
def test_generation(self, model_id, reference_output):
8080
prompt = "I am in Paris and"
8181
device = torch.device("cuda:0")
8282

8383
try:
84-
model_q = GPTQModel.load(model_id, revision=revision, device="cuda:0", backend=BACKEND.MARLIN)
84+
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN)
8585
except ValueError as e:
8686
raise e
8787

0 commit comments

Comments
 (0)