Skip to content

Commit 6f4df1a

Browse files
authored
Refactor _get_source_transforms to remove args
Differential Revision: D73800023 Pull Request resolved: #10519
1 parent d7201ab commit 6f4df1a

File tree

1 file changed

+135
-36
lines changed

1 file changed

+135
-36
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 135 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,37 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661661
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
662662
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
663663
_get_source_transforms(
664-
modelname=args.model,
665664
dtype_override=dtype_override,
665+
checkpoint=args.checkpoint,
666666
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
667-
args=args,
667+
tokenizer_path=args.tokenizer_path,
668+
use_spin_quant=args.use_spin_quant,
669+
embedding_quantize=args.embedding_quantize,
670+
use_shared_embedding=args.use_shared_embedding,
671+
quantization_mode=args.quantization_mode,
672+
group_size=args.group_size,
673+
calibration_tasks=args.calibration_tasks,
674+
calibration_limit=args.calibration_limit,
675+
calibration_seq_length=args.calibration_seq_length,
676+
expand_rope_table=args.expand_rope_table,
677+
use_custom_sdpa_with_attention_mask=getattr(
678+
args, "use_custom_sdpa_with_attention_mask", False
679+
),
680+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
681+
quantize_kv_cache=args.quantize_kv_cache,
682+
use_kv_cache=args.use_kv_cache,
683+
qnn=args.qnn,
684+
use_qnn_sha=args.use_qnn_sha,
685+
optimized_rotation_path=args.optimized_rotation_path,
686+
mps=args.mps,
687+
coreml=args.coreml,
688+
coreml_ios=args.coreml_ios,
689+
vulkan=args.vulkan,
690+
use_qat=args.use_qat,
691+
use_lora=args.use_lora,
692+
preq_mode=args.preq_mode,
693+
preq_group_size=args.preq_group_size,
694+
preq_embedding_quantize=args.preq_embedding_quantize,
668695
)
669696
)
670697

