diff --git a/fla/modules/featue_map.py b/fla/modules/featue_map.py index d06d20456..a0efe1d35 100644 --- a/fla/modules/featue_map.py +++ b/fla/modules/featue_map.py @@ -37,14 +37,13 @@ class DPFPFeatureMap(nn.Module): def __init__(self, head_dim: int, nu: int = 4): super().__init__() self.nu = nu - + def forward(self, x: torch.Tensor): x = torch.cat([x.relu(), -x.relu()], dim=-1) x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1) x_repeat = torch.cat([x] * self.nu, dim=-1) return x_repeat * x_rolled - class HadamardFeatureMap(nn.Module): def __init__(self, head_dim: int): super().__init__() @@ -54,4 +53,23 @@ def __init__(self, head_dim: int): def forward(self, x: torch.Tensor): return self.layer1(x) * self.layer2(x) - \ No newline at end of file + + +def flatten_outer_product(x, y): + z = x.unsqueeze(-1) * y.unsqueeze(-2) + N = z.size(-1) + indicies = torch.triu_indices(N, N) + indicies = N * indicies[0] + indicies[1] + return z.flatten(-2)[..., indicies] + +class LearnableOuterProductFeatureMap(nn.Module): + def __init__(self, head_dim: int, feature_dim: int): + super().__init__() + # Trainable map + self.layer1 = nn.Linear(head_dim, feature_dim, bias=False) + self.layer2 = nn.Linear(head_dim, feature_dim, bias=False) + self.normalizer = feature_dim ** -0.5 + + def forward(self, x: torch.Tensor): + # x = x * self.normalizer + return flatten_outer_product(self.layer1(x), self.layer2(x))