diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst new file mode 100644 index 00000000..d2a74721 --- /dev/null +++ b/docs/dev/env_vars.rst @@ -0,0 +1,91 @@ +Environment Variables +===================== + +Logging +--------------- + +Accepted values for these variables are logging levels defined by the logging module, +refer to: https://docs.python.org/3/library/logging.html#logging-levels + +NOTSET + +DEBUG + +INFO + +WARNING + +ERROR + +CRITICAL + + +By default, torch_migraphx logs levels higher than WARNING (inclusive) + +.. envvar:: TORCH_MIGRAPHX_LOGLEVEL + +Default log level for all purposes. +Default behavior is WARNING + + +.. envvar:: TORCH_MIGRAPHX_LOG_INTERPRETER + +Log level for interpreter. + +INFO outputs: + + - PyTorch graph passed to MGXInterpreter + - Parsed MIGraphX program + +DEBUG outputs: + + - Node info for each node in pytorch graph + + +.. envvar:: TORCH_MIGRAPHX_LOG_FX_LOWER + +Log level for fx lowering. + +INFO outputs: + + - Name of each subgraph that is being lowered + - Node support summary (ie. supported, unsupported nodes) + +DEBUG outputs: + + - Input shapes and pytorch graph + - Parsed MIGraphX program that is to be compiled + - Compiled MIGraphX program + + +.. envvar:: TORCH_MIGRAPHX_LOG_DYNAMO_LOWER + +Log level for dynamo lowering. + +INFO outputs: + + - Name of each subgraph that is being lowered + - Input shapes and pytorch graph + +DEBUG outputs: + + - Parsed MIGraphX program that is to be compiled + - Compiled MIGraphX program + + +.. envvar:: TORCH_MIGRAPHX_LOG_DYNAMO_PASSES + +Log level for dynamo pre lowering dynamo passes + +INFO outputs: + + - Graph info before and after pre-partitioning, partitioning and post-partitioning passes + +DEBUG outputs: + + - Graph info for each sub pass + + +.. envvar:: TORCH_MIGRAPHX_LOG_PARTITIONER + +Log level for partitioner pass specifically \ No newline at end of file diff --git a/py/torch_migraphx/__init__.py b/py/torch_migraphx/__init__.py index 332cbeec..02ed7ef5 100644 --- a/py/torch_migraphx/__init__.py +++ b/py/torch_migraphx/__init__.py @@ -4,4 +4,16 @@ from torch_migraphx import fx, _C if version.parse(_torch_version) >= version.parse("2.1.0"): - from torch_migraphx import dynamo \ No newline at end of file + from torch_migraphx import dynamo + +import logging +import os +import sys + +LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOGLEVEL', 'WARNING').upper() +logging.basicConfig( + level=LOGLEVEL, + stream=sys.stderr, + format= + '%(asctime)s.%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s\n', + datefmt='%Y-%m-%d:%H:%M:%S') diff --git a/py/torch_migraphx/dynamo/lower_dynamo.py b/py/torch_migraphx/dynamo/lower_dynamo.py index e02c73a9..fe52698a 100644 --- a/py/torch_migraphx/dynamo/lower_dynamo.py +++ b/py/torch_migraphx/dynamo/lower_dynamo.py @@ -28,6 +28,8 @@ ##################################################################################### from typing import Sequence +import logging +import os import torch from torch.fx.passes.shape_prop import ShapeProp @@ -39,7 +41,12 @@ from .passes.pass_manager import pre_partition_pass, post_partition_pass from .passes.partition import partition, get_partition_inputs -from .utils import print_graph_info +from .utils import get_input_info, get_graph_info, SetLogLevel + +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_LOWER', None) +if DYNAMO_LOGLEVEL: + _LOGGER.setLevel(DYNAMO_LOGLEVEL) def lower_aten_to_mgx(gm: torch.fx.GraphModule, @@ -58,25 +65,32 @@ def lower_aten_to_mgx(gm: torch.fx.GraphModule, torch.fx.GraphModule: GraphModule contatning MGXModule objects for supported subgraphs """ verbose = kwargs['verbose'] if 'verbose' in kwargs else False - if verbose: - print_graph_info('Traced Model', gm, example_inputs) optim_gm = pre_partition_pass(gm) partition(optim_gm, verbose=verbose) - for name, mod in optim_gm.named_children(): - # Const folded params can show up as "child objects" - if not isinstance(mod, torch.fx.GraphModule): - continue + log_level = min(_LOGGER.level, logging.INFO) if verbose else _LOGGER.level + with SetLogLevel(_LOGGER, log_level): + for name, mod in optim_gm.named_children(): + # Const folded params can show up as "child objects" + if not isinstance(mod, torch.fx.GraphModule): + continue + + mod = post_partition_pass(mod) + partition_inputs = get_partition_inputs(partitioned_gm, mod, + example_inputs) - mod = post_partition_pass(mod) - partition_inputs = get_partition_inputs(optim_gm, mod, example_inputs) - if verbose: - print_graph_info(name, mod, partition_inputs) + _LOGGER.info(f"Lowering subgraph: {name}") + _LOGGER.info( + f"Subgraph inputs: {get_input_info(partition_inputs)}") + _LOGGER.info(f"Subgraph:\n{get_graph_info(mod.graph)}") - mgx_mod = lower_subgraph(mod, partition_inputs, name=name, **kwargs) + mgx_mod = lower_subgraph(mod, + partition_inputs, + name=name, + **kwargs) - setattr(optim_gm, name, mgx_mod) + setattr(optim_gm, name, mgx_mod) return optim_gm @@ -94,34 +108,26 @@ def lower_subgraph(module: torch.fx.GraphModule, MGXModule: Callable module that executes graph via MIGraphX """ - verbose = kwargs['verbose'] if 'verbose' in kwargs else False fp16 = kwargs['fp16'] if 'fp16' in kwargs else False deallocate = kwargs['deallocate'] if 'deallocate' in kwargs else False exhaustive_tune = kwargs[ 'exhaustive_tune'] if 'exhaustive_tune' in kwargs else False save_mxr = kwargs['save_mxr'] if 'save_mxr' in kwargs else False - print_uncompiled = (kwargs['print_parsed_program'] - if 'print_parsed_program' in kwargs else False) - print_compiled = (kwargs['print_compiled_program'] - if 'print_compiled_program' in kwargs else False) - - interpreter = MGXInterpreter(module, - inputs, - deallocate=deallocate, - verbose_log=verbose) + + interpreter = MGXInterpreter(module, inputs, deallocate=deallocate) interpreter.run() if save_mxr: name = f"{kwargs['name']}.mxr" if 'name' in kwargs else "prog.mxr" migraphx.save(interpreter.program, name) - if print_uncompiled: interpreter.program.print() + _LOGGER.debug(f"Interpreted Program:\n{interpreter.program}") mgx_module = MGXModule(program=interpreter.program, input_names=interpreter.get_input_names(), quantize_fp16=fp16, exhaustive_tune=exhaustive_tune) - if print_compiled: mgx_module.program.print() + _LOGGER.debug(f"Compiled Program:\n{mgx_module.program}") return mgx_module diff --git a/py/torch_migraphx/dynamo/passes/fix_tensor_meta.py b/py/torch_migraphx/dynamo/passes/fix_tensor_meta.py index 2491a097..8c99fd89 100644 --- a/py/torch_migraphx/dynamo/passes/fix_tensor_meta.py +++ b/py/torch_migraphx/dynamo/passes/fix_tensor_meta.py @@ -26,11 +26,19 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### - +import logging +import os import torch import operator +from .utils import log_pass + +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None) +if DYNAMO_LOGLEVEL: + _LOGGER.setLevel(DYNAMO_LOGLEVEL) +@log_pass(_LOGGER, logging.DEBUG) def fix_tensor_meta(gm: torch.fx.GraphModule): for node in gm.graph.nodes: # This is only true for functions with multiple outputs diff --git a/py/torch_migraphx/dynamo/passes/partition.py b/py/torch_migraphx/dynamo/passes/partition.py index fdfafea3..2bd2ce06 100644 --- a/py/torch_migraphx/dynamo/passes/partition.py +++ b/py/torch_migraphx/dynamo/passes/partition.py @@ -28,15 +28,24 @@ ##################################################################################### from typing import Dict, Optional, Sequence, Mapping +import logging +import os import torch from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport from torch_migraphx.fx.converter_registry import CONVERTERS -from ..utils import print_graph_info +from ..utils import get_graph_info, SetLogLevel from ...fx.utils import TYPE_MAP +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None) +PARTITIONER_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_PARTITIONER', + DYNAMO_LOGLEVEL) +if PARTITIONER_LOGLEVEL: + _LOGGER.setLevel(PARTITIONER_LOGLEVEL) + class MGXOperatorSupport(OperatorSupport): '''Construct OperatorSupport object used for partitioning based on registered converters''' @@ -66,14 +75,19 @@ def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], self.unsupported.add(node.target) return False - def print_support_summary(self): - print('Supported Nodes: ') + def support_summary(self): + summary = "Supported Nodes:\n" for n in self.supported: - print(n) + summary += f"\t{n}\n" - print('\nUnsupported Nodes: ') + summary += "\nUnsupported Nodes:\n" for n in self.unsupported: - print(n) + summary += f"\t{n}\n" + + return summary + + def print_support_summary(self): + print(self.support_summary()) def partition(gm: torch.fx.GraphModule, @@ -90,20 +104,25 @@ def partition(gm: torch.fx.GraphModule, op_support = MGXOperatorSupport() partitioner = CapabilityBasedPartitioner(gm, op_support) - partitons = partitioner.propose_partitions() - fused_gm = partitioner.fuse_partitions(partitons) + partitions = partitioner.propose_partitions() + fused_gm = partitioner.fuse_partitions(partitions) fused_gm.graph.eliminate_dead_code() fused_gm.recompile() fused_gm.delete_all_unused_submodules() - if verbose: - print_graph_info("Partitioned Module", fused_gm, None) - op_support.print_support_summary() + log_level = _LOGGER.level + if verbose and _LOGGER.level > logging.INFO: + log_level = logging.INFO + + with SetLogLevel(_LOGGER, log_level): + _LOGGER.debug(f"Partitioned Module:\n{get_graph_info(fused_gm.graph)}") + _LOGGER.info(f"Node support summary:\n{op_support.support_summary()}") + _LOGGER.info(f"Number of partitions: {len(partitions)}") # TODO: Compute number of partitions after dead code elimination - if len(partitons) > max_partitions: + if len(partitions) > max_partitions: raise RuntimeError( - f'Found {len(partitons)} partitions, max allowed: {max_partitions}.' + f'Found {len(partitions)} partitions, max allowed: {max_partitions}.' ) return fused_gm diff --git a/py/torch_migraphx/dynamo/passes/pass_manager.py b/py/torch_migraphx/dynamo/passes/pass_manager.py index 64e71cea..ea55b65b 100644 --- a/py/torch_migraphx/dynamo/passes/pass_manager.py +++ b/py/torch_migraphx/dynamo/passes/pass_manager.py @@ -26,6 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### +import logging +import os import torch from torch.fx.passes.pass_manager import PassManager @@ -34,7 +36,12 @@ from .promote_types import promote_inputs from .remove_empty_slice import remove_empty_slices from .fix_tensor_meta import fix_tensor_meta +from ..utils import get_graph_info +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None) +if DYNAMO_LOGLEVEL: + _LOGGER.setLevel(DYNAMO_LOGLEVEL) class MGXPassManager(PassManager): @@ -51,6 +58,7 @@ def pre_partition_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: const_fold, ] pre_partition_pass_mgr = MGXPassManager(passes) + _LOGGER.info(f"Pre Partition Pass In:\n{get_graph_info(gm.graph)}") return pre_partition_pass_mgr(gm) @@ -59,4 +67,5 @@ def post_partition_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: fix_tensor_meta, ] post_partition_pass_mgr = MGXPassManager(passes) + _LOGGER.info(f"Post Partition Pass In:\n{get_graph_info(gm.graph)}") return post_partition_pass_mgr(gm) diff --git a/py/torch_migraphx/dynamo/passes/remove_ops.py b/py/torch_migraphx/dynamo/passes/remove_ops.py index e4cd899b..1a3189cc 100644 --- a/py/torch_migraphx/dynamo/passes/remove_ops.py +++ b/py/torch_migraphx/dynamo/passes/remove_ops.py @@ -26,11 +26,19 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### - +import logging +import os import torch import torch.fx +from .utils import log_pass + +_LOGGER = logging.getLogger(__name__) +DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None) +if DYNAMO_LOGLEVEL: + _LOGGER.setLevel(DYNAMO_LOGLEVEL) +@log_pass(_LOGGER, logging.DEBUG) def remove_clone_ops(gm: torch.fx.GraphModule): clone_ops = [ torch.ops.aten.clone.default, @@ -46,6 +54,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule): return gm +@log_pass(_LOGGER, logging.DEBUG) def remove_view_ops(gm: torch.fx.GraphModule): view_ops = [ torch.ops.aten._unsafe_view.default, @@ -69,8 +78,10 @@ def remove_view_ops(gm: torch.fx.GraphModule): return gm +@log_pass(_LOGGER, logging.DEBUG) def remove_const_ops(gm: torch.fx.GraphModule, device: str = "cuda"): + @log_pass(_LOGGER, logging.DEBUG) def _remove_new_const_ops(gm: torch.fx.GraphModule): const_ops = { torch.ops.aten.new_zeros.default: torch.zeros, @@ -98,6 +109,7 @@ def _remove_new_const_ops(gm: torch.fx.GraphModule): gm.graph.eliminate_dead_code() gm.recompile() + @log_pass(_LOGGER, logging.DEBUG) def _remove_const_like_ops(gm: torch.fx.GraphModule): const_ops = { torch.ops.aten.full_like.default: torch.full, @@ -129,6 +141,7 @@ def _remove_const_like_ops(gm: torch.fx.GraphModule): gm.graph.eliminate_dead_code() gm.recompile() + @log_pass(_LOGGER, logging.DEBUG) def _remove_range_ops(gm: torch.fx.GraphModule): const_ops = { torch.ops.aten.arange.start: torch.arange, diff --git a/py/torch_migraphx/dynamo/passes/utils.py b/py/torch_migraphx/dynamo/passes/utils.py new file mode 100644 index 00000000..5df50354 --- /dev/null +++ b/py/torch_migraphx/dynamo/passes/utils.py @@ -0,0 +1,52 @@ +##################################################################################### +# Copyright (c) 2022-present, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +##################################################################################### +import logging +import os +from ..utils import get_graph_info + + +def log_pass(logger: logging.Logger, level: int): + + def pass_wrapper(func): + + def log_func(gm, *args, **kwargs): + logger.log( + level, + f"Pass: {func.__name__}\nIn Graph:\n{get_graph_info(gm.graph)}" + ) + out = func(gm, *args, **kwargs) + logger.log( + level, + f"Pass: {func.__name__}\nOut Graph:\n{get_graph_info(gm.graph)}" + ) + return out + + return log_func + + return pass_wrapper diff --git a/py/torch_migraphx/dynamo/utils.py b/py/torch_migraphx/dynamo/utils.py index c483669f..d5664703 100644 --- a/py/torch_migraphx/dynamo/utils.py +++ b/py/torch_migraphx/dynamo/utils.py @@ -29,14 +29,17 @@ from typing import Sequence, Union import torch +from ..fx.utils import get_node_info, get_graph_info, SetLogLevel + + +def get_input_info(tensors: Sequence[torch.Tensor]): + return f'Input Sizes: {[tuple(i.size()) if isinstance(i, torch.Tensor) else f"scalar: {i}" for i in tensors]}' def print_graph_info(name: str, gm: torch.fx.GraphModule, inputs: Union[Sequence[torch.Tensor], None]) -> None: print(f'\n{name}') if inputs: - print( - f'Input Sizes: {[tuple(i.size()) if isinstance(i, torch.Tensor) else f"scalar: {i}" for i in inputs]}' - ) - gm.graph.print_tabular() + print(get_input_info(inputs)) + print(get_graph_info(gm.graph)) print() diff --git a/py/torch_migraphx/fx/fx2mgx.py b/py/torch_migraphx/fx/fx2mgx.py index ba38b1d3..e442e7fd 100644 --- a/py/torch_migraphx/fx/fx2mgx.py +++ b/py/torch_migraphx/fx/fx2mgx.py @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### from typing import Iterable -import warnings +import os import torch import torch.fx import migraphx @@ -36,15 +36,17 @@ from .utils import * from .mgx_module import MGXInstruction from .converters.utils import convert_arg +import logging + +_LOGGER = logging.getLogger(__name__) +INTERPRETER_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_INTERPRETER', None) +if INTERPRETER_LOGLEVEL: + _LOGGER.setLevel(INTERPRETER_LOGLEVEL) class MGXInterpreter(torch.fx.Interpreter): - def __init__(self, - module, - sample_inputs, - deallocate=False, - verbose_log=False): + def __init__(self, module, sample_inputs, deallocate=False): super().__init__(module) self.program = migraphx.program() @@ -56,7 +58,7 @@ def __init__(self, self._outputs = [] self.unsupported_ops = self.validate_conversion() if self.unsupported_ops: - warnings.warn( + _LOGGER.warning( 'Torch model contains the following unsupported operations: \n' + '\n'.join(f'{i}' for i in self.unsupported_ops)) self.deallocate = deallocate @@ -78,12 +80,15 @@ def validate_conversion(self): return missing_converters def run(self): + _LOGGER.info(f"Running MGXInterpreter for:\n{self.module.graph}") super().run() output_instr_refs = [i.instr_ref for i in self._outputs] self.mm.add_return(output_instr_refs) + _LOGGER.info(f"Parsed MIGraphX Program:\n{self.program}") return self.program def run_node(self, n): + _LOGGER.debug(f"MGXInterpreter running node:\n{get_node_info(n)}") args, kwargs = self.fetch_args_kwargs_from_env(n) assert isinstance(args, tuple) assert isinstance(kwargs, dict) @@ -106,9 +111,7 @@ def placeholder(self, node, args, kwargs): def call_module(self, node, args, kwargs): assert isinstance(node.target, str) - # print(f'call module: {args}') submod = self.fetch_attr(node.target) - # submod_type = getattr(submod, '_base_class_origin', type(submod)) submod_type = type(submod) converter = CONVERTERS.get(submod_type) diff --git a/py/torch_migraphx/fx/lower.py b/py/torch_migraphx/fx/lower.py index d0b45e57..4b06e5e4 100644 --- a/py/torch_migraphx/fx/lower.py +++ b/py/torch_migraphx/fx/lower.py @@ -28,6 +28,7 @@ ##################################################################################### import dataclasses as dc import logging +import os from typing import Any, Callable, Optional, Sequence import migraphx @@ -46,9 +47,12 @@ from .tracer.acc_tracer import acc_tracer from .tracer.aten_tracer import aten_tracer from .mgx_module import MGXModule -from .utils import LowerPrecision +from .utils import LowerPrecision, SuppressPrints, SetLogLevel, get_graph_info -logger = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) +LOWERER_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_FX_LOWER', None) +if LOWERER_LOGLEVEL: + _LOGGER.setLevel(LOWERER_LOGLEVEL) Input = Sequence[Any] @@ -88,18 +92,24 @@ def lower_to_mgx(module: nn.Module, """ module = module.cpu().eval() input = [to_device(x, "cpu") for x in input] - lower_setting = LowerSetting( - lower_precision=lower_precision, - verbose_log=verbose_log, - min_acc_module_size=min_acc_module_size, - suppress_accuracy_check=suppress_accuracy_check, - save_subgraph_programs=save_subgraph_programs, - tracer_base_cls=tracer_base_cls, - leaf_module_list=leaf_modules, - use_aten=use_aten, - ) - lowerer = Lowerer.create(lower_setting=lower_setting) - return lowerer(module, input) + + if verbose_log and _LOGGER.level > logging.INFO: + log_level = logging.INFO + else: + log_level = _LOGGER.level + + with SetLogLevel(_LOGGER, log_level): + lower_setting = LowerSetting( + lower_precision=lower_precision, + min_acc_module_size=min_acc_module_size, + suppress_accuracy_check=suppress_accuracy_check, + save_subgraph_programs=save_subgraph_programs, + tracer_base_cls=tracer_base_cls, + leaf_module_list=leaf_modules, + use_aten=use_aten, + ) + lowerer = Lowerer.create(lower_setting=lower_setting) + return lowerer(module, input) @dc.dataclass @@ -111,10 +121,12 @@ def create(cls, lower_setting): return LowerMgxInterpreter(lower_setting) def __call__(self, mod, input, split_name) -> MGXInterpreter: - logger.info(f"split_name={split_name}") + _LOGGER.info(f"Running MGXInterpreter for {split_name}") AccShapeProp(mod).propagate(*input) - interpreter = MGXInterpreter( - mod, input, verbose_log=self.lower_setting.verbose_log) + input_shapes = [(i.shape, i.dtype) for i in input] + _LOGGER.debug(f"Input Shapes: {input_shapes}") + _LOGGER.debug(f"{split_name} Graph:\n{get_graph_info(mod.graph)}") + interpreter = MGXInterpreter(mod, input) interpreter.run() return interpreter @@ -125,8 +137,9 @@ def default_split_function(model: fx.GraphModule, inputs: Input, splitter_setting = MGXSplitterSetting() splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size splitter = MGXSplitter(model, inputs, settings=splitter_setting) - if lower_setting.verbose_log: - splitter.node_support_preview() + with SuppressPrints(): + node_preview = splitter.node_support_preview() + _LOGGER.info(f"\n{node_preview}") return splitter.generate_split_results() @@ -147,6 +160,8 @@ def lower_pass(mod: nn.Module, input: Input, lower_setting: LowerSetting, """ interpreter = create_mgx_interpreter(lower_setting) interp_res: MGXInterpreter = interpreter(mod, input, module_name) + + _LOGGER.debug(f"Interpreted MIGraphX Program:\n{interp_res.program}") if lower_setting.save_subgraph_programs: migraphx.save(interp_res.program, f'{module_name}.mxr') @@ -158,6 +173,8 @@ def lower_pass(mod: nn.Module, input: Input, lower_setting: LowerSetting, input_names=interp_res.get_input_names(), quantize_fp16=fp16_mode, ) + + _LOGGER.debug(f"Compiled MIGraphX Program:\n{mgx_module.program}") return mgx_module return lower_pass diff --git a/py/torch_migraphx/fx/lower_setting.py b/py/torch_migraphx/fx/lower_setting.py index e5a1d984..e237ca57 100644 --- a/py/torch_migraphx/fx/lower_setting.py +++ b/py/torch_migraphx/fx/lower_setting.py @@ -75,7 +75,6 @@ class LowerSetting(LowerSettingBasic): instance of Lowerer. """ - verbose_log: bool = False explicit_precision: bool = False preset_lowerer: str = "" suppress_accuracy_check: bool = False diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py index b7ad9618..c7aeda29 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py @@ -591,9 +591,8 @@ def trace( or a callable (e.g. for op == "call_function"). """ if mod.training: - warnings.warn( - "acc_tracer does not support currently support models for training." - " Calling eval on model before tracing.") + _LOGGER.warning("acc_tracer does not support currently support models for training." + " Calling eval on model before tracing.") mod.eval() assert isinstance(sample_inputs, (list, tuple)) diff --git a/py/torch_migraphx/fx/utils.py b/py/torch_migraphx/fx/utils.py index 098cea3b..f3f1fcd7 100644 --- a/py/torch_migraphx/fx/utils.py +++ b/py/torch_migraphx/fx/utils.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ##################################################################################### import os +import sys from enum import Enum from typing import List, Callable from packaging import version @@ -41,6 +42,31 @@ HIPSTREAMTYPE = 'ihipStream_t' +class SuppressPrints: + + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + + +class SetLogLevel: + + def __init__(self, logger, level): + self.logger = logger + self.level = level + + def __enter__(self): + self.original_level = self.logger.level + self.logger.setLevel(self.level) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.logger.setLevel(self.original_level) + + class LowerPrecision(Enum): FP32 = "fp32" FP16 = "fp16" @@ -183,16 +209,23 @@ def mgx_program_from_bytearray(barray: bytearray) -> migraphx.program: return prog -def print_graph(graph: torch.fx.Graph) -> None: +def get_node_info(node: torch.fx.Node) -> str: + node_info = 'Return' if node.op == 'output' else node.format_node() + out_str = f"{node_info}, args: {node.args}, kwargs: {node.kwargs}" + if 'tensor_meta' in node.meta: + out_str += tensor_meta_str(node.meta['tensor_meta']) + return out_str + + +def get_graph_info(graph: torch.fx.Graph) -> str: + out_str = "" for node in graph.nodes: - node_info = 'Return' if node.op == 'output' else node.format_node() - out_str = f"{node_info}, args: {node.args}, kwargs: {node.kwargs}" - if 'tensor_meta' in node.meta: - out_str += tensor_meta_str(node.meta['tensor_meta']) + out_str += f"\n\t{get_node_info(node)}\n" + return out_str - print(out_str) - print() +def print_graph(graph: torch.fx.Graph) -> None: + print(get_graph_info(graph)) def tensor_meta_str(tm) -> str: