From 9675227472ccda3ba9bf61b759a3e24a95b4aded Mon Sep 17 00:00:00 2001 From: Caner Gocmen Date: Thu, 24 Apr 2025 10:19:00 -0700 Subject: [PATCH] Reduce unnecessary GreedyPerfPartitioner calls from MemoryBalancedPartitioner Summary: MemoryBalancedPartitioner works by adjusting the max memory on devices and calling GreedyPerfPartitioner repeatedly. The max memory is adjusted with a binary search procedure to identify a more memory efficient plan than what GreedyPerfPartitioner gives by default. The search boundaries for the binary search procedure were inefficient which this diff addresses. * **Upper bound** * **Before:** Max device HBM (e.g. 80 GB) * **After:** Max HBM usage of the default plan since there is no point in searching for plans that use more max memory than what the default plan uses. * **Lower bound:** * **Before:** [Avg. HBM per Device] = [Total HBM Needed Across All Shards] / [World Size] * **After:** max([Avg. HBM per Device], [Max HBM Needed Across All Shards]). A feasible solution requires at least the max HBM that the biggest shard needs so there is no point in searching for options below that. Making these changes can have impact in two ways: 1. Search procedure is more efficient leading to plans with lower memory 2. We can reduce `search_count` to get comparable plans as before while calling `GreedyPerfPartitioner` less number of times from `MemoryBalancedPartitioner`. The default impact without further changes from #1 should lead to a marginal max memory improvement. Differential Revision: D73598477 --- torchrec/distributed/planner/partitioners.py | 29 ++++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/torchrec/distributed/planner/partitioners.py b/torchrec/distributed/planner/partitioners.py index b397c5064..7b0cc3257 100644 --- a/torchrec/distributed/planner/partitioners.py +++ b/torchrec/distributed/planner/partitioners.py @@ -599,17 +599,28 @@ def partition( default_plan = copy.deepcopy(default_plan) original_plan_perf = _perf_model.rate(default_plan) - max_hbm_per_device: int = _topology.devices[0].storage.hbm + # compute shard and default plan HBM stats + hbm_by_rank = [0] * _topology.world_size + hbm_requirement: int = 0 + max_shard_hbm: int = 0 + for sharding_option in default_plan: + for shard in sharding_option.shards: + if shard.storage is not None and shard.rank is not None: + hbm_used = shard.storage.hbm + rank = shard.rank + hbm_by_rank[rank] += hbm_used + hbm_requirement += hbm_used + max_shard_hbm = max(max_shard_hbm, hbm_used) + + # Upper bound for the search is the default plan's max HBM usage + max_hbm_per_device: int = max(hbm_by_rank) logger.info( - f"Default plan uses {round(bytes_to_gb(max_hbm_per_device), 3)} GB per device." + f"Default plan max HBM is {round(bytes_to_gb(max_hbm_per_device), 3)} GB." ) - hbm_requirement: int = 0 - for sharding_option in proposal: - for shard in sharding_option.shards: - if shard.storage is not None: - hbm_requirement += shard.storage.hbm - min_hbm_per_device: int = int(hbm_requirement / _topology.world_size) + # Lower bound for the search is the maximum of avg. HBM usage or the biggest shard + avg_hbm_usage: int = int(hbm_requirement / _topology.world_size) + min_hbm_per_device: int = max(avg_hbm_usage, max_shard_hbm) logger.info( "Searching in the range (min_hbm_per_device, max_hbm_per_device): " f"({round(bytes_to_gb(min_hbm_per_device), 3)}, " @@ -660,7 +671,7 @@ def partition( max_hbm_per_device = mid_hbm_per_device except PlannerError: logger.info( - f"Couldn't find a plan with {round(bytes_to_gb(max_hbm_per_device), 3)} " + f"Couldn't find a plan with {round(bytes_to_gb(mid_hbm_per_device), 3)} " f"GB per device for embedding tables." ) min_hbm_per_device = mid_hbm_per_device