@@ -261,127 +261,3 @@ def reader_thread() -> None: # pyre-ignore
261261 self .assertTrue (equal_one_of (embs [5 , :4 ], possible_embs ))
262262 reader_thread .join ()
263263 self .assertFalse (reader_failed_event .is_set ())
264-
265- def test_randomized_cache_miss_initialization (self ) -> None :
266- """Test that cache misses use randomized data from existing blocks."""
267- num_shards = 8
268- uniform_init_lower : float = - 0.01
269- uniform_init_upper : float = 0.01
270-
271- # Create DRAM KV inference cache
272- kv_embedding_cache = torch .classes .fbgemm .DramKVEmbeddingInferenceWrapper (
273- num_shards ,
274- uniform_init_lower ,
275- uniform_init_upper ,
276- False , # disable_random_init
277- )
278- kv_embedding_cache .init (
279- [(32 , 4 , SparseType .FP16 .as_int ())],
280- 32 ,
281- 4 ,
282- torch .tensor ([0 , 100 ], dtype = torch .int64 ),
283- )
284-
285- # Setup: Populate the cache with many initial values for better randomization diversity
286- # Use 400 setup items to ensure each shard (8 shards) gets ~50 entries for good randomization
287- setup_indices = torch .arange (0 , 400 , dtype = torch .int64 ) # 400 setup items
288- setup_weights = torch .randint (
289- 1 , 255 , (400 , 32 ), dtype = torch .uint8
290- ) # Non-zero values to ensure randomization source
291- print (f"setup_weights: { setup_weights } " )
292-
293- # Populate cache
294- kv_embedding_cache .set_embeddings (setup_indices , setup_weights )
295-
296- # Execute: Request cache misses multiple times - these should get randomized initialization
297- # Use indices outside the range [0, 399] to ensure they are actual cache misses
298- miss_indices = torch .tensor ([500 , 501 , 502 , 503 , 504 ], dtype = torch .int64 )
299-
300- # Get the cache miss results multiple times to check for randomization
301- results = []
302- for _ in range (5 ):
303- current_output = kv_embedding_cache .get_embeddings (miss_indices )
304- results .append (current_output .clone ())
305-
306- # Assert: Verify that randomization occurs
307- # The results should not all be identical if randomization is working
308- all_identical = True
309- for i in range (1 , len (results )):
310- if not torch .equal (
311- results [0 ][:, :4 ], results [i ][:, :4 ]
312- ): # Only check first 4 columns (actual data)
313- all_identical = False
314- break
315-
316- # Since we're using randomization, results should be different
317- # Note: There's a small chance they could be identical by random chance,
318- # but with 5 trials of 5 vectors of 4 bytes, this is extremely unlikely
319- self .assertFalse (
320- all_identical ,
321- "Randomized cache miss initialization should produce different results" ,
322- )
323-
324- # All results should be non-zero (since we populated the cache with non-zero random values)
325- for result in results :
326- # Check that at least some values are non-zero (indicating data came from existing blocks)
327- self .assertTrue (
328- torch .any (result [:, :4 ] != 0 ),
329- "Cache miss results should contain non-zero values when cache has data" ,
330- )
331-
332- def test_zero_cache_miss_initialization_with_embedding_cache_mode (self ) -> None :
333- """Test that cache misses return all zero values when embedding_cache_mode=True."""
334- num_shards = 8
335- uniform_init_lower : float = - 0.01
336- uniform_init_upper : float = 0.01
337-
338- # Setup: Create DRAM KV inference cache with embedding_cache_mode=True (zero initialization)
339- kv_embedding_cache = torch .classes .fbgemm .DramKVEmbeddingInferenceWrapper (
340- num_shards ,
341- uniform_init_lower ,
342- uniform_init_upper ,
343- True , # embedding_cache_mode=True for zero initialization
344- )
345- kv_embedding_cache .init (
346- [(32 , 4 , SparseType .FP16 .as_int ())],
347- 32 ,
348- 4 ,
349- torch .tensor ([0 , 100 ], dtype = torch .int64 ),
350- )
351-
352- # Populate the cache with some initial non-zero values to ensure zero initialization
353- # is not just due to empty cache
354- setup_indices = torch .arange (0 , 50 , dtype = torch .int64 )
355- setup_weights = torch .randint (
356- 1 , 255 , (50 , 32 ), dtype = torch .uint8
357- ) # Non-zero values
358- kv_embedding_cache .set_embeddings (setup_indices , setup_weights )
359-
360- # Execute: Request cache misses - these should get zero initialization due to embedding_cache_mode=True
361- # Use indices outside the range [0, 49] to ensure they are actual cache misses
362- miss_indices = torch .tensor ([100 , 101 , 102 , 103 , 104 ], dtype = torch .int64 )
363- results = []
364-
365- # Get cache miss results multiple times to ensure consistent behavior
366- for _ in range (3 ):
367- current_output = kv_embedding_cache .get_embeddings (miss_indices )
368- results .append (current_output .clone ())
369-
370- # Assert: Verify that all cache miss results are zeros when embedding_cache_mode=True
371- expected_zeros = torch .zeros ((5 , 32 ), dtype = torch .uint8 )
372-
373- for i , result in enumerate (results ):
374- # Check that all cache miss results are zero
375- self .assertTrue (
376- torch .equal (result , expected_zeros ),
377- f"Cache miss results should be all zeros when embedding_cache_mode=True, "
378- f"but got non-zero values in iteration { i } : { result [:, :4 ]} " ,
379- )
380-
381- # Additional verification: all results should be identical since they're all zeros
382- for i in range (1 , len (results )):
383- self .assertTrue (
384- torch .equal (results [0 ], results [i ]),
385- f"All zero cache miss results should be identical across calls, "
386- f"but results[0] != results[{ i } ]" ,
387- )
0 commit comments