Skip to content

Commit 25f1607

Browse files
ZX-ModelCloudQubitiumshihyanglnbasylhutm
authored
Eora Fused Kernel (#1309)
* add `extension` property to QuantizeConfig + EoRA Extension/Config * test shihyang push * match/validate correct kernel to extension * model.quantize return the quantized weight now for EoRA * allow test_perplexity to run without buffered_fwd arg * limit test to only 1 for fast debug * reduce verbosity of logs (meant for debug) * fix python 3.10 compat * finish eora first version(not optimize might only work for llama type) * dummy (non-working) eora torch kernel * add `BACKEND.EORA_TORCH` and correctly register the eora_torch kernel * fix eora torch backend selection * fix typo causing dtype mismatch * trying to get the eora loading but fail * refractor eora config/loading * refractor eora config * add `test_eora.py`, loading not fixed yet * fix config loading, and quant model loading (non-lora weighs) with eroa config. * load A and B weights * fix transposed tensors for inference * move a/b to correct device * rename `extension` to `adapter` * half-way done with eora * eora bug device mismatch * fix eora v2 generation code(non-concatenated version) * added GPTQ-eora kernel based off exllama vllm GPTQ implementation * refractor adapter a/b load and math inside EoRA adapter and out of kernel * fix adapter not copied causing shape errors since all adapters are the same instance * fix loader cache ci bug * create eora_load_and_infer.py at root to avoid recompiling * use local model dir * load local datasets * fix setting CUDA_DEVICE_ORDER * add local model path * fix merge error * move adapter code adapter.py * rename EoRA to Lora * rename `lora.path_or_id` to `lora.path` * added sweep test for different k and r that conform to condition: (128 * r / k) is an integer >= 1 * relaxed r to be any rank < k * add default value for pack_dtype & adapter * Revert "add default value for pack_dtype & adapter" This reverts commit e56b86a. * add pack_dtype & adapter for hf_select_quant_linear * set adapter to None * remove unexpected char * default None for name and set it with kernel name * 1. use dict for model args. 2. accept extra args * use dict for model args * add lm eval tests * use triton backend * optimization: reordering for loop to have unrolled inner for loops * add test_kernel_output.py * cleanup test Signed-off-by: Qubitium <[email protected]> * fix eora kernel must be on gptq_v1 format and not on the internal v2 format for other kernels with zeropoint offset fix Signed-off-by: Qubitium <[email protected]> * also skip v1 to v2 for marlin Signed-off-by: Qubitium <[email protected]> * enabled fused eora kernel Signed-off-by: Qubitium <[email protected]> * remove `eora_torch` backend (useless) Signed-off-by: Qubitium <[email protected]> * add test_kernel_output_with_lora() * remove FORMAT_FIELD_COMPAT_MARLIN * test add BACKEND.CUDA * EXLLAMA-EORA add SUPPORTS_BITS 2,3 * fix gptq_marlin error Signed-off-by: ZX-ModelCloud <[email protected]> * merge changes from main and fix v2_to_v1 conversion should bypass marlin + eora kernel Signed-off-by: Qubitium <[email protected]> * wrong eq check * fixing kernel bug * . * add test file for eora kernel * fix the eora_kernel buggit add . * format Signed-off-by: Qubitium <[email protected]> # Conflicts: # tests/test_kernel_output.py * format Signed-off-by: Qubitium <[email protected]> * format Signed-off-by: Qubitium <[email protected]> --------- Signed-off-by: Qubitium <[email protected]> Signed-off-by: ZX-ModelCloud <[email protected]> Co-authored-by: Qubitium <[email protected]> Co-authored-by: shihyangl <[email protected]> Co-authored-by: nbasyl <[email protected]> Co-authored-by: Maksim Khadkevich <[email protected]> Co-authored-by: CSY-ModelCloud <[email protected]>
1 parent 0a0cfb0 commit 25f1607

39 files changed

+6042
-59
lines changed

eora_load_and_infer.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
3+
from gptqmodel import BACKEND, GPTQModel
4+
from gptqmodel.adapter.adapter import Lora
5+
from parameterized import parameterized
6+
7+
8+
@parameterized.expand([
9+
(BACKEND.TORCH),
10+
(BACKEND.CUDA),
11+
(BACKEND.TRITON),
12+
(BACKEND.EXLLAMA_V1),
13+
# (BACKEND.EXLLAMA_V2), <-- adapter not working yet
14+
(BACKEND.MARLIN),
15+
# (BACKEND.IPEX), <-- not tested yet
16+
# (BACKEND.BITBLAS, <-- not tested yet
17+
])
18+
def test_load(backend: BACKEND):
19+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
20+
quant_model_path = "/home/shihyangl/gptqmodel_save/Llama-3.2-1B-gptqmodel-4bit"
21+
lora_path = "/home/shihyangl/llama3.2-1b-4bit-group128-eora-rank128-arc/adapter_model.safetensors" #"sliuau/llama3.2-1b-4bit-group128-eora-rank128-arc/blob/main/adapter_model.safetensors" #"sliuau/llama3.2-1b-4bit-group128-eora-rank128-arc"
22+
23+
adapter = Lora(path=lora_path, rank=128)
24+
25+
model = GPTQModel.load(
26+
quant_model_path,
27+
adapter=adapter,
28+
backend=backend,
29+
device_map="auto",
30+
)
31+
32+
# print(model)
33+
tokens = model.generate("Capital of France is")[0]
34+
result = model.tokenizer.decode(tokens)
35+
print(f"Result: {result}")
36+
assert "paris" in result.lower()
37+
38+
39+
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
40+
# quant_model_path = "/home/shihyangl/gptqmodel_save/Llama-3.2-1B-gptqmodel-4bit"
41+
# lora_path = "/home/shihyangl/llama3.2-1b-4bit-group128-eora-rank128-arc/adapter_model.safetensors" #"sliuau/llama3.2-1b-4bit-group128-eora-rank128-arc/blob/main/adapter_model.safetensors" #"sliuau/llama3.2-1b-4bit-group128-eora-rank128-arc"
42+
43+
# adapter = EoRA(lora_path=lora_path, rank=128)
44+
45+
# model = GPTQModel.load(
46+
# quant_model_path,
47+
# adapter=adapter,
48+
# backend=BACKEND.TORCH,
49+
# device_map="auto",
50+
# )
51+
52+
# # print(model)
53+
# tokens = model.generate("Capital of France is")[0]
54+
# result = model.tokenizer.decode(tokens)
55+
# print(f"Result: {result}")
56+
# assert "paris" in result.lower()

eora_no_bug.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
from datasets import load_dataset
3+
from gptqmodel import GPTQModel, QuantizeConfig
4+
5+
# from gptqmodel.eora import get_eora, get_eora_optimize
6+
7+
8+
bit = 4
9+
model_id = "meta-llama/Llama-3.2-1B"
10+
model = None
11+
12+
quant_path = "Llama-3.2-1B-gptqmodel-4bit"
13+
fake_quant_path = "Llama-3.2-1B-gptqmodel-4bit-fakequantized/qw.pt"
14+
eora_path = "Llama-3.2-1B-gptqmodel-4bit-eora-rank-128-v2/eora.pt"
15+
quant_config = QuantizeConfig(bits=bit, group_size=128)
16+
17+
18+
calibration_dataset = load_dataset(
19+
"allenai/c4",
20+
data_files="en/c4-train.00001-of-01024.json.gz",
21+
split="train"
22+
).select(range(1024))["text"]
23+
24+
print(f"{type(calibration_dataset)}")
25+
26+
### 3-bit group_size = 128 leads to out: IndexError: index 192 is out of bounds when packing
27+
model = GPTQModel.load(model_id, quant_config)
28+
29+
# increase `batch_size` to match gpu/vram specs to speed up quantization
30+
quant_log, quantized_weights = model.quantize(calibration_dataset, batch_size=2)
31+
32+
model.save(quant_path)
33+
34+
torch.save(quantized_weights, fake_quant_path)
35+
quantized_weights = torch.load(fake_quant_path, map_location='cpu')
36+
37+
## 4-bit gs=128 Acc: 0.2850
38+
39+
batch_size = 2
40+
from test_prepare_dataset import construct_ARC
41+
42+
calibration_dataset = construct_ARC(nsamples=1024)
43+
eora_rank = 128
44+
model = GPTQModel.load(model_id, quant_config)
45+
46+
eora_weight = model.get_eora(calibration_dataset, batch_size, quantized_weights, eora_rank)
47+
48+
torch.save(eora_weight, eora_path)
49+
50+
eora_weight = torch.load(eora_path, map_location='cpu')
51+
print(eora_weight)

gptqmodel/adapter/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=N
171171
pop_keys.append(k)
172172
elif k.endswith(lora_B_weight_key):
173173
lora_B = v.T
174+
lora_B = torch.clone(v.T, memory_format=torch.contiguous_format)
174175
pop_keys.append(k)
175176

176-
177177
if pop_keys:
178178
for k in pop_keys:
179179
lora_weights.pop(k) # releasee lora weights from cache memory

gptqmodel/eora/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .eora import *
2+
from .eora_calibration_dataloader import *
3+
from .modelutils import *
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# NVIDIA CORPORATION and its licensors retain all intellectual property
4+
# and proprietary rights in and to this software, related documentation
5+
# and any modifications thereto. Any use, reproduction, disclosure or
6+
# distribution of this software and related documentation without an express
7+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8+
9+
import re
10+
from typing import Dict, Optional, Sequence
11+
12+
## This is the oldway of constructing the calibration dataset
13+
import numpy as np
14+
import torch
15+
import transformers
16+
17+
18+
def set_seed(seed):
19+
np.random.seed(seed)
20+
torch.random.manual_seed(seed)
21+
def get_mathqa_c4(nsamples, seed, seqlen, model):
22+
from datasets import load_dataset
23+
traindata_mathqa = load_dataset('math_qa', split='train')
24+
from transformers import AutoTokenizer
25+
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, seqlen=2048)
26+
27+
import random
28+
random.seed(seed)
29+
trainloader = []
30+
mathqa_namsples = int(20)
31+
print(f"mathqa_namsples {mathqa_namsples}")
32+
i = 0
33+
for _ in range(mathqa_namsples):
34+
35+
cur_len = 0
36+
input = ""
37+
while cur_len < seqlen:
38+
doc = traindata_mathqa[i]
39+
cur_input = "Question: " + doc["Problem"] + " Choices: " + doc["options"] + ". Rationale: " + doc["Rationale"] + ". "
40+
input = input + cur_input
41+
trainenc = tokenizer(input, return_tensors='pt')
42+
cur_len = (trainenc.input_ids.shape[1]) ## neglect the bos token
43+
i += 1
44+
45+
## reach seq_len
46+
final_inp = tokenizer(input, return_tensors='pt')
47+
inp = final_inp.input_ids[:, :seqlen]
48+
tar = inp.clone()
49+
tar[:, :-1] = -100
50+
trainloader.append((inp, tar))
51+
52+
traindata = load_dataset('allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
53+
c4_nsamples = nsamples - mathqa_namsples
54+
for _ in range(c4_nsamples):
55+
while True:
56+
i = random.randint(0, len(traindata) - 1)
57+
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
58+
if trainenc.input_ids.shape[1] > seqlen:
59+
break
60+
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
61+
j = i + seqlen
62+
inp = trainenc.input_ids[:, i:j]
63+
tar = inp.clone()
64+
tar[:, :-1] = -100
65+
trainloader.append((inp, tar))
66+
67+
return trainloader
68+
69+
def get_arc_c4(nsamples, seed, seqlen, model):
70+
from datasets import load_dataset
71+
traindata_arc_easy = load_dataset('ai2_arc', 'ARC-Easy', split='train')
72+
traindata_arc_challenge = load_dataset('ai2_arc', 'ARC-Challenge', split='train')
73+
from transformers import AutoTokenizer
74+
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, seqlen=2048)
75+
76+
77+
import random
78+
random.seed(seed)
79+
trainloader = []
80+
arc_e_namsples = int(20)
81+
print(f"arc_e_namsples {arc_e_namsples}")
82+
i = 0
83+
for _ in range(arc_e_namsples):
84+
85+
cur_len = 0
86+
input = ""
87+
while cur_len < seqlen:
88+
answer = traindata_arc_easy[i]['choices']['label'].index(traindata_arc_easy[i]['answerKey'])
89+
cur_input = traindata_arc_easy[i]['question'] +" "+ traindata_arc_easy[i]['choices']['text'][answer] + ". "
90+
input = input + cur_input
91+
trainenc = tokenizer(input, return_tensors='pt')
92+
cur_len = (trainenc.input_ids.shape[1]) ## neglect the bos token
93+
i += 1
94+
95+
final_inp = tokenizer(input, return_tensors='pt')
96+
inp = final_inp.input_ids[:, :seqlen]
97+
tar = inp.clone()
98+
tar[:, :-1] = -100
99+
trainloader.append((inp, tar))
100+
101+
102+
arc_c_namsples = int(10)
103+
print(f"arc_c_namsples {arc_c_namsples}")
104+
i = 0
105+
for _ in range(arc_c_namsples):
106+
107+
cur_len = 0
108+
input = ""
109+
while cur_len < seqlen:
110+
answer = traindata_arc_challenge[i]['choices']['label'].index(traindata_arc_challenge[i]['answerKey'])
111+
cur_input = traindata_arc_challenge[i]['question'] +" "+ traindata_arc_challenge[i]['choices']['text'][answer] + ". "
112+
input = input + cur_input
113+
trainenc = tokenizer(input, return_tensors='pt')
114+
cur_len = (trainenc.input_ids.shape[1]) ## neglect the bos token
115+
i += 1
116+
117+
## reach seq_len
118+
final_inp = tokenizer(input, return_tensors='pt')
119+
inp = final_inp.input_ids[:, :seqlen]
120+
tar = inp.clone()
121+
tar[:, :-1] = -100
122+
trainloader.append((inp, tar))
123+
124+
125+
# traindata = load_dataset("json", data_files=f"{c4_data}/c4-train.json")['train']
126+
traindata = load_dataset('allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
127+
c4_nsamples = nsamples - arc_c_namsples - arc_e_namsples
128+
for _ in range(c4_nsamples):
129+
while True:
130+
i = random.randint(0, len(traindata) - 1)
131+
# print(len(traindata[i]['text']))
132+
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
133+
if trainenc.input_ids.shape[1] > seqlen:
134+
break
135+
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
136+
j = i + seqlen
137+
inp = trainenc.input_ids[:, i:j]
138+
tar = inp.clone()
139+
tar[:, :-1] = -100
140+
# print(f"inp {inp.shape}")
141+
trainloader.append((inp, tar))
142+
143+
return trainloader
144+
145+
def get_wikitext2(nsamples, seed, seqlen, model):
146+
from datasets import load_dataset
147+
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
148+
149+
from transformers import AutoTokenizer
150+
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
151+
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
152+
153+
import random
154+
random.seed(seed)
155+
trainloader = []
156+
for _ in range(nsamples):
157+
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
158+
j = i + seqlen
159+
inp = trainenc.input_ids[:, i:j]
160+
tar = inp.clone()
161+
tar[:, :-1] = -100
162+
trainloader.append((inp, tar))
163+
return trainloader
164+
165+
def get_loaders(
166+
data_name, nsamples=128, seed=0, seqlen=2048, model=''
167+
):
168+
if type(data_name) == list:
169+
raise NotImplementedError
170+
else:
171+
if 'wikitext2' in data_name:
172+
return get_wikitext2(nsamples, seed, seqlen, model)
173+
if "mathqa" in data_name:
174+
return get_mathqa_c4(nsamples, seed, seqlen, model)
175+
if "arc" in data_name:
176+
return get_arc_c4(nsamples, seed, seqlen, model)
177+
178+
179+

gptqmodel/eora/modelutils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import functools
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
def recurse_getattr(obj, attr: str):
8+
"""
9+
Recursive `getattr`.
10+
11+
Args:
12+
obj:
13+
A class instance holding the attribute.
14+
attr (`str`):
15+
The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
16+
"""
17+
18+
def _getattr(obj, attr):
19+
return getattr(obj, attr)
20+
21+
return functools.reduce(_getattr, [obj] + attr.split("."))
22+
23+
24+
def recurse_setattr(module, name, value):
25+
"""A function to recursively set attributes to a module."""
26+
if "." not in name:
27+
setattr(module, name, value)
28+
else:
29+
name, rest = name.split(".", 1)
30+
recurse_setattr(getattr(module, name), rest, value)
31+
32+
33+
34+
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
35+
if type(module) in layers:
36+
return {name: module}
37+
res = {}
38+
for name1, child in module.named_children():
39+
res.update(find_layers(
40+
child, layers=layers, name=name + '.' + name1 if name != '' else name1
41+
))
42+
return res
43+
44+
45+

gptqmodel/models/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def from_quantized(
323323
model, hf_config = load_model_by_sglang(
324324
model=model_local_path,
325325
trust_remote_code=trust_remote_code,
326+
dtype=torch.float16,
326327
**kwargs,
327328
)
328329
model.config = hf_config

gptqmodel/nn_modules/qlinear/exllama_eora.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
5454

5555

5656
class ExllamaEoraQuantLinear(BaseQuantLinear):
57-
SUPPORTS_BITS = [4, 8] # TODO: validate 2/3
57+
SUPPORTS_BITS = [4, 8]
5858
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
5959
SUPPORTS_DESC_ACT = [True, False]
6060
SUPPORTS_SYM = [True] # TODO: validate False
@@ -157,7 +157,7 @@ def forward(self, x):
157157
x_dtype = x.dtype
158158
if x_dtype != torch.float16:
159159
logger.warning_once(
160-
f"Exllama v2 kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
160+
f"Exllama EoRA kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
161161
)
162162

163163
x = x.to(dtype=torch.float16)
@@ -172,8 +172,8 @@ def forward(self, x):
172172
# x = F.pad(x, self.in_features_padding_shape)
173173

174174
if self.adapter:
175-
# output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused
176-
output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal
175+
output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused
176+
# output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal
177177
else:
178178
output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits)
179179

gptqmodel/quantization/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,6 @@ def from_quant_config(cls, quantize_cfg, format: str = None):
383383
raise ValueError(f"QuantizeConfig: Unknown quantization method: `{val}`.")
384384
else:
385385
normalized[QUANT_METHOD_FIELD] = val
386-
elif key == FORMAT_FIELD_COMPAT_MARLIN and val:
387-
normalized[FORMAT_FIELD_CODE] = FORMAT.MARLIN
388386
elif key in field_names:
389387
normalized[key] = val
390388
else:

0 commit comments

Comments
 (0)