Skip to content

Commit

Permalink
edits
Browse files Browse the repository at this point in the history
  • Loading branch information
richagadgil committed Mar 8, 2025
1 parent 1c891a0 commit c7f8885
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions py/torch_migraphx/dynamo/passes/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ..utils import print_graph_info
from ...fx.utils import TYPE_MAP

import operator

class MGXOperatorSupport(OperatorSupport):
'''Construct OperatorSupport object used for partitioning based on registered converters'''
Expand All @@ -57,24 +58,52 @@ def is_node_supported(self, submodules: Mapping[str, torch.nn.Module],

# --- If this is a get_attr node, we can do extra checks ---
if node.op == "get_attr":

if len(node.users) == 1: # and node.users[0].op == "output":
x = list(node.users.keys())[0]
if x.op == 'output':
return False

root_mod = submodules[""]
attr_name = node.target

if hasattr(root_mod, attr_name):
attr_val = getattr(root_mod, attr_name)


if isinstance(attr_val, torch.nn.ParameterList):
for param in attr_val:
if param.dtype not in self.supported_dtypes:
self.unsupported.add(f"{attr_name} : {param.dtype}")
return False
# if isinstance(attr_val, torch.nn.ParameterList):
# import pdb; pdb.set_trace()
# for param in attr_val:
# if param.dtype not in self.supported_dtypes:
# self.unsupported.add(f"{attr_name} : {param.dtype}")
# return False

if isinstance(attr_val, torch.nn.Parameter):
if attr_val.dtype not in self.supported_dtypes:
self.unsupported.add(f"{attr_name} : {attr_val.dtype}")
return False

return True


if node.target == operator.getitem:
root_mod = submodules[""]
attr_name = node.args[0].target # first input ---> target is fx const folded, etc.

if isinstance(attr_name, str):
attr_val_list = getattr(root_mod, attr_name) # paramlist

if isinstance(attr_val_list, torch.nn.ParameterList):

if len(node.users) == 1: # and node.users[0].op == "output":
x = list(node.users.keys())[0]
if x.op == 'output':
return False

attr_val = attr_val_list[node.args[1]]
if attr_val.dtype not in self.supported_dtypes:
self.unsupported.add(f"{attr_name} : {attr_val.dtype}")
return False

if node.target in CONVERTERS.keys():
if not node.is_impure():
Expand Down

0 comments on commit c7f8885

Please sign in to comment.