From f716372df2844678f924b3b83eb8cf18325c302c Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Fri, 7 Nov 2025 17:23:23 -0800 Subject: [PATCH] optimization: move set_metadata out of main stream (#5082) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2090 with feature score eviction, tbe will call backend to update feature score metadata separately in forward pass. this process is designed for asynchronous update without blocking forward/backward pass, however the cpu blocking operation blocked the main stream, so after get_cuda, all2all cannot be started immediately. from dummy profile, we can see this trace: {F1983224804} the set metadata operation becomes a blocker in critical path, which took 217ms With this change, we can see the trace is updated to: {F1983224830} where overall prefetch is reduced to less than 70ms, also the get_cuda is followed by all2all immediately without other waiting and stream sync https://www.internalfb.com/ai_infra/zoomer/profiling-run/overview?profilingRunID=1913270729575721 Reviewed By: steven1327, kathyxuyy Differential Revision: D86013406 --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 82 ++++++++++++++++++----- 1 file changed, 65 insertions(+), 17 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 8425569171..601fcf6abd 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -762,8 +762,10 @@ def __init__( (low_priority, high_priority) = torch.cuda.Stream.priority_range() # GPU stream for SSD cache eviction self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority) - # GPU stream for SSD memory copy + # GPU stream for SSD memory copy (also reused for feature score D2H) self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority) + # GPU stream for async metadata operation + self.feature_score_stream = torch.cuda.Stream(priority=low_priority) # SSD get completion event self.ssd_event_get = torch.cuda.Event() @@ -1675,6 +1677,56 @@ def _update_cache_counter_and_pointers( unique_indices_length_curr=curr_data.actions_count_gpu, ) + def _update_feature_score_metadata( + self, + linear_cache_indices: Tensor, + weights: Tensor, + d2h_stream: torch.cuda.Stream, + write_stream: torch.cuda.Stream, + pre_event_for_write: torch.cuda.Event, + post_event: Optional[torch.cuda.Event] = None, + ) -> None: + """ + Write feature score metadata to DRAM + + This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream. + The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations. + + Args: + linear_cache_indices: GPU tensor containing cache indices + weights: GPU tensor containing feature scores + d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately) + write_stream: Stream for metadata write operation + pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction) + post_event: Event to record when the operation is done + """ + # Start D2H copy on d2h_stream + with torch.cuda.stream(d2h_stream): + # Record streams to prevent premature deallocation + linear_cache_indices.record_stream(d2h_stream) + weights.record_stream(d2h_stream) + # Do the D2H copy + linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices) + score_weights_cpu = self.to_pinned_cpu(weights) + + # Write feature score metadata to DRAM + with record_function("## ssd_write_feature_score_metadata ##"): + with torch.cuda.stream(write_stream): + write_stream.wait_event(pre_event_for_write) + write_stream.wait_stream(d2h_stream) + self.record_function_via_dummy_profile( + "## ssd_write_feature_score_metadata ##", + self.ssd_db.set_feature_score_metadata_cuda, + linear_cache_indices_cpu, + torch.tensor( + [score_weights_cpu.shape[0]], device="cpu", dtype=torch.long + ), + score_weights_cpu, + ) + + if post_event is not None: + write_stream.record_event(post_event) + def prefetch( self, indices: Tensor, @@ -1747,12 +1799,6 @@ def _prefetch( # noqa C901 self.timestep += 1 self.timesteps_prefetched.append(self.timestep) - if self.backend_type == BackendType.DRAM and weights is not None: - # DRAM backend supports feature score eviction, if there is weights available - # in the prefetch call, we will set metadata for feature score eviction asynchronously - cloned_linear_cache_indices = linear_cache_indices.clone() - else: - cloned_linear_cache_indices = None # Lookup and virtually insert indices into L1. After this operator, # we know: @@ -2114,16 +2160,18 @@ def _prefetch( # noqa C901 name="cache", is_bwd=False, ) - if self.backend_type == BackendType.DRAM and weights is not None: - # Write feature score metadata to DRAM - self.record_function_via_dummy_profile( - "## ssd_write_feature_score_metadata ##", - self.ssd_db.set_feature_score_metadata_cuda, - cloned_linear_cache_indices.cpu(), - torch.tensor( - [weights.shape[0]], device="cpu", dtype=torch.long - ), - weights.cpu(), + if ( + self.backend_type == BackendType.DRAM + and weights is not None + and linear_cache_indices.numel() > 0 + ): + # Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done + self._update_feature_score_metadata( + linear_cache_indices=linear_cache_indices, + weights=weights, + d2h_stream=self.ssd_memcpy_stream, + write_stream=self.feature_score_stream, + pre_event_for_write=self.ssd_event_cache_evict, ) # Generate row addresses (pointing to either L1 or the current