Skip to content

Commit 7bbc473

Browse files
authored
add pad_to_buckets in evaluation for hpu performance (#2011)
* add pad_to_buckets in evaluation for hpu performance --------- Signed-off-by: xin3he <[email protected]>
1 parent b6b7d7c commit 7bbc473

File tree

6 files changed

+102
-788
lines changed

6 files changed

+102
-788
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.nn.functional import pad
1313
from torch.utils.data import DataLoader
1414
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
15+
from neural_compressor.torch.utils import is_hpex_available
1516

1617
parser = argparse.ArgumentParser()
1718
parser.add_argument(
@@ -324,22 +325,26 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
324325
user_model, _ = get_user_model()
325326
tokenizer = AutoTokenizer.from_pretrained(args.model)
326327
config = AutoConfig.from_pretrained(args.model)
327-
user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)), user_model)
328+
user_model = load(
329+
os.path.abspath(os.path.expanduser(args.output_dir)),
330+
user_model,
331+
device="hpu" if is_hpex_available() else "cpu",
332+
)
328333
setattr(user_model, "config", config)
329334
else:
330335
user_model, tokenizer = get_user_model()
331336

332337

333338
if args.accuracy:
334339
user_model.eval()
335-
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
340+
from neural_compressor.evaluation.lm_eval import evaluate, LMEvalParser
336341
eval_args = LMEvalParser(
337342
model="hf",
338343
user_model=user_model,
339344
tokenizer=tokenizer,
340345
batch_size=args.batch_size,
341346
tasks=args.tasks,
342-
device="cpu",
347+
device="hpu" if is_hpex_available() else "cpu",
343348
)
344349
results = evaluate(eval_args)
345350
for task_name in args.tasks.split(","):
@@ -352,7 +357,7 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
352357

353358
if args.performance:
354359
user_model.eval()
355-
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
360+
from neural_compressor.evaluation.lm_eval import evaluate, LMEvalParser
356361
import time
357362

358363
samples = args.iters * args.batch_size
@@ -363,7 +368,7 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
363368
batch_size=args.batch_size,
364369
tasks=args.tasks,
365370
limit=samples,
366-
device="cpu",
371+
device="hpu" if is_hpex_available() else "cpu",
367372
)
368373
start = time.time()
369374
results = evaluate(eval_args)

neural_compressor/evaluation/lm_eval/accuracy.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,26 @@
3636
from pathlib import Path
3737
from typing import Union
3838

39+
import lm_eval
3940
import numpy as np
40-
from lm_eval import utils
41+
from lm_eval import evaluator, utils
4142
from lm_eval.loggers import WandbLogger
4243
from lm_eval.tasks import TaskManager
4344
from lm_eval.utils import make_table, simple_parse_args_string
4445

45-
from neural_compressor.evaluation.lm_eval import evaluator
46-
from neural_compressor.evaluation.lm_eval.evaluator import request_caching_arg_to_dict
47-
4846
DEFAULT_RESULTS_FILE = "results.json"
4947

5048

49+
def request_caching_arg_to_dict(cache_requests: str) -> dict:
50+
request_caching_args = {
51+
"cache_requests": cache_requests in {"true", "refresh"},
52+
"rewrite_requests_cache": cache_requests == "refresh",
53+
"delete_requests_cache": cache_requests == "delete",
54+
}
55+
56+
return request_caching_args
57+
58+
5159
def _handle_non_serializable(o):
5260
if isinstance(o, np.int64) or isinstance(o, np.int32):
5361
return int(o)
@@ -143,8 +151,57 @@ def cli_evaluate(args) -> None:
143151

144152
request_caching_args = request_caching_arg_to_dict(cache_requests=args.cache_requests)
145153

154+
### update model with user_model ###
155+
if args.model_args is None:
156+
args.model_args = ""
157+
# replace HFLM.
158+
from .models.huggingface import HFLM
159+
160+
lm_eval.api.registry.MODEL_REGISTRY["hf-auto"] = HFLM
161+
lm_eval.api.registry.MODEL_REGISTRY["hf"] = HFLM
162+
lm_eval.api.registry.MODEL_REGISTRY["huggingface"] = HFLM
163+
164+
if args.user_model is not None:
165+
# use tiny model to built lm.
166+
print(
167+
"We use 'pretrained=Muennighoff/tiny-random-bert'"
168+
+ "to build `LM` instance, the actually run model is user_model you passed."
169+
)
170+
lm = lm_eval.api.registry.get_model(args.model).create_from_arg_string(
171+
"pretrained=Muennighoff/tiny-random-bert",
172+
{
173+
"batch_size": args.batch_size,
174+
"max_batch_size": args.max_batch_size,
175+
"device": args.device,
176+
},
177+
)
178+
lm._model = args.user_model
179+
if args.tokenizer is not None:
180+
lm.tokenizer = args.tokenizer
181+
else:
182+
assert False, "Please provide tokenizer in evaluation function"
183+
elif isinstance(args.model_args, dict):
184+
lm = lm_eval.api.registry.get_model(args.model).create_from_arg_obj(
185+
args.model_args,
186+
{
187+
"batch_size": args.batch_size,
188+
"max_batch_size": args.max_batch_size,
189+
"device": args.device,
190+
},
191+
)
192+
else:
193+
lm = lm_eval.api.registry.get_model(args.model).create_from_arg_string(
194+
args.model_args,
195+
{
196+
"batch_size": args.batch_size,
197+
"max_batch_size": args.max_batch_size,
198+
"device": args.device,
199+
},
200+
)
201+
lm.pad_to_buckets = args.pad_to_buckets
202+
146203
results = evaluator.simple_evaluate(
147-
model=args.model,
204+
model=lm,
148205
model_args=args.model_args,
149206
tasks=task_names,
150207
num_fewshot=args.num_fewshot,
@@ -163,8 +220,6 @@ def cli_evaluate(args) -> None:
163220
random_seed=args.seed[0],
164221
numpy_random_seed=args.seed[1],
165222
torch_random_seed=args.seed[2],
166-
user_model=args.user_model, # to validate the model in memory,
167-
tokenizer=args.tokenizer, # to use tokenizer in mem,
168223
**request_caching_args,
169224
)
170225

0 commit comments

Comments
 (0)