diff --git a/.github/workflows/_linux-benchmark-h100.yml b/.github/workflows/_linux-benchmark-h100.yml index bd47a5f4..9ee1ba41 100644 --- a/.github/workflows/_linux-benchmark-h100.yml +++ b/.github/workflows/_linux-benchmark-h100.yml @@ -6,14 +6,6 @@ on: required: True description: | Tritonbench Scribe Graph Access Token - AWS_ACCESS_KEY_ID: - required: True - description: | - AWS S3 bucket access key - AWS_SECRET_ACCESS_KEY: - required: True - description: | - AWS S3 bucket secret access key inputs: benchmark_name: required: True diff --git a/.github/workflows/compile-time.yaml b/.github/workflows/compile-time.yaml index 768da357..4ad2c95f 100644 --- a/.github/workflows/compile-time.yaml +++ b/.github/workflows/compile-time.yaml @@ -18,8 +18,6 @@ jobs: benchmark_name: "compile_time" secrets: TRITONBENCH_SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.TRITONBENCH_SCRIBE_GRAPHQL_ACCESS_TOKEN }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index 56c59c73..2f7f60dc 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -1,8 +1,7 @@ -import argparse import logging import unittest -from typing import List, Optional +from typing import Dict, List import yaml @@ -36,6 +35,8 @@ # Ops that we run forward only FWD_ONLY_OPS = skip_tests.get("fwd_only_ops", []) +# Ops that require special arguments in backwards +BWD_ARGS_OPS: Dict[str, List[str]] = skip_tests.get("bwd_args", {}) TEST_OPERATORS = set(list_operators_by_collection(op_collection="default")) @@ -77,6 +78,8 @@ def _run_one_operator(args: List[str]): if op.has_bwd(): del op tb_args.mode = "bwd" + if tb_args.op in BWD_ARGS_OPS: + extra_args.extend(BWD_ARGS_OPS[tb_args.op]) op = Operator(tb_args=tb_args, extra_args=extra_args) op.run() check_ci_output(op) @@ -101,6 +104,8 @@ def _run_operator_in_task(op: str, args: List[str]): if task.get_attribute("has_bwd", method=True): task.del_op_instance() args.extend(["--bwd"]) + if op in BWD_ARGS_OPS: + args.extend(BWD_ARGS_OPS[op]) task.make_operator_instance(args=args) task.run() task.check_output() diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index 3d742596..c1e1112a 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -40,3 +40,7 @@ test_op: # TODO: decoding attention requires updated xformers and flash_attn # Which will RAM OOM on the CI machine decoding_attention: +bwd_args: + # flash_attention/triton_tutorial_flash_v2 does not support non-causal in backward + flash_attention: + - --causal diff --git a/test/test_gpu/skip_tests_h100_triton_main.yaml b/test/test_gpu/skip_tests_h100_triton_main.yaml index d3020f08..44b4b0fb 100644 --- a/test/test_gpu/skip_tests_h100_triton_main.yaml +++ b/test/test_gpu/skip_tests_h100_triton_main.yaml @@ -38,10 +38,11 @@ test_op: # TODO: decoding attention requires updated xformers and flash_attn # Which will RAM OOM on the CI machine decoding_attention: -# FIXME: PT2 is broken with Triton-main -launch_latency: addmm: gemm: -flash_attention: gather_gemv: layer_norm: +bwd_args: + # flash_attention/triton_tutorial_flash_v2 does not support non-causal in backward + flash_attention: + - --causal