Skip to content
Merged
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ global-include gptqmodel_ext/**/*.cpp
global-include gptqmodel_ext/**/*.cu
global-include gptqmodel_ext/**/*.py
include requirements.txt
prune tests/
33 changes: 22 additions & 11 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os

from lm_eval.utils import make_table
from tokenicer import Tokenicer

if not os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'expandable_segments:True'
Expand All @@ -42,7 +43,7 @@
import numpy # noqa: E402
import torch # noqa: E402
from huggingface_hub import list_repo_files # noqa: E402
from transformers import AutoConfig,AutoTokenizer # noqa: E402
from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402

from ..quantization import QUANT_CONFIG_FILENAME # noqa: E402
from ..utils import BACKEND # noqa: E402
Expand Down Expand Up @@ -286,7 +287,8 @@ def from_quantized(
def eval(
cls,
model_or_id_or_path: str=None,
tasks: Union[List[EVAL.LM_EVAL], List[EVAL.EVALPLUS]] = None, # set to None to tifx mutable warning
tokenizer: PreTrainedTokenizerBase=None,
tasks: Union[List[EVAL.LM_EVAL], List[EVAL.EVALPLUS]] = None, # set to None to fix mutable warning
framework: EVAL = EVAL.LM_EVAL,
batch_size: int = 1,
trust_remote_code: bool = False,
Expand Down Expand Up @@ -316,20 +318,29 @@ def eval(
if isinstance(model_or_id_or_path, str):
model = None
model_id_or_path = model_or_id_or_path
elif isinstance(model_or_id_or_path, BaseGPTQModel) or isinstance(model_or_id_or_path, PreTrainedModel):
model = model_or_id_or_path
model_id_or_path = model.config.name_or_path #
else:
model = model_or_id_or_path
model_id_or_path = model.model_local_path
raise ValueError(f"`model_or_id_or_path` is invalid. expected: `model instance or str` actual: `{model_or_id_or_path}`")

if tokenizer is None:
if isinstance(model, BaseGPTQModel):
tokenizer = model.tokenizer
elif isinstance(model, PreTrainedModel) or model_id_or_path.strip():
tokenizer = Tokenicer.load(model_id_or_path)

if tokenizer is None:
raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.")

model_args["tokenizer"] = tokenizer

if framework == EVAL.LM_EVAL:
for task in tasks:
if task not in EVAL.get_task_enums():
raise ValueError(f"lm_eval support tasks: {EVAL.get_all_tasks_string()}")

# model_id_or_path=model_id_or_path if model_id_or_path else model.model_id_or_path
# tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
tokenizer = model.tokenizer if model else AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)

model_name = 'hf' if backend == 'gptqmodel' else backend
model_name = "hf" if backend == "gptqmodel" else backend

if backend == "gptqmodel":
model_args["gptqmodel"] = True
Expand All @@ -349,13 +360,13 @@ def eval(
batch_size=batch_size,
trust_remote_code=trust_remote_code,
)
apply_chat_template=args.pop("apply_chat_template", True if tokenizer.chat_template is not None else False)

results = simple_evaluate(
model=model_name,
model_args=model_args,
tasks=[task.value for task in tasks],
batch_size=batch_size,
apply_chat_template=apply_chat_template,
apply_chat_template=args.pop("apply_chat_template", True if tokenizer.chat_template is not None else False),
gen_kwargs=args.pop("gen_kwargs", "temperature=0.0,top_k=50"),
random_seed=random_seed,
numpy_random_seed=random_seed,
Expand Down
5 changes: 4 additions & 1 deletion gptqmodel/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from enum import Enum
from typing import Optional

from .evalplus import patch_evalplus

class EVAL:
class LM_EVAL(Enum):
Expand Down Expand Up @@ -56,13 +57,15 @@ def get_all_tasks_string(cls):


def evalplus(
model: str,
model,
dataset: str,
batch: int = 1,
trust_remote_code: bool = False,
output_file: Optional[str] = None,
backend: str = 'gptqmodel'
):
patch_evalplus(model)

try:
from evalplus.evaluate import evaluate
except BaseException:
Expand Down
79 changes: 79 additions & 0 deletions gptqmodel/utils/evalplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import types

from tokenicer import Tokenicer
from transformers import PreTrainedModel


def patch_strip(self, *args, **kwargs):
return self.config.name_or_path.strip(*args, **kwargs)

def patch_tostring(self):
return self.config.name_or_path

def patch_evalplus(model):
from ..models.base import BaseGPTQModel
if isinstance(model, BaseGPTQModel) or isinstance(model, PreTrainedModel):
model.strip = types.MethodType(patch_strip, model)
model.__str__ = types.MethodType(patch_tostring, model)

import torch
from evalplus.provider.base import DecoderBase
from evalplus.provider.gptqmodel import GPTQModelDecoder
from evalplus.provider.utility import extra_eos_for_direct_completion
from gptqmodel.models import BaseGPTQModel

from .. import GPTQModel

class PatchedGPTQModelDecoder(DecoderBase):
def __init__(
self,
name: str,
dataset: str,
gptqmodel_backend: str = 'auto',
force_base_prompt: bool = False,
**kwargs,
):

super(GPTQModelDecoder, self).__init__(name=name, **kwargs)

if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available():
device = torch.device("mps")
elif hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available():
device = torch.device("xpu")
elif hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")

self.device = device

kwargs = {
"model_id_or_path": name,
"trust_remote_code": self.trust_remote_code,
"backend": gptqmodel_backend,
"device": device
}
self.skip_special_tokens = True
self.force_base_prompt = force_base_prompt
if isinstance(name, BaseGPTQModel):
self.model = name
self.tokenizer = self.model.tokenizer
elif isinstance(name, PreTrainedModel):
self.model = name
self.tokenizer = Tokenicer.load(name.config.name_or_path, trust_remote_code=self.trust_remote_code)
elif isinstance(name, str):
self.tokenizer = Tokenicer.load(name, trust_remote_code=self.trust_remote_code)
self.model = GPTQModel.load(**kwargs)
self.model = self.model.to(self.device)
else:
raise ValueError(f"`name` is invalid. expected: `model instance or str` actual: `{name}`")

if self.tokenizer is None:
raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.")

if self.is_direct_completion(): # no chat template
self.eos += extra_eos_for_direct_completion(dataset)
else: # with chat template
self.eos += ["\n```\n"]

GPTQModelDecoder.__init__ = PatchedGPTQModelDecoder.__init__