Skip to content

Commit

Permalink
Remove unused tensor literals from memory when saving in dynamo (#218)
Browse files Browse the repository at this point in the history
* update loading saved model in torch backend to work with accelerate

* remove unused tensor literals from memory
  • Loading branch information
shivadbhavsar authored Jan 7, 2025
1 parent 33d4d63 commit af14cd9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
7 changes: 4 additions & 3 deletions py/torch_migraphx/dynamo/lower_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from torch_migraphx.fx.fx2mgx import MGXInterpreter
from torch_migraphx.fx.passes.pass_utils import validate_inference

from .passes.pass_manager import pre_partition_pass, post_partition_pass
from .passes.pass_manager import pre_partition_pass, post_partition_pass, post_lowering_pass
from .passes.partition import partition, get_partition_inputs
from .utils import print_graph_info

Expand Down Expand Up @@ -77,8 +77,9 @@ def lower_aten_to_mgx(gm: torch.fx.GraphModule,
mgx_mod = lower_subgraph(mod, partition_inputs, name=name, **kwargs)

setattr(optim_gm, name, mgx_mod)

return optim_gm

lowered_gm = post_lowering_pass(optim_gm)
return lowered_gm


# @validate_inference(0.1, 0.1)
Expand Down
9 changes: 9 additions & 0 deletions py/torch_migraphx/dynamo/passes/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .promote_types import promote_inputs
from .remove_empty_slice import remove_empty_slices
from .fix_tensor_meta import fix_tensor_meta
from .remove_lowered_constants import remove_lowered_constants


class MGXPassManager(PassManager):
Expand All @@ -60,3 +61,11 @@ def post_partition_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
]
post_partition_pass_mgr = MGXPassManager(passes)
return post_partition_pass_mgr(gm)


def post_lowering_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
passes = [
remove_lowered_constants,
]
post_partition_pass_mgr = MGXPassManager(passes)
return post_partition_pass_mgr(gm)
19 changes: 19 additions & 0 deletions py/torch_migraphx/dynamo/passes/remove_lowered_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import itertools
import torch

def remove_lowered_constants(gm: torch.fx.GraphModule):
used_literals = set()
for node in gm.graph.nodes:
if node.op == "get_attr":
used_literals.add(node.name)

unused_literals = set()
for name, _ in itertools.chain(gm.named_parameters(), gm.named_buffers()):
if not name in used_literals:
if hasattr(gm, name):
unused_literals.add(name)

for name in unused_literals:
delattr(gm, name)

return gm

0 comments on commit af14cd9

Please sign in to comment.