Skip to content

Commit 6452f4a

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Add get_unique_indices on CPU (#5096)
Summary: X-link: facebookresearch/FBGEMM#2103 Add `get_unique_indices` on CPU Add test to compare `get_unique_indices` from CPU with GPU Differential Revision: D85736286
1 parent 99b6fd1 commit 6452f4a

File tree

4 files changed

+505
-0
lines changed

4 files changed

+505
-0
lines changed

fbgemm_gpu/src/split_embeddings_cache/common.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,23 @@ Tensor direct_mapped_lxu_cache_lookup_cpu(
120120
bool gather_cache_stats,
121121
std::optional<Tensor> uvm_cache_stats);
122122

123+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
124+
get_unique_indices_cpu_impl(
125+
const Tensor& linear_indices,
126+
const int64_t max_indices,
127+
const bool compute_count,
128+
const bool compute_inverse_indices);
129+
130+
std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu(
131+
const Tensor& linear_indices,
132+
const int64_t max_indices,
133+
const bool compute_count);
134+
135+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
136+
get_unique_indices_with_inverse_cpu(
137+
const Tensor& linear_indices,
138+
const int64_t max_indices,
139+
const bool compute_count,
140+
const bool compute_inverse_indices);
141+
123142
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,158 @@ DLL_PUBLIC Tensor linearize_cache_indices_meta(
3939
return at::empty_like(indices, indices.options().dtype(at::kLong));
4040
}
4141

42+
/**
43+
* CPU implementation for computing unique indices from a 1D tensor of linear
44+
* indices.
45+
*
46+
* This function processes a tensor of linear indices and returns the unique
47+
* values along with optional metadata (counts and inverse mapping). The
48+
* implementation uses stable sorting to ensure deterministic ordering of
49+
* duplicate values, matching the reference Python implementation.
50+
*
51+
* Example:
52+
* Input:
53+
* linear_indices = [20, 0, 10, 10, 0]
54+
* max_indices = 20
55+
* compute_count = true
56+
* compute_inverse_indices = true
57+
* Output:
58+
* unique_indices = [0, 10, 20, x, x] where x is uninitialized value.
59+
* unique_indices_length = [3]
60+
* unique_indices_count = [2, 2, 1, x, x] (0 appears 2 times, 10
61+
* appears 2 times, 20 appears 1 time)
62+
* linear_index_positions_sorted = [1, 4, 2, 3, 0] (positions that
63+
* sort the input: linear_indices[[1,4,2,3,0]] = [0,0,10,10,20])
64+
*
65+
* @param linear_indices 1D input tensor containing linear indices to process.
66+
* Must be 1D and have fewer than INT32_MAX elements.
67+
* @param max_indices Maximum number of unique indices expected (currently
68+
* unused, present to match GPU interface and API compatibility).
69+
* @param compute_count If true, computes and returns the count of each unique
70+
* index in the output.
71+
* @param compute_inverse_indices If true, computes the original positions of
72+
* elements in sorted order using stable sort.
73+
*
74+
* @return A tuple containing:
75+
* - unique_indices_output: Tensor of size `linear_indices` that stores
76+
* unique values in sorted order (i.e., unique values padded to input
77+
* size; first `num_unique` elements are valid)
78+
* - unique_indices_length: Tensor of size 1 containing number of unique
79+
* indices
80+
* - unique_indices_count: Optional tensor (if compute_count=true) of size
81+
* `linear_indices` that contains an occurrence count for each unique
82+
* value, else std::nullopt
83+
* - linear_index_positions_sorted: Optional tensor (if
84+
* compute_inverse_indices=true) of size `linear_indices` that contains
85+
* original positions such that
86+
* linear_indices[linear_index_positions_sorted] produces sorted indices.
87+
* Otherwise, std::nullopt. Converted to int32.
88+
*
89+
*/
90+
DLL_PUBLIC
91+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
92+
get_unique_indices_cpu_impl(
93+
const Tensor& linear_indices,
94+
const int64_t /*max_indices*/,
95+
const bool compute_count,
96+
const bool compute_inverse_indices) {
97+
TORCH_CHECK(linear_indices.dim() == 1, "linear_indices must be 1D");
98+
TORCH_CHECK(linear_indices.numel() < std::numeric_limits<int32_t>::max());
99+
100+
const int32_t N = linear_indices.numel();
101+
102+
// Handle empty input
103+
if (N == 0) {
104+
return std::make_tuple(
105+
at::empty_like(linear_indices),
106+
at::zeros({1}, linear_indices.options().dtype(at::kInt)),
107+
compute_count ? std::optional<Tensor>(at::arange(
108+
{0}, linear_indices.options().dtype(at::kInt)))
109+
: std::optional<Tensor>(),
110+
compute_inverse_indices
111+
? std::optional<Tensor>(
112+
at::empty({0}, linear_indices.options().dtype(at::kInt)))
113+
: std::optional<Tensor>());
114+
}
115+
116+
// Use torch::unique to get unique indices
117+
Tensor unique_indices;
118+
Tensor inverse_indices;
119+
Tensor counts;
120+
121+
if (compute_count || compute_inverse_indices) {
122+
std::tie(unique_indices, inverse_indices, counts) = at::unique_dim(
123+
linear_indices,
124+
/*dim=*/0,
125+
/*sorted=*/true,
126+
/*return_inverse=*/true,
127+
/*return_counts=*/true);
128+
} else {
129+
unique_indices = std::get<0>(at::unique_dim(
130+
linear_indices,
131+
/*dim=*/0,
132+
/*sorted=*/true,
133+
/*return_inverse=*/false,
134+
/*return_counts=*/false));
135+
}
136+
137+
// Prepare output tensors
138+
const int32_t num_unique = unique_indices.numel();
139+
auto unique_indices_length =
140+
at::ones({1}, linear_indices.options().dtype(at::kInt)) * num_unique;
141+
142+
// Resize unique_indices to match same size as input
143+
auto unique_indices_output = at::empty_like(linear_indices);
144+
unique_indices_output.slice(0, 0, num_unique).copy_(unique_indices);
145+
146+
std::optional<Tensor> unique_indices_count = std::nullopt;
147+
std::optional<Tensor> linear_index_positions_sorted;
148+
149+
if (compute_count) {
150+
// Resize counts to match same size as input
151+
unique_indices_count =
152+
at::empty({N}, linear_indices.options().dtype(at::kInt));
153+
unique_indices_count->slice(0, 0, num_unique).copy_(counts.to(at::kInt));
154+
}
155+
156+
if (compute_inverse_indices) {
157+
auto sort_indices = at::argsort(
158+
linear_indices, /*stable=*/true, /*dim=*/0, /*descending=*/false);
159+
160+
// Convert to int32
161+
linear_index_positions_sorted = sort_indices.to(at::kInt);
162+
}
163+
164+
return std::make_tuple(
165+
unique_indices_output,
166+
unique_indices_length,
167+
unique_indices_count,
168+
linear_index_positions_sorted);
169+
}
170+
171+
DLL_PUBLIC
172+
std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu(
173+
const Tensor& linear_indices,
174+
const int64_t max_indices,
175+
const bool compute_count) {
176+
const auto ret = get_unique_indices_cpu_impl(
177+
linear_indices,
178+
max_indices,
179+
compute_count,
180+
/*compute_inverse_indices=*/false);
181+
182+
return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)};
183+
}
184+
185+
DLL_PUBLIC
186+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
187+
get_unique_indices_with_inverse_cpu(
188+
const Tensor& linear_indices,
189+
const int64_t max_indices,
190+
const bool compute_count,
191+
const bool compute_inverse_indices) {
192+
return get_unique_indices_cpu_impl(
193+
linear_indices, max_indices, compute_count, compute_inverse_indices);
194+
}
195+
42196
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
6969
DISPATCH_TO_CPU("lxu_cache_lookup", lxu_cache_lookup_cpu);
7070
DISPATCH_TO_CPU(
7171
"direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cpu);
72+
DISPATCH_TO_CPU("get_unique_indices", get_unique_indices_cpu);
73+
DISPATCH_TO_CPU(
74+
"get_unique_indices_with_inverse", get_unique_indices_with_inverse_cpu);
7275

7376
DISPATCH_TO_META("linearize_cache_indices", linearize_cache_indices_meta);
7477
DISPATCH_TO_META("lxu_cache_lookup", lxu_cache_lookup_meta);

0 commit comments

Comments
 (0)