Skip to content

Commit

Permalink
Merge pull request #378 from aqlaboratory/deepspeed-evo-attention
Browse files Browse the repository at this point in the history
Deepspeed evoformer attention
  • Loading branch information
christinaflo authored Dec 8, 2023
2 parents 2dc080c + 40d7635 commit a13c0ce
Show file tree
Hide file tree
Showing 16 changed files with 595 additions and 81 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ kernels support in-place attention during inference and training. They use
implementations, respectively.
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
- **FlashAttention** support greatly speeds up MSA attention.
- **DeepSpeed DS4Sci_EvoformerAttention kernel** is a memory-efficient attention kernel developed as part of a collaboration between OpenFold and the DeepSpeed4Science initiative. The kernel provides substantial speedups for training and inference, and significantly reduces the model's peak device memory requirement by 13X. The model is 15% faster during the initial training and finetuning stages, and up to 4x faster during inference. To use this feature, simply set the `use_deepspeed_evo_attention` option in `openfold/config.py`.

## Installation (Linux)

All Python dependencies are specified in `environment.yml`. For producing sequence
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite),
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)}
installed on on your system. You'll need `git-lfs` to download OpenFold parameters.
installed on your system. You'll need `git-lfs` to download OpenFold parameters.
Finally, some download scripts require `aria2c` and `aws`.

This package is currently supported for CUDA 11 and Pytorch 1.12
Expand Down
7 changes: 5 additions & 2 deletions deepspeed_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
},
"zero_optimization": {
"stage": 2,
"cpu_offload": true,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true
},
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": false,
"profile": false
},
"gradient_clipping": 0.1
"gradient_clipping": 0.1,
"zero_force_ds_cpu_optimizer": false
}
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies:
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pip:
- deepspeed==0.5.10
- deepspeed==0.12.4
- dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
32 changes: 23 additions & 9 deletions openfold/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,28 @@ def string_to_setting(s):
(
"globals.use_lma",
"globals.use_flash",
"globals.use_deepspeed_evo_attention"
),
]

for s1, s2 in mutually_exclusive_bools:
s1_setting = string_to_setting(s1)
s2_setting = string_to_setting(s2)
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
for options in mutually_exclusive_bools:
option_settings = [string_to_setting(o) for o in options]
if sum(option_settings) > 1:
raise ValueError(f"Only one of {', '.join(options)} may be set at a time")

fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
if config.globals.use_flash and not fa_is_installed:
raise ValueError("use_flash requires that FlashAttention is installed")

deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
if config.globals.use_deepspeed_evo_attention and not ds4s_is_installed:
raise ValueError(
"use_deepspeed_evo_attention requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)

if(
config.globals.offload_inference and
not config.model.template.average_templates
Expand Down Expand Up @@ -193,7 +202,8 @@ def model_config(
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
c.globals.use_lma = True
# Default to DeepSpeed memory-efficient attention kernel unless use_lma is explicitly set
c.globals.use_deepspeed_evo_attention = True if not c.globals.use_lma else False
c.globals.use_flash = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
Expand Down Expand Up @@ -419,11 +429,15 @@ def model_config(
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
# Use DeepSpeed memory-efficient attention kernel. Mutually
# exclusive with use_lma and use_flash.
"use_deepspeed_evo_attention": False,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
# exclusive with use_deepspeed_evo_attention and use_flash.
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues).
# use_deepspeed_evo_attention and use_lma. Doesn't work that well
# on long sequences (>1000 residues).
"use_flash": False,
"offload_inference": False,
"c_z": c_z,
Expand Down
43 changes: 35 additions & 8 deletions openfold/model/evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def _chunk(self,
no_batch_dims=len(m.shape[:-2]),
)


def forward(
self,
m: torch.Tensor,
Expand Down Expand Up @@ -181,6 +180,7 @@ def forward(self,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
Expand Down Expand Up @@ -260,6 +260,7 @@ def forward(self,
mask=pair_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
Expand All @@ -279,6 +280,7 @@ def forward(self,
mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
Expand Down Expand Up @@ -339,7 +341,7 @@ def __init__(self,

# Specifically, seqemb mode does not use column attention
self.no_column_attention = no_column_attention
if self.no_column_attention == False:
if not self.no_column_attention:
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
Expand Down Expand Up @@ -369,6 +371,7 @@ def forward(self,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
Expand Down Expand Up @@ -396,19 +399,21 @@ def forward(self,
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
)
),
inplace=inplace_safe,
)

# Specifically, column attention is not used in seqemb mode.
if self.no_column_attention == False:
if not self.no_column_attention:
m = add(m,
self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
),
Expand All @@ -424,7 +429,8 @@ def forward(self,
input_tensors,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
Expand Down Expand Up @@ -500,6 +506,7 @@ def forward(self,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
Expand All @@ -526,7 +533,8 @@ def forward(self,
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
use_memory_efficient_kernel=not use_lma,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention),
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
Expand Down Expand Up @@ -560,6 +568,7 @@ def fn(input_tensors):
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
Expand Down Expand Up @@ -685,6 +694,7 @@ def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool,
use_flash: bool,
msa_mask: Optional[torch.Tensor],
Expand All @@ -698,6 +708,7 @@ def _prep_blocks(self,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
inplace_safe=inplace_safe,
Expand Down Expand Up @@ -737,6 +748,7 @@ def _forward_offload(self,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True,
Expand All @@ -748,6 +760,7 @@ def _forward_offload(self,
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
Expand Down Expand Up @@ -779,6 +792,7 @@ def forward(self,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
Expand All @@ -797,10 +811,15 @@ def forward(self,
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory efficient kernel.
Mutually exclusive with use_lma and use_flash.
use_lma:
Whether to use low-memory attention during inference.
Mutually exclusive with use_flash and use_deepspeed_evo_attention.
use_flash:
Whether to use FlashAttention where possible. Mutually
exclusive with use_lma.
exclusive with use_lma and use_deepspeed_evo_attention.
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
Expand All @@ -813,6 +832,7 @@ def forward(self,
m=m,
z=z,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
Expand Down Expand Up @@ -893,6 +913,7 @@ def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
Expand All @@ -904,7 +925,8 @@ def _prep_blocks(self,
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
Expand Down Expand Up @@ -941,6 +963,7 @@ def clear_cache(b, *args, **kwargs):
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
Expand All @@ -953,6 +976,7 @@ def _forward_offload(self,
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
Expand All @@ -979,6 +1003,7 @@ def forward(self,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
Expand All @@ -990,6 +1015,7 @@ def forward(self,
z:
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_deepspeed_evo_attention: Whether to use DeepSpeed memory-efficient kernel
use_lma: Whether to use low-memory attention during inference
msa_mask:
Optional [*, N_extra, N_res] MSA mask
Expand All @@ -1003,6 +1029,7 @@ def forward(self,
m=m,
z=z,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
Expand Down
5 changes: 5 additions & 0 deletions openfold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe):
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
Expand Down Expand Up @@ -374,6 +375,7 @@ def iteration(self, feats, prevs, _recycle=True):
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
Expand All @@ -386,6 +388,7 @@ def iteration(self, feats, prevs, _recycle=True):
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
Expand All @@ -404,6 +407,7 @@ def iteration(self, feats, prevs, _recycle=True):
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
Expand All @@ -416,6 +420,7 @@ def iteration(self, feats, prevs, _recycle=True):
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,
Expand Down
Loading

0 comments on commit a13c0ce

Please sign in to comment.