diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py index 73513492e..a0f916377 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py @@ -40,8 +40,7 @@ def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_soft softmax_lse = torch.empty( (batch_size, num_heads, seqlen_q), dtype=torch.float, device="meta" ) - # p = None - # if return_softmax: + # if not return_softmax: p = None p = torch.empty( (batch_size, num_heads, seqlen_q, seqlen_k), dtype=torch.float, diff --git a/tests/build_and_test_e2e.sh b/tests/build_and_test_e2e.sh index de913e58e..ab98d1022 100755 --- a/tests/build_and_test_e2e.sh +++ b/tests/build_and_test_e2e.sh @@ -20,11 +20,11 @@ pip3 install $ROOT_PROJ_DIR/external/AITemplate/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/compiler/build/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/runtime/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/frontends/torch-frontend/build/torch-frontend/python/dist/*.whl -pip3 install flash_attn==2.5.3 source scripts/prepare.sh install_mhlo_tools # numerical test +# pip3 install flash_attn==2.5.3 python3 tests/numerical_test/main.py --target all rm -rf ./local_test diff --git a/tests/numerical_test/torch_dynamo_e2e_testing/backend.py b/tests/numerical_test/torch_dynamo_e2e_testing/backend.py index 8fbce14a8..4f2dfac74 100644 --- a/tests/numerical_test/torch_dynamo_e2e_testing/backend.py +++ b/tests/numerical_test/torch_dynamo_e2e_testing/backend.py @@ -96,18 +96,18 @@ def byteir_compile_fx_inner( backend_legal_ops = BYTEIR_CUSTOM_OPS + GENERIC_CUSTOM_OPS with maybe_disable_fake_tensor_mode(): compiled_graph = torch_frontend.compile_dynamo_model( - fx_graph, compile_type, backend_legal_ops=backend_legal_ops, verbose=True + fx_graph, compile_type, backend_legal_ops=backend_legal_ops ) # print(compiled_graph) model_name = "test" - TEMP_FOLDER = "./temp" + TEMP_FOLDER = "./local_test" category_name = f"{category}_{next(g_graph_id)}" os.makedirs(TEMP_FOLDER, exist_ok=True) os.makedirs(TEMP_FOLDER + f"/{model_name}_{category_name}", exist_ok=True) - mlir_file_name = f"{TEMP_FOLDER}/{model_name}_{category_name}.{compile_type}.mlir" + mlir_file_name = f"{TEMP_FOLDER}/{model_name}_{category_name}/{model_name}_{category_name}.{compile_type}.mlir" output_mlir_file_name = ( - f"{TEMP_FOLDER}/{model_name}_{category}/{model_name}_{category_name}.rt.mlir" + f"{TEMP_FOLDER}/{model_name}_{category_name}/{model_name}_{category_name}.rt.mlir" ) with open(mlir_file_name, "w+") as fout: compiled_graph.operation.print(file=fout, large_elements_limit=None) @@ -159,7 +159,7 @@ def byteir_compile_fx(model_: torch.fx.GraphModule, example_inputs_): def partition_fn(graph, joint_inputs, **kwargs): joint_graph_passes(graph) - return min_cut_rematerialization_partition( + return min_cut_rematerialization_partition(c graph, joint_inputs, **kwargs, compiler="inductor" ) diff --git a/tests/numerical_test/torch_dynamo_e2e_testing/execute.py b/tests/numerical_test/torch_dynamo_e2e_testing/execute.py index fe9c106f6..30a9d910c 100644 --- a/tests/numerical_test/torch_dynamo_e2e_testing/execute.py +++ b/tests/numerical_test/torch_dynamo_e2e_testing/execute.py @@ -13,7 +13,10 @@ # ============================================================================== from .test_suite.test_flash_attn import test_flash_attn_unit, test_flash_attn_functional_unit, test_flash_attn_kvcache -SM80_PLUS_DYNAMO_TESTS = [test_flash_attn_unit, test_flash_attn_functional_unit, test_flash_attn_kvcache] +SM80_PLUS_DYNAMO_TESTS = [test_flash_attn_unit, + test_flash_attn_functional_unit, + # test_flash_attn_kvcache, + ] def run_torch_dynamo_tests(arch): if arch >= 80: for test in SM80_PLUS_DYNAMO_TESTS: