9
9
from typing import Optional
10
10
11
11
import numpy as np
12
-
13
12
import pytest
14
13
from transformers import AutoModelForCausalLM
15
14
@@ -89,15 +88,15 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
89
88
Constants .CTX_LEN ,
90
89
)
91
90
92
- # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf)
91
+ pytorch_hf_tokens = api_runner .run_hf_model_on_pytorch (model_hf )
93
92
is_tlm = False if num_speculative_tokens is None else True
94
93
qeff_model = QEFFAutoModelForCausalLM (model_hf , is_tlm = is_tlm )
95
94
96
95
pytorch_kv_tokens = api_runner .run_kv_model_on_pytorch (qeff_model .model )
97
96
98
- # assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
99
- # "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
100
- # )
97
+ assert (pytorch_hf_tokens == pytorch_kv_tokens ).all (), (
98
+ "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
99
+ )
101
100
102
101
onnx_model_path = qeff_model .export ()
103
102
ort_tokens = api_runner .run_kv_model_on_ort (onnx_model_path , is_tlm = is_tlm )
@@ -117,6 +116,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
117
116
)
118
117
exec_info = qeff_model .generate (tokenizer , prompts = Constants .INPUT_STR )
119
118
cloud_ai_100_tokens = exec_info .generated_ids [0 ] # Because we always run for single input and single batch size
119
+
120
120
gen_len = ort_tokens .shape [- 1 ]
121
121
assert (ort_tokens == cloud_ai_100_tokens [:, :gen_len ]).all (), (
122
122
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
@@ -128,20 +128,20 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
128
128
config = model_hf .config
129
129
full_batch_size = 4
130
130
fbs_prompts = Constants .INPUT_STR * 4
131
- # api_runner = ApiRunner(
132
- # batch_size,
133
- # tokenizer,
134
- # config,
135
- # fbs_prompts,
136
- # Constants.PROMPT_LEN,
137
- # Constants.CTX_LEN,
138
- # full_batch_size,
139
- # )
140
-
141
- # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf)
142
- # pytorch_hf_tokens = np.vstack(pytorch_hf_tokens)
143
-
144
- qeff_model = QEFFAutoModelForCausalLM (model_hf , continuous_batching = True , is_tlm = is_tlm )
131
+ api_runner = ApiRunner (
132
+ batch_size ,
133
+ tokenizer ,
134
+ config ,
135
+ fbs_prompts ,
136
+ Constants .PROMPT_LEN ,
137
+ Constants .CTX_LEN ,
138
+ full_batch_size ,
139
+ )
140
+
141
+ pytorch_hf_tokens = api_runner .run_hf_model_on_pytorch_CB (model_hf )
142
+ pytorch_hf_tokens = np .vstack (pytorch_hf_tokens )
143
+
144
+ qeff_model = QEFFAutoModelForCausalLM (model_hf , continuous_batching = True , is_tlm = False )
145
145
onnx_model_path = qeff_model .export ()
146
146
147
147
if not get_available_device_id ():
@@ -151,13 +151,12 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
151
151
prefill_seq_len = prompt_len ,
152
152
ctx_len = ctx_len ,
153
153
num_cores = 14 ,
154
- mxfp6 = False ,
154
+ mxfp6_matmul = False ,
155
155
aic_enable_depth_first = False ,
156
156
full_batch_size = full_batch_size ,
157
157
num_speculative_tokens = num_speculative_tokens ,
158
158
)
159
- # exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
160
- qeff_model .generate (tokenizer , prompts = fbs_prompts )
159
+ exec_info_fbs = qeff_model .generate (tokenizer , prompts = fbs_prompts , device_id = [0 ])
161
160
162
161
163
162
"""
0 commit comments