Skip to content

Commit d8ac866

Browse files
authored
Qualcomm AI Engine Direct - Refactor llama runner (#10578)
Summary: - Refactored io_manager into five distinct components: - DecoderRunner: Module wrapper class. - PromptProcessor: Handles prompt processing using the decoder and key-value manager. - TokenGenerator: Generates tokens using the decoder and key-value manager. - KVManager: Manages key-value cache with kv_updater, including data buffer allocation, cache updates, and buffer updates in TensorImpl. - IBufferAlloc: Allocates data buffers from RPC memory or client buffer. - Validated story llama with CL=128, prefill_ar_len=16, QNN SDK: 2.32 - Original : | CL | prefill_ar_len | eval_mode | kv_updater | Model Load Time (seconds) | Prompt evaluation (seconds) | Generated token rate (tokens/seconds) | Time to first generated token (seconds) | | --- | --- | --- | --- | --- | --- | --- | --- | | 128 | 16 | KV | shift_pointer | 0.3082 | 0.0105 | 237.5553131 | 0.0152 | | 128 | 16 | KV | smart_mask | 0.2691 | 0.0501 | 258.9103433 | 0.0544 | | 128 | 16 | hybrid | shift_pointer | 0.3408 | 0.008 | 232.1754892 | 0.008 | | 128 | 16 | hybrid | smart_mask | 0.3175 | 0.0447 | 237.7134587 | 0.0447 | - Refactor: | CL | prefill_ar_len | eval_mode | kv_updater | Model Load Time (seconds) | Prompt evaluation (seconds) | Generated token rate (tokens/seconds) | Time to first generated token (seconds) | | --- | --- | --- | --- | --- | --- | --- | --- | | 128 | 16 | KV | shift_pointer |0.2808 | 0.0124 | 234.835 | 0.0124 | | 128 | 16 | KV | smart_mask | 0.238 | 0.027 | 251.004016 | 0.027 | | 128 | 16 | hybrid | shift_pointer | 0.3305 | 0.0082 | 229.1122162 | 0.0082 | | 128 | 16 | hybrid | smart_mask | 0.258| 0.013 |239.463602 | 0.013 | - Support multi-turn use case. - Validated on story llama. To simulate the scenario, I forced decode mode to generate 5 tokens each time. Tokens with random length are inserted after one round of prefill->decode finished. - Reproduce command: (Note that some whitespaces are missing due to decoding. But token is actually the same as golden.) ``` python examples/qualcomm/oss_scripts/llama/llama.py -b build-android --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --prompt "Once" "a little girl named Lily." "toys and her favorite toy was a big, red ball." "s mom asked her to help her with the laundry." "and she put all the clothes in the washing machine." --temperature 0 --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --ptq 16a4w -m SM8650 -H ${HOST} -s ${DEVICE}-a ${ARTIFACTS}--max_seq_len 128 --prefill_ar_len 16 Result: Once upon a time, there wasa little girl named Lily. She loved to play with hertoys and her favorite toy was a big, red ball. One day, Lily's mom asked her to help her with the laundry. Lily was happy to helpand she put all the clothes in the washing machine. After the clothes were ``` - Need to apply the below patch to forced decode mode to generate 5 tokens each time. ``` diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp index 8a81b598d..a8ec53cdb 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp @@ -170,7 +170,10 @@ Result<int64_t> TokenGenerator::generate( "Failed to set output tensor for module %s", forward_name_.c_str()); // Generate our tokens - while (pos < seq_len - 1) { + // force decode to generate 5 runs at most + int64_t max_pos = std::min(pos + 5, (int64_t)seq_len - 1); +// while (pos < seq_len - 1) { + while (pos < max_pos) { // Fill in the token and position data prepare_io(cur_token, pos); // Only update data pointer of the cache to the tensor for SHIFT_POINTER ```
1 parent 380eb5f commit d8ac866

25 files changed

+1988
-2367
lines changed

backends/qualcomm/runtime/SharedBuffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ std::size_t std::hash<CustomMemTensorInfo>::operator()(
2222
hash_val ^= std::hash<size_t>()(info.pos);
2323
hash_val ^= std::hash<size_t>()(info.tensor_bytes);
2424
for (int i = 0; i < info.rank; ++i) {
25-
hash_val ^= info.shape[i];
25+
hash_val ^= std::hash<uint32_t>()(info.shape[i]);
2626
}
2727
hash_val ^= std::hash<uint32_t>()(info.rank);
2828
hash_val ^= std::hash<executorch::aten::ScalarType>()(info.dtype);

backends/qualcomm/runtime/backends/QnnBackendFactory.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ std::unique_ptr<BackendConfigParameters> QnnBackendFactory::Create(
8080
options->soc_info(),
8181
htp_options);
8282
backend_params->qnn_mem_manager_ptr_ = std::make_unique<QnnMemManager>(
83-
implementation, backend_params->qnn_context_ptr_.get());
83+
implementation,
84+
backend_params->qnn_context_ptr_.get(),
85+
options->log_level());
8486
backend_params->backend_init_state_ = BackendInitializeState::INITIALIZED;
8587
} break;
8688
case QnnExecuTorchBackendType::kGpuBackend:

