Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,15 +1231,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,

config = get_config_func(M)

intermediate_cache1 = torch.empty((M, top_k_num, N),
device=hidden_states.device,
dtype=hidden_states.dtype)
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = cache13[:M * top_k_num * N].view(
(M, topk_ids.shape[1], N))
intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view(
(M, topk_ids.shape[1], w2.shape[1]))

# This needs separate memory since it's used concurrently with cache1
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)

if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
Expand Down