1919import torch_tensorrt
2020from transformers import AutoModelForCausalLM , AutoTokenizer
2121from contextlib import nullcontext
22- from utils import export_llm , generate , recordStats , time_generate , generate_with_kv_cache , get_zeroed_kv_cache_inputs
22+ from utils import export_llm , generate , recordStats , time_generate , generate_with_kv_cache
23+ import sys
24+ import os
2325
26+ # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
27+ sys .path .append (os .path .join (os .path .dirname (__file__ ), '..' ))
28+ from register_sdpa import *
2429
2530DEVICE = torch .device ("cuda:0" )
2631
2732def get_model (args ):
2833 with torch .no_grad ():
29- if args .model == "meta-llama/Llama-2-7b-chat-hf" :
30- model = (
31- AutoModelForCausalLM .from_pretrained (
32- args .model ,
33- use_cache = False ,
34- attn_implementation = "sdpa" ,
35- num_hidden_layers = 1
36- )
37- .eval ()
38- .cuda ()
39- )
40- elif args .model == "meta-llama/Llama-3.2-1B-Instruct" :
41- model = (
42- AutoModelForCausalLM .from_pretrained (
43- args .model ,
44- use_cache = False ,
45- attn_implementation = "sdpa" ,
46- num_hidden_layers = 1
47- )
48- .eval ()
49- .cuda ()
50- )
51-
52- elif args .model == "meta-llama/Llama-3.2-3B-Instruct" :
53- model = (
34+ # Supported list of models:
35+ # - meta-llama/Llama-3.2-1B-Instruct
36+ # - meta-llama/Llama-3.2-3B-Instruct
37+ # - meta-llama/Llama-3.1-8B-Instruct
38+ # - Qwen/Qwen2.5-1.5B-Instruct
39+ model = (
5440 AutoModelForCausalLM .from_pretrained (
5541 args .model ,
5642 use_cache = False ,
5743 attn_implementation = "sdpa" ,
58- # num_hidden_layers=2
59- )
60- .eval ()
61- .cuda ()
62- )
63- elif args .model == "meta-llama/Llama-3.1-8B-Instruct" :
64- model = (
65- AutoModelForCausalLM .from_pretrained (
66- args .model ,
67- use_cache = False ,
68- attn_implementation = "sdpa" , # num_hidden_layers=1
69- )
70- .eval ()
71- .cuda ()
72- )
73- elif args .model == "google/gemma-3-1b-it" :
74- model = (
75- AutoModelForCausalLM .from_pretrained (
76- "google/gemma-3-1b-it" ,
77- use_cache = False ,
78- attn_implementation = "sdpa"
44+ # num_hidden_layers=1
7945 )
8046 .eval ()
8147 .cuda ()
@@ -91,9 +57,9 @@ def get_model(args):
9157
9258
9359def compile_torchtrt (model , input_ids , args ):
94- max_seq_len = input_ids .shape [1 ] + args .max_tokens
60+ max_seq_len = input_ids .shape [1 ] + args .num_tokens
9561 ep = export_llm (model , input_ids , max_seq_len = max_seq_len )
96-
62+
9763 # Set precision specific flags
9864 use_fp32_acc = False
9965 use_explicit_typing = False
@@ -119,6 +85,7 @@ def compile_torchtrt(model, input_ids, args):
11985 disable_tf32 = True ,
12086 use_python_runtime = True ,
12187 debug = args .debug ,
88+ offload_module_to_cpu = True ,
12289 min_block_size = args .min_block_size ,
12390 )
12491
@@ -170,23 +137,29 @@ def measure_perf(trt_model, input_signature, backend_name):
170137 "--model" , type = str , default = "meta-llama/Llama-3.2-1B-Instruct" , help = "Name of LLM model"
171138 )
172139 arg_parser .add_argument (
173- "--tokenizer_path " ,
140+ "--tokenizer " ,
174141 type = str ,
175- default = "meta-llama/Llama-3.2-1B-Instruct " ,
142+ default = "" ,
176143 help = "Name of LLM model tokenizer" ,
177144 )
178145 arg_parser .add_argument (
179146 "--prompt" , type = str , default = "What is parallel programming ?" , help = "Prompt"
180147 )
181- arg_parser .add_argument ("--precision" , type = str , default = "FP16" , help = "Prompt " )
148+ arg_parser .add_argument ("--precision" , type = str , default = "FP16" , help = "Precision to use in the model. Options: FP16, BF16, FP32 " )
182149 arg_parser .add_argument (
183150 "--iterations" , type = int , default = 5 , help = "no. of iterations to run"
184151 )
185152 arg_parser .add_argument (
186153 "--min_block_size" , type = int , default = 1 , help = "no. of iterations to run"
187154 )
188155 arg_parser .add_argument (
189- "--max_tokens" , type = int , default = 128 , help = "no. of max tokens to be generated"
156+ "--num_tokens" , type = int , default = 128 , help = "no. of output tokens to be generated"
157+ )
158+ arg_parser .add_argument (
159+ "--batch_size" , type = int , default = 1 , help = "Batch size used for benchmarking"
160+ )
161+ arg_parser .add_argument (
162+ "--isl" , type = int , default = 2048 , help = "Input sequence length used for benchmarking"
190163 )
191164 arg_parser .add_argument (
192165 "--enable_pytorch_run" ,
@@ -196,8 +169,8 @@ def measure_perf(trt_model, input_signature, backend_name):
196169 arg_parser .add_argument (
197170 "--cache" ,
198171 type = str ,
199- default = "static " ,
200- help = "Type of KV cache to use" ,
172+ default = "" ,
173+ help = "Type of KV cache to use. Options: static_v1, static_v2, dynamic " ,
201174 )
202175 arg_parser .add_argument (
203176 "--cudagraph" ,
@@ -214,22 +187,24 @@ def measure_perf(trt_model, input_signature, backend_name):
214187 action = "store_true" ,
215188 help = "Enable benchmark (default: False)"
216189 )
190+
217191 args = arg_parser .parse_args ()
218192 with torch .inference_mode ():
219193 model = get_model (args )
220194
221- tokenizer = AutoTokenizer .from_pretrained (args .tokenizer_path )
195+ tokenizer = AutoTokenizer .from_pretrained (args .tokenizer or args . model )
222196
223- prompt = "What is parallel programming ?"
224- # prompt = "What is the capital of France ?"
225- model_inputs = tokenizer (prompt , return_tensors = "pt" )
226- input_ids = model_inputs ["input_ids" ].to (DEVICE )
227- # Prepare input prompt
228- # word = "What"
229- # word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence
230- # input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device
197+ # Prepare input for benchmarking or evaluation
198+ if args .benchmark :
199+ input_ids = torch .randint (1 , 10000 , (args .batch_size , args .isl ), dtype = torch .int64 ).to (model .device )
200+ position_ids = torch .arange (input_ids .shape [1 ]).unsqueeze (0 ).to (DEVICE )
201+ else :
202+ model_inputs = tokenizer (args .prompt , return_tensors = "pt" )
203+ input_ids = model_inputs ["input_ids" ].to (DEVICE )
204+ position_ids = torch .arange (input_ids .shape [1 ]).unsqueeze (0 ).to (DEVICE )
205+
231206
232- MAX_OUTPUT_SEQ_LENGTH = input_ids .shape [1 ] + args .max_tokens
207+ MAX_OUTPUT_SEQ_LENGTH = input_ids .shape [1 ] + args .num_tokens
233208 # Pyt
234209 pyt_gen_tokens = None
235210 pyt_timings = None
@@ -238,7 +213,6 @@ def measure_perf(trt_model, input_signature, backend_name):
238213 pyt_gen_tokens = generate (
239214 model , input_ids .clone (), MAX_OUTPUT_SEQ_LENGTH , tokenizer .eos_token_id
240215 )
241-
242216 if args .benchmark :
243217 pyt_timings = time_generate (
244218 generate ,
@@ -249,71 +223,22 @@ def measure_perf(trt_model, input_signature, backend_name):
249223 iterations = args .iterations ,
250224 )
251225 pyt_stats = recordStats (
252- "PyTorch" , pyt_timings , args .precision , batch_size = 1 , compile_time_s = None
226+ "PyTorch" , pyt_timings , args .precision , batch_size = args . batch_size , compile_time_s = None
253227 )
254228
255- # TRT
256- pyt_logits_tok1 = model .cuda ()(input_ids )
257- next_tokens = torch .argmax (pyt_logits_tok1 .logits [:, - 1 , :], dim = - 1 )
258- input_seq = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
259- pyt_logits_tok2 = model .cuda ()(input_seq )
260- from lower_sdpa import *
261- if args .cache == "static" :
262- # This import is required to register static KV cache transformations as lowering passes
263- from static_cache2 import *
264- trt_model = compile_torchtrt (model , input_ids , args )
265- kv_cache = get_zeroed_kv_cache_inputs (trt_model )
266-
267- # First token generation
268- pyt_keys = torch .load ("key.pt" ); pyt_values = torch .load ("value.pt" )
269- trt_logits , key_cache , value_cache , trt_keys_1 , trt_values_1 = trt_model (input_ids .clone (), True , * kv_cache , 0 , input_ids .shape [1 ])
270- print (f"Diff between pyt and trt logits: { torch .mean (torch .abs (pyt_logits_tok1 .logits - trt_logits ))} " )
271- print (f"Diff between pyt and trt keys: { torch .mean (torch .abs (pyt_keys - trt_keys_1 ))} " )
272- print (f"Diff between pyt and trt keys in cache: { torch .mean (torch .abs (pyt_keys - key_cache [:, :, :- 2 , :]))} " )
273- print (f"Diff between pyt and trt values: { torch .mean (torch .abs (pyt_values - trt_values_1 ))} " )
274- print (f"Diff between pyt and trt values in cache: { torch .mean (torch .abs (pyt_values - value_cache [:, :, :- 2 , :]))} " )
275- next_tokens = torch .argmax (trt_logits [:, - 1 , :], dim = - 1 )
276-
277- # Second token generation
278- trt_logits_2 , key_cache2 , value_cache2 , trt_keys_2 , trt_values_2 = trt_model (next_tokens [:, None ], False , key_cache .clone (), value_cache .clone (), input_ids .shape [1 ], input_ids .shape [1 ]+ 1 )
279- pyt_keys2 = torch .load ("key2.pt" ); pyt_values2 = torch .load ("value2.pt" )
280- print (f"Diff between pyt and trt logits: { torch .mean (torch .abs (pyt_logits_tok2 .logits [:, - 1 :, :] - trt_logits_2 ))} " )
281- print (f"Diff between pyt and trt keys: { torch .mean (torch .abs (pyt_keys2 [:, :, - 2 :- 1 , :] - trt_keys_2 ))} " )
282- print (f"Diff between pyt and trt keys in cache: { torch .mean (torch .abs (pyt_keys2 - key_cache2 [:, :, :- 1 , :]))} " )
283- print (f"Diff between pyt and trt values: { torch .mean (torch .abs (pyt_values2 [:, :, - 2 :- 1 , :] - trt_values_2 ))} " )
284- print (f"Diff between pyt and trt values in cache: { torch .mean (torch .abs (pyt_values2 - value_cache2 [:, :, :- 1 , :]))} " )
285- breakpoint ()
229+ if args .cache == "static_v1" :
230+ # This import is required to register static v1 KV cache transformations as lowering passes
231+ import static_cache_v1
232+ if args .cache == "static_v2" :
233+ # This import is required to register static v2 KV cache transformations as lowering passes
234+ import static_cache_v2
286235 elif args .cache == "dynamic" :
287- from dynamic_cache import *
288- trt_model = compile_torchtrt (model , input_ids , args )
289- breakpoint ()
290- kv_cache = get_zeroed_kv_cache_inputs (trt_model )
291- else :
292- # pyt_logits = model.cuda()(input_ids.clone())
293- trt_model = compile_torchtrt (model , input_ids , args )
294- # trt_logits = trt_model(input_ids.clone(), True)
295- # print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}")
296- # print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}")
297- if args .cache == "static" :
298- if args .cudagraph :
299- # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
300- # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
301- torch_tensorrt .runtime .set_cudagraphs_mode (True )
302-
303- trt_gen_tokens = generate_with_kv_cache (
304- trt_model , input_ids .clone (), MAX_OUTPUT_SEQ_LENGTH , tokenizer .eos_token_id ,
305- )
236+ import dynamic_cache
306237
307- if args .benchmark :
308- trt_timings = time_generate (
309- generate_with_kv_cache ,
310- trt_model ,
311- input_ids .clone (),
312- MAX_OUTPUT_SEQ_LENGTH ,
313- tokenizer .eos_token_id ,
314- iterations = args .iterations ,
315- )
316- elif args .cache == "dynamic" :
238+
239+ trt_model = compile_torchtrt (model , input_ids , args )
240+
241+ if args .cache == "static_v1" or args .cache == "static_v2" or args .cache == "dynamic" :
317242 if args .cudagraph :
318243 # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
319244 # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
@@ -332,7 +257,6 @@ def measure_perf(trt_model, input_signature, backend_name):
332257 tokenizer .eos_token_id ,
333258 iterations = args .iterations ,
334259 )
335-
336260 else :
337261 trt_gen_tokens = generate (
338262 trt_model , input_ids .clone (), MAX_OUTPUT_SEQ_LENGTH , tokenizer .eos_token_id ,
@@ -349,14 +273,20 @@ def measure_perf(trt_model, input_signature, backend_name):
349273
350274 if args .benchmark :
351275 trt_stats = recordStats (
352- "TensorRT" , trt_timings , args .precision , batch_size = 1 , compile_time_s = None
276+ "TensorRT" , trt_timings , args .precision , batch_size = args . batch_size , compile_time_s = None
353277 )
354278
355- if args .enable_pytorch_run :
356- print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
357- print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
279+
280+ if not args .benchmark :
281+ if args .enable_pytorch_run :
282+ print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
283+
284+ print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
358285
359- if args .benchmark :
286+ if args .enable_pytorch_run :
287+ print (f"PyTorch and TensorRT outputs match: { torch .equal (pyt_gen_tokens , trt_gen_tokens )} " )
288+
289+ if args .benchmark :
360290 if args .enable_pytorch_run :
361291 print ("=========PyTorch PERFORMANCE============ \n " )
362292 print (pyt_stats )
0 commit comments