backends/qualcomm/runtime/backends/QnnMemManager.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ Error QnnMemManager::RegisterIonMem(
4747
}
4848
tensor_wrapper->SetMemHandle(handle);
4949
registered_map_.insert({handle, mem_ptr});
50-
QNN_EXECUTORCH_LOG_INFO(
51-
"Tensor %s is successfully registered to ION shared memory.",
52-
tensor_wrapper->GetName().c_str());
50+
if (log_level_ >= QnnExecuTorchLogLevel::kLogLevelInfo) {
51+
QNN_EXECUTORCH_LOG_INFO(
52+
"Tensor %s is successfully registered to ION shared memory.",
53+
tensor_wrapper->GetName().c_str());
54+
}
55+
5356
return Error::Ok;
5457
}
5558

@@ -92,9 +95,11 @@ Error QnnMemManager::RegisterCustomMem(
9295
}
9396
tensor_wrapper->SetMemHandle(handle);
9497
registered_map_.insert({handle, mem_ptr});
95-
QNN_EXECUTORCH_LOG_INFO(
96-
"Tensor %s is successfully registered to custom shared memory.",
97-
tensor_wrapper->GetName().c_str());
98+
if (log_level_ >= QnnExecuTorchLogLevel::kLogLevelInfo) {
99+
QNN_EXECUTORCH_LOG_INFO(
100+
"Tensor %s is successfully registered to custom shared memory.",
101+
tensor_wrapper->GetName().c_str());
102+
}
98103
return Error::Ok;
99104
}
100105

