Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging #207

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
Environment Variables
=====================

Logging
---------------

Accepted values for these variables are logging levels defined by the logging module,
refer to: https://docs.python.org/3/library/logging.html#logging-levels

NOTSET

DEBUG

INFO

WARNING

ERROR

CRITICAL


By default, torch_migraphx logs levels higher than WARNING (inclusive)

.. envvar:: TORCH_MIGRAPHX_LOGLEVEL

Default log level for all purposes.
Default behavior is WARNING


.. envvar:: TORCH_MIGRAPHX_LOG_INTERPRETER

Log level for interpreter.

INFO outputs:

- PyTorch graph passed to MGXInterpreter
- Parsed MIGraphX program

DEBUG outputs:

- Node info for each node in pytorch graph


.. envvar:: TORCH_MIGRAPHX_LOG_FX_LOWER

Log level for fx lowering.

INFO outputs:

- Name of each subgraph that is being lowered
- Node support summary (ie. supported, unsupported nodes)

DEBUG outputs:

- Input shapes and pytorch graph
- Parsed MIGraphX program that is to be compiled
- Compiled MIGraphX program


.. envvar:: TORCH_MIGRAPHX_LOG_DYNAMO_LOWER

Log level for dynamo lowering.

INFO outputs:

- Name of each subgraph that is being lowered
- Input shapes and pytorch graph

DEBUG outputs:

- Parsed MIGraphX program that is to be compiled
- Compiled MIGraphX program


.. envvar:: TORCH_MIGRAPHX_LOG_DYNAMO_PASSES

Log level for dynamo pre lowering dynamo passes

INFO outputs:

- Graph info before and after pre-partitioning, partitioning and post-partitioning passes

DEBUG outputs:

- Graph info for each sub pass


.. envvar:: TORCH_MIGRAPHX_LOG_PARTITIONER

Log level for partitioner pass specifically
14 changes: 13 additions & 1 deletion py/torch_migraphx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,16 @@
from torch_migraphx import fx, _C

if version.parse(_torch_version) >= version.parse("2.1.0"):
from torch_migraphx import dynamo
from torch_migraphx import dynamo

import logging
import os
import sys

LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOGLEVEL', 'WARNING').upper()
logging.basicConfig(
level=LOGLEVEL,
stream=sys.stderr,
format=
'%(asctime)s.%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s\n',
datefmt='%Y-%m-%d:%H:%M:%S')
56 changes: 31 additions & 25 deletions py/torch_migraphx/dynamo/lower_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#####################################################################################

from typing import Sequence
import logging
import os

import torch
from torch.fx.passes.shape_prop import ShapeProp
Expand All @@ -39,7 +41,12 @@

from .passes.pass_manager import pre_partition_pass, post_partition_pass
from .passes.partition import partition, get_partition_inputs
from .utils import print_graph_info
from .utils import get_input_info, get_graph_info, SetLogLevel

_LOGGER = logging.getLogger(__name__)
DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_LOWER', None)
if DYNAMO_LOGLEVEL:
_LOGGER.setLevel(DYNAMO_LOGLEVEL)


def lower_aten_to_mgx(gm: torch.fx.GraphModule,
Expand All @@ -58,25 +65,32 @@ def lower_aten_to_mgx(gm: torch.fx.GraphModule,
torch.fx.GraphModule: GraphModule contatning MGXModule objects for supported subgraphs
"""
verbose = kwargs['verbose'] if 'verbose' in kwargs else False
if verbose:
print_graph_info('Traced Model', gm, example_inputs)

optim_gm = pre_partition_pass(gm)
partition(optim_gm, verbose=verbose)

for name, mod in optim_gm.named_children():
# Const folded params can show up as "child objects"
if not isinstance(mod, torch.fx.GraphModule):
continue
log_level = min(_LOGGER.level, logging.INFO) if verbose else _LOGGER.level
with SetLogLevel(_LOGGER, log_level):
for name, mod in optim_gm.named_children():
# Const folded params can show up as "child objects"
if not isinstance(mod, torch.fx.GraphModule):
continue

mod = post_partition_pass(mod)
partition_inputs = get_partition_inputs(partitioned_gm, mod,
example_inputs)

mod = post_partition_pass(mod)
partition_inputs = get_partition_inputs(optim_gm, mod, example_inputs)
if verbose:
print_graph_info(name, mod, partition_inputs)
_LOGGER.info(f"Lowering subgraph: {name}")
_LOGGER.info(
f"Subgraph inputs: {get_input_info(partition_inputs)}")
_LOGGER.info(f"Subgraph:\n{get_graph_info(mod.graph)}")

mgx_mod = lower_subgraph(mod, partition_inputs, name=name, **kwargs)
mgx_mod = lower_subgraph(mod,
partition_inputs,
name=name,
**kwargs)

setattr(optim_gm, name, mgx_mod)
setattr(optim_gm, name, mgx_mod)

return optim_gm

Expand All @@ -94,34 +108,26 @@ def lower_subgraph(module: torch.fx.GraphModule,
MGXModule: Callable module that executes graph via MIGraphX
"""

verbose = kwargs['verbose'] if 'verbose' in kwargs else False
fp16 = kwargs['fp16'] if 'fp16' in kwargs else False
deallocate = kwargs['deallocate'] if 'deallocate' in kwargs else False
exhaustive_tune = kwargs[
'exhaustive_tune'] if 'exhaustive_tune' in kwargs else False
save_mxr = kwargs['save_mxr'] if 'save_mxr' in kwargs else False
print_uncompiled = (kwargs['print_parsed_program']
if 'print_parsed_program' in kwargs else False)
print_compiled = (kwargs['print_compiled_program']
if 'print_compiled_program' in kwargs else False)

interpreter = MGXInterpreter(module,
inputs,
deallocate=deallocate,
verbose_log=verbose)

interpreter = MGXInterpreter(module, inputs, deallocate=deallocate)
interpreter.run()

if save_mxr:
name = f"{kwargs['name']}.mxr" if 'name' in kwargs else "prog.mxr"
migraphx.save(interpreter.program, name)

if print_uncompiled: interpreter.program.print()
_LOGGER.debug(f"Interpreted Program:\n{interpreter.program}")

mgx_module = MGXModule(program=interpreter.program,
input_names=interpreter.get_input_names(),
quantize_fp16=fp16,
exhaustive_tune=exhaustive_tune)

if print_compiled: mgx_module.program.print()
_LOGGER.debug(f"Compiled Program:\n{mgx_module.program}")

return mgx_module
10 changes: 9 additions & 1 deletion py/torch_migraphx/dynamo/passes/fix_tensor_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,19 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#####################################################################################

import logging
import os
import torch
import operator
from .utils import log_pass

_LOGGER = logging.getLogger(__name__)
DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None)
if DYNAMO_LOGLEVEL:
_LOGGER.setLevel(DYNAMO_LOGLEVEL)


@log_pass(_LOGGER, logging.DEBUG)
def fix_tensor_meta(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
# This is only true for functions with multiple outputs
Expand Down
45 changes: 32 additions & 13 deletions py/torch_migraphx/dynamo/passes/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,24 @@
#####################################################################################

from typing import Dict, Optional, Sequence, Mapping
import logging
import os

import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport

from torch_migraphx.fx.converter_registry import CONVERTERS
from ..utils import print_graph_info
from ..utils import get_graph_info, SetLogLevel
from ...fx.utils import TYPE_MAP

_LOGGER = logging.getLogger(__name__)
DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None)
PARTITIONER_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_PARTITIONER',
DYNAMO_LOGLEVEL)
if PARTITIONER_LOGLEVEL:
_LOGGER.setLevel(PARTITIONER_LOGLEVEL)


