Skip to content

Commit

Permalink
fix add export
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Feb 28, 2025
1 parent c4f7c64 commit e7cbede
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tritonbench/components/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .export import export_data
from .export import export_data
19 changes: 15 additions & 4 deletions tritonbench/components/export/export.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
"""
Serialize pickled tensors to directory.
"""
from pathlib import Path

import pickle
from pathlib import Path

from typing import Any, Callable

from tritonbench.utils.input import input_cast

from typing import Callable, Any

def get_input_gradients(inputs):
all_inputs = []
input_cast(lambda x: True, lambda y: all_inputs.append(y), inputs)
return [x.grad for x in all_inputs]

def export_data(x_val: str, inputs: Any, fn_mode: str, fn: Callable, export_type: str, export_dir: str):

def export_data(
x_val: str,
inputs: Any,
fn_mode: str,
fn: Callable,
export_type: str,
export_dir: str,
):
# pickle naming convention
# x_<x_val>-input.pkl
# x_<x_val>-<fn_name>-fwd-output.pkl
# x_<x_val>-<fn_name>-bwd-grad.pkl
assert export_dir, f"Export dir must be specified."
export_path = Path(export_dir)
assert export_path.exists(), f"Export path {export_dir} must exist."
if export_type == "input" or export_type =="both":
if export_type == "input" or export_type == "both":
input_file_name = f"x_{x_val}-input.pkl"
input_file_path = export_path.joinpath(input_file_name)
with open(input_file_path, "wb") as ifp:
Expand Down
1 change: 1 addition & 0 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
def _bwd(out, dout):
out.backward(dout, retain_graph=True)
return out.grad()

o = fwd_fn()
o_tensor = input_filter(
lambda x: isinstance(x, torch.Tensor),
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ def get_parser(args=None):
"--export",
default=None,
choices=["in", "out", "both"],
help="Export input or output. Must be used together with --export-dir."
help="Export input or output. Must be used together with --export-dir.",
)
parser.add_argument(
"--export-dir",
default=None,
type=str,
help="The directory to store input or output."
help="The directory to store input or output.",
)

if IS_FBCODE:
Expand Down
4 changes: 2 additions & 2 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import triton

from tritonbench.components.do_bench import do_bench_wrapper, Latency
from tritonbench.components.ncu import ncu_analyzer, nsys_analyzer
from tritonbench.components.export import export_data
from tritonbench.components.ncu import ncu_analyzer, nsys_analyzer
from tritonbench.utils.env_utils import apply_precision, set_env, set_random_seed
from tritonbench.utils.input import input_cast
from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter
Expand Down Expand Up @@ -1363,7 +1363,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}"
)
from tritonbench.components.ncu import do_bench_in_task

do_bench_in_task(
fn=fn,
grad_to_none=self.get_grad_to_none(self.example_inputs),
Expand Down

0 comments on commit e7cbede

Please sign in to comment.