Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ci] Enable more tests #157

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions .github/workflows/_linux-benchmark-h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@ on:
required: True
description: |
Tritonbench Scribe Graph Access Token
AWS_ACCESS_KEY_ID:
required: True
description: |
AWS S3 bucket access key
AWS_SECRET_ACCESS_KEY:
required: True
description: |
AWS S3 bucket secret access key
inputs:
benchmark_name:
required: True
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/compile-time.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ jobs:
benchmark_name: "compile_time"
secrets:
TRITONBENCH_SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.TRITONBENCH_SCRIBE_GRAPHQL_ACCESS_TOKEN }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}



Expand Down
9 changes: 7 additions & 2 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import argparse
import logging
import unittest

from typing import List, Optional
from typing import Dict, List

import yaml

Expand Down Expand Up @@ -36,6 +35,8 @@

# Ops that we run forward only
FWD_ONLY_OPS = skip_tests.get("fwd_only_ops", [])
# Ops that require special arguments in backwards
BWD_ARGS_OPS: Dict[str, List[str]] = skip_tests.get("bwd_args", {})

TEST_OPERATORS = set(list_operators_by_collection(op_collection="default"))

Expand Down Expand Up @@ -77,6 +78,8 @@ def _run_one_operator(args: List[str]):
if op.has_bwd():
del op
tb_args.mode = "bwd"
if tb_args.op in BWD_ARGS_OPS:
extra_args.extend(BWD_ARGS_OPS[tb_args.op])
op = Operator(tb_args=tb_args, extra_args=extra_args)
op.run()
check_ci_output(op)
Expand All @@ -101,6 +104,8 @@ def _run_operator_in_task(op: str, args: List[str]):
if task.get_attribute("has_bwd", method=True):
task.del_op_instance()
args.extend(["--bwd"])
if op in BWD_ARGS_OPS:
args.extend(BWD_ARGS_OPS[op])
task.make_operator_instance(args=args)
task.run()
task.check_output()
Expand Down
4 changes: 4 additions & 0 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ test_op:
# TODO: decoding attention requires updated xformers and flash_attn
# Which will RAM OOM on the CI machine
decoding_attention:
bwd_args:
# flash_attention/triton_tutorial_flash_v2 does not support non-causal in backward
flash_attention:
- --causal
7 changes: 4 additions & 3 deletions test/test_gpu/skip_tests_h100_triton_main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ test_op:
# TODO: decoding attention requires updated xformers and flash_attn
# Which will RAM OOM on the CI machine
decoding_attention:
# FIXME: PT2 is broken with Triton-main
launch_latency:
addmm:
gemm:
flash_attention:
gather_gemv:
layer_norm:
bwd_args:
# flash_attention/triton_tutorial_flash_v2 does not support non-causal in backward
flash_attention:
- --causal