26
26
from tokenizer import get_tokenizer
27
27
import time
28
28
from torchao .quantization .GPTQ import Int4WeightOnlyGPTQQuantizer
29
- from torchao ._models .llama .model import prepare_inputs_for_model
29
+ from torchao ._models .llama .model import prepare_inputs_for_model , TransformerBlock
30
30
from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
31
31
32
32
def run_evaluation (
@@ -122,6 +122,51 @@ def run_evaluation(
122
122
else :
123
123
if not TORCH_VERSION_AT_LEAST_2_5 :
124
124
unwrap_tensor_subclass (model )
125
+ if "autoround" in quantization :
126
+ from torchao .prototype .autoround .autoround_llm import quantize_model_with_autoround_
127
+ from transformers import AutoTokenizer
128
+
129
+ _tokenizer = AutoTokenizer .from_pretrained (checkpoint_path .parent )
130
+ # parse args from quantization string:
131
+ # autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>
132
+ _quant_args = quantization .split ("-" )
133
+ _default_quant_args = [False , 200 , 128 , 8 , 2048 , 128 ]
134
+ _model_devie = _quant_args [1 ] if len (_quant_args ) > 1 else device
135
+ _quant_args = _quant_args [2 :]
136
+ quant_lm_head , iters , groupsize , batch_size , seqlen , nsamples = [
137
+ int (x ) for x in _quant_args
138
+ ] + _default_quant_args [len (_quant_args ) :]
139
+ model = model .to (_model_devie )
140
+ print (
141
+ (
142
+ f"Quantizing model with autoround(iters={ iters } , groupsize={ groupsize } , "
143
+ f"quant_lm_head={ quant_lm_head } , batch_size={ batch_size } , seqlen={ seqlen } , nsamples={ nsamples } )"
144
+ )
145
+ )
146
+ with torch .device (_model_devie ):
147
+ model .setup_caches (
148
+ max_batch_size = batch_size , max_seq_length = seqlen , training = True
149
+ )
150
+
151
+ if quant_lm_head :
152
+ is_target_module = (
153
+ lambda mod , fqn : isinstance (mod , TransformerBlock )
154
+ or "output" in fqn
155
+ )
156
+ else :
157
+ is_target_module = lambda mod , fqn : isinstance (mod , TransformerBlock )
158
+ quantize_model_with_autoround_ (
159
+ model = model ,
160
+ tokenizer = _tokenizer ,
161
+ is_target_module = is_target_module ,
162
+ bits = 4 ,
163
+ seqlen = seqlen ,
164
+ bs = batch_size ,
165
+ iters = iters ,
166
+ nsamples = nsamples ,
167
+ )
168
+ model .to (device )
169
+ model .reset_caches ()
125
170
126
171
if compile :
127
172
model = torch .compile (model , mode = "max-autotune" , fullgraph = True )
@@ -145,11 +190,15 @@ def run_evaluation(
145
190
parser .add_argument ('--limit' , type = int , default = None , help = 'Number of eval samples to evaluate' )
146
191
parser .add_argument ('--precision' , type = lambda x : getattr (torch , x .split ("." )[- 1 ]), default = torch .bfloat16 , help = 'dtype precision to use' )
147
192
parser .add_argument ('--device' , type = str , default = "cuda" , help = 'Device to use for evaluation' )
148
- parser .add_argument ('-q' , '--quantization' , type = str ,
193
+ parser .add_argument (
194
+ "-q" ,
195
+ "--quantization" ,
196
+ type = str ,
149
197
help = (
150
- 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-gptq, autoquant, autoquant-int4, ' +
151
- 'int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
152
- )
198
+ "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-gptq, "
199
+ "autoquant, autoquant-int4, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, "
200
+ "sparse-marlin, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>"
201
+ ),
153
202
)
154
203
parser .add_argument ('--compile' , action = 'store_true' , help = 'Whether to compile the model.' )
155
204
parser .add_argument ('--max_length' , type = int , default = None , help = 'Length of text to process at one time' )
0 commit comments