Skip to content

Commit

Permalink
remove duplication
Browse files Browse the repository at this point in the history
Differential Revision: D65236267

Pull Request resolved: pytorch#7186
  • Loading branch information
skrtskrtfb authored Dec 18, 2024
1 parent 44e31fb commit 6b72663
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
34 changes: 34 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,37 @@ python_unittest(
"//executorch/exir/dialects:lib",
],
)


python_library(
name = "memory_planning",
srcs = [
"memory_planning.py",
],
deps = [
"fbsource//third-party/pypi/tabulate:tabulate",
":memory_constraints",
":pass_utils",
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:memory_planning",
"//executorch/exir:tensor",
"//executorch/exir/passes:lib",
],
)


python_library(
name = "memory_constraints",
srcs = [
"memory_constraints.py",
],
deps = [
":pass_utils",
":utils",
"//caffe2:torch",
"//executorch/exir:memory",
"//executorch/exir:pass_manager",
"//executorch/exir:tensor",
],
)
10 changes: 8 additions & 2 deletions backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import collections
import itertools
import logging
Expand Down Expand Up @@ -331,14 +333,15 @@ def find_peak_memory_usage(
# | Peak memory usage across all spaces | 2380032 bytes | Node 86 |
# +-------------------------------------+---------------+---------+
def print_memory_planning_info(
# pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type.
executorch_prog: ExecutorchProgramManager,
memory_config: MemoryConfig,
opt_level: int,
alloc_graph_input: bool,
alloc_graph_output: bool,
) -> None:
# Get the peak memory usages per memory space
mem_constraints = MemConstraints(
opt_level=opt_level,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
)
Expand Down Expand Up @@ -406,6 +409,7 @@ class CadenceMemoryPlanning:
def __init__(
self,
memory_config: MemoryConfig,
opt_level: int,
mem_algo: int,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
Expand All @@ -421,6 +425,7 @@ def __init__(
self._init_mem_algos()

self.memory_config = memory_config
self.opt_level = opt_level
self.mem_algo = mem_algo
self.alloc_graph_input = alloc_graph_input
self.alloc_graph_output = alloc_graph_output
Expand All @@ -434,6 +439,7 @@ def _init_mem_algos(self) -> None:

def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
mem_constraints = MemConstraints(
opt_level=self.opt_level,
alloc_graph_input=self.alloc_graph_input,
alloc_graph_output=self.alloc_graph_output,
)
Expand All @@ -448,7 +454,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
# True.
mem_planning = MemoryPlanningPass(
algo,
allow_lifetime_and_storage_overlap=False,
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
alloc_graph_input=self.alloc_graph_input,
alloc_graph_output=self.alloc_graph_output,
)
Expand Down

0 comments on commit 6b72663

Please sign in to comment.