Skip to content

Commit

Permalink
[BugFix] 1D query fix for MoE models (3597)
Browse files Browse the repository at this point in the history
MoE models were broken by vLLM PR 3236.
  • Loading branch information
njhill committed Mar 24, 2024
1 parent f79c766 commit c6a36a9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
10 changes: 6 additions & 4 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,21 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data

# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1)

# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(inputs)
vllm_states = vllm_moe.forward(inputs)
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs)

mixtral_moe_tol = {
torch.float32: 1e-3,
torch.float16: 1e-3,
torch.bfloat16: 1e-2,
}

assert torch.allclose(hf_states,
assert torch.allclose(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
7 changes: 3 additions & 4 deletions vllm/model_executor/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def pack_params(self):
self.w2 = self.w2.view(len(w2), *w2s[0].shape)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.config.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (batch * sequence_length, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
Expand All @@ -169,8 +169,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states.view(batch_size, sequence_length,
hidden_dim)
return final_hidden_states.view(num_tokens, hidden_dim)


class DeepseekAttention(nn.Module):
Expand Down
7 changes: 3 additions & 4 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
param_data[expert_id, :, :] = loaded_weight[:, shard]

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_size = hidden_states.shape
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
Expand All @@ -140,8 +140,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states.view(batch_size, sequence_length,
hidden_size)
return final_hidden_states.view(num_tokens, hidden_size)


class MixtralAttention(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/mixtral_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def __init__(
linear_method=None)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
Expand All @@ -158,7 +158,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states.add_(current_hidden_states)

return tensor_model_parallel_all_reduce(final_hidden_states).view(
batch_size, sequence_length, hidden_dim)
num_tokens, hidden_dim)


class MixtralAttention(nn.Module):
Expand Down

0 comments on commit c6a36a9

Please sign in to comment.