Skip to content

Commit 54da6bc

Browse files
emlinfacebook-github-bot
authored andcommitted
optimization: move set_metadata to ssd_eviction_stream
Summary: X-link: facebookresearch/FBGEMM#2090 with feature score eviction, tbe will call backend to update feature score metadata separately in forward pass. this process is designed for asynchronous update without blocking forward/backward pass, however the cpu blocking operation blocked the main stream, so after get_cuda, all2all cannot be started immediately. from dummy profile, we can see this trace: {F1983224804} the set metadata operation becomes a blocker in critical path, which took 217ms With this change, we can see the trace is updated to: {F1983224830} where overall prefetch is reduced to less than 70ms, also the get_cuda is followed by all2all immediately without other waiting and stream sync Differential Revision: D86013406
1 parent 32786e8 commit 54da6bc

File tree

1 file changed

+66
-17
lines changed

1 file changed

+66
-17
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -752,8 +752,10 @@ def __init__(
752752
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
753753
# GPU stream for SSD cache eviction
754754
self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority)
755-
# GPU stream for SSD memory copy
755+
# GPU stream for SSD memory copy (also reused for feature score D2H)
756756
self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority)
757+
# GPU stream for async metadata operation
758+
self.feature_score_stream = torch.cuda.Stream(priority=low_priority)
757759

758760
# SSD get completion event
759761
self.ssd_event_get = torch.cuda.Event()
@@ -1665,6 +1667,56 @@ def _update_cache_counter_and_pointers(
16651667
unique_indices_length_curr=curr_data.actions_count_gpu,
16661668
)
16671669

1670+
def _update_feature_score_metadata(
1671+
self,
1672+
linear_cache_indices: Tensor,
1673+
weights: Tensor,
1674+
d2h_stream: torch.cuda.Stream,
1675+
write_stream: torch.cuda.Stream,
1676+
pre_event_for_write: torch.cuda.Event,
1677+
post_event: Optional[torch.cuda.Event] = None,
1678+
) -> None:
1679+
"""
1680+
Write feature score metadata to DRAM
1681+
1682+
This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
1683+
The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.
1684+
1685+
Args:
1686+
linear_cache_indices: GPU tensor containing cache indices
1687+
weights: GPU tensor containing feature scores
1688+
d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
1689+
write_stream: Stream for metadata write operation
1690+
pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
1691+
post_event: Event to record when the operation is done
1692+
"""
1693+
# Start D2H copy on d2h_stream
1694+
with torch.cuda.stream(d2h_stream):
1695+
# Record streams to prevent premature deallocation
1696+
linear_cache_indices.record_stream(d2h_stream)
1697+
weights.record_stream(d2h_stream)
1698+
# Do the D2H copy
1699+
linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices)
1700+
score_weights_cpu = self.to_pinned_cpu(weights)
1701+
1702+
# Write feature score metadata to DRAM
1703+
with record_function("## ssd_write_feature_score_metadata ##"):
1704+
with torch.cuda.stream(write_stream):
1705+
write_stream.wait_event(pre_event_for_write)
1706+
write_stream.wait_stream(d2h_stream)
1707+
self.record_function_via_dummy_profile(
1708+
"## ssd_write_feature_score_metadata ##",
1709+
self.ssd_db.set_feature_score_metadata_cuda,
1710+
linear_cache_indices_cpu,
1711+
torch.tensor(
1712+
[score_weights_cpu.shape[0]], device="cpu", dtype=torch.long
1713+
),
1714+
score_weights_cpu,
1715+
)
1716+
1717+
if post_event is not None:
1718+
write_stream.record_event(post_event)
1719+
16681720
def prefetch(
16691721
self,
16701722
indices: Tensor,
@@ -1737,12 +1789,6 @@ def _prefetch( # noqa C901
17371789

17381790
self.timestep += 1
17391791
self.timesteps_prefetched.append(self.timestep)
1740-
if self.backend_type == BackendType.DRAM and weights is not None:
1741-
# DRAM backend supports feature score eviction, if there is weights available
1742-
# in the prefetch call, we will set metadata for feature score eviction asynchronously
1743-
cloned_linear_cache_indices = linear_cache_indices.clone()
1744-
else:
1745-
cloned_linear_cache_indices = None
17461792

17471793
# Lookup and virtually insert indices into L1. After this operator,
17481794
# we know:
@@ -2104,16 +2150,19 @@ def _prefetch( # noqa C901
21042150
name="cache",
21052151
is_bwd=False,
21062152
)
2107-
if self.backend_type == BackendType.DRAM and weights is not None:
2108-
# Write feature score metadata to DRAM
2109-
self.record_function_via_dummy_profile(
2110-
"## ssd_write_feature_score_metadata ##",
2111-
self.ssd_db.set_feature_score_metadata_cuda,
2112-
cloned_linear_cache_indices.cpu(),
2113-
torch.tensor(
2114-
[weights.shape[0]], device="cpu", dtype=torch.long
2115-
),
2116-
weights.cpu(),
2153+
if (
2154+
self.backend_type == BackendType.DRAM
2155+
and weights is not None
2156+
and linear_cache_indices.numel() > 0
2157+
):
2158+
# Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
2159+
# (we've already waited for it at line 2066)
2160+
self._update_feature_score_metadata(
2161+
linear_cache_indices=linear_cache_indices,
2162+
weights=weights,
2163+
d2h_stream=self.ssd_memcpy_stream,
2164+
write_stream=self.feature_score_stream,
2165+
pre_event_for_write=self.ssd_event_cache_evict,
21172166
)
21182167

21192168
# Generate row addresses (pointing to either L1 or the current

0 commit comments

Comments
 (0)