Skip to content

Commit

Permalink
add learnable outer product feature map
Browse files Browse the repository at this point in the history
  • Loading branch information
sustcsonglin committed Feb 23, 2024
1 parent 92cfa7a commit 9763394
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions fla/modules/featue_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@ class DPFPFeatureMap(nn.Module):
def __init__(self, head_dim: int, nu: int = 4):
super().__init__()
self.nu = nu

def forward(self, x: torch.Tensor):
x = torch.cat([x.relu(), -x.relu()], dim=-1)
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1)
x_repeat = torch.cat([x] * self.nu, dim=-1)
return x_repeat * x_rolled


class HadamardFeatureMap(nn.Module):
def __init__(self, head_dim: int):
super().__init__()
Expand All @@ -54,4 +53,23 @@ def __init__(self, head_dim: int):

def forward(self, x: torch.Tensor):
return self.layer1(x) * self.layer2(x)



def flatten_outer_product(x, y):
z = x.unsqueeze(-1) * y.unsqueeze(-2)
N = z.size(-1)
indicies = torch.triu_indices(N, N)
indicies = N * indicies[0] + indicies[1]
return z.flatten(-2)[..., indicies]

class LearnableOuterProductFeatureMap(nn.Module):
def __init__(self, head_dim: int, feature_dim: int):
super().__init__()
# Trainable map
self.layer1 = nn.Linear(head_dim, feature_dim, bias=False)
self.layer2 = nn.Linear(head_dim, feature_dim, bias=False)
self.normalizer = feature_dim ** -0.5

def forward(self, x: torch.Tensor):
# x = x * self.normalizer
return flatten_outer_product(self.layer1(x), self.layer2(x))

0 comments on commit 9763394

Please sign in to comment.