backends/qualcomm/runtime/backends/QnnMemManager.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ class QnnMemManager {
2121
public:
2222
explicit QnnMemManager(
2323
const QnnImplementation& implementation,
24-
QnnContext* context)
25-
: implementation_(implementation), context_(context) {}
24+
QnnContext* context,
25+
QnnExecuTorchLogLevel log_level)
26+
: implementation_(implementation),
27+
context_(context),
28+
log_level_(log_level) {}
2629
~QnnMemManager() {
2730
DeRegisterMem();
2831
}
@@ -63,6 +66,7 @@ class QnnMemManager {
6366

6467
const QnnImplementation& implementation_;
6568
QnnContext* context_;
69+
QnnExecuTorchLogLevel log_level_;
6670
std::unordered_map<Qnn_MemHandle_t, void*> registered_map_;
6771
std::unordered_map<CustomMemTensorInfo, void*> pre_registered_handles_;
6872
std::unordered_map<executorch::aten::ScalarType, Qnn_DataType_t>

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3681,7 +3681,7 @@ def test_llama3_2_1b(self):
36813681
if self.pre_gen_pte:
36823682
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
36833683

3684-
golden_start_with = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
3684+
golden_start_with = "<|start_header_id|>user<|end_header_id|>"
36853685
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
36863686
with Listener((self.ip, self.port)) as listener:
36873687
conn = listener.accept()

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,18 @@ list(
2828
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
2929
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
3030
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
31-
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.cpp
32-
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.h
31+
${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.cpp
32+
${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.h
33+
${CMAKE_CURRENT_LIST_DIR}/runner/prompt_processor.cpp
34+
${CMAKE_CURRENT_LIST_DIR}/runner/prompt_processor.h
35+
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.cpp
36+
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h
37+
${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h
38+
${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h
39+
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp
40+
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
41+
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp
42+
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.h
3343
)
3444

3545
list(
@@ -42,7 +52,7 @@ list(
4252
# build qnn llama runner
4353
add_executable(qnn_llama_runner ${_llama_runner__srcs})
4454
target_include_directories(
45-
qnn_llama_runner PUBLIC ${_common_include_directories}
55+
qnn_llama_runner PUBLIC ${_common_include_directories} ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/include
4656
)
4757

4858
target_link_options_shared_lib(quantized_ops_lib)

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
403403
logging.info("Quantizing the model...")
404404
calibrate(
405405
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
406-
args.prompt,
406+
args.prompt[0],
407407
fx_graph_module,
408408
tokenizer=tokenizer,
409409
ar_len=self.llama_meta["get_ar_len"],
@@ -756,7 +756,7 @@ def permute(w, heads):
756756
return quant_attrs
757757

758758

759-
def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
759+
def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
760760
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"
761761

762762
if args.model_mode == "kv":
@@ -782,14 +782,13 @@ def post_process():
782782
outputs.append(f.read())
783783

784784
seq_len = args.max_seq_len
785+
multi_prompts = " ".join([f'--prompt "{prompt}"' for prompt in args.prompt])
785786
runner_args = " ".join(
786787
[
787-
f'--prompt "{args.prompt}"',
788+
multi_prompts,
788789
f"--eval_mode {eval_mode}",
789790
f"--temperature {args.temperature}",
790791
f"--system_prompt '{args.system_prompt}'",
791-
f"--logits_scale {quant_attrs['scale']}",
792-
f"--logits_offset {quant_attrs['zero_point']}",
793792
]
794793
)
795794

@@ -932,9 +931,10 @@ def _build_parser():
932931

933932
parser.add_argument(
934933
"--prompt",
935-
help="User prompts for llama.",
934+
help="User prompts for Llama. When multiple prompts are entered, a multi-turn conversation will be initiated. Note that this feature is currently for testing purposes only.",
936935
required=True,
937936
type=str,
937+
nargs="+",
938938
)
939939

940940
parser.add_argument(
@@ -1018,7 +1018,7 @@ def _build_parser():
10181018

10191019
def export_llama(args) -> None:
10201020
if args.compile_only and args.pre_gen_pte:
1021-
exit("Cannot set both compile_only and pre_gen_pte as true")
1021+
raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true")
10221022

10231023
if args.model_mode == "kv":
10241024
pte_filename = "kv_llama_qnn"
@@ -1054,29 +1054,15 @@ def export_llama(args) -> None:
10541054
elif args.kv_updater == "shift_pointer":
10551055
args.kv_updater = shift_pointer_updater
10561056
else:
1057-
exit(f"Using an unkown kv update {args.kv_updater}")
1057+
raise RuntimeError(f"Using an unknown kv update {args.kv_updater}")
10581058

10591059
if args.pre_gen_pte:
1060-
quant_attrs = json.load(
1061-
open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt")
1062-
)
1063-
inference(
1064-
args, quant_attrs, pte_filename, runtime_tokenizer_path, args.pre_gen_pte
1065-
)
1066-
exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")
1060+
inference(args, pte_filename, runtime_tokenizer_path, args.pre_gen_pte)
1061+
print(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")
1062+
return
10671063

10681064
if args.compile_only:
1069-
quant_attrs = compile(args, pte_filename, tokenizer)
1070-
if quant_attrs:
1071-
json.dump(
1072-
{
1073-
"scale": quant_attrs["scale"],
1074-
"zero_point": quant_attrs["zero_point"],
1075-
},
1076-
open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"),
1077-
)
1078-
else:
1079-
logging.warning("Quant attributes of the logit is None.")
1065+
compile(args, pte_filename, tokenizer)
10801066

10811067
if args.ip and args.port != -1:
10821068
pte_path = f"{args.artifact}/{pte_filename}.pte"
@@ -1089,24 +1075,18 @@ def export_llama(args) -> None:
10891075
}
10901076
)
10911077
)
1092-
exit(f"Finish compile_only and save to {args.artifact}")
1078+
print(f"Finish compile_only and save to {args.artifact}")
1079+
return
1080+
1081+
compile(args, pte_filename, tokenizer)
1082+
inference(args, pte_filename, runtime_tokenizer_path)
10931083

1084+
1085+
def main():
1086+
parser = _build_parser()
1087+
args = parser.parse_args()
10941088
try:
1095-
quant_attrs = compile(args, pte_filename, tokenizer)
1096-
if quant_attrs:
1097-
logging.info(
1098-
f"Logit scale: {quant_attrs['scale']}; Logit offset: {quant_attrs['zero_point']}"
1099-
)
1100-
json.dump(
1101-
{
1102-
"scale": quant_attrs["scale"],
1103-
"zero_point": quant_attrs["zero_point"],
1104-
},
1105-
open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"),
1106-
)
1107-
else:
1108-
logging.warning("Quant attributes of the logit is None.")
1109-
inference(args, quant_attrs, pte_filename, runtime_tokenizer_path)
1089+
export_llama(args)
11101090
except Exception as e:
11111091
if args.ip and args.port != -1:
11121092
with Client((args.ip, args.port)) as conn:
@@ -1115,12 +1095,6 @@ def export_llama(args) -> None:
11151095
raise Exception(e)
11161096

11171097

1118-
def main():
1119-
parser = _build_parser()
1120-
args = parser.parse_args()
1121-
export_llama(args)
1122-
1123-
11241098
# flake8: noqa: C901
11251099
if __name__ == "__main__":
11261100
main()

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ DEFINE_string(
3434
"inference_speed.txt",
3535
"Records inference speed. For CI purpose.");
3636
DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
37-
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
37+
DEFINE_string(
38+
prompt,
39+
"The answer to the ultimate question is",
40+
"User prompts for Llama. When multiple prompts are entered, a multi-turn conversation will be initiated. Note that this feature is currently for testing purposes only.");
3841
DEFINE_string(
3942
system_prompt,
4043
"",
@@ -49,10 +52,8 @@ DEFINE_int32(
4952
"Total number of tokens to generate (prompt + output).");
5053
DEFINE_int32(
5154
eval_mode,
52-
1,
55+
0,
5356
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)");
54-
DEFINE_double(logits_scale, 0.0, "Logits scale");
55-
DEFINE_int32(logits_offset, 0, "Logits offset");
5657
DEFINE_string(
5758
kv_updater,
5859
"How to update kv cache. Choose between SmartMask and ShiftPointer",
@@ -72,20 +73,46 @@ std::vector<std::string> CollectPrompts(int argc, char** argv) {
7273
return prompts;
7374
}
7475

76+
std::string get_formatted_prompt(
77+
const std::string& prompt,
78+
const std::string& system_prompt,
79+
example::LlamaVersion llama_version) {
80+
std::string formatted_prompt;
81+
switch (llama_version) {
82+
case example::LlamaVersion::kLlama2:
83+
formatted_prompt.append(prompt);
84+
break;
85+
case example::LlamaVersion::kLlama3:
86+
if (!system_prompt.empty()) {
87+
formatted_prompt.append(
88+
"<|start_header_id|>system<|end_header_id|>\n\n");
89+
formatted_prompt.append(system_prompt);
90+
formatted_prompt.append("<|eot_id|>");
91+
}
92+
formatted_prompt.append("<|start_header_id|>user<|end_header_id|>\n\n");
93+
formatted_prompt.append(prompt);
94+
formatted_prompt.append(
95+
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
96+
break;
97+
default:
98+
ET_CHECK_MSG(false, "unsupported llama version");
99+
break;
100+
}
101+
return formatted_prompt;
102+
}
103+
75104
int main(int argc, char** argv) {
76105
std::vector<std::string> prompts = CollectPrompts(argc, argv);
77106
gflags::ParseCommandLineFlags(&argc, &argv, true);
78107
// create llama runner
79108
example::Runner runner(
80-
{FLAGS_model_path},
109+
FLAGS_model_path.c_str(),
81110
FLAGS_tokenizer_path.c_str(),
82111
FLAGS_performance_output_path.c_str(),
83-
FLAGS_logits_scale,
84-
FLAGS_logits_offset,
85112
FLAGS_temperature,
86113
FLAGS_eval_mode,
87-
FLAGS_kv_updater,
88-
FLAGS_num_iters);
114+
FLAGS_kv_updater);
115+
auto llama_version = runner.get_llama_version();
89116
std::vector<char> buf;
90117
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
91118
std::ofstream fout(FLAGS_output_path.c_str());
@@ -97,8 +124,10 @@ int main(int argc, char** argv) {
97124
// generate tokens & store inference output
98125
for (int i = 0; i < FLAGS_num_iters; i++) {
99126
for (const auto& prompt : prompts) {
100-
runner.generate(
101-
FLAGS_seq_len, prompt.c_str(), FLAGS_system_prompt.c_str(), callback);
127+
std::string formatted_prompt;
128+
formatted_prompt = get_formatted_prompt(
129+
prompt, FLAGS_system_prompt, llama_version.get());
130+
runner.generate(formatted_prompt.c_str(), FLAGS_seq_len, callback);
102131
}
103132
}
104133
fout.write(buf.data(), buf.size());
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Qualcomm Innovation Center, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/imem_alloc.h>
12+
#include <vector>
13+
14+
namespace example {
15+
/**
16+
* @class ClientMem
17+
* @brief Final class for client buffer allocation, implementing IBufferAlloc
18+
* interface. Used for SHIFT_POINTER mode.
19+
*/
20+
class ClientMem final : public IMemAlloc {
21+
public:
22+
ClientMem(){};
23+
// Disable copy constructors, r-value referencing, etc
24+
ClientMem(const ClientMem&) = delete;
25+
ClientMem& operator=(const ClientMem&) = delete;
26+
ClientMem(ClientMem&&) = delete;
27+
ClientMem& operator=(ClientMem&&) = delete;
28+
virtual ~ClientMem(){};
29+
/**
30+
* @brief Allocate buffer of specified size with vector.
31+
* @param data_size Size of the data to allocate.
32+
* @return Pointer to the allocated buffer.
33+
*/
34+
std::byte* allocate(size_t data_size) override {
35+
allocated_buffers_.push_back(std::vector<std::byte>(data_size));
36+
return allocated_buffers_.back().data();
37+
};
38+
// Only used for SMART_MASK mode
39+
void add_memory_info(
40+
void* data_ptr,
41+
size_t data_size,
42+
executorch::runtime::TensorInfo tensor_info) override {};
43+
44+
private:
45+
std::vector<std::vector<std::byte>> allocated_buffers_;
46+
};
47+
48+
} // namespace example

0 commit comments

Comments
 (0)