From 72d77a10630b8cdcccdb2aff7432b74407843c3a Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 9 Oct 2024 16:22:43 -0700 Subject: [PATCH 1/7] WIP - initial logging updates --- py/torch_migraphx/__init__.py | 14 ++++- py/torch_migraphx/fx/fx2mgx.py | 17 ++++-- py/torch_migraphx/fx/lower.py | 55 ++++++++++++------- py/torch_migraphx/fx/lower_setting.py | 1 - .../fx/tracer/acc_tracer/acc_tracer.py | 5 +- py/torch_migraphx/fx/utils.py | 47 +++++++++++++--- 6 files changed, 103 insertions(+), 36 deletions(-) diff --git a/py/torch_migraphx/__init__.py b/py/torch_migraphx/__init__.py index 332cbeec..02ed7ef5 100644 --- a/py/torch_migraphx/__init__.py +++ b/py/torch_migraphx/__init__.py @@ -4,4 +4,16 @@ from torch_migraphx import fx, _C if version.parse(_torch_version) >= version.parse("2.1.0"): - from torch_migraphx import dynamo \ No newline at end of file + from torch_migraphx import dynamo + +import logging +import os +import sys + +LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOGLEVEL', 'WARNING').upper() +logging.basicConfig( + level=LOGLEVEL, + stream=sys.stderr, + format= + '%(asctime)s.%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s\n', + datefmt='%Y-%m-%d:%H:%M:%S') diff --git a/py/torch_migraphx/fx/fx2mgx.py b/py/torch_migraphx/fx/fx2mgx.py index 2e6d315b..b2d2c72d 100644 --- a/py/torch_migraphx/fx/fx2mgx.py +++ b/py/torch_migraphx/fx/fx2mgx.py @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### from typing import Iterable -import warnings +import os import torch import torch.fx import migraphx @@ -36,11 +36,17 @@ from .utils import * from .mgx_module import MGXInstruction from .converters.utils import convert_arg +import logging + +_LOGGER = logging.getLogger(__name__) +INTERPRETER_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_INTERPRETER', None) +if INTERPRETER_LOGLEVEL: + _LOGGER.setLevel(INTERPRETER_LOGLEVEL) class MGXInterpreter(torch.fx.Interpreter): - def __init__(self, module, sample_inputs, verbose_log=False): + def __init__(self, module, sample_inputs): super().__init__(module) self.program = migraphx.program() @@ -52,7 +58,7 @@ def __init__(self, module, sample_inputs, verbose_log=False): self._outputs = [] self.unsupported_ops = self.validate_conversion() if self.unsupported_ops: - warnings.warn( + _LOGGER.warning( 'Torch model contains the following unsupported operations: \n' + '\n'.join(f'{i}' for i in self.unsupported_ops)) @@ -73,12 +79,15 @@ def validate_conversion(self): return missing_converters def run(self): + _LOGGER.info(f"Running MGXInterpreter for:\n{self.module.graph}") super().run() output_instr_refs = [i.instr_ref for i in self._outputs] self.mm.add_return(output_instr_refs) + _LOGGER.info(f"Parsed MIGraphX Program:\n{self.program}") return self.program def run_node(self, n): + _LOGGER.debug(f"MGXInterpreter running node:\n{get_node_info(n)}") args, kwargs = self.fetch_args_kwargs_from_env(n) assert isinstance(args, tuple) assert isinstance(kwargs, dict) @@ -101,9 +110,7 @@ def placeholder(self, node, args, kwargs): def call_module(self, node, args, kwargs): assert isinstance(node.target, str) - # print(f'call module: {args}') submod = self.fetch_attr(node.target) - # submod_type = getattr(submod, '_base_class_origin', type(submod)) submod_type = type(submod) converter = CONVERTERS.get(submod_type) diff --git a/py/torch_migraphx/fx/lower.py b/py/torch_migraphx/fx/lower.py index d0b45e57..581bc9a6 100644 --- a/py/torch_migraphx/fx/lower.py +++ b/py/torch_migraphx/fx/lower.py @@ -28,6 +28,7 @@ ##################################################################################### import dataclasses as dc import logging +import os from typing import Any, Callable, Optional, Sequence import migraphx @@ -46,9 +47,12 @@ from .tracer.acc_tracer import acc_tracer from .tracer.aten_tracer import aten_tracer from .mgx_module import MGXModule -from .utils import LowerPrecision +from .utils import LowerPrecision, SuppressPrints, SetLogLevel, get_graph_info -logger = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) +LOWERER_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_LOWER', None) +if LOWERER_LOGLEVEL: + _LOGGER.setLevel(LOWERER_LOGLEVEL) Input = Sequence[Any] @@ -88,18 +92,24 @@ def lower_to_mgx(module: nn.Module, """ module = module.cpu().eval() input = [to_device(x, "cpu") for x in input] - lower_setting = LowerSetting( - lower_precision=lower_precision, - verbose_log=verbose_log, - min_acc_module_size=min_acc_module_size, - suppress_accuracy_check=suppress_accuracy_check, - save_subgraph_programs=save_subgraph_programs, - tracer_base_cls=tracer_base_cls, - leaf_module_list=leaf_modules, - use_aten=use_aten, - ) - lowerer = Lowerer.create(lower_setting=lower_setting) - return lowerer(module, input) + + if verbose_log and _LOGGER.level > logging.INFO: + log_level = logging.INFO + else: + log_level = _LOGGER.level + + with SetLogLevel(_LOGGER, log_level): + lower_setting = LowerSetting( + lower_precision=lower_precision, + min_acc_module_size=min_acc_module_size, + suppress_accuracy_check=suppress_accuracy_check, + save_subgraph_programs=save_subgraph_programs, + tracer_base_cls=tracer_base_cls, + leaf_module_list=leaf_modules, + use_aten=use_aten, + ) + lowerer = Lowerer.create(lower_setting=lower_setting) + return lowerer(module, input) @dc.dataclass @@ -111,10 +121,12 @@ def create(cls, lower_setting): return LowerMgxInterpreter(lower_setting) def __call__(self, mod, input, split_name) -> MGXInterpreter: - logger.info(f"split_name={split_name}") + _LOGGER.info(f"Running MGXInterpreter for {split_name}") AccShapeProp(mod).propagate(*input) - interpreter = MGXInterpreter( - mod, input, verbose_log=self.lower_setting.verbose_log) + input_shapes = [(i.shape, i.dtype) for i in input] + _LOGGER.debug(f"Input Shapes: {input_shapes}") + _LOGGER.debug(f"{split_name} Graph:\n{get_graph_info(mod.graph)}") + interpreter = MGXInterpreter(mod, input) interpreter.run() return interpreter @@ -125,8 +137,9 @@ def default_split_function(model: fx.GraphModule, inputs: Input, splitter_setting = MGXSplitterSetting() splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size splitter = MGXSplitter(model, inputs, settings=splitter_setting) - if lower_setting.verbose_log: - splitter.node_support_preview() + with SuppressPrints(): + node_preview = splitter.node_support_preview() + _LOGGER.info(f"\n{node_preview}") return splitter.generate_split_results() @@ -147,6 +160,8 @@ def lower_pass(mod: nn.Module, input: Input, lower_setting: LowerSetting, """ interpreter = create_mgx_interpreter(lower_setting) interp_res: MGXInterpreter = interpreter(mod, input, module_name) + + _LOGGER.debug(f"Interpreted MIGraphX Program:\n{interp_res.program}") if lower_setting.save_subgraph_programs: migraphx.save(interp_res.program, f'{module_name}.mxr') @@ -158,6 +173,8 @@ def lower_pass(mod: nn.Module, input: Input, lower_setting: LowerSetting, input_names=interp_res.get_input_names(), quantize_fp16=fp16_mode, ) + + _LOGGER.debug(f"Compiled MIGraphX Program:\n{mgx_module.program}") return mgx_module return lower_pass diff --git a/py/torch_migraphx/fx/lower_setting.py b/py/torch_migraphx/fx/lower_setting.py index e5a1d984..e237ca57 100644 --- a/py/torch_migraphx/fx/lower_setting.py +++ b/py/torch_migraphx/fx/lower_setting.py @@ -75,7 +75,6 @@ class LowerSetting(LowerSettingBasic): instance of Lowerer. """ - verbose_log: bool = False explicit_precision: bool = False preset_lowerer: str = "" suppress_accuracy_check: bool = False diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py index 2f8213cc..4a1bba9f 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py @@ -591,9 +591,8 @@ def trace( or a callable (e.g. for op == "call_function"). """ if mod.training: - warnings.warn( - "acc_tracer does not support currently support models for training." - " Calling eval on model before tracing.") + _LOGGER.warning("acc_tracer does not support currently support models for training." + " Calling eval on model before tracing.") mod.eval() assert isinstance(sample_inputs, (list, tuple)) diff --git a/py/torch_migraphx/fx/utils.py b/py/torch_migraphx/fx/utils.py index 098cea3b..f3f1fcd7 100644 --- a/py/torch_migraphx/fx/utils.py +++ b/py/torch_migraphx/fx/utils.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### import os +import sys from enum import Enum from typing import List, Callable from packaging import version @@ -41,6 +42,31 @@ HIPSTREAMTYPE = 'ihipStream_t' +class SuppressPrints: + + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + + +class SetLogLevel: + + def __init__(self, logger, level): + self.logger = logger + self.level = level + + def __enter__(self): + self.original_level = self.logger.level + self.logger.setLevel(self.level) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.logger.setLevel(self.original_level) + + class LowerPrecision(Enum): FP32 = "fp32" FP16 = "fp16" @@ -183,16 +209,23 @@ def mgx_program_from_bytearray(barray: bytearray) -> migraphx.program: return prog -def print_graph(graph: torch.fx.Graph) -> None: +def get_node_info(node: torch.fx.Node) -> str: + node_info = 'Return' if node.op == 'output' else node.format_node() + out_str = f"{node_info}, args: {node.args}, kwargs: {node.kwargs}" + if 'tensor_meta' in node.meta: + out_str += tensor_meta_str(node.meta['tensor_meta']) + return out_str + + +def get_graph_info(graph: torch.fx.Graph) -> str: + out_str = "" for node in graph.nodes: - node_info = 'Return' if node.op == 'output' else node.format_node() - out_str = f"{node_info}, args: {node.args}, kwargs: {node.kwargs}" - if 'tensor_meta' in node.meta: - out_str += tensor_meta_str(node.meta['tensor_meta']) + out_str += f"\n\t{get_node_info(node)}\n" + return out_str - print(out_str) - print() +def print_graph(graph: torch.fx.Graph) -> None: + print(get_graph_info(graph)) def tensor_meta_str(tm) -> str: From 9baf1984dbc143d4ffce34be87a54f0a36b2e76d Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 10 Oct 2024 16:03:31 -0700 Subject: [PATCH 2/7] add logging for dynamo --- py/torch_migraphx/dynamo/lower_dynamo.py | 59 +++++++++++-------- .../dynamo/passes/fix_tensor_meta.py | 10 +++- py/torch_migraphx/dynamo/passes/partition.py | 45 ++++++++++---- .../dynamo/passes/pass_manager.py | 9 +++ py/torch_migraphx/dynamo/passes/remove_ops.py | 15 ++++- py/torch_migraphx/dynamo/passes/utils.py | 52 ++++++++++++++++ py/torch_migraphx/dynamo/utils.py | 11 ++-- py/torch_migraphx/fx/lower.py | 2 +- 8 files changed, 158 insertions(+), 45 deletions(-) create mode 100644 py/torch_migraphx/dynamo/passes/utils.py diff --git a/py/torch_migraphx/dynamo/lower_dynamo.py b/py/torch_migraphx/dynamo/lower_dynamo.py index e500a234..3933f3cc 100644 --- a/py/torch_migraphx/dynamo/lower_dynamo.py +++ b/py/torch_migraphx/dynamo/lower_dynamo.py @@ -28,6 +28,8 @@ ##################################################################################### from typing import Sequence +import logging +import os import torch from torch.fx.passes.shape_prop import ShapeProp @@ -39,7 +41,12 @@ from .passes.pass_manager import pre_partition_pass, post_partition_pass from .passes.partition import partition, get_partition_inputs -from .utils import print_graph_info +from .utils import get_input_info, get_graph_info, SetLogLevel + +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_LOWERING', None) +if DYNAMO_LOGLEVEL: + _LOGGER.setLevel(DYNAMO_LOGLEVEL) def lower_aten_to_mgx(gm: torch.fx.GraphModule, @@ -58,29 +65,36 @@ def lower_aten_to_mgx(gm: torch.fx.GraphModule, torch.fx.GraphModule: GraphModule contatning MGXModule objects for supported subgraphs """ verbose = kwargs['verbose'] if 'verbose' in kwargs else False - if verbose: - print_graph_info('Traced Model', gm, example_inputs) optim_gm = pre_partition_pass(gm) - patitioned_gm = partition(optim_gm, verbose=verbose) + partitioned_gm = partition(optim_gm, verbose=verbose) + + log_level = min(_LOGGER.level, logging.INFO) if verbose else _LOGGER.level + with SetLogLevel(_LOGGER, log_level): + for name, mod in partitioned_gm.named_children(): + # Const folded params can show up as "child objects" + if not isinstance(mod, torch.fx.GraphModule): + continue - for name, mod in patitioned_gm.named_children(): - # Const folded params can show up as "child objects" - if not isinstance(mod, torch.fx.GraphModule): - continue + mod = post_partition_pass(mod) + partition_inputs = get_partition_inputs(partitioned_gm, mod, + example_inputs) - mod = post_partition_pass(mod) - partition_inputs = get_partition_inputs(optim_gm, mod, example_inputs) - if verbose: - print_graph_info(name, mod, partition_inputs) + _LOGGER.info(f"Lowering subgraph: {name}") + _LOGGER.info( + f"Subgraph inputs: {get_input_info(partition_inputs)}") + _LOGGER.info(f"Subgraph:\n{get_graph_info(mod.graph)}") - mgx_mod = lower_subgraph(mod, partition_inputs, name=name, **kwargs) + mgx_mod = lower_subgraph(mod, + partition_inputs, + name=name, + **kwargs) - setattr(optim_gm, name, mgx_mod) - del mod - del partition_inputs + setattr(partitioned_gm, name, mgx_mod) + del mod + del partition_inputs - return optim_gm + return partitioned_gm # @validate_inference(0.1, 0.1) @@ -96,30 +110,25 @@ def lower_subgraph(module: torch.fx.GraphModule, MGXModule: Callable module that executes graph via MIGraphX """ - verbose = kwargs['verbose'] if 'verbose' in kwargs else False fp16 = kwargs['fp16'] if 'fp16' in kwargs else False exhaustive_tune = kwargs[ 'exhaustive_tune'] if 'exhaustive_tune' in kwargs else False save_mxr = kwargs['save_mxr'] if 'save_mxr' in kwargs else False - print_uncompiled = (kwargs['print_parsed_program'] - if 'print_parsed_program' in kwargs else False) - print_compiled = (kwargs['print_compiled_program'] - if 'print_compiled_program' in kwargs else False) - interpreter = MGXInterpreter(module, inputs, verbose_log=verbose) + interpreter = MGXInterpreter(module, inputs) interpreter.run() if save_mxr: name = f"{kwargs['name']}.mxr" if 'name' in kwargs else "prog.mxr" migraphx.save(interpreter.program, name) - if print_uncompiled: interpreter.program.print() + _LOGGER.debug(f"Interpreted Program:\n{interpreter.program}") mgx_module = MGXModule(program=interpreter.program, input_names=interpreter.get_input_names(), quantize_fp16=fp16, exhaustive_tune=exhaustive_tune) - if print_compiled: mgx_module.program.print() + _LOGGER.debug(f"Compiled Program:\n{mgx_module.program}") return mgx_module diff --git a/py/torch_migraphx/dynamo/passes/fix_tensor_meta.py b/py/torch_migraphx/dynamo/passes/fix_tensor_meta.py index 2491a097..8c99fd89 100644 --- a/py/torch_migraphx/dynamo/passes/fix_tensor_meta.py +++ b/py/torch_migraphx/dynamo/passes/fix_tensor_meta.py @@ -26,11 +26,19 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### - +import logging +import os import torch import operator +from .utils import log_pass + +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None) +if DYNAMO_LOGLEVEL: + _LOGGER.setLevel(DYNAMO_LOGLEVEL) +@log_pass(_LOGGER, logging.DEBUG) def fix_tensor_meta(gm: torch.fx.GraphModule): for node in gm.graph.nodes: # This is only true for functions with multiple outputs diff --git a/py/torch_migraphx/dynamo/passes/partition.py b/py/torch_migraphx/dynamo/passes/partition.py index fdfafea3..2bd2ce06 100644 --- a/py/torch_migraphx/dynamo/passes/partition.py +++ b/py/torch_migraphx/dynamo/passes/partition.py @@ -28,15 +28,24 @@ ##################################################################################### from typing import Dict, Optional, Sequence, Mapping +import logging +import os import torch from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport from torch_migraphx.fx.converter_registry import CONVERTERS -from ..utils import print_graph_info +from ..utils import get_graph_info, SetLogLevel from ...fx.utils import TYPE_MAP +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None) +PARTITIONER_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_PARTITIONER', + DYNAMO_LOGLEVEL) +if PARTITIONER_LOGLEVEL: + _LOGGER.setLevel(PARTITIONER_LOGLEVEL) + class MGXOperatorSupport(OperatorSupport): '''Construct OperatorSupport object used for partitioning based on registered converters''' @@ -66,14 +75,19 @@ def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], self.unsupported.add(node.target) return False - def print_support_summary(self): - print('Supported Nodes: ') + def support_summary(self): + summary = "Supported Nodes:\n" for n in self.supported: - print(n) + summary += f"\t{n}\n" - print('\nUnsupported Nodes: ') + summary += "\nUnsupported Nodes:\n" for n in self.unsupported: - print(n) + summary += f"\t{n}\n" + + return summary + + def print_support_summary(self): + print(self.support_summary()) def partition(gm: torch.fx.GraphModule, @@ -90,20 +104,25 @@ def partition(gm: torch.fx.GraphModule, op_support = MGXOperatorSupport() partitioner = CapabilityBasedPartitioner(gm, op_support) - partitons = partitioner.propose_partitions() - fused_gm = partitioner.fuse_partitions(partitons) + partitions = partitioner.propose_partitions() + fused_gm = partitioner.fuse_partitions(partitions) fused_gm.graph.eliminate_dead_code() fused_gm.recompile() fused_gm.delete_all_unused_submodules() - if verbose: - print_graph_info("Partitioned Module", fused_gm, None) - op_support.print_support_summary() + log_level = _LOGGER.level + if verbose and _LOGGER.level > logging.INFO: + log_level = logging.INFO + + with SetLogLevel(_LOGGER, log_level): + _LOGGER.debug(f"Partitioned Module:\n{get_graph_info(fused_gm.graph)}") + _LOGGER.info(f"Node support summary:\n{op_support.support_summary()}") + _LOGGER.info(f"Number of partitions: {len(partitions)}") # TODO: Compute number of partitions after dead code elimination - if len(partitons) > max_partitions: + if len(partitions) > max_partitions: raise RuntimeError( - f'Found {len(partitons)} partitions, max allowed: {max_partitions}.' + f'Found {len(partitions)} partitions, max allowed: {max_partitions}.' ) return fused_gm diff --git a/py/torch_migraphx/dynamo/passes/pass_manager.py b/py/torch_migraphx/dynamo/passes/pass_manager.py index 64e71cea..ea55b65b 100644 --- a/py/torch_migraphx/dynamo/passes/pass_manager.py +++ b/py/torch_migraphx/dynamo/passes/pass_manager.py @@ -26,6 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### +import logging +import os import torch from torch.fx.passes.pass_manager import PassManager @@ -34,7 +36,12 @@ from .promote_types import promote_inputs from .remove_empty_slice import remove_empty_slices from .fix_tensor_meta import fix_tensor_meta +from ..utils import get_graph_info +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None) +if DYNAMO_LOGLEVEL: + _LOGGER.setLevel(DYNAMO_LOGLEVEL) class MGXPassManager(PassManager): @@ -51,6 +58,7 @@ def pre_partition_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: const_fold, ] pre_partition_pass_mgr = MGXPassManager(passes) + _LOGGER.info(f"Pre Partition Pass In:\n{get_graph_info(gm.graph)}") return pre_partition_pass_mgr(gm) @@ -59,4 +67,5 @@ def post_partition_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: fix_tensor_meta, ] post_partition_pass_mgr = MGXPassManager(passes) + _LOGGER.info(f"Post Partition Pass In:\n{get_graph_info(gm.graph)}") return post_partition_pass_mgr(gm) diff --git a/py/torch_migraphx/dynamo/passes/remove_ops.py b/py/torch_migraphx/dynamo/passes/remove_ops.py index e4cd899b..1a3189cc 100644 --- a/py/torch_migraphx/dynamo/passes/remove_ops.py +++ b/py/torch_migraphx/dynamo/passes/remove_ops.py @@ -26,11 +26,19 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### - +import logging +import os import torch import torch.fx +from .utils import log_pass + +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None) +if DYNAMO_LOGLEVEL: + _LOGGER.setLevel(DYNAMO_LOGLEVEL) +@log_pass(_LOGGER, logging.DEBUG) def remove_clone_ops(gm: torch.fx.GraphModule): clone_ops = [ torch.ops.aten.clone.default, @@ -46,6 +54,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule): return gm +@log_pass(_LOGGER, logging.DEBUG) def remove_view_ops(gm: torch.fx.GraphModule): view_ops = [ torch.ops.aten._unsafe_view.default, @@ -69,8 +78,10 @@ def remove_view_ops(gm: torch.fx.GraphModule): return gm +@log_pass(_LOGGER, logging.DEBUG) def remove_const_ops(gm: torch.fx.GraphModule, device: str = "cuda"): + @log_pass(_LOGGER, logging.DEBUG) def _remove_new_const_ops(gm: torch.fx.GraphModule): const_ops = { torch.ops.aten.new_zeros.default: torch.zeros, @@ -98,6 +109,7 @@ def _remove_new_const_ops(gm: torch.fx.GraphModule): gm.graph.eliminate_dead_code() gm.recompile() + @log_pass(_LOGGER, logging.DEBUG) def _remove_const_like_ops(gm: torch.fx.GraphModule): const_ops = { torch.ops.aten.full_like.default: torch.full, @@ -129,6 +141,7 @@ def _remove_const_like_ops(gm: torch.fx.GraphModule): gm.graph.eliminate_dead_code() gm.recompile() + @log_pass(_LOGGER, logging.DEBUG) def _remove_range_ops(gm: torch.fx.GraphModule): const_ops = { torch.ops.aten.arange.start: torch.arange, diff --git a/py/torch_migraphx/dynamo/passes/utils.py b/py/torch_migraphx/dynamo/passes/utils.py new file mode 100644 index 00000000..5df50354 --- /dev/null +++ b/py/torch_migraphx/dynamo/passes/utils.py @@ -0,0 +1,52 @@ +##################################################################################### +# Copyright (c) 2022-present, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +##################################################################################### +import logging +import os +from ..utils import get_graph_info + + +def log_pass(logger: logging.Logger, level: int): + + def pass_wrapper(func): + + def log_func(gm, *args, **kwargs): + logger.log( + level, + f"Pass: {func.__name__}\nIn Graph:\n{get_graph_info(gm.graph)}" + ) + out = func(gm, *args, **kwargs) + logger.log( + level, + f"Pass: {func.__name__}\nOut Graph:\n{get_graph_info(gm.graph)}" + ) + return out + + return log_func + + return pass_wrapper diff --git a/py/torch_migraphx/dynamo/utils.py b/py/torch_migraphx/dynamo/utils.py index c483669f..d5664703 100644 --- a/py/torch_migraphx/dynamo/utils.py +++ b/py/torch_migraphx/dynamo/utils.py @@ -29,14 +29,17 @@ from typing import Sequence, Union import torch +from ..fx.utils import get_node_info, get_graph_info, SetLogLevel + + +def get_input_info(tensors: Sequence[torch.Tensor]): + return f'Input Sizes: {[tuple(i.size()) if isinstance(i, torch.Tensor) else f"scalar: {i}" for i in tensors]}' def print_graph_info(name: str, gm: torch.fx.GraphModule, inputs: Union[Sequence[torch.Tensor], None]) -> None: print(f'\n{name}') if inputs: - print( - f'Input Sizes: {[tuple(i.size()) if isinstance(i, torch.Tensor) else f"scalar: {i}" for i in inputs]}' - ) - gm.graph.print_tabular() + print(get_input_info(inputs)) + print(get_graph_info(gm.graph)) print() diff --git a/py/torch_migraphx/fx/lower.py b/py/torch_migraphx/fx/lower.py index 581bc9a6..4b06e5e4 100644 --- a/py/torch_migraphx/fx/lower.py +++ b/py/torch_migraphx/fx/lower.py @@ -50,7 +50,7 @@ from .utils import LowerPrecision, SuppressPrints, SetLogLevel, get_graph_info _LOGGER = logging.getLogger(__name__) -LOWERER_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_LOWER', None) +LOWERER_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_FX_LOWER', None) if LOWERER_LOGLEVEL: _LOGGER.setLevel(LOWERER_LOGLEVEL) Input = Sequence[Any] From 562165b741783f6166b50c73209e0c89c078aeba Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 10 Oct 2024 16:25:30 -0700 Subject: [PATCH 3/7] consistent env var naming --- py/torch_migraphx/dynamo/lower_dynamo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_migraphx/dynamo/lower_dynamo.py b/py/torch_migraphx/dynamo/lower_dynamo.py index 3933f3cc..13f1af4a 100644 --- a/py/torch_migraphx/dynamo/lower_dynamo.py +++ b/py/torch_migraphx/dynamo/lower_dynamo.py @@ -44,7 +44,7 @@ from .utils import get_input_info, get_graph_info, SetLogLevel _LOGGER = logging.getLogger(__name__) -DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_LOWERING', None) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_LOWER', None) if DYNAMO_LOGLEVEL: _LOGGER.setLevel(DYNAMO_LOGLEVEL) From 4289d7fe60a85fa5bbfb547dce86f1afe07f0e92 Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 10 Oct 2024 16:31:31 -0700 Subject: [PATCH 4/7] doc env vars --- docs/dev/env_vars.rst | 68 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 docs/dev/env_vars.rst diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst new file mode 100644 index 00000000..fd8395d7 --- /dev/null +++ b/docs/dev/env_vars.rst @@ -0,0 +1,68 @@ +Environment Variables +===================== + +Logging +--------------- + +Accepted values for these variables are logging levels defined by the logging module, +refer to: https://docs.python.org/3/library/logging.html#logging-levels +NOTSET +DEBUG +INFO +WARNING +ERROR +CRITICAL + +By default, torch_migraphx logs levels higher than WARNING (inclusive) + +.. envvar:: TORCH_MIGRAPHX_LOGLEVEL + +Default log level for all purposes. +Default behavior is WARNING + + +.. envvar:: TORCH_MIGRAPHX_LOG_INTERPRETER + +Log level for interpreter. +INFO outputs: + - PyTorch graph passed to MGXInterpreter + - Parsed MIGraphX program +DEBUG outputs: + - Node info for each node in pytorch graph + + +.. envvar:: TORCH_MIGRAPHX_LOG_FX_LOWER + +Log level for fx lowering. +INFO outputs: + - Name of each subgraph that is being lowered + - Node support summary (ie. supported, unsupported nodes) +DEBUG outputs: + - Input shapes and pytorch graph + - Parsed MIGraphX program that is to be compiled + - Compiled MIGraphX program + + +.. envvar:: TORCH_MIGRAPHX_LOG_DYNAMO_LOWER + +Log level for dynamo lowering. +INFO outputs: + - Name of each subgraph that is being lowered + - Input shapes and pytorch graph +DEBUG outputs: + - Parsed MIGraphX program that is to be compiled + - Compiled MIGraphX program + + +.. envvar:: TORCH_MIGRAPHX_LOG_DYNAMO_PASSES + +Log level for dynamo pre lowering dynamo passes +INFO outputs: + - Graph info before and after pre-partitioning, partitioning and post-partitioning passes +DEBUG outputs: + - Graph info for each sub pass + + +.. envvar:: TORCH_MIGRAPHX_LOG_PARTITIONER + +Log level for partitioner pass specifically \ No newline at end of file From 42820c9aefc19ea9a281f91c6637046490cec82e Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 10 Oct 2024 16:34:00 -0700 Subject: [PATCH 5/7] rst format --- docs/dev/env_vars.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index fd8395d7..6fda5779 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -6,13 +6,20 @@ Logging Accepted values for these variables are logging levels defined by the logging module, refer to: https://docs.python.org/3/library/logging.html#logging-levels + NOTSET + DEBUG + INFO + WARNING + ERROR + CRITICAL + By default, torch_migraphx logs levels higher than WARNING (inclusive) .. envvar:: TORCH_MIGRAPHX_LOGLEVEL @@ -24,9 +31,12 @@ Default behavior is WARNING .. envvar:: TORCH_MIGRAPHX_LOG_INTERPRETER Log level for interpreter. + INFO outputs: + - PyTorch graph passed to MGXInterpreter - Parsed MIGraphX program + DEBUG outputs: - Node info for each node in pytorch graph @@ -34,9 +44,12 @@ DEBUG outputs: .. envvar:: TORCH_MIGRAPHX_LOG_FX_LOWER Log level for fx lowering. + INFO outputs: + - Name of each subgraph that is being lowered - Node support summary (ie. supported, unsupported nodes) + DEBUG outputs: - Input shapes and pytorch graph - Parsed MIGraphX program that is to be compiled @@ -46,9 +59,11 @@ DEBUG outputs: .. envvar:: TORCH_MIGRAPHX_LOG_DYNAMO_LOWER Log level for dynamo lowering. + INFO outputs: - Name of each subgraph that is being lowered - Input shapes and pytorch graph + DEBUG outputs: - Parsed MIGraphX program that is to be compiled - Compiled MIGraphX program @@ -57,8 +72,10 @@ DEBUG outputs: .. envvar:: TORCH_MIGRAPHX_LOG_DYNAMO_PASSES Log level for dynamo pre lowering dynamo passes + INFO outputs: - Graph info before and after pre-partitioning, partitioning and post-partitioning passes + DEBUG outputs: - Graph info for each sub pass From 42ba4c5b949ece5e1693ac767e9e20bdd42dee2d Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 10 Oct 2024 16:35:31 -0700 Subject: [PATCH 6/7] rst format --- docs/dev/env_vars.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index 6fda5779..d2a74721 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -38,6 +38,7 @@ INFO outputs: - Parsed MIGraphX program DEBUG outputs: + - Node info for each node in pytorch graph @@ -51,6 +52,7 @@ INFO outputs: - Node support summary (ie. supported, unsupported nodes) DEBUG outputs: + - Input shapes and pytorch graph - Parsed MIGraphX program that is to be compiled - Compiled MIGraphX program @@ -61,10 +63,12 @@ DEBUG outputs: Log level for dynamo lowering. INFO outputs: + - Name of each subgraph that is being lowered - Input shapes and pytorch graph DEBUG outputs: + - Parsed MIGraphX program that is to be compiled - Compiled MIGraphX program @@ -74,9 +78,11 @@ DEBUG outputs: Log level for dynamo pre lowering dynamo passes INFO outputs: + - Graph info before and after pre-partitioning, partitioning and post-partitioning passes DEBUG outputs: + - Graph info for each sub pass From dcdeadfedb53254d67da28cc06c8e24cac291834 Mon Sep 17 00:00:00 2001 From: shivadbhavsar <105248561+shivadbhavsar@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:46:03 -0800 Subject: [PATCH 7/7] fix merge --- py/torch_migraphx/dynamo/lower_dynamo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_migraphx/dynamo/lower_dynamo.py b/py/torch_migraphx/dynamo/lower_dynamo.py index 6ecae893..fe52698a 100644 --- a/py/torch_migraphx/dynamo/lower_dynamo.py +++ b/py/torch_migraphx/dynamo/lower_dynamo.py @@ -92,7 +92,7 @@ def lower_aten_to_mgx(gm: torch.fx.GraphModule, setattr(optim_gm, name, mgx_mod) - return partitioned_gm + return optim_gm # @validate_inference(0.1, 0.1)