@@ -1189,23 +1216,69 @@ def _load_llama_model(
11891216

11901217

11911218
def _get_source_transforms( # noqa
1192-
modelname: str,
11931219
dtype_override: DType,
11941220
*,
1221+
checkpoint: Optional[str] = None,
11951222
checkpoint_dtype: Optional[DType] = None,
1196-
args,
1223+
tokenizer_path: Optional[str] = None,
1224+
use_spin_quant: Optional[str] = None,
1225+
embedding_quantize: Optional[str] = None,
1226+
use_shared_embedding: bool = False,
1227+
quantization_mode: Optional[str] = None,
1228+
group_size: Optional[int] = None,
1229+
calibration_tasks: Optional[List[str]] = None,
1230+
calibration_limit: Optional[int] = None,
1231+
calibration_seq_length: Optional[int] = None,
1232+
expand_rope_table: bool = False,
1233+
use_custom_sdpa_with_attention_mask: bool = False,
1234+
use_sdpa_with_kv_cache: bool = False,
1235+
quantize_kv_cache: bool = False,
1236+
use_kv_cache: bool = False,
1237+
qnn: bool = False,
1238+
use_qnn_sha: bool = False,
1239+
optimized_rotation_path: Optional[str] = None,
1240+
mps: bool = False,
1241+
coreml: bool = False,
1242+
coreml_ios: int = 15,
1243+
vulkan: bool = False,
1244+
use_qat: bool = False,
1245+
use_lora: int = 0,
1246+
preq_mode: Optional[str] = None,
1247+
preq_group_size: Optional[int] = None,
1248+
preq_embedding_quantize: Optional[str] = None,
11971249
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
11981250
"""
11991251
Return a list of functions that transform a graph.
12001252
12011253
Args:
1202-
modelname: The name of the model.
12031254
dtype_override: The dtype to use for the model.
1255+
checkpoint: Path to the checkpoint file.
12041256
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
12051257
it means that you want to run quantize transformations on the weights represented
12061258
in their original dtype, while the overall dtype of the model maybe something
12071259
different. If not specified, defaults to dtype_override.
1208-
args: The arguments passed to the script.
1260+
tokenizer_path: Path to the tokenizer file.
1261+
use_spin_quant: Type of spin quant to use ("cuda" or "native").
1262+
embedding_quantize: Type of embedding quantization.
1263+
quantization_mode: Type of quantization mode.
1264+
expand_rope_table: Whether to expand rope table.
1265+
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1266+
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1267+
quantize_kv_cache: Whether to quantize KV cache.
1268+
use_kv_cache: Whether to use KV cache.
1269+
qnn: Whether to use QNN.
1270+
use_qnn_sha: Whether to use QNN SHA.
1271+
optimized_rotation_path: Path to optimized rotation.
1272+
mps: Whether to use MPS.
1273+
coreml: Whether to use CoreML.
1274+
coreml_ios: CoreML iOS version.
1275+
vulkan: Whether to use Vulkan.
1276+
use_shared_embedding: Whether to use shared embedding.
1277+
use_qat: Whether to use QAT.
1278+
use_lora: LoRA rank (0 means no LoRA).
1279+
preq_mode: Pre-quantization mode.
1280+
preq_group_size: Pre-quantization group size.
1281+
preq_embedding_quantize: Pre-quantization embedding quantize.
12091282
12101283
Returns:
12111284
A list of transformation functions.
@@ -1216,21 +1289,21 @@ def _get_source_transforms( # noqa
12161289

12171290
transforms = []
12181291

1219-
if args.use_spin_quant:
1220-
if args.use_spin_quant == "cuda":
1292+
if use_spin_quant:
1293+
if use_spin_quant == "cuda":
12211294
from .source_transformation.spin_quant import (
12221295
inject_fast_hadamard_transform_cuda_for_spin_quant,
12231296
)
12241297

12251298
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
1226-
elif args.use_spin_quant == "native":
1299+
elif use_spin_quant == "native":
12271300
from .source_transformation.spin_quant import (
12281301
inject_fast_hadamard_transform_native_for_spin_quant,
12291302
)
12301303

12311304
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
12321305

1233-
if args.embedding_quantize:
1306+
if embedding_quantize:
12341307
"""
12351308
When this option is selected, it finds all embedding layers and transforms
12361309
into quantized embedding equivalent module.
@@ -1240,12 +1313,27 @@ def _get_source_transforms( # noqa
12401313
transformations based on the given checkpoint first. In those cases,
12411314
this wil be a no-op.
12421315
"""
1243-
modelname = f"{modelname}_e"
1316+
1317+
# Create a mock args object with the necessary attributes
1318+
class Args:
1319+
pass
1320+
1321+
args = Args()
1322+
args.checkpoint = checkpoint
1323+
args.tokenizer_path = tokenizer_path
1324+
args.embedding_quantize = embedding_quantize
1325+
args.use_shared_embedding = use_shared_embedding
1326+
args.use_qat = use_qat
1327+
args.use_lora = use_lora
1328+
args.preq_mode = preq_mode
1329+
args.preq_group_size = preq_group_size
1330+
args.preq_embedding_quantize = preq_embedding_quantize
1331+
12441332
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
12451333

12461334
# quantization_mode should be applied after embedding_quantize
12471335
# to support shared_embedding
1248-
if args.quantization_mode:
1336+
if quantization_mode:
12491337
"""
12501338
When this option is selected, it finds all linear layers and transforms
12511339
into quantized linear equivalent module.
@@ -1259,7 +1347,25 @@ def _get_source_transforms( # noqa
12591347
There are cases where this may be a no-op, namely, if all linears are
12601348
quantized in the checkpoint.
12611349
"""
1262-
modelname = f"{modelname}_q"
1350+
1351+
# Create a mock args object with the necessary attributes
1352+
class Args:
1353+
pass
1354+
1355+
args = Args()
1356+
args.checkpoint = checkpoint
1357+
args.tokenizer_path = tokenizer_path
1358+
args.quantization_mode = quantization_mode
1359+
args.group_size = group_size
1360+
args.use_shared_embedding = use_shared_embedding
1361+
args.calibration_tasks = calibration_tasks
1362+
args.calibration_limit = calibration_limit
1363+
args.calibration_seq_length = calibration_seq_length
1364+
args.use_shared_embedding = use_shared_embedding
1365+
args.use_qat = use_qat
1366+
args.use_lora = use_lora
1367+
args.preq_mode = preq_mode
1368+
12631369
transforms.append(
12641370
get_quant_weight_transform(
12651371
args=args,
@@ -1268,15 +1374,12 @@ def _get_source_transforms( # noqa
12681374
)
12691375
)
12701376

1271-
if args.expand_rope_table:
1377+
if expand_rope_table:
12721378
transforms.append(materialze_broadcast_of_rope_freq_cis)
12731379

1274-
use_attention_mask_for_custom_sdpa = False
1275-
if isinstance(args, argparse.Namespace):
1276-
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1277-
use_attention_mask_for_custom_sdpa = True
1380+
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12781381

1279-
if args.use_sdpa_with_kv_cache:
1382+
if use_sdpa_with_kv_cache:
12801383
transforms.append(replace_kv_cache_with_custom_kv_cache)
12811384
# todo: do this optionally
12821385
# if use attention mask instead of causal attention
@@ -1288,24 +1391,22 @@ def _get_source_transforms( # noqa
12881391
else:
12891392
transforms.append(replace_sdpa_with_custom_op)
12901393

1291-
if args.quantize_kv_cache:
1292-
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1394+
if quantize_kv_cache:
1395+
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
12931396
transforms.append(replace_kv_cache_with_quantized_kv_cache)
12941397
# Right now
12951398
transforms.append(replace_sdpa_with_quantized_sdpa)
12961399

1297-
if args.use_kv_cache:
1298-
if args.qnn:
1400+
if use_kv_cache:
1401+
if qnn:
12991402
from executorch.backends.qualcomm.utils.utils import (
13001403
convert_linear_to_conv2d,
13011404
)
13021405

1303-
if args.use_qnn_sha:
1304-
if args.optimized_rotation_path:
1406+
if use_qnn_sha:
1407+
if optimized_rotation_path:
13051408
transforms.append(fuse_layer_norms)
1306-
transforms.append(
1307-
get_model_with_r1_r2(args.optimized_rotation_path)
1308-
)
1409+
transforms.append(get_model_with_r1_r2(optimized_rotation_path))
13091410
transforms.append(replace_attention_to_attention_sha)
13101411
transforms.append(replace_causal_mask)
13111412
transforms.append(replace_rms_norm_with_native_rms_norm)
@@ -1316,29 +1417,27 @@ def _get_source_transforms( # noqa
13161417
transforms.append(replace_sdpa_with_flex_sdpa)
13171418
transforms.append(replace_causal_mask)
13181419
transforms.append(replace_rms_norm_with_native_rms_norm)
1319-
if args.optimized_rotation_path:
1420+
if optimized_rotation_path:
13201421
transforms.append(fuse_layer_norms)
1321-
transforms.append(
1322-
get_model_with_r1_r2(args.optimized_rotation_path)
1323-
)
1422+
transforms.append(get_model_with_r1_r2(optimized_rotation_path))
13241423
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
13251424
transforms.append(convert_linear_to_conv2d)
13261425

1327-
elif args.mps:
1426+
elif mps:
13281427
# Currently mps doesn't support sdpa op, use the simpler decomposition
13291428
# to get free perf gain.
13301429
transforms.append(replace_sdpa_with_simple_sdpa)
13311430
transforms.append(replace_causal_mask)
13321431

1333-
elif args.coreml:
1432+
elif coreml:
13341433
# iOS 18 introduced fused sdpa op
1335-
if args.coreml_ios >= 18:
1434+
if coreml_ios >= 18:
13361435
transforms.append(replace_sdpa_with_coreml_sdpa)
13371436
else:
13381437
transforms.append(replace_sdpa_with_simple_sdpa)
13391438
transforms.append(replace_kv_cache_with_coreml_kv_cache)
13401439

1341-
if args.vulkan:
1440+
if vulkan:
13421441
transforms.append(replace_with_vulkan_rotary_emb)
13431442

13441443
return transforms

0 commit comments

Comments
 (0)