Skip to content

Commit 1ae70de

Browse files
committed
qnn end to end flow for stories model
Pull Request resolved: #3038 Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize qnn_8a8w -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a little girl named Lily. She loved to play outside and explore the world around her. One day, she went on a walk with her mommy and they found a beautiful landscape with lots of trees and flowers. Lily said, "Mommy, this place is so pretty! Can we take a picture?" Mommy replied, "Of course, Lily! Let's take a picture to remember the original place we found." After they took the picture, they continued their walk and saw a bird flying in the sky. Lily said, "MomPyTorchObserver {"prompt_tokens":2,"generated_tokens":125,"model_load_start_ms":1713226585936,"model_load_end_ms":1713226586909,"inference_start_ms":1713226586909,"inference_end_ms":1713226590363,"prompt_eval_end_ms":1713226586966,"first_token_ms":1713226586994,"aggregate_sampling_time_ms":23,"SCALING_FACTOR_UNITS_PER_SECOND":1000} I 00:00:04.436699 executorch:runner.cpp:414] Prompt Tokens: 2 Generated Tokens: 125 I 00:00:04.436703 executorch:runner.cpp:420] Model Load Time: 0.973000 (seconds) I 00:00:04.436732 executorch:runner.cpp:430] Total inference time: 3.454000 (seconds) Rate: 36.189925 (tokens/second) I 00:00:04.436735 executorch:runner.cpp:438] Prompt evaluation: 0.057000 (seconds) Rate: 35.087719 (tokens/second) I 00:00:04.436739 executorch:runner.cpp:449] Generated 125 tokens: 3.397000 (seconds) Rate: 36.797174 (tokens/second) I 00:00:04.436742 executorch:runner.cpp:457] Time to first generated token: 0.085000 (seconds) I 00:00:04.436744 executorch:runner.cpp:464] Sampling time over 127 tokens: 0.023000 (seconds) [INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters [INFO] [Qnn ExecuTorch]: Destroy Qnn context ``` Stories model is too small and sensitive to qunatization. ghstack-source-id: 223136109 @exported-using-ghexport Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/)
1 parent f8fbbe6 commit 1ae70de

File tree

3 files changed

+62
-9
lines changed

3 files changed

+62
-9
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3030
}
3131
QNN_TENSOR_TYPE_MAP = {
32+
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
3233
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
3334
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
3435
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,

backends/qualcomm/partition/common_defs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
exir_ops.edge.aten.clone.default,
1414
exir_ops.edge.aten.index.Tensor,
1515
exir_ops.edge.aten.full.default,
16+
exir_ops.edge.aten.slice_scatter.default,
17+
exir_ops.edge.aten.index_put.default,
1618
]
1719

1820
allow_list_operator = [

examples/models/llama2/export_llama_lib.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pkg_resources
2121
import torch
22+
import torch.nn.functional as F
2223
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
2324
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
2425
XnnpackDynamicallyQuantizedPartitioner,
@@ -364,6 +365,13 @@ def build_args_parser() -> argparse.ArgumentParser:
364365
parser.add_argument(
365366
"--pt2e_quantize",
366367
default=None,
368+
choices=[
369+
"xnnpack_dynamic",
370+
"xnnpack_dynamic_qc4",
371+
"qnn_8a8w",
372+
"qnn_16a16w",
373+
"qnn_16a4w",
374+
],
367375
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
368376
)
369377
parser.add_argument(
@@ -633,6 +641,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
633641
if args.use_sdpa_with_kv_cache:
634642
transforms.append(replace_sdpa_with_custom_op)
635643

644+
if args.qnn and args.use_kv_cache:
645+
transforms.append(replace_sdpa_with_simple_sdpa)
646+
transforms.append(replace_causal_mask)
636647
return (
637648
load_llama_model(
638649
modelname=modelname,
@@ -656,13 +667,16 @@ def _export_llama(modelname, args) -> str: # noqa: C901
656667
# export_to_edge
657668
pt2e_quant_params = _get_pt2e_quantization_params(args)
658669
quantizers = get_pt2e_quantizers(pt2e_quant_params, args)
659-
if args.qnn:
660-
assert (
661-
args.quantization_mode is None
662-
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
670+
quant_dtype = None
671+
if args.qnn and args.pt2e_quantize:
663672
try:
664673
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer`
665-
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
674+
from executorch.backends.qualcomm.quantizer.quantizer import (
675+
get_16a4w_qnn_ptq_config,
676+
get_default_16bit_qnn_ptq_config,
677+
QnnQuantizer,
678+
QuantDtype,
679+
)
666680

667681
# reset quantizers and pt2e_quant_params from xnnpack backend
668682
pt2e_quant_params = None
@@ -672,10 +686,36 @@ def _export_llama(modelname, args) -> str: # noqa: C901
672686
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
673687
)
674688

689+
backend, quant_config = args.pt2e_quantize.split("_")
690+
assert (
691+
backend == "qnn"
692+
), f"The quantization config is for backend {backend} instead of qnn."
675693
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
676694
qnn_quantizer = QnnQuantizer()
677695
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
678696
custom_annotations = ()
697+
if quant_config == "8a8w":
698+
quant_dtype = QuantDtype.use_8a8w
699+
pass
700+
elif quant_config == "16a16w":
701+
quant_dtype = QuantDtype.use_16a16w
702+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
703+
qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config())
704+
elif quant_config == "16a4w":
705+
quant_dtype = QuantDtype.use_16a4w
706+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
707+
qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config())
708+
qnn_quantizer.set_per_channel_weight_dtype(
709+
weight_dtype_for_16bit_act="int4"
710+
)
711+
else:
712+
raise AssertionError(
713+
f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w."
714+
)
715+
716+
assert (
717+
args.quantization_mode is None
718+
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
679719
qnn_quantizer.add_custom_quant_annotations(custom_annotations)
680720
quantizers.append(qnn_quantizer)
681721

@@ -793,24 +833,34 @@ def _export_llama(modelname, args) -> str: # noqa: C901
793833
)
794834

795835
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
796-
backend_options = generate_htp_compiler_spec(use_fp16=False)
836+
use_fp16 = True
837+
skip_node_op_set = {}
838+
if args.pt2e_quantize:
839+
use_fp16 = False
840+
# TODO: fix the lowering error without skipping nodes
841+
if quant_dtype == QuantDtype.use_8a8w:
842+
raise NotImplementedError("8a8w for llama is still under development")
843+
elif quant_dtype == QuantDtype.use_16a16w:
844+
raise NotImplementedError("16a16w for llama is still under development")
845+
elif quant_dtype == QuantDtype.use_16a4w:
846+
raise NotImplementedError("16a4w for llama is still under development")
797847
partitioners.append(
798848
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
799849
QnnPartitioner(
800850
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
801851
generate_qnn_executorch_compiler_spec(
802852
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
803853
soc_model=QcomChipset.SM8650, # default to SM8650
804-
backend_options=backend_options,
854+
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
805855
debug=False,
806856
saver=False,
807857
),
808858
skip_node_id_set={},
809-
skip_node_op_set={},
859+
skip_node_op_set=skip_node_op_set,
810860
)
811861
)
812862
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
813-
_transform(builder_exported_to_edge.export_program())
863+
_transform(builder_exported_to_edge.edge_manager.exported_program())
814864

815865
if args.generate_etrecord:
816866
if not builder_exported_to_edge.edge_manager:

0 commit comments

Comments
 (0)