diff --git a/tests/python_tests/samples/test_tools_llm_benchmark.py b/tests/python_tests/samples/test_tools_llm_benchmark.py index f7176ef79b..8ab52588cf 100644 --- a/tests/python_tests/samples/test_tools_llm_benchmark.py +++ b/tests/python_tests/samples/test_tools_llm_benchmark.py @@ -5,8 +5,10 @@ import pytest import sys -from conftest import SAMPLES_PY_DIR, convert_model, download_test_content from test_utils import run_sample +from data.models import get_gguf_model_list +from utils.hugging_face import download_gguf_model +from conftest import SAMPLES_PY_DIR, convert_model, download_test_content from utils.hugging_face import download_and_convert_embeddings_models, download_and_convert_model convert_draft_model = convert_model @@ -286,3 +288,20 @@ def test_python_tool_llm_benchmark_text_reranking_qwen3(self, model_id, sample_a "-m", models_path, ] + sample_args run_sample(benchmark_py_command) + + + @pytest.mark.samples + @pytest.mark.parametrize("sample_args", [ + ["-d", "cpu", "-n", "1"], + ["-d", "cpu", "-n", "1", "-f", "pt"], + ]) + def test_python_tool_llm_benchmark_gguf_format(self, sample_args): + benchmark_script = os.path.join(SAMPLES_PY_DIR, 'llm_bench/benchmark.py') + gguf_model = get_gguf_model_list()[0] + gguf_full_path = download_gguf_model(gguf_model["gguf_model_id"], gguf_model["gguf_filename"]) + benchmark_py_command = [ + sys.executable, + benchmark_script, + "-m", gguf_full_path, + ] + sample_args + run_sample(benchmark_py_command) diff --git a/tools/llm_bench/benchmark.py b/tools/llm_bench/benchmark.py index bc926d750d..a7d548287e 100644 --- a/tools/llm_bench/benchmark.py +++ b/tools/llm_bench/benchmark.py @@ -40,7 +40,7 @@ def num_infer_count_type(x): def get_argprser(): parser = argparse.ArgumentParser('LLM benchmarking tool', add_help=True, formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('-m', '--model', help='model folder including IR files or Pytorch files', required=TabError) + parser.add_argument('-m', '--model', help='model folder including IR files or Pytorch files or path to GGUF model', required=TabError) parser.add_argument('-d', '--device', default='cpu', help='inference device') parser.add_argument('-r', '--report', help='report csv') parser.add_argument('-rj', '--report_json', help='report json') diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 30398b3a51..56800cbea8 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -250,6 +250,12 @@ def get_use_case(model_name_or_path: str | Path, task: str | None = None): log.info(f'==SUCCESS FOUND==: use_case: {case}, model_type: {model_name}') return case, model_name model_id = config.get("model_type").lower().replace('_', '-') + elif Path(model_name_or_path).suffix in '.gguf': + import gguf_parser + parser = gguf_parser.GGUFParser(model_name_or_path) + parser.parse() + if parser.metadata and parser.metadata.get('general.architecture'): + model_id = parser.metadata.get('general.architecture').lower() if model_id is not None: case, model_id = get_use_case_by_model_id(model_id, task) diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index f48bfcb83a..bdd810b5b7 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -208,15 +208,13 @@ def cb_pipeline_required(args): def create_genai_text_gen_model(model_path, device, ov_config, memory_data_collector, **kwargs): import openvino_genai - from transformers import AutoTokenizer from packaging.version import parse - if not (model_path / "openvino_tokenizer.xml").exists() or not (model_path / "openvino_detokenizer.xml").exists(): + if Path(model_path).suffix != '.gguf'\ + and (not (model_path / "openvino_tokenizer.xml").exists() or not (model_path / "openvino_detokenizer.xml").exists()): raise ValueError("OpenVINO Tokenizer model is not found in model directory. Please convert tokenizer using following command:\n" "convert_tokenizer --with-detokenizer MODEL_DIR --output MODEL_DIR ") - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - config = {} draft_model_path = kwargs.get("draft_model", '') cb_config = kwargs.get("cb_config") @@ -288,7 +286,7 @@ def get_time_list(self): return self.token_generation_time streamer = TokenStreamer(llm_pipe.get_tokenizer()) if use_streamer_metrics else None - return llm_pipe, tokenizer, end - start, streamer, True + return llm_pipe, None, end - start, streamer, True def convert_ov_tokenizer(tokenizer_path): diff --git a/tools/llm_bench/llm_bench_utils/pt_utils.py b/tools/llm_bench/llm_bench_utils/pt_utils.py index b335db8c99..8e7a6fcf55 100644 --- a/tools/llm_bench/llm_bench_utils/pt_utils.py +++ b/tools/llm_bench/llm_bench_utils/pt_utils.py @@ -60,61 +60,62 @@ def run_torch_compile(model, backend='openvino', dynamic=None, options=None, chi def create_text_gen_model(model_path, device, memory_data_collector, **kwargs): model_path = Path(model_path) - from_pretrain_time = 0 - if model_path.exists(): - if model_path.is_dir() and len(os.listdir(model_path)) != 0: - log.info(f'Load text model from model path:{model_path}') - model_class = kwargs['use_case'].pt_cls - token_class = kwargs['use_case'].tokenizer_cls - if kwargs.get("mem_consumption"): - memory_data_collector.start() - start = time.perf_counter() - trust_remote_code = False - try: - model = model_class.from_pretrained(model_path, trust_remote_code=trust_remote_code) - except Exception: - start = time.perf_counter() - trust_remote_code = True - model = model_class.from_pretrained(model_path, trust_remote_code=trust_remote_code) - tokenizer = token_class.from_pretrained(model_path, trust_remote_code=trust_remote_code) - end = time.perf_counter() - from_pretrain_time = end - start - if kwargs.get("mem_consumption"): - memory_data_collector.stop_and_collect_data('from_pretrained_phase') - memory_data_collector.log_data(compilation_phase=True) - else: - raise RuntimeError(f'==Failure ==: model path:{model_path} is not directory or directory is empty') - else: + is_gguf_model = model_path.suffix == '.gguf' + if not model_path.exists(): raise RuntimeError(f'==Failure ==: model path:{model_path} is not exist') + if not is_gguf_model and not (model_path.is_dir() and len(os.listdir(model_path)) != 0): + raise RuntimeError(f'==Failure ==: model path:{model_path} is not directory or directory is empty') + if not device: + raise RuntimeError('==Failure ==: no device to load') + + log.info(f'Load text model from model path:{model_path}') + model_class = kwargs['use_case'].pt_cls + token_class = kwargs['use_case'].tokenizer_cls + if kwargs.get("mem_consumption"): + memory_data_collector.start() + start = time.perf_counter() + load_model_kwargs = {'trust_remote_code': False} + if is_gguf_model: + load_model_kwargs |= {'gguf_file': str(model_path)} + model_path = model_path.parent + try: + model = model_class.from_pretrained(model_path, **load_model_kwargs) + except Exception: + start = time.perf_counter() + load_model_kwargs['trust_remote_code'] = True + model = model_class.from_pretrained(model_path, **load_model_kwargs) + tokenizer = token_class.from_pretrained(model_path, **load_model_kwargs) + end = time.perf_counter() + from_pretrain_time = end - start + if kwargs.get("mem_consumption"): + memory_data_collector.stop_and_collect_data('from_pretrained_phase') + memory_data_collector.log_data(compilation_phase=True) log.info(f'model path:{model_path}, from pretrained time: {from_pretrain_time:.2f}s') - if device is not None: - gptjfclm = 'transformers.models.gptj.modeling_gptj.GPTJForCausalLM' - lfclm = 'transformers.models.llama.modeling_llama.LlamaForCausalLM' - bfclm = 'transformers.models.bloom.modeling_bloom.BloomForCausalLM' - gpt2lmhm = 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel' - gptneoxclm = 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM' - chatglmfcg = 'transformers_modules.pytorch_original.modeling_chatglm.ChatGLMForConditionalGeneration' - real_base_model_name = str(type(model)).lower() - log.info(f'Real base model={real_base_model_name}') - # bfclm will trigger generate crash. + gptjfclm = 'transformers.models.gptj.modeling_gptj.GPTJForCausalLM' + lfclm = 'transformers.models.llama.modeling_llama.LlamaForCausalLM' + bfclm = 'transformers.models.bloom.modeling_bloom.BloomForCausalLM' + gpt2lmhm = 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel' + gptneoxclm = 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM' + chatglmfcg = 'transformers_modules.pytorch_original.modeling_chatglm.ChatGLMForConditionalGeneration' + real_base_model_name = str(type(model)).lower() + log.info(f'Real base model={real_base_model_name}') + # bfclm will trigger generate crash. - # If the device is set to GPU there's a need to substitute it with 'cuda' so it will be accepted by PyTorch - if device.upper() == 'GPU': - device = torch.device('cuda') if torch.cuda.is_available() else log.info('CUDA device is unavailable') - else: - device = torch.device(device.lower()) - log.info(f'Torch device was set to: {device}') + # If the device is set to GPU there's a need to substitute it with 'cuda' so it will be accepted by PyTorch + if device.upper() == 'GPU': + device = torch.device('cuda') if torch.cuda.is_available() else log.info('CUDA device is unavailable') + else: + device = torch.device(device.lower()) + log.info(f'Torch device was set to: {device}') - if any(x in real_base_model_name for x in [gptjfclm, lfclm, bfclm, gpt2lmhm, gptneoxclm, chatglmfcg]): - model = set_bf16(model, device, **kwargs) - else: - if len(kwargs['config']) > 0 and kwargs['config'].get('PREC_BF16') and kwargs['config']['PREC_BF16'] is True: - log.info('Param [bf16/prec_bf16] will not work.') - model.to(device) + if any(x in real_base_model_name for x in [gptjfclm, lfclm, bfclm, gpt2lmhm, gptneoxclm, chatglmfcg]): + model = set_bf16(model, device, **kwargs) else: - raise RuntimeError('==Failure ==: no device to load') + if len(kwargs['config']) > 0 and kwargs['config'].get('PREC_BF16') and kwargs['config']['PREC_BF16'] is True: + log.info('Param [bf16/prec_bf16] will not work.') + model.to(device) bench_hook = hook_common.get_bench_hook(kwargs['num_beams'], model) diff --git a/tools/llm_bench/requirements.txt b/tools/llm_bench/requirements.txt index 6c46eb193d..ef28ad7112 100644 --- a/tools/llm_bench/requirements.txt +++ b/tools/llm_bench/requirements.txt @@ -18,3 +18,5 @@ librosa # For Whisper matplotlib jinja2>=3.1.0 scipy +gguf_parser +gguf>=0.10