1+ import argparse
2+ import time
3+
14import torch
25import torch .distributed as dist
36import transformers
47
58import colossalai
6- import time
79from colossalai .inference import PPInferEngine
810from colossalai .inference .pipeline .policy .llama_ppinfer import LlamaForCausalLMPipelinePolicy
9- import argparse
10- GIGABYTE = 1024 ** 3
11+
12+ GIGABYTE = 1024 ** 3
1113MEGABYTE = 1024 * 1024
1214
1315colossalai .launch_from_torch (config = {})
1416
15- def data_gen (batch_size : int = 4 , seq_len : int = 512 ):
17+
18+ def data_gen (batch_size : int = 4 , seq_len : int = 512 ):
1619 input_ids = torch .randint (10 , 30000 , (1 , seq_len ), dtype = torch .int32 )
1720 attention_mask = torch .ones ((1 , seq_len ), dtype = torch .int32 )
1821 data = dict (input_ids = input_ids , attention_mask = attention_mask )
1922 for k , v in data .items ():
20- if torch .is_tensor (v ) or ' Tensor' in v .__class__ .__name__ :
23+ if torch .is_tensor (v ) or " Tensor" in v .__class__ .__name__ :
2124 new_shape = [1 ] * v .dim ()
2225 new_shape [0 ] = batch_size
23- data [k ] = v .to (' cuda' ).repeat (* new_shape )
26+ data [k ] = v .to (" cuda" ).repeat (* new_shape )
2427 return data
2528
29+
2630def print_details_info (timestamps , model_config , args , whole_end2end ):
2731 if dist .get_rank () == 0 :
2832 prefill = []
@@ -31,32 +35,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
3135 for timestamp in timestamps :
3236 prefill .append (timestamp [1 ] - timestamp [0 ])
3337 encoder .append (
34- sum (timestamp [i + 1 ] - timestamp [i ] for i in range (1 ,len (timestamp ) - 1 )) / (len (timestamp ) - 2 ))
38+ sum (timestamp [i + 1 ] - timestamp [i ] for i in range (1 , len (timestamp ) - 1 )) / (len (timestamp ) - 2 )
39+ )
3540 end2end .append (timestamp [- 1 ] - timestamp [0 ])
3641 print (whole_end2end )
37- with open (f"{ args .log_path } /llama-{ args .model } { args .dtype } _pp{ args .pp_size } _{ args .seq_len } _{ args .new_length } _bsz{ args .batch_size } _mbsz{ args .mb_size } .log" ,"w+" ) as f :
38- mb_avg_end2end = sum (end2end )/ len (end2end )
39- mb_avg_latency = mb_avg_end2end / (args .new_length * args .mb_size )
40- whole_avg_latency = whole_end2end / (args .new_length * args .batch_size )
42+ with open (
43+ f"{ args .log_path } /llama-{ args .model } { args .dtype } _pp{ args .pp_size } _{ args .seq_len } _{ args .new_length } _bsz{ args .batch_size } _mbsz{ args .mb_size } .log" ,
44+ "w+" ,
45+ ) as f :
46+ mb_avg_end2end = sum (end2end ) / len (end2end )
47+ mb_avg_latency = mb_avg_end2end / (args .new_length * args .mb_size )
48+ whole_avg_latency = whole_end2end / (args .new_length * args .batch_size )
4149 num_layers = getattr (model_config , "num_layers" , model_config .num_hidden_layers )
4250 num_parameters = num_layers * model_config .hidden_size * model_config .hidden_size * 12 / args .pp_size
43- if args .dtype in [' fp16' , ' bf16' ]:
51+ if args .dtype in [" fp16" , " bf16" ]:
4452 num_bytes = 2
4553 else :
4654 num_bytes = 4
4755
48- f .write (f"llama-{ args .model } { args .dtype } _pp{ args .pp_size } , input_len:{ args .seq_len } , output_len:{ args .new_length } , bsz:{ args .batch_size } , mbsz:{ args .mb_size } \n " )
49- f .write ("Average prefill time: {0:8.2f} ms\n " .format (sum (prefill )/ len (prefill )* 1000 ))
50- f .write ("Average encode time: {0:8.2f} ms\n " .format (sum (encoder )/ len (encoder )* 1000 ))
51- f .write ("Average micro batch end2end time: {0:8.2f} ms\n " .format (mb_avg_end2end * 1000 ))
56+ f .write (
57+ f"llama-{ args .model } { args .dtype } _pp{ args .pp_size } , input_len:{ args .seq_len } , output_len:{ args .new_length } , bsz:{ args .batch_size } , mbsz:{ args .mb_size } \n "
58+ )
59+ f .write ("Average prefill time: {0:8.2f} ms\n " .format (sum (prefill ) / len (prefill ) * 1000 ))
60+ f .write ("Average encode time: {0:8.2f} ms\n " .format (sum (encoder ) / len (encoder ) * 1000 ))
61+ f .write ("Average micro batch end2end time: {0:8.2f} ms\n " .format (mb_avg_end2end * 1000 ))
5262 f .write ("Average micro batch Per Token Latency: {0:8.2f} ms\n " .format (mb_avg_latency * 1000 ))
53- f .write ("Whole batch end2end time: {0:8.2f} ms\n " .format (whole_end2end * 1000 ))
63+ f .write ("Whole batch end2end time: {0:8.2f} ms\n " .format (whole_end2end * 1000 ))
5464 f .write ("Whole batch Per Token Latency: {0:8.2f} ms\n " .format (whole_avg_latency * 1000 ))
55- f .write ("Throughput: {} tokens/s\n " .format ((1000 / (whole_avg_latency * 1000 ))))
56- f .write ("flops: {0:8.2f} TFlops/s\n " .format (1 / whole_avg_latency * num_parameters * num_bytes / 1e12 ))
65+ f .write ("Throughput: {} tokens/s\n " .format ((1000 / (whole_avg_latency * 1000 ))))
66+ f .write ("flops: {0:8.2f} TFlops/s\n " .format (1 / whole_avg_latency * num_parameters * num_bytes / 1e12 ))
5767 f .write ("----------------------------------------------------------\n " )
5868
59-
6069 if torch .cuda .is_available ():
6170 current_device = torch .cuda .current_device ()
6271
@@ -66,7 +75,10 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
6675 max_memory_allocated = torch .cuda .max_memory_allocated ()
6776 memory_reserved = torch .cuda .memory_reserved ()
6877 max_memory_reserved = torch .cuda .max_memory_reserved ()
69- with open (f"{ args .log_path } /llama-{ args .model } { args .dtype } _pp{ args .pp_size } _{ args .seq_len } _{ args .new_length } _bsz{ args .batch_size } _mbsz{ args .mb_size } .log" ,"a" ) as f :
78+ with open (
79+ f"{ args .log_path } /llama-{ args .model } { args .dtype } _pp{ args .pp_size } _{ args .seq_len } _{ args .new_length } _bsz{ args .batch_size } _mbsz{ args .mb_size } .log" ,
80+ "a" ,
81+ ) as f :
7082 f .write (
7183 f"\n Currently using GPU: { current_device } \n "
7284 f"free memory : { global_free_memory / GIGABYTE :.4f} GB,\n "
@@ -77,29 +89,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
7789 f"Max CUDA memory reserved/cached: { max_memory_reserved / GIGABYTE :.4f} GB,\n "
7890 )
7991
80- if __name__ == '__main__' :
92+
93+ if __name__ == "__main__" :
8194 parser = argparse .ArgumentParser ()
82- parser .add_argument (' --model' , default = ' toy' , help = ' the size of model' )
83- parser .add_argument ('-b' , ' --batch_size' , type = int , default = 8 , help = ' batch size' )
84- parser .add_argument ('-s' , ' --seq_len' , type = int , default = 8 , help = ' sequence length' )
85- parser .add_argument (' --new_length' , type = int , default = 4 , help = ' new tokens length' )
86- parser .add_argument (' --mb_size' , type = int , default = 1 , help = ' micro_batch_size' )
87- parser .add_argument (' --pp_size' , type = int , default = 2 , help = ' pipeline size' )
88- parser .add_argument (' --log_path' , type = str , default = ' ./log' , help = ' where to store the benchmark log' )
89- parser .add_argument (' --dtype' , type = str , default = ' fp16' , help = ' data type' )
95+ parser .add_argument (" --model" , default = " toy" , help = " the size of model" )
96+ parser .add_argument ("-b" , " --batch_size" , type = int , default = 8 , help = " batch size" )
97+ parser .add_argument ("-s" , " --seq_len" , type = int , default = 8 , help = " sequence length" )
98+ parser .add_argument (" --new_length" , type = int , default = 4 , help = " new tokens length" )
99+ parser .add_argument (" --mb_size" , type = int , default = 1 , help = " micro_batch_size" )
100+ parser .add_argument (" --pp_size" , type = int , default = 2 , help = " pipeline size" )
101+ parser .add_argument (" --log_path" , type = str , default = " ./log" , help = " where to store the benchmark log" )
102+ parser .add_argument (" --dtype" , type = str , default = " fp16" , help = " data type" )
90103 args = parser .parse_args ()
91104
92- if args .model == ' toy' :
105+ if args .model == " toy" :
93106 model = transformers .LlamaForCausalLM (transformers .LlamaConfig (num_hidden_layers = 8 ))
94- elif args .model == '7b' :
95- model = transformers .LlamaForCausalLM (transformers .AutoConfig .from_pretrained (' decapoda-research/llama-7b-hf' ))
96- elif args .model == ' 13b' :
97- model = transformers .LlamaForCausalLM (transformers .AutoConfig .from_pretrained (' decapoda-research/llama-13b-hf' ))
107+ elif args .model == "7b" :
108+ model = transformers .LlamaForCausalLM (transformers .AutoConfig .from_pretrained (" decapoda-research/llama-7b-hf" ))
109+ elif args .model == " 13b" :
110+ model = transformers .LlamaForCausalLM (transformers .AutoConfig .from_pretrained (" decapoda-research/llama-13b-hf" ))
98111 else :
99112 raise NotImplementedError
100-
101-
102- engine = PPInferEngine (pp_size = args .pp_size , dtype = args .dtype , micro_batch_size = args .mb_size , new_length = args .new_length , model = model , model_policy = LlamaForCausalLMPipelinePolicy (),verbose = True )
113+
114+ engine = PPInferEngine (
115+ pp_size = args .pp_size ,
116+ dtype = args .dtype ,
117+ micro_batch_size = args .mb_size ,
118+ new_length = args .new_length ,
119+ model = model ,
120+ model_policy = LlamaForCausalLMPipelinePolicy (),
121+ verbose = True ,
122+ )
103123 data = data_gen (args .batch_size , args .seq_len )
104124
105125 torch .cuda .synchronize ()
@@ -109,4 +129,3 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
109129 whole_end2end = time .time () - whole_end2end
110130
111131 print_details_info (timestamps , model .config , args , whole_end2end )
112-
0 commit comments