From a867c4b53a4cc8dc0f79daced3d2d4cd2722a3f5 Mon Sep 17 00:00:00 2001 From: Gantaphon Chalumporn Date: Tue, 11 Nov 2025 20:18:49 -0800 Subject: [PATCH] Add `get_unique_indices` on CPU (#5096) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2103 Implements `get_unique_indices_cpu_impl()` to extract unique indices from linear index tensors on CPU, with comprehensive documentation and test coverage for both int32 and int64 dtypes. Function Description -------------------- **`get_unique_indices_cpu_impl`** processes a 1D tensor of linear indices and returns unique values with optional metadata (counts and inverse mapping for reordering). ### Example ``` Input: linear_indices = [20, 0, 10, 10, 0] Output: unique_indices = [0, 10, 20, x, x] (sorted, padded) unique_indices_length = [3] unique_indices_count = [2, 2, 1, x, x] (occurrence counts) linear_index_positions_sorted = [1, 4, 2, 3, 0] (positions that sort input: linear_indices[[1,4,2,3,0]] = [0,0,10,10,20]) ``` ### Returns 1. **unique_indices**: Sorted unique values padded to input size (first `num_unique` elements valid) 2. **unique_indices_length**: Scalar tensor with count of unique values 3. **unique_indices_count** (optional): Occurrence count for each unique value 4. **linear_index_positions_sorted** (optional): Original positions that reorder input to sorted order (int32) ### Implementation Details * Uses `at::unique_dim()` for core uniqueness computation with stable sorting * Preserves input dtype for unique values * Converts counts and positions to int32 for consistency with CUDA implementation * Supports both `torch.int` (int32) and `torch.long` (int64) input dtypes ### Test Coverage Added dtype parameterization to `test_get_unique_indices_cpu` to validate both int32 and int64, ensuring CPU implementation supports all dtypes that CUDA implementation support. Differential Revision: D85736286 --- fbgemm_gpu/fbgemm_gpu/__init__.py | 1 + .../src/split_embeddings_cache/common.h | 19 + .../linearize_cache_indices.cpp | 157 +++++++++ .../split_embeddings_cache_ops.cpp | 3 + fbgemm_gpu/test/tbe/cache/cache_test.py | 329 ++++++++++++++++++ 5 files changed, 509 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/__init__.py b/fbgemm_gpu/fbgemm_gpu/__init__.py index afb69cfa87..7971f84f6b 100644 --- a/fbgemm_gpu/fbgemm_gpu/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/__init__.py @@ -131,6 +131,7 @@ def _load_library(filename: str, version: str, no_throw: bool = False) -> None: "fbgemm_gpu_config", "fbgemm_gpu_tbe_utils", "fbgemm_gpu_tbe_index_select", + "fbgemm_gpu_tbe_cache", "fbgemm_gpu_tbe_optimizers", "fbgemm_gpu_tbe_inference", "fbgemm_gpu_tbe_training_forward", diff --git a/fbgemm_gpu/src/split_embeddings_cache/common.h b/fbgemm_gpu/src/split_embeddings_cache/common.h index 3bbc2b62de..6b7667019c 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/common.h +++ b/fbgemm_gpu/src/split_embeddings_cache/common.h @@ -120,4 +120,23 @@ Tensor direct_mapped_lxu_cache_lookup_cpu( bool gather_cache_stats, std::optional uvm_cache_stats); +std::tuple, std::optional> +get_unique_indices_cpu_impl( + const Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count, + const bool compute_inverse_indices); + +std::tuple> get_unique_indices_cpu( + const Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count); + +std::tuple, std::optional> +get_unique_indices_with_inverse_cpu( + const Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count, + const bool compute_inverse_indices); + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cpp b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cpp index 5111576690..26be9c5360 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cpp @@ -39,4 +39,161 @@ DLL_PUBLIC Tensor linearize_cache_indices_meta( return at::empty_like(indices, indices.options().dtype(at::kLong)); } +/** + * CPU implementation for computing unique indices from a 1D tensor of linear + * indices. + * + * This function processes a tensor of linear indices and returns the unique + * values along with optional metadata (counts and inverse mapping). The + * implementation uses stable sorting to ensure deterministic ordering of + * duplicate values, matching the reference Python implementation. + * + * Example: + * Input: + * linear_indices = [20, 0, 10, 10, 0] + * max_indices = 20 + * compute_count = true + * compute_inverse_indices = true + * Output: + * unique_indices = [0, 10, 20, x, x] (dtype: int64, x is + * uninitialized) + * unique_indices_length = [3] (dtype: int32) + * unique_indices_count = [2, 2, 1, x, x] (dtype: int32, 0 appears 2 + * times, 10 appears 2 times, 20 appears 1 time) + * linear_index_positions_sorted = [1, 4, 2, 3, 0] (dtype: int32, + * positions that sort the input: + * linear_indices[[1,4,2,3,0]] = [0,0,10,10,20]) + * + * @param linear_indices 1D input tensor containing linear indices to process + * (dtype: int32 or int64). Must be 1D and have at most INT32_MAX + * elements. + * @param max_indices Maximum number of unique indices expected (dtype: int64, + * currently unused, present to match GPU interface and API compatibility). + * @param compute_count If true, computes and returns the count of each unique + * index in the output (dtype: bool). + * @param compute_inverse_indices If true, computes the original positions of + * elements in sorted order using stable sort (dtype: bool). + * + * @return A tuple containing: + * - unique_indices_output: Tensor of size `linear_indices` that stores + * unique values in sorted order (dtype: same as input; first `num_unique` + * elements are valid, rest are uninitialized) + * - unique_indices_length: Tensor of size 1 containing number of unique + * indices (dtype: int32) + * - unique_indices_count: Optional tensor (if compute_count=true) of size + * `linear_indices` that contains an occurrence count for each unique + * value (dtype: int32), else std::nullopt + * - linear_index_positions_sorted: Optional tensor (dtype: int32) (if + * compute_inverse_indices=true) of size `linear_indices` that contains + * original positions such that + * linear_indices[linear_index_positions_sorted[i]] is the i postion in the + * sorted order. + * + */ +DLL_PUBLIC +std::tuple, std::optional> +get_unique_indices_cpu_impl( + const Tensor& linear_indices, + const int64_t /*max_indices*/, + const bool compute_count, + const bool compute_inverse_indices) { + TORCH_CHECK(linear_indices.dim() == 1, "linear_indices must be 1D"); + TORCH_CHECK(linear_indices.numel() < std::numeric_limits::max()); + + const int32_t N = linear_indices.numel(); + + // Handle empty input + if (N == 0) { + return std::make_tuple( + at::empty_like(linear_indices), + at::zeros({1}, linear_indices.options().dtype(at::kInt)), + compute_count ? std::optional(at::zeros( + {0}, linear_indices.options().dtype(at::kInt))) + : std::nullopt, + compute_inverse_indices + ? std::optional( + at::zeros({0}, linear_indices.options().dtype(at::kInt))) + : std::nullopt); + } + + // Use torch::unique to get unique indices + Tensor unique_indices; + Tensor inverse_indices; + Tensor counts; + + if (compute_count || compute_inverse_indices) { + std::tie(unique_indices, inverse_indices, counts) = at::unique_dim( + linear_indices, + /*dim=*/0, + /*sorted=*/true, + /*return_inverse=*/true, + /*return_counts=*/true); + } else { + unique_indices = std::get<0>(at::unique_dim( + linear_indices, + /*dim=*/0, + /*sorted=*/true, + /*return_inverse=*/false, + /*return_counts=*/false)); + } + + // Prepare output tensors + const int32_t num_unique = unique_indices.numel(); + auto unique_indices_length = + at::tensor({num_unique}, linear_indices.options().dtype(at::kInt)); + + // Resize unique_indices to match same size as input + auto unique_indices_output = at::empty_like(linear_indices); + unique_indices_output.slice(0, 0, num_unique).copy_(unique_indices); + + std::optional unique_indices_count = std::nullopt; + std::optional linear_index_positions_sorted; + + if (compute_count) { + // Resize counts to match same size as input + unique_indices_count = + at::empty({N}, linear_indices.options().dtype(at::kInt)); + unique_indices_count->slice(0, 0, num_unique).copy_(counts.to(at::kInt)); + } + + if (compute_inverse_indices) { + auto sort_indices = at::argsort( + linear_indices, /*stable=*/true, /*dim=*/0, /*descending=*/false); + + // Convert to int32 + linear_index_positions_sorted = sort_indices.to(at::kInt); + } + + return std::make_tuple( + unique_indices_output, + unique_indices_length, + unique_indices_count, + linear_index_positions_sorted); +} + +DLL_PUBLIC +std::tuple> get_unique_indices_cpu( + const Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count) { + const auto ret = get_unique_indices_cpu_impl( + linear_indices, + max_indices, + compute_count, + /*compute_inverse_indices=*/false); + + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; +} + +DLL_PUBLIC +std::tuple, std::optional> +get_unique_indices_with_inverse_cpu( + const Tensor& linear_indices, + const int64_t max_indices, + const bool compute_count, + const bool compute_inverse_indices) { + return get_unique_indices_cpu_impl( + linear_indices, max_indices, compute_count, compute_inverse_indices); +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp index b3c2a726a8..cbd995c5e9 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp @@ -69,6 +69,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CPU("lxu_cache_lookup", lxu_cache_lookup_cpu); DISPATCH_TO_CPU( "direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cpu); + DISPATCH_TO_CPU("get_unique_indices", get_unique_indices_cpu); + DISPATCH_TO_CPU( + "get_unique_indices_with_inverse", get_unique_indices_with_inverse_cpu); DISPATCH_TO_META("linearize_cache_indices", linearize_cache_indices_meta); DISPATCH_TO_META("lxu_cache_lookup", lxu_cache_lookup_meta); diff --git a/fbgemm_gpu/test/tbe/cache/cache_test.py b/fbgemm_gpu/test/tbe/cache/cache_test.py index 8f3a3eec52..6b292a1a0c 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_test.py @@ -649,6 +649,335 @@ def test_cache_update_function(self, L: int, H: int, S: int) -> None: self.assertEqual(unique_cache_miss_count, expect_out) self.assertLessEqual(cache_miss_forward_count, unique_cache_miss_count) + def _get_unique_indices_reference( + self, + linear_indices: torch.Tensor, + max_indices: int, + compute_count: bool, + compute_inverse_indices: bool, + ) -> tuple[ + torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor] + ]: + """Python reference implementation for validating get_unique_indices operations. + + This function provides an independent baseline for testing CPU and GPU implementations + of get_unique_indices. It uses only pure Python operations (sorted(), set(), dict). + + Example: + Input: + linear_indices = [20, 0, 10, 10, 0] + max_indices = 20 + compute_count = True + compute_inverse_indices = True + Output: + unique_indices = [0, 10, 20, x, x] where x is the uninitialized value. + unique_indices_length = [3] + unique_indices_count = [2, 2, 1, x, x] (0 appears 2 times, 10 appears 2 times, 20 appears 1 time) + linear_index_positions_sorted = [1, 4, 2, 3, 0] (positions that sort the input: linear_indices[[1,4,2,3,0]] = [0,0,10,10,20]) + + Args: + linear_indices (Tensor): Input tensor of indices to find unique values from + max_indices (int): Maximum possible index value (not used in computation, kept for API compatibility) + compute_count (bool): If True, count occurrence for each unique index + compute_inverse_indices (bool): If True, store original positions of the indices in a sorted manner + + Returns: + A tuple containing: + - unique_indices (Tensor): Tensor of size `linear_indices` that stores unique values in sorted order (i.e., unique values padded to input size) + - unique_indices_length (Tensor): Tensor of size 1 containing number of unique values + - unique_indices_count (Optional[Tensor]): If compute_count=True, tensor of size `linear_indices` that contains an occurrence count for each unique value, else None. + - linear_index_positions_sorted (Optional[Tensor]): If compute_inverse_indices=True, tensor of size `linear_indices` that contains original positions such that linear_indices[linear_index_positions_sorted] produces sorted indices. Otherwise, None. + """ + N = linear_indices.numel() + + # Convert to Python list for pure Python processing + indices_list = linear_indices.tolist() + + # Get unique values in sorted order using pure Python + unique_vals_list = sorted(set(indices_list)) + num_unique = len(unique_vals_list) + + # Prepare outputs matching the format of the ops + unique_indices = torch.empty_like(linear_indices) + if num_unique > 0: + unique_indices[:num_unique] = torch.tensor( + unique_vals_list, dtype=linear_indices.dtype + ) + + unique_indices_length = torch.tensor([num_unique], dtype=torch.int32) + + unique_indices_count = None + if compute_count: + # Count occurrences using pure Python + count_dict = {} + for val in indices_list: + count_dict[val] = count_dict.get(val, 0) + 1 + + counts_list = [count_dict[val] for val in unique_vals_list] + + unique_indices_count = torch.empty(N, dtype=torch.int32) + if num_unique > 0: + unique_indices_count[:num_unique] = torch.tensor( + counts_list, dtype=torch.int32 + ) + + linear_index_positions_sorted = None + if compute_inverse_indices: + # Create list of (value, original_position) tuples + indexed_list = [(val, idx) for idx, val in enumerate(indices_list)] + + # Sort by value (stable sort preserves order for equal values) + sorted_indexed = sorted(indexed_list, key=lambda x: x[0]) + + # Extract the original positions in sorted order + positions_list = [pos for val, pos in sorted_indexed] + + linear_index_positions_sorted = torch.tensor( + positions_list, dtype=torch.int32 + ) + + return ( + unique_indices, + unique_indices_length, + unique_indices_count, + linear_index_positions_sorted, + ) + + @given( + N=st.integers(min_value=0, max_value=1000), + max_indices=st.integers(min_value=100, max_value=10000), + compute_count=st.booleans(), + compute_inverse_indices=st.booleans(), + dtype=st.sampled_from([torch.int, torch.long]), + ) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) + def test_get_unique_indices_cpu( + self, + N: int, + max_indices: int, + compute_count: bool, + compute_inverse_indices: bool, + dtype: torch.dtype, + ) -> None: + """Test get_unique_indices ops on CPU, GPU and MTIA. + + This test validates two ops: + - torch.ops.fbgemm.get_unique_indices: Returns unique indices and optionally their counts + - torch.ops.fbgemm.get_unique_indices_with_inverse: Additionally returns sorted positions for reordering + + The test uses a Python reference implementation (_get_unique_indices_reference) to ensure correctness and parity acorss devices. + + Test strategy: + 1. Generate random linear indices with values in [0, max_indices) + 2. Run pure Python reference implementation for ground truth + 3. Run CPU implementation via torch.ops.fbgemm.get_unique_indices[_with_inverse] + 4. Compare CPU results against reference implementation + 5. If GPU available, run the ops on GPU and compare against CPU results + 6. If MTIA available, run the ops on MTIA and compare against CPU results + + Validates: + - Unique indices: Both CPU and GPU extract the same set of unique values in sorted order + - Length: Number of unique values matches across all implementations + - Counts (if compute_count=True): Occurrence count for each unique value matches + - Positions (if compute_inverse_indices=True): Sorted positions produce identical reordering + + Args: + N: Number of random linear indices to generate (0-1000). Tests with N=0 validate empty input handling. + max_indices: Maximum value for generated indices (100-10000). Indices are in range [0, max_indices). + compute_count: If True, ops return occurrence count for each unique value in the third output. + compute_inverse_indices: If True, ops return original positions in sorted order (fourth output for + get_unique_indices_with_inverse). These positions enable reordering the input to be in sorted order. + dtype: Data type for generated indices. Tests both torch.int (int32) and torch.long (int64) to ensure + CPU implementation supports all dtypes that CUDA implementation supports. + """ + # Generate random linear indices with the specified dtype + linear_indices = torch.randint(0, max_indices, (N,), dtype=dtype) + + # Get reference implementation results + ( + unique_ref, + length_ref, + count_ref, + positions_ref, + ) = self._get_unique_indices_reference( + linear_indices.cpu(), max_indices, compute_count, compute_inverse_indices + ) + + # Run on CPU + if compute_inverse_indices: + ( + unique_cpu, + length_cpu, + count_cpu, + positions_cpu, + ) = torch.ops.fbgemm.get_unique_indices_with_inverse( + linear_indices, + max_indices, + compute_count, + compute_inverse_indices, + ) + else: + unique_cpu, length_cpu, count_cpu = torch.ops.fbgemm.get_unique_indices( + linear_indices, max_indices, compute_count + ) + positions_cpu = None + + def compare_output( + input_indices: torch.Tensor, + annotate1: str, + annotate2: str, + length1: int, + length2: int, + unique1: torch.Tensor, + unique2: torch.Tensor, + compute_count: bool, + compute_inverse_indices: bool, + count1: Optional[torch.Tensor] = None, + count2: Optional[torch.Tensor] = None, + positions1: Optional[torch.Tensor] = None, + positions2: Optional[torch.Tensor] = None, + ): + self.assertEqual( + length1, + length2, + f"{annotate1} unique indices length mismatch with {annotate2}", + ) + + torch.testing.assert_close( + unique1[:length1].cpu(), + unique2[:length2].cpu(), + msg=f"{annotate1} unique indices mismatch with {annotate2}", + ) + + if compute_count: + self.assertIsNotNone(count1, f"{annotate1} count should not be None") + self.assertIsNotNone(count2, f"{annotate2} count should not be None") + torch.testing.assert_close( + count1[:length1].cpu(), + count2[:length2].cpu(), + msg=f"{annotate1} unique indices count mismatch with {annotate2}", + ) + + if compute_inverse_indices: + self.assertIsNotNone( + positions1, f"{annotate1} positions should not be None" + ) + self.assertIsNotNone( + positions2, f"{annotate2} positions should not be None" + ) + + torch.testing.assert_close( + positions1.cpu(), + positions2.cpu(), + msg=f"{annotate1} unique indices position mismatch with {annotate2}", + ) + # Move positions to same device as input_indices before gather + reordered1 = input_indices.gather( + 0, positions1.long().to(input_indices.device) + ) + reordered2 = input_indices.gather( + 0, positions2.long().to(input_indices.device) + ) + + torch.testing.assert_close( + reordered1.cpu(), + reordered2.cpu(), + msg=f"{annotate1} reordered indices mismatch with {annotate2}", + ) + + # Test CPU op with reference + compare_output( + linear_indices, + "CPU", + "ref implementation", + length_cpu.item(), + length_ref, + unique_cpu, + unique_ref, + compute_count, + compute_inverse_indices, + count_cpu, + count_ref, + positions_cpu, + positions_ref, + ) + + # Run on GPU + if not gpu_unavailable[0]: + linear_indices_gpu = linear_indices.cuda() + if compute_inverse_indices: + ( + unique_gpu, + length_gpu, + count_gpu, + positions_gpu, + ) = torch.ops.fbgemm.get_unique_indices_with_inverse( + linear_indices_gpu, + max_indices, + compute_count, + compute_inverse_indices, + ) + else: + unique_gpu, length_gpu, count_gpu = torch.ops.fbgemm.get_unique_indices( + linear_indices_gpu, max_indices, compute_count + ) + positions_gpu = None + + compare_output( + linear_indices, + "CPU", + "GPU", + length_cpu.item(), + length_gpu.item(), + unique_cpu, + unique_gpu, + compute_count, + compute_inverse_indices, + count_cpu, + count_gpu, + positions_cpu, + positions_gpu, + ) + + # Run on MTIA + if torch.mtia.is_available(): + linear_indices_mtia = linear_indices.mtia() + if compute_inverse_indices: + ( + unique_mtia, + length_mtia, + count_mtia, + positions_mtia, + ) = torch.ops.fbgemm.get_unique_indices_with_inverse( + linear_indices_mtia, + max_indices, + compute_count, + compute_inverse_indices, + ) + else: + unique_mtia, length_mtia, count_mtia = ( + torch.ops.fbgemm.get_unique_indices( + linear_indices_gpu, max_indices, compute_count + ) + ) + positions_mtia = None + + compare_output( + linear_indices, + "CPU", + "MTIA", + length_cpu.item(), + length_mtia.item(), + unique_cpu, + unique_mtia, + compute_count, + compute_inverse_indices, + count_cpu, + count_mtia, + positions_cpu, + positions_mtia, + ) + @unittest.skipIf(*gpu_unavailable) @given(N=st.integers(min_value=1, max_value=8)) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)