25
25
)
26
26
27
27
torch .sparse .SparseSemiStructuredTensor ._FORCE_CUTLASS = False
28
-
28
+ torch . backends . cuda . enable_cudnn_sdp ( True )
29
29
30
30
class HostEvent :
31
31
def __init__ (self ):
@@ -256,6 +256,7 @@ def _load_model(checkpoint_path, device, precision):
256
256
def main (
257
257
prefill_size : Optional [int ] = None ,
258
258
prompt : str = "Hello, my name is" ,
259
+ demo_summarize_prompt : Optional [str ] = None ,
259
260
interactive : bool = False ,
260
261
num_samples : int = 5 ,
261
262
max_new_tokens : int = 100 ,
@@ -285,7 +286,11 @@ def main(
285
286
286
287
if prefill_size is not None and prefill_size > 0 :
287
288
# create prompt of prefill size
288
- prompt = "prompt " * (int (prefill_size ) - 3 )
289
+ if demo_summarize_prompt is None :
290
+ prompt = "prompt " * (int (prefill_size ) - 2 )
291
+ else :
292
+ with open (demo_summarize_prompt , "r" ) as f :
293
+ prompt = f .read ()
289
294
290
295
torchao .quantization .utils .recommended_inductor_config_setter ()
291
296
@@ -306,6 +311,12 @@ def main(
306
311
tokenizer = get_tokenizer (tokenizer_path , checkpoint_path )
307
312
308
313
encoded = encode_tokens (tokenizer , prompt , bos = True , device = device )
314
+
315
+ if demo_summarize_prompt is not None :
316
+ end_tag = encode_tokens (tokenizer , "\n <END_TEXT>" , bos = False , device = device )
317
+ encoded = encoded [:prefill_size - end_tag .size (0 )]
318
+ encoded = torch .cat ((encoded , end_tag ), dim = 0 )
319
+
309
320
prompt_length = encoded .size (0 )
310
321
311
322
torch .manual_seed (1234 )
@@ -390,6 +401,8 @@ def ffn_or_attn_only(mod, fqn):
390
401
quantize_ (
391
402
model , int8_dynamic_activation_int8_weight (), filter_fn = not_ffn_only
392
403
)
404
+ elif "int8dq_prefill_wo_decode" in quantization :
405
+ quantize_ (model , int8_dynamic_activation_int8_weight (weight_only_decode = True ))
393
406
else :
394
407
quantize_ (model , int8_dynamic_activation_int8_weight ())
395
408
if "int4wo" in quantization :
@@ -809,14 +822,23 @@ def callback(x):
809
822
nonlocal done_generating
810
823
if done_generating :
811
824
return
812
- buffer .append (tokenizer .decode ([period_id ] + x .tolist ())[1 :])
825
+ buffer .append (tokenizer .decode ([period_id ] + x .squeeze ( 0 ). tolist ())[1 :])
813
826
if x .item () == tokenizer .eos_id ():
814
827
done_generating = True
815
828
if len (buffer ) == 4 or done_generating :
816
829
print ("" .join (buffer ), end = "" , flush = True )
817
830
buffer .clear ()
818
- # print(, end='', flush=True)
831
+ # print(, end="", flush=True)
832
+
833
+ elif demo_summarize_prompt is not None and i >= 0 :
834
+ buffer = []
835
+ period_id = tokenizer .encode ("." )[0 ]
819
836
837
+ def callback (x ):
838
+ buffer .append (tokenizer .decode ([period_id ] + x .squeeze (0 ).tolist ())[1 :])
839
+ if len (buffer ) == 4 :
840
+ print ("" .join (buffer ), end = "" , flush = True )
841
+ buffer .clear ()
820
842
else :
821
843
callback = lambda x : x
822
844
t0 = time .perf_counter ()
@@ -851,15 +873,15 @@ def callback(x):
851
873
decode_start_event = decode_start_event ,
852
874
decode_end_event = decode_end_event ,
853
875
)
854
- if i == - 1 :
876
+ if i < 0 :
855
877
print (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
856
878
continue
857
879
if hasattr (prof , "export_chrome_trace" ):
858
880
prof .export_chrome_trace (f"{ profile } .json" )
859
881
device_sync (device = device ) # MKG
860
882
t = time .perf_counter () - t0
861
883
862
- if not interactive and prefill_size is None :
884
+ if not interactive and demo_summarize_prompt is None :
863
885
tok_list = y [0 ].tolist ()
864
886
# truncate text after end of string token
865
887
tokens = (
@@ -869,7 +891,7 @@ def callback(x):
869
891
)
870
892
print (tokenizer .decode (tokens ))
871
893
else :
872
- print ()
894
+ print (" \n " )
873
895
tokens_generated = y .size (- 1 ) - prompt_length
874
896
tokens_sec = tokens_generated / t
875
897
aggregate_metrics ["tokens_per_sec" ].append (tokens_sec )
@@ -913,7 +935,7 @@ def callback(x):
913
935
bandwidth = model_size * tokpersec
914
936
mem = torch .cuda .max_memory_reserved () / 1e9
915
937
print (f"Average overall tokens/sec: { tokpersec :.2f} " )
916
- print (f"Average decode tokens/sec: { decode_tokens_sec :.04f} s" )
938
+ print (f"Average decode tokens/sec: { decode_tokpersec :.04f} s" )
917
939
print (f"Average TTFT: { ttft :.04f} s" )
918
940
if device == "cuda" :
919
941
mem = torch .cuda .max_memory_reserved () / 1e9
@@ -975,6 +997,9 @@ def callback(x):
975
997
parser .add_argument (
976
998
"--prompt" , type = str , default = "Hello, my name is" , help = "Input prompt."
977
999
)
1000
+ parser .add_argument (
1001
+ "--demo_summarize_prompt" , type = str , help = "Read prompt from text file"
1002
+ )
978
1003
parser .add_argument (
979
1004
"--interactive" ,
980
1005
action = "store_true" ,
@@ -1073,6 +1098,7 @@ def callback(x):
1073
1098
main (
1074
1099
args .prefill_size ,
1075
1100
args .prompt ,
1101
+ args .demo_summarize_prompt ,
1076
1102
args .interactive ,
1077
1103
args .num_samples ,
1078
1104
args .max_new_tokens ,
0 commit comments