Skip to content

Commit 6dc4fc8

Browse files
emlinmeta-codesync[bot]
authored andcommitted
Back out "change from first element to a random element for cache missing items" (#5048)
Summary: Pull Request resolved: #5048 X-link: https://github.com/facebookresearch/FBGEMM/pull/2058 Original commit changeset: 23e7f0d1e249 Original Phabricator Diff: D83612329 Since D83612329 is released, the model has started crashing intermittently. Revert the change and, at the same time, look for better optimization approaches. Reviewed By: jma99fb Differential Revision: D85404797 fbshipit-source-id: 5dad6196ead14cb1e5e8845bab12a30a058cbf75
1 parent dda8b12 commit 6dc4fc8

File tree

2 files changed

+4
-158
lines changed

2 files changed

+4
-158
lines changed

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include <thrift/lib/cpp2/protocol/CompactProtocol.h>
2323
#include <thrift/lib/cpp2/protocol/Serializer.h>
2424
#include <torch/script.h>
25-
#include <random>
2625
#include "common/time/Time.h"
2726

2827
#include "../ssd_split_embeddings_cache/initializer.h"
@@ -435,36 +434,9 @@ class DramKVInferenceEmbedding {
435434
before_read_lock_ts;
436435

437436
if (!wlmap->empty() && !disable_random_init_) {
438-
// Simple block-based randomization using get_block with
439-
// cursor
440-
auto* pool = kv_store_.pool_by(shard_id);
441-
442-
// Random starting cursor based on map size for good
443-
// entropy
444-
size_t random_start =
445-
folly::Random::rand32(wlmap->size());
446-
447-
// Try to find a used block starting from random
448-
// position
449-
weight_type* block = nullptr;
450-
for (int attempts = 0; attempts < 16; ++attempts) {
451-
block = pool->template get_block<weight_type>(
452-
random_start + attempts);
453-
if (block != nullptr) {
454-
// Block is used (not null)
455-
row_storage_data_ptr =
456-
FixedBlockPool::data_ptr<weight_type>(block);
457-
break;
458-
}
459-
}
460-
461-
// Fallback: if no used block found, use first element
462-
// from map
463-
if (block == nullptr) {
464-
row_storage_data_ptr =
465-
FixedBlockPool::data_ptr<weight_type>(
466-
wlmap->begin()->second);
467-
}
437+
row_storage_data_ptr =
438+
FixedBlockPool::data_ptr<weight_type>(
439+
wlmap->begin()->second);
468440
} else {
469441
const auto& init_storage =
470442
initializers_[shard_id]->row_storage_;
@@ -569,9 +541,7 @@ class DramKVInferenceEmbedding {
569541
read_lookup_cache_total_duration / num_shards_;
570542
read_acquire_lock_avg_duration_ +=
571543
read_acquire_lock_total_duration / num_shards_;
572-
LOG_EVERY_MS(INFO, 5000)
573-
<< "get_kv_db_async total read_missing_load per batch: "
574-
<< read_missing_load;
544+
read_missing_load_avg_ += read_missing_load / num_shards_;
575545
return std::vector<folly::Unit>(results.size());
576546
});
577547
};

fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py

Lines changed: 0 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -261,127 +261,3 @@ def reader_thread() -> None: # pyre-ignore
261261
self.assertTrue(equal_one_of(embs[5, :4], possible_embs))
262262
reader_thread.join()
263263
self.assertFalse(reader_failed_event.is_set())
264-
265-
def test_randomized_cache_miss_initialization(self) -> None:
266-
"""Test that cache misses use randomized data from existing blocks."""
267-
num_shards = 8
268-
uniform_init_lower: float = -0.01
269-
uniform_init_upper: float = 0.01
270-
271-
# Create DRAM KV inference cache
272-
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
273-
num_shards,
274-
uniform_init_lower,
275-
uniform_init_upper,
276-
False, # disable_random_init
277-
)
278-
kv_embedding_cache.init(
279-
[(32, 4, SparseType.FP16.as_int())],
280-
32,
281-
4,
282-
torch.tensor([0, 100], dtype=torch.int64),
283-
)
284-
285-
# Setup: Populate the cache with many initial values for better randomization diversity
286-
# Use 400 setup items to ensure each shard (8 shards) gets ~50 entries for good randomization
287-
setup_indices = torch.arange(0, 400, dtype=torch.int64) # 400 setup items
288-
setup_weights = torch.randint(
289-
1, 255, (400, 32), dtype=torch.uint8
290-
) # Non-zero values to ensure randomization source
291-
print(f"setup_weights: {setup_weights}")
292-
293-
# Populate cache
294-
kv_embedding_cache.set_embeddings(setup_indices, setup_weights)
295-
296-
# Execute: Request cache misses multiple times - these should get randomized initialization
297-
# Use indices outside the range [0, 399] to ensure they are actual cache misses
298-
miss_indices = torch.tensor([500, 501, 502, 503, 504], dtype=torch.int64)
299-
300-
# Get the cache miss results multiple times to check for randomization
301-
results = []
302-
for _ in range(5):
303-
current_output = kv_embedding_cache.get_embeddings(miss_indices)
304-
results.append(current_output.clone())
305-
306-
# Assert: Verify that randomization occurs
307-
# The results should not all be identical if randomization is working
308-
all_identical = True
309-
for i in range(1, len(results)):
310-
if not torch.equal(
311-
results[0][:, :4], results[i][:, :4]
312-
): # Only check first 4 columns (actual data)
313-
all_identical = False
314-
break
315-
316-
# Since we're using randomization, results should be different
317-
# Note: There's a small chance they could be identical by random chance,
318-
# but with 5 trials of 5 vectors of 4 bytes, this is extremely unlikely
319-
self.assertFalse(
320-
all_identical,
321-
"Randomized cache miss initialization should produce different results",
322-
)
323-
324-
# All results should be non-zero (since we populated the cache with non-zero random values)
325-
for result in results:
326-
# Check that at least some values are non-zero (indicating data came from existing blocks)
327-
self.assertTrue(
328-
torch.any(result[:, :4] != 0),
329-
"Cache miss results should contain non-zero values when cache has data",
330-
)
331-
332-
def test_zero_cache_miss_initialization_with_embedding_cache_mode(self) -> None:
333-
"""Test that cache misses return all zero values when embedding_cache_mode=True."""
334-
num_shards = 8
335-
uniform_init_lower: float = -0.01
336-
uniform_init_upper: float = 0.01
337-
338-
# Setup: Create DRAM KV inference cache with embedding_cache_mode=True (zero initialization)
339-
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
340-
num_shards,
341-
uniform_init_lower,
342-
uniform_init_upper,
343-
True, # embedding_cache_mode=True for zero initialization
344-
)
345-
kv_embedding_cache.init(
346-
[(32, 4, SparseType.FP16.as_int())],
347-
32,
348-
4,
349-
torch.tensor([0, 100], dtype=torch.int64),
350-
)
351-
352-
# Populate the cache with some initial non-zero values to ensure zero initialization
353-
# is not just due to empty cache
354-
setup_indices = torch.arange(0, 50, dtype=torch.int64)
355-
setup_weights = torch.randint(
356-
1, 255, (50, 32), dtype=torch.uint8
357-
) # Non-zero values
358-
kv_embedding_cache.set_embeddings(setup_indices, setup_weights)
359-
360-
# Execute: Request cache misses - these should get zero initialization due to embedding_cache_mode=True
361-
# Use indices outside the range [0, 49] to ensure they are actual cache misses
362-
miss_indices = torch.tensor([100, 101, 102, 103, 104], dtype=torch.int64)
363-
results = []
364-
365-
# Get cache miss results multiple times to ensure consistent behavior
366-
for _ in range(3):
367-
current_output = kv_embedding_cache.get_embeddings(miss_indices)
368-
results.append(current_output.clone())
369-
370-
# Assert: Verify that all cache miss results are zeros when embedding_cache_mode=True
371-
expected_zeros = torch.zeros((5, 32), dtype=torch.uint8)
372-
373-
for i, result in enumerate(results):
374-
# Check that all cache miss results are zero
375-
self.assertTrue(
376-
torch.equal(result, expected_zeros),
377-
f"Cache miss results should be all zeros when embedding_cache_mode=True, "
378-
f"but got non-zero values in iteration {i}: {result[:, :4]}",
379-
)
380-
381-
# Additional verification: all results should be identical since they're all zeros
382-
for i in range(1, len(results)):
383-
self.assertTrue(
384-
torch.equal(results[0], results[i]),
385-
f"All zero cache miss results should be identical across calls, "
386-
f"but results[0] != results[{i}]",
387-
)

0 commit comments

Comments
 (0)