Skip to content

Reduce unnecessary GreedyPerfPartitioner calls from MemoryBalancedPartitioner #2914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions torchrec/distributed/planner/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,17 +599,28 @@ def partition(
default_plan = copy.deepcopy(default_plan)
original_plan_perf = _perf_model.rate(default_plan)

max_hbm_per_device: int = _topology.devices[0].storage.hbm
# compute shard and default plan HBM stats
hbm_by_rank = [0] * _topology.world_size
hbm_requirement: int = 0
max_shard_hbm: int = 0
for sharding_option in default_plan:
for shard in sharding_option.shards:
if shard.storage is not None and shard.rank is not None:
hbm_used = shard.storage.hbm
rank = shard.rank
hbm_by_rank[rank] += hbm_used
hbm_requirement += hbm_used
max_shard_hbm = max(max_shard_hbm, hbm_used)

# Upper bound for the search is the default plan's max HBM usage
max_hbm_per_device: int = max(hbm_by_rank)
logger.info(
f"Default plan uses {round(bytes_to_gb(max_hbm_per_device), 3)} GB per device."
f"Default plan max HBM is {round(bytes_to_gb(max_hbm_per_device), 3)} GB."
)

hbm_requirement: int = 0
for sharding_option in proposal:
for shard in sharding_option.shards:
if shard.storage is not None:
hbm_requirement += shard.storage.hbm
min_hbm_per_device: int = int(hbm_requirement / _topology.world_size)
# Lower bound for the search is the maximum of avg. HBM usage or the biggest shard
avg_hbm_usage: int = int(hbm_requirement / _topology.world_size)
min_hbm_per_device: int = max(avg_hbm_usage, max_shard_hbm)
logger.info(
"Searching in the range (min_hbm_per_device, max_hbm_per_device): "
f"({round(bytes_to_gb(min_hbm_per_device), 3)}, "
Expand Down Expand Up @@ -660,7 +671,7 @@ def partition(
max_hbm_per_device = mid_hbm_per_device
except PlannerError:
logger.info(
f"Couldn't find a plan with {round(bytes_to_gb(max_hbm_per_device), 3)} "
f"Couldn't find a plan with {round(bytes_to_gb(mid_hbm_per_device), 3)} "
f"GB per device for embedding tables."
)
min_hbm_per_device = mid_hbm_per_device
Expand Down