@@ -731,53 +731,43 @@ def _contains_xnnpack_delegate(graph_module: torch.fx.GraphModule) -> bool:
731
731
732
732
733
733
def greedy (
734
- graph_module : torch .fx .GraphModule ,
735
734
alignment : int ,
736
- graph_signature : Optional [ExportGraphSignature ] = None ,
737
- alloc_graph_input : bool = True ,
738
- alloc_graph_output : bool = True ,
735
+ specs : Set [TensorSpec ],
736
+ graph_module : torch .fx .GraphModule ,
737
+ graph_signature : ExportGraphSignature ,
738
+ extra_padding : int = 0 ,
739
+ * ,
739
740
allow_overlapping_allocations : bool = True ,
740
741
) -> MemoryAlgoResult :
741
742
r"""Greedy algorithm to allocate memory for tensors in the graph.
742
- alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
743
- alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
744
- allow_overlapping_allocations: If set to true, allows for allocations that overlap
745
- in their lifetime but are at different offsets in the storage. By default true.
746
- This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
747
- allocations disabled
743
+
744
+ Args:
745
+ alignment: Memory alignment requirement
746
+ specs: Set of TensorSpec objects with updated lifetimes
747
+ graph_module: Graph module
748
+ graph_signature: Graph signature
749
+ extra_padding: Additional padding to add to each memory buffer (in bytes)
750
+ allow_overlapping_allocations: If set to true, allows for allocations that overlap
751
+ in their lifetime but are at different offsets in the storage. By default true.
752
+ This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
753
+ allocations disabled
754
+
755
+ Returns:
756
+ MemoryAlgoResult containing the allocation decisions
748
757
"""
749
758
greedy_result = MemoryAlgoResult ({}, [])
750
- # padding allocation with 64 bytes.
751
- # this requirement is really for XNNPACK backend which can read tensors
752
- # beyond the end of the tensor. This is done for performance
753
- # optimizations in XNNPACK.
754
- # While accounting for backend specific requirement is not the right choice
755
- # in backend agnostic memory planning, we do it here as it seems most appropriate.
756
- # Right now this applies to greedy only so any other
757
- # algorithm that plans memory for XNNPACK backend will
758
- # not have this.
759
- extra_padded_bytes = 0
760
- if _contains_xnnpack_delegate (graph_module ):
761
- extra_padded_bytes = 64
762
759
spec2obj = {}
763
760
shared_objects = defaultdict (list )
764
- # Don't do assertion in collect_specs_from_nodes if we have already encountered
765
- # and ignored some to_out_variant errors.
766
- do_assertion = not getattr (graph_module , "encounter_to_out_var_failure" , False )
761
+
767
762
# For each tensor, pick the available shared object with closest size to
768
763
# the tensor. If there are no available shared object left, create a new
769
764
# one.
770
765
import bisect
771
766
772
767
sorted_specs = []
773
- for spec in collect_specs_from_nodes (
774
- graph_module .graph .nodes ,
775
- graph_signature ,
776
- do_assertion = do_assertion ,
777
- ignore_graph_input = not alloc_graph_input ,
778
- ignore_graph_output = not alloc_graph_output ,
779
- ):
768
+ for spec in specs :
780
769
bisect .insort (sorted_specs , spec , key = lambda x : x .allocated_memory )
770
+
781
771
sorted_specs .reverse ()
782
772
783
773
for spec in sorted_specs :
@@ -806,15 +796,13 @@ def greedy(
806
796
for mem_id in shared_objects :
807
797
input_total_size = 0
808
798
if bufsizes := getattr (graph_module , "input_mem_buffer_sizes" , None ):
809
- # pyre-fixme[6]: For 1st argument expected
810
- # `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`.
799
+ assert isinstance (bufsizes , list )
811
800
if len (bufsizes ) > mem_id :
812
- # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten...
813
801
input_total_size = bufsizes [mem_id ]
814
802
total_sizes [mem_id ] = materialize_buffer (
815
803
shared_objects [mem_id ], input_total_size
816
804
)
817
- total_sizes [mem_id ] += extra_padded_bytes
805
+ total_sizes [mem_id ] += extra_padding
818
806
819
807
# Since we now know the number of shared objects we need and the size of
820
808
# each shared object, we can assign offset in the memory buffer for each
@@ -837,73 +825,101 @@ def greedy(
837
825
greedy_result .bufsizes = total_sizes
838
826
return greedy_result
839
827
828
+ class MemoryPlanningAlgorithmSuite :
829
+ def __init__ (self , algo_list : Optional [List [Callable [..., MemoryAlgoResult ]]] = None ,) -> None :
830
+ if algo_list is None :
831
+ algo_list = [greedy ]
832
+ self .algo_list : List [Callable [..., MemoryAlgoResult ]] = algo_list
840
833
841
- def memory_planning_algorithm_suite (
842
- graph_module : torch .fx .GraphModule ,
843
- alignment : int ,
844
- graph_signature : Optional [ExportGraphSignature ] = None ,
845
- alloc_graph_input : bool = True ,
846
- alloc_graph_output : bool = True ,
847
- allow_overlapping_allocations : bool = True ,
848
- algo_list : Optional [List [Callable [..., MemoryAlgoResult ]]] = None ,
849
- ) -> List [int ]:
850
- r"""
851
- Memory planning algorithm suite that runs a list of memory planning algorithms
852
- and returns the result of the algorithm that minimizes the total memory usage.
853
- """
854
- if algo_list is None :
855
- algo_list = [greedy ]
856
- mem_algo_results = {}
857
- for algo in algo_list :
858
- if isinstance (algo , functools .partial ):
859
- name = algo .func .__name__
860
- else :
861
- name = getattr (algo , "__name__" , None )
862
- # Run this memory planning algorithm and store the result in mem_algo_results
863
- # with the name of the algorithm as the key.
864
- mem_algo_results [name ] = algo (
865
- graph_module ,
866
- alignment ,
867
- graph_signature ,
868
- alloc_graph_input ,
869
- alloc_graph_output ,
870
- )
834
+ def __call__ (
835
+ self ,
836
+ alignment : int ,
837
+ specs : Set [TensorSpec ],
838
+ graph_module : torch .fx .GraphModule ,
839
+ graph_signature : ExportGraphSignature ,
840
+ extra_padding : int ,
841
+ ) -> List [int ]:
842
+ r"""
843
+ Memory planning algorithm suite that runs a list of memory planning algorithms
844
+ and returns the result of the algorithm that minimizes the total memory usage.
845
+
846
+ Args:
847
+ graph_module: The graph module to allocate memory for
848
+ alignment: Memory alignment requirement
849
+ graph_signature: Optional graph signature
850
+ alloc_graph_input: Whether to allocate memory for graph input
851
+ alloc_graph_output: Whether to allocate memory for graph output
852
+ allow_overlapping_allocations: Whether to allow overlapping allocations
853
+ algo_list: List of memory planning algorithms to run
854
+ specs: Optional set of TensorSpec objects with updated lifetimes. If None, they will be
855
+ calculated from the graph_module.
856
+
857
+ Returns:
858
+ List of buffer sizes for each memory hierarchy
859
+ """
860
+
861
+ mem_algo_results = {}
862
+ for algo in self .algo_list :
863
+ if isinstance (algo , functools .partial ):
864
+ name = algo .func .__name__
865
+ else :
866
+ name = getattr (algo , "__name__" , None )
867
+
868
+ mem_algo_results [name ] = algo (
869
+ alignment ,
870
+ specs ,
871
+ graph_module ,
872
+ graph_signature ,
873
+ extra_padding ,
874
+ )
871
875
872
- # All the algorithms should have the same number of buffers allocated.
873
- assert (
874
- len (
875
- {
876
- len (mem_algo_result .bufsizes )
877
- for mem_algo_result in mem_algo_results .values ()
878
- }
879
- )
880
- == 1
881
- ), "Different memory planning algorithms should have the same number of buffers allocated."
882
-
883
- # Find the algorithm that minimizes the total memory usage.
884
- best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
885
- logging .debug (f"Best memory planning algo for this model is { best_algo } " )
886
- bufsizes = mem_algo_results [best_algo ].bufsizes
887
-
888
- # Update the mem_id and mem_offset for each spec in the graph module based on the
889
- # values provided by the best memory planning algorithm.
890
- for spec in mem_algo_results [best_algo ].spec_dict :
891
- spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
892
- spec .mem_id = spec_alloc_result .mem_id
893
- spec .mem_offset = spec_alloc_result .mem_offset
894
- spec .mem_obj_id = spec_alloc_result .mem_obj_id
876
+ # All the algorithms should have the same number of buffers allocated.
877
+ assert (
878
+ len (
879
+ {
880
+ len (mem_algo_result .bufsizes )
881
+ for mem_algo_result in mem_algo_results .values ()
882
+ }
883
+ )
884
+ == 1
885
+ ), "Different memory planning algorithms should have the same number of buffers allocated."
895
886
896
- return bufsizes
887
+ # Find the algorithm that minimizes the total memory usage.
888
+ best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
889
+ logging .debug (f"Best memory planning algo for this model is { best_algo } " )
890
+ bufsizes = mem_algo_results [best_algo ].bufsizes
897
891
892
+ # Update the mem_id and mem_offset for each spec in the graph module based on the
893
+ # values provided by the best memory planning algorithm.
894
+ for spec in mem_algo_results [best_algo ].spec_dict :
895
+ spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
896
+ spec .mem_id = spec_alloc_result .mem_id
897
+ spec .mem_offset = spec_alloc_result .mem_offset
898
+ spec .mem_obj_id = spec_alloc_result .mem_obj_id
899
+
900
+ return bufsizes
898
901
899
902
def naive (
900
- graph_module : torch .fx .GraphModule ,
901
903
alignment : int ,
902
- graph_signature : Optional [ExportGraphSignature ] = None ,
903
- alloc_graph_input : bool = True ,
904
- alloc_graph_output : bool = True ,
904
+ specs : Set [TensorSpec ],
905
+ graph_module : torch .fx .GraphModule ,
906
+ graph_signature : ExportGraphSignature ,
907
+ extra_padding : int ,
905
908
) -> MemoryAlgoResult :
906
-
909
+ """Naive algorithm to allocate memory for tensors in the graph.
910
+
911
+ This algorithm simply allocates memory for each tensor sequentially without reusing memory.
912
+
913
+ Args:
914
+ alignment: Memory alignment requirement
915
+ specs: Set of TensorSpec objects with updated lifetimes
916
+ graph_module: Graph module
917
+ graph_signature: Graph signature
918
+ extra_padding: Additional padding to add to each memory buffer (in bytes)
919
+
920
+ Returns:
921
+ MemoryAlgoResult containing the allocation decisions
922
+ """
907
923
naive_result = MemoryAlgoResult ({}, [])
908
924
909
925
# allocate 'allocated' bytes from buffer with id mem_id.
@@ -918,14 +934,9 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
918
934
bufsizes = getattr (graph_module , "input_mem_buffer_sizes" , None )
919
935
if bufsizes is None :
920
936
bufsizes = [0 , 0 ]
921
-
922
937
bufsizes = typing .cast (List [int ], bufsizes )
923
- for spec in collect_specs_from_nodes (
924
- graph_module .graph .nodes ,
925
- graph_signature ,
926
- ignore_graph_input = not alloc_graph_input ,
927
- ignore_graph_output = not alloc_graph_output ,
928
- ):
938
+
939
+ for spec in specs :
929
940
spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
930
941
# assume a single memory layer which has mem_id 1
931
942
if spec .mem_id is None :
@@ -1027,7 +1038,7 @@ def insert_calls_to_free(
1027
1038
1028
1039
def apply_algo (
1029
1040
algo : Callable [
1030
- [ torch . fx . GraphModule , int , Optional [ ExportGraphSignature ], bool , bool ] ,
1041
+ ... ,
1031
1042
List [int ],
1032
1043
],
1033
1044
graph_module : torch .fx .GraphModule ,
@@ -1048,10 +1059,46 @@ def apply_algo(
1048
1059
TODO: make these optimizations once we have some baseline working.
1049
1060
"""
1050
1061
1051
- specs = update_all_tensors_lifetime (graph_module , graph_signature )
1062
+ # Extract the nodes and their lifespans from the graph_module
1063
+ specs = update_all_tensors_lifetime (
1064
+ graph_module ,
1065
+ graph_signature
1066
+ )
1067
+
1068
+ # Filter specs based on alloc_graph_input and alloc_graph_output
1069
+ filtered_specs = set ()
1070
+ graph_input_tensors = get_graph_input_tensors (graph_module .graph .nodes , graph_signature )
1071
+ graph_output_tensors = get_graph_output_tensors (graph_module .graph .nodes )
1072
+
1073
+ for spec in specs :
1074
+ # Apply the same filtering as collect_specs_from_nodes
1075
+ if not alloc_graph_input and spec in graph_input_tensors :
1076
+ continue
1077
+ if not alloc_graph_output and spec in graph_output_tensors :
1078
+ continue
1079
+ if spec .shape_dynamism == TensorShapeDynamism .DYNAMIC_UNBOUND :
1080
+ continue
1081
+ # In Training we flag weights with an associated gradient,
1082
+ # as these need to be memory planned since their value will
1083
+ # be udated each step of training.
1084
+ if spec .const and not getattr (spec , "weight_has_gradient" , False ):
1085
+ continue
1086
+ filtered_specs .add (spec )
1087
+
1088
+ # Get extra padding for XNNPACK if needed
1089
+ extra_padding = 0
1090
+ if _contains_xnnpack_delegate (graph_module ):
1091
+ extra_padding = 64
1092
+
1093
+ # Pass the filtered specs to the algorithm
1052
1094
bufsizes : List [int ] = algo (
1053
- graph_module , alignment , graph_signature , alloc_graph_input , alloc_graph_output
1095
+ alignment ,
1096
+ filtered_specs ,
1097
+ graph_module ,
1098
+ graph_signature ,
1099
+ extra_padding ,
1054
1100
)
1101
+
1055
1102
insert_calls_to_free (graph_module , specs )
1056
1103
1057
1104
def handle_submodule (
@@ -1063,6 +1110,7 @@ def handle_submodule(
1063
1110
# memory planning for submodule need to be aware of the amount of
1064
1111
# buffer already allocated.
1065
1112
submodule .input_mem_buffer_sizes = bufsizes
1113
+
1066
1114
bufsizes = apply_algo (
1067
1115
algo ,
1068
1116
submodule ,
0 commit comments