Skip to content

Commit

Permalink
move torch tensorts to cpu as graph is interpreted (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored Nov 22, 2024
1 parent 75fdc3e commit 5c76dae
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
12 changes: 7 additions & 5 deletions py/torch_migraphx/dynamo/lower_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def lower_aten_to_mgx(gm: torch.fx.GraphModule,
print_graph_info('Traced Model', gm, example_inputs)

optim_gm = pre_partition_pass(gm)
patitioned_gm = partition(optim_gm, verbose=verbose)
partition(optim_gm, verbose=verbose)

for name, mod in patitioned_gm.named_children():
for name, mod in optim_gm.named_children():
# Const folded params can show up as "child objects"
if not isinstance(mod, torch.fx.GraphModule):
continue
Expand All @@ -77,8 +77,6 @@ def lower_aten_to_mgx(gm: torch.fx.GraphModule,
mgx_mod = lower_subgraph(mod, partition_inputs, name=name, **kwargs)

setattr(optim_gm, name, mgx_mod)
del mod
del partition_inputs

return optim_gm

Expand All @@ -98,6 +96,7 @@ def lower_subgraph(module: torch.fx.GraphModule,

verbose = kwargs['verbose'] if 'verbose' in kwargs else False
fp16 = kwargs['fp16'] if 'fp16' in kwargs else False
deallocate = kwargs['deallocate'] if 'deallocate' in kwargs else False
exhaustive_tune = kwargs[
'exhaustive_tune'] if 'exhaustive_tune' in kwargs else False
save_mxr = kwargs['save_mxr'] if 'save_mxr' in kwargs else False
Expand All @@ -106,7 +105,10 @@ def lower_subgraph(module: torch.fx.GraphModule,
print_compiled = (kwargs['print_compiled_program']
if 'print_compiled_program' in kwargs else False)

interpreter = MGXInterpreter(module, inputs, verbose_log=verbose)
interpreter = MGXInterpreter(module,
inputs,
deallocate=deallocate,
verbose_log=verbose)
interpreter.run()

if save_mxr:
Expand Down
13 changes: 12 additions & 1 deletion py/torch_migraphx/fx/fx2mgx.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@

class MGXInterpreter(torch.fx.Interpreter):

def __init__(self, module, sample_inputs, verbose_log=False):
def __init__(self,
module,
sample_inputs,
deallocate=False,
verbose_log=False):
super().__init__(module)

self.program = migraphx.program()
Expand All @@ -55,6 +59,7 @@ def __init__(self, module, sample_inputs, verbose_log=False):
warnings.warn(
'Torch model contains the following unsupported operations: \n'
+ '\n'.join(f'{i}' for i in self.unsupported_ops))
self.deallocate = deallocate

def validate_conversion(self):
missing_converters = set()
Expand Down Expand Up @@ -142,6 +147,9 @@ def get_attr(self, node, args, kwargs):
if isinstance(attr, torch.nn.ParameterList):
mgx_attrs = []
for a in attr:
if self.deallocate:
a.data = a.data.cpu()
torch.cuda.empty_cache()
t, qparams = get_qparams(a)
mgx_attrs.append(
MGXInstruction(
Expand All @@ -151,6 +159,9 @@ def get_attr(self, node, args, kwargs):
))
return tuple(mgx_attrs)

if self.deallocate:
attr.data = attr.data.cpu()
torch.cuda.empty_cache()
t, qparams = get_qparams(attr)
return MGXInstruction(self.mm.add_literal(t.cpu().detach().numpy()),
torch_attr_value=attr,
Expand Down

0 comments on commit 5c76dae

Please sign in to comment.