diff --git a/frontends/torch-frontend/scripts/build_and_test.sh b/frontends/torch-frontend/scripts/build_and_test.sh index 882ccc117..d0b177d59 100755 --- a/frontends/torch-frontend/scripts/build_and_test.sh +++ b/frontends/torch-frontend/scripts/build_and_test.sh @@ -51,7 +51,7 @@ cmake --build ./build --target all if [[ $TORCH_FRONTEND_TEST == "ON" ]]; then python3 -m pip install -r test-requirements.txt install_mhlo_tools - PYTHONPATH=build/python_packages/:build/torch_mlir_build/python_packages/torch_mlir TORCH_DISABLE_NATIVE_FUNCOL=1 python3 -m pytest torch-frontend/python/test + PYTHONPATH=build/python_packages/:build/torch_mlir_build/python_packages/torch_mlir TORCH_DISABLE_NATIVE_FUNCOL=1 python3 -m pytest -m "not attention_rewriter" torch-frontend/python/test fi popd diff --git a/frontends/torch-frontend/scripts/envsetup.sh b/frontends/torch-frontend/scripts/envsetup.sh index d9dc80c28..982e37441 100755 --- a/frontends/torch-frontend/scripts/envsetup.sh +++ b/frontends/torch-frontend/scripts/envsetup.sh @@ -33,7 +33,6 @@ function prepare_for_build_with_prebuilt() { pushd ${PROJ_DIR} # install requirements python3 -m pip install -r build-requirements.txt - # python3 -m pip install --no-cache-dir torch==2.1.0+cu118 torchvision==0.16.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html # initialize submodule git submodule update --init -f $TORCH_MLIR_ROOT @@ -49,10 +48,11 @@ function prepare_for_build() { pushd ${PROJ_DIR} # install requirements python3 -m pip install -r build-requirements.txt - # python3 -m pip install --no-cache-dir torch==2.1.0+cu118 torchvision==0.16.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html # initialize submodule git submodule update --init --recursive -f $TORCH_MLIR_ROOT apply_patches } + +# python3 -m pip install --no-cache-dir torch==2.1.0+cu118 torchvision==0.16.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/frontends/torch-frontend/torch-cpu-requirements.txt b/frontends/torch-frontend/torch-cpu-requirements.txt index 15cd43740..57ea64eae 100644 --- a/frontends/torch-frontend/torch-cpu-requirements.txt +++ b/frontends/torch-frontend/torch-cpu-requirements.txt @@ -1,4 +1,4 @@ --extra-index-url https://download.pytorch.org/whl/cpu --pre -torch==2.1.0+cpu -torchvision==0.16.0+cpu +torch==2.4.1+cpu +torchvision==0.19.1+cpu diff --git a/frontends/torch-frontend/torch-cuda-requirements.txt b/frontends/torch-frontend/torch-cuda-requirements.txt index 7be5d08b3..5cae3892e 100644 --- a/frontends/torch-frontend/torch-cuda-requirements.txt +++ b/frontends/torch-frontend/torch-cuda-requirements.txt @@ -1,4 +1,4 @@ --extra-index-url https://download.pytorch.org/whl/cu118 --pre -torch==2.1.0+cu118 -torchvision==0.16.0+cu118 +torch==2.4.1+cu118 +torchvision==0.19.1+cu118 diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_ccl.py b/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_ccl.py deleted file mode 100644 index fe3770fb7..000000000 --- a/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_ccl.py +++ /dev/null @@ -1,140 +0,0 @@ -import torch -from torch.testing import FileCheck - -import torch.distributed as dist -import torch.distributed._functional_collectives as funcol -from torch.testing._internal.common_utils import run_tests - -from utils import with_comms, DistributedTestBase, skip_unless_torch_version_bigger_than - -import torch_frontend -from torch_frontend import compile_dynamo_model - - -class AllReduceModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return funcol.all_reduce(x, "sum", [0, 1, 2, 3]) - - -class AllGatherModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return funcol.all_gather_tensor(x, 0, [0, 1, 2, 3]) - - -class ReduceScatterModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return funcol.reduce_scatter_tensor(x, "sum", 0, [0, 1, 2, 3]) - - -class BroadcastModule(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x): - return funcol.broadcast(x, 2, [0, 1, 2, 3]) - - -class DistributedCollectiveTest(DistributedTestBase): - @property - def world_size(self): - return 4 - - @with_comms - def test_reduce_scatter(self): - module = ReduceScatterModule() - inputs = [torch.tensor([1, 2, 3, 4], dtype=torch.float32)] - prog = torch.export.export(module, tuple(inputs)) - if dist.get_rank() == 0: - module = compile_dynamo_model(prog, "stablehlo") - ir = module.operation.get_asm() - FileCheck().check("@main").check("ccl.reduce_scatter").check( - "axis = 0" - ).check('reduction = "sum"').check("replica_groups = [[0, 1, 2, 3]]").check( - "-> tensor<1xf32>" - ).run( - ir - ) - - @with_comms - def test_all_reduce(self): - module = AllReduceModule() - inputs = [torch.tensor([1, 2, 3, 4], dtype=torch.float32)] - prog = torch.export.export(module, tuple(inputs)) - if dist.get_rank() == 0: - module = compile_dynamo_model(prog, "stablehlo") - ir = module.operation.get_asm() - FileCheck().check("@main").check("ccl.all_reduce").check( - 'reduction = "sum"' - ).check("replica_groups = [[0, 1, 2, 3]]").check("-> tensor<4xf32>").run(ir) - - @with_comms - def test_all_gather(self): - module = AllGatherModule() - inputs = [torch.tensor([1, 2, 3, 4], dtype=torch.float32)] - prog = torch.export.export(module, tuple(inputs)) - if dist.get_rank() == 0: - module = compile_dynamo_model(prog, "stablehlo") - ir = module.operation.get_asm() - FileCheck().check("@main").check("ccl.all_gather").check("axis = 0").check( - "replica_groups = [[0, 1, 2, 3]]" - ).check("-> tensor<16xf32>").run(ir) - - @with_comms - @skip_unless_torch_version_bigger_than(torch_version="2.2") - def test_broadcast(self): - module = BroadcastModule() - inputs = [torch.tensor([1, 2, 3, 4], dtype=torch.float32)] - prog = torch.export.export(module, tuple(inputs)) - if dist.get_rank() == 0: - module = compile_dynamo_model(prog, "stablehlo") - ir = module.operation.get_asm() - FileCheck().check("@main").check("ccl.broadcast").check( - "replica_groups = [[2, 0, 1, 3]]" - ).check("-> tensor<4xf32>").run(ir) - - # TODO: add test for send/recv - - -class MLP(torch.nn.Module): - def __init__(self, hidden_dim, world_size): - super().__init__() - self.hidden_dim = hidden_dim - self.world_size = world_size - self.fc1 = torch.nn.Linear(self.hidden_dim, self.hidden_dim * 4) - self.fc2 = torch.nn.Linear(self.hidden_dim * 4, self.hidden_dim) - - def forward(self, x): - return funcol.all_reduce( - self.fc2(self.fc1(x)), "sum", list(range(self.world_size)) - ) - - -class DistributedCollectiveE2ETest(DistributedTestBase): - @property - def world_size(self): - return 4 - - @with_comms - def test_mlp_e2e(self): - module = MLP(hidden_dim=4, world_size=self.world_size) - x = torch.rand(3, 4) - prog = torch.export.export(module, (x,)) - - module = compile_dynamo_model(prog, "stablehlo") - - if dist.get_rank() == 0: - ir = module.operation.get_asm() - print(ir) - - -if __name__ == "__main__": - run_tests() diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_simple_ops.py b/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_simple_ops.py index 73cb64442..1291d98c2 100644 --- a/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_simple_ops.py +++ b/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_simple_ops.py @@ -27,6 +27,7 @@ def forward(self, x): def test_nonzero(): inputs = (torch.tensor([1, 0, 0, 1, 1]),) prog = torch.export.export(AtenNonZeroModule(), inputs) + print(prog) module = compile_dynamo_model(prog, "raw") print(module.operation.get_asm()) @@ -38,9 +39,10 @@ def forward(self, x): def test_view_dtype(): inputs = (torch.rand(4, 5),) + # note: torch==2.1.0's export has bug on view.dtype prog = torch.export.export(ViewDtypeModule(), inputs) print(prog) - module = compile_dynamo_model(prog, "raw") # note: torch2.1 export has bug on view.dtype + module = compile_dynamo_model(prog, "stablehlo") print(module.operation.get_asm()) # ============================================================================== @@ -60,5 +62,8 @@ def test_mlp(): print(mlir_str) assert "dense_resource" not in mlir_str +# ============================================================================== + if __name__ == "__main__": + test_nonzero() test_view_dtype() \ No newline at end of file diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/utils.py b/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/utils.py deleted file mode 100644 index e674bc3a3..000000000 --- a/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/utils.py +++ /dev/null @@ -1,121 +0,0 @@ -import datetime -import sys -from typing import Any, Callable, Dict, Tuple, TypeVar, cast -from functools import wraps - -import torch -import torch.distributed as dist -from torch.testing._internal.common_distributed import TEST_SKIPS, MultiProcessTestCase, skip_if_lt_x_gpu, TestSkip - -# add new skipped test exit code -TEST_SKIPS["torch-version-2.2"] = TestSkip(90, "Need torch version bigger than 2.2") - -TestFunc = Callable[[object], object] -T = TypeVar("T") -DEVICE_TYPE = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else "cpu" -PG_BACKEND = "nccl" if DEVICE_TYPE == "cuda" else "gloo" - -NUM_DEVICES = 4 - -# We use this as a proxy for "multiple GPUs exist" -if torch.cuda.is_available() and torch.cuda.device_count() > 1: - # when we actually have multiple GPUs, relax the requirement to smaller counts. - NUM_DEVICES = min(NUM_DEVICES, torch.cuda.device_count()) - - -class DistributedTestBase(MultiProcessTestCase): - @property - def world_size(self) -> int: - return NUM_DEVICES - - @property - def backend(self) -> str: - return PG_BACKEND - - def init_pg(self) -> None: - if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: - sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) - - if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl", "meta"]: - raise RuntimeError(f"Backend {self.backend} not supported!") - - dist.init_process_group( - backend=self.backend, - world_size=self.world_size, - rank=self.rank, # pyre-ignore[16] - init_method=f"file://{self.file_name}", # pyre-ignore[16] - timeout=datetime.timedelta(seconds=1200), - ) - - # set device for nccl pg for collectives - if "nccl" in self.backend: - torch.cuda.set_device(self.rank) - - def destroy_pg(self) -> None: - # Wait for all ranks to reach here before starting shutdown. - # FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895 - # dist.all_reduce(torch.zeros((1,), device="cuda" if torch.cuda.is_available() else "cpu")) - # FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs: - # test_dtensor.py -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion - dist.barrier() - dist.destroy_process_group() - - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - -# wrapper to initialize comms (processgroup) -def with_comms(func: TestFunc) -> TestFunc: - assert func is not None - - @wraps(func) # pyre-ignore[6] - def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None: # type: ignore[misc] - # if backend not specified, and cuda available, then use nccl, else gloo - if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size: - self.device_type = "cuda" - else: - self.device_type = "cpu" - - self.init_pg() - func(self, *args, **kwargs) # type: ignore[misc] - self.destroy_pg() - - return wrapper - - -def skip_unless_torch_gpu(method: T) -> T: - """ - Test decorator which skips the test unless there's a GPU available to torch. - - >>> # xdoctest: +SKIP - >>> @skip_unless_torch_gpu - >>> def test_some_method(self) -> None: - >>> ... - """ - # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set. - return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method)) - - -def skip_unless_torch_version_bigger_than(torch_version: str): - """ - Test decorator which skips the test unless current torch version is - bigger than the given number. - - >>> # xdoctest: +SKIP - >>> @skip_unless_torch_version_bigger_than(torch_version="2.2") - >>> def test_some_method(self) -> None: - >>> ... - """ - - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - current_torch_version = torch.__version__ - if current_torch_version >= torch_version: - return func(*args, **kwargs) - sys.exit(TEST_SKIPS[f"torch-version-{torch_version}"].exit_code) - - return wrapper - - return decorator diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_rewrite.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_rewrite.py index 7e7ebd30b..8f90efd26 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_rewrite.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_rewrite.py @@ -179,6 +179,51 @@ def AttnReplacement5(q, k, v, attn_mask, inv_scale): ) +# LLaMA aten attention op pattern +def LLaMAAttnPattern(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim): + transpose_3 = torch.ops.aten.transpose.int(key, 2, 3) + expand_2 = torch.ops.aten.expand.default(query, [batch, num_head, seq_len, head_dim]) + clone = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format) + _unsafe_view_3 = torch.ops.aten._unsafe_view.default(clone, [fused_batch, seq_len, head_dim]) + expand_3 = torch.ops.aten.expand.default(transpose_3, [batch, num_head, head_dim, seq_len]) + clone_1 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format) + _unsafe_view_4 = torch.ops.aten._unsafe_view.default(clone_1, [fused_batch, head_dim, seq_len]) + bmm = torch.ops.aten.bmm.default(_unsafe_view_3, _unsafe_view_4) + _unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm, [batch, num_head, seq_len, seq_len]) + div = torch.ops.aten.div.Tensor(_unsafe_view_5, inv_scale) + add_5 = torch.ops.aten.add.Tensor(div, attn_mask) + maximum = torch.ops.aten.maximum.default(add_5, min_val) + _softmax = torch.ops.aten._softmax.default(maximum, -1, False) + _to_copy_10 = torch.ops.aten._to_copy.default(_softmax, dtype = torch.float16) + expand_4 = torch.ops.aten.expand.default(_to_copy_10, [batch, num_head, seq_len, seq_len]) + view_8 = torch.ops.aten.view.default(expand_4, [fused_batch, seq_len, seq_len]); expand_4 = None + expand_5 = torch.ops.aten.expand.default(value, [batch, num_head, seq_len, head_dim]) + clone_2 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format) + _unsafe_view_6 = torch.ops.aten._unsafe_view.default(clone_2, [fused_batch, seq_len, head_dim]) + bmm_1 = torch.ops.aten.bmm.default(view_8, _unsafe_view_6) + _unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm_1, [batch, num_head, seq_len, head_dim]) + return _softmax, _unsafe_view_5 + + +def LLaMAAttnReplacement(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim): + # q, k, v needs to be transposed for flash attn v2 + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + out, q_pad, k_pad, v_pad, out_pad, softmax_lse, S_dmask, rng_state = torch.ops.byteir.flash_attn_fwd( + query, + key, + value, + 0.0, + 1.0/inv_scale, + True, + False + ) + # output also needs to be transposed + out = out.transpose(1, 2) + return out, out + + def canonicalize_graph_before_replacement(gm): for n in gm.graph.nodes: if n.op == "call_module": @@ -243,4 +288,5 @@ def fx_replace_attn_pattern(gm: torch.fx.GraphModule): torch.fx.replace_pattern(gm, AttnPattern3, AttnReplacement3) torch.fx.replace_pattern(gm, AttnPattern4, AttnReplacement4) torch.fx.replace_pattern(gm, AttnPattern5, AttnReplacement5) + torch.fx.replace_pattern(gm, LLaMAAttnPattern, LLaMAAttnReplacement) return gm diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_utils.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_utils.py index fdf734ac3..6ab060706 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_utils.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/fx_utils.py @@ -124,50 +124,6 @@ def unsafe_index_put_pattern(self, indices, values, accumulate): def unsafe_index_put_replacement(self, indices, values, accumulate): return torch.ops.aten.index_put_.hacked_twin(self, indices, values, accumulate) -# LLaMA aten attention op pattern -def LLaMAAttnPattern(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim): - transpose_3 = torch.ops.aten.transpose.int(key, 2, 3) - expand_2 = torch.ops.aten.expand.default(query, [batch, num_head, seq_len, head_dim]) - clone = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format) - _unsafe_view_3 = torch.ops.aten._unsafe_view.default(clone, [fused_batch, seq_len, head_dim]) - expand_3 = torch.ops.aten.expand.default(transpose_3, [batch, num_head, head_dim, seq_len]) - clone_1 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format) - _unsafe_view_4 = torch.ops.aten._unsafe_view.default(clone_1, [fused_batch, head_dim, seq_len]) - bmm = torch.ops.aten.bmm.default(_unsafe_view_3, _unsafe_view_4) - _unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm, [batch, num_head, seq_len, seq_len]) - div = torch.ops.aten.div.Tensor(_unsafe_view_5, inv_scale) - add_5 = torch.ops.aten.add.Tensor(div, attn_mask) - maximum = torch.ops.aten.maximum.default(add_5, min_val) - _softmax = torch.ops.aten._softmax.default(maximum, -1, False) - _to_copy_10 = torch.ops.aten._to_copy.default(_softmax, dtype = torch.float16) - expand_4 = torch.ops.aten.expand.default(_to_copy_10, [batch, num_head, seq_len, seq_len]) - view_8 = torch.ops.aten.view.default(expand_4, [fused_batch, seq_len, seq_len]); expand_4 = None - expand_5 = torch.ops.aten.expand.default(value, [batch, num_head, seq_len, head_dim]) - clone_2 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format) - _unsafe_view_6 = torch.ops.aten._unsafe_view.default(clone_2, [fused_batch, seq_len, head_dim]) - bmm_1 = torch.ops.aten.bmm.default(view_8, _unsafe_view_6) - _unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm_1, [batch, num_head, seq_len, head_dim]) - return _softmax, _unsafe_view_5 - - -def LLaMAAttnReplacement(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim): - # q, k, v needs to be transposed for flash attn v2 - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - out, q_pad, k_pad, v_pad, out_pad, softmax_lse, S_dmask, rng_state = torch.ops.byteir.flash_attn_fwd( - query, - key, - value, - 0.0, - 1.0/inv_scale, - True, - False - ) - # output also needs to be transposed - out = out.transpose(1, 2) - return out, out - def get_none_indices(fx_g: torch.fx.GraphModule) -> List[int]: none_indices = [] @@ -206,7 +162,6 @@ def preprocess_fx_graph(fx_graph: torch.fx.GraphModule): torch.fx.replace_pattern(fx_graph, squeeze_dims_pattern, squeeze_dims_replacement) torch.fx.replace_pattern(fx_graph, unsafe_index_put_pattern, unsafe_index_put_replacement) - torch.fx.replace_pattern(fx_graph, LLaMAAttnPattern, LLaMAAttnReplacement) was_unwrapped = _unwrap_single_tuple_return(fx_graph) was_list_replaced = _list_return_to_tuple_return(fx_graph) removed_none_indexes = _remove_nones(fx_graph)