Skip to content

Commit

Permalink
Add openfold triton code (#1734)
Browse files Browse the repository at this point in the history
* Add openfold triton code

* Add copyright headers

* Mention dependency on einops in README

* Remove _test_fused_adam_swa.py
  • Loading branch information
ar-nowaczynski authored Oct 3, 2023
1 parent 6a77872 commit 58acf96
Show file tree
Hide file tree
Showing 10 changed files with 2,847 additions and 0 deletions.
123 changes: 123 additions & 0 deletions apex/contrib/openfold_triton/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# OpenFold triton kernels

This subpackage is a collection of Triton kernels written specifically for the OpenFold model architecture initial training mode.

To use this subpackage, you must install additional dependencies:

```bash
pip install einops
```

The following sections list all main features and show how to use them.

## Multi-Head Attention

```python
import apex.contrib.openfold_triton.mha as mha
from apex.contrib.openfold_triton import AttnBiasJIT, AttnNoBiasJIT, AttnTri, CanSchTriMHA

# Integration with Attention module:
class SelfAttentionWithGate(nn.Module):
# ...

def _attention_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
if self.chunk_size is None:
if mha.is_enabled() and CanSchTriMHA(
list(query.shape),
bias is not None,
inf=self.inf,
training=self.training,
):
if mask is not None:
mask = mask.contiguous()
if bias is not None:
bias = bias.contiguous()
return AttnTri(
query, key, value, mask, bias, self.inf, torch.is_grad_enabled()
)
elif mha.is_enabled() and bias is not None and self.training:
return AttnBiasJIT(query, key, value, mask, bias, self.inf)
elif mha.is_enabled() and bias is None and self.training:
return AttnNoBiasJIT(query, key, value, mask, self.inf)

# Switch on/off MHA dynamically at runtime via:
mha.enable()
mha.disable()

```

## LayerNorm

```python
from apex.contrib.openfold_triton import LayerNormSmallShapeOptImpl

# Integration with LayerNorm module:
class LayerNorm(nn.Module):
# ...

def _should_use_triton_kernels(self, x: torch.Tensor) -> bool:
ln_triton_shapes = (
(256, 128),
(256, 256),
)
ln_triton_dim = 4
return (
self.training
and x.dim() == ln_triton_dim
and x.shape[-2:] in ln_triton_shapes
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._should_use_triton_kernels(x):
return LayerNormSmallShapeOptImpl.apply(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
else:
return F.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)

# To load auto tuned cache:
from apex.contrib.openfold_triton._layer_norm_config_ampere import _auto_tuned_config_ampere
from apex.contrib.openfold_triton._layer_norm_config_hopper import _auto_tuned_config_hopper
from apex.contrib.openfold_triton import _tuneable_triton_kernels

def load_triton_auto_tuned_cache(dap_size: int, arch_type: str) -> None:
auto_tuned_config = {
"hopper": _auto_tuned_config_hopper,
"ampere": _auto_tuned_config_ampere,
}[arch_type]
config_for_current_dap = auto_tuned_config[dap_size]
for func_name, cache in config_for_current_dap.items():
_tuneable_triton_kernels[func_name].cache = cache

load_triton_auto_tuned_cache(
dap_size=4, # supported values: 0, 1, 2, 4, 8
arch_type="hopper",
)

```

## FusedAdamSWA

```python
from apex.contrib.openfold_triton.fused_adam_swa import FusedAdamSWA

fused_optimizer = FusedAdamSWA.from_optim(
adam_optimizer=adam_optimizer, # standard pytorch optimizer
fp32_params=fp32_params, # FP32 used in weight update
bf16_params=bf16_params, # BF16 used in forward, backward, reduction
swa_params=swa_params, # SWA used for evaluation
swa_decay_rate=swa_decay_rate, # for example: 0.9, 0.99, 0.999
)

fused_optimizer.step() # fused optimizer step: casting BF16/FP32 + param updates + SWA

```
118 changes: 118 additions & 0 deletions apex/contrib/openfold_triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# © 2023 NVIDIA CORPORATION & AFFILIATES

import pickle
from collections import OrderedDict
from copy import deepcopy
from io import BytesIO
from typing import BinaryIO, Union

import torch
from triton.runtime.autotuner import Autotuner, Heuristics
from triton.runtime.jit import JITFunction

from apex.contrib.openfold_triton._layer_norm_backward_kernels import (
_layer_norm_backward_dw_db_partial,
_layer_norm_backward_dw_db_partial_strided,
_layer_norm_backward_dx,
_layer_norm_backward_dx_strided,
)
from apex.contrib.openfold_triton._layer_norm_forward_kernels import (
_layer_norm_forward,
_layer_norm_forward_strided,
)
from apex.contrib.openfold_triton.layer_norm import LayerNormSmallShapeOptImpl
from apex.contrib.openfold_triton.mha import (
AttnBiasJIT,
AttnNoBiasJIT,
AttnTri,
CanSchTriMHA,
)

__all__ = (
"LayerNormSmallShapeOptImpl",
"sync_triton_auto_tune_cache_across_gpus",
"CanSchTriMHA",
"AttnTri",
"AttnBiasJIT",
"AttnNoBiasJIT",
)


def _get_tuneable_triton_func_name(f: Union[Autotuner, Heuristics, JITFunction]) -> str:
if isinstance(f, JITFunction):
return f.__name__
else:
return _get_tuneable_triton_func_name(f.fn)


_tuneable_triton_kernels = OrderedDict(
(_get_tuneable_triton_func_name(func), func)
for func in (
_layer_norm_backward_dw_db_partial,
_layer_norm_backward_dw_db_partial_strided,
_layer_norm_backward_dx,
_layer_norm_backward_dx_strided,
_layer_norm_forward,
_layer_norm_forward_strided,
)
)


def _save_triton_auto_tune_cache(f: BinaryIO, verbose: bool = False) -> None:
caches = OrderedDict()
for func_name, func in _tuneable_triton_kernels.items():
if len(func.cache) < 1:
raise ValueError(
f"Triton JIT kernel {func.__name__} didn't have tuning cache"
)
caches[func_name] = deepcopy(func.cache)
pickle.dump(caches, f)
if verbose:
print(f"Triton kernel auto-tuning caches written to {f}")


def _load_triton_auto_tune_cache(
f: BinaryIO, strict: bool = True, verbose: bool = False
) -> None:
caches = pickle.load(f)
if strict:
loaded_func_name = set(caches.keys())
tuneable_func_name = set(_tuneable_triton_kernels.keys())
if loaded_func_name != tuneable_func_name:
raise ValueError(
f"Tuneable Triton kernels don't match with provided auto-tuning cache file {f}\n"
f"Missing kernel caches: {tuneable_func_name - loaded_func_name}\n"
f"Unexpected kernel caches: {loaded_func_name - tuneable_func_name}"
)
for func_name, cache in caches.items():
if func_name not in _tuneable_triton_kernels:
raise ValueError(
f"{func_name} from {f} doesn't match any tuneable Triton kernels"
)
_tuneable_triton_kernels[func_name].cache = cache
if verbose:
print(f"Triton kernel auto-tuning caches loaded from {f}")


def sync_triton_auto_tune_cache_across_gpus() -> None:
if not torch.distributed.is_initialized():
return
if torch.distributed.get_rank() == 0:
print("Broadcasting Triton auto-tuning cache from rank 0 to other ranks...")
cache = BytesIO()
_save_triton_auto_tune_cache(cache)
cache.seek(0)
cache_list = [
cache,
]
else:
print(
f"Rank {torch.distributed.get_rank()} is waiting for Triton auto-tuning cache from rank 0..."
)
cache_list = [
None,
]
torch.distributed.broadcast_object_list(cache_list)
cache = cache_list[0]
_load_triton_auto_tune_cache(cache)
print("Succeed!")
Loading

0 comments on commit 58acf96

Please sign in to comment.