Skip to content

Commit a6700c7

Browse files
cccclaifacebook-github-bot
authored andcommitted
fix qnn export (#9808)
Summary: One missing item for the new tokenizer lib Differential Revision: D72263224
1 parent b66c319 commit a6700c7

File tree

1 file changed

+4
-4
lines changed
  • examples/qualcomm/oss_scripts/llama

1 file changed

+4
-4
lines changed

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from executorch.examples.models.llama.source_transformation.quantize import (
5858
get_quant_embedding_transform,
5959
)
60-
from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
6160
from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
6261
LlamaModel,
6362
ModelArgs,
@@ -77,6 +76,7 @@
7776
from executorch.extension.llm.export.builder import DType
7877
from pytorch_tokenizers import get_tokenizer
7978
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
79+
from pytorch_tokenizers import TiktokenTokenizer
8080

8181
from torch.ao.quantization.observer import MinMaxObserver
8282
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
@@ -141,7 +141,7 @@ def _kv_calibrate(
141141
# Llama2 tokenizer has no special tokens
142142
if isinstance(tokenizer, SentencePieceTokenizer):
143143
token_list = tokenizer.encode(user_prompts, bos=True, eos=False)
144-
elif isinstance(tokenizer, Tiktoken):
144+
elif isinstance(tokenizer, TiktokenTokenizer):
145145
token_list = tokenizer.encode(
146146
user_prompts, bos=True, eos=False, allowed_special="all"
147147
)
@@ -213,7 +213,7 @@ def _prefill_calibrate(
213213
# Llama2 tokenizer has no special tokens
214214
if isinstance(tokenizer, SentencePieceTokenizer):
215215
token_list = tokenizer.encode(user_prompts, bos=True, eos=False)
216-
elif isinstance(tokenizer, Tiktoken):
216+
elif isinstance(tokenizer, TiktokenTokenizer):
217217
token_list = tokenizer.encode(
218218
user_prompts, bos=True, eos=False, allowed_special="all"
219219
)
@@ -1111,7 +1111,7 @@ def export_llama(args) -> None:
11111111
runtime_tokenizer_path = args.tokenizer_bin
11121112
elif args.llama_model == "llama3_2":
11131113
assert isinstance(
1114-
tokenizer, Tiktoken
1114+
tokenizer, TiktokenTokenizer
11151115
), f"Wrong tokenizer provided for llama3_2."
11161116
runtime_tokenizer_path = args.tokenizer_model
11171117
else:

0 commit comments

Comments
 (0)