@@ -752,8 +752,10 @@ def __init__(
752752 (low_priority , high_priority ) = torch .cuda .Stream .priority_range ()
753753 # GPU stream for SSD cache eviction
754754 self .ssd_eviction_stream = torch .cuda .Stream (priority = low_priority )
755- # GPU stream for SSD memory copy
755+ # GPU stream for SSD memory copy (also reused for feature score D2H)
756756 self .ssd_memcpy_stream = torch .cuda .Stream (priority = low_priority )
757+ # GPU stream for async metadata operation
758+ self .feature_score_stream = torch .cuda .Stream (priority = low_priority )
757759
758760 # SSD get completion event
759761 self .ssd_event_get = torch .cuda .Event ()
@@ -1665,6 +1667,56 @@ def _update_cache_counter_and_pointers(
16651667 unique_indices_length_curr = curr_data .actions_count_gpu ,
16661668 )
16671669
1670+ def _update_feature_score_metadata (
1671+ self ,
1672+ linear_cache_indices : Tensor ,
1673+ weights : Tensor ,
1674+ d2h_stream : torch .cuda .Stream ,
1675+ write_stream : torch .cuda .Stream ,
1676+ pre_event_for_write : torch .cuda .Event ,
1677+ post_event : Optional [torch .cuda .Event ] = None ,
1678+ ) -> None :
1679+ """
1680+ Write feature score metadata to DRAM
1681+
1682+ This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
1683+ The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.
1684+
1685+ Args:
1686+ linear_cache_indices: GPU tensor containing cache indices
1687+ weights: GPU tensor containing feature scores
1688+ d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
1689+ write_stream: Stream for metadata write operation
1690+ pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
1691+ post_event: Event to record when the operation is done
1692+ """
1693+ # Start D2H copy on d2h_stream
1694+ with torch .cuda .stream (d2h_stream ):
1695+ # Record streams to prevent premature deallocation
1696+ linear_cache_indices .record_stream (d2h_stream )
1697+ weights .record_stream (d2h_stream )
1698+ # Do the D2H copy
1699+ linear_cache_indices_cpu = self .to_pinned_cpu (linear_cache_indices )
1700+ score_weights_cpu = self .to_pinned_cpu (weights )
1701+
1702+ # Write feature score metadata to DRAM
1703+ with record_function ("## ssd_write_feature_score_metadata ##" ):
1704+ with torch .cuda .stream (write_stream ):
1705+ write_stream .wait_event (pre_event_for_write )
1706+ write_stream .wait_stream (d2h_stream )
1707+ self .record_function_via_dummy_profile (
1708+ "## ssd_write_feature_score_metadata ##" ,
1709+ self .ssd_db .set_feature_score_metadata_cuda ,
1710+ linear_cache_indices_cpu ,
1711+ torch .tensor (
1712+ [score_weights_cpu .shape [0 ]], device = "cpu" , dtype = torch .long
1713+ ),
1714+ score_weights_cpu ,
1715+ )
1716+
1717+ if post_event is not None :
1718+ write_stream .record_event (post_event )
1719+
16681720 def prefetch (
16691721 self ,
16701722 indices : Tensor ,
@@ -1737,12 +1789,6 @@ def _prefetch( # noqa C901
17371789
17381790 self .timestep += 1
17391791 self .timesteps_prefetched .append (self .timestep )
1740- if self .backend_type == BackendType .DRAM and weights is not None :
1741- # DRAM backend supports feature score eviction, if there is weights available
1742- # in the prefetch call, we will set metadata for feature score eviction asynchronously
1743- cloned_linear_cache_indices = linear_cache_indices .clone ()
1744- else :
1745- cloned_linear_cache_indices = None
17461792
17471793 # Lookup and virtually insert indices into L1. After this operator,
17481794 # we know:
@@ -2104,16 +2150,19 @@ def _prefetch( # noqa C901
21042150 name = "cache" ,
21052151 is_bwd = False ,
21062152 )
2107- if self .backend_type == BackendType .DRAM and weights is not None :
2108- # Write feature score metadata to DRAM
2109- self .record_function_via_dummy_profile (
2110- "## ssd_write_feature_score_metadata ##" ,
2111- self .ssd_db .set_feature_score_metadata_cuda ,
2112- cloned_linear_cache_indices .cpu (),
2113- torch .tensor (
2114- [weights .shape [0 ]], device = "cpu" , dtype = torch .long
2115- ),
2116- weights .cpu (),
2153+ if (
2154+ self .backend_type == BackendType .DRAM
2155+ and weights is not None
2156+ and linear_cache_indices .numel () > 0
2157+ ):
2158+ # Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
2159+ # (we've already waited for it at line 2066)
2160+ self ._update_feature_score_metadata (
2161+ linear_cache_indices = linear_cache_indices ,
2162+ weights = weights ,
2163+ d2h_stream = self .ssd_memcpy_stream ,
2164+ write_stream = self .feature_score_stream ,
2165+ pre_event_for_write = self .ssd_event_cache_evict ,
21172166 )
21182167
21192168 # Generate row addresses (pointing to either L1 or the current
0 commit comments