class MGXOperatorSupport(OperatorSupport):
'''Construct OperatorSupport object used for partitioning based on registered converters'''
Expand Down Expand Up @@ -66,14 +75,19 @@ def is_node_supported(self, submodules: Mapping[str, torch.nn.Module],
self.unsupported.add(node.target)
return False

def print_support_summary(self):
print('Supported Nodes: ')
def support_summary(self):
summary = "Supported Nodes:\n"
for n in self.supported:
print(n)
summary += f"\t{n}\n"

print('\nUnsupported Nodes: ')
summary += "\nUnsupported Nodes:\n"
for n in self.unsupported:
print(n)
summary += f"\t{n}\n"

return summary

def print_support_summary(self):
print(self.support_summary())


def partition(gm: torch.fx.GraphModule,
Expand All @@ -90,20 +104,25 @@ def partition(gm: torch.fx.GraphModule,
op_support = MGXOperatorSupport()
partitioner = CapabilityBasedPartitioner(gm, op_support)

partitons = partitioner.propose_partitions()
fused_gm = partitioner.fuse_partitions(partitons)
partitions = partitioner.propose_partitions()
fused_gm = partitioner.fuse_partitions(partitions)
fused_gm.graph.eliminate_dead_code()
fused_gm.recompile()
fused_gm.delete_all_unused_submodules()

if verbose:
print_graph_info("Partitioned Module", fused_gm, None)
op_support.print_support_summary()
log_level = _LOGGER.level
if verbose and _LOGGER.level > logging.INFO:
log_level = logging.INFO

with SetLogLevel(_LOGGER, log_level):
_LOGGER.debug(f"Partitioned Module:\n{get_graph_info(fused_gm.graph)}")
_LOGGER.info(f"Node support summary:\n{op_support.support_summary()}")
_LOGGER.info(f"Number of partitions: {len(partitions)}")

# TODO: Compute number of partitions after dead code elimination
if len(partitons) > max_partitions:
if len(partitions) > max_partitions:
raise RuntimeError(
f'Found {len(partitons)} partitions, max allowed: {max_partitions}.'
f'Found {len(partitions)} partitions, max allowed: {max_partitions}.'
)

return fused_gm
Expand Down
9 changes: 9 additions & 0 deletions py/torch_migraphx/dynamo/passes/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#####################################################################################
import logging
import os
import torch
from torch.fx.passes.pass_manager import PassManager

Expand All @@ -34,7 +36,12 @@
from .promote_types import promote_inputs
from .remove_empty_slice import remove_empty_slices
from .fix_tensor_meta import fix_tensor_meta
from ..utils import get_graph_info

_LOGGER = logging.getLogger(__name__)
DYNAMO_LOGLEVEL = os.environ.get('TORCH_MIGRAPHX_LOG_DYNAMO_PASSES', None)
if DYNAMO_LOGLEVEL:
_LOGGER.setLevel(DYNAMO_LOGLEVEL)

class MGXPassManager(PassManager):

Expand All @@ -51,6 +58,7 @@ def pre_partition_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
const_fold,
]
pre_partition_pass_mgr = MGXPassManager(passes)
_LOGGER.info(f"Pre Partition Pass In:\n{get_graph_info(gm.graph)}")
return pre_partition_pass_mgr(gm)


Expand All @@ -59,4 +67,5 @@ def post_partition_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
fix_tensor_meta,
]
post_partition_pass_mgr = MGXPassManager(passes)
_LOGGER.info(f"Post Partition Pass In:\n{get_graph_info(gm.graph)}")
return post_partition_pass_mgr(gm)
Loading
Loading