@@ -60,61 +60,62 @@ def run_torch_compile(model, backend='openvino', dynamic=None, options=None, chi
6060
6161def create_text_gen_model (model_path , device , memory_data_collector , ** kwargs ):
6262 model_path = Path (model_path )
63- from_pretrain_time = 0
64- if model_path .exists ():
65- if model_path .is_dir () and len (os .listdir (model_path )) != 0 :
66- log .info (f'Load text model from model path:{ model_path } ' )
67- model_class = kwargs ['use_case' ].pt_cls
68- token_class = kwargs ['use_case' ].tokenizer_cls
69- if kwargs .get ("mem_consumption" ):
70- memory_data_collector .start ()
71- start = time .perf_counter ()
72- trust_remote_code = False
73- try :
74- model = model_class .from_pretrained (model_path , trust_remote_code = trust_remote_code )
75- except Exception :
76- start = time .perf_counter ()
77- trust_remote_code = True
78- model = model_class .from_pretrained (model_path , trust_remote_code = trust_remote_code )
79- tokenizer = token_class .from_pretrained (model_path , trust_remote_code = trust_remote_code )
80- end = time .perf_counter ()
81- from_pretrain_time = end - start
82- if kwargs .get ("mem_consumption" ):
83- memory_data_collector .stop_and_collect_data ('from_pretrained_phase' )
84- memory_data_collector .log_data (compilation_phase = True )
85- else :
86- raise RuntimeError (f'==Failure ==: model path:{ model_path } is not directory or directory is empty' )
87- else :
63+ is_gguf_model = model_path .suffix == '.gguf'
64+ if not model_path .exists ():
8865 raise RuntimeError (f'==Failure ==: model path:{ model_path } is not exist' )
66+ if not is_gguf_model and not (model_path .is_dir () and len (os .listdir (model_path )) != 0 ):
67+ raise RuntimeError (f'==Failure ==: model path:{ model_path } is not directory or directory is empty' )
68+ if not device :
69+ raise RuntimeError ('==Failure ==: no device to load' )
70+
71+ log .info (f'Load text model from model path:{ model_path } ' )
72+ model_class = kwargs ['use_case' ].pt_cls
73+ token_class = kwargs ['use_case' ].tokenizer_cls
74+ if kwargs .get ("mem_consumption" ):
75+ memory_data_collector .start ()
76+ start = time .perf_counter ()
77+ load_model_kwargs = {'trust_remote_code' : False }
78+ if is_gguf_model :
79+ load_model_kwargs |= {'gguf_file' : str (model_path )}
80+ model_path = model_path .parent
81+ try :
82+ model = model_class .from_pretrained (model_path , ** load_model_kwargs )
83+ except Exception :
84+ start = time .perf_counter ()
85+ load_model_kwargs ['trust_remote_code' ] = True
86+ model = model_class .from_pretrained (model_path , ** load_model_kwargs )
87+ tokenizer = token_class .from_pretrained (model_path , ** load_model_kwargs )
88+ end = time .perf_counter ()
89+ from_pretrain_time = end - start
90+ if kwargs .get ("mem_consumption" ):
91+ memory_data_collector .stop_and_collect_data ('from_pretrained_phase' )
92+ memory_data_collector .log_data (compilation_phase = True )
8993
9094 log .info (f'model path:{ model_path } , from pretrained time: { from_pretrain_time :.2f} s' )
9195
92- if device is not None :
93- gptjfclm = 'transformers.models.gptj.modeling_gptj.GPTJForCausalLM'
94- lfclm = 'transformers.models.llama.modeling_llama.LlamaForCausalLM'
95- bfclm = 'transformers.models.bloom.modeling_bloom.BloomForCausalLM'
96- gpt2lmhm = 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'
97- gptneoxclm = 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM'
98- chatglmfcg = 'transformers_modules.pytorch_original.modeling_chatglm.ChatGLMForConditionalGeneration'
99- real_base_model_name = str (type (model )).lower ()
100- log .info (f'Real base model={ real_base_model_name } ' )
101- # bfclm will trigger generate crash.
96+ gptjfclm = 'transformers.models.gptj.modeling_gptj.GPTJForCausalLM'
97+ lfclm = 'transformers.models.llama.modeling_llama.LlamaForCausalLM'
98+ bfclm = 'transformers.models.bloom.modeling_bloom.BloomForCausalLM'
99+ gpt2lmhm = 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'
100+ gptneoxclm = 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM'
101+ chatglmfcg = 'transformers_modules.pytorch_original.modeling_chatglm.ChatGLMForConditionalGeneration'
102+ real_base_model_name = str (type (model )).lower ()
103+ log .info (f'Real base model={ real_base_model_name } ' )
104+ # bfclm will trigger generate crash.
102105
103- # If the device is set to GPU there's a need to substitute it with 'cuda' so it will be accepted by PyTorch
104- if device .upper () == 'GPU' :
105- device = torch .device ('cuda' ) if torch .cuda .is_available () else log .info ('CUDA device is unavailable' )
106- else :
107- device = torch .device (device .lower ())
108- log .info (f'Torch device was set to: { device } ' )
106+ # If the device is set to GPU there's a need to substitute it with 'cuda' so it will be accepted by PyTorch
107+ if device .upper () == 'GPU' :
108+ device = torch .device ('cuda' ) if torch .cuda .is_available () else log .info ('CUDA device is unavailable' )
109+ else :
110+ device = torch .device (device .lower ())
111+ log .info (f'Torch device was set to: { device } ' )
109112
110- if any (x in real_base_model_name for x in [gptjfclm , lfclm , bfclm , gpt2lmhm , gptneoxclm , chatglmfcg ]):
111- model = set_bf16 (model , device , ** kwargs )
112- else :
113- if len (kwargs ['config' ]) > 0 and kwargs ['config' ].get ('PREC_BF16' ) and kwargs ['config' ]['PREC_BF16' ] is True :
114- log .info ('Param [bf16/prec_bf16] will not work.' )
115- model .to (device )
113+ if any (x in real_base_model_name for x in [gptjfclm , lfclm , bfclm , gpt2lmhm , gptneoxclm , chatglmfcg ]):
114+ model = set_bf16 (model , device , ** kwargs )
116115 else :
117- raise RuntimeError ('==Failure ==: no device to load' )
116+ if len (kwargs ['config' ]) > 0 and kwargs ['config' ].get ('PREC_BF16' ) and kwargs ['config' ]['PREC_BF16' ] is True :
117+ log .info ('Param [bf16/prec_bf16] will not work.' )
118+ model .to (device )
118119
119120 bench_hook = hook_common .get_bench_hook (kwargs ['num_beams' ], model )
120121
0 commit comments