1
1
#include " all_to_all.h"
2
2
3
- #include " core/cuda_utils.h"
4
3
#include " core/utils.h"
5
4
6
- #include < cuda_runtime.h>
7
-
8
5
using namespace pplx ;
9
6
10
- namespace {
11
- template <typename T> T *mallocZeroBuffer (size_t size) {
12
- T *ptr;
13
- CUDACHECK (cudaMalloc (&ptr, size * sizeof (T)));
14
- cudaMemset (ptr, 0 , size * sizeof (T));
15
- return ptr;
16
- }
17
- } // namespace
18
-
19
7
AllToAll::AllToAll (
20
8
size_t maxNumTokens,
21
9
size_t numExperts,
@@ -37,31 +25,16 @@ AllToAll::AllToAll(
37
25
hiddenDimScaleBytes(hiddenDimScaleBytes),
38
26
rank(rank),
39
27
worldSize(worldSize),
40
- dpSize(dpSize),
41
- maxBatchTokens(numLocalExperts * numDPGroups * maxNumTokens) {
28
+ dpSize(dpSize) {
42
29
43
30
ROSE_ASSERT (hiddenDimBytes % 16 == 0 , " invalid hidden dim bytes" );
44
31
ROSE_ASSERT (hiddenDimScaleBytes % 16 == 0 , " invalid hidden dim scale bytes" );
45
32
const size_t perTokenBytes =
46
33
round_up<size_t >(hiddenDimBytes + hiddenDimScaleBytes + sizeof (uint32_t ), 16 );
47
- const size_t maxBatchTokens = numLocalExperts * numDPGroups * maxNumTokens;
48
34
49
35
ROSE_ASSERT (numLocalExperts != 0 , " numLocalExperts is 0" );
50
36
ROSE_ASSERT (numDPGroups > 1 , " at least 2 DP groups are required" );
51
37
ROSE_ASSERT (hiddenDimScaleBytes <= hiddenDimBytes, " invalid hidden dim bytes" );
52
-
53
- // Buffers for token tracking.
54
- numTokensPerDP = mallocZeroBuffer<uint32_t >(numLocalExperts * numDPGroups);
55
- sourceIndex = mallocZeroBuffer<uint32_t >(maxBatchTokens);
56
- sourceExpert = mallocZeroBuffer<uint32_t >(maxBatchTokens);
57
- sourceOffset = mallocZeroBuffer<uint32_t >(maxBatchTokens);
58
- sourceGroup = mallocZeroBuffer<uint32_t >(maxBatchTokens);
59
38
}
60
39
61
- AllToAll::~AllToAll () {
62
- CUDACHECK (cudaFree (numTokensPerDP));
63
- CUDACHECK (cudaFree (sourceIndex));
64
- CUDACHECK (cudaFree (sourceExpert));
65
- CUDACHECK (cudaFree (sourceOffset));
66
- CUDACHECK (cudaFree (sourceGroup));
67
- }
40
+ AllToAll::~AllToAll () {}
0 commit comments