@@ -57,6 +57,7 @@ def __init__(
5757 batch_size : Optional [Union [int , str ]] = 1 ,
5858 trust_remote_code : Optional [bool ] = False ,
5959 revision = None ,
60+ model_name = None ,
6061 attn_implementation = best_fit_attn_implementation ,
6162 use_flash_attention_2 = True ,
6263 device_map = "auto" ,
@@ -83,8 +84,20 @@ def __init__(
8384 llava_model_args ["attn_implementation" ] = attn_implementation
8485 if customized_config :
8586 llava_model_args ["customized_config" ] = customized_config
86- llava_model_args ["use_flash_attention_2" ] = False
87- self ._tokenizer , self ._model , self ._image_processor , self ._max_length = load_pretrained_model (pretrained , None , get_model_name_from_path (pretrained ), device_map = self .device_map , ** llava_model_args )
87+ if attn_implementation is not None :
88+ llava_model_args ["attn_implementation" ] = attn_implementation
89+ if "use_flash_attention_2" in kwargs :
90+ llava_model_args ["use_flash_attention_2" ] = kwargs ["use_flash_attention_2" ]
91+
92+ model_name = model_name if model_name is not None else get_model_name_from_path (pretrained )
93+ try :
94+ # Try to load the model with the multimodal argument
95+ self ._tokenizer , self ._model , self ._image_processor , self ._max_length = load_pretrained_model (pretrained , None , model_name , device_map = self .device_map , ** llava_model_args )
96+ except TypeError :
97+ # for older versions of LLaVA that don't have multimodal argument
98+ llava_model_args .pop ("multimodal" , None )
99+ self ._tokenizer , self ._model , self ._image_processor , self ._max_length = load_pretrained_model (pretrained , None , model_name , device_map = self .device_map , ** llava_model_args )
100+
88101 self ._config = self ._model .config
89102 self .model .eval ()
90103 self .model .tie_weights ()
0 commit comments