Skip to content

Commit

Permalink
Use EvoformerFusedAttention directly to avoid all-zero bias term in c…
Browse files Browse the repository at this point in the history
…olumn attention
  • Loading branch information
christinaflo committed Nov 8, 2023
1 parent 5aa5495 commit b7f35dc
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions openfold/model/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import deepspeed

if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
from deepspeed.ops.deepspeed4science import EvoformerFusedAttention

fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed:
Expand Down Expand Up @@ -661,18 +661,19 @@ def reshape_dims(x):
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]

biases.extend([None] * (2 - len(biases)))

# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
inputs_bf16 = [x.to(dtype=torch.bfloat16) if x is not None else x
for x in (q, k, v, biases[0], biases[1])]
o = EvoformerFusedAttention.apply(*inputs_bf16)

o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = EvoformerFusedAttention.apply(q, k, v, biases[0], biases[1])

o = o.reshape(orig_shape)
return o
Expand Down

0 comments on commit b7f35dc

Please sign in to comment.