@@ -195,6 +195,19 @@ void dispatch(
195
195
);
196
196
}
197
197
198
+ void fake_dispatch (
199
+ fptr_t ptr,
200
+ at::Tensor &outExpertNumTokens,
201
+ at::Tensor &outExpertX,
202
+ const std::optional<at::Tensor> &outExpertXScale,
203
+ const at::Tensor &dpX,
204
+ const std::optional<at::Tensor> &dpXScale,
205
+ const at::Tensor &indices,
206
+ const std::optional<at::Tensor> &boundM,
207
+ bool doSend,
208
+ bool doRecv
209
+ ) {}
210
+
198
211
template <typename Kernel, typename T, typename U>
199
212
void combineImpl (
200
213
Kernel *all_to_all,
@@ -297,6 +310,17 @@ void combine(
297
310
}
298
311
}
299
312
313
+ void fake_combine (
314
+ fptr_t ptr,
315
+ at::Tensor &outTokens,
316
+ const at::Tensor &indices,
317
+ const at::Tensor &weights,
318
+ const at::Tensor &expertY,
319
+ const std::optional<at::Tensor> &boundM,
320
+ bool doSend,
321
+ bool doRecv
322
+ ) {}
323
+
300
324
#undef _CHECK_TENSOR
301
325
302
326
} // namespace
@@ -306,11 +330,64 @@ void register_all_to_all_ops(torch::Library &m) {
306
330
m.def (" all_to_all_destroy" , &destroy);
307
331
308
332
m.def (" all_to_all_internode_create" , &create_internode);
309
- m.def (" all_to_all_internode_dispatch" , &dispatch<AllToAllInterNode>);
310
- m.def (" all_to_all_internode_combine" , &combine<AllToAllInterNode>);
333
+
334
+ m.def (" all_to_all_internode_dispatch("
335
+ " int fptr,"
336
+ " Tensor! out_expert_num_tokens,"
337
+ " Tensor! out_expert_x,"
338
+ " Tensor!? out_expert_x_scale,"
339
+ " Tensor dp_x,"
340
+ " Tensor? dp_x_scale,"
341
+ " Tensor indices,"
342
+ " Tensor? bound_m,"
343
+ " bool do_send,"
344
+ " bool do_recv"
345
+ " ) -> ()" );
346
+ m.impl (" all_to_all_internode_dispatch" , c10::kCUDA , &dispatch<AllToAllInterNode>);
347
+ m.impl (" all_to_all_internode_dispatch" , c10::kMeta , &fake_dispatch);
348
+
349
+ m.def (" all_to_all_internode_combine("
350
+ " int fptr,"
351
+ " Tensor! out_tokens,"
352
+ " Tensor indices,"
353
+ " Tensor weights,"
354
+ " Tensor expert_y,"
355
+ " Tensor? bound_m,"
356
+ " bool do_send,"
357
+ " bool do_recv"
358
+ " ) -> ()" );
359
+ m.impl (" all_to_all_internode_combine" , c10::kCUDA , &combine<AllToAllInterNode>);
360
+ m.impl (" all_to_all_internode_combine" , c10::kMeta , &fake_combine);
311
361
312
362
m.def (" all_to_all_intranode_create" , &create_intranode);
313
- m.def (" all_to_all_intranode_dispatch" , &dispatch<AllToAllIntraNode>);
314
- m.def (" all_to_all_intranode_combine" , &combine<AllToAllIntraNode>);
363
+
364
+ m.def (" all_to_all_intranode_dispatch("
365
+ " int fptr,"
366
+ " Tensor! out_expert_num_tokens,"
367
+ " Tensor! out_expert_x,"
368
+ " Tensor!? out_expert_x_scale,"
369
+ " Tensor dp_x,"
370
+ " Tensor? dp_x_scale,"
371
+ " Tensor indices,"
372
+ " Tensor? bound_m,"
373
+ " bool do_send,"
374
+ " bool do_recv"
375
+ " ) -> ()" );
376
+ m.impl (" all_to_all_intranode_dispatch" , c10::kCUDA , &dispatch<AllToAllIntraNode>);
377
+ m.impl (" all_to_all_intranode_dispatch" , c10::kMeta , &fake_dispatch);
378
+
379
+ m.def (" all_to_all_intranode_combine("
380
+ " int fptr,"
381
+ " Tensor! out_tokens,"
382
+ " Tensor indices,"
383
+ " Tensor weights,"
384
+ " Tensor expert_y,"
385
+ " Tensor? bound_m,"
386
+ " bool do_send,"
387
+ " bool do_recv"
388
+ " ) -> ()" );
389
+ m.impl (" all_to_all_intranode_combine" , c10::kCUDA , &combine<AllToAllIntraNode>);
390
+ m.impl (" all_to_all_intranode_combine" , c10::kMeta , &fake_combine);
315
391
}
392
+
316
393
} // namespace pplx
0 commit comments