Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Nov 23, 2024
1 parent 59893b6 commit f69b855
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_soft
softmax_lse = torch.empty(
(batch_size, num_heads, seqlen_q), dtype=torch.float, device="meta"
)
# p = None
# if return_softmax:
# if not return_softmax: p = None
p = torch.empty(
(batch_size, num_heads, seqlen_q, seqlen_k),
dtype=torch.float,
Expand Down
2 changes: 1 addition & 1 deletion tests/build_and_test_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ pip3 install $ROOT_PROJ_DIR/external/AITemplate/python/dist/*.whl
pip3 install $ROOT_PROJ_DIR/compiler/build/python/dist/*.whl
pip3 install $ROOT_PROJ_DIR/runtime/python/dist/*.whl
pip3 install $ROOT_PROJ_DIR/frontends/torch-frontend/build/torch-frontend/python/dist/*.whl
pip3 install flash_attn==2.5.3
source scripts/prepare.sh
install_mhlo_tools

# numerical test
# pip3 install flash_attn==2.5.3
python3 tests/numerical_test/main.py --target all
rm -rf ./local_test

Expand Down
10 changes: 5 additions & 5 deletions tests/numerical_test/torch_dynamo_e2e_testing/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,18 @@ def byteir_compile_fx_inner(
backend_legal_ops = BYTEIR_CUSTOM_OPS + GENERIC_CUSTOM_OPS
with maybe_disable_fake_tensor_mode():
compiled_graph = torch_frontend.compile_dynamo_model(
fx_graph, compile_type, backend_legal_ops=backend_legal_ops, verbose=True
fx_graph, compile_type, backend_legal_ops=backend_legal_ops
)
# print(compiled_graph)

model_name = "test"
TEMP_FOLDER = "./temp"
TEMP_FOLDER = "./local_test"
category_name = f"{category}_{next(g_graph_id)}"
os.makedirs(TEMP_FOLDER, exist_ok=True)
os.makedirs(TEMP_FOLDER + f"/{model_name}_{category_name}", exist_ok=True)
mlir_file_name = f"{TEMP_FOLDER}/{model_name}_{category_name}.{compile_type}.mlir"
mlir_file_name = f"{TEMP_FOLDER}/{model_name}_{category_name}/{model_name}_{category_name}.{compile_type}.mlir"
output_mlir_file_name = (
f"{TEMP_FOLDER}/{model_name}_{category}/{model_name}_{category_name}.rt.mlir"
f"{TEMP_FOLDER}/{model_name}_{category_name}/{model_name}_{category_name}.rt.mlir"
)
with open(mlir_file_name, "w+") as fout:
compiled_graph.operation.print(file=fout, large_elements_limit=None)
Expand Down Expand Up @@ -159,7 +159,7 @@ def byteir_compile_fx(model_: torch.fx.GraphModule, example_inputs_):

def partition_fn(graph, joint_inputs, **kwargs):
joint_graph_passes(graph)
return min_cut_rematerialization_partition(
return min_cut_rematerialization_partition(c
graph, joint_inputs, **kwargs, compiler="inductor"
)

Expand Down
5 changes: 4 additions & 1 deletion tests/numerical_test/torch_dynamo_e2e_testing/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# ==============================================================================
from .test_suite.test_flash_attn import test_flash_attn_unit, test_flash_attn_functional_unit, test_flash_attn_kvcache

SM80_PLUS_DYNAMO_TESTS = [test_flash_attn_unit, test_flash_attn_functional_unit, test_flash_attn_kvcache]
SM80_PLUS_DYNAMO_TESTS = [test_flash_attn_unit,
test_flash_attn_functional_unit,
# test_flash_attn_kvcache,
]
def run_torch_dynamo_tests(arch):
if arch >= 80:
for test in SM80_PLUS_DYNAMO_TESTS:
Expand Down

0 comments on commit f69b855

Please sign in to comment.