Skip to content

Commit 79941d7

Browse files
allow passing model instance to evalplus & update tokenizer loading logics (#1284)
* remove tokenizer * revert format * receive tokenizer & get chat template * passing tokenizer to lm eval * add eval plus patch * check model_id_or_path instance type * update evalplus, check model_id_or_path instance type * also do patch for str value * remove tests file from pkg * load tokenizer if model_id_or_path is str * model_id_or_path must be a str * Update auto.py * Update evalplus.py --------- Co-authored-by: Qubitium-ModelCloud <[email protected]>
1 parent f0d81b5 commit 79941d7

File tree

4 files changed

+106
-12
lines changed

4 files changed

+106
-12
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ global-include gptqmodel_ext/**/*.cpp
44
global-include gptqmodel_ext/**/*.cu
55
global-include gptqmodel_ext/**/*.py
66
include requirements.txt
7+
prune tests/

gptqmodel/models/auto.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020

2121
from lm_eval.utils import make_table
22+
from tokenicer import Tokenicer
2223

2324
if not os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None):
2425
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'expandable_segments:True'
@@ -42,7 +43,7 @@
4243
import numpy # noqa: E402
4344
import torch # noqa: E402
4445
from huggingface_hub import list_repo_files # noqa: E402
45-
from transformers import AutoConfig,AutoTokenizer # noqa: E402
46+
from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402
4647

4748
from ..quantization import QUANT_CONFIG_FILENAME # noqa: E402
4849
from ..utils import BACKEND # noqa: E402
@@ -286,7 +287,8 @@ def from_quantized(
286287
def eval(
287288
cls,
288289
model_or_id_or_path: str=None,
289-
tasks: Union[List[EVAL.LM_EVAL], List[EVAL.EVALPLUS]] = None, # set to None to tifx mutable warning
290+
tokenizer: PreTrainedTokenizerBase=None,
291+
tasks: Union[List[EVAL.LM_EVAL], List[EVAL.EVALPLUS]] = None, # set to None to fix mutable warning
290292
framework: EVAL = EVAL.LM_EVAL,
291293
batch_size: int = 1,
292294
trust_remote_code: bool = False,
@@ -316,20 +318,29 @@ def eval(
316318
if isinstance(model_or_id_or_path, str):
317319
model = None
318320
model_id_or_path = model_or_id_or_path
321+
elif isinstance(model_or_id_or_path, BaseGPTQModel) or isinstance(model_or_id_or_path, PreTrainedModel):
322+
model = model_or_id_or_path
323+
model_id_or_path = model.config.name_or_path #
319324
else:
320-
model = model_or_id_or_path
321-
model_id_or_path = model.model_local_path
325+
raise ValueError(f"`model_or_id_or_path` is invalid. expected: `model instance or str` actual: `{model_or_id_or_path}`")
326+
327+
if tokenizer is None:
328+
if isinstance(model, BaseGPTQModel):
329+
tokenizer = model.tokenizer
330+
elif isinstance(model, PreTrainedModel) or model_id_or_path.strip():
331+
tokenizer = Tokenicer.load(model_id_or_path)
332+
333+
if tokenizer is None:
334+
raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.")
335+
336+
model_args["tokenizer"] = tokenizer
322337

323338
if framework == EVAL.LM_EVAL:
324339
for task in tasks:
325340
if task not in EVAL.get_task_enums():
326341
raise ValueError(f"lm_eval support tasks: {EVAL.get_all_tasks_string()}")
327342

328-
# model_id_or_path=model_id_or_path if model_id_or_path else model.model_id_or_path
329-
# tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
330-
tokenizer = model.tokenizer if model else AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
331-
332-
model_name = 'hf' if backend == 'gptqmodel' else backend
343+
model_name = "hf" if backend == "gptqmodel" else backend
333344

334345
if backend == "gptqmodel":
335346
model_args["gptqmodel"] = True
@@ -349,13 +360,13 @@ def eval(
349360
batch_size=batch_size,
350361
trust_remote_code=trust_remote_code,
351362
)
352-
apply_chat_template=args.pop("apply_chat_template", True if tokenizer.chat_template is not None else False)
363+
353364
results = simple_evaluate(
354365
model=model_name,
355366
model_args=model_args,
356367
tasks=[task.value for task in tasks],
357368
batch_size=batch_size,
358-
apply_chat_template=apply_chat_template,
369+
apply_chat_template=args.pop("apply_chat_template", True if tokenizer.chat_template is not None else False),
359370
gen_kwargs=args.pop("gen_kwargs", "temperature=0.0,top_k=50"),
360371
random_seed=random_seed,
361372
numpy_random_seed=random_seed,

gptqmodel/utils/eval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from enum import Enum
2020
from typing import Optional
2121

22+
from .evalplus import patch_evalplus
2223

2324
class EVAL:
2425
class LM_EVAL(Enum):
@@ -56,13 +57,15 @@ def get_all_tasks_string(cls):
5657

5758

5859
def evalplus(
59-
model: str,
60+
model,
6061
dataset: str,
6162
batch: int = 1,
6263
trust_remote_code: bool = False,
6364
output_file: Optional[str] = None,
6465
backend: str = 'gptqmodel'
6566
):
67+
patch_evalplus(model)
68+
6669
try:
6770
from evalplus.evaluate import evaluate
6871
except BaseException:

gptqmodel/utils/evalplus.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import types
2+
3+
from tokenicer import Tokenicer
4+
from transformers import PreTrainedModel
5+
6+
7+
def patch_strip(self, *args, **kwargs):
8+
return self.config.name_or_path.strip(*args, **kwargs)
9+
10+
def patch_tostring(self):
11+
return self.config.name_or_path
12+
13+
def patch_evalplus(model):
14+
from ..models.base import BaseGPTQModel
15+
if isinstance(model, BaseGPTQModel) or isinstance(model, PreTrainedModel):
16+
model.strip = types.MethodType(patch_strip, model)
17+
model.__str__ = types.MethodType(patch_tostring, model)
18+
19+
import torch
20+
from evalplus.provider.base import DecoderBase
21+
from evalplus.provider.gptqmodel import GPTQModelDecoder
22+
from evalplus.provider.utility import extra_eos_for_direct_completion
23+
from gptqmodel.models import BaseGPTQModel
24+
25+
from .. import GPTQModel
26+
27+
class PatchedGPTQModelDecoder(DecoderBase):
28+
def __init__(
29+
self,
30+
name: str,
31+
dataset: str,
32+
gptqmodel_backend: str = 'auto',
33+
force_base_prompt: bool = False,
34+
**kwargs,
35+
):
36+
37+
super(GPTQModelDecoder, self).__init__(name=name, **kwargs)
38+
39+
if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available():
40+
device = torch.device("mps")
41+
elif hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available():
42+
device = torch.device("xpu")
43+
elif hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
44+
device = torch.device("cuda")
45+
else:
46+
device = torch.device("cpu")
47+
48+
self.device = device
49+
50+
kwargs = {
51+
"model_id_or_path": name,
52+
"trust_remote_code": self.trust_remote_code,
53+
"backend": gptqmodel_backend,
54+
"device": device
55+
}
56+
self.skip_special_tokens = True
57+
self.force_base_prompt = force_base_prompt
58+
if isinstance(name, BaseGPTQModel):
59+
self.model = name
60+
self.tokenizer = self.model.tokenizer
61+
elif isinstance(name, PreTrainedModel):
62+
self.model = name
63+
self.tokenizer = Tokenicer.load(name.config.name_or_path, trust_remote_code=self.trust_remote_code)
64+
elif isinstance(name, str):
65+
self.tokenizer = Tokenicer.load(name, trust_remote_code=self.trust_remote_code)
66+
self.model = GPTQModel.load(**kwargs)
67+
self.model = self.model.to(self.device)
68+
else:
69+
raise ValueError(f"`name` is invalid. expected: `model instance or str` actual: `{name}`")
70+
71+
if self.tokenizer is None:
72+
raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.")
73+
74+
if self.is_direct_completion(): # no chat template
75+
self.eos += extra_eos_for_direct_completion(dataset)
76+
else: # with chat template
77+
self.eos += ["\n```\n"]
78+
79+
GPTQModelDecoder.__init__ = PatchedGPTQModelDecoder.__init__

0 commit comments

Comments
 (0)