Skip to content

Commit

Permalink
Merge branch 'dev_upstream' of https://github.com/ROCm/xformers into …
Browse files Browse the repository at this point in the history
…develop
  • Loading branch information
tenpercent committed Feb 20, 2024
2 parents 89fb7d6 + 5d3247f commit 9be7f8d
Show file tree
Hide file tree
Showing 31 changed files with 2,438 additions and 1,174 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/wheels_upload_pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ jobs:
path: "."
fetch-depth: 0 # for tags

# inspired by https://github.com/jlumbroso/free-disk-space/blob/main/action.yml
- name: Free disk space
run: |
sudo rm -rf /usr/local/lib/android || true
sudo rm -rf /usr/share/dotnet || true
- name: Setup twine config
if: inputs.pypirc
run: |
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/wheels_upload_s3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ jobs:
path: "."
fetch-depth: 0 # for tags

# inspired by https://github.com/jlumbroso/free-disk-space/blob/main/action.yml
- name: Free disk space
run: |
sudo rm -rf /usr/local/lib/android || true
sudo rm -rf /usr/share/dotnet || true
- uses: actions/download-artifact@v3
with:
path: dist
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.0.25] - TBD
### Added
- New merge_attentions function
### Improved
- fMHA: Updated Flash-Attention to v2.5.2: this has a performance improvement for multiquery.
- fMHA: triton_splitk changed and expanded. Now amalgamates using LSE. Can autotune, supports causal with a small number of queries - not just 1. Experimental support for paged attention.
### Removed

## [0.0.24] - 2024-01-31
Expand Down
14 changes: 13 additions & 1 deletion docs/source/components/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,25 @@ Available implementations
:member-order: bysource

.. automodule:: xformers.ops.fmha.triton
:members: FwOp, BwOp
:members: FwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.small_k
:members: FwOp, BwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.ck
:members: FwOp, BwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.ck_decoder
:members: FwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.ck_splitk
:members: FwOp
:member-order: bysource

Attention biases
~~~~~~~~~~~~~~~~~~~~

Expand Down
147 changes: 36 additions & 111 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ def get_cuda_version(cuda_dir) -> int:
return bare_metal_major * 100 + bare_metal_minor


def get_hip_version(rocm_dir) -> str:
hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
try:
raw_output = subprocess.check_output(
[hipcc_bin, "--version"], universal_newlines=True
)
except Exception as e:
print(
f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
)
return None
for line in raw_output.split("\n"):
if "HIP version" in line:
return line.split()[-1]
return None


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
# XXX: Not supported on windows for cuda<12
# https://github.com/Dao-AILab/flash-attention/issues/345
Expand Down Expand Up @@ -186,12 +203,13 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
sources = ["csrc/flash_attn/flash_api.cpp"]
for f in glob.glob(os.path.join(flash_root, "csrc", "flash_attn", "src", "*.cu")):
sources.append(str(Path(f).relative_to(flash_root)))
common_extra_compile_args = ["-DFLASHATTENTION_DISABLE_ALIBI"]
return [
CUDAExtension(
name="xformers._C_flashattention",
sources=[os.path.join(flash_root, path) for path in sources],
extra_compile_args={
**extra_compile_args,
"cxx": extra_compile_args.get("cxx", []) + common_extra_compile_args,
"nvcc": extra_compile_args.get("nvcc", [])
+ [
"-O3",
Expand All @@ -207,6 +225,7 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
]
+ nvcc_archs_flags
+ nvcc_windows_flags
+ common_extra_compile_args
+ get_extra_nvcc_flags_for_build_type(cuda_version),
},
include_dirs=[
Expand All @@ -229,118 +248,19 @@ def rename_cpp_cu(cpp_files):
def get_extensions():
extensions_dir = os.path.join("xformers", "csrc")

sources = glob.glob(
os.path.join(extensions_dir, "attention", "*.cpp"), recursive=False
)
sources += glob.glob(
os.path.join(extensions_dir, "attention", "autograd", "**", "*.cpp"),
recursive=True,
)
sources += glob.glob(
os.path.join(extensions_dir, "attention", "cpu", "**", "*.cpp"), recursive=True
)
sources += glob.glob(
os.path.join(extensions_dir, "indexing", "**", "*.cpp"), recursive=True
)
sources += glob.glob(
os.path.join(extensions_dir, "swiglu", "**", "*.cpp"), recursive=True
)

# avoid the temporary .cu file under xformers/csrc/attention/hip_fmha are included
source_cuda = glob.glob(os.path.join(extensions_dir, "*.cu"), recursive=False)
source_cuda += glob.glob(
os.path.join(extensions_dir, "attention", "cuda", "**", "*.cu"), recursive=True
)
source_cuda += glob.glob(
os.path.join(extensions_dir, "indexing", "**", "*.cu"), recursive=True
)
source_cuda += glob.glob(
os.path.join(extensions_dir, "swiglu", "**", "*.cu"), recursive=True
)

sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True)
source_hip = glob.glob(
os.path.join(extensions_dir, "attention", "hip_fmha", "ck_fmha_test.cpp"),
recursive=False,
)
source_hip += glob.glob(
os.path.join(
extensions_dir, "attention", "hip_fmha", "attention_forward_decoder.cpp"
),
recursive=False,
)

source_hip_decoder = [
*glob.glob(
os.path.join(
extensions_dir, "attention", "hip_fmha", "attention_forward_decoder.cpp"
),
recursive=False,
),
*glob.glob(
os.path.join(
extensions_dir, "attention", "hip_fmha", "attention_forward_splitk.cpp"
),
recursive=False,
),
]

source_hip += glob.glob(
os.path.join(
extensions_dir,
"attention",
"hip_fmha",
"attention_forward_generic_ck_tiled.cpp",
),
recursive=False,
)
source_hip += glob.glob(
os.path.join(
extensions_dir,
"attention",
"hip_fmha",
"ck_tiled_fmha_batched_infer_*.cpp",
),
recursive=False,
)
source_hip += glob.glob(
os.path.join(
extensions_dir,
"attention",
"hip_fmha",
"ck_tiled_fmha_grouped_infer_*.cpp",
),
recursive=False,
)
source_hip += glob.glob(
os.path.join(
extensions_dir,
"attention",
"hip_fmha",
"ck_tiled_fmha_batched_forward_*.cpp",
),
recursive=False,
)
source_hip += glob.glob(
os.path.join(
extensions_dir,
"attention",
"hip_fmha",
"ck_tiled_fmha_grouped_forward_*.cpp",
),
recursive=False,
os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cpp"),
recursive=True,
)
source_hip += glob.glob(
os.path.join(
extensions_dir,
"attention",
"hip_fmha",
"instances",
"ck_tiled_fmha_*.cpp",
),
recursive=False,
source_hip_generated = glob.glob(
os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cu"),
recursive=True,
)

source_hip += source_hip_decoder
# avoid the temporary .cu files generated under xformers/csrc/attention/hip_fmha
source_cuda = list(set(source_cuda) - set(source_hip_generated))
sources = list(set(sources) - set(source_hip))

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
Expand All @@ -366,6 +286,7 @@ def get_extensions():
include_dirs = [extensions_dir]
ext_modules = []
cuda_version = None
hip_version = None
flash_version = "0.0.0"

if (
Expand Down Expand Up @@ -422,6 +343,9 @@ def get_extensions():
]
elif torch.cuda.is_available() and torch.version.hip:
rename_cpp_cu(source_hip)
rocm_home = os.getenv("ROCM_PATH")
hip_version = get_hip_version(rocm_home)

source_hip_cu = []
for ff in source_hip:
source_hip_cu += [ff.replace(".cpp", ".cu")]
Expand Down Expand Up @@ -467,6 +391,7 @@ def get_extensions():
return ext_modules, {
"version": {
"cuda": cuda_version,
"hip": hip_version,
"torch": torch.__version__,
"python": platform.python_version(),
"flash": flash_version,
Expand All @@ -475,6 +400,7 @@ def get_extensions():
k: os.environ.get(k)
for k in [
"TORCH_CUDA_ARCH_LIST",
"PYTORCH_ROCM_ARCH",
"XFORMERS_BUILD_TYPE",
"XFORMERS_ENABLE_DEBUG_ASSERTIONS",
"NVCC_FLAGS",
Expand Down Expand Up @@ -530,7 +456,6 @@ def copy_extensions_to_source(self) -> None:


if __name__ == "__main__":

if os.getenv("BUILD_VERSION"): # In CI
version = os.getenv("BUILD_VERSION", "0.0.0")
else:
Expand Down
14 changes: 7 additions & 7 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
build_attention,
)

disable_on_rocm = pytest.mark.skipif(
not not torch.version.hip, reason="could not be done on ROCM"
)

DEVICES = (
[torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")]
)
Expand Down Expand Up @@ -95,7 +91,6 @@ def noop(x):
return multi_head


@disable_on_rocm
@pytest.mark.parametrize("attn_dropout", [0.0, 0.3])
@pytest.mark.parametrize("residual_dropout", [0.0, 0.1])
@pytest.mark.parametrize("causal", [True, False])
Expand All @@ -112,6 +107,13 @@ def test_order_invariance(
causal: bool,
device: torch.device,
):
if (
torch.version.hip
and device == torch.device("cuda")
and attention_name == "local"
):
# Backend calls into Sputnik library which isn't built on ROCm
device = torch.device("cpu")

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
Expand Down Expand Up @@ -166,7 +168,6 @@ def test_order_invariance(
_ = multi_head(inputs, inputs_shuffled, inputs)


@disable_on_rocm
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ["scaled_dot_product"])
@pytest.mark.parametrize("device", DEVICES)
Expand Down Expand Up @@ -210,7 +211,6 @@ def test_kqv_ordering(
assert torch.allclose(res_false[0, :, :], res_false[1, :, :])


@disable_on_rocm
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ["scaled_dot_product"])
@pytest.mark.parametrize("device", DEVICES)
Expand Down
4 changes: 0 additions & 4 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
)

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
disable_on_rocm = pytest.mark.skipif(
not not torch.version.hip, reason="could not be done on ROCM"
)
_devices = ["cpu"]
cuda_cap = (0, 0)

Expand All @@ -39,7 +36,6 @@ def _all_policy(func, *args, **kwargs):
return True


@disable_on_rocm
@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy])
@pytest.mark.parametrize("input_requires_grad", [True, False])
Expand Down
Loading

0 comments on commit 9be7f8d

Please sign in to comment.