Skip to content

Commit 3a6d41e

Browse files
committed
restored tests file
Signed-off-by: Onkar Chougule <[email protected]>
1 parent 10e2c1f commit 3a6d41e

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

tests/transformers/models/test_causal_lm_models.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import Optional
1010

1111
import numpy as np
12-
1312
import pytest
1413
from transformers import AutoModelForCausalLM
1514

@@ -89,15 +88,15 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
8988
Constants.CTX_LEN,
9089
)
9190

92-
# pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf)
91+
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf)
9392
is_tlm = False if num_speculative_tokens is None else True
9493
qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm)
9594

9695
pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)
9796

98-
# assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
99-
# "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
100-
# )
97+
assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
98+
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
99+
)
101100

102101
onnx_model_path = qeff_model.export()
103102
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm)
@@ -117,6 +116,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
117116
)
118117
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
119118
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
119+
120120
gen_len = ort_tokens.shape[-1]
121121
assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), (
122122
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
@@ -128,20 +128,20 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
128128
config = model_hf.config
129129
full_batch_size = 4
130130
fbs_prompts = Constants.INPUT_STR * 4
131-
# api_runner = ApiRunner(
132-
# batch_size,
133-
# tokenizer,
134-
# config,
135-
# fbs_prompts,
136-
# Constants.PROMPT_LEN,
137-
# Constants.CTX_LEN,
138-
# full_batch_size,
139-
# )
140-
141-
# pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf)
142-
# pytorch_hf_tokens = np.vstack(pytorch_hf_tokens)
143-
144-
qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm)
131+
api_runner = ApiRunner(
132+
batch_size,
133+
tokenizer,
134+
config,
135+
fbs_prompts,
136+
Constants.PROMPT_LEN,
137+
Constants.CTX_LEN,
138+
full_batch_size,
139+
)
140+
141+
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf)
142+
pytorch_hf_tokens = np.vstack(pytorch_hf_tokens)
143+
144+
qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=False)
145145
onnx_model_path = qeff_model.export()
146146

147147
if not get_available_device_id():
@@ -151,13 +151,12 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
151151
prefill_seq_len=prompt_len,
152152
ctx_len=ctx_len,
153153
num_cores=14,
154-
mxfp6=False,
154+
mxfp6_matmul=False,
155155
aic_enable_depth_first=False,
156156
full_batch_size=full_batch_size,
157157
num_speculative_tokens=num_speculative_tokens,
158158
)
159-
# exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
160-
qeff_model.generate(tokenizer, prompts=fbs_prompts)
159+
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts, device_id=[0])
161160

162161

163162
"""

0 commit comments

Comments
 (0)