Skip to content

Commit

Permalink
Return to regular kernel usage
Browse files Browse the repository at this point in the history
  • Loading branch information
christinaflo committed Nov 10, 2023
1 parent b7f35dc commit 54d414e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies:
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pip:
- deepspeed==0.12.2
- git+https://github.com/microsoft/DeepSpeed.git@4388a60 # Replace when version becomes available
- dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
13 changes: 6 additions & 7 deletions openfold/model/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import deepspeed

if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import EvoformerFusedAttention
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention

fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed:
Expand Down Expand Up @@ -661,19 +661,18 @@ def reshape_dims(x):
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]

biases.extend([None] * (2 - len(biases)))

# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
inputs_bf16 = [x.to(dtype=torch.bfloat16) if x is not None else x
for x in (q, k, v, biases[0], biases[1])]
o = EvoformerFusedAttention.apply(*inputs_bf16)
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])

o = o.to(dtype=orig_dtype)
else:
o = EvoformerFusedAttention.apply(q, k, v, biases[0], biases[1])
o = DS4Sci_EvoformerAttention(q, k, v, biases)

o = o.reshape(orig_shape)
return o
Expand Down

0 comments on commit 54d414e

Please sign in to comment.