Skip to content

Commit

Permalink
Provide clamp_min option
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Feb 17, 2024
1 parent e4dcdb8 commit fb4408b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
13 changes: 9 additions & 4 deletions fla/layers/gla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -27,6 +29,7 @@ def __init__(
gate_logit_normalizer: int = 16,
gate_low_rank_dim: int = 16,
mode: str = 'fused_chunk',
clamp_min: Optional[float] = None,
fuse_norm: bool = True,
*args, **kwargs
) -> GatedLinearAttention:
Expand All @@ -36,6 +39,8 @@ def __init__(
self.mode = mode
self.value_dim = int(d_model * expand_v)
self.key_dim = int(d_model * expand_k)
self.clamp_min = clamp_min

assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
Expand Down Expand Up @@ -82,6 +87,8 @@ def forward(self, x):
gk = rearrange(self.gk_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
gk = (F.logsigmoid(gk) / self.gate_logit_normalizer)

if self.clamp_min is not None:
gk = torch.clamp_min(gk, self.clamp_min)
if mode == 'fused_recurrent':
o = fused_recurrent_gla(q, k, v, gk, None)
elif mode == 'fused_chunk':
Expand Down Expand Up @@ -119,10 +126,8 @@ def forward(self, x):
# print(x.grad.shape)

for act in ['swish']:
org = GatedLinearAttention(
d_model=d_model, gate_fn=act, fuse_norm=False).to(torch.bfloat16).cuda()
fused = GatedLinearAttention(
d_model=d_model, gate_fn=act, fuse_norm=True).to(torch.bfloat16).cuda()
org = GatedLinearAttention(d_model=d_model, gate_fn=act, fuse_norm=False).to(torch.bfloat16).cuda()
fused = GatedLinearAttention(d_model=d_model, gate_fn=act, fuse_norm=True).to(torch.bfloat16).cuda()
fused.q_proj.weight.data.copy_(org.q_proj.weight.data)
fused.k_proj.weight.data.copy_(org.k_proj.weight.data)
fused.v_proj.weight.data.copy_(org.v_proj.weight.data)
Expand Down
2 changes: 2 additions & 0 deletions fla/models/gla/configuration_gla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
num_hidden_layers: int = 24,
num_attention_heads: int = 8,
num_key_value_heads: Optional[int] = None,
clamp_min: Optional[float] = None,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
rms_norm_eps: float = 1e-6,
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.clamp_min = clamp_min
self.hidden_act = hidden_act
self.rms_norm_eps = rms_norm_eps
self.use_gk = use_gk
Expand Down
1 change: 1 addition & 0 deletions fla/models/gla/modeling_gla.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, config: GLAConfig, layer_idx: int):
num_heads=config.num_attention_heads,
gate_fn=config.hidden_act,
layernorm_eps=config.rms_norm_eps,
clamp_min=config.clamp_min,
fuse_norm=config.fuse_norm,
layer_idx=layer_idx
)
Expand Down

0 comments on commit fb4408b

Please sign in to comment.