-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add openfold triton code * Add copyright headers * Mention dependency on einops in README * Remove _test_fused_adam_swa.py
- Loading branch information
1 parent
6a77872
commit 58acf96
Showing
10 changed files
with
2,847 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# OpenFold triton kernels | ||
|
||
This subpackage is a collection of Triton kernels written specifically for the OpenFold model architecture initial training mode. | ||
|
||
To use this subpackage, you must install additional dependencies: | ||
|
||
```bash | ||
pip install einops | ||
``` | ||
|
||
The following sections list all main features and show how to use them. | ||
|
||
## Multi-Head Attention | ||
|
||
```python | ||
import apex.contrib.openfold_triton.mha as mha | ||
from apex.contrib.openfold_triton import AttnBiasJIT, AttnNoBiasJIT, AttnTri, CanSchTriMHA | ||
|
||
# Integration with Attention module: | ||
class SelfAttentionWithGate(nn.Module): | ||
# ... | ||
|
||
def _attention_forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
mask: torch.Tensor, | ||
bias: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
if self.chunk_size is None: | ||
if mha.is_enabled() and CanSchTriMHA( | ||
list(query.shape), | ||
bias is not None, | ||
inf=self.inf, | ||
training=self.training, | ||
): | ||
if mask is not None: | ||
mask = mask.contiguous() | ||
if bias is not None: | ||
bias = bias.contiguous() | ||
return AttnTri( | ||
query, key, value, mask, bias, self.inf, torch.is_grad_enabled() | ||
) | ||
elif mha.is_enabled() and bias is not None and self.training: | ||
return AttnBiasJIT(query, key, value, mask, bias, self.inf) | ||
elif mha.is_enabled() and bias is None and self.training: | ||
return AttnNoBiasJIT(query, key, value, mask, self.inf) | ||
|
||
# Switch on/off MHA dynamically at runtime via: | ||
mha.enable() | ||
mha.disable() | ||
|
||
``` | ||
|
||
## LayerNorm | ||
|
||
```python | ||
from apex.contrib.openfold_triton import LayerNormSmallShapeOptImpl | ||
|
||
# Integration with LayerNorm module: | ||
class LayerNorm(nn.Module): | ||
# ... | ||
|
||
def _should_use_triton_kernels(self, x: torch.Tensor) -> bool: | ||
ln_triton_shapes = ( | ||
(256, 128), | ||
(256, 256), | ||
) | ||
ln_triton_dim = 4 | ||
return ( | ||
self.training | ||
and x.dim() == ln_triton_dim | ||
and x.shape[-2:] in ln_triton_shapes | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
if self._should_use_triton_kernels(x): | ||
return LayerNormSmallShapeOptImpl.apply( | ||
x, self.normalized_shape, self.weight, self.bias, self.eps | ||
) | ||
else: | ||
return F.layer_norm( | ||
x, self.normalized_shape, self.weight, self.bias, self.eps | ||
) | ||
|
||
# To load auto tuned cache: | ||
from apex.contrib.openfold_triton._layer_norm_config_ampere import _auto_tuned_config_ampere | ||
from apex.contrib.openfold_triton._layer_norm_config_hopper import _auto_tuned_config_hopper | ||
from apex.contrib.openfold_triton import _tuneable_triton_kernels | ||
|
||
def load_triton_auto_tuned_cache(dap_size: int, arch_type: str) -> None: | ||
auto_tuned_config = { | ||
"hopper": _auto_tuned_config_hopper, | ||
"ampere": _auto_tuned_config_ampere, | ||
}[arch_type] | ||
config_for_current_dap = auto_tuned_config[dap_size] | ||
for func_name, cache in config_for_current_dap.items(): | ||
_tuneable_triton_kernels[func_name].cache = cache | ||
|
||
load_triton_auto_tuned_cache( | ||
dap_size=4, # supported values: 0, 1, 2, 4, 8 | ||
arch_type="hopper", | ||
) | ||
|
||
``` | ||
|
||
## FusedAdamSWA | ||
|
||
```python | ||
from apex.contrib.openfold_triton.fused_adam_swa import FusedAdamSWA | ||
|
||
fused_optimizer = FusedAdamSWA.from_optim( | ||
adam_optimizer=adam_optimizer, # standard pytorch optimizer | ||
fp32_params=fp32_params, # FP32 used in weight update | ||
bf16_params=bf16_params, # BF16 used in forward, backward, reduction | ||
swa_params=swa_params, # SWA used for evaluation | ||
swa_decay_rate=swa_decay_rate, # for example: 0.9, 0.99, 0.999 | ||
) | ||
|
||
fused_optimizer.step() # fused optimizer step: casting BF16/FP32 + param updates + SWA | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# © 2023 NVIDIA CORPORATION & AFFILIATES | ||
|
||
import pickle | ||
from collections import OrderedDict | ||
from copy import deepcopy | ||
from io import BytesIO | ||
from typing import BinaryIO, Union | ||
|
||
import torch | ||
from triton.runtime.autotuner import Autotuner, Heuristics | ||
from triton.runtime.jit import JITFunction | ||
|
||
from apex.contrib.openfold_triton._layer_norm_backward_kernels import ( | ||
_layer_norm_backward_dw_db_partial, | ||
_layer_norm_backward_dw_db_partial_strided, | ||
_layer_norm_backward_dx, | ||
_layer_norm_backward_dx_strided, | ||
) | ||
from apex.contrib.openfold_triton._layer_norm_forward_kernels import ( | ||
_layer_norm_forward, | ||
_layer_norm_forward_strided, | ||
) | ||
from apex.contrib.openfold_triton.layer_norm import LayerNormSmallShapeOptImpl | ||
from apex.contrib.openfold_triton.mha import ( | ||
AttnBiasJIT, | ||
AttnNoBiasJIT, | ||
AttnTri, | ||
CanSchTriMHA, | ||
) | ||
|
||
__all__ = ( | ||
"LayerNormSmallShapeOptImpl", | ||
"sync_triton_auto_tune_cache_across_gpus", | ||
"CanSchTriMHA", | ||
"AttnTri", | ||
"AttnBiasJIT", | ||
"AttnNoBiasJIT", | ||
) | ||
|
||
|
||
def _get_tuneable_triton_func_name(f: Union[Autotuner, Heuristics, JITFunction]) -> str: | ||
if isinstance(f, JITFunction): | ||
return f.__name__ | ||
else: | ||
return _get_tuneable_triton_func_name(f.fn) | ||
|
||
|
||
_tuneable_triton_kernels = OrderedDict( | ||
(_get_tuneable_triton_func_name(func), func) | ||
for func in ( | ||
_layer_norm_backward_dw_db_partial, | ||
_layer_norm_backward_dw_db_partial_strided, | ||
_layer_norm_backward_dx, | ||
_layer_norm_backward_dx_strided, | ||
_layer_norm_forward, | ||
_layer_norm_forward_strided, | ||
) | ||
) | ||
|
||
|
||
def _save_triton_auto_tune_cache(f: BinaryIO, verbose: bool = False) -> None: | ||
caches = OrderedDict() | ||
for func_name, func in _tuneable_triton_kernels.items(): | ||
if len(func.cache) < 1: | ||
raise ValueError( | ||
f"Triton JIT kernel {func.__name__} didn't have tuning cache" | ||
) | ||
caches[func_name] = deepcopy(func.cache) | ||
pickle.dump(caches, f) | ||
if verbose: | ||
print(f"Triton kernel auto-tuning caches written to {f}") | ||
|
||
|
||
def _load_triton_auto_tune_cache( | ||
f: BinaryIO, strict: bool = True, verbose: bool = False | ||
) -> None: | ||
caches = pickle.load(f) | ||
if strict: | ||
loaded_func_name = set(caches.keys()) | ||
tuneable_func_name = set(_tuneable_triton_kernels.keys()) | ||
if loaded_func_name != tuneable_func_name: | ||
raise ValueError( | ||
f"Tuneable Triton kernels don't match with provided auto-tuning cache file {f}\n" | ||
f"Missing kernel caches: {tuneable_func_name - loaded_func_name}\n" | ||
f"Unexpected kernel caches: {loaded_func_name - tuneable_func_name}" | ||
) | ||
for func_name, cache in caches.items(): | ||
if func_name not in _tuneable_triton_kernels: | ||
raise ValueError( | ||
f"{func_name} from {f} doesn't match any tuneable Triton kernels" | ||
) | ||
_tuneable_triton_kernels[func_name].cache = cache | ||
if verbose: | ||
print(f"Triton kernel auto-tuning caches loaded from {f}") | ||
|
||
|
||
def sync_triton_auto_tune_cache_across_gpus() -> None: | ||
if not torch.distributed.is_initialized(): | ||
return | ||
if torch.distributed.get_rank() == 0: | ||
print("Broadcasting Triton auto-tuning cache from rank 0 to other ranks...") | ||
cache = BytesIO() | ||
_save_triton_auto_tune_cache(cache) | ||
cache.seek(0) | ||
cache_list = [ | ||
cache, | ||
] | ||
else: | ||
print( | ||
f"Rank {torch.distributed.get_rank()} is waiting for Triton auto-tuning cache from rank 0..." | ||
) | ||
cache_list = [ | ||
None, | ||
] | ||
torch.distributed.broadcast_object_list(cache_list) | ||
cache = cache_list[0] | ||
_load_triton_auto_tune_cache(cache) | ||
print("Succeed!") |
Oops, something went wrong.