Skip to content

Commit

Permalink
value residual learning (#312)
Browse files Browse the repository at this point in the history
* cite

* add value residual learning

* oops

* slip in value residual learning for pairformer stack

* also cite Nguyen, whose initial paper led here
  • Loading branch information
lucidrains authored Nov 2, 2024
1 parent 9e5bb92 commit 95d3ab6
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 37 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,23 @@ docker run -v .:/data --gpus all -it af3
url = {https://api.semanticscholar.org/CorpusID:267657558}
}
```

```bibtex
@article{Nguyen2023MitigatingOI,
title = {Mitigating Over-smoothing in Transformers via Regularized Nonlocal Functionals},
author = {Tam Nguyen and Tan M. Nguyen and Richard G. Baraniuk},
journal = {ArXiv},
year = {2023},
volume = {abs/2312.00751},
url = {https://api.semanticscholar.org/CorpusID:264300597}
}
```

```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
128 changes: 98 additions & 30 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,10 @@ def forward(
self,
x: Float['... n d'],
**kwargs
) -> Float['... n d']:
) -> (
Float['... n d'] |
tuple[Float['... n d'] | Any]
):

x = self.norm(x)
return self.fn(x, **kwargs)
Expand Down Expand Up @@ -678,13 +681,26 @@ def forward(
*,
cond: Float['b n dc'],
**kwargs
) -> Float['b n d']:
) -> (
Float['b n d'] |
tuple[Float['b n d'], Float['b _ _']]
):
x = self.adaptive_norm(x, cond = cond)

out = self.fn(x, **kwargs)

tuple_output = isinstance(out, tuple)

if tuple_output:
out, *rest = out

gamma = self.to_adaln_zero_gamma(cond)
return out * gamma
out = out * gamma

if tuple_output:
out = (out, *rest)

return out

# triangle multiplicative module
# seems to be unchanged from alphafold2
Expand Down Expand Up @@ -762,7 +778,10 @@ def __init__(self, *, heads, dim_pairwise, window_size=None, num_memory_kv=0, **
self.window_size = window_size

self.attn = Attention(
heads=heads, window_size=window_size, num_memory_kv=num_memory_kv, **attn_kwargs
heads = heads,
window_size = window_size,
num_memory_kv = num_memory_kv,
**attn_kwargs
)

# line 8 of Algorithm 24
Expand All @@ -777,8 +796,14 @@ def forward(
*,
pairwise_repr: Float["b n n dp"] | Float["b nw w (w*2) dp"], # type: ignore
attn_bias: Float["b n n"] | Float["b nw w (w*2)"] | None = None, # type: ignore
return_values: bool = False,
value_residual: Float['b _ _'] | None = None,
**kwargs,
) -> Float["b n ds"]: # type: ignore
) -> (
Float['b n ds'] |
tuple[Float['b n ds'], Float['b _ _']]
): # type: ignore

"""Perform the forward pass.
:param single_repr: The single representation tensor.
Expand Down Expand Up @@ -837,9 +862,22 @@ def forward(
else:
attn_bias = self.to_attn_bias(self.to_attn_bias_norm(pairwise_repr)) + attn_bias

out = self.attn(single_repr, attn_bias=attn_bias, **kwargs)
# attention

return out
out, values = self.attn(
single_repr,
attn_bias = attn_bias,
value_residual = value_residual,
return_values = True,
**kwargs
)

# whether to return values for value residual learning

if not return_values:
return out

return out, values

class TriangleAttention(Module):
def __init__(
Expand Down Expand Up @@ -1360,6 +1398,7 @@ def __init__(
dropout_row_prob = 0.25,
num_register_tokens = 0,
checkpoint = False,
add_value_residual = False,
pairwise_block_kwargs: dict = dict(),
pair_bias_attn_kwargs: dict = dict()
):
Expand Down Expand Up @@ -1395,6 +1434,8 @@ def __init__(

self.layers = layers

self.add_value_residual = add_value_residual

# checkpointing

self.checkpoint = checkpoint
Expand Down Expand Up @@ -1423,6 +1464,8 @@ def to_layers(

) -> Tuple[Float['b n ds'], Float['b n n dp']]:

value_residual = None

for _ in range(self.recurrent_depth):
for (
pairwise_block,
Expand All @@ -1432,7 +1475,13 @@ def to_layers(

pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)

single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
attn_out, attn_values = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = value_residual)

single_repr = single_repr + attn_out

if self.add_value_residual:
value_residual = default(value_residual, attn_values)

single_repr = single_transition(single_repr) + single_repr

return single_repr, pairwise_repr
Expand All @@ -1447,30 +1496,35 @@ def to_checkpointed_layers(

) -> Tuple[Float['b n ds'], Float['b n n dp']]:

inputs = (single_repr, pairwise_repr, mask)
inputs = (single_repr, pairwise_repr, mask, None)

def pairwise_block_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask = inputs
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
pairwise_repr = layer(pairwise_repr = pairwise_repr, mask = mask)
return single_repr, pairwise_repr, mask
return single_repr, pairwise_repr, mask, maybe_value_residual
return inner

def pair_bias_attn_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask = inputs
single_repr = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
return single_repr, pairwise_repr, mask
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
attn_out, attn_values = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = maybe_value_residual)
single_repr = single_repr + attn_out

if self.add_value_residual:
maybe_value_residual = default(maybe_value_residual, attn_values)

return single_repr, pairwise_repr, mask, maybe_value_residual
return inner

def single_transition_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask = inputs
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
single_repr = layer(single_repr) + single_repr
return single_repr, pairwise_repr, mask
return single_repr, pairwise_repr, mask, maybe_value_residual
return inner

wrapped_layers = []
Expand All @@ -1489,7 +1543,7 @@ def inner(inputs, *args, **kwargs):
for layer in wrapped_layers:
inputs = checkpoint(layer, inputs)

single_repr, pairwise_repr, _ = inputs
single_repr, pairwise_repr, *_ = inputs
return single_repr, pairwise_repr

@typecheck
Expand Down Expand Up @@ -1915,9 +1969,9 @@ def __init__(
attn_num_memory_kv = False,
trans_expansion_factor = 2,
num_register_tokens = 0,
add_residual = True,
use_linear_attn = False,
checkpoint = False,
add_value_residual = False,
linear_attn_kwargs = dict(
heads = 8,
dim_head = 16
Expand Down Expand Up @@ -1997,7 +2051,7 @@ def __init__(

self.layers = layers

self.add_residual = add_residual
self.add_value_residual = add_value_residual

self.has_registers = num_register_tokens > 0
self.num_registers = num_register_tokens
Expand All @@ -2018,32 +2072,37 @@ def to_checkpointed_serial_layers(
windowed_mask: Bool['b nw w (w*2)'] | None = None
):

inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask)
inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask, None)

wrapped_layers = []

def efficient_attn_wrapper(fn):
@wraps(fn)
def inner(inputs):
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
noised_repr = fn(noised_repr, mask = mask) + noised_repr
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
return inner

def attn_wrapper(fn):
@wraps(fn)
def inner(inputs):
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
noised_repr = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask) + noised_repr
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
attn_out, attn_values = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask, value_residual = maybe_value_residual, return_values = True)
noised_repr = attn_out + noised_repr

if self.add_value_residual:
maybe_value_residual = default(maybe_value_residual, attn_values)

return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
return inner

def transition_wrapper(fn):
@wraps(fn)
def inner(inputs):
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
noised_repr = fn(noised_repr, cond = single_repr) + noised_repr
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
return inner

for linear_attn, colt5_attn, attn, transition in self.layers:
Expand Down Expand Up @@ -2074,6 +2133,8 @@ def to_serial_layers(
windowed_mask: Bool['b nw w (w*2)'] | None = None
):

value_residual = None

for linear_attn, colt5_attn, attn, transition in self.layers:

if exists(linear_attn):
Expand All @@ -2082,13 +2143,20 @@ def to_serial_layers(
if exists(colt5_attn):
noised_repr = colt5_attn(noised_repr, mask = mask) + noised_repr

noised_repr = attn(
attn_out, attn_values = attn(
noised_repr,
cond = single_repr,
pairwise_repr = pairwise_repr,
mask = mask,
windowed_mask = windowed_mask
) + noised_repr
windowed_mask = windowed_mask,
return_values = True,
value_residual = value_residual
)

noised_repr = noised_repr + attn_out

if self.add_value_residual:
value_residual = default(value_residual, attn_values)

noised_repr = transition(
noised_repr,
Expand Down
27 changes: 24 additions & 3 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,29 @@ def forward(
mask: Bool['b n']| None = None,
context: Float['b j d'] | None = None,
windowed_mask: Bool['b nw w (w*2)'] | None = None,
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None,
return_values: bool = False,
value_residual: Float['b j dh'] | None = None,

) -> Float['b i d']:
) -> (
Float['b i d'] |
tuple[Float['b i d'], Float['b j _']]
):

q = self.to_q(seq)

context_seq = default(context, seq)
k, v = self.to_kv(context_seq).chunk(2, dim = -1)

# handle value residual

orig_v = v

if exists(value_residual):
v = 0.5 * (v + value_residual)

# split heads

q, k, v = tuple(self.split_heads(t) for t in (q, k, v))

# attention
Expand All @@ -270,7 +284,14 @@ def forward(

# combine heads

return self.to_out(out)
out = self.to_out(out)

# maybe return values

if not return_values:
return out

return out, orig_v

# the main attention function

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.6.5"
version = "0.6.6"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down
Loading

0 comments on commit 95d3ab6

Please sign in to comment.