diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index c862dec81fcc..0d2d304156a5 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -40,6 +40,7 @@ def benchmark_config( use_fp8_w8a8: bool, use_int8_w8a16: bool, num_iters: int = 100, + block_quant_shape: List[int] = None, ) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) @@ -81,8 +82,24 @@ def benchmark_config( dtype=torch.float32) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) if use_fp8_w8a8: - w1_scale = torch.randn(num_experts, dtype=torch.float32) - w2_scale = torch.randn(num_experts, dtype=torch.float32) + if block_quant_shape: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + E = num_experts + N = shard_intermediate_size // 2 + K = hidden_size + factor_for_scale = 1e-2 + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1), + dtype=torch.float32) * factor_for_scale + w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2), + dtype=torch.float32) * factor_for_scale + else: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32) @@ -111,6 +128,7 @@ def run(): w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_quant_shape, ) # JIT compilation & warmup @@ -175,7 +193,8 @@ def get_rocm_tuning_space(use_fp16): return param_ranges -def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]: +def get_configs_compute_bound(use_fp16, + block_quant_shape) -> list[dict[str, int]]: configs: list[BenchmarkConfig] = [] if current_platform.is_rocm(): @@ -204,17 +223,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]: for config_values in product(*values): config = dict(zip(keys, config_values)) configs.append(config) + + # Remove configs that are not compatible with fp8 block quantization + # BLOCK_SIZE_K must be a multiple of block_k + # BLOCK_SIZE_N must be a multiple of block_n + if block_quant_shape is not None and not use_fp16: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + for config in configs[:]: + if config["BLOCK_SIZE_K"] % block_k != 0 or config[ + "BLOCK_SIZE_N"] % block_n != 0: + configs.remove(config) return configs def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, - search_space, is_fp16): + search_space, is_fp16, topk): N1, K1 = shard_intermediate_size, hidden_size N2, K2 = hidden_size, shard_intermediate_size // 2 - pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space, - is_fp16) - pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space, - is_fp16) + pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1, + search_space, is_fp16) + pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2, + search_space, is_fp16) search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) return search_space @@ -372,6 +401,7 @@ def tune( use_fp8_w8a8: bool, use_int8_w8a16: bool, search_space: list[dict[str, int]], + block_quant_shape: list[int], ) -> dict[str, int]: best_config = None best_time = float("inf") @@ -380,21 +410,23 @@ def tune( search_space = prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, search_space, - is_fp16) + is_fp16, topk) with torch.cuda.device(self.device_id): for config in tqdm(search_space): try: - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=20) + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=20, + block_quant_shape=block_quant_shape) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. continue @@ -436,8 +468,8 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, shard_intermediate_size: int, hidden_size: int, topk: int, - dtype: torch.dtype, use_fp8_w8a8: bool, - use_int8_w8a16: bool) -> None: + dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_quant_shape: List[int]) -> None: dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) @@ -445,7 +477,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. filename = get_config_file_name(num_experts, shard_intermediate_size // 2, - dtype_str) + dtype_str, block_quant_shape) print(f"Writing best config to {filename}...") with open(filename, "w") as f: @@ -455,7 +487,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, def main(args: argparse.Namespace): print(args) - + block_quant_shape = None config = AutoConfig.from_pretrained( args.model, trust_remote_code=args.trust_remote_code) if config.architectures[0] == "DbrxForCausalLM": @@ -474,6 +506,7 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + block_quant_shape = config.quantization_config['weight_block_size'] else: # Default: Mixtral. E = config.num_local_experts @@ -511,27 +544,30 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: if args.tune: is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = get_configs_compute_bound(is_fp16) + search_space = get_configs_compute_bound(is_fp16, block_quant_shape) print(f"Start tuning over {len(search_space)} configurations...") start = time.time() configs = _distribute( - "tune", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space) - for batch_size in batch_sizes]) + "tune", + [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, + use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape) + for batch_size in batch_sizes]) best_configs = { M: sort_config(config) for M, config in zip(batch_sizes, configs) } save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16) + topk, dtype, use_fp8_w8a8, use_int8_w8a16, + block_quant_shape) end = time.time() print(f"Tuning took {end - start:.2f} seconds") else: outputs = _distribute( - "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16) - for batch_size in batch_sizes]) + "benchmark", + [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, + use_fp8_w8a8, use_int8_w8a16, block_quant_shape) + for batch_size in batch_sizes]) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}, config: {config}") diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json index 2b1167fc71e2..63e118746fd8 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,28 +1,28 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, @@ -31,15 +31,15 @@ "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2, "waves_per_eu": 0 @@ -49,13 +49,13 @@ "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 2, "num_stages": 2, "waves_per_eu": 0 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 2, @@ -64,7 +64,7 @@ }, "48": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 2, @@ -73,7 +73,7 @@ }, "64": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, @@ -82,46 +82,82 @@ }, "96": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "GROUP_SIZE_M": 8, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, "num_stages": 2, "waves_per_eu": 0 }, "256": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 2, "waves_per_eu": 0 }, "512": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, - "num_warps": 8, + "num_warps": 4, "num_stages": 2, "waves_per_eu": 0 }, "1024": { "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, - "num_warps": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, "num_stages": 2, "waves_per_eu": 0 }