Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions tests/python_tests/samples/test_tools_llm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from conftest import SAMPLES_PY_DIR, convert_model, download_test_content
from test_utils import run_sample
from data.models import get_gguf_model_list
from utils.hugging_face import download_gguf_model

convert_draft_model = convert_model
download_mask_image = download_test_content
Expand Down Expand Up @@ -221,9 +223,9 @@ def test_python_tool_llm_benchmark_optimum(self, convert_model, download_test_co
@pytest.mark.samples
@pytest.mark.parametrize("convert_model", ["bge-small-en-v1.5"], indirect=True)
@pytest.mark.parametrize("sample_args", [
["-d", "cpu", "-n", "2"],
["-d", "cpu", "-n", "2", "--embedding_max_length", "128", "--embedding_normalize", "--embedding_pooling", "mean"],
["-d", "cpu", "-n", "2", "--optimum"],
["-d", "cpu", "-n", "2"],
["-d", "cpu", "-n", "2", "--embedding_max_length", "128", "--embedding_normalize", "--embedding_pooling", "mean"],
["-d", "cpu", "-n", "2", "--optimum"],
["-d", "cpu", "-n", "1", "--embedding_max_length", "128", "--embedding_normalize", "--embedding_pooling", "mean", "--optimum"]
])
def test_python_tool_llm_benchmark_text_embeddings(self, convert_model, sample_args):
Expand All @@ -234,21 +236,36 @@ def test_python_tool_llm_benchmark_text_embeddings(self, convert_model, sample_a
"-m", convert_model,
] + sample_args
run_sample(benchmark_py_command)



@pytest.mark.samples
@pytest.mark.parametrize("convert_model", ["ms-marco-TinyBERT-L2-v2"], indirect=True)
@pytest.mark.parametrize("sample_args", [
["-d", "cpu", "-n", "2", "--rerank"],
["-d", "cpu", "-n", "2", "--rerank"],
["-d", "cpu", "-n", "2", "--reranking_max_length", "10", "--reranking_top_n", "1", "--rerank"],
["-d", "cpu", "-n", "2", "--optimum", "--rerank"],
["-d", "cpu", "-n", "2", "--optimum", "--rerank"],
["-d", "cpu", "-n", "1", "--reranking_max_length", "10", "--reranking_top_n", "1", "--optimum", "--rerank"]
])
def test_python_tool_llm_benchmark_text_reranking(self, convert_model, sample_args):
benchmark_script = os.path.join(SAMPLES_PY_DIR, 'llm_bench/benchmark.py')
benchmark_py_command = [
sys.executable,
benchmark_script,
"-m", convert_model,
"-m", convert_model,
] + sample_args
run_sample(benchmark_py_command)

@pytest.mark.samples
@pytest.mark.parametrize("sample_args", [
["-d", "cpu", "-n", "1"],
["-d", "cpu", "-n", "1", "-f", "pt"],
])
def test_python_tool_llm_benchmark_gguf_format(self, sample_args):
benchmark_script = os.path.join(SAMPLES_PY_DIR, 'llm_bench/benchmark.py')
gguf_model = get_gguf_model_list()[0]
gguf_full_path = download_gguf_model(gguf_model["gguf_model_id"], gguf_model["gguf_filename"])
benchmark_py_command = [
sys.executable,
benchmark_script,
"-m", gguf_full_path,
] + sample_args
run_sample(benchmark_py_command)
2 changes: 1 addition & 1 deletion tools/llm_bench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def num_infer_count_type(x):

def get_argprser():
parser = argparse.ArgumentParser('LLM benchmarking tool', add_help=True, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('-m', '--model', help='model folder including IR files or Pytorch files', required=TabError)
parser.add_argument('-m', '--model', help='model folder including IR files or Pytorch files or path to GGUF model', required=TabError)
parser.add_argument('-d', '--device', default='cpu', help='inference device')
parser.add_argument('-r', '--report', help='report csv')
parser.add_argument('-rj', '--report_json', help='report json')
Expand Down
31 changes: 20 additions & 11 deletions tools/llm_bench/llm_bench_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,26 +234,35 @@ def analyze_args(args):
return model_path, model_framework, model_args, model_name


def get_use_case(model_name_or_path):
config_file = Path(model_name_or_path) / "config.json"
config = None
if config_file.exists():
config = json.loads(config_file.read_text())
def get_use_case(model_name_or_path: str | Path):
if (Path(model_name_or_path) / "model_index.json").exists():
diffusers_config = json.loads((Path(model_name_or_path) / "model_index.json").read_text())
pipe_type = diffusers_config.get("_class_name")
if pipe_type in ["StableDiffusionPipeline", "StableDiffusionXLPipeline", "StableDiffusion3Pipeline", "StableDiffusionInpaintPipeline",
"StableDiffusionXLInpaintPipeline", "FluxPipeline", "LatentConsistencyModelPipeline"]:
return "image_gen", pipe_type.replace("Pipeline", "")

if config is not None:
case, model_name = resolve_complex_model_types(config)
if case is not None:
log.info(f'==SUCCESS FOUND==: use_case: {case}, model_type: {model_name}')
return case, model_name
model_type = None
config_file = Path(model_name_or_path) / "config.json"
if config_file.exists():
config = json.loads(config_file.read_text())
if config is not None:
case, model_name = resolve_complex_model_types(config)
if case is not None:
log.info(f'==SUCCESS FOUND==: use_case: {case}, model_type: {model_name}')
return case, model_name
model_type = config.get("model_type").lower().replace('_', '-')
elif Path(model_name_or_path).suffix in '.gguf':
import gguf_parser
parser = gguf_parser.GGUFParser(model_name_or_path)
parser.parse()
if parser.metadata and parser.metadata.get('general.architecture'):
model_type = parser.metadata.get('general.architecture').lower()

if model_type is not None:
for case, model_ids in USE_CASES.items():
for idx, model_id in enumerate(normalize_model_ids(model_ids)):
if config.get("model_type").lower().replace('_', '-').startswith(model_id):
if model_type.startswith(model_id):
log.info(f'==SUCCESS FOUND==: use_case: {case}, model_type: {model_id}')
return case, model_ids[idx]

Expand Down
8 changes: 3 additions & 5 deletions tools/llm_bench/llm_bench_utils/ov_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,13 @@ def cb_pipeline_required(args):

def create_genai_text_gen_model(model_path, device, ov_config, memory_data_collector, **kwargs):
import openvino_genai
from transformers import AutoTokenizer
from packaging.version import parse

if not (model_path / "openvino_tokenizer.xml").exists() or not (model_path / "openvino_detokenizer.xml").exists():
if Path(model_path).suffix not in '.gguf'\
and (not (model_path / "openvino_tokenizer.xml").exists() or not (model_path / "openvino_detokenizer.xml").exists()):
raise ValueError("OpenVINO Tokenizer model is not found in model directory. Please convert tokenizer using following command:\n"
"convert_tokenizer --with-detokenizer MODEL_DIR --output MODEL_DIR ")

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

config = {}
draft_model_path = kwargs.get("draft_model", '')
cb_config = kwargs.get("cb_config")
Expand Down Expand Up @@ -296,7 +294,7 @@ def get_time_list(self):
return self.token_generation_time
streamer = TokenStreamer(llm_pipe.get_tokenizer()) if use_streamer_metrics else None

return llm_pipe, tokenizer, end - start, streamer, True
return llm_pipe, None, end - start, streamer, True


def convert_ov_tokenizer(tokenizer_path):
Expand Down
101 changes: 51 additions & 50 deletions tools/llm_bench/llm_bench_utils/pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,63 +66,64 @@ def run_torch_compile(model, backend='openvino', dynamic=None, options=None, chi

def create_text_gen_model(model_path, device, memory_data_collector, **kwargs):
model_path = Path(model_path)
from_pretrain_time = 0
if model_path.exists():
if model_path.is_dir() and len(os.listdir(model_path)) != 0:
log.info(f'Load text model from model path:{model_path}')
default_model_type = DEFAULT_MODEL_CLASSES[kwargs['use_case']]
model_type = kwargs.get('model_type', default_model_type)
model_class = PT_MODEL_CLASSES_MAPPING.get(model_type, PT_MODEL_CLASSES_MAPPING[default_model_type])
token_class = TOKENIZE_CLASSES_MAPPING.get(model_type, TOKENIZE_CLASSES_MAPPING[default_model_type])
if kwargs.get("mem_consumption"):
memory_data_collector.start()
start = time.perf_counter()
trust_remote_code = False
try:
model = model_class.from_pretrained(model_path, trust_remote_code=trust_remote_code)
except Exception:
start = time.perf_counter()
trust_remote_code = True
model = model_class.from_pretrained(model_path, trust_remote_code=trust_remote_code)
tokenizer = token_class.from_pretrained(model_path, trust_remote_code=trust_remote_code)
end = time.perf_counter()
from_pretrain_time = end - start
if kwargs.get("mem_consumption"):
memory_data_collector.stop_and_collect_data('from_pretrained_phase')
memory_data_collector.log_data(compilation_phase=True)
else:
raise RuntimeError(f'==Failure ==: model path:{model_path} is not directory or directory is empty')
else:
is_gguf_model = model_path.suffix in '.gguf'
if not model_path.exists():
raise RuntimeError(f'==Failure ==: model path:{model_path} is not exist')
if not is_gguf_model and not (model_path.is_dir() and len(os.listdir(model_path)) != 0):
raise RuntimeError(f'==Failure ==: model path:{model_path} is not directory or directory is empty')
if not device:
raise RuntimeError('==Failure ==: no device to load')

log.info(f'Load text model from model path:{model_path}')
default_model_type = DEFAULT_MODEL_CLASSES[kwargs['use_case']]
model_type = kwargs.get('model_type', default_model_type)
model_class = PT_MODEL_CLASSES_MAPPING.get(model_type, PT_MODEL_CLASSES_MAPPING[default_model_type])
token_class = TOKENIZE_CLASSES_MAPPING.get(model_type, TOKENIZE_CLASSES_MAPPING[default_model_type])
if kwargs.get("mem_consumption"):
memory_data_collector.start()
start = time.perf_counter()
load_model_kwargs = {'trust_remote_code': False}
if is_gguf_model:
load_model_kwargs |= {'gguf_file': str(model_path)}
model_path = model_path.parent
try:
model = model_class.from_pretrained(model_path, **load_model_kwargs)
except Exception:
start = time.perf_counter()
load_model_kwargs['trust_remote_code'] = True
model = model_class.from_pretrained(model_path, **load_model_kwargs)
tokenizer = token_class.from_pretrained(model_path, **load_model_kwargs)
end = time.perf_counter()
from_pretrain_time = end - start
if kwargs.get("mem_consumption"):
memory_data_collector.stop_and_collect_data('from_pretrained_phase')
memory_data_collector.log_data(compilation_phase=True)

log.info(f'model path:{model_path}, from pretrained time: {from_pretrain_time:.2f}s')

if device is not None:
gptjfclm = 'transformers.models.gptj.modeling_gptj.GPTJForCausalLM'
lfclm = 'transformers.models.llama.modeling_llama.LlamaForCausalLM'
bfclm = 'transformers.models.bloom.modeling_bloom.BloomForCausalLM'
gpt2lmhm = 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'
gptneoxclm = 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM'
chatglmfcg = 'transformers_modules.pytorch_original.modeling_chatglm.ChatGLMForConditionalGeneration'
real_base_model_name = str(type(model)).lower()
log.info(f'Real base model={real_base_model_name}')
# bfclm will trigger generate crash.
gptjfclm = 'transformers.models.gptj.modeling_gptj.GPTJForCausalLM'
lfclm = 'transformers.models.llama.modeling_llama.LlamaForCausalLM'
bfclm = 'transformers.models.bloom.modeling_bloom.BloomForCausalLM'
gpt2lmhm = 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'
gptneoxclm = 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM'
chatglmfcg = 'transformers_modules.pytorch_original.modeling_chatglm.ChatGLMForConditionalGeneration'
real_base_model_name = str(type(model)).lower()
log.info(f'Real base model={real_base_model_name}')
# bfclm will trigger generate crash.

# If the device is set to GPU there's a need to substitute it with 'cuda' so it will be accepted by PyTorch
if device.upper() == 'GPU':
device = torch.device('cuda') if torch.cuda.is_available() else log.info('CUDA device is unavailable')
else:
device = torch.device(device.lower())
log.info(f'Torch device was set to: {device}')
# If the device is set to GPU there's a need to substitute it with 'cuda' so it will be accepted by PyTorch
if device.upper() == 'GPU':
device = torch.device('cuda') if torch.cuda.is_available() else log.info('CUDA device is unavailable')
else:
device = torch.device(device.lower())
log.info(f'Torch device was set to: {device}')

if any(x in real_base_model_name for x in [gptjfclm, lfclm, bfclm, gpt2lmhm, gptneoxclm, chatglmfcg]):
model = set_bf16(model, device, **kwargs)
else:
if len(kwargs['config']) > 0 and kwargs['config'].get('PREC_BF16') and kwargs['config']['PREC_BF16'] is True:
log.info('Param [bf16/prec_bf16] will not work.')
model.to(device)
if any(x in real_base_model_name for x in [gptjfclm, lfclm, bfclm, gpt2lmhm, gptneoxclm, chatglmfcg]):
model = set_bf16(model, device, **kwargs)
else:
raise RuntimeError('==Failure ==: no device to load')
if len(kwargs['config']) > 0 and kwargs['config'].get('PREC_BF16') and kwargs['config']['PREC_BF16'] is True:
log.info('Param [bf16/prec_bf16] will not work.')
model.to(device)

bench_hook = hook_common.get_bench_hook(kwargs['num_beams'], model)

Expand Down
2 changes: 2 additions & 0 deletions tools/llm_bench/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ librosa # For Whisper
matplotlib
jinja2>=3.1.0
scipy
gguf_parser
gguf>=0.10
Loading