@@ -39,4 +39,158 @@ DLL_PUBLIC Tensor linearize_cache_indices_meta(
3939 return at::empty_like (indices, indices.options ().dtype (at::kLong ));
4040}
4141
42+ /* *
43+ * CPU implementation for computing unique indices from a 1D tensor of linear
44+ * indices.
45+ *
46+ * This function processes a tensor of linear indices and returns the unique
47+ * values along with optional metadata (counts and inverse mapping). The
48+ * implementation uses stable sorting to ensure deterministic ordering of
49+ * duplicate values, matching the reference Python implementation.
50+ *
51+ * Example:
52+ * Input:
53+ * linear_indices = [20, 0, 10, 10, 0]
54+ * max_indices = 20
55+ * compute_count = true
56+ * compute_inverse_indices = true
57+ * Output:
58+ * unique_indices = [0, 10, 20, x, x] where x is uninitialized value.
59+ * unique_indices_length = [3]
60+ * unique_indices_count = [2, 2, 1, x, x] (0 appears 2 times, 10
61+ * appears 2 times, 20 appears 1 time)
62+ * linear_index_positions_sorted = [1, 4, 2, 3, 0] (positions that
63+ * sort the input: linear_indices[[1,4,2,3,0]] = [0,0,10,10,20])
64+ *
65+ * @param linear_indices 1D input tensor containing linear indices to process.
66+ * Must be 1D and have fewer than INT32_MAX elements.
67+ * @param max_indices Maximum number of unique indices expected (currently
68+ * unused, present to match GPU interface and API compatibility).
69+ * @param compute_count If true, computes and returns the count of each unique
70+ * index in the output.
71+ * @param compute_inverse_indices If true, computes the original positions of
72+ * elements in sorted order using stable sort.
73+ *
74+ * @return A tuple containing:
75+ * - unique_indices_output: Tensor of size `linear_indices` that stores
76+ * unique values in sorted order (i.e., unique values padded to input
77+ * size; first `num_unique` elements are valid)
78+ * - unique_indices_length: Tensor of size 1 containing number of unique
79+ * indices
80+ * - unique_indices_count: Optional tensor (if compute_count=true) of size
81+ * `linear_indices` that contains an occurrence count for each unique
82+ * value, else std::nullopt
83+ * - linear_index_positions_sorted: Optional tensor (if
84+ * compute_inverse_indices=true) of size `linear_indices` that contains
85+ * original positions such that
86+ * linear_indices[linear_index_positions_sorted] produces sorted indices.
87+ * Otherwise, std::nullopt. Converted to int32.
88+ *
89+ */
90+ DLL_PUBLIC
91+ std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
92+ get_unique_indices_cpu_impl (
93+ const Tensor& linear_indices,
94+ const int64_t /* max_indices*/ ,
95+ const bool compute_count,
96+ const bool compute_inverse_indices) {
97+ TORCH_CHECK (linear_indices.dim () == 1 , " linear_indices must be 1D" );
98+ TORCH_CHECK (linear_indices.numel () < std::numeric_limits<int32_t >::max ());
99+
100+ const int32_t N = linear_indices.numel ();
101+
102+ // Handle empty input
103+ if (N == 0 ) {
104+ return std::make_tuple (
105+ at::empty_like (linear_indices),
106+ at::zeros ({1 }, linear_indices.options ().dtype (at::kInt )),
107+ compute_count ? std::optional<Tensor>(at::arange (
108+ {0 }, linear_indices.options ().dtype (at::kInt )))
109+ : std::optional<Tensor>(),
110+ compute_inverse_indices
111+ ? std::optional<Tensor>(
112+ at::empty ({0 }, linear_indices.options ().dtype (at::kInt )))
113+ : std::optional<Tensor>());
114+ }
115+
116+ // Use torch::unique to get unique indices
117+ Tensor unique_indices;
118+ Tensor inverse_indices;
119+ Tensor counts;
120+
121+ if (compute_count || compute_inverse_indices) {
122+ std::tie (unique_indices, inverse_indices, counts) = at::unique_dim (
123+ linear_indices,
124+ /* dim=*/ 0 ,
125+ /* sorted=*/ true ,
126+ /* return_inverse=*/ true ,
127+ /* return_counts=*/ true );
128+ } else {
129+ unique_indices = std::get<0 >(at::unique_dim (
130+ linear_indices,
131+ /* dim=*/ 0 ,
132+ /* sorted=*/ true ,
133+ /* return_inverse=*/ false ,
134+ /* return_counts=*/ false ));
135+ }
136+
137+ // Prepare output tensors
138+ const int32_t num_unique = unique_indices.numel ();
139+ auto unique_indices_length =
140+ at::ones ({1 }, linear_indices.options ().dtype (at::kInt )) * num_unique;
141+
142+ // Resize unique_indices to match same size as input
143+ auto unique_indices_output = at::empty_like (linear_indices);
144+ unique_indices_output.slice (0 , 0 , num_unique).copy_ (unique_indices);
145+
146+ std::optional<Tensor> unique_indices_count = std::nullopt ;
147+ std::optional<Tensor> linear_index_positions_sorted;
148+
149+ if (compute_count) {
150+ // Resize counts to match same size as input
151+ unique_indices_count =
152+ at::empty ({N}, linear_indices.options ().dtype (at::kInt ));
153+ unique_indices_count->slice (0 , 0 , num_unique).copy_ (counts.to (at::kInt ));
154+ }
155+
156+ if (compute_inverse_indices) {
157+ auto sort_indices = at::argsort (
158+ linear_indices, /* stable=*/ true , /* dim=*/ 0 , /* descending=*/ false );
159+
160+ // Convert to int32
161+ linear_index_positions_sorted = sort_indices.to (at::kInt );
162+ }
163+
164+ return std::make_tuple (
165+ unique_indices_output,
166+ unique_indices_length,
167+ unique_indices_count,
168+ linear_index_positions_sorted);
169+ }
170+
171+ DLL_PUBLIC
172+ std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu (
173+ const Tensor& linear_indices,
174+ const int64_t max_indices,
175+ const bool compute_count) {
176+ const auto ret = get_unique_indices_cpu_impl (
177+ linear_indices,
178+ max_indices,
179+ compute_count,
180+ /* compute_inverse_indices=*/ false );
181+
182+ return {std::get<0 >(ret), std::get<1 >(ret), std::get<2 >(ret)};
183+ }
184+
185+ DLL_PUBLIC
186+ std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
187+ get_unique_indices_with_inverse_cpu (
188+ const Tensor& linear_indices,
189+ const int64_t max_indices,
190+ const bool compute_count,
191+ const bool compute_inverse_indices) {
192+ return get_unique_indices_cpu_impl (
193+ linear_indices, max_indices, compute_count, compute_inverse_indices);
194+ }
195+
42196} // namespace fbgemm_gpu
0 commit comments