Skip to content

Commit

Permalink
fix: adjust model config vars and other refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jan 24, 2024
1 parent 9939237 commit 9bcd21a
Showing 1 changed file with 14 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
rope_scaling=None,
rope_theta=10000.0,
resid_pdrop=0.1, # llama doesn't have this
partial_rotary_factor=0.5,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -54,6 +55,7 @@ def __init__(
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.resid_pdrop = resid_pdrop
self.partial_rotary_factor = partial_rotary_factor

super().__init__(
pad_token_id=pad_token_id,
Expand All @@ -68,21 +70,13 @@ def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
if config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
bias=True,
)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)

def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
Expand Down Expand Up @@ -130,6 +124,7 @@ def __init__(
)

self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)

if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
Expand Down Expand Up @@ -188,7 +183,7 @@ def forward(
#
# Apply partial positional embeddings in place
self.rotary_emb(
query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads],
query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
cos, sin
)

Expand Down Expand Up @@ -243,7 +238,7 @@ def __init__(self, prefix, config, weights):
)

# llama weights are up_proj and down_proj and bias=False
self.gate_up_proj = TensorParallelRowLinear.load(
self.up_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc1",
weights=weights,
Expand All @@ -259,9 +254,7 @@ def __init__(self, prefix, config, weights):
def forward(self, hidden_states):
# NOTE: Llama requires the gate up states to an intermediate size
# Phi does not and we can avoid the `view` operation
gate_up_states = self.gate_up_proj(hidden_states)
post_act = self.act(gate_up_states)
return self.down_proj(post_act)
return self.down_proj(self.act(self.up_proj(hidden_states)))


class FlashPhiLayer(nn.Module):
Expand Down Expand Up @@ -304,10 +297,7 @@ def forward(
max_s,
)

attn_output = self.resid_dropout(attn_output)

feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_output + feed_forward_hidden_states
hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))

return hidden_states, res

Expand Down

0 comments on commit 9bcd21a

Please sign in to comment.