diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index ff0277c2056..dcd2afa7a13 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -28,3 +28,6 @@ matplotlib>=3.9.4 myst-parser==0.18.1 sphinx_design==0.4.1 sphinx-copybutton==0.5.0 + +# script unit test requirements +yaspin==3.1.0 diff --git a/.ci/scripts/benchmark_tooling/README.md b/.ci/scripts/benchmark_tooling/README.md new file mode 100644 index 00000000000..25ba2a739a4 --- /dev/null +++ b/.ci/scripts/benchmark_tooling/README.md @@ -0,0 +1,171 @@ +# Executorch Benchmark Tooling + +A library providing tools for fetching, processing, and analyzing ExecutorchBenchmark data from the HUD Open API. This tooling helps compare performance metrics between private and public devices with identical settings. + +## Table of Contents + +- [Overview](#overview) +- [Installation](#installation) +- [Tools](#tools) + - [get_benchmark_analysis_data.py](#get_benchmark_analysis_datapy) + - [Quick Start](#quick-start) + - [Command Line Options](#command-line-options) + - [Example Usage](#example-usage) + - [Working with Output Files](#working-with-output-files-csv-and-excel) + - [Python API Usage](#python-api-usage) +- [Running Unit Tests](#running-unit-tests) + +## Overview + +The Executorch Benchmark Tooling provides a suite of utilities designed to: + +- Fetch benchmark data from HUD Open API for specified time ranges +- Clean and process data by filtering out failures +- Compare metrics between private and public devices with matching configurations +- Generate analysis reports in various formats (CSV, Excel, JSON) +- Support filtering by device pools, backends, and models + +This tooling is particularly useful for performance analysis, regression testing, and cross-device comparisons. + +## Installation + +Install dependencies: + +```bash +pip install -r requirements.txt +``` + +## Tools + +### get_benchmark_analysis_data.py + +This script is mainly used to generate analysis data comparing private devices with public devices using the same settings. + +It fetches benchmark data from HUD Open API for a specified time range, cleans the data by removing entries with FAILURE indicators, and retrieves all private device metrics along with equivalent public device metrics based on matching [model, backend, device_pool_names, arch] configurations. Users can filter the data by specifying private device_pool_names, backends, and models. + +#### Quick Start + +```bash +# generate excel sheets for all private devices with public devices using the same settings +python3 .ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py \ + --startTime "2025-06-11T00:00:00" \ + --endTime "2025-06-17T18:00:00" \ + --outputType "excel" + +python3 .ci/scripts/benchmark_tooling/analyze_benchmark_stability.py \ +--primary-file private.xlsx \ +--reference-file public.xlsx +``` + +#### Command Line Options + +##### Basic Options: +- `--startTime`: Start time in ISO format (e.g., "2025-06-11T00:00:00") (required) +- `--endTime`: End time in ISO format (e.g., "2025-06-17T18:00:00") (required) +- `--env`: Choose environment ("local" or "prod", default: "prod") +- `--no-silent`: Show processing logs (default: only show results & minimum logging) + +##### Output Options: +- `--outputType`: Choose output format (default: "print") + - `print`: Display results in console + - `json`: Generate JSON file + - `df`: Display results in DataFrame format: `{'private': List[{'groupInfo':Dict,'df': DF},...],'public':List[{'groupInfo':Dict,'df': DF}]` + - `excel`: Generate Excel files with multiple sheets, the field in first row and first column contains the JSON string of the raw metadata + - `csv`: Generate CSV files in separate folders, the field in first row and first column contains the JSON string of the raw metadata +- `--outputDir`: Directory to save output files (default: current directory) + +##### Filtering Options: + +- `--private-device-pools`: Filter by private device pool names (e.g., "samsung-galaxy-s22-5g", "samsung-galaxy-s22plus-5g") +- `--backends`: Filter by specific backend names (e.g., "qnn-q8", "llama3-spinquan") +- `--models`: Filter by specific model names (e.g., "mv3", "meta-llama-llama-3.2-1b-instruct-qlora-int4-eo8") + +#### Example Usage + +Filter by multiple private device pools and models: +```bash +# This fetches all private table data for models 'llama-3.2-1B' and 'mv3' +python3 get_benchmark_analysis_data.py \ + --startTime "2025-06-01T00:00:00" \ + --endTime "2025-06-11T00:00:00" \ + --private-device-pools 'apple_iphone_15_private' 'samsung_s22_private' \ + --models 'meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8' 'mv3' +``` + +Filter by specific device pool and models: +```bash +# This fetches all private iPhone table data for models 'llama-3.2-1B' and 'mv3', +# and associated public iPhone data +python3 get_benchmark_analysis_data.py \ + --startTime "2025-06-01T00:00:00" \ + --endTime "2025-06-11T00:00:00" \ + --private-device-pools 'apple_iphone_15_private' \ + --models 'meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8' 'mv3' +``` + +#### Working with Output Files CSV and Excel + +You can use methods in `common.py` to convert the file data back to DataFrame format. These methods read the first row in CSV/Excel files and return results with the format `list of {"groupInfo":DICT, "df":df.Dataframe{}}`. + +```python +import logging +logging.basicConfig(level=logging.INFO) +from .ci.scripts.benchmark_tooling.common import read_all_csv_with_metadata, read_excel_with_json_header + +# For CSV files (assuming the 'private' folder is in the current directory) +folder_path = './private' +res = read_all_csv_with_metadata(folder_path) +logging.info(res) + +# For Excel files (assuming the Excel file is in the current directory) +file_path = "./private.xlsx" +res = read_excel_with_json_header(file_path) +logging.info(res) +``` + +#### Python API Usage + +To use the benchmark fetcher in your own scripts: + +```python +from .ci.scripts.benchmark_tooling.get_benchmark_analysis_data import ExecutorchBenchmarkFetcher + +# Initialize the fetcher +fetcher = ExecutorchBenchmarkFetcher(env="prod", disable_logging=False) + +# Fetch data for a specific time range +fetcher.run( + start_time="2025-06-11T00:00:00", + end_time="2025-06-17T18:00:00" +) + +# Get results in different formats +# As DataFrames +df_results = fetcher.to_df() + +# Export to Excel +fetcher.to_excel(output_dir="./results") + +# Export to CSV +fetcher.to_csv(output_dir="./results") + +# Export to JSON +json_path = fetcher.to_json(output_dir="./results") + +# Get raw dictionary results +dict_results = fetcher.to_dict() + +# Use the output_data method for flexible output +results = fetcher.output_data(output_type="excel", output_dir="./results") +``` + +## Running Unit Tests + +The benchmark tooling includes unit tests to ensure functionality. + +### Using pytest for unit tests + +```bash +# From the executorch root directory +pytest -c /dev/null .ci/scripts/tests/test_get_benchmark_analysis_data.py +``` diff --git a/.ci/scripts/benchmark_tooling/__init__.py b/.ci/scripts/benchmark_tooling/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/.ci/scripts/analyze_benchmark_stability.py b/.ci/scripts/benchmark_tooling/analyze_benchmark_stability.py similarity index 88% rename from .ci/scripts/analyze_benchmark_stability.py rename to .ci/scripts/benchmark_tooling/analyze_benchmark_stability.py index 47f984b7ce3..64e4b05df86 100644 --- a/.ci/scripts/analyze_benchmark_stability.py +++ b/.ci/scripts/benchmark_tooling/analyze_benchmark_stability.py @@ -1,10 +1,10 @@ import argparse import os -import re import matplotlib.pyplot as plt import numpy as np import pandas as pd +from common import read_excel_with_json_header from tabulate import tabulate @@ -15,40 +15,23 @@ def print_section_header(title): print("=" * 100 + "\n") -def normalize_tab_name(name): +def normalize_name(name): """Normalize tab name for better matching""" # Convert to lowercase and remove spaces - return name.lower().replace(" ", "") + return name.lower().replace(" ", "").replace("(private)", "") -def parse_model_device(sheet_name): - """Extract model and device from sheet name using the 'model+device' pattern""" - parts = sheet_name.split("+", 1) - if len(parts) < 2: - return sheet_name, "Unknown" - return parts[0], parts[1] - - -def extract_model_device_os(sheet_name): - """ - Extract model, device, and OS from sheet name - Format expected: model+device_osname - Returns: (model, device_base, os_version) - """ - model, device_full = parse_model_device(sheet_name) - - # Use regex to separate device base name from OS version - # Pattern looks for device name followed by underscore or android/ios - match = re.match(r"(.*?)(android|ios|_)(.*)", device_full, re.IGNORECASE) - - if match: - device_base = match.group(1).rstrip("_") - os_name = match.group(2) - os_version = match.group(3) - return model, device_base, f"{os_name}{os_version}" - else: - # If no OS version found, return the device as is with empty OS - return model, device_full, "" +def parse_model_device_config(config): + """Extract model and device from config""" + model = config.get("model", "") + backend = config.get("backend", "") + full_model = f"{model}({backend})" if backend else model + base_device = config.get("device", "") + os_version = config.get("arch", "") + full_device = f"{base_device}({os_version})" if os_version else base_device + if not base_device: + return full_model, "unkown", "unknown", "" + return full_model, full_device, base_device, os_version def is_matching_dataset(primary_sheet, reference_sheet): @@ -56,10 +39,20 @@ def is_matching_dataset(primary_sheet, reference_sheet): Check if two datasets match for comparison based on model and device Allows different OS versions for the same device """ - primary_model, primary_device, primary_os = extract_model_device_os(primary_sheet) - reference_model, reference_device, reference_os = extract_model_device_os( - reference_sheet - ) + primary_model = normalize_name(primary_sheet.get("model", "")) + primary_device = normalize_name(primary_sheet.get("base_device", "")) + # primary_os = normalize_name(primary_sheet.get("os_version", "")) + + reference_model = normalize_name(reference_sheet.get("model", "")) + reference_device = normalize_name(reference_sheet.get("base_device", "")) + # reference_os = normalize_name(reference_sheet.get("os_version", "")) + + if not primary_model: + print("Warning: Primary sheet {} has no model info, for {primary_model} ") + return False + if not reference_model: + print("Warning: Reference sheet {} has no model info, for {reference_model}") + return False # Model must match exactly if primary_model != reference_model: @@ -69,26 +62,12 @@ def is_matching_dataset(primary_sheet, reference_sheet): if primary_device != reference_device: return False - # If we get here, model and device base match, so it's a valid comparison - # even if OS versions differ return True def analyze_latency_stability( # noqa: C901 primary_file, reference_file=None, output_dir="stability_analysis_results" ): - """ - Analyze latency stability metrics from benchmark data in Excel files. - - Parameters: - ----------- - primary_file : str - Path to the Excel file containing primary (private) benchmark data - reference_file : str, optional - Path to the Excel file containing reference (public) benchmark data - output_dir : str - Directory to save output files - """ print(f"Analyzing latency stability from primary file: {primary_file}") if reference_file: print(f"Using reference file for comparison: {reference_file}") @@ -100,15 +79,28 @@ def analyze_latency_stability( # noqa: C901 # Load primary datasets print_section_header("LOADING PRIMARY DATASETS (Private)") primary_datasets = {} - primary_xls = pd.ExcelFile(primary_file) + documents = read_excel_with_json_header(primary_file) + + for document in documents: + sheetName = document.get("sheetName", None) + df = document.get("df", None) + config = document.get("groupInfo", None) + print(f"Loading dataset: {sheetName} with config: {config} ") + + if df is None or df.empty: + print(f"Skipping sheet {sheetName} because it has no df data") + continue + + if not config or not sheetName: + print( + f" Skipping document: Missing required data groupInfo:{config} sheetName:{sheetName}" + ) + continue - for sheet in primary_xls.sheet_names: - print(f"Loading dataset: {sheet}") - df = pd.read_excel(primary_xls, sheet_name=sheet) - model, device = parse_model_device(sheet) + model, full_device, base_device, os_version = parse_model_device_config(config) # Check if required columns exist - required_cols = ["InferenceTime", "Date"] + required_cols = ["avg_inference_latency(ms)", "metadata_info.timestamp"] if "trimmean_inference_latency(ms)" in df.columns: trimmed_col = "trimmean_inference_latency(ms)" required_cols.append(trimmed_col) @@ -123,36 +115,54 @@ def analyze_latency_stability( # noqa: C901 # Skip sheets without required columns if not all(col in df.columns for col in required_cols): - print(f" Skipping {sheet}: Missing required columns") + print(f" Skipping {sheetName}: Missing required columns") continue # Convert Date to datetime - df["Date"] = pd.to_datetime(df["Date"]) + df["Date"] = pd.to_datetime(df["metadata_info.timestamp"]) # Calculate stability metrics - metrics = calculate_stability_metrics(df, "InferenceTime", trimmed_col, tps_col) + metrics = calculate_stability_metrics( + df, "avg_inference_latency(ms)", trimmed_col, tps_col + ) - primary_datasets[sheet] = { + primary_datasets[sheetName] = { "df": df, "metrics": metrics, "model": model, - "device": device, - "sheet_name": sheet, + "full_device": full_device, + "base_device": base_device, + "os_version": os_version, + "sheet_name": sheetName, } # Load reference datasets if provided reference_datasets = {} if reference_file: print_section_header("LOADING REFERENCE DATASETS (Public)") - reference_xls = pd.ExcelFile(reference_file) + documents = read_excel_with_json_header(reference_file) + + for document in documents: + sheetName = document.get("sheetName", None) + df = document.get("df", None) + config = document.get("groupInfo", None) + print(f"Loading dataset: {sheetName} with config:{config}") + if df is None or df.empty: + print(f"Skipping sheet {sheetName} because it has no df data") + continue + + if not config or not sheetName: + print( + f" Skipping document: Missing required data groupInfo:{config} sheetName:{sheetName}" + ) + continue - for sheet in reference_xls.sheet_names: - print(f"Loading reference dataset: {sheet}") - df = pd.read_excel(reference_xls, sheet_name=sheet) - model, device = parse_model_device(sheet) + model, full_device, base_device, os_version = parse_model_device_config( + config + ) # Check if required columns exist - required_cols = ["InferenceTime", "Date"] + required_cols = ["avg_inference_latency(ms)", "metadata_info.timestamp"] if "trimmean_inference_latency(ms)" in df.columns: trimmed_col = "trimmean_inference_latency(ms)" required_cols.append(trimmed_col) @@ -167,23 +177,27 @@ def analyze_latency_stability( # noqa: C901 # Skip sheets without required columns if not all(col in df.columns for col in required_cols): - print(f" Skipping reference {sheet}: Missing required columns") + print( + f" Skipping reference {sheetName}: Missing required columns{required_cols}" + ) continue # Convert Date to datetime - df["Date"] = pd.to_datetime(df["Date"]) + df["Date"] = pd.to_datetime(df["metadata_info.timestamp"]) # Calculate stability metrics metrics = calculate_stability_metrics( - df, "InferenceTime", trimmed_col, tps_col + df, "avg_inference_latency(ms)", trimmed_col, tps_col ) - reference_datasets[sheet] = { + reference_datasets[sheetName] = { "df": df, "metrics": metrics, "model": model, - "device": device, - "sheet_name": sheet, + "full_device": full_device, + "sheet_name": sheetName, + "base_device": base_device, + "os_version": os_version, } # Process primary datasets @@ -193,7 +207,7 @@ def analyze_latency_stability( # noqa: C901 generate_dataset_report( sheet, info["model"], - info["device"], + info["full_device"], "Primary", info["df"], info["metrics"], @@ -212,7 +226,7 @@ def analyze_latency_stability( # noqa: C901 generate_dataset_report( sheet, info["model"], - info["device"], + info["full_device"], "Reference", info["df"], info["metrics"], @@ -232,7 +246,7 @@ def analyze_latency_stability( # noqa: C901 found_match = False for ref_sheet, ref_info in reference_datasets.items(): - if is_matching_dataset(primary_sheet, ref_sheet): + if is_matching_dataset(primary_info, ref_info): # Found a match print( f"Matched: {primary_sheet} (Private) with {ref_sheet} (Public)" @@ -240,11 +254,8 @@ def analyze_latency_stability( # noqa: C901 generate_comparison_report( primary_sheet, ref_sheet, - primary_info["model"], - primary_info["device"], - ref_info["device"], - primary_info["metrics"], - ref_info["metrics"], + primary_info, + ref_info, output_dir, ) found_match = True @@ -252,7 +263,9 @@ def analyze_latency_stability( # noqa: C901 break if not found_match: - print(f"Warning: No matching reference dataset for {primary_sheet}") + print( + f"Warning: No matching reference dataset for {primary_sheet} with config: {primary_info['model']}{primary_info['full_device']} " + ) if not matches_found: print("No matching datasets found between primary and reference files.") @@ -620,7 +633,12 @@ def generate_time_series_plot(dataset_name, df, output_dir, dataset_type): df_sorted = df.sort_values("Date") # Plot raw latency - plt.plot(df_sorted["Date"], df_sorted["InferenceTime"], "b-", label="Raw Latency") + plt.plot( + df_sorted["Date"], + df_sorted["avg_inference_latency(ms)"], + "b-", + label="Raw Latency", + ) # Plot trimmed latency if available if "trimmean_inference_latency(ms)" in df_sorted.columns: @@ -634,7 +652,9 @@ def generate_time_series_plot(dataset_name, df, output_dir, dataset_type): # Add rolling mean window = min(5, len(df_sorted)) if window > 1: - rolling_mean = df_sorted["InferenceTime"].rolling(window=window).mean() + rolling_mean = ( + df_sorted["avg_inference_latency(ms)"].rolling(window=window).mean() + ) plt.plot( df_sorted["Date"], rolling_mean, "r--", label=f"{window}-point Rolling Mean" ) @@ -658,11 +678,8 @@ def generate_time_series_plot(dataset_name, df, output_dir, dataset_type): def generate_comparison_report( # noqa: C901 primary_sheet, reference_sheet, - model, - primary_device, - reference_device, - primary_metrics, - reference_metrics, + primary_info, + reference_info, output_dir, ): """Generate a comparison report between primary and reference datasets""" @@ -671,6 +688,12 @@ def generate_comparison_report( # noqa: C901 # Create a string buffer to hold the report content report_content = [] + model = (primary_info["model"],) + primary_device = (primary_info["full_device"],) + reference_device = reference_info["full_device"] + primary_metrics = primary_info["metrics"] + reference_metrics = reference_info["metrics"] + # Header report_content.append("Private vs Public Stability Comparison") report_content.append("=" * 80) @@ -971,8 +994,10 @@ def generate_comparison_report( # noqa: C901 ) # Note about OS version difference if applicable - _, primary_device_base, primary_os = extract_model_device_os(primary_sheet) - _, reference_device_base, reference_os = extract_model_device_os(reference_sheet) + primary_device_base = primary_info.get("base_device", "") + primary_os = primary_info.get("os_version", "") + reference_device_base = reference_info.get("base_device", "") + reference_os = reference_info.get("os_version", "") if primary_os != reference_os and primary_os and reference_os: report_content.append("") @@ -1030,7 +1055,7 @@ def generate_intra_primary_summary(primary_datasets, output_dir): # noqa: C901 { "Sheet": sheet_name, "Model": info["model"], - "Device": info["device"], + "Device": info["full_device"], "Mean Latency (ms)": info["metrics"]["mean_raw_latency"], "CV (%)": info["metrics"]["cv_raw_latency"], "Stability Score": info["metrics"]["stability_score"], @@ -1103,8 +1128,8 @@ def generate_intra_primary_summary(primary_datasets, output_dir): # noqa: C901 # Device-based comparison # First, extract base device names for grouping device_base_map = {} - for sheet_name in primary_datasets: - _, device_base, _ = extract_model_device_os(sheet_name) + for sheet_name, info in primary_datasets.items(): + device_base = info.get("base_device", "") device_base_map[sheet_name] = device_base # Add base device to DataFrame @@ -1138,8 +1163,8 @@ def generate_intra_primary_summary(primary_datasets, output_dir): # noqa: C901 # OS version comparison if multiple OS versions exist os_versions = {} - for sheet_name in primary_datasets: - _, _, os_version = extract_model_device_os(sheet_name) + for sheet_name, info in primary_datasets.items(): + os_version = info.get("os_version", "") if os_version: # Only include if OS version was extracted os_versions[sheet_name] = os_version @@ -1254,9 +1279,13 @@ def generate_summary_report( # noqa: C901 # Primary datasets summary primary_data = [] for sheet_name, info in primary_datasets.items(): - model, device_base, os_version = extract_model_device_os(sheet_name) + model, device_base, os_version = ( + info.get("model", ""), + info.get("base_device", ""), + info.get("os_version", ""), + ) device_display = ( - f"{device_base} ({os_version})" if os_version else info["device"] + f"{device_base}({os_version})" if os_version else info["device"] ) primary_data.append( @@ -1287,9 +1316,13 @@ def generate_summary_report( # noqa: C901 if reference_datasets: reference_data = [] for sheet_name, info in reference_datasets.items(): - model, device_base, os_version = extract_model_device_os(sheet_name) + model, device_base, os_version = ( + info.get("model", ""), + info.get("base_device", ""), + info.get("os_version", ""), + ) device_display = ( - f"{device_base} ({os_version})" if os_version else info["device"] + f"{device_base}({os_version})" if os_version else info["device"] ) reference_data.append( @@ -1322,29 +1355,31 @@ def generate_summary_report( # noqa: C901 # Comparison summary for matching datasets comparison_data = [] - for primary_sheet, primary_info in primary_datasets.items(): - for ref_sheet, ref_info in reference_datasets.items(): - if is_matching_dataset(primary_sheet, ref_sheet): + for _, primary_info in primary_datasets.items(): + for _, ref_info in reference_datasets.items(): + if is_matching_dataset(primary_info, ref_info): primary_metrics = primary_info["metrics"] reference_metrics = ref_info["metrics"] # Extract model and device info for display - model, primary_device_base, primary_os = extract_model_device_os( - primary_sheet - ) - _, reference_device_base, reference_os = extract_model_device_os( - ref_sheet + model, primary_device_base, primary_os = ( + primary_info.get("model", ""), + primary_info.get("base_device", ""), + primary_info.get("os_version", ""), ) + reference_device_base, reference_os = ref_info.get( + "base_device", "" + ), ref_info.get("os_version", "") primary_device_display = ( f"{primary_device_base} ({primary_os})" if primary_os - else primary_info["device"] + else primary_info["full_device"] ) reference_device_display = ( f"{reference_device_base} ({reference_os})" if reference_os - else ref_info["device"] + else ref_info["full_device"] ) comparison_data.append( @@ -1424,15 +1459,15 @@ def generate_summary_report( # noqa: C901 # OS version insights if available os_versions = {} - for sheet_name in primary_datasets: - _, _, os_version = extract_model_device_os(sheet_name) + for sheet_name, info in primary_datasets.items(): + os_version = info.get("os_version", "") if os_version: os_versions[sheet_name] = os_version if os_versions and len(set(os_versions.values())) > 1: # Add OS version to primary DataFrame primary_df["OS Version"] = primary_df["Dataset"].map( - lambda x: extract_model_device_os(x)[2] + lambda x: primary_datasets[x].get("os_version", np.nan) ) # Remove rows with no OS version @@ -1498,11 +1533,11 @@ def main(): description="Analyze ML model latency stability from benchmark data." ) parser.add_argument( - "primary_file", + "--primary-file", help="Path to Excel file containing primary (private) benchmark data", ) parser.add_argument( - "--reference_file", + "--reference-file", help="Path to Excel file containing reference (public) benchmark data for comparison", default=None, ) diff --git a/.ci/scripts/benchmark_tooling/common.py b/.ci/scripts/benchmark_tooling/common.py new file mode 100644 index 00000000000..521e9f3b3ce --- /dev/null +++ b/.ci/scripts/benchmark_tooling/common.py @@ -0,0 +1,50 @@ +import json +import os +from typing import Any, Dict, List + +import pandas as pd + + +def read_excel_with_json_header(path: str) -> List[Dict[str, Any]]: + # Read all sheets into a dict of DataFrames, without altering + all_sheets = pd.read_excel(path, sheet_name=None, header=None, engine="openpyxl") + + results = [] + for sheet, df in all_sheets.items(): + # Extract JSON string from A1 (row 0, col 0) + json_str = df.iat[0, 0] + meta = json.loads(json_str) if isinstance(json_str, str) else {} + + # The actual data starts from the next row; treat row 1 as header + df_data = pd.read_excel(path, sheet_name=sheet, skiprows=1, engine="openpyxl") + results.append({"groupInfo": meta, "df": df_data, "sheetName": sheet}) + print(f"successfully fetched {len(results)} sheets from {path}") + return results + + +def read_all_csv_with_metadata(folder_path: str) -> List[Dict[str, Any]]: + results = [] # {filename: {"meta": dict, "df": DataFrame}} + for fname in os.listdir(folder_path): + if not fname.lower().endswith(".csv"): + continue + path = os.path.join(folder_path, fname) + with open(path, "r", encoding="utf-8") as f: + first_line = f.readline().strip() + try: + meta = json.loads(first_line) + except json.JSONDecodeError: + meta = {} + df = pd.read_csv(path, skiprows=1) + results.append({"groupInfo": meta, "df": df, "sheetName": fname}) + print(f"successfully fetched {len(results)} sheets from {folder_path}") + return results + + +import logging + +logging.basicConfig(level=logging.INFO) + +# For Excel files (assuming the Excel file is in the current directory) +file_path = "./private.xlsx" +res = read_excel_with_json_header(file_path) +logging.info(res) diff --git a/.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py b/.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py new file mode 100644 index 00000000000..aecf4ff9744 --- /dev/null +++ b/.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py @@ -0,0 +1,763 @@ +""" +ExecutorchBenchmark Analysis Data Retrieval + +This module provides tools for fetching, processing, and analyzing benchmark data +from the HUD Open API for ExecutorchBenchmark. It supports filtering data by (private) device pool names, +backends, and models, exporting results in various formats (JSON, DataFrame, Excel, CSV), +and customizing data retrieval parameters. +""" + +import argparse +import json +import logging +import os +import re +from copy import deepcopy +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import pandas as pd +import requests +from yaspin import yaspin + +logging.basicConfig(level=logging.INFO) + +# add here just for the records +VALID_PRIVATE_DEVICE_POOLS_MAPPINGS = { + "apple_iphone_15_private": [ + ("Apple iPhone 15 Pro (private)", "iOS 18.4.1"), + ("Apple iPhone 15 (private)", "iOS 18.0"), + ("Apple iPhone 15 Plus (private)", "iOS 17.4.1"), + ], + "samsung_s22_private": [ + ("Samsung Galaxy S22 Ultra 5G (private)", "Android 14"), + ("Samsung Galaxy S22 5G (private)", "Android 13"), + ], +} + +VALID_PRIVATE_DEVICE_POOLS_NAMES = list(VALID_PRIVATE_DEVICE_POOLS_MAPPINGS.keys()) + + +class OutputType(Enum): + """ + Enumeration of supported output formats for benchmark data. + + Values: + EXCEL: Export data to Excel spreadsheets + PRINT: Print data to console (default) + CSV: Export data to CSV files + JSON: Export data to JSON files + DF: Return data as pandas DataFrames + """ + + EXCEL = "excel" + PRINT = "print" + CSV = "csv" + JSON = "json" + DF = "df" + + +@dataclass +class BenchmarkQueryGroupDataParams: + """ + Parameters for querying benchmark data from HUD API. + + Attributes: + repo: Repository name (e.g., "pytorch/executorch") + benchmark_name: Name of the benchmark (e.g., "ExecuTorch") + start_time: ISO8601 formatted start time + end_time: ISO8601 formatted end time + group_table_by_fields: Fields to group tables by + group_row_by_fields: Fields to group rows by + """ + + repo: str + benchmark_name: str + start_time: str + end_time: str + group_table_by_fields: list + group_row_by_fields: list + + +@dataclass +class MatchingGroupResult: + """ + Container for benchmark results grouped by category. + + Attributes: + category: Category name (e.g., 'private', 'public') + data: List of benchmark data for this category + """ + + category: str + data: list + + +@dataclass +class BenchmarkFilters: + models: list + backends: list + devicePoolNames: list + + +BASE_URLS = { + "local": "http://localhost:3000", + "prod": "https://hud.pytorch.org", +} + + +def validate_iso8601_no_ms(value: str): + """ + Validate that a string is in ISO8601 format without milliseconds. + Args: + value: String to validate (format: YYYY-MM-DDTHH:MM:SS) + """ + try: + return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S").strftime( + "%Y-%m-%dT%H:%M:%S" + ) + except ValueError: + raise argparse.ArgumentTypeError( + f"Invalid datetime format for '{value}'. Expected: YYYY-MM-DDTHH:MM:SS" + ) + + +class ExecutorchBenchmarkFetcher: + """ + Fetch and process benchmark data from HUD API for ExecutorchBenchmark. + + This class provides methods to: + 1. Fetch all benchmark data for a specified time range + 2. Get all private device info within the time range + 3. Filter the private device data if filter is provided + 4. Then use the filtered private device data to find matched the public device data using [model, backend, device, arch] + 3. Export results in various formats (JSON, DataFrame, Excel, CSV) + + Usage: + fetcher = ExecutorchBenchmarkFetcher() + fetcher.run(start_time, end_time) + fetcher.output_data(OutputType.EXCEL, output_dir="./results") + """ + + def __init__( + self, + env: str = "prod", + disable_logging: bool = False, + group_table_fields=None, + group_row_fields=None, + ): + """ + Initialize the ExecutorchBenchmarkFetcher. + + Args: + env: Environment to use ('local' or 'prod') + disable_logging: Whether to suppress log output + group_table_fields: Custom fields to group tables by (defaults to device, backend, arch, model) + group_row_fields: Custom fields to group rows by (defaults to workflow_id, job_id, granularity_bucket) + """ + self.env = env + self.base_url = self._get_base_url() + self.query_group_table_by_fields = ( + group_table_fields + if group_table_fields + else ["model", "backend", "device", "arch"] + ) + self.query_group_row_by_fields = ( + group_row_fields + if group_row_fields + else ["workflow_id", "job_id", "metadata_info.timestamp"] + ) + self.data = None + self.disable_logging = disable_logging + self.matching_groups: Dict[str, MatchingGroupResult] = {} + + def run( + self, + start_time: str, + end_time: str, + filters: Optional[BenchmarkFilters] = None, + ) -> None: + # reset group & raw data for new run + self.matching_groups = {} + self.data = None + + data = self._fetch_execu_torch_data(start_time, end_time) + if data is None: + logging.warning("no data fetched from the HUD API") + return None + self._proces_raw_data(data) + self._process_private_public_data(filters) + + def _filter_out_failure_only( + self, data_list: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Clean data by removing rows that only contain FAILURE_REPORT metrics. + + Args: + data_list: List of benchmark data dictionaries + + Returns: + Filtered list with rows containing only FAILURE_REPORT removed + """ + ONLY = {"workflow_id", "metadata_info.timestamp", "job_id", "FAILURE_REPORT"} + for item in data_list: + filtered_rows = [ + row + for row in item.get("rows", []) + # Keep row only if it has additional fields beyond ONLY + if not set(row.keys()).issubset(ONLY) + ] + item["rows"] = filtered_rows + return [item for item in data_list if item.get("rows")] + + def _filter_public_result(self, private_list, all_public): + # find intersection betwen private and public tables. + common = list( + set([item["table_name"] for item in private_list]) + & set([item["table_name"] for item in all_public]) + ) + + if not self.disable_logging: + logging.info( + f"Found {len(common)} table names existed in both private and public, use it to filter public tables:" + ) + logging.info(json.dumps(common, indent=1)) + filtered_public = [item for item in all_public if item["table_name"] in common] + return filtered_public + + def get_result(self) -> Dict[str, List[Dict[str, Any]]]: + """ + Get a deep copy of the benchmark results. + + Returns: + Dictionary containing benchmark results grouped by category + """ + return deepcopy(self.to_dict()) + + def to_excel(self, output_dir: str = ".") -> None: + """ + Export benchmark results to Excel files. + Creates two Excel files: + - res_private.xlsx: Results for private devices + - res_public.xlsx: Results for public devices + Each file contains multiple sheets, one per benchmark configuration for private and public. + Args: + output_dir: Directory to save Excel files + """ + for item in self.matching_groups.values(): + self._write_multi_sheet_excel(item.data, output_dir, item.category) + + def _write_multi_sheet_excel(self, data_list, output_dir, file_name): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logging.info(f"Created output directory: {output_dir}") + else: + logging.info(f"Using existing output directory: {output_dir}") + file = os.path.join(output_dir, f"{file_name}.xlsx") + with pd.ExcelWriter(file, engine="xlsxwriter") as writer: + workbook = writer.book + for idx, entry in enumerate(data_list): + sheet_name = f"table{idx+1}" + df = pd.DataFrame(entry.get("rows", [])) + + # Encode metadata as compact JSON string + meta = entry.get("groupInfo", {}) + json_str = json.dumps(meta, separators=(",", ":")) + + worksheet = workbook.add_worksheet(sheet_name) + writer.sheets[sheet_name] = worksheet + + # Write JSON into A1 + worksheet.write_string(0, 0, json_str) + + logging.info( + f"Wrting excel sheet to file {file} with sheet name {sheet_name} for {entry['table_name']}" + ) + # Write DataFrame starting at row 2 (index 1) + df.to_excel(writer, sheet_name=sheet_name, startrow=1, index=False) + + def output_data( + self, output_type: OutputType = OutputType.PRINT, output_dir: str = "." + ) -> Any: + """ + Generate output in the specified format. + + Supports multiple output formats: + - PRINT: Print results to console + - JSON: Export to JSON files + - DF: Return as pandas DataFrames + - EXCEL: Export to Excel files + - CSV: Export to CSV files + + Args: + output_type: Format to output the data in + output_dir: Directory to save output files (for file-based formats) + + Returns: + Benchmark results in the specified format + """ + logging.info( + f"Generating output with type {output_type}: {[self.matching_groups.keys()]}" + ) + + o_type = self._to_output_type(output_type) + if o_type == OutputType.PRINT: + logging.info("\n ========= Generate print output ========= \n") + logging.info(json.dumps(self.get_result(), indent=2)) + elif o_type == OutputType.JSON: + logging.info("\n ========= Generate json output ========= \n") + file_path = self.to_json(output_dir) + logging.info(f"success, please check {file_path}") + elif o_type == OutputType.DF: + logging.info("\n ========= Generate dataframe output ========= \n") + res = self.to_df() + logging.info(res) + return res + elif o_type == OutputType.EXCEL: + logging.info("\n ========= Generate excel output ========= \n") + self.to_excel(output_dir) + elif o_type == OutputType.CSV: + logging.info("\n ========= Generate csv output ========= \n") + self.to_csv(output_dir) + return self.get_result() + + def _to_output_type(self, output_type: Any) -> OutputType: + if isinstance(output_type, str): + try: + return OutputType(output_type.lower()) + except ValueError: + logging.warning( + f"Invalid output type string: {output_type}. Defaulting to PRINT" + ) + return OutputType.JSON + elif isinstance(output_type, OutputType): + return output_type + logging.warning(f"Invalid output type: {output_type}. Defaulting to JSON") + return OutputType.JSON + + def to_json(self, output_dir: str = ".") -> Any: + """ + Export benchmark results to a JSON file. + + Args: + output_dir: Directory to save the JSON file + + Returns: + Path to the generated JSON file + """ + data = self.get_result() + return self.generate_json_file(data, "benchmark_results", output_dir) + + def generate_json_file(self, data, file_name, output_dir: str = "."): + """ + Generate a JSON file from the provided data. + + Args: + data: Data to write to the JSON file + file_name: Name for the JSON file (without extension) + output_dir: Directory to save the JSON file + + Returns: + Path to the generated JSON file + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logging.info(f"Created output directory: {output_dir}") + else: + logging.info(f"Using existing output directory: {output_dir}") + path = os.path.join(output_dir, file_name + ".json") + with open(path, "w") as f: + json.dump(data, f, indent=2) + return path + + def to_dict(self) -> Dict[str, List[Dict[str, Any]]]: + """ + Convert benchmark results to a dictionary. + + Returns: + Dictionary with categories as keys and benchmark data as values + """ + result = {} + for item in self.matching_groups.values(): + result[item.category] = item.data + return result + + def to_df(self) -> Dict[str, List[Dict[str, Union[Dict[str, Any], pd.DataFrame]]]]: + """ + Convert benchmark results to pandas DataFrames. + + Creates a dictionary with categories as keys and lists of DataFrames as values. + Each DataFrame represents one benchmark configuration. + + Returns: + Dictionary mapping categories ['private','public'] to lists of DataFrames "df" with metadata 'groupInfo'. + + """ + result = {} + for item in self.matching_groups.values(): + result[item.category] = [ + { + "groupInfo": item.get("groupInfo", {}), + "df": pd.DataFrame(item.get("rows", [])), + } + for item in item.data + ] + return result + + def to_csv(self, output_dir: str = ".") -> None: + """ + Export benchmark results to CSV files. + + Creates two CSV folders and one json file: + - private/: Results for private devices + - public/: Results for public devices + - benchmark_name_mappings.json: json dict which maps the generated csv file_name to + + Each file contains multiple CSV files, one per benchmark configuration for private and public. + + Args: + output_dir: Directory to save CSV files + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logging.info(f"Created output directory: {output_dir}") + else: + logging.info(f"Using existing output directory: {output_dir}") + + for item in self.matching_groups.values(): + path = os.path.join(output_dir, item.category) + self._write_multiple_csv_files(item.data, path) + + def _write_multiple_csv_files( + self, data_list: List[Dict[str, Any]], output_dir: str, prefix: str = "" + ) -> None: + """ + Write multiple benchmark results to CSV files. + + Creates a CSV file for each benchmark configuration, with metadata + as a JSON string in the first row and data in subsequent rows. + + Args: + data_list: List of benchmark result dictionaries + output_dir: Directory to save CSV files + prefix: Optional prefix for CSV filenames + """ + os.makedirs(output_dir, exist_ok=True) + for idx, entry in enumerate(data_list): + filename = f"{prefix}_table{idx+1}.csv" if prefix else f"table{idx+1}.csv" + file_path = os.path.join(output_dir, filename) + + # Prepare DataFrame + df = pd.DataFrame(entry.get("rows", [])) + + # Prepare metadata JSON (e.g. groupInfo) + meta = entry.get("groupInfo", {}) + json_str = json.dumps(meta, separators=(",", ":")) + + logging.info(f"Wrting csv file to {file_path}") + + # Write metadata and data + with open(file_path, "w", encoding="utf-8", newline="") as f: + f.write(json_str + "\n") # First row: JSON metadata + df.to_csv(f, index=False) # Remaining rows: DataFrame rows + + def _get_base_url(self) -> str: + """ + Get the base URL for API requests based on environment. + + Returns: + Base URL string for the configured environment + """ + return BASE_URLS[self.env] + + def get_all_private_devices(self) -> Tuple[List[Any], List[Any]]: + """ + Print all devices found in the data. + Separates results by category and displays counts. + This is useful for debugging and understanding what data is available. + """ + if not self.data: + logging.info("No data found, please call get_data() first") + return ([], []) + + all_private = { + (group.get("device", ""), group.get("arch", "")) + for item in self.data + if (group := item.get("groupInfo", {})).get("aws_type") == "private" + } + iphone_set = {pair for pair in all_private if "iphone" in pair[0].lower()} + samsung_set = {pair for pair in all_private if "samsung" in pair[0].lower()} + + # logging + logging.info( + f"Found private {len(iphone_set)} iphone devices: {list(iphone_set)}" + ) + logging.info( + f"Found private {len(samsung_set)} samsung devices: {list(samsung_set)}" + ) + return (list(iphone_set), list(samsung_set)) + + def _generate_table_name( + self, group_info: Dict[str, Any], fields: List[str] + ) -> str: + """ + Generate a table name from group info fields. + + Creates a normalized string by joining specified fields from group info. + + Args: + group_info: Dictionary containing group information + fields: List of field names to include in the table name + + Returns: + Normalized table name string + """ + name = "-".join( + self.normalize_string(group_info[k]) + for k in fields + if k in group_info and group_info[k] + ) + + return name + + def _proces_raw_data(self, input_data: List[Dict[str, Any]]): + """ + Process raw benchmark data. + """ + logging.info(f"fetched {len(input_data)} data from HUD") + data = self._clean_data(input_data) + + for item in data: + org_group = item.get("groupInfo", {}) + if org_group.get("device", "").find("private") != -1: + item["groupInfo"]["aws_type"] = "private" + else: + item["groupInfo"]["aws_type"] = "public" + # Add full name joined by the group key fields + item["table_name"] = self._generate_table_name( + org_group, self.query_group_table_by_fields + ) + self.data = deepcopy(data) + + def _process_private_public_data(self, filters: Optional[BenchmarkFilters]): + """ + Process raw benchmark data. + """ + if not self.data: + logging.info("No data found, please call get_data() first") + return + + # + private_list = sorted( + ( + item + for item in self.data + if item.get("groupInfo", {}).get("aws_type") == "private" + ), + key=lambda x: x["table_name"], + ) + + if filters: + logging.info(f"Found {len(private_list)} private tables before filtering") + private_list = self.filter_private_results(private_list, filters) + else: + logging.info("filters is None, using all private results") + + all_public = sorted( + ( + item + for item in self.data + if item.get("groupInfo", {}).get("aws_type") == "public" + ), + key=lambda x: x["table_name"], + ) + public_list = self._filter_public_result(private_list, all_public) + + logging.info( + f"Found {len(private_list)} private tables, {[item['table_name'] for item in private_list]}" + ) + logging.info( + f"Found assoicated {len(public_list)} public tables, {json.dumps([item['table_name'] for item in public_list],indent=2)}" + ) + + self.matching_groups["private"] = MatchingGroupResult( + category="private", data=private_list + ) + self.matching_groups["public"] = MatchingGroupResult( + category="public", data=public_list + ) + + def _clean_data(self, data_list): + # filter data with arch equal exactly "",ios and android, this normally + # indicates it's job-level falure indicator + removed_gen_arch = [ + item + for item in data_list + if (arch := item.get("groupInfo", {}).get("arch")) is not None + and arch.lower() not in ("ios", "android") + ] + data = self._filter_out_failure_only(removed_gen_arch) + return data + + def _fetch_execu_torch_data( + self, start_time: str, end_time: str + ) -> Optional[List[Dict[str, Any]]]: + url = f"{self.base_url}/api/benchmark/group_data" + params_object = BenchmarkQueryGroupDataParams( + repo="pytorch/executorch", + benchmark_name="ExecuTorch", + start_time=start_time, + end_time=end_time, + group_table_by_fields=self.query_group_table_by_fields, + group_row_by_fields=self.query_group_row_by_fields, + ) + params = {k: v for k, v in params_object.__dict__.items() if v is not None} + with yaspin(text="Waiting for response", color="cyan") as spinner: + response = requests.get(url, params=params) + if response.status_code == 200: + spinner.ok("V") + return response.json() + else: + logging.info(f"Failed to fetch benchmark data ({response.status_code})") + logging.info(response.text) + spinner.fail("x") + return None + + def normalize_string(self, s: str) -> str: + s = s.lower().strip() + s = s.replace("+", "plus") + s = s.replace("-", "_") + s = s.replace(" ", "_") + s = re.sub(r"[^\w\-\.\(\)]", "_", s) + s = re.sub(r"_{2,}", "_", s) + s = s.replace("_(", "(").replace("(_", "(") + s = s.replace(")_", ")").replace("_)", ")") + s = s.replace("(private)", "") + return s + + def filter_private_results( + self, all_privates: List[Dict[str, Any]], filters: BenchmarkFilters + ): + """ + dynamically filter private device data based on filters, if any. + fetch all private devices within the time range, and then filter based on filter parameters + such as device_pool, backends, and models. + """ + private_devices = self.get_all_private_devices() + + device_pool = filters.devicePoolNames or set() + backends = filters.backends or set() + models = filters.models or set() + + if not backends and not device_pool and not models: + logging.info("No filters provided, using all private results") + return all_privates + + device_ios_match = set() + # hardcoded since we only have 2 device pools, each for iphone and samsung + if "apple_iphone_15_private" in device_pool: + device_ios_match.update( + private_devices[0] + ) # assumed to be list of (device, arch) + if "samsung_s22_private" in device_pool: + device_ios_match.update(private_devices[1]) + logging.info( + f"Applying filter: backends={backends}, devices={device_pool}, models={models}, pair_filter={bool(device_ios_match)}" + ) + results = [] + for item in all_privates: + info = item.get("groupInfo", {}) + if backends and info.get("backend") not in backends: + continue + + if device_ios_match: + # must match both device and arch in a record, otherwise skip + pair = (info.get("device", ""), info.get("arch", "")) + if pair not in device_ios_match: + continue + if models and info.get("model", "") not in models: + continue + results.append(item) + + logging.info( + f"Filtered from private data {len(all_privates)} → {len(results)} results" + ) + if not results: + logging.info("No results matched the filters. Something is wrong.") + return results + + +def argparsers(): + parser = argparse.ArgumentParser(description="Benchmark Analysis Runner") + + # Required common args + parser.add_argument( + "--startTime", + type=validate_iso8601_no_ms, + required=True, + help="Start time, ISO format (e.g. 2025-06-01T00:00:00)", + ) + parser.add_argument( + "--endTime", + type=validate_iso8601_no_ms, + required=True, + help="End time, ISO format (e.g. 2025-06-06T00:00:00)", + ) + parser.add_argument( + "--env", choices=["local", "prod"], default="prod", help="Environment" + ) + + parser.add_argument( + "--no-silent", + action="store_false", + dest="silent", + default=True, + help="Allow output (disable silent mode)", + ) + + # Options for generate_data + parser.add_argument( + "--outputType", + choices=["json", "df", "csv", "print", "excel"], + default="print", + help="Output format (only for generate_data)", + ) + + parser.add_argument( + "--outputDir", default=".", help="Output directory, default is ." + ) + parser.add_argument( + "--backends", + nargs="+", + help="Filter results by one or more backend full name(e.g. --backends qlora mv3) (OR logic within backends scope, AND logic with other filter type)", + ) + parser.add_argument( + "--private-device-pools", + nargs="+", # allow one or more values + choices=VALID_PRIVATE_DEVICE_POOLS_NAMES, + help="List of devices to include [apple_iphone_15_private, samsung_s22_private, you can include both] (OR logic within private-device-pools scope, AND logic with other filter type)", + ) + parser.add_argument( + "--models", + nargs="+", + help="Filter by one or more models (e.g. --backend 'meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8' 'mv3') (OR logic withn models scope, AND logic with other filter type)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = argparsers() + fetcher = ExecutorchBenchmarkFetcher(args.env, args.silent) + result = fetcher.run( + args.startTime, + args.endTime, + filters=BenchmarkFilters( + models=args.models, + backends=args.backends, + devicePoolNames=args.private_device_pools, + ), + ) + fetcher.output_data(args.outputType, args.outputDir) diff --git a/.ci/scripts/benchmark_tooling/requirements.txt b/.ci/scripts/benchmark_tooling/requirements.txt new file mode 100644 index 00000000000..3a2d69c0676 --- /dev/null +++ b/.ci/scripts/benchmark_tooling/requirements.txt @@ -0,0 +1,7 @@ +requests>=2.32.3 +xlsxwriter>=3.2.3 +pandas>=2.3.0 +yaspin>=3.1.0 +tabulate +matplotlib +openpyxl diff --git a/.ci/scripts/tests/test_get_benchmark_analysis_data.py b/.ci/scripts/tests/test_get_benchmark_analysis_data.py new file mode 100644 index 00000000000..673452ab481 --- /dev/null +++ b/.ci/scripts/tests/test_get_benchmark_analysis_data.py @@ -0,0 +1,903 @@ +import importlib.util +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock, mock_open, patch + +import pandas as pd + + +class TestBenchmarkAnalysis(unittest.TestCase): + @classmethod + def setUpClass(cls): + script_path = os.path.join( + ".ci", "scripts", "benchmark_tooling", "get_benchmark_analysis_data.py" + ) + spec = importlib.util.spec_from_file_location( + "get_benchmark_analysis_data", script_path + ) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module # Register before execution + spec.loader.exec_module(module) + cls.module = module + + """Test the validate_iso8601_no_ms function.""" + + def test_valid_iso8601(self): + """Test with valid ISO8601 format.""" + valid_date = "2025-06-01T00:00:00" + result = self.module.validate_iso8601_no_ms(valid_date) + self.assertEqual(result, valid_date) + + def test_invalid_iso8601(self): + """Test with invalid ISO8601 format.""" + invalid_dates = [ + "2025-06-01", # Missing time + "2025-06-01 00:00:00", # Space instead of T + "2025-06-01T00:00:00.000", # With milliseconds + "not-a-date", # Not a date at all + ] + for invalid_date in invalid_dates: + with self.subTest(invalid_date=invalid_date): + with self.assertRaises(self.module.argparse.ArgumentTypeError): + self.module.validate_iso8601_no_ms(invalid_date) + + def test_output_type_values(self): + """Test that OutputType has the expected values.""" + self.assertEqual(self.module.OutputType.EXCEL.value, "excel") + self.assertEqual(self.module.OutputType.PRINT.value, "print") + self.assertEqual(self.module.OutputType.CSV.value, "csv") + self.assertEqual(self.module.OutputType.JSON.value, "json") + self.assertEqual(self.module.OutputType.DF.value, "df") + + def setUp(self): + """Set up test fixtures.""" + self.maxDiff = None + + self.fetcher = self.module.ExecutorchBenchmarkFetcher( + env="prod", disable_logging=True + ) + + # Sample data for testing + self.sample_data_1 = [ + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17", + }, + "rows": [ + { + "workflow_id": 1, + "job_id": 1, + "metadata_info.timestamp": "2025-06-15T15:00:00Z", + "metric_1": 2.0, + }, + { + "workflow_id": 2, + "job_id": 2, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 3.0, + }, + ], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "s22_5g", + "arch": "android_13", + }, + "rows": [ + { + "workflow_id": 3, + "job_id": 3, + "metadata_info.timestamp": "2025-06-15T17:00:00Z", + "metric_1": 2.0, + }, + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 3.0, + }, + ], + }, + ] + + self.sample_data_2 = [ + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17.4.3", + }, + "rows": [ + { + "workflow_id": 1, + "job_id": 1, + "metadata_info.timestamp": "2025-06-15T15:00:00Z", + "metric_1": 2.0, + }, + { + "workflow_id": 2, + "job_id": 2, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 3.0, + }, + ], + }, + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max", + "arch": "ios_17.4.3", + }, + "rows": [ + { + "workflow_id": 6, + "job_id": 6, + "metadata_info.timestamp": "2025-06-15T17:00:00Z", + "metric_1": 1.0, + }, + { + "workflow_id": 8, + "job_id": 8, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 1.0, + }, + ], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "s22_5g", + "arch": "android_13", + }, + "rows": [ + { + "workflow_id": 3, + "job_id": 3, + "metadata_info.timestamp": "2025-06-15T17:00:00Z", + "metric_1": 2.0, + }, + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 3.0, + }, + ], + }, + ] + + def test_init(self): + """Test initialization of ExecutorchBenchmarkFetcher.""" + self.assertEqual(self.fetcher.env, "prod") + self.assertEqual(self.fetcher.base_url, "https://hud.pytorch.org") + self.assertEqual( + self.fetcher.query_group_table_by_fields, + ["model", "backend", "device", "arch"], + ) + self.assertEqual( + self.fetcher.query_group_row_by_fields, + ["workflow_id", "job_id", "metadata_info.timestamp"], + ) + self.assertTrue(self.fetcher.disable_logging) + self.assertEqual(self.fetcher.matching_groups, {}) + + def test_get_base_url(self): + """Test _get_base_url method.""" + self.assertEqual(self.fetcher._get_base_url(), "https://hud.pytorch.org") + + # Test with local environment + local_fetcher = self.module.ExecutorchBenchmarkFetcher(env="local") + self.assertEqual(local_fetcher._get_base_url(), "http://localhost:3000") + + def test_normalize_string(self): + """Test normalize_string method.""" + test_cases = [ + ("Test String", "test_string"), + ("test_string", "test_string"), + ("test string", "test_string"), + ("test--string", "test_string"), + ("test (private)", "test"), + ("test@#$%^&*", "test_"), + ] + + for input_str, expected in test_cases: + with self.subTest(input_str=input_str): + result = self.fetcher.normalize_string(input_str) + self.assertEqual(result, expected) + + @patch("requests.get") + def test_fetch_execu_torch_data_success(self, mock_get): + """Test _fetch_execu_torch_data method with successful response.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = self.sample_data_1 + mock_get.return_value = mock_response + + result = self.fetcher._fetch_execu_torch_data( + "2025-06-01T00:00:00", "2025-06-02T00:00:00" + ) + + self.assertEqual(result, self.sample_data_1) + mock_get.assert_called_once() + + @patch("requests.get") + def test_fetch_execu_torch_data_failure(self, mock_get): + """Test _fetch_execu_torch_data method with failed response.""" + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.text = "Not Found" + mock_get.return_value = mock_response + + result = self.fetcher._fetch_execu_torch_data( + "2025-06-01T00:00:00", "2025-06-02T00:00:00" + ) + + self.assertIsNone(result) + mock_get.assert_called_once() + + def test_filter_out_failure_only(self): + """Test _filter_out_failure_only method.""" + test_data = [ + { + "rows": [ + { + "workflow_id": 1, + "job_id": 2, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": 6, + "metric": 7.0, + }, + ] + }, + { + "rows": [ + { + "workflow_id": 8, + "job_id": 9, + "metadata_info.timestamp": 10, + "metric": 11.0, + }, + ] + }, + { + "rows": [ + { + "workflow_id": 10, + "job_id": 12, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + { + "workflow_id": 21, + "job_id": 15, + "metadata_info.timestamp": 6, + "FAILURE_REPORT": "0", + }, + ] + }, + ] + + expected = [ + { + "rows": [ + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": 6, + "metric": 7.0, + }, + ] + }, + { + "rows": [ + { + "workflow_id": 8, + "job_id": 9, + "metadata_info.timestamp": 10, + "metric": 11.0, + }, + ] + }, + ] + + result = self.fetcher._filter_out_failure_only(test_data) + self.assertEqual(result, expected) + + def test_filter_public_result(self): + """Test _filter_public_result method.""" + private_list = [ + {"table_name": "model1_backend1"}, + {"table_name": "model2_backend2"}, + ] + + public_list = [ + {"table_name": "model1_backend1"}, + {"table_name": "model3_backend3"}, + ] + + expected = [{"table_name": "model1_backend1"}] + + result = self.fetcher._filter_public_result(private_list, public_list) + self.assertEqual(result, expected) + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_filter_private_results(self, mock_fetch): + """Test filter_private_results method with various filter combinations.""" + # Create test data + test_data = [ + { + "groupInfo": { + "model": "mv3", + "backend": "coreml_fp16", + "device": "Apple iPhone 15 Pro (private)", + "arch": "iOS 18.0", + "total_rows": 10, + "aws_type": "private", + }, + "rows": [{"metric_1": 1.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "test_backend", + "device": "Apple iPhone 15 Pro (private)", + "arch": "iOS 14.1.0", + "total_rows": 10, + "aws_type": "private", + }, + "rows": [{"metric_1": 1.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "Samsung Galaxy S22 Ultra 5G (private)", + "arch": "Android 14", + "total_rows": 10, + "aws_type": "private", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "Samsung Galaxy S22 Ultra 5G (private)", + "arch": "Android 13", + "total_rows": 10, + "aws_type": "private", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", + "backend": "llama3_spinquant", + "device": "Apple iPhone 15", + "arch": "iOS 18.0", + "total_rows": 19, + "aws_type": "public", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "coreml_fp16", + "device": "Apple iPhone 15 Pro Max", + "arch": "iOS 17.0", + "total_rows": 10, + "aws_type": "public", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", + "backend": "test", + "device": "Samsung Galaxy S22 Ultra 5G", + "arch": "Android 14", + "total_rows": 10, + "aws_type": "public", + }, + "rows": [{"metric_1": 2.0}], + }, + ] + + mock_fetch.return_value = test_data + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + # Test with no filters + empty_filters = self.module.BenchmarkFilters( + models=None, backends=None, devicePoolNames=None + ) + + result = self.fetcher.filter_private_results(test_data, empty_filters) + self.assertEqual(result, test_data) + + # Test with model filter + model_filters = self.module.BenchmarkFilters( + models=["meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"], + backends=None, + devicePoolNames=None, + ) + result = self.fetcher.filter_private_results(test_data, model_filters) + self.assertEqual(len(result), 2) + self.assertTrue( + all( + item["groupInfo"]["model"] + == "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8" + for item in result + ) + ) + + # Test with backend filter + backend_filters = self.module.BenchmarkFilters( + models=None, backends=["coreml_fp16", "test"], devicePoolNames=None + ) + result = self.fetcher.filter_private_results(test_data, backend_filters) + self.assertEqual(len(result), 3) + self.assertTrue( + all( + item["groupInfo"]["backend"] in ["coreml_fp16", "test"] + for item in result + ) + ) + + # Test with device filter + device_filters = self.module.BenchmarkFilters( + models=None, backends=None, devicePoolNames=["samsung_s22_private"] + ) + result = self.fetcher.filter_private_results(test_data, device_filters) + self.assertEqual(len(result), 2) + self.assertTrue( + all( + "Samsung Galaxy S22 Ultra 5G (private)" in item["groupInfo"]["device"] + for item in result + ) + ) + + # Test with combined filters (And logic fails) + combined_filters = self.module.BenchmarkFilters( + models=["meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"], + backends=["xnnpack_q8"], + devicePoolNames=None, + ) + result = self.fetcher.filter_private_results(test_data, combined_filters) + self.assertEqual(len(result), 0) + + # Test with combined filters (And logic success) + combined_filters = self.module.BenchmarkFilters( + models=["mv3"], + backends=None, + devicePoolNames=["apple_iphone_15_private"], + ) + result = self.fetcher.filter_private_results(test_data, combined_filters) + self.assertEqual(len(result), 2) + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_without_public_match(self, mock_fetch): + """Test run method.""" + # Setup mocks + mock_fetch.return_value = self.sample_data_1 + # Run the method + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + result = self.fetcher.get_result() + + # Verify results + self.assertEqual(result, {"private": [self.sample_data_1[0]], "public": []}) + self.assertEqual(len(self.fetcher.matching_groups), 2) + self.assertIn("private", self.fetcher.matching_groups) + self.assertIn("public", self.fetcher.matching_groups) + + # Verify mocks were called + mock_fetch.assert_called_once_with("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_with_public_match(self, mock_fetch): + """Test run method.""" + # Setup mocks + mock_fetch.return_value = self.sample_data_2 + + # Run the method + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + result = self.fetcher.get_result() + + # Verify results + self.assertEqual( + result, + {"private": [self.sample_data_2[0]], "public": [self.sample_data_2[1]]}, + ) + self.assertEqual(len(self.fetcher.matching_groups), 2) + self.assertIn("private", self.fetcher.matching_groups) + self.assertIn("public", self.fetcher.matching_groups) + # Verify mocks were called + mock_fetch.assert_called_once_with("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_with_failure_report(self, mock_fetch): + """Test run method.""" + # Setup mocks + mock_data = [ + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17.4.3", + }, + "rows": [ + { + "workflow_id": 1, + "job_id": 2, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": 6, + "metric": 7.0, + }, + ], + }, + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max", + "arch": "ios_17.4.3", + }, + "rows": [ + { + "workflow_id": 1, + "job_id": 2, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + { + "workflow_id": 1, + "job_id": 2, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + ], + }, + ] + + expected_private = { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17.4.3", + "aws_type": "private", + }, + "rows": [ + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": 6, + "metric": 7.0, + }, + ], + "table_name": "llama3-qlora-iphone_15_pro_max-ios_17.4.3", + } + mock_fetch.return_value = mock_data + # Run the method + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + result = self.fetcher.get_result() + # Verify results + self.assertEqual(result.get("private", []), [expected_private]) + self.assertEqual(len(self.fetcher.matching_groups), 2) + self.assertIn("private", self.fetcher.matching_groups) + self.assertIn("public", self.fetcher.matching_groups) + # Verify mocks were called + mock_fetch.assert_called_once_with("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_no_data(self, mock_fetch): + """Test run method when no data is fetched.""" + mock_fetch.return_value = None + + result = self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + self.assertIsNone(result) + self.assertEqual(self.fetcher.matching_groups, {}) + mock_fetch.assert_called_once_with("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_with_filters(self, mock_fetch): + """Test run method with filters.""" + # Setup mock data + mock_data = [ + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17", + }, + "rows": [{"metric_1": 1.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "s22_5g (private)", + "arch": "android_13", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "s22_5g", + "arch": "android_13", + }, + "rows": [{"metric_1": 3.0}], + }, + ] + mock_fetch.return_value = mock_data + + # Create filters for llama3 model only + filters = self.module.BenchmarkFilters( + models=["llama3"], backends=None, devicePoolNames=None + ) + # Run the method with filters + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00", filters) + result = self.fetcher.get_result() + print("result1", result) + + # Verify results - should only have llama3 in private results + self.assertEqual(len(result["private"]), 1) + self.assertEqual(result["private"][0]["groupInfo"]["model"], "llama3") + + # Public results should be empty since there's no matching table_name + self.assertEqual(result["public"], []) + + # Test with backend filter + filters = self.module.BenchmarkFilters( + models=None, backends=["xnnpack_q8"], devicePoolNames=None + ) + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00", filters) + result = self.fetcher.get_result() + + print("result", result) + + # Verify results - should only have xnnpack_q8 in private results + self.assertEqual(len(result["private"]), 1) + self.assertEqual(result["private"][0]["groupInfo"]["backend"], "xnnpack_q8") + + # Public results should have the matching xnnpack_q8 entry + self.assertEqual(len(result["public"]), 1) + self.assertEqual(result["public"][0]["groupInfo"]["backend"], "xnnpack_q8") + + def test_to_dict(self): + """Test to_dict method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", data=[{"key": "private_value"}] + ), + "public": self.module.MatchingGroupResult( + category="public", data=[{"key": "public_value"}] + ), + } + + expected = { + "private": [{"key": "private_value"}], + "public": [{"key": "public_value"}], + } + + result = self.fetcher.to_dict() + self.assertEqual(result, expected) + + def test_to_df(self): + """Test to_df method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", + data=[{"groupInfo": {"model": "llama3"}, "rows": [{"metric1": 1.0}]}], + ), + } + + result = self.fetcher.to_df() + + self.assertIn("private", result) + self.assertEqual(len(result["private"]), 1) + self.assertIn("groupInfo", result["private"][0]) + self.assertIn("df", result["private"][0]) + self.assertIsInstance(result["private"][0]["df"], pd.DataFrame) + self.assertEqual(result["private"][0]["groupInfo"], {"model": "llama3"}) + + @patch("os.makedirs") + @patch("json.dump") + @patch("builtins.open", new_callable=mock_open) + def test_to_json(self, mock_file, mock_json_dump, mock_makedirs): + """Test to_json method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", data=[{"key": "value"}] + ), + } + + with tempfile.TemporaryDirectory() as temp_dir: + result = self.fetcher.to_json(temp_dir) + + # Check that the file path is returned + self.assertEqual(result, os.path.join(temp_dir, "benchmark_results.json")) + + # Check that the file was opened for writing + mock_file.assert_called_once_with( + os.path.join(temp_dir, "benchmark_results.json"), "w" + ) + + # Check that json.dump was called with the expected data + mock_json_dump.assert_called_once() + args, _ = mock_json_dump.call_args + self.assertEqual(args[0], {"private": [{"key": "value"}]}) + + @patch("pandas.DataFrame.to_excel") + @patch("pandas.ExcelWriter") + @patch("os.makedirs") + def test_to_excel(self, mock_makedirs, mock_excel_writer, mock_to_excel): + """Test to_excel method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", + data=[ + { + "groupInfo": {"model": "llama3"}, + "rows": [{"metric1": 1.0}], + "table_name": "llama3_table", + } + ], + ), + } + + # Mock the context manager for ExcelWriter + mock_writer = MagicMock() + mock_excel_writer.return_value.__enter__.return_value = mock_writer + mock_writer.book = MagicMock() + mock_writer.book.add_worksheet.return_value = MagicMock() + mock_writer.sheets = {} + + with tempfile.TemporaryDirectory() as temp_dir: + self.fetcher.to_excel(temp_dir) + + # Check that ExcelWriter was called with the expected path + mock_excel_writer.assert_called_once_with( + os.path.join(temp_dir, "private.xlsx"), engine="xlsxwriter" + ) + + # Check that to_excel was called + mock_to_excel.assert_called_once() + + @patch("os.makedirs") + @patch("builtins.open", new_callable=mock_open) + @patch("pandas.DataFrame.to_csv") + def test_to_csv(self, mock_to_csv, mock_file, mock_makedirs): + """Test to_csv method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", + data=[{"groupInfo": {"model": "llama3"}, "rows": [{"metric1": 1.0}]}], + ), + } + + with tempfile.TemporaryDirectory() as temp_dir: + self.fetcher.to_csv(temp_dir) + + # Check that the directory was created + mock_makedirs.assert_called() + + # Check that the file was opened for writing + mock_file.assert_called_once() + + # Check that to_csv was called + mock_to_csv.assert_called_once() + + def test_to_output_type(self): + """Test _to_output_type method.""" + # Test with string values + self.assertEqual( + self.fetcher._to_output_type("excel"), self.module.OutputType.EXCEL + ) + self.assertEqual( + self.fetcher._to_output_type("print"), self.module.OutputType.PRINT + ) + self.assertEqual( + self.fetcher._to_output_type("csv"), self.module.OutputType.CSV + ) + self.assertEqual( + self.fetcher._to_output_type("json"), self.module.OutputType.JSON + ) + self.assertEqual(self.fetcher._to_output_type("df"), self.module.OutputType.DF) + + # Test with enum values + self.assertEqual( + self.fetcher._to_output_type(self.module.OutputType.EXCEL), + self.module.OutputType.EXCEL, + ) + + # Test with invalid values + self.assertEqual( + self.fetcher._to_output_type("invalid"), self.module.OutputType.JSON + ) + self.assertEqual(self.fetcher._to_output_type(123), self.module.OutputType.JSON) + + @patch("get_benchmark_analysis_data.ExecutorchBenchmarkFetcher.to_json") + @patch("get_benchmark_analysis_data.ExecutorchBenchmarkFetcher.to_df") + @patch("get_benchmark_analysis_data.ExecutorchBenchmarkFetcher.to_excel") + @patch("get_benchmark_analysis_data.ExecutorchBenchmarkFetcher.to_csv") + def test_output_data(self, mock_to_csv, mock_to_excel, mock_to_df, mock_to_json): + """Test output_data method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", data=[{"key": "value"}] + ), + } + + # Test PRINT output + result = self.fetcher.output_data(self.module.OutputType.PRINT) + self.assertEqual(result, {"private": [{"key": "value"}]}) + + # Test JSON output + mock_to_json.return_value = "/path/to/file.json" + result = self.fetcher.output_data(self.module.OutputType.JSON) + self.assertEqual(result, {"private": [{"key": "value"}]}) + mock_to_json.assert_called_once_with(".") + + # Test DF output + mock_to_df.return_value = {"private": [{"df": "value"}]} + result = self.fetcher.output_data(self.module.OutputType.DF) + self.assertEqual(result, {"private": [{"df": "value"}]}) + mock_to_df.assert_called_once() + + # Test EXCEL output + result = self.fetcher.output_data(self.module.OutputType.EXCEL) + self.assertEqual(result, {"private": [{"key": "value"}]}) + mock_to_excel.assert_called_once_with(".") + + # Test CSV output + result = self.fetcher.output_data(self.module.OutputType.CSV) + self.assertEqual(result, {"private": [{"key": "value"}]}) + mock_to_csv.assert_called_once_with(".") + + +if __name__ == "__main__": + unittest.main()