@@ -661,10 +661,37 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661
661
logging .info (f"Checkpoint dtype: { edge_manager .model .checkpoint_dtype } " )
662
662
edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
663
663
_get_source_transforms (
664
- modelname = args .model ,
665
664
dtype_override = dtype_override ,
665
+ checkpoint = args .checkpoint ,
666
666
checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
667
- args = args ,
667
+ tokenizer_path = args .tokenizer_path ,
668
+ use_spin_quant = args .use_spin_quant ,
669
+ embedding_quantize = args .embedding_quantize ,
670
+ use_shared_embedding = args .use_shared_embedding ,
671
+ quantization_mode = args .quantization_mode ,
672
+ group_size = args .group_size ,
673
+ calibration_tasks = args .calibration_tasks ,
674
+ calibration_limit = args .calibration_limit ,
675
+ calibration_seq_length = args .calibration_seq_length ,
676
+ expand_rope_table = args .expand_rope_table ,
677
+ use_custom_sdpa_with_attention_mask = getattr (
678
+ args , "use_custom_sdpa_with_attention_mask" , False
679
+ ),
680
+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
681
+ quantize_kv_cache = args .quantize_kv_cache ,
682
+ use_kv_cache = args .use_kv_cache ,
683
+ qnn = args .qnn ,
684
+ use_qnn_sha = args .use_qnn_sha ,
685
+ optimized_rotation_path = args .optimized_rotation_path ,
686
+ mps = args .mps ,
687
+ coreml = args .coreml ,
688
+ coreml_ios = args .coreml_ios ,
689
+ vulkan = args .vulkan ,
690
+ use_qat = args .use_qat ,
691
+ use_lora = args .use_lora ,
692
+ preq_mode = args .preq_mode ,
693
+ preq_group_size = args .preq_group_size ,
694
+ preq_embedding_quantize = args .preq_embedding_quantize ,
668
695
)
669
696
)
670
697
@@ -1189,23 +1216,69 @@ def _load_llama_model(
1189
1216
1190
1217
1191
1218
def _get_source_transforms ( # noqa
1192
- modelname : str ,
1193
1219
dtype_override : DType ,
1194
1220
* ,
1221
+ checkpoint : Optional [str ] = None ,
1195
1222
checkpoint_dtype : Optional [DType ] = None ,
1196
- args ,
1223
+ tokenizer_path : Optional [str ] = None ,
1224
+ use_spin_quant : Optional [str ] = None ,
1225
+ embedding_quantize : Optional [str ] = None ,
1226
+ use_shared_embedding : bool = False ,
1227
+ quantization_mode : Optional [str ] = None ,
1228
+ group_size : Optional [int ] = None ,
1229
+ calibration_tasks : Optional [List [str ]] = None ,
1230
+ calibration_limit : Optional [int ] = None ,
1231
+ calibration_seq_length : Optional [int ] = None ,
1232
+ expand_rope_table : bool = False ,
1233
+ use_custom_sdpa_with_attention_mask : bool = False ,
1234
+ use_sdpa_with_kv_cache : bool = False ,
1235
+ quantize_kv_cache : bool = False ,
1236
+ use_kv_cache : bool = False ,
1237
+ qnn : bool = False ,
1238
+ use_qnn_sha : bool = False ,
1239
+ optimized_rotation_path : Optional [str ] = None ,
1240
+ mps : bool = False ,
1241
+ coreml : bool = False ,
1242
+ coreml_ios : int = 15 ,
1243
+ vulkan : bool = False ,
1244
+ use_qat : bool = False ,
1245
+ use_lora : int = 0 ,
1246
+ preq_mode : Optional [str ] = None ,
1247
+ preq_group_size : Optional [int ] = None ,
1248
+ preq_embedding_quantize : Optional [str ] = None ,
1197
1249
) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
1198
1250
"""
1199
1251
Return a list of functions that transform a graph.
1200
1252
1201
1253
Args:
1202
- modelname: The name of the model.
1203
1254
dtype_override: The dtype to use for the model.
1255
+ checkpoint: Path to the checkpoint file.
1204
1256
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
1205
1257
it means that you want to run quantize transformations on the weights represented
1206
1258
in their original dtype, while the overall dtype of the model maybe something
1207
1259
different. If not specified, defaults to dtype_override.
1208
- args: The arguments passed to the script.
1260
+ tokenizer_path: Path to the tokenizer file.
1261
+ use_spin_quant: Type of spin quant to use ("cuda" or "native").
1262
+ embedding_quantize: Type of embedding quantization.
1263
+ quantization_mode: Type of quantization mode.
1264
+ expand_rope_table: Whether to expand rope table.
1265
+ use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1266
+ use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1267
+ quantize_kv_cache: Whether to quantize KV cache.
1268
+ use_kv_cache: Whether to use KV cache.
1269
+ qnn: Whether to use QNN.
1270
+ use_qnn_sha: Whether to use QNN SHA.
1271
+ optimized_rotation_path: Path to optimized rotation.
1272
+ mps: Whether to use MPS.
1273
+ coreml: Whether to use CoreML.
1274
+ coreml_ios: CoreML iOS version.
1275
+ vulkan: Whether to use Vulkan.
1276
+ use_shared_embedding: Whether to use shared embedding.
1277
+ use_qat: Whether to use QAT.
1278
+ use_lora: LoRA rank (0 means no LoRA).
1279
+ preq_mode: Pre-quantization mode.
1280
+ preq_group_size: Pre-quantization group size.
1281
+ preq_embedding_quantize: Pre-quantization embedding quantize.
1209
1282
1210
1283
Returns:
1211
1284
A list of transformation functions.
@@ -1216,21 +1289,21 @@ def _get_source_transforms( # noqa
1216
1289
1217
1290
transforms = []
1218
1291
1219
- if args . use_spin_quant :
1220
- if args . use_spin_quant == "cuda" :
1292
+ if use_spin_quant :
1293
+ if use_spin_quant == "cuda" :
1221
1294
from .source_transformation .spin_quant import (
1222
1295
inject_fast_hadamard_transform_cuda_for_spin_quant ,
1223
1296
)
1224
1297
1225
1298
transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
1226
- elif args . use_spin_quant == "native" :
1299
+ elif use_spin_quant == "native" :
1227
1300
from .source_transformation .spin_quant import (
1228
1301
inject_fast_hadamard_transform_native_for_spin_quant ,
1229
1302
)
1230
1303
1231
1304
transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
1232
1305
1233
- if args . embedding_quantize :
1306
+ if embedding_quantize :
1234
1307
"""
1235
1308
When this option is selected, it finds all embedding layers and transforms
1236
1309
into quantized embedding equivalent module.
@@ -1240,12 +1313,27 @@ def _get_source_transforms( # noqa
1240
1313
transformations based on the given checkpoint first. In those cases,
1241
1314
this wil be a no-op.
1242
1315
"""
1243
- modelname = f"{ modelname } _e"
1316
+
1317
+ # Create a mock args object with the necessary attributes
1318
+ class Args :
1319
+ pass
1320
+
1321
+ args = Args ()
1322
+ args .checkpoint = checkpoint
1323
+ args .tokenizer_path = tokenizer_path
1324
+ args .embedding_quantize = embedding_quantize
1325
+ args .use_shared_embedding = use_shared_embedding
1326
+ args .use_qat = use_qat
1327
+ args .use_lora = use_lora
1328
+ args .preq_mode = preq_mode
1329
+ args .preq_group_size = preq_group_size
1330
+ args .preq_embedding_quantize = preq_embedding_quantize
1331
+
1244
1332
transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
1245
1333
1246
1334
# quantization_mode should be applied after embedding_quantize
1247
1335
# to support shared_embedding
1248
- if args . quantization_mode :
1336
+ if quantization_mode :
1249
1337
"""
1250
1338
When this option is selected, it finds all linear layers and transforms
1251
1339
into quantized linear equivalent module.
@@ -1259,7 +1347,25 @@ def _get_source_transforms( # noqa
1259
1347
There are cases where this may be a no-op, namely, if all linears are
1260
1348
quantized in the checkpoint.
1261
1349
"""
1262
- modelname = f"{ modelname } _q"
1350
+
1351
+ # Create a mock args object with the necessary attributes
1352
+ class Args :
1353
+ pass
1354
+
1355
+ args = Args ()
1356
+ args .checkpoint = checkpoint
1357
+ args .tokenizer_path = tokenizer_path
1358
+ args .quantization_mode = quantization_mode
1359
+ args .group_size = group_size
1360
+ args .use_shared_embedding = use_shared_embedding
1361
+ args .calibration_tasks = calibration_tasks
1362
+ args .calibration_limit = calibration_limit
1363
+ args .calibration_seq_length = calibration_seq_length
1364
+ args .use_shared_embedding = use_shared_embedding
1365
+ args .use_qat = use_qat
1366
+ args .use_lora = use_lora
1367
+ args .preq_mode = preq_mode
1368
+
1263
1369
transforms .append (
1264
1370
get_quant_weight_transform (
1265
1371
args = args ,
@@ -1268,15 +1374,12 @@ def _get_source_transforms( # noqa
1268
1374
)
1269
1375
)
1270
1376
1271
- if args . expand_rope_table :
1377
+ if expand_rope_table :
1272
1378
transforms .append (materialze_broadcast_of_rope_freq_cis )
1273
1379
1274
- use_attention_mask_for_custom_sdpa = False
1275
- if isinstance (args , argparse .Namespace ):
1276
- if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1277
- use_attention_mask_for_custom_sdpa = True
1380
+ use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
1278
1381
1279
- if args . use_sdpa_with_kv_cache :
1382
+ if use_sdpa_with_kv_cache :
1280
1383
transforms .append (replace_kv_cache_with_custom_kv_cache )
1281
1384
# todo: do this optionally
1282
1385
# if use attention mask instead of causal attention
@@ -1288,24 +1391,22 @@ def _get_source_transforms( # noqa
1288
1391
else :
1289
1392
transforms .append (replace_sdpa_with_custom_op )
1290
1393
1291
- if args . quantize_kv_cache :
1292
- assert args . use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1394
+ if quantize_kv_cache :
1395
+ assert use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1293
1396
transforms .append (replace_kv_cache_with_quantized_kv_cache )
1294
1397
# Right now
1295
1398
transforms .append (replace_sdpa_with_quantized_sdpa )
1296
1399
1297
- if args . use_kv_cache :
1298
- if args . qnn :
1400
+ if use_kv_cache :
1401
+ if qnn :
1299
1402
from executorch .backends .qualcomm .utils .utils import (
1300
1403
convert_linear_to_conv2d ,
1301
1404
)
1302
1405
1303
- if args . use_qnn_sha :
1304
- if args . optimized_rotation_path :
1406
+ if use_qnn_sha :
1407
+ if optimized_rotation_path :
1305
1408
transforms .append (fuse_layer_norms )
1306
- transforms .append (
1307
- get_model_with_r1_r2 (args .optimized_rotation_path )
1308
- )
1409
+ transforms .append (get_model_with_r1_r2 (optimized_rotation_path ))
1309
1410
transforms .append (replace_attention_to_attention_sha )
1310
1411
transforms .append (replace_causal_mask )
1311
1412
transforms .append (replace_rms_norm_with_native_rms_norm )
@@ -1316,29 +1417,27 @@ def _get_source_transforms( # noqa
1316
1417
transforms .append (replace_sdpa_with_flex_sdpa )
1317
1418
transforms .append (replace_causal_mask )
1318
1419
transforms .append (replace_rms_norm_with_native_rms_norm )
1319
- if args . optimized_rotation_path :
1420
+ if optimized_rotation_path :
1320
1421
transforms .append (fuse_layer_norms )
1321
- transforms .append (
1322
- get_model_with_r1_r2 (args .optimized_rotation_path )
1323
- )
1422
+ transforms .append (get_model_with_r1_r2 (optimized_rotation_path ))
1324
1423
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
1325
1424
transforms .append (convert_linear_to_conv2d )
1326
1425
1327
- elif args . mps :
1426
+ elif mps :
1328
1427
# Currently mps doesn't support sdpa op, use the simpler decomposition
1329
1428
# to get free perf gain.
1330
1429
transforms .append (replace_sdpa_with_simple_sdpa )
1331
1430
transforms .append (replace_causal_mask )
1332
1431
1333
- elif args . coreml :
1432
+ elif coreml :
1334
1433
# iOS 18 introduced fused sdpa op
1335
- if args . coreml_ios >= 18 :
1434
+ if coreml_ios >= 18 :
1336
1435
transforms .append (replace_sdpa_with_coreml_sdpa )
1337
1436
else :
1338
1437
transforms .append (replace_sdpa_with_simple_sdpa )
1339
1438
transforms .append (replace_kv_cache_with_coreml_kv_cache )
1340
1439
1341
- if args . vulkan :
1440
+ if vulkan :
1342
1441
transforms .append (replace_with_vulkan_rotary_emb )
1343
1442
1344
1443
return transforms
0 commit comments