diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bc9573b36df7..a9f7429373bc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1231,15 +1231,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) - intermediate_cache1 = torch.empty((M, top_k_num, N), - device=hidden_states.device, - dtype=hidden_states.dtype) + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache1 = cache13[:M * top_k_num * N].view( + (M, topk_ids.shape[1], N)) + intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view( + (M, topk_ids.shape[1], w2.shape[1])) + + # This needs separate memory since it's used concurrently with cache1 intermediate_cache2 = torch.empty((M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16