Skip to content

Commit

Permalink
Fix pallas GLU kernel on large inputs and update documentation for co…
Browse files Browse the repository at this point in the history
…mpilation buckets.

PiperOrigin-RevId: 696513320
  • Loading branch information
jacobjinkelly committed Nov 14, 2024
1 parent 1d3e173 commit 2ffe43f
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 69 deletions.
15 changes: 0 additions & 15 deletions docs/known_issues.md
Original file line number Diff line number Diff line change
@@ -1,16 +1 @@
# Known Issues

## Numerical Accuracy above 5,120 Tokens

AlphaFold 3 does not currently support inference on inputs larger than 5,120
tokens. An error will be raised if the input is larger than this threshold.

This is due to a numerical issue with the custom Pallas kernel implementing the
Gated Linear Unit. The numerical issue only occurs at inputs above the 5,120
tokens threshold, and results in degraded accuracy in the predicted structure.

This numerical issue is unique to the single GPU configuration used in this
repository, and does not affect the results in the
[AlphaFold 3 paper](https://www.nature.com/articles/s41586-024-07487-w).

We hope to resolve this issue soon and remove this check on input size.
51 changes: 43 additions & 8 deletions docs/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,41 @@ V100 using the flag `--flash_attention_implementation=xla` in
`run_alphafold.py`, this configuration has not been tested for numerical
accuracy or throughput efficiency, so please proceed with caution.

## Compilation Buckets

To avoid excessive re-compilation of the model, AlphaFold 3 implements
compilation buckets: ranges of input sizes using a single compilation of the
model.

When featurising an input, AlphaFold 3 determines the smallest bucket the input
fits into, then adds any necessary padding. This may avoid re-compiling the
model when running inference on the input if it belongs to the same bucket as a
previously processed input.

The configuration of bucket sizes involves a trade-off: more buckets leads to
more re-compilations of the model, but less padding.

By default, the largest bucket size is 5,120 tokens. Processing inputs larger
than this maximum bucket size triggers the creation of a new bucket for exactly
that input size, and a re-compilation of the model. In this case, you may wish
to redefine the compilation bucket sizes via the `--buckets` flag in
`run_alphafold.py` to add additional larger bucket sizes. For example, suppose
you are running inference on inputs with token sizes: `5132, 5280, 5342`. Using
the default bucket sizes configured in `run_alphafold.py` will trigger three
separate model compilations, one for each unique token size. If instead you pass
in the following flag to `run_alphafold.py`

```
--buckets 256,512,768,1024,1280,1536,2048,2560,3072,3584,4096,4608,5120,5376
```
when running inference on the above three input sizes, the model will be
compiled only once for the bucket size `5376`. **Note:** for this specific
example with input sizes `5132, 5280, 5342`, passing in `--buckets 5376` is
sufficient to achieve the desired compilation behaviour. The provided example
with multiple buckets illustrates a more general solution suitable for diverse
input sizes.
## Additional Flags
### Compilation Time Workaround with XLA Flags
Expand All @@ -109,8 +144,8 @@ ENV XLA_FLAGS="--xla_gpu_enable_triton_gemm=false"
### GPU Memory

The following environment variables (set by default in the `Dockerfile`) enable
folding a single input of size up to 5,120 tokens on a single A100 with 80 GB of
memory:
folding a single input of size up to 5,120 tokens on a single A100 (80 GB) or a
single H100 (80 GB):

```sh
ENV XLA_PYTHON_CLIENT_PREALLOCATE=true
Expand All @@ -119,12 +154,12 @@ ENV XLA_CLIENT_MEM_FRACTION=0.95

#### Unified Memory

If you would like to run AlphaFold 3 on a GPU with less memory (an A100 with 40
GB of memory, for instance), we recommend enabling unified memory. Enabling
unified memory allows the program to spill GPU memory to host memory if there
isn't enough space. This prevents an OOM, at the cost of making the program
slower by accessing host memory instead of device memory. To learn more, check
out the
If you would like to run AlphaFold 3 on inputs larger than 5,120 tokens, or on a
GPU with less memory (an A100 with 40 GB of memory, for instance), we recommend
enabling unified memory. Enabling unified memory allows the program to spill GPU
memory to host memory if there isn't enough space. This prevents an OOM, at the
cost of making the program slower by accessing host memory instead of device
memory. To learn more, check out the
[NVIDIA blog post](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/).

You can enable unified memory by setting the following environment variables in
Expand Down
30 changes: 13 additions & 17 deletions run_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import textwrap
import time
import typing
from typing import Final, Protocol, Self, TypeVar, overload
from typing import Protocol, Self, TypeVar, overload

from absl import app
from absl import flags
Expand Down Expand Up @@ -203,27 +203,23 @@
' beyond 8 CPUs provides very little additional speedup.',
)

# Compilation cache
# Compilation cache.
_JAX_COMPILATION_CACHE_DIR = flags.DEFINE_string(
'jax_compilation_cache_dir',
None,
'Path to a directory for the JAX compilation cache.',
)

_BUCKETS: Final[tuple[int, ...]] = (
256,
512,
768,
1024,
1280,
1536,
2048,
2560,
3072,
3584,
4096,
4608,
5120,
# Compilation buckets.
_BUCKETS = flags.DEFINE_list(
'buckets',
# pyformat: disable
['256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072',
'3584', '4096', '4608', '5120'],
# pyformat: enable
'Strictly increasing order of token sizes for which to cache compilations.'
' For any input with more tokens than the largest bucket size, a new bucket'
' is created for exactly that number of tokens.',
)


Expand Down Expand Up @@ -665,7 +661,7 @@ def main(_):
data_pipeline_config=data_pipeline_config,
model_runner=model_runner,
output_dir=os.path.join(_OUTPUT_DIR.value, fold_input.sanitised_name()),
buckets=_BUCKETS,
buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
)

print(f'Done processing {len(fold_inputs)} fold inputs.')
Expand Down
60 changes: 36 additions & 24 deletions src/alphafold3/jax/common/array_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Self, TypeAlias, TypeVar

import jax
import jax.experimental
from jax.experimental import pallas as pl
import jax.numpy as jnp
from jax.typing import ArrayLike # pylint: disable=g-importing-member
Expand Down Expand Up @@ -91,11 +92,17 @@ def ndim(self) -> int:
def T(self) -> Self: # pylint: disable=invalid-name
return self.transpose()

@property
def _index_dtype(self) -> jax.typing.DTypeLike:
i32_max = jnp.iinfo(jnp.int32).max
return jnp.int32 if (self.base.size <= i32_max) else jnp.int64

@property
def offsets(self) -> jax.Array:
"""Returns array of offsets into `base` for each element."""
idxs = jnp.indices(self.shape, sparse=True)
return self.offset + sum(s * idx for s, idx in zip(self.strides, idxs))
with jax.experimental.enable_x64():
idxs = jnp.indices(self.shape, sparse=True, dtype=self._index_dtype)
return self.offset + sum(s * idx for s, idx in zip(self.strides, idxs))

def astype(self, dtype: jax.typing.DTypeLike) -> Self:
return self._replace(base=self.base.astype(dtype))
Expand Down Expand Up @@ -255,29 +262,34 @@ def __getitem__(self, idxs: Indexer | tuple[Indexer, ...]) -> Self:

shape = []
strides = []
offset = self.offset

for idx, dim, stride in zip(idxs, self.shape, self.strides, strict=True):
if isinstance(idx, int):
if not (-dim <= idx < dim):
raise ValueError("Slice index out of range.")
offset += stride * (idx % dim)
elif isinstance(idx, ScalarInt):
offset += stride * idx
elif isinstance(idx, slice):
start, stop, step = idx.indices(dim)
if step >= 0:
shape.append(pl.cdiv(stop - start, step))
with jax.experimental.enable_x64():

def as_index(x):
return x.astype(self._index_dtype) if isinstance(x, jax.Array) else x

offset = as_index(self.offset)

for idx, dim, stride in zip(idxs, self.shape, self.strides, strict=True):
if isinstance(idx, int):
if not (-dim <= idx < dim):
raise ValueError("Slice index out of range.")
offset += stride * (idx % dim)
elif isinstance(idx, ScalarInt):
offset += stride * as_index(idx)
elif isinstance(idx, slice):
start, stop, step = idx.indices(dim)
if step >= 0:
shape.append(pl.cdiv(stop - start, step))
else:
shape.append(pl.cdiv(start - stop, -step))
strides.append(stride * step)
offset += stride * start
elif isinstance(idx, pl.Slice):
shape.append(idx.size)
strides.append(stride * idx.stride)
offset += stride * as_index(idx.start)
else:
shape.append(pl.cdiv(start - stop, -step))
strides.append(stride * step)
offset += stride * start
elif isinstance(idx, pl.Slice):
shape.append(idx.size)
strides.append(stride * idx.stride)
offset += stride * idx.start
else:
raise ValueError(f"Unexpected indexer: {idx}")
raise ValueError(f"Unexpected indexer: {idx}")

return self._replace(shape=shape, strides=strides, offset=offset)

Expand Down
7 changes: 5 additions & 2 deletions src/alphafold3/jax/gated_linear_unit/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from alphafold3.jax.common import array_view
import jax
import jax.experimental
from jax.experimental import pallas as pl
import jax.numpy as jnp
import jaxtyping
Expand Down Expand Up @@ -43,7 +44,8 @@ def load_block(
idx = ref[idx].offsets
ref = ref.base
other = None if mask is None else other
return pl.load(ref, idx, mask=mask, other=other, **kwargs)
with jax.experimental.enable_x64():
return pl.load(ref, idx, mask=mask, other=other, **kwargs)


@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
Expand All @@ -62,7 +64,8 @@ def store_block(
if isinstance(ref, array_view.ArrayView):
idx = ref[idx].offsets
ref = ref.base
pl.store(ref, idx, val.astype(ref.dtype), mask=mask, **kwargs)
with jax.experimental.enable_x64():
pl.store(ref, idx, val.astype(ref.dtype), mask=mask, **kwargs)


def in_bounds_mask(
Expand Down
21 changes: 18 additions & 3 deletions src/alphafold3/model/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,15 @@ def calculate_bucket_size(
bucket_idx = bisect.bisect_left(buckets, num_tokens)

if bucket_idx == len(buckets):
raise ValueError(
f'Number of tokens {num_tokens} is more than the largest currently'
f' supported bucket size {buckets[-1]}.'
logging.warning(
'Creating a new bucket of size %d since the input has more tokens than'
' the largest bucket size %d. This may trigger a re-compilation of the'
' model. Consider additional large bucket sizes to avoid excessive'
' re-compilation.',
num_tokens,
buckets[-1],
)
return num_tokens

return buckets[bucket_idx]

Expand Down Expand Up @@ -250,9 +255,19 @@ def process_item(
f'({total_tokens} < {self._config.min_total_residues})'
)

logging.info(
'Calculating bucket size for input with %d tokens.', total_tokens
)
padded_token_length = calculate_bucket_size(
total_tokens, self._config.buckets
)
logging.info(
'Got bucket size %d for input with %d tokens, resulting in %d padded'
' tokens.',
padded_token_length,
total_tokens,
padded_token_length - total_tokens,
)

# Padding shapes for all features.
num_atoms = padded_token_length * self._config.average_num_atoms_per_token
Expand Down

0 comments on commit 2ffe43f

Please sign in to comment.