Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/facebookresearch/xformers i…
Browse files Browse the repository at this point in the history
…nto dev_upstream
  • Loading branch information
tenpercent committed Feb 26, 2024
2 parents 7d43238 + 9469bb5 commit 6fbb383
Show file tree
Hide file tree
Showing 16 changed files with 847 additions and 735 deletions.
8 changes: 4 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ flake8 --config .flake8
mypy --ignore-missing-imports --scripts-are-modules --pretty --exclude build/ --exclude stubs/ .
```

* or you can just install [pre-commit](https://pre-commit.com/), which will make sure that all of the above is run automatically anytime you commit
in that case, you would need to
* or you can just install [pre-commit](https://pre-commit.com/), which will make sure that all of the above is run automatically anytime you commit
in that case, you would need to
```bash
pip install pre-commit
pip install pre-commit
```
then (in the xformers repository, just once)
```bash
pre-commit install
pre-commit install
```

After these steps each of your commits will run the same linting and formatting routines as the xformers continuous integration, which greatly helps getting your PRs all green !
Expand Down
1 change: 0 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,3 @@ This is meant to be an easy introduction to using xformers in practice, mirrorin
This is very close to the MicroViT example above, but illustrating the use of a hierarchical Transformer ([Metaformer](https://arxiv.org/pdf/2111.11418.pdf)) this time, through a helper function which generates the required configuration given the pooling parameters. The suggested configuration is about 6.6M parameters big (half of a ResNet18) and trains to about 86% top-1 Cifar10 within minutes.

![Example curves](../docs/assets/metaformer.png)

2 changes: 1 addition & 1 deletion examples/llama_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Example runs:
$ python -m generate --ckpt_dir models/CodeLlama-7b-Instruct/
loaded SentencePiece model: #words: 32016 - bos id: 1 - eos id: 2
loaded model in 12.36 seconds
> [INST]abc[/INST]
> [INST]abc[/INST]
I'm not sure I understand what you are saying with "abc". Could you explain?
---------------
> [INST]can you write a hello world program in C#[/INST]
Expand Down
67 changes: 67 additions & 0 deletions tests/test_ipc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
import torch.distributed as dist

from xformers.ops import init_ipc

from .multiprocessing_utils import launch_subprocesses

compute_capability = (0, 0)
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")
cuda_sm70_only = pytest.mark.skipif(
compute_capability < (7, 0), reason="requires sm70+"
)


def inner_test_ipc() -> None:
my_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
subgroup = torch.distributed.new_group()

ipcs = init_ipc(subgroup)

send_bufs = [
torch.full([1], my_rank, device="cuda", dtype=torch.int32)
for _ in range(world_size)
]
recv_bufs = send_bufs.copy()
for other_rank, conn in enumerate(ipcs):
if conn is None:
continue
conn.send(send_bufs[other_rank])
for other_rank, conn in enumerate(ipcs):
if conn is None:
continue
recv_bufs[other_rank] = conn.recv()

torch.cuda.synchronize()
dist.barrier(subgroup)

# Use the buffer to send data
for other_rank, buf in enumerate(recv_bufs):
assert buf[0].item() == other_rank
buf.fill_(my_rank)

torch.cuda.synchronize()
dist.barrier(subgroup)

# Verify we've received the data correctly
for other_rank, buf in enumerate(send_bufs):
assert (
buf[0].item() == other_rank
), f"[#{my_rank}] {other_rank=} != {buf[0].item()=}"


@cuda_sm70_only
def test_ipc() -> None:
world_size = 2
launch_subprocesses(
world_size=world_size,
fn=inner_test_ipc,
)
18 changes: 14 additions & 4 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,10 @@ def test_decoder(
# kv_heads > 1: BMGHK
if dtype == "bf16" and compute_capability < (8, 0):
raise pytest.skip("BF16 is only supported on SM80+")
import triton

if dequant and triton.__version__[:4] < "3.0.":
raise pytest.skip("dequant needs triton updates")
dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype]
torch.manual_seed(1)
if kv_heads is not None and kv_heads > 1:
Expand Down Expand Up @@ -2527,8 +2531,10 @@ def test_forward_splitk(


@cuda_only
@pytest.mark.parametrize("op", [fmha.triton_splitk.FwOp, fmha.flash.FwOp])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"op", [fmha.triton_splitk.FwOp, fmha.flash.FwOp], ids=lambda op: op.NAME
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"B_Mkv_H_K",
[
Expand Down Expand Up @@ -2777,19 +2783,20 @@ def test_cutlassB_iter_order(
fmha.triton_splitk.FwOp_S8,
fmha.triton_splitk.FwOp_Map[48],
],
ids=lambda op: op.NAME,
)
@pytest.mark.parametrize("num_quant_groups", [0, 1, 8])
@pytest.mark.parametrize("page_size", [64, 128, 256])
def test_paged_attention(
B, MAX_T: int, num_quant_groups: bool, page_size: int, op: Type[AttentionFwOpBase]
B, MAX_T: int, num_quant_groups: int, page_size: int, op: Type[AttentionFwOpBase]
):
paged_attention_run_inner(B, MAX_T, num_quant_groups, page_size, op, bench=False)


def paged_attention_run_inner(
B: int,
MAX_T: int,
num_quant_groups: bool,
num_quant_groups: int,
page_size: int,
op: Type[AttentionFwOpBase],
bench: bool,
Expand All @@ -2814,6 +2821,9 @@ def paged_attention_run_inner(

q = torch.randn((B, 1, N_H_L, D_H), dtype=torch.bfloat16, device="cuda")
if num_quant_groups:
if triton.__version__[:4] < "3.0.":
raise pytest.skip("dequant needs triton updates")

# Using high=64 below, because with 256 both paged and non-paged paths
# will produce NaNs - probably some quantization coeffitions are NaNs
# after the bitwise cast.
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def assert_allclose(
f"{msg}: "
f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)"
f" at {max_location} of shape {tuple(out.shape)} / atol={atol}, rtol={rtol}"
f"/ total failing elements: {num_different}, percentage={percentage}"
f"/ total failing elements: {num_different} ({percentage*100:.3}%)"
)


Expand Down
2 changes: 1 addition & 1 deletion xformers/benchmarks/LRA/code/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -319,4 +319,4 @@
}
}
}
}
}
2 changes: 1 addition & 1 deletion xformers/benchmarks/LRA/code/config_nystrom.json
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,4 @@
}
}
}
}
}
2 changes: 1 addition & 1 deletion xformers/benchmarks/LRA/code/config_orig_lra.json
Original file line number Diff line number Diff line change
Expand Up @@ -328,4 +328,4 @@
}
}
}
}
}
2 changes: 1 addition & 1 deletion xformers/benchmarks/LRA/code/config_orig_lra_paper.json
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,4 @@
}
}
}
}
}
1 change: 1 addition & 0 deletions xformers/csrc/attention/cuda/fmha/kernel_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ struct AttentionKernel {
return false;
}
q_strideM = q_strideH;
bias_strideM = bias_strideH;
num_queries = num_heads;
num_heads = 1; // unused but here for intent
// remove causal since n_query = 1
Expand Down
3 changes: 3 additions & 0 deletions xformers/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
memory_efficient_attention_forward_requires_grad,
)
from .indexing import index_select_cat, scaled_index_add
from .ipc import init_ipc
from .modpar_layers import ColumnParallelLinear, RowParallelLinear
from .rmsnorm import RMSNorm
from .rope_padded import rope_padded
Expand Down Expand Up @@ -97,6 +98,8 @@ def masked_matmul(a, b, mask=None):
# indexing
"index_select_cat",
"scaled_index_add",
# ipc
"init_ipc",
# modpar_layers
"ColumnParallelLinear",
"RowParallelLinear",
Expand Down
8 changes: 6 additions & 2 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,18 @@ def memory_efficient_attention(
.. code-block:: python
scale = 1 / query.shape[-1] ** 0.5
scale = 1.0 / query.shape[-1] ** 0.5
query = query * scale
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
return attn @ value
attn = attn @ value
return attn.transpose(1, 2)
:Examples:
Expand Down
Loading

0 comments on commit 6fbb383

Please sign in to comment.