Skip to content

Commit 3cf1bc2

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
memory_planning algos take the specs as inputs instead of calculating them themselves
Summary: Refactor the algos to not calculate the active set of nodes themselves. Future work should refactor apply_algo more to not be recursive and actually smartly handle lifespans on control flow. Made the memory planning suite a class so that the algo list is easily configureable without having to make a wrapper function. Differential Revision: D72600295
1 parent 2cce2db commit 3cf1bc2

File tree

3 files changed

+159
-113
lines changed

3 files changed

+159
-113
lines changed

exir/memory_planning.py

Lines changed: 153 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -731,53 +731,43 @@ def _contains_xnnpack_delegate(graph_module: torch.fx.GraphModule) -> bool:
731731

732732

733733
def greedy(
734-
graph_module: torch.fx.GraphModule,
735734
alignment: int,
736-
graph_signature: Optional[ExportGraphSignature] = None,
737-
alloc_graph_input: bool = True,
738-
alloc_graph_output: bool = True,
735+
specs: Set[TensorSpec],
736+
graph_module: torch.fx.GraphModule,
737+
graph_signature: ExportGraphSignature,
738+
extra_padding: int = 0,
739+
*,
739740
allow_overlapping_allocations: bool = True,
740741
) -> MemoryAlgoResult:
741742
r"""Greedy algorithm to allocate memory for tensors in the graph.
742-
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
743-
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
744-
allow_overlapping_allocations: If set to true, allows for allocations that overlap
745-
in their lifetime but are at different offsets in the storage. By default true.
746-
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
747-
allocations disabled
743+
744+
Args:
745+
alignment: Memory alignment requirement
746+
specs: Set of TensorSpec objects with updated lifetimes
747+
graph_module: Graph module
748+
graph_signature: Graph signature
749+
extra_padding: Additional padding to add to each memory buffer (in bytes)
750+
allow_overlapping_allocations: If set to true, allows for allocations that overlap
751+
in their lifetime but are at different offsets in the storage. By default true.
752+
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
753+
allocations disabled
754+
755+
Returns:
756+
MemoryAlgoResult containing the allocation decisions
748757
"""
749758
greedy_result = MemoryAlgoResult({}, [])
750-
# padding allocation with 64 bytes.
751-
# this requirement is really for XNNPACK backend which can read tensors
752-
# beyond the end of the tensor. This is done for performance
753-
# optimizations in XNNPACK.
754-
# While accounting for backend specific requirement is not the right choice
755-
# in backend agnostic memory planning, we do it here as it seems most appropriate.
756-
# Right now this applies to greedy only so any other
757-
# algorithm that plans memory for XNNPACK backend will
758-
# not have this.
759-
extra_padded_bytes = 0
760-
if _contains_xnnpack_delegate(graph_module):
761-
extra_padded_bytes = 64
762759
spec2obj = {}
763760
shared_objects = defaultdict(list)
764-
# Don't do assertion in collect_specs_from_nodes if we have already encountered
765-
# and ignored some to_out_variant errors.
766-
do_assertion = not getattr(graph_module, "encounter_to_out_var_failure", False)
761+
767762
# For each tensor, pick the available shared object with closest size to
768763
# the tensor. If there are no available shared object left, create a new
769764
# one.
770765
import bisect
771766

772767
sorted_specs = []
773-
for spec in collect_specs_from_nodes(
774-
graph_module.graph.nodes,
775-
graph_signature,
776-
do_assertion=do_assertion,
777-
ignore_graph_input=not alloc_graph_input,
778-
ignore_graph_output=not alloc_graph_output,
779-
):
768+
for spec in specs:
780769
bisect.insort(sorted_specs, spec, key=lambda x: x.allocated_memory)
770+
781771
sorted_specs.reverse()
782772

783773
for spec in sorted_specs:
@@ -806,15 +796,13 @@ def greedy(
806796
for mem_id in shared_objects:
807797
input_total_size = 0
808798
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
809-
# pyre-fixme[6]: For 1st argument expected
810-
# `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`.
799+
assert isinstance(bufsizes, list)
811800
if len(bufsizes) > mem_id:
812-
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten...
813801
input_total_size = bufsizes[mem_id]
814802
total_sizes[mem_id] = materialize_buffer(
815803
shared_objects[mem_id], input_total_size
816804
)
817-
total_sizes[mem_id] += extra_padded_bytes
805+
total_sizes[mem_id] += extra_padding
818806

819807
# Since we now know the number of shared objects we need and the size of
820808
# each shared object, we can assign offset in the memory buffer for each
@@ -837,73 +825,101 @@ def greedy(
837825
greedy_result.bufsizes = total_sizes
838826
return greedy_result
839827

828+
class MemoryPlanningAlgorithmSuite:
829+
def __init__(self, algo_list: Optional[List[Callable[..., MemoryAlgoResult]]] = None,) -> None:
830+
if algo_list is None:
831+
algo_list = [greedy]
832+
self.algo_list: List[Callable[..., MemoryAlgoResult]] = algo_list
840833

841-
def memory_planning_algorithm_suite(
842-
graph_module: torch.fx.GraphModule,
843-
alignment: int,
844-
graph_signature: Optional[ExportGraphSignature] = None,
845-
alloc_graph_input: bool = True,
846-
alloc_graph_output: bool = True,
847-
allow_overlapping_allocations: bool = True,
848-
algo_list: Optional[List[Callable[..., MemoryAlgoResult]]] = None,
849-
) -> List[int]:
850-
r"""
851-
Memory planning algorithm suite that runs a list of memory planning algorithms
852-
and returns the result of the algorithm that minimizes the total memory usage.
853-
"""
854-
if algo_list is None:
855-
algo_list = [greedy]
856-
mem_algo_results = {}
857-
for algo in algo_list:
858-
if isinstance(algo, functools.partial):
859-
name = algo.func.__name__
860-
else:
861-
name = getattr(algo, "__name__", None)
862-
# Run this memory planning algorithm and store the result in mem_algo_results
863-
# with the name of the algorithm as the key.
864-
mem_algo_results[name] = algo(
865-
graph_module,
866-
alignment,
867-
graph_signature,
868-
alloc_graph_input,
869-
alloc_graph_output,
870-
)
834+
def __call__(
835+
self,
836+
alignment: int,
837+
specs: Set[TensorSpec],
838+
graph_module: torch.fx.GraphModule,
839+
graph_signature: ExportGraphSignature,
840+
extra_padding: int,
841+
) -> List[int]:
842+
r"""
843+
Memory planning algorithm suite that runs a list of memory planning algorithms
844+
and returns the result of the algorithm that minimizes the total memory usage.
845+
846+
Args:
847+
graph_module: The graph module to allocate memory for
848+
alignment: Memory alignment requirement
849+
graph_signature: Optional graph signature
850+
alloc_graph_input: Whether to allocate memory for graph input
851+
alloc_graph_output: Whether to allocate memory for graph output
852+
allow_overlapping_allocations: Whether to allow overlapping allocations
853+
algo_list: List of memory planning algorithms to run
854+
specs: Optional set of TensorSpec objects with updated lifetimes. If None, they will be
855+
calculated from the graph_module.
856+
857+
Returns:
858+
List of buffer sizes for each memory hierarchy
859+
"""
860+
861+
mem_algo_results = {}
862+
for algo in self.algo_list:
863+
if isinstance(algo, functools.partial):
864+
name = algo.func.__name__
865+
else:
866+
name = getattr(algo, "__name__", None)
867+
868+
mem_algo_results[name] = algo(
869+
alignment,
870+
specs,
871+
graph_module,
872+
graph_signature,
873+
extra_padding,
874+
)
871875

872-
# All the algorithms should have the same number of buffers allocated.
873-
assert (
874-
len(
875-
{
876-
len(mem_algo_result.bufsizes)
877-
for mem_algo_result in mem_algo_results.values()
878-
}
879-
)
880-
== 1
881-
), "Different memory planning algorithms should have the same number of buffers allocated."
882-
883-
# Find the algorithm that minimizes the total memory usage.
884-
best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes))
885-
logging.debug(f"Best memory planning algo for this model is {best_algo}")
886-
bufsizes = mem_algo_results[best_algo].bufsizes
887-
888-
# Update the mem_id and mem_offset for each spec in the graph module based on the
889-
# values provided by the best memory planning algorithm.
890-
for spec in mem_algo_results[best_algo].spec_dict:
891-
spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec]
892-
spec.mem_id = spec_alloc_result.mem_id
893-
spec.mem_offset = spec_alloc_result.mem_offset
894-
spec.mem_obj_id = spec_alloc_result.mem_obj_id
876+
# All the algorithms should have the same number of buffers allocated.
877+
assert (
878+
len(
879+
{
880+
len(mem_algo_result.bufsizes)
881+
for mem_algo_result in mem_algo_results.values()
882+
}
883+
)
884+
== 1
885+
), "Different memory planning algorithms should have the same number of buffers allocated."
895886

896-
return bufsizes
887+
# Find the algorithm that minimizes the total memory usage.
888+
best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes))
889+
logging.debug(f"Best memory planning algo for this model is {best_algo}")
890+
bufsizes = mem_algo_results[best_algo].bufsizes
897891

892+
# Update the mem_id and mem_offset for each spec in the graph module based on the
893+
# values provided by the best memory planning algorithm.
894+
for spec in mem_algo_results[best_algo].spec_dict:
895+
spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec]
896+
spec.mem_id = spec_alloc_result.mem_id
897+
spec.mem_offset = spec_alloc_result.mem_offset
898+
spec.mem_obj_id = spec_alloc_result.mem_obj_id
899+
900+
return bufsizes
898901

899902
def naive(
900-
graph_module: torch.fx.GraphModule,
901903
alignment: int,
902-
graph_signature: Optional[ExportGraphSignature] = None,
903-
alloc_graph_input: bool = True,
904-
alloc_graph_output: bool = True,
904+
specs: Set[TensorSpec],
905+
graph_module: torch.fx.GraphModule,
906+
graph_signature: ExportGraphSignature,
907+
extra_padding: int,
905908
) -> MemoryAlgoResult:
906-
909+
"""Naive algorithm to allocate memory for tensors in the graph.
910+
911+
This algorithm simply allocates memory for each tensor sequentially without reusing memory.
912+
913+
Args:
914+
alignment: Memory alignment requirement
915+
specs: Set of TensorSpec objects with updated lifetimes
916+
graph_module: Graph module
917+
graph_signature: Graph signature
918+
extra_padding: Additional padding to add to each memory buffer (in bytes)
919+
920+
Returns:
921+
MemoryAlgoResult containing the allocation decisions
922+
"""
907923
naive_result = MemoryAlgoResult({}, [])
908924

909925
# allocate 'allocated' bytes from buffer with id mem_id.
@@ -918,14 +934,9 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
918934
bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
919935
if bufsizes is None:
920936
bufsizes = [0, 0]
921-
922937
bufsizes = typing.cast(List[int], bufsizes)
923-
for spec in collect_specs_from_nodes(
924-
graph_module.graph.nodes,
925-
graph_signature,
926-
ignore_graph_input=not alloc_graph_input,
927-
ignore_graph_output=not alloc_graph_output,
928-
):
938+
939+
for spec in specs:
929940
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
930941
# assume a single memory layer which has mem_id 1
931942
if spec.mem_id is None:
@@ -1027,7 +1038,7 @@ def insert_calls_to_free(
10271038

10281039
def apply_algo(
10291040
algo: Callable[
1030-
[torch.fx.GraphModule, int, Optional[ExportGraphSignature], bool, bool],
1041+
...,
10311042
List[int],
10321043
],
10331044
graph_module: torch.fx.GraphModule,
@@ -1048,10 +1059,46 @@ def apply_algo(
10481059
TODO: make these optimizations once we have some baseline working.
10491060
"""
10501061

1051-
specs = update_all_tensors_lifetime(graph_module, graph_signature)
1062+
# Extract the nodes and their lifespans from the graph_module
1063+
specs = update_all_tensors_lifetime(
1064+
graph_module,
1065+
graph_signature
1066+
)
1067+
1068+
# Filter specs based on alloc_graph_input and alloc_graph_output
1069+
filtered_specs = set()
1070+
graph_input_tensors = get_graph_input_tensors(graph_module.graph.nodes, graph_signature)
1071+
graph_output_tensors = get_graph_output_tensors(graph_module.graph.nodes)
1072+
1073+
for spec in specs:
1074+
# Apply the same filtering as collect_specs_from_nodes
1075+
if not alloc_graph_input and spec in graph_input_tensors:
1076+
continue
1077+
if not alloc_graph_output and spec in graph_output_tensors:
1078+
continue
1079+
if spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND:
1080+
continue
1081+
# In Training we flag weights with an associated gradient,
1082+
# as these need to be memory planned since their value will
1083+
# be udated each step of training.
1084+
if spec.const and not getattr(spec, "weight_has_gradient", False):
1085+
continue
1086+
filtered_specs.add(spec)
1087+
1088+
# Get extra padding for XNNPACK if needed
1089+
extra_padding = 0
1090+
if _contains_xnnpack_delegate(graph_module):
1091+
extra_padding = 64
1092+
1093+
# Pass the filtered specs to the algorithm
10521094
bufsizes: List[int] = algo(
1053-
graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output
1095+
alignment,
1096+
filtered_specs,
1097+
graph_module,
1098+
graph_signature,
1099+
extra_padding,
10541100
)
1101+
10551102
insert_calls_to_free(graph_module, specs)
10561103

10571104
def handle_submodule(
@@ -1063,6 +1110,7 @@ def handle_submodule(
10631110
# memory planning for submodule need to be aware of the amount of
10641111
# buffer already allocated.
10651112
submodule.input_mem_buffer_sizes = bufsizes
1113+
10661114
bufsizes = apply_algo(
10671115
algo,
10681116
submodule,

exir/passes/memory_planning_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
_is_out_var_node,
1818
apply_algo,
1919
get_node_tensor_specs,
20-
memory_planning_algorithm_suite,
20+
MemoryPlanningAlgorithmSuite,
2121
Verifier,
2222
)
2323
from executorch.exir.operator.convert import get_out_args_from_opoverload
@@ -42,7 +42,7 @@ def __init__(
4242
self,
4343
memory_planning_algo: Callable[
4444
..., List[int]
45-
] = memory_planning_algorithm_suite,
45+
] = MemoryPlanningAlgorithmSuite(),
4646
allow_lifetime_and_storage_overlap: bool = False,
4747
alloc_graph_input: bool = True,
4848
alloc_graph_output: bool = True,

0 commit comments

Comments
 (0)