diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index f7c531e9..b357bd00 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -73,6 +73,106 @@ def parse_args(): return vars(parser.parse_args()) +def gemm_one_shot_all_reduce(A, B, shmem, args_dict): + """ + Core GEMM one-shot all-reduce function that can be reused by both example and tests. + + Args: + A: Input matrix A + B: Input matrix B + shmem: Iris shared memory object + args_dict: Dictionary containing algorithm parameters + + Returns: + global_C: The result matrix after GEMM and all-reduce + """ + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + + # Validate divisibility requirements + assert args_dict["n"] % world_size == 0, f"N ({args_dict['n']}) must be divisible by world size ({world_size})." + assert args_dict["k"] % world_size == 0, f"K ({args_dict['k']}) must be divisible by world size ({world_size})." + + # Splitting + rows_per_gpu = args_dict["k"] // world_size + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + # Create output tensors + global_C = shmem.zeros((args_dict["m"], args_dict["n"]), device="cuda", dtype=A.dtype) + local_C = shmem.zeros((args_dict["m"], args_dict["n"]), device="cuda", dtype=A.dtype) + + # Calculate tile information + total_blocks_M = triton.cdiv(args_dict["m"], args_dict["BLK_M"]) + total_blocks_N = triton.cdiv(args_dict["n"], args_dict["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + if args_dict["gemm_sms"] >= args_dict["total_sms"]: + raise ValueError(f"Invalid number of stream-K SMs. {args_dict['gemm_sms']} >= {args_dict['total_sms']}") + + # Create synchronization tensors + tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + locks = shmem.zeros((args_dict["gemm_sms"],), device="cuda", dtype=torch.int32) + P = shmem.zeros( + (args_dict["gemm_sms"], args_dict["BLK_M"] * args_dict["BLK_N"]), + device="cuda", + dtype=torch.float32, + ) + bias = None + + # Timestamps for tracing (optional) + timestamps = Timestamps(num_tiles=total_tiles) + + def preamble(): + shmem.barrier() + iris.memset_tensor(tile_completed, 0) + shmem.barrier() + + # Prepare for computation + shmem.barrier() + preamble() + shmem.barrier() + + # Run the GEMM + all-reduce + shmem.barrier() + + local_C = matmul.apply( + local_A, + local_B, + local_C, + global_C, + bias, + P, + locks, + tile_completed, + rank, + world_size, + args_dict["gemm_sms"], + args_dict["BLK_M"], + args_dict["BLK_N"], + args_dict["BLK_K"], + args_dict["gsize_m"], + args_dict["two_tiles"], + args_dict["num_stages"], + args_dict["num_warps"], + args_dict["waves_per_eu"], + args_dict["mfmaInstrSize"], + args_dict["kpack"], + shmem.get_heap_bases(), + cu_count, + args_dict.get("trace_tiles", False), + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + + shmem.barrier() + + return global_C + + def main(): args = parse_args() @@ -95,9 +195,6 @@ def main(): print("Unknown datatype.") exit(1) - assert args["n"] % world_size == 0, f"N ({args['n']}) must be divisible by world size ({world_size})." - assert args["k"] % world_size == 0, f"K ({args['k']}) must be divisible by world size ({world_size})." - A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype) B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) @@ -109,41 +206,9 @@ def main(): json_writer = JSONWriter(args["output_file"]) json_writer.add_field("world_size", world_size) - # Splitting - rows_per_gpu = args["k"] // world_size - args["k"] = rows_per_gpu - start_row = rank * rows_per_gpu - end_row = start_row + rows_per_gpu - local_B = B[start_row:end_row, :] - local_A = A[:, start_row:end_row] - for key, value in args.items(): json_writer.add_field(key, value) - global_C = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) - local_C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) - - total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) - total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) - total_tiles = total_blocks_M * total_blocks_N - - if args["gemm_sms"] >= args["total_sms"]: - print(f"Invalid number of stream-K SMs. {args['gemm_sms']} >= {args['total_sms']}") - exit(1) - - tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) - - locks = shmem.zeros((args["gemm_sms"],), device="cuda", dtype=torch.int32) - - P = shmem.zeros( - (args["gemm_sms"], args["BLK_M"] * args["BLK_N"]), - device="cuda", - dtype=torch.float32, - ) - bias = None - - gemm_stream = torch.cuda.Stream() - json_writer.add_field("gemm_sms", args["gemm_sms"]) kernel_timing = { @@ -156,55 +221,22 @@ def main(): } # Timestamps + total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) + total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N timestamps = Timestamps(num_tiles=total_tiles) - def preamble(): - shmem.barrier() - iris.memset_tensor(tile_completed, 0) - shmem.barrier() - def run_experiment(): - nonlocal local_C - nonlocal global_C nonlocal kernel_timing - shmem.barrier() - if args["trace_tiles"]: timestamps.reset() shmem.barrier() torch.cuda.nvtx.range_push("GEMM + Communication") - with torch.cuda.stream(gemm_stream): + with torch.cuda.stream(torch.cuda.Stream()): kernel_timing["gemm"]["start_event"].record() - local_C = matmul.apply( - local_A, - local_B, - local_C, - global_C, - bias, - P, - locks, - tile_completed, - rank, - world_size, - args["gemm_sms"], - args["BLK_M"], - args["BLK_N"], - args["BLK_K"], - args["gsize_m"], - args["two_tiles"], - args["num_stages"], - args["num_warps"], - args["waves_per_eu"], - args["mfmaInstrSize"], - args["kpack"], - shmem.get_heap_bases(), - cu_count, - args["trace_tiles"], - timestamps.mm_begin_timestamp, - timestamps.mm_end_timestamp, - ) + global_C = gemm_one_shot_all_reduce(A, B, shmem, args) kernel_timing["gemm"]["end_event"].record() kernel_timing["gemm"]["experiments"] += 1 @@ -214,15 +246,15 @@ def run_experiment(): for k in ["gemm"]: ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) kernel_timing[k]["ms"] += ms + + return global_C # Synchronize across all GPUs shmem.barrier() # Warmup - run_experiment() + global_C = run_experiment() - shmem.barrier() - preamble() shmem.barrier() for k in ["gemm"]: @@ -253,6 +285,10 @@ def run_experiment(): if args["benchmark"]: shmem.info("Benchmarking...") perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) + + def preamble(): + shmem.barrier() + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) triton_tflops = perf(triton_ms) shmem.info(f"tile matmul + all_reduce (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") diff --git a/tests/examples/test_gemm_one_shot_all_reduce.py b/tests/examples/test_gemm_one_shot_all_reduce.py new file mode 100644 index 00000000..3fce38cf --- /dev/null +++ b/tests/examples/test_gemm_one_shot_all_reduce.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for the 09_gemm_one_shot_all_reduce example. + +This test suite provides comprehensive testing for the GEMM one-shot all-reduce +algorithm implementation. Tests expect ROCm/HIP to be available in CI environment. +""" + +import pytest +import torch +import triton +import numpy as np +import sys +import os + +import importlib.util +from pathlib import Path + +# Add the project root to Python path to help with imports +current_dir = Path(__file__).parent +project_root = (current_dir / "../..").resolve() +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# Import the example module +example_dir = (project_root / "examples/09_gemm_one_shot_all_reduce").resolve() +if str(example_dir) not in sys.path: + sys.path.insert(0, str(example_dir)) + +# Import necessary modules +import iris +from examples.common.validation import validate_gemm + +# Import the benchmark module +benchmark_file = example_dir / "benchmark.py" +spec = importlib.util.spec_from_file_location("benchmark", benchmark_file) +benchmark_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(benchmark_module) + + +def test_gemm_one_shot_all_reduce_import(): + """Test that the benchmark module can be imported and has required functions.""" + assert hasattr(benchmark_module, "main"), "Benchmark module should have a main function" + assert hasattr(benchmark_module, "parse_args"), "Benchmark module should have a parse_args function" + assert hasattr(benchmark_module, "gemm_one_shot_all_reduce"), ( + "Benchmark module should have a gemm_one_shot_all_reduce function" + ) + + +def test_parse_args(): + """Test argument parsing functionality.""" + # Temporarily replace sys.argv to test argument parsing + original_argv = sys.argv + try: + # Test with minimal arguments + sys.argv = ["benchmark.py", "-m", "128", "-n", "128", "-k", "128", "--validate"] + + args = benchmark_module.parse_args() + + # Check that arguments are parsed correctly + assert args["m"] == 128, f"Expected m=128, got {args['m']}" + assert args["n"] == 128, f"Expected n=128, got {args['n']}" + assert args["k"] == 128, f"Expected k=128, got {args['k']}" + assert args["validate"], f"Expected validate=True, got {args['validate']}" + + # Check that defaults are set + assert "datatype" in args, "Args should contain datatype" + assert "BLK_M" in args, "Args should contain BLK_M" + assert "BLK_N" in args, "Args should contain BLK_N" + assert "BLK_K" in args, "Args should contain BLK_K" + + finally: + sys.argv = original_argv + + +@pytest.mark.parametrize( + "M, N, K, world_size", + [ + (256, 256, 256, 2), # Basic case with 2 ranks + (512, 512, 512, 4), # Larger case with 4 ranks + ], +) +def test_matrix_dimension_divisibility(M, N, K, world_size): + """Test that matrix dimensions are properly divisible by world size as required by the algorithm.""" + # Test the assertions that are made in the benchmark code + assert N % world_size == 0, f"N ({N}) must be divisible by world size ({world_size})" + assert K % world_size == 0, f"K ({K}) must be divisible by world size ({world_size})" + + # Test matrix splitting logic + rows_per_gpu = K // world_size + assert rows_per_gpu > 0, "Each GPU should get at least one row" + assert rows_per_gpu * world_size == K, "Total rows should equal K" + + +@pytest.mark.parametrize( + "datatype, M, N, K", + [ + (torch.float16, 128, 128, 128), + (torch.float32, 128, 128, 128), + (torch.bfloat16, 128, 128, 128), + ], +) +def test_gemm_one_shot_all_reduce_function(datatype, M, N, K): + """Test the core gemm_one_shot_all_reduce function with different data types.""" + # Set up iris environment + heap_size = 1 << 30 # 1GB heap + shmem = iris.iris(heap_size) + world_size = shmem.get_num_ranks() + + # Skip if dimensions are not divisible by world size + if N % world_size != 0 or K % world_size != 0: + pytest.skip(f"Matrix dimensions ({M}x{N}x{K}) not divisible by world_size ({world_size})") + + # Create input matrices + A = shmem.randn(M, K, device="cuda", dtype=datatype) + B = shmem.randn(N, K, device="cuda", dtype=datatype).T + + # Set up algorithm parameters + args_dict = { + "m": M, + "n": N, + "k": K, + "BLK_M": 64, # Smaller blocks for testing + "BLK_N": 64, + "BLK_K": 32, + "gsize_m": 1, + "two_tiles": True, + "num_stages": 1, + "num_warps": 4, + "waves_per_eu": 0, + "mfmaInstrSize": 16, + "kpack": 1, + "gemm_sms": min(64, 288), # Reduced for testing + "total_sms": 304, + "trace_tiles": False, + } + + # Run the GEMM one-shot all-reduce + result_C = benchmark_module.gemm_one_shot_all_reduce(A, B, shmem, args_dict) + + # Basic shape and type checks + assert result_C.shape == (M, N), f"Expected output shape ({M}, {N}), got {result_C.shape}" + assert result_C.dtype == datatype, f"Expected output dtype {datatype}, got {result_C.dtype}" + + # Validate the result using the existing validation function + success = validate_gemm(A, B, result_C, shmem, atol=2) + assert success, "GEMM validation failed" + + +def test_block_size_calculations(): + """Test block size calculations used in the GEMM kernel.""" + # Test triton.cdiv functionality which is used in the benchmark + M, N, K = 1000, 2000, 3000 + BLK_M, BLK_N, BLK_K = 256, 256, 32 + + # Test ceiling division + import math + + expected_blocks_M = math.ceil(M / BLK_M) + expected_blocks_N = math.ceil(N / BLK_N) + expected_blocks_K = math.ceil(K / BLK_K) + + actual_blocks_M = triton.cdiv(M, BLK_M) + actual_blocks_N = triton.cdiv(N, BLK_N) + actual_blocks_K = triton.cdiv(K, BLK_K) + + assert actual_blocks_M == expected_blocks_M, ( + f"Block M calculation mismatch: {actual_blocks_M} != {expected_blocks_M}" + ) + assert actual_blocks_N == expected_blocks_N, ( + f"Block N calculation mismatch: {actual_blocks_N} != {expected_blocks_N}" + ) + assert actual_blocks_K == expected_blocks_K, ( + f"Block K calculation mismatch: {actual_blocks_K} != {expected_blocks_K}" + ) + + +def test_file_structure(): + """Test that all required files exist in the example directory.""" + example_dir = Path(__file__).parent / "../../examples/09_gemm_one_shot_all_reduce" + example_dir = example_dir.resolve() + + required_files = [ + "benchmark.py", + "gemm_one_shot_all_reduce.py", + "matmul_wrapper.py", + ] + + for filename in required_files: + file_path = example_dir / filename + assert file_path.exists(), f"Required file {filename} not found in {example_dir}" + assert file_path.is_file(), f"{filename} exists but is not a file" + + # Check that file is not empty + assert file_path.stat().st_size > 0, f"{filename} is empty" + + +def test_algorithm_parameters_validation(): + """Test validation of algorithm parameters.""" + # Test that invalid gemm_sms configuration is caught + args_dict = { + "gemm_sms": 350, # Greater than total_sms + "total_sms": 304, + } + + with pytest.raises(ValueError, match="Invalid number of stream-K SMs"): + # This should raise an error + heap_size = 1 << 30 + shmem = iris.iris(heap_size) + A = shmem.randn(64, 64, device="cuda", dtype=torch.float16) + B = shmem.randn(64, 64, device="cuda", dtype=torch.float16).T + + # Add required parameters + args_dict.update( + { + "m": 64, + "n": 64, + "k": 64, + "BLK_M": 32, + "BLK_N": 32, + "BLK_K": 16, + "gsize_m": 1, + "two_tiles": True, + "num_stages": 1, + "num_warps": 4, + "waves_per_eu": 0, + "mfmaInstrSize": 16, + "kpack": 1, + } + ) + + benchmark_module.gemm_one_shot_all_reduce(A, B, shmem, args_dict)