Skip to content

Commit

Permalink
Introduce mem_obj_id to TensorSpec (pytorch#1868)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1868

## Background

Using [Efficient Memory Planning for Deep Neural Networks](https://arxiv.org/pdf/2001.03288.pdf) as a reference, there are generally two approaches to solving the memory planning problem for a neural network:

* **Shared Objects Approach**
    * Match tensors to “shared objects” (i.e. shared memory allocation), based on tensor sizes and lifetimes
    * The goal is then to minimize the total size of each shared object
* **Memory Offset Calculation Approach**
    * Assign each tensor to a specific offset in a large memory arena shared by all tensors, based on tensor sizes and lifetimes
    * The goal is then to minimize the total size of the overall memory arena

Though the **Memory Offset Calculation Approach** can produce more optimal solutions, it cannot be used for GPU textures because a texture memory cannot be divided. **To plan memory for GPU textures, memory planning must be solved using the Shared Object Approach**. Note that a solution to the Shared Objects problem can be converted to a solution for the Memory Offses problem by materializing the shared objects as buffers within a memory arena.

## Context

Currently, memory planning algorithms implemented for `exir`'s `MemoryPlanningPass` output solutions to the memory planning problem in the **Memory Offsets** format. The `greedy` algorithm solves memory planning as a Shared Objects problem, but then converts the solution to the memory offset format.

This changeset introduces the `mem_obj_id` field to `TensorSpec`, which memory planning algorithms can use to record shared object ids.

## Review Guide

* The `greedy` memory planning algorithm now records `mem_obj_id` in addition to `mem_offset`
* `verify_storage_reuse()` of `Verifier` class now checks whether `mem_obj_id` is valid, if it is set

Reviewed By: ydwu4

Differential Revision: D53496787

fbshipit-source-id: 0c2f81a00c9254b5af3d94e42a2dc3d34da1da6e
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Feb 7, 2024
1 parent b76d409 commit 544d296
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
61 changes: 56 additions & 5 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ def __init__(
self.alloc_graph_input = alloc_graph_input
self.alloc_graph_output = alloc_graph_output

@classmethod
def mem_obj_id_match(
cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec, accept_both_none: bool = True
) -> bool:
"""
Given two `TensorSpec`, return if their `mem_obj_id` are the same. Note that if
both are None, this function will return True if `accept_both_none` is True and
False otherwise.
"""
if lhs_spec.mem_id != rhs_spec.mem_id:
return False

# both are None
if lhs_spec.mem_obj_id is None and rhs_spec.mem_obj_id is None:
return accept_both_none

return lhs_spec.mem_obj_id == rhs_spec.mem_obj_id

@classmethod
def has_overlap(cls, lhs_ivl: List[int], rhs_ivl: List[int]) -> bool:
r"""
Expand Down Expand Up @@ -95,9 +113,11 @@ def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
f"{spec} should have specified memory offset",
)
intervals.append(
[spec.mem_offset, spec.mem_offset + spec.allocated_memory - 1]
[spec.mem_offset, spec.mem_offset + max(spec.allocated_memory - 1, 0)]
)
return cls.has_overlap(*intervals)
has_overlap = cls.has_overlap(*intervals)

return has_overlap

def verify_storage_reuse(
self, allow_lifetime_and_storage_overlap: bool = False
Expand Down Expand Up @@ -126,15 +146,41 @@ def verify_storage_reuse(

for lhs_spec_idx, lhs_spec in enumerate(all_specs):
for rhs_spec in all_specs[lhs_spec_idx + 1 :]:
if not self.storage_overlap(lhs_spec, rhs_spec):
# Check that both specs are consistent about whether mem_obj_id is defined
if (lhs_spec.mem_obj_id is None) != (rhs_spec.mem_obj_id is None):
raise InternalError(
"Specs do not agree on whether mem_obj_id is defined."
)

has_storage_overlap = Verifier.storage_overlap(lhs_spec, rhs_spec)
if not has_storage_overlap:
# Check that each mem_obj_id is consistent with whether the tensors
# have storage overlap
if Verifier.mem_obj_id_match(
lhs_spec, rhs_spec, accept_both_none=False
):
raise InternalError(
f"Unexpected mem_obj_id match: "
f"lhs {lhs_spec} with id {lhs_spec.mem_obj_id} "
f"rhs {rhs_spec} with id {rhs_spec.mem_obj_id}"
)
continue

if not allow_lifetime_and_storage_overlap and self.lifetime_overlap(
lhs_spec, rhs_spec
):
raise InternalError(
f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}"
)
num_reuse_pairs += Verifier.storage_overlap(lhs_spec, rhs_spec)

# Check that each mem_obj_id is consistent with whether the tensors have
# storage overlap
if not Verifier.mem_obj_id_match(lhs_spec, rhs_spec):
raise InternalError(
f"Unexpected mem_obj_id mismatch: lhs {lhs_spec}, rhs {rhs_spec}"
)

num_reuse_pairs += 1

return num_reuse_pairs

Expand Down Expand Up @@ -386,6 +432,8 @@ class SharedObject:
last_used_index attribute. The shared object will be available for nodes
with index greater than last_used_index.
"""
# index of the shared object in the list of shared objects, used as a unique id
idx: int
# offset in the memory buffer
offset: int
# size of this shared object in bytes
Expand Down Expand Up @@ -435,7 +483,9 @@ def pick_shared_obj(
sobj.last_used_index = spec.lifetime[1]
sobj.size = max(sobj.size, spec.allocated_memory)
if picked is None:
picked = SharedObject(-1, spec.allocated_memory, spec.lifetime[1])
picked = SharedObject(
len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1]
)
shared_objects.append(picked)

return picked
Expand Down Expand Up @@ -503,6 +553,7 @@ def greedy(
# each shared object, we can assign offset in the memory buffer for each
# shared object.
for spec, sobj in spec2obj.items():
spec.mem_obj_id = sobj.idx
spec.mem_offset = sobj.offset

logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
Expand Down
1 change: 1 addition & 0 deletions exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def from_tensor(cls, tensor: torch.Tensor, const: bool = False) -> TensorSpec:
def init_mem_planning_fields(self) -> None:
self.lifetime = [None, None]
self.mem_id = None
self.mem_obj_id = None
self.mem_offset = None

@property
Expand Down

0 comments on commit 544d296

Please sign in to comment.