36
36
from pathlib import Path
37
37
from typing import Union
38
38
39
+ import lm_eval
39
40
import numpy as np
40
- from lm_eval import utils
41
+ from lm_eval import evaluator , utils
41
42
from lm_eval .loggers import WandbLogger
42
43
from lm_eval .tasks import TaskManager
43
44
from lm_eval .utils import make_table , simple_parse_args_string
44
45
45
- from neural_compressor .evaluation .lm_eval import evaluator
46
- from neural_compressor .evaluation .lm_eval .evaluator import request_caching_arg_to_dict
47
-
48
46
DEFAULT_RESULTS_FILE = "results.json"
49
47
50
48
49
+ def request_caching_arg_to_dict (cache_requests : str ) -> dict :
50
+ request_caching_args = {
51
+ "cache_requests" : cache_requests in {"true" , "refresh" },
52
+ "rewrite_requests_cache" : cache_requests == "refresh" ,
53
+ "delete_requests_cache" : cache_requests == "delete" ,
54
+ }
55
+
56
+ return request_caching_args
57
+
58
+
51
59
def _handle_non_serializable (o ):
52
60
if isinstance (o , np .int64 ) or isinstance (o , np .int32 ):
53
61
return int (o )
@@ -143,8 +151,57 @@ def cli_evaluate(args) -> None:
143
151
144
152
request_caching_args = request_caching_arg_to_dict (cache_requests = args .cache_requests )
145
153
154
+ ### update model with user_model ###
155
+ if args .model_args is None :
156
+ args .model_args = ""
157
+ # replace HFLM.
158
+ from .models .huggingface import HFLM
159
+
160
+ lm_eval .api .registry .MODEL_REGISTRY ["hf-auto" ] = HFLM
161
+ lm_eval .api .registry .MODEL_REGISTRY ["hf" ] = HFLM
162
+ lm_eval .api .registry .MODEL_REGISTRY ["huggingface" ] = HFLM
163
+
164
+ if args .user_model is not None :
165
+ # use tiny model to built lm.
166
+ print (
167
+ "We use 'pretrained=Muennighoff/tiny-random-bert'"
168
+ + "to build `LM` instance, the actually run model is user_model you passed."
169
+ )
170
+ lm = lm_eval .api .registry .get_model (args .model ).create_from_arg_string (
171
+ "pretrained=Muennighoff/tiny-random-bert" ,
172
+ {
173
+ "batch_size" : args .batch_size ,
174
+ "max_batch_size" : args .max_batch_size ,
175
+ "device" : args .device ,
176
+ },
177
+ )
178
+ lm ._model = args .user_model
179
+ if args .tokenizer is not None :
180
+ lm .tokenizer = args .tokenizer
181
+ else :
182
+ assert False , "Please provide tokenizer in evaluation function"
183
+ elif isinstance (args .model_args , dict ):
184
+ lm = lm_eval .api .registry .get_model (args .model ).create_from_arg_obj (
185
+ args .model_args ,
186
+ {
187
+ "batch_size" : args .batch_size ,
188
+ "max_batch_size" : args .max_batch_size ,
189
+ "device" : args .device ,
190
+ },
191
+ )
192
+ else :
193
+ lm = lm_eval .api .registry .get_model (args .model ).create_from_arg_string (
194
+ args .model_args ,
195
+ {
196
+ "batch_size" : args .batch_size ,
197
+ "max_batch_size" : args .max_batch_size ,
198
+ "device" : args .device ,
199
+ },
200
+ )
201
+ lm .pad_to_buckets = args .pad_to_buckets
202
+
146
203
results = evaluator .simple_evaluate (
147
- model = args . model ,
204
+ model = lm ,
148
205
model_args = args .model_args ,
149
206
tasks = task_names ,
150
207
num_fewshot = args .num_fewshot ,
@@ -163,8 +220,6 @@ def cli_evaluate(args) -> None:
163
220
random_seed = args .seed [0 ],
164
221
numpy_random_seed = args .seed [1 ],
165
222
torch_random_seed = args .seed [2 ],
166
- user_model = args .user_model , # to validate the model in memory,
167
- tokenizer = args .tokenizer , # to use tokenizer in mem,
168
223
** request_caching_args ,
169
224
)
170
225
0 commit comments