diff --git a/.gitignore b/.gitignore index 6e06c6b8..f610940f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +# Output files +benchmarks.json +benchmarks.yaml +benchmarks.csv + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/src/guidellm/__init__.py b/src/guidellm/__init__.py index 929d046e..3bf38f03 100644 --- a/src/guidellm/__init__.py +++ b/src/guidellm/__init__.py @@ -21,7 +21,29 @@ hf_logging.set_verbosity_error() logging.getLogger("transformers").setLevel(logging.ERROR) -from .config import settings +from .config import ( + settings, + DatasetSettings, + Environment, + LoggingSettings, + OpenAISettings, + print_config, + Settings, + reload_settings, +) from .logger import configure_logger, logger -__all__ = ["configure_logger", "logger", "settings", "generate_benchmark_report"] +__all__ = [ + # Config + "DatasetSettings", + "Environment", + "LoggingSettings", + "OpenAISettings", + "print_config", + "Settings", + "reload_settings", + "settings", + # Logger + "logger", + "configure_logger", +] diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 096614de..f6b0e3d8 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -210,6 +210,15 @@ def cli(): callback=parse_json, help="A JSON string of extra data to save with the output benchmarks", ) +@click.option( + "--output-sampling", + type=int, + help=( + "The number of samples to save in the output file. " + "If None (default), will save all samples." + ), + default=None, +) @click.option( "--random-seed", default=42, @@ -237,6 +246,7 @@ def benchmark( disable_console_outputs, output_path, output_extras, + output_sampling, random_seed, ): asyncio.run( @@ -261,6 +271,7 @@ def benchmark( output_console=not disable_console_outputs, output_path=output_path, output_extras=output_extras, + output_sampling=output_sampling, random_seed=random_seed, ) ) diff --git a/src/guidellm/benchmark/__init__.py b/src/guidellm/benchmark/__init__.py index dc100596..5eaeb57c 100644 --- a/src/guidellm/benchmark/__init__.py +++ b/src/guidellm/benchmark/__init__.py @@ -1,7 +1,19 @@ from .aggregator import AggregatorT, BenchmarkAggregator, GenerativeBenchmarkAggregator -from .benchmark import Benchmark, BenchmarkT, GenerativeBenchmark +from .benchmark import ( + Benchmark, + BenchmarkArgs, + BenchmarkMetrics, + BenchmarkRunStats, + BenchmarkT, + GenerativeBenchmark, + GenerativeMetrics, + GenerativeTextErrorStats, + GenerativeTextResponseStats, + StatusBreakdown, +) from .benchmarker import Benchmarker, BenchmarkerResult, GenerativeBenchmarker from .entrypoints import benchmark_generative_text +from .output import GenerativeBenchmarksConsole, GenerativeBenchmarksReport from .profile import ( AsyncProfile, ConcurrentProfile, @@ -12,17 +24,39 @@ ThroughputProfile, create_profile, ) +from .progress import ( + BenchmarkerProgressDisplay, + BenchmarkerTaskProgressState, + GenerativeTextBenchmarkerProgressDisplay, + GenerativeTextBenchmarkerTaskProgressState, +) __all__ = [ + # Aggregator "AggregatorT", - "BenchmarkT", - "Benchmark", "BenchmarkAggregator", - "GenerativeBenchmark", "GenerativeBenchmarkAggregator", + # Benchmark + "Benchmark", + "BenchmarkArgs", + "BenchmarkMetrics", + "BenchmarkRunStats", + "BenchmarkT", + "GenerativeBenchmark", + "GenerativeMetrics", + "GenerativeTextErrorStats", + "GenerativeTextResponseStats", + "StatusBreakdown", + # Benchmarker "Benchmarker", "BenchmarkerResult", "GenerativeBenchmarker", + # Entry points + "benchmark_generative_text", + # Output + "GenerativeBenchmarksConsole", + "GenerativeBenchmarksReport", + # Profile "AsyncProfile", "ConcurrentProfile", "Profile", @@ -31,5 +65,9 @@ "SynchronousProfile", "ThroughputProfile", "create_profile", - "benchmark_generative_text", + # Progress + "BenchmarkerProgressDisplay", + "BenchmarkerTaskProgressState", + "GenerativeTextBenchmarkerProgressDisplay", + "GenerativeTextBenchmarkerTaskProgressState", ] diff --git a/src/guidellm/benchmark/benchmark.py b/src/guidellm/benchmark/benchmark.py index f1f9187c..3d94d9d8 100644 --- a/src/guidellm/benchmark/benchmark.py +++ b/src/guidellm/benchmark/benchmark.py @@ -457,7 +457,12 @@ def time_per_output_token_ms(self) -> Optional[float]: # type: ignore[override] This includes the time to generate the first token and all other tokens. None if the output_tokens is None or 0. """ - if self.output_tokens is None or self.output_tokens == 0: + if ( + self.output_tokens is None + or self.output_tokens == 0 + or self.first_token_time is None + or self.last_token_time is None + ): return None return super().time_per_output_token_ms @@ -614,41 +619,46 @@ def duration(self) -> float: ), ) - def create_sampled(self, sample_size: int) -> "GenerativeBenchmark": + def set_sample_size(self, sample_size: Optional[int]) -> "GenerativeBenchmark": """ - Create a new benchmark instance with a random sample of the completed and - errored requests based on the given sample sizes. If the sample sizes are - larger than the total number of requests, the sample sizes are capped at - the total number of requests. + Set the sample size for the benchmark. This will randomly sample the + requests for each status type to the given sample size or the maximum + number of requests for that status type, whichever is smaller. + This is applied to requests.successful, requests.errored, and + requests.incomplete. + If None, no sampling is applied and the state is kept. :param sample_size: The number of requests to sample for each status type. - :return: A new benchmark instance with the sampled requests. - :raises ValueError: If the sample sizes are negative. + :return: The benchmark with the sampled requests. + :raises ValueError: If the sample size is invalid. """ - if sample_size < 0: - raise ValueError(f"Sample size must be non-negative, given {sample_size}") - sample_size = min(sample_size, len(self.requests.successful)) - error_sample_size = min(sample_size, len(self.requests.errored)) - incomplete_sample_size = min(sample_size, len(self.requests.incomplete)) + if sample_size is not None: + if sample_size < 0 or not isinstance(sample_size, int): + raise ValueError( + f"Sample size must be non-negative integer, given {sample_size}" + ) - sampled_instance = self.model_copy() - sampled_instance.requests.successful = random.sample( - self.requests.successful, sample_size - ) - sampled_instance.requests.errored = random.sample( - self.requests.errored, error_sample_size - ) - sampled_instance.requests.incomplete = random.sample( - self.requests.incomplete, incomplete_sample_size - ) - sampled_instance.request_samples = StatusBreakdown( - successful=len(sampled_instance.requests.successful), - incomplete=len(sampled_instance.requests.incomplete), - errored=len(sampled_instance.requests.errored), - ) + sample_size = min(sample_size, len(self.requests.successful)) + error_sample_size = min(sample_size, len(self.requests.errored)) + incomplete_sample_size = min(sample_size, len(self.requests.incomplete)) + + self.requests.successful = random.sample( + self.requests.successful, sample_size + ) + self.requests.errored = random.sample( + self.requests.errored, error_sample_size + ) + self.requests.incomplete = random.sample( + self.requests.incomplete, incomplete_sample_size + ) + self.request_samples = StatusBreakdown( + successful=len(self.requests.successful), + incomplete=len(self.requests.incomplete), + errored=len(self.requests.errored), + ) - return sampled_instance + return self @staticmethod def from_stats( diff --git a/src/guidellm/benchmark/entrypoints.py b/src/guidellm/benchmark/entrypoints.py index fc98219e..676b9a94 100644 --- a/src/guidellm/benchmark/entrypoints.py +++ b/src/guidellm/benchmark/entrypoints.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, Iterable, List, Literal, Optional, Union +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from transformers import ( # type: ignore[import] @@ -7,11 +7,10 @@ ) from guidellm.backend import Backend, BackendType -from guidellm.benchmark.benchmark import GenerativeBenchmark from guidellm.benchmark.benchmarker import GenerativeBenchmarker from guidellm.benchmark.output import ( GenerativeBenchmarksConsole, - save_generative_benchmarks, + GenerativeBenchmarksReport, ) from guidellm.benchmark.profile import ProfileType, create_profile from guidellm.benchmark.progress import GenerativeTextBenchmarkerProgressDisplay @@ -48,8 +47,9 @@ async def benchmark_generative_text( output_console: bool, output_path: Optional[Union[str, Path]], output_extras: Optional[Dict[str, Any]], + output_sampling: Optional[int], random_seed: int, -) -> List[GenerativeBenchmark]: +) -> Tuple[GenerativeBenchmarksReport, Optional[Path]]: console = GenerativeBenchmarksConsole(enabled=show_progress) console.print_line("Creating backend...") backend = Backend.create( @@ -100,7 +100,7 @@ async def benchmark_generative_text( if show_progress else None ) - benchmarks = [] + report = GenerativeBenchmarksReport() async for result in benchmarker.run( profile=profile, @@ -115,15 +115,26 @@ async def benchmark_generative_text( if result.type_ == "benchmark_compiled": if result.current_benchmark is None: raise ValueError("Current benchmark is None") - benchmarks.append(result.current_benchmark) + report.benchmarks.append( + result.current_benchmark.set_sample_size(output_sampling) + ) if output_console: - console.benchmarks = benchmarks + orig_enabled = console.enabled + console.enabled = True + console.benchmarks = report.benchmarks console.print_benchmarks_metadata() console.print_benchmarks_info() console.print_benchmarks_stats() + console.enabled = orig_enabled if output_path: - save_generative_benchmarks(benchmarks=benchmarks, path=output_path) + console.print_line("\nSaving benchmarks report...") + saved_path = report.save_file(output_path) + console.print_line(f"Benchmarks report saved to {saved_path}") + else: + saved_path = None - return benchmarks + console.print_line("\nBenchmarking complete.") + + return report, saved_path diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index 68031c12..7f316e47 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -1,86 +1,385 @@ +import csv import json +import math from collections import OrderedDict from datetime import datetime from pathlib import Path -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import yaml +from pydantic import Field from rich.console import Console from rich.padding import Padding -from rich.table import Table from rich.text import Text -from guidellm.benchmark.benchmark import GenerativeBenchmark +from guidellm.benchmark.benchmark import GenerativeBenchmark, GenerativeMetrics from guidellm.benchmark.profile import ( AsyncProfile, ConcurrentProfile, SweepProfile, ThroughputProfile, ) -from guidellm.objects import StandardBaseModel +from guidellm.config import settings +from guidellm.objects import ( + DistributionSummary, + StandardBaseModel, + StatusDistributionSummary, +) from guidellm.scheduler import strategy_display_str -from guidellm.utils import Colors +from guidellm.utils import Colors, split_text_list_by_length __all__ = [ "GenerativeBenchmarksReport", - "save_generative_benchmarks", "GenerativeBenchmarksConsole", ] class GenerativeBenchmarksReport(StandardBaseModel): - benchmarks: List[GenerativeBenchmark] - - def save_file(self, path: Path): - if path.is_dir(): - path = path / "benchmarks.json" - - path.parent.mkdir(parents=True, exist_ok=True) - extension = path.suffix.lower() - - if extension == ".json": - self.save_json(path) - elif extension in [".yaml", ".yml"]: - self.save_yaml(path) - elif extension in [".csv"]: - self.save_csv(path) - else: - raise ValueError(f"Unsupported file extension: {extension} for {path}.") + """ + A pydantic model representing a completed benchmark report. + Contains a list of benchmarks along with convenience methods for finalizing + and saving the report. + """ + + @staticmethod + def load_file(path: Union[str, Path]) -> "GenerativeBenchmarksReport": + """ + Load a report from a file. The file type is determined by the file extension. + If the file is a directory, it expects a file named benchmarks.json under the + directory. + + :param path: The path to load the report from. + :return: The loaded report. + """ + path, type_ = GenerativeBenchmarksReport._file_setup(path) + + if type_ == "json": + with path.open("r") as file: + model_dict = json.load(file) + + return GenerativeBenchmarksReport.model_validate(model_dict) + + if type_ == "yaml": + with path.open("r") as file: + model_dict = yaml.safe_load(file) + + return GenerativeBenchmarksReport.model_validate(model_dict) + + if type_ == "csv": + raise ValueError(f"CSV file type is not supported for loading: {path}.") + + raise ValueError(f"Unsupported file type: {type_} for {path}.") + + benchmarks: List[GenerativeBenchmark] = Field( + description="The list of completed benchmarks contained within the report.", + default_factory=list, + ) + + def set_sample_size( + self, sample_size: Optional[int] + ) -> "GenerativeBenchmarksReport": + """ + Set the sample size for each benchmark in the report. In doing this, it will + reduce the contained requests of each benchmark to the sample size. + If sample size is None, it will return the report as is. + + :param sample_size: The sample size to set for each benchmark. + If None, the report will be returned as is. + :return: The report with the sample size set for each benchmark. + """ + + if sample_size is not None: + for benchmark in self.benchmarks: + benchmark.set_sample_size(sample_size) + + return self + + def save_file(self, path: Union[str, Path]) -> Path: + """ + Save the report to a file. The file type is determined by the file extension. + If the file is a directory, it will save the report to a file named + benchmarks.json under the directory. + + :param path: The path to save the report to. + :return: The path to the saved report. + """ + path, type_ = GenerativeBenchmarksReport._file_setup(path) + + if type_ == "json": + return self.save_json(path) + + if type_ == "yaml": + return self.save_yaml(path) + + if type_ == "csv": + return self.save_csv(path) + + raise ValueError(f"Unsupported file type: {type_} for {path}.") + + def save_json(self, path: Union[str, Path]) -> Path: + """ + Save the report to a JSON file containing all of the report data which is + reloadable using the pydantic model. If the file is a directory, it will save + the report to a file named benchmarks.json under the directory. + + :param path: The path to save the report to. + :return: The path to the saved report. + """ + path, type_ = GenerativeBenchmarksReport._file_setup(path, "json") + + if type_ != "json": + raise ValueError( + f"Unsupported file type for saving a JSON: {type_} for {path}." + ) - def save_json(self, path: Path): model_dict = self.model_dump() model_json = json.dumps(model_dict) with path.open("w") as file: file.write(model_json) - def save_yaml(self, path: Path): + return path + + def save_yaml(self, path: Union[str, Path]) -> Path: + """ + Save the report to a YAML file containing all of the report data which is + reloadable using the pydantic model. If the file is a directory, it will save + the report to a file named benchmarks.yaml under the directory. + + :param path: The path to save the report to. + :return: The path to the saved report. + """ + + path, type_ = GenerativeBenchmarksReport._file_setup(path, "yaml") + + if type_ != "yaml": + raise ValueError( + f"Unsupported file type for saving a YAML: {type_} for {path}." + ) + model_dict = self.model_dump() model_yaml = yaml.dump(model_dict) with path.open("w") as file: file.write(model_yaml) - def save_csv(self, path: Path): - raise NotImplementedError("CSV format is not implemented yet.") + return path + + def save_csv(self, path: Union[str, Path]) -> Path: + """ + Save the report to a CSV file containing the summarized statistics and values + for each report. Note, this data is not reloadable using the pydantic model. + If the file is a directory, it will save the report to a file named + benchmarks.csv under the directory. + + :param path: The path to save the report to. + :return: The path to the saved report. + """ + path, type_ = GenerativeBenchmarksReport._file_setup(path, "csv") + + if type_ != "csv": + raise ValueError( + f"Unsupported file type for saving a CSV: {type_} for {path}." + ) + + with path.open("w", newline="") as file: + writer = csv.writer(file) + headers: List[str] = [] + rows: List[List[Union[str, float, List[float]]]] = [] + + for benchmark in self.benchmarks: + benchmark_headers: List[str] = [] + benchmark_values: List[Union[str, float, List[float]]] = [] + + desc_headers, desc_values = self._benchmark_desc_headers_and_values( + benchmark + ) + benchmark_headers += desc_headers + benchmark_values += desc_values + + for status in StatusDistributionSummary.model_fields: + status_headers, status_values = ( + self._benchmark_status_headers_and_values(benchmark, status) + ) + benchmark_headers += status_headers + benchmark_values += status_values + + benchmark_extra_headers, benchmark_extra_values = ( + self._benchmark_extras_headers_and_values(benchmark) + ) + benchmark_headers += benchmark_extra_headers + benchmark_values += benchmark_extra_values + + if not headers: + headers = benchmark_headers + rows.append(benchmark_values) + + writer.writerow(headers) + for row in rows: + writer.writerow(row) + + return path + + @staticmethod + def _file_setup( + path: Union[str, Path], + default_file_type: Literal["json", "yaml", "csv"] = "json", + ) -> Tuple[Path, Literal["json", "yaml", "csv"]]: + path = Path(path) if not isinstance(path, Path) else path + + if path.is_dir(): + path = path / f"benchmarks.{default_file_type}" + + path.parent.mkdir(parents=True, exist_ok=True) + path_suffix = path.suffix.lower() + + if path_suffix == ".json": + return path, "json" + + if path_suffix in [".yaml", ".yml"]: + return path, "yaml" + + if path_suffix in [".csv"]: + return path, "csv" + + raise ValueError(f"Unsupported file extension: {path_suffix} for {path}.") + + @staticmethod + def _benchmark_desc_headers_and_values( + benchmark: GenerativeBenchmark, + ) -> Tuple[List[str], List[Union[str, float]]]: + headers = [ + "Type", + "Run Id", + "Id", + "Name", + "Start Time", + "End Time", + "Duration", + ] + values: List[Union[str, float]] = [ + benchmark.type_, + benchmark.run_id, + benchmark.id_, + strategy_display_str(benchmark.args.strategy), + datetime.fromtimestamp(benchmark.start_time).strftime("%Y-%m-%d %H:%M:%S"), + datetime.fromtimestamp(benchmark.end_time).strftime("%Y-%m-%d %H:%M:%S"), + benchmark.duration, + ] + + if len(headers) != len(values): + raise ValueError("Headers and values length mismatch.") + + return headers, values + + @staticmethod + def _benchmark_extras_headers_and_values( + benchmark: GenerativeBenchmark, + ) -> Tuple[List[str], List[str]]: + headers = ["Args", "Worker", "Request Loader", "Extras"] + values: List[str] = [ + json.dumps(benchmark.args.model_dump()), + json.dumps(benchmark.worker.model_dump()), + json.dumps(benchmark.request_loader.model_dump()), + json.dumps(benchmark.extras), + ] + + if len(headers) != len(values): + raise ValueError("Headers and values length mismatch.") + + return headers, values + + @staticmethod + def _benchmark_status_headers_and_values( + benchmark: GenerativeBenchmark, status: str + ) -> Tuple[List[str], List[Union[float, List[float]]]]: + headers = [ + f"{status.capitalize()} Requests", + ] + values = [ + getattr(benchmark.request_totals, status), + ] + + for metric in GenerativeMetrics.model_fields: + metric_headers, metric_values = ( + GenerativeBenchmarksReport._benchmark_status_metrics_stats( + benchmark, status, metric + ) + ) + headers += metric_headers + values += metric_values + + if len(headers) != len(values): + raise ValueError("Headers and values length mismatch.") + + return headers, values + + @staticmethod + def _benchmark_status_metrics_stats( + benchmark: GenerativeBenchmark, + status: str, + metric: str, + ) -> Tuple[List[str], List[Union[float, List[float]]]]: + status_display = status.capitalize() + metric_display = metric.replace("_", " ").capitalize() + status_dist_summary: StatusDistributionSummary = getattr( + benchmark.metrics, metric + ) + dist_summary: DistributionSummary = getattr(status_dist_summary, status) + headers = [ + f"{status_display} {metric_display} mean", + f"{status_display} {metric_display} median", + f"{status_display} {metric_display} std dev", + ( + f"{status_display} {metric_display} " + "[min, 0.1, 1, 5, 10, 25, 75, 90, 95, 99, max]" + ), + ] + values: List[Union[float, List[float]]] = [ + dist_summary.mean, + dist_summary.median, + dist_summary.std_dev, + [ + dist_summary.min, + dist_summary.percentiles.p001, + dist_summary.percentiles.p01, + dist_summary.percentiles.p05, + dist_summary.percentiles.p10, + dist_summary.percentiles.p25, + dist_summary.percentiles.p75, + dist_summary.percentiles.p90, + dist_summary.percentiles.p95, + dist_summary.percentiles.p99, + dist_summary.max, + ], + ] + if len(headers) != len(values): + raise ValueError("Headers and values length mismatch.") -def save_generative_benchmarks( - benchmarks: List[GenerativeBenchmark], path: Union[Path, str] -): - path = Path(path) if isinstance(path, str) else path - report = GenerativeBenchmarksReport(benchmarks=benchmarks) - report.save_file(path) + return headers, values class GenerativeBenchmarksConsole: + """ + A class for outputting progress and benchmark results to the console. + Utilizes the rich library for formatting, enabling colored and styled output. + """ + def __init__(self, enabled: bool = True): + """ + :param enabled: Whether to enable console output. Defaults to True. + If False, all console output will be suppressed. + """ self.enabled = enabled self.benchmarks: Optional[List[GenerativeBenchmark]] = None self.console = Console() @property def benchmarks_profile_str(self) -> str: + """ + :return: A string representation of the profile used for the benchmarks. + """ profile = self.benchmarks[0].args.profile if self.benchmarks else None if profile is None: @@ -108,6 +407,9 @@ def benchmarks_profile_str(self) -> str: @property def benchmarks_args_str(self) -> str: + """ + :return: A string representation of the arguments used for the benchmarks. + """ args = self.benchmarks[0].args if self.benchmarks else None if args is None: @@ -128,14 +430,23 @@ def benchmarks_args_str(self) -> str: @property def benchmarks_worker_desc_str(self) -> str: + """ + :return: A string representation of the worker used for the benchmarks. + """ return str(self.benchmarks[0].worker) if self.benchmarks else "None" @property def benchmarks_request_loader_desc_str(self) -> str: + """ + :return: A string representation of the request loader used for the benchmarks. + """ return str(self.benchmarks[0].request_loader) if self.benchmarks else "None" @property def benchmarks_extras_str(self) -> str: + """ + :return: A string representation of the extras used for the benchmarks. + """ extras = self.benchmarks[0].extras if self.benchmarks else None if not extras: @@ -143,7 +454,64 @@ def benchmarks_extras_str(self) -> str: return ", ".join(f"{key}={value}" for key, value in extras.items()) - def print_section_header(self, title: str, new_lines: int = 2): + def print_section_header(self, title: str, indent: int = 0, new_lines: int = 2): + """ + Print out a styled section header to the console. + The title is underlined, bolded, and colored with the INFO color. + + :param title: The title of the section. + :param indent: The number of spaces to indent the title. + Defaults to 0. + :param new_lines: The number of new lines to print before the title. + Defaults to 2. + """ + self.print_line( + value=f"{title}:", + style=f"bold underline {Colors.INFO}", + indent=indent, + new_lines=new_lines, + ) + + def print_labeled_line( + self, label: str, value: str, indent: int = 4, new_lines: int = 0 + ): + """ + Print out a styled, labeled line (label: value) to the console. + The label is bolded and colored with the INFO color, + and the value is italicized. + + :param label: The label of the line. + :param value: The value of the line. + :param indent: The number of spaces to indent the line. + Defaults to 4. + :param new_lines: The number of new lines to print before the line. + Defaults to 0. + """ + self.print_line( + value=[label + ":", value], + style=["bold " + Colors.INFO, "italic"], + new_lines=new_lines, + indent=indent, + ) + + def print_line( + self, + value: Union[str, List[str]], + style: Union[str, List[str]] = "", + indent: int = 0, + new_lines: int = 0, + ): + """ + Print out a a value to the console as a line with optional indentation. + + :param value: The value to print. + :param style: The style to apply to the value. + Defaults to none. + :param indent: The number of spaces to indent the line. + Defaults to 0. + :param new_lines: The number of new lines to print before the value. + Defaults to 0. + """ if not self.enabled: return @@ -152,45 +520,248 @@ def print_section_header(self, title: str, new_lines: int = 2): for _ in range(new_lines): text.append("\n") - text.append(f"{title}:", style=f"bold underline {Colors.INFO}") - self.console.print(text) + if not isinstance(value, list): + value = [value] - def print_labeled_line(self, label: str, value: str, indent: int = 4): - if not self.enabled: - return + if not isinstance(style, list): + style = [style for _ in range(len(value))] - text = Text() - text.append(label + ": ", style=f"bold {Colors.INFO}") - text.append(": ") - text.append(value, style="italic") - self.console.print( - Padding.indent(text, indent), - ) + if len(value) != len(style): + raise ValueError( + f"Value and style length mismatch. Value length: {len(value)}, " + f"Style length: {len(style)}." + ) - def print_line(self, value: str, indent: int = 0): - if not self.enabled: - return + for val, sty in zip(value, style): + text.append(val, style=sty) + + self.console.print(Padding.indent(text, indent)) + + def print_table( + self, + headers: List[str], + rows: List[List[Any]], + title: str, + sections: Optional[Dict[str, Tuple[int, int]]] = None, + max_char_per_col: int = 2**10, + indent: int = 0, + new_lines: int = 2, + ): + """ + Print a table to the console with the given headers and rows. + + :param headers: The headers of the table. + :param rows: The rows of the table. + :param title: The title of the table. + :param sections: The sections of the table grouping columns together. + This is a mapping of the section display name to a tuple of the start and + end column indices. If None, no sections are added (default). + :param max_char_per_col: The maximum number of characters per column. + :param indent: The number of spaces to indent the table. + Defaults to 0. + :param new_lines: The number of new lines to print before the table. + Defaults to 0. + """ + + if rows and any(len(row) != len(headers) for row in rows): + raise ValueError( + f"Headers and rows length mismatch. Headers length: {len(headers)}, " + f"Row length: {len(rows[0]) if rows else 'N/A'}." + ) - text = Text(value) - self.console.print( - Padding.indent(text, indent), + max_characters_per_column = self.calculate_max_chars_per_column( + headers, rows, sections, max_char_per_col ) - def print_table(self, headers: List[str], rows: List[List[Any]], title: str): - if not self.enabled: - return - - self.print_section_header(title) - table = Table(*headers, header_style=f"bold {Colors.INFO}") - + self.print_section_header(title, indent=indent, new_lines=new_lines) + self.print_table_divider( + max_characters_per_column, include_separators=False, indent=indent + ) + if sections: + self.print_table_sections( + sections, max_characters_per_column, indent=indent + ) + self.print_table_row( + split_text_list_by_length(headers, max_characters_per_column), + style=f"bold {Colors.INFO}", + indent=indent, + ) + self.print_table_divider( + max_characters_per_column, include_separators=True, indent=indent + ) for row in rows: - table.add_row(*[Text(item, style="italic") for item in row]) + self.print_table_row( + split_text_list_by_length(row, max_characters_per_column), + style="italic", + indent=indent, + ) + self.print_table_divider( + max_characters_per_column, include_separators=False, indent=indent + ) - self.console.print(table) + def calculate_max_chars_per_column( + self, + headers: List[str], + rows: List[List[Any]], + sections: Optional[Dict[str, Tuple[int, int]]], + max_char_per_col: int, + ) -> List[int]: + """ + Calculate the maximum number of characters per column in the table. + This is done by checking the length of the headers, rows, and optional sections + to ensure all columns are accounted for and spaced correctly. + + :param headers: The headers of the table. + :param rows: The rows of the table. + :param sections: The sections of the table grouping columns together. + This is a mapping of the section display name to a tuple of the start and + end column indices. If None, no sections are added (default). + :param max_char_per_col: The maximum number of characters per column. + :return: A list of the maximum number of characters per column. + """ + max_characters_per_column = [] + for ind in range(len(headers)): + max_characters_per_column.append(min(len(headers[ind]), max_char_per_col)) + + for row in rows: + max_characters_per_column[ind] = max( + max_characters_per_column[ind], len(str(row[ind])) + ) + + if not sections: + return max_characters_per_column + + for section in sections: + start_col, end_col = sections[section] + min_section_len = len(section) + ( + end_col - start_col + ) # ensure we have enough space for separators + chars_in_columns = sum( + max_characters_per_column[start_col : end_col + 1] + ) + 2 * (end_col - start_col) + if min_section_len > chars_in_columns: + add_chars_per_col = math.ceil( + (min_section_len - chars_in_columns) / (end_col - start_col + 1) + ) + for col in range(start_col, end_col + 1): + max_characters_per_column[col] += add_chars_per_col + + return max_characters_per_column + + def print_table_divider( + self, max_chars_per_column: List[int], include_separators: bool, indent: int = 0 + ): + """ + Print a divider line for the table (top and bottom of table with '=' characters) + + :param max_chars_per_column: The maximum number of characters per column. + :param include_separators: Whether to include separators between columns. + :param indent: The number of spaces to indent the line. + Defaults to 0. + """ + if include_separators: + columns = [ + settings.table_headers_border_char * max_chars + + settings.table_column_separator_char + + settings.table_headers_border_char + for max_chars in max_chars_per_column + ] + else: + columns = [ + settings.table_border_char * (max_chars + 2) + for max_chars in max_chars_per_column + ] + + columns[-1] = columns[-1][:-2] + self.print_line(value=columns, style=Colors.INFO, indent=indent) + + def print_table_sections( + self, + sections: Dict[str, Tuple[int, int]], + max_chars_per_column: List[int], + indent: int = 0, + ): + """ + Print the sections of the table with corresponding separators to the columns + the sections are mapped to to ensure it is compliant with a CSV format. + For example, a section named "Metadata" with columns 0-3 will print this: + Metadata ,,,, + Where the spaces plus the separators at the end will span the columns 0-3. + All columns must be accounted for in the sections. + + :param sections: The sections of the table. + :param max_chars_per_column: The maximum number of characters per column. + :param indent: The number of spaces to indent the line. + Defaults to 0. + """ + section_tuples = [(start, end, name) for name, (start, end) in sections.items()] + section_tuples.sort(key=lambda x: x[0]) + + if any(start > end for start, end, _ in section_tuples): + raise ValueError(f"Invalid section ranges: {section_tuples}") + + if ( + any( + section_tuples[ind][1] + 1 != section_tuples[ind + 1][0] + for ind in range(len(section_tuples) - 1) + ) + or section_tuples[0][0] != 0 + or section_tuples[-1][1] != len(max_chars_per_column) - 1 + ): + raise ValueError(f"Invalid section ranges: {section_tuples}") + + line_values = [] + line_styles = [] + for section, (start_col, end_col) in sections.items(): + section_length = sum(max_chars_per_column[start_col : end_col + 1]) + 2 * ( + end_col - start_col + 1 + ) + num_separators = end_col - start_col + line_values.append(section) + line_styles.append("bold " + Colors.INFO) + line_values.append( + " " * (section_length - len(section) - num_separators - 2) + ) + line_styles.append("") + line_values.append(settings.table_column_separator_char * num_separators) + line_styles.append("") + line_values.append(settings.table_column_separator_char + " ") + line_styles.append(Colors.INFO) + line_values = line_values[:-1] + line_styles = line_styles[:-1] + self.print_line(value=line_values, style=line_styles, indent=indent) + + def print_table_row( + self, column_lines: List[List[str]], style: str, indent: int = 0 + ): + """ + Print a single row of a table to the console. + + :param column_lines: The lines of text to print for each column. + :param indent: The number of spaces to indent the line. + Defaults to 0. + """ + for row in range(len(column_lines[0])): + print_line = [] + print_styles = [] + for column in range(len(column_lines)): + print_line.extend( + [ + column_lines[column][row], + settings.table_column_separator_char, + " ", + ] + ) + print_styles.extend([style, Colors.INFO, ""]) + print_line = print_line[:-2] + print_styles = print_styles[:-2] + self.print_line(value=print_line, style=print_styles, indent=indent) def print_benchmarks_metadata(self): - if not self.enabled: - return + """ + Print out the metadata of the benchmarks to the console including the run id, + duration, profile, args, worker, request loader, and extras. + """ if not self.benchmarks: raise ValueError( @@ -198,55 +769,77 @@ def print_benchmarks_metadata(self): ) start_time = self.benchmarks[0].run_stats.start_time - end_time = self.benchmarks[0].run_stats.end_time + end_time = self.benchmarks[-1].run_stats.end_time duration = end_time - start_time - self.print_section_header("Benchmarks Completed") - self.print_labeled_line("Run id", str(self.benchmarks[0].run_id)) + self.print_section_header(title="Benchmarks Metadata") self.print_labeled_line( - "Duration", - f"{duration:.1f} seconds", + label="Run id", + value=str(self.benchmarks[0].run_id), + ) + self.print_labeled_line( + label="Duration", + value=f"{duration:.1f} seconds", ) self.print_labeled_line( - "Profile", - self.benchmarks_profile_str, + label="Profile", + value=self.benchmarks_profile_str, ) self.print_labeled_line( - "Args", - self.benchmarks_args_str, + label="Args", + value=self.benchmarks_args_str, ) self.print_labeled_line( - "Worker", - self.benchmarks_worker_desc_str, + label="Worker", + value=self.benchmarks_worker_desc_str, ) self.print_labeled_line( - "Request Loader", - self.benchmarks_request_loader_desc_str, + label="Request Loader", + value=self.benchmarks_request_loader_desc_str, ) self.print_labeled_line( - "Extras", - self.benchmarks_extras_str, + label="Extras", + value=self.benchmarks_extras_str, ) def print_benchmarks_info(self): - if not self.enabled: - return - + """ + Print out the benchmark information to the console including the start time, + end time, duration, request totals, and token totals for each benchmark. + """ if not self.benchmarks: raise ValueError( "No benchmarks to print info for. Please set benchmarks first." ) + sections = { + "Metadata": (0, 3), + "Requests Made": (4, 6), + "Prompt Tok/Req": (7, 9), + "Output Tok/Req": (10, 12), + "Prompt Tok Total": (13, 15), + "Output Tok Total": (16, 18), + } headers = [ "Benchmark", "Start Time", "End Time", - "Duration (sec)", - "Requests Made \n(comp / inc / err)", - "Prompt Tok / Req \n(comp / inc / err)", - "Output Tok / Req \n(comp / inc / err)", - "Prompt Tok Total \n(comp / inc / err)", - "Output Tok Total \n(comp / inc / err)", + "Duration (s)", + "Comp", + "Inc", + "Err", + "Comp", + "Inc", + "Err", + "Comp", + "Inc", + "Err", + "Comp", + "Inc", + "Err", + "Comp", + "Inc", + "Err", ] rows = [] @@ -257,55 +850,68 @@ def print_benchmarks_info(self): f"{datetime.fromtimestamp(benchmark.start_time).strftime('%H:%M:%S')}", f"{datetime.fromtimestamp(benchmark.end_time).strftime('%H:%M:%S')}", f"{(benchmark.end_time - benchmark.start_time):.1f}", - ( - f"{benchmark.request_totals.successful:>5} / " - f"{benchmark.request_totals.incomplete} / " - f"{benchmark.request_totals.errored}" - ), - ( - f"{benchmark.metrics.prompt_token_count.successful.mean:>5.1f} / " # noqa: E501 - f"{benchmark.metrics.prompt_token_count.incomplete.mean:.1f} / " - f"{benchmark.metrics.prompt_token_count.errored.mean:.1f}" - ), - ( - f"{benchmark.metrics.output_token_count.successful.mean:>5.1f} / " # noqa: E501 - f"{benchmark.metrics.output_token_count.incomplete.mean:.1f} / " - f"{benchmark.metrics.output_token_count.errored.mean:.1f}" - ), - ( - f"{benchmark.metrics.prompt_token_count.successful.total_sum:>6.0f} / " # noqa: E501 - f"{benchmark.metrics.prompt_token_count.incomplete.total_sum:.0f} / " # noqa: E501 - f"{benchmark.metrics.prompt_token_count.errored.total_sum:.0f}" - ), - ( - f"{benchmark.metrics.output_token_count.successful.total_sum:>6.0f} / " # noqa: E501 - f"{benchmark.metrics.output_token_count.incomplete.total_sum:.0f} / " # noqa: E501 - f"{benchmark.metrics.output_token_count.errored.total_sum:.0f}" - ), + f"{benchmark.request_totals.successful:.0f}", + f"{benchmark.request_totals.incomplete:.0f}", + f"{benchmark.request_totals.errored:.0f}", + f"{benchmark.metrics.prompt_token_count.successful.mean:.1f}", + f"{benchmark.metrics.prompt_token_count.incomplete.mean:.1f}", + f"{benchmark.metrics.prompt_token_count.errored.mean:.1f}", + f"{benchmark.metrics.output_token_count.successful.mean:.1f}", + f"{benchmark.metrics.output_token_count.incomplete.mean:.1f}", + f"{benchmark.metrics.output_token_count.errored.mean:.1f}", + f"{benchmark.metrics.prompt_token_count.successful.total_sum:.0f}", + f"{benchmark.metrics.prompt_token_count.incomplete.total_sum:.0f}", + f"{benchmark.metrics.prompt_token_count.errored.total_sum:.0f}", + f"{benchmark.metrics.output_token_count.successful.total_sum:.0f}", + f"{benchmark.metrics.output_token_count.incomplete.total_sum:.0f}", + f"{benchmark.metrics.output_token_count.errored.total_sum:.0f}", ] ) - self.print_table(headers=headers, rows=rows, title="Benchmarks Info") + self.print_table( + headers=headers, rows=rows, title="Benchmarks Info", sections=sections + ) def print_benchmarks_stats(self): - if not self.enabled: - return - + """ + Print out the benchmark statistics to the console including the requests per + second, request concurrency, output tokens per second, total tokens per second, + request latency, time to first token, inter token latency, and time per output + token for each benchmark. + """ if not self.benchmarks: raise ValueError( "No benchmarks to print stats for. Please set benchmarks first." ) + sections = { + "Metadata": (0, 0), + "Request Stats": (1, 2), + "Out Tok/sec": (3, 3), + "Tot Tok/sec": (4, 4), + "Req Latency (ms)": (5, 7), + "TTFT (ms)": (8, 10), + "ITL (ms)": (11, 13), + "TPOT (ms)": (14, 16), + } headers = [ "Benchmark", - "Requests / sec", - "Requests Concurrency", - "Output Tok / sec", - "Total Tok / sec", - "Req Latency (sec)\n(mean / median / p99)", - "TTFT (ms)\n(mean / median / p99)", - "ITL (ms)\n(mean / median / p99)", - "TPOT (ms)\n(mean / median / p99)", + "Per Second", + "Concurrency", + "mean", + "mean", + "mean", + "median", + "p99", + "mean", + "median", + "p99", + "mean", + "median", + "p99", + "mean", + "median", + "p99", ] rows = [] @@ -315,28 +921,20 @@ def print_benchmarks_stats(self): strategy_display_str(benchmark.args.strategy), f"{benchmark.metrics.requests_per_second.successful.mean:.2f}", f"{benchmark.metrics.request_concurrency.successful.mean:.2f}", - f"{benchmark.metrics.output_tokens_per_second.total.mean:.1f}", - f"{benchmark.metrics.tokens_per_second.total.mean:.1f}", - ( - f"{benchmark.metrics.request_latency.successful.mean:.2f} / " - f"{benchmark.metrics.request_latency.successful.median:.2f} / " - f"{benchmark.metrics.request_latency.successful.percentiles.p99:.2f}" - ), - ( - f"{benchmark.metrics.time_to_first_token_ms.successful.mean:.1f} / " # noqa: E501 - f"{benchmark.metrics.time_to_first_token_ms.successful.median:.1f} / " # noqa: E501 - f"{benchmark.metrics.time_to_first_token_ms.successful.percentiles.p99:.1f}" - ), - ( - f"{benchmark.metrics.inter_token_latency_ms.successful.mean:.1f} / " # noqa: E501 - f"{benchmark.metrics.inter_token_latency_ms.successful.median:.1f} / " # noqa: E501 - f"{benchmark.metrics.inter_token_latency_ms.successful.percentiles.p99:.1f}" - ), - ( - f"{benchmark.metrics.time_per_output_token_ms.successful.mean:.1f} / " # noqa: E501 - f"{benchmark.metrics.time_per_output_token_ms.successful.median:.1f} / " # noqa: E501 - f"{benchmark.metrics.time_per_output_token_ms.successful.percentiles.p99:.1f}" - ), + f"{benchmark.metrics.output_tokens_per_second.successful.mean:.1f}", + f"{benchmark.metrics.tokens_per_second.successful.mean:.1f}", + f"{benchmark.metrics.request_latency.successful.mean:.2f}", + f"{benchmark.metrics.request_latency.successful.median:.2f}", + f"{benchmark.metrics.request_latency.successful.percentiles.p99:.2f}", + f"{benchmark.metrics.time_to_first_token_ms.successful.mean:.1f}", + f"{benchmark.metrics.time_to_first_token_ms.successful.median:.1f}", + f"{benchmark.metrics.time_to_first_token_ms.successful.percentiles.p99:.1f}", + f"{benchmark.metrics.inter_token_latency_ms.successful.mean:.1f}", + f"{benchmark.metrics.inter_token_latency_ms.successful.median:.1f}", + f"{benchmark.metrics.inter_token_latency_ms.successful.percentiles.p99:.1f}", + f"{benchmark.metrics.time_per_output_token_ms.successful.mean:.1f}", + f"{benchmark.metrics.time_per_output_token_ms.successful.median:.1f}", + f"{benchmark.metrics.time_per_output_token_ms.successful.percentiles.p99:.1f}", ] ) @@ -344,4 +942,5 @@ def print_benchmarks_stats(self): headers=headers, rows=rows, title="Benchmarks Stats", + sections=sections, ) diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index ea1577d3..1d443784 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -327,6 +327,14 @@ def from_standard_args( # type: ignore[override] if "sweep_size" in kwargs: raise ValueError("Sweep size must not be provided, use rate instead.") + if isinstance(rate, Sequence): + if len(rate) != 1: + raise ValueError( + "Rate must be a single value for sweep profile, received " + f"{len(rate)} values." + ) + rate = rate[0] + if not rate: rate = settings.default_sweep_number @@ -341,7 +349,8 @@ def from_standard_args( # type: ignore[override] or rate <= 1 ): raise ValueError( - f"Rate (sweep_size) must be a positive integer > 1, received {rate}" + f"Rate (sweep_size) must be a positive integer > 1, received {rate} " + f"with type {type(rate)}" ) if not kwargs: diff --git a/src/guidellm/benchmark/progress.py b/src/guidellm/benchmark/progress.py index 059c4b06..e6f83ce1 100644 --- a/src/guidellm/benchmark/progress.py +++ b/src/guidellm/benchmark/progress.py @@ -32,6 +32,13 @@ ) from guidellm.utils import Colors +__all__ = [ + "BenchmarkerTaskProgressState", + "BenchmarkerProgressDisplay", + "GenerativeTextBenchmarkerTaskProgressState", + "GenerativeTextBenchmarkerProgressDisplay", +] + @dataclass class BenchmarkerTaskProgressState: diff --git a/src/guidellm/config.py b/src/guidellm/config.py index ece9d63f..87957e42 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/config.py @@ -11,7 +11,6 @@ "LoggingSettings", "OpenAISettings", "print_config", - "ReportGenerationSettings", "Settings", "reload_settings", "settings", @@ -87,16 +86,6 @@ class OpenAISettings(BaseModel): max_output_tokens: int = 16384 -class ReportGenerationSettings(BaseModel): - """ - Report generation settings for the application - """ - - source: str = "" - report_html_match: str = "window.report_data = {};" - report_html_placeholder: str = "{}" - - class Settings(BaseSettings): """ All the settings are powered by pydantic_settings and could be @@ -139,22 +128,21 @@ class Settings(BaseSettings): # Request/stats settings preferred_prompt_tokens_source: Optional[ Literal["request", "response", "local"] - ] = None + ] = "response" preferred_output_tokens_source: Optional[ Literal["request", "response", "local"] - ] = None + ] = "response" preferred_backend: Literal["openai"] = "openai" openai: OpenAISettings = OpenAISettings() - # Report settings - report_generation: ReportGenerationSettings = ReportGenerationSettings() + # Output settings + table_border_char: str = "=" + table_headers_border_char: str = "-" + table_column_separator_char: str = "|" @model_validator(mode="after") @classmethod def set_default_source(cls, values): - if not values.report_generation.source: - values.report_generation.source = ENV_REPORT_MAPPING.get(values.env) - return values def generate_env_file(self) -> str: diff --git a/src/guidellm/objects/pydantic.py b/src/guidellm/objects/pydantic.py index b6e998fa..8365be33 100644 --- a/src/guidellm/objects/pydantic.py +++ b/src/guidellm/objects/pydantic.py @@ -13,7 +13,7 @@ class StandardBaseModel(BaseModel): """ model_config = ConfigDict( - extra="allow", + extra="ignore", use_enum_values=True, validate_assignment=True, from_attributes=True, diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 3620a3d3..f3f832af 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -10,6 +10,7 @@ is_puncutation, load_text, split_text, + split_text_list_by_length, ) __all__ = [ @@ -22,4 +23,5 @@ "load_text", "is_puncutation", "EndlessTextCreator", + "split_text_list_by_length", ] diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index 92a0284a..8c999b05 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -1,8 +1,9 @@ import gzip import re +import textwrap from importlib.resources import as_file, files # type: ignore[attr-defined] from pathlib import Path -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import ftfy import httpx @@ -12,6 +13,7 @@ from guidellm.config import settings __all__ = [ + "split_text_list_by_length", "filter_text", "clean_text", "split_text", @@ -23,6 +25,54 @@ MAX_PATH_LENGTH = 4096 +def split_text_list_by_length( + text_list: List[Any], + max_characters: Union[int, List[int]], + pad_horizontal: bool = True, + pad_vertical: bool = True, +) -> List[List[str]]: + """ + Split a list of strings into a list of strings, + each with a maximum length of max_characters + + :param text_list: the list of strings to split + :param max_characters: the maximum length of each string + :param pad_horizontal: whether to pad the strings horizontally, defaults to True + :param pad_vertical: whether to pad the strings vertically, defaults to True + :return: a list of strings + """ + if not isinstance(max_characters, list): + max_characters = [max_characters] * len(text_list) + + if len(max_characters) != len(text_list): + raise ValueError( + f"max_characters must be a list of the same length as text_list, " + f"but got {len(max_characters)} and {len(text_list)}" + ) + + result: List[List[str]] = [] + for index, text in enumerate(text_list): + lines = textwrap.wrap(text, max_characters[index]) + result.append(lines) + + if pad_vertical: + max_lines = max(len(lines) for lines in result) + for lines in result: + while len(lines) < max_lines: + lines.append(" ") + + if pad_horizontal: + for index in range(len(result)): + lines = result[index] + max_chars = max_characters[index] + new_lines = [] + for line in lines: + new_lines.append(line.rjust(max_chars)) + result[index] = new_lines + + return result + + def filter_text( text: str, filter_start: Optional[Union[str, int]] = None, diff --git a/tests/unit/benchmark/__init__.py b/tests/unit/benchmark/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/benchmark/test_output.py b/tests/unit/benchmark/test_output.py new file mode 100644 index 00000000..5089dbb2 --- /dev/null +++ b/tests/unit/benchmark/test_output.py @@ -0,0 +1,205 @@ +import csv +import json +from pathlib import Path +from unittest.mock import patch + +import pytest +import yaml +from pydantic import ValidationError + +from guidellm.benchmark import ( + GenerativeBenchmarksReport, +) +from guidellm.benchmark.output import GenerativeBenchmarksConsole +from tests.unit.mock_benchmark import mock_generative_benchmark + + +def test_generative_benchmark_initilization(): + report = GenerativeBenchmarksReport() + assert len(report.benchmarks) == 0 + + mock_benchmark = mock_generative_benchmark() + report_with_benchmarks = GenerativeBenchmarksReport(benchmarks=[mock_benchmark]) + assert len(report_with_benchmarks.benchmarks) == 1 + assert report_with_benchmarks.benchmarks[0] == mock_benchmark + + +def test_generative_benchmark_invalid_initilization(): + with pytest.raises(ValidationError): + GenerativeBenchmarksReport(benchmarks="invalid_type") # type: ignore[arg-type] + + +def test_generative_benchmark_marshalling(): + mock_benchmark = mock_generative_benchmark() + report = GenerativeBenchmarksReport(benchmarks=[mock_benchmark]) + + serialized = report.model_dump() + deserialized = GenerativeBenchmarksReport.model_validate(serialized) + deserialized_benchmark = deserialized.benchmarks[0] + + for field in mock_benchmark.model_fields: + assert getattr(mock_benchmark, field) == getattr(deserialized_benchmark, field) + + +def test_file_json(): + mock_benchmark = mock_generative_benchmark() + report = GenerativeBenchmarksReport(benchmarks=[mock_benchmark]) + + mock_path = Path("mock_report.json") + report.save_file(mock_path) + + with mock_path.open("r") as file: + saved_data = json.load(file) + assert saved_data == report.model_dump() + + loaded_report = GenerativeBenchmarksReport.load_file(mock_path) + loaded_benchmark = loaded_report.benchmarks[0] + + for field in mock_benchmark.model_fields: + assert getattr(mock_benchmark, field) == getattr(loaded_benchmark, field) + + mock_path.unlink() + + +def test_file_yaml(): + mock_benchmark = mock_generative_benchmark() + report = GenerativeBenchmarksReport(benchmarks=[mock_benchmark]) + + mock_path = Path("mock_report.yaml") + report.save_file(mock_path) + + with mock_path.open("r") as file: + saved_data = yaml.safe_load(file) + assert saved_data == report.model_dump() + + loaded_report = GenerativeBenchmarksReport.load_file(mock_path) + loaded_benchmark = loaded_report.benchmarks[0] + + for field in mock_benchmark.model_fields: + assert getattr(mock_benchmark, field) == getattr(loaded_benchmark, field) + + mock_path.unlink() + + +def test_file_csv(): + mock_benchmark = mock_generative_benchmark() + report = GenerativeBenchmarksReport(benchmarks=[mock_benchmark]) + + mock_path = Path("mock_report.csv") + report.save_csv(mock_path) + + with mock_path.open("r") as file: + reader = csv.reader(file) + headers = next(reader) + rows = list(reader) + + assert "Type" in headers + assert len(rows) == 1 + + mock_path.unlink() + + +def test_console_benchmarks_profile_str(): + console = GenerativeBenchmarksConsole(enabled=True) + mock_benchmark = mock_generative_benchmark() + console.benchmarks = [mock_benchmark] + assert ( + console.benchmarks_profile_str == "type=synchronous, strategies=['synchronous']" + ) + + +def test_console_benchmarks_args_str(): + console = GenerativeBenchmarksConsole(enabled=True) + mock_benchmark = mock_generative_benchmark() + console.benchmarks = [mock_benchmark] + assert console.benchmarks_args_str == ( + "max_number=None, max_duration=10.0, warmup_number=None, " + "warmup_duration=None, cooldown_number=None, cooldown_duration=None" + ) + + +def test_console_benchmarks_worker_desc_str(): + console = GenerativeBenchmarksConsole(enabled=True) + mock_benchmark = mock_generative_benchmark() + console.benchmarks = [mock_benchmark] + assert console.benchmarks_worker_desc_str == str(mock_benchmark.worker) + + +def test_console_benchmarks_request_loader_desc_str(): + console = GenerativeBenchmarksConsole(enabled=True) + mock_benchmark = mock_generative_benchmark() + console.benchmarks = [mock_benchmark] + assert console.benchmarks_request_loader_desc_str == str( + mock_benchmark.request_loader + ) + + +def test_console_benchmarks_extras_str(): + console = GenerativeBenchmarksConsole(enabled=True) + mock_benchmark = mock_generative_benchmark() + console.benchmarks = [mock_benchmark] + assert console.benchmarks_extras_str == "None" + + +def test_console_print_section_header(): + console = GenerativeBenchmarksConsole(enabled=True) + with patch.object(console.console, "print") as mock_print: + console.print_section_header("Test Header") + mock_print.assert_called_once() + + +def test_console_print_labeled_line(): + console = GenerativeBenchmarksConsole(enabled=True) + with patch.object(console.console, "print") as mock_print: + console.print_labeled_line("Label", "Value") + mock_print.assert_called_once() + + +def test_console_print_line(): + console = GenerativeBenchmarksConsole(enabled=True) + with patch.object(console.console, "print") as mock_print: + console.print_line("Test Line") + mock_print.assert_called_once() + + +def test_console_print_table(): + console = GenerativeBenchmarksConsole(enabled=True) + headers = ["Header1", "Header2"] + rows = [["Row1Col1", "Row1Col2"], ["Row2Col1", "Row2Col2"]] + with patch.object(console, "print_section_header") as mock_header, patch.object( + console, "print_table_divider" + ) as mock_divider, patch.object(console, "print_table_row") as mock_row: + console.print_table(headers, rows, "Test Table") + mock_header.assert_called_once() + mock_divider.assert_called() + mock_row.assert_called() + + +def test_console_print_benchmarks_metadata(): + console = GenerativeBenchmarksConsole(enabled=True) + mock_benchmark = mock_generative_benchmark() + console.benchmarks = [mock_benchmark] + with patch.object(console, "print_section_header") as mock_header, patch.object( + console, "print_labeled_line" + ) as mock_labeled: + console.print_benchmarks_metadata() + mock_header.assert_called_once() + mock_labeled.assert_called() + + +def test_console_print_benchmarks_info(): + console = GenerativeBenchmarksConsole(enabled=True) + mock_benchmark = mock_generative_benchmark() + console.benchmarks = [mock_benchmark] + with patch.object(console, "print_table") as mock_table: + console.print_benchmarks_info() + mock_table.assert_called_once() + + +def test_console_print_benchmarks_stats(): + console = GenerativeBenchmarksConsole(enabled=True) + mock_benchmark = mock_generative_benchmark() + console.benchmarks = [mock_benchmark] + with patch.object(console, "print_table") as mock_table: + console.print_benchmarks_stats() + mock_table.assert_called_once() diff --git a/tests/unit/mock_benchmark.py b/tests/unit/mock_benchmark.py new file mode 100644 index 00000000..81364fa1 --- /dev/null +++ b/tests/unit/mock_benchmark.py @@ -0,0 +1,271 @@ +from guidellm.benchmark import ( + BenchmarkArgs, + BenchmarkRunStats, + GenerativeBenchmark, + GenerativeTextErrorStats, + GenerativeTextResponseStats, + SynchronousProfile, +) +from guidellm.objects import StatusBreakdown +from guidellm.request import GenerativeRequestLoaderDescription +from guidellm.scheduler import ( + GenerativeRequestsWorkerDescription, + SchedulerRequestInfo, + SynchronousStrategy, +) + +__all__ = ["mock_generative_benchmark"] + + +def mock_generative_benchmark() -> GenerativeBenchmark: + return GenerativeBenchmark.from_stats( + run_id="fa4a92c1-9a1d-4c83-b237-83fcc7971bd3", + successful=[ + GenerativeTextResponseStats( + request_id="181a63e2-dc26-4268-9cfc-2ed9279aae63", + request_type="text_completions", + scheduler_info=SchedulerRequestInfo( + requested=True, + completed=True, + errored=False, + canceled=False, + targeted_start_time=1744728125.203447, + queued_time=1744728125.204123, + dequeued_time=1744728125.2048807, + scheduled_time=1744728125.2048993, + worker_start=1744728125.2049701, + request_start=1744728125.2052872, + request_end=1744728126.7004411, + worker_end=1744728126.701175, + process_id=0, + ), + prompt="such a sacrifice to her advantage as years of gratitude cannot enough acknowledge. By this time she is actually with them! If such goodness does not make her miserable now, she will never deserve to be happy! What a meeting for her, when she first sees my aunt! We must endeavour to forget all that has passed on either side, said Jane I hope and trust they will yet be happy. His consenting to marry her is a proof, I will believe, that he is come to a right way of thinking. Their mutual affection will steady them; and I flatter myself they will settle so quietly, and live in so rational a manner", # noqa: E501 + output=", as to make their long life together very comfortable and very useful. I feel, if they and the honourable Mr. Thorpe, who still lives amongst us, should be all I need, I could perfectly rest happy. Writes to meet them in that kind of obedience which is necessary and honourable, and such", # noqa: E501 + prompt_tokens=128, + output_tokens=64, + start_time=1744728125.2052872, + end_time=1744728126.7004411, + first_token_time=1744728125.2473357, + last_token_time=1744728126.699908, + ), + GenerativeTextResponseStats( + request_id="8a7846d5-7624-420d-a269-831e568a848f", + request_type="text_completions", + scheduler_info=SchedulerRequestInfo( + requested=True, + completed=True, + errored=False, + canceled=False, + targeted_start_time=1744728125.204613, + queued_time=1744728125.2047558, + dequeued_time=1744728126.7025175, + scheduled_time=1744728126.7025256, + worker_start=1744728126.702579, + request_start=1744728126.7027814, + request_end=1744728128.1961868, + worker_end=1744728128.196895, + process_id=0, + ), + prompt="a reconciliation; and, after a little further resistance on the part of his aunt, her resentment gave way, either to her affection for him, or her curiosity to see how his wife conducted herself; and she condescended to wait on them at Pemberley, in spite of that pollution which its woods had received, not merely from the presence of such a mistress, but the visits of her uncle and aunt from the city. With the Gardiners they were always on the most intimate terms. Darcy, as well as Elizabeth, really loved them; and they were both ever sensible of the warmest gratitude towards the persons who,", # noqa: E501 + output=" in their own days of poverty, had been so hotel and hospitable to a young couple leaving Pemberley. Till the size of Mr. Bennet\u2019s salary had been altered, the blessing of their friendship was much more greatly needed by the family than it appeared after that event.\n- Mr. Darcy soon deserved", # noqa: E501 + prompt_tokens=128, + output_tokens=64, + start_time=1744728126.7027814, + end_time=1744728128.1961868, + first_token_time=1744728126.7526379, + last_token_time=1744728128.1956792, + ), + GenerativeTextResponseStats( + request_id="4cde0e6c-4531-4e59-aac1-07bc8b6e4139", + request_type="text_completions", + scheduler_info=SchedulerRequestInfo( + requested=True, + completed=True, + errored=False, + canceled=False, + targeted_start_time=1744728126.7031465, + queued_time=1744728126.7034643, + dequeued_time=1744728128.198447, + scheduled_time=1744728128.1984534, + worker_start=1744728128.198509, + request_start=1744728128.1986883, + request_end=1744728129.6919055, + worker_end=1744728129.692606, + process_id=0, + ), + prompt="struck her, that _she_ was selected from among her sisters as worthy of being the mistress of Hunsford Parsonage, and of assisting to form a quadrille table at Rosings, in the absence of more eligible visitors. The idea soon reached to conviction, as she observed his increasing civilities towards herself, and heard his frequent attempt at a compliment on her wit and vivacity; and though more astonished than gratified herself by this effect of her charms, it was not long before her mother gave her to understand that the probability of their marriage was exceedingly agreeable to _her_. Elizabeth, however, did not choose", # noqa: E501 + output=" to improve this conversation into a prophecy, and her mother would hardly take on herself to announce so important a phenomenon. At last he was to drive to Hunsford from Meryton on Sunday; they staid for an hour at eight o'clock, and the following day appeared to be hung up on the walls of", # noqa: E501 + prompt_tokens=128, + output_tokens=64, + start_time=1744728128.1986883, + end_time=1744728129.6919055, + first_token_time=1744728128.2481627, + last_token_time=1744728129.6914039, + ), + GenerativeTextResponseStats( + request_id="a95b96be-05d4-4130-b0dd-9528c01c9909", + request_type="text_completions", + scheduler_info=SchedulerRequestInfo( + requested=True, + completed=True, + errored=False, + canceled=False, + targeted_start_time=1744728128.1987216, + queued_time=1744728128.1991177, + dequeued_time=1744728129.6953137, + scheduled_time=1744728129.695318, + worker_start=1744728129.695379, + request_start=1744728129.6955585, + request_end=1744728131.187553, + worker_end=1744728131.188169, + process_id=0, + ), + prompt="were comfortable on this subject. Day after day passed away without bringing any other tidings of him than the report which shortly prevailed in Meryton of his coming no more to Netherfield the whole winter; a report which highly incensed Mrs. Bennet, and which she never failed to contradict as a most scandalous falsehood. Even Elizabeth began to fear not that Bingley was indifferent but that his sisters would be successful in keeping him away. Unwilling as she was to admit an idea so destructive to Jane s happiness, and so dishonourable to the stability of her lover, she could not prevent its frequently recurring", # noqa: E501 + output=" during these indefinite disputes; and was often seriously engaged in blaming her sisters for increasing a suspense which might only be caused by their own inattention to a subject of so much moment. Whether she had really made that impression on the s+.ayers, or whether she had merely imagined it, she could decide no farther, for", # noqa: E501 + prompt_tokens=128, + output_tokens=64, + start_time=1744728129.6955585, + end_time=1744728131.187553, + first_token_time=1744728129.7438853, + last_token_time=1744728131.187019, + ), + GenerativeTextResponseStats( + request_id="714b751c-bbfe-4b2a-a0af-7c1bf2c224ae", + request_type="text_completions", + scheduler_info=SchedulerRequestInfo( + requested=True, + completed=True, + errored=False, + canceled=False, + targeted_start_time=1744728129.6975086, + queued_time=1744728129.6978767, + dequeued_time=1744728131.190093, + scheduled_time=1744728131.190101, + worker_start=1744728131.1901798, + request_start=1744728131.1904676, + request_end=1744728132.6833503, + worker_end=1744728132.6839745, + process_id=0, + ), + prompt="? cried Elizabeth, brightening up for a moment. Upon my word, said Mrs. Gardiner, I begin to be of your uncle s opinion. It is really too great a violation of decency, honour, and interest, for him to be guilty of it. I cannot think so very ill of Wickham. Can you, yourself, Lizzie, so wholly give him up, as to believe him capable of it? Not perhaps of neglecting his own interest. But of every other neglect I can believe him capable. If, indeed, it should be so! But I dare not hope it. Why should they not go on", # noqa: E501 + output=" together? This is still a motive incapable of being denied. He has such a faculty of pleasing, and you know how much she likes him. \nQuestion: What made elder sisters the center of their families?\nSometimes early this would be discussed in the family circle, but that was a very exceptional treatment.\nThank you,", # noqa: E501 + prompt_tokens=128, + output_tokens=64, + start_time=1744728131.1904676, + end_time=1744728132.6833503, + first_token_time=1744728131.2394557, + last_token_time=1744728132.6828275, + ), + GenerativeTextResponseStats( + request_id="ef73ae8a-4c8f-4c88-b303-cfff152ce378", + request_type="text_completions", + scheduler_info=SchedulerRequestInfo( + requested=True, + completed=True, + errored=False, + canceled=False, + targeted_start_time=1744728131.1891043, + queued_time=1744728131.1893764, + dequeued_time=1744728132.6859632, + scheduled_time=1744728132.6859682, + worker_start=1744728132.6860242, + request_start=1744728132.6862206, + request_end=1744728134.1805167, + worker_end=1744728134.1813161, + process_id=0, + ), + prompt="was. But her commendation, though costing her some trouble, could by no means satisfy Mr. Collins, and he was very soon obliged to take her Ladyship s praise into his own hands. Sir William stayed only a week at Hunsford; but his visit was long enough to convince him of his daughter s being most comfortably settled, and of her possessing such a husband and such a neighbour as were not often met with. While Sir William was with them, Mr. Collins devoted his mornings to driving him out in his gig, and showing him the country but when he went away, the whole family returned to their usual employments", # noqa: E501 + output=", and the sides of the family in which he was more particularly interested, to their respective places in the establishment. Here Jane was occasionally up as a substitute to her indolent sister, in her matron s stead, but was more frequently left idle, and with her hours of quietness, the unwelcome intrusion", # noqa: E501 + prompt_tokens=128, + output_tokens=64, + start_time=1744728132.6862206, + end_time=1744728134.1805167, + first_token_time=1744728132.7354612, + last_token_time=1744728134.1797993, + ), + ], + errored=[], + incomplete=[ + GenerativeTextErrorStats( + request_id="1b3def04-ca81-4f59-a56c-452a069d91af", + request_type="text_completions", + scheduler_info=SchedulerRequestInfo( + requested=True, + completed=False, + errored=True, + canceled=True, + targeted_start_time=1744728132.686177, + queued_time=1744728132.6866345, + dequeued_time=1744728134.1831052, + scheduled_time=1744728134.1831107, + worker_start=1744728134.183183, + request_start=1744728134.183544, + request_end=1744728135.2031732, + worker_end=1744728135.2033112, + process_id=0, + ), + prompt="is to tempt anyone to our humble abode. Our plain manner of living, our small rooms, and few domestics, and the little we see of the world, must make Hunsford extremely dull to a young lady like yourself; but I hope you will believe us grateful for the condescension, and that we have done everything in our power to prevent you spending your time unpleasantly. Elizabeth was eager with her thanks and assurances of happiness. She had spent six weeks with great enjoyment; and the pleasure of being with Charlotte, and the kind attention she had received, must make _her_ feel the obliged. Mr. Collins", # noqa: E501 + output=", who certainly had an eye to Elizabeth's manner, was glad _he was not to lose the curiosity she had given, and requested her away_ , _for the politeness of her conciliating manner would", # noqa: E501 + prompt_tokens=128, + output_tokens=43, + start_time=1744728134.183544, + end_time=1744728135.2031732, + first_token_time=1744728134.2323751, + last_token_time=1744728135.1950455, + error="TimeoutError: The request timed out before completing.", + ) + ], + args=BenchmarkArgs( + profile=SynchronousProfile(), + strategy_index=0, + strategy=SynchronousStrategy(), + max_number=None, + max_duration=10.0, + warmup_number=None, + warmup_duration=None, + cooldown_number=None, + cooldown_duration=None, + ), + run_stats=BenchmarkRunStats( + start_time=1744728125.0772898, + end_time=1744728135.8407037, + requests_made=StatusBreakdown( + successful=6, + errored=0, + incomplete=1, + total=7, + ), + queued_time_avg=1.2821388585226876, + scheduled_time_delay_avg=7.96999250139509e-6, + scheduled_time_sleep_avg=0.0, + worker_start_delay_avg=6.399835859026228e-5, + worker_time_avg=1.4266603674207414, + worker_start_time_targeted_delay_avg=1.2825865745544434, + request_start_time_delay_avg=0.6414163964135307, + request_start_time_targeted_delay_avg=1.2827096836907523, + request_time_delay_avg=0.0004316908972603934, + request_time_avg=1.426228676523481, + ), + worker=GenerativeRequestsWorkerDescription( + backend_type="openai_http", + backend_target="http://localhost:8000", + backend_model="neuralmagic/Qwen2.5-7B-quantized.w8a8", + backend_info={ + "max_output_tokens": 16384, + "timeout": 300, + "http2": True, + "authorization": False, + "organization": None, + "project": None, + "text_completions_path": "/v1/completions", + "chat_completions_path": "/v1/chat/completions", + }, + ), + requests_loader=GenerativeRequestLoaderDescription( + data='{"prompt_tokens": 128, "output_tokens": 64}', + data_args=None, + processor="neuralmagic/Qwen2.5-7B-quantized.w8a8", + processor_args=None, + ), + extras={}, + ) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 13e1699d..c32159ec 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,11 +1,14 @@ import pytest from guidellm.config import ( + DatasetSettings, Environment, LoggingSettings, OpenAISettings, - ReportGenerationSettings, Settings, + print_config, + reload_settings, + settings, ) @@ -15,10 +18,6 @@ def test_default_settings(): assert settings.env == Environment.PROD assert settings.logging == LoggingSettings() assert settings.openai == OpenAISettings() - assert ( - settings.report_generation.source - == "https://guidellm.neuralmagic.com/local-report/index.html" - ) @pytest.mark.smoke() @@ -30,7 +29,6 @@ def test_settings_from_env_variables(mocker): "GUIDELLM__logging__disabled": "true", "GUIDELLM__OPENAI__API_KEY": "test_key", "GUIDELLM__OPENAI__BASE_URL": "http://test.url", - "GUIDELLM__REPORT_GENERATION__SOURCE": "http://custom.url", }, ) @@ -39,31 +37,6 @@ def test_settings_from_env_variables(mocker): assert settings.logging.disabled is True assert settings.openai.api_key == "test_key" assert settings.openai.base_url == "http://test.url" - assert settings.report_generation.source == "http://custom.url" - - -@pytest.mark.smoke() -def test_report_generation_default_source(): - settings = Settings(env=Environment.LOCAL) - assert settings.report_generation.source == "tests/dummy/report.html" - - settings = Settings(env=Environment.DEV) - assert ( - settings.report_generation.source - == "https://dev.guidellm.neuralmagic.com/local-report/index.html" - ) - - settings = Settings(env=Environment.STAGING) - assert ( - settings.report_generation.source - == "https://staging.guidellm.neuralmagic.com/local-report/index.html" - ) - - settings = Settings(env=Environment.PROD) - assert ( - settings.report_generation.source - == "https://guidellm.neuralmagic.com/local-report/index.html" - ) @pytest.mark.sanity() @@ -88,6 +61,90 @@ def test_openai_settings(): @pytest.mark.sanity() -def test_report_generation_settings(): - report_settings = ReportGenerationSettings(source="http://custom.report") - assert report_settings.source == "http://custom.report" +def test_generate_env_file(): + settings = Settings() + env_file_content = settings.generate_env_file() + assert "GUIDELLM__LOGGING__DISABLED" in env_file_content + assert "GUIDELLM__OPENAI__API_KEY" in env_file_content + + +@pytest.mark.sanity() +def test_reload_settings(mocker): + mocker.patch.dict( + "os.environ", + { + "GUIDELLM__env": "staging", + "GUIDELLM__logging__disabled": "false", + }, + ) + reload_settings() + assert settings.env == Environment.STAGING + assert settings.logging.disabled is False + + +@pytest.mark.sanity() +def test_print_config(capsys): + print_config() + captured = capsys.readouterr() + assert "Settings:" in captured.out + assert "GUIDELLM__LOGGING__DISABLED" in captured.out + assert "GUIDELLM__OPENAI__API_KEY" in captured.out + + +@pytest.mark.sanity() +def test_dataset_settings_defaults(): + dataset_settings = DatasetSettings() + assert dataset_settings.preferred_data_columns == [ + "prompt", + "instruction", + "input", + "inputs", + "question", + "context", + "text", + "content", + "body", + "data", + ] + assert dataset_settings.preferred_data_splits == [ + "test", + "tst", + "validation", + "val", + "train", + ] + + +@pytest.mark.sanity() +def test_openai_settings_defaults(): + openai_settings = OpenAISettings() + assert openai_settings.api_key is None + assert openai_settings.bearer_token is None + assert openai_settings.organization is None + assert openai_settings.project is None + assert openai_settings.base_url == "http://localhost:8000" + assert openai_settings.max_output_tokens == 16384 + + +@pytest.mark.sanity() +def test_table_properties_defaults(): + settings = Settings() + assert settings.table_border_char == "=" + assert settings.table_headers_border_char == "-" + assert settings.table_column_separator_char == "|" + + +@pytest.mark.sanity() +def test_settings_with_env_variables(mocker): + mocker.patch.dict( + "os.environ", + { + "GUIDELLM__DATASET__PREFERRED_DATA_COLUMNS": '["custom_column"]', + "GUIDELLM__OPENAI__API_KEY": "env_api_key", + "GUIDELLM__TABLE_BORDER_CHAR": "*", + }, + ) + settings = Settings() + assert settings.dataset.preferred_data_columns == ["custom_column"] + assert settings.openai.api_key == "env_api_key" + assert settings.table_border_char == "*"