@@ -66,7 +66,8 @@ void invokeCurandBatchInitialize(curandState_t* states, int const* batchSlots, c
66
66
template <typename T>
67
67
__global__ void addBiasSoftMax (T* logits, T** logitsPtrs, T* probs, T const * bias, int32_t const * endIds,
68
68
FinishedState const * finished, int32_t const * batchSlots, int32_t batchSize, int32_t maxBatchSize,
69
- int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits)
69
+ int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits,
70
+ float const * minPs)
70
71
{
71
72
auto const batchIdx = blockIdx .x ;
72
73
auto const beamIdx = blockIdx .y ;
@@ -114,6 +115,12 @@ __global__ void addBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bia
114
115
logitsPtr[tid] = logit;
115
116
}
116
117
118
+ float minP = 0 .0f ;
119
+ if (minPs != nullptr )
120
+ {
121
+ minP = minPs[batchSlot];
122
+ }
123
+
117
124
if (!skipSoftMax)
118
125
{
119
126
maxVal = blockReduceMax<float >((float ) maxVal);
@@ -123,10 +130,18 @@ __global__ void addBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bia
123
130
}
124
131
__syncthreads ();
125
132
133
+ // min_p : probability of token proportional to the max token
134
+ // compare min_p against exp(logit - maxVal) / exp(maxVal - maxVal) = exp(logit - maxVal)
135
+
126
136
float sumVal = 0 .0f ;
127
137
for (int tid = threadIdx .x ; tid < vocabSizePadded; tid += blockDim .x )
128
138
{
129
- probs[offset + tid] = __expf ((float ) logitsPtr[tid] - sMaxVal );
139
+ float rel_prob = __expf ((float ) logitsPtr[tid] - sMaxVal );
140
+ if (rel_prob < minP) {
141
+ rel_prob = 0.0 ;
142
+ logitsPtr[tid] = -MAX_T_VAL;
143
+ }
144
+ probs[offset + tid] = rel_prob;
130
145
sumVal += (float ) probs[offset + tid];
131
146
}
132
147
@@ -148,7 +163,7 @@ template <typename T>
148
163
void invokeAddBiasSoftMax (T* logits, T** logitsPtrs, T* probs, T const * bias, int32_t const * endIds,
149
164
FinishedState const * finished, int32_t const * batchSlots, int32_t batchSize, int32_t maxBatchSize,
150
165
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits,
151
- cudaStream_t stream)
166
+ float const * minPs, cudaStream_t stream)
152
167
{
153
168
TLLM_LOG_TRACE (" %s start" , __PRETTY_FUNCTION__);
154
169
@@ -157,20 +172,20 @@ void invokeAddBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, in
157
172
dim3 block (min (vocabRoundedToWarp, 1024 ));
158
173
// vocabSize, e.g., 30000, 7000.... vocabSize is usually very big.
159
174
addBiasSoftMax<<<grid, block, 0 , stream>>> (logits, logitsPtrs, probs, bias, endIds, finished, batchSlots, batchSize,
160
- maxBatchSize, beamWidth, vocabSize, vocabSizePadded, skipSoftMax, batchSlotsLogits);
175
+ maxBatchSize, beamWidth, vocabSize, vocabSizePadded, skipSoftMax, batchSlotsLogits, minPs );
161
176
162
177
TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
163
178
}
164
179
165
180
template void invokeAddBiasSoftMax (float * logits, float ** logitsPtrs, float * probs, float const * bias,
166
181
int32_t const * endIds, FinishedState const * finished, int32_t const * batchSlots, int32_t batchSize,
167
182
int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax,
168
- bool batchSlotsLogits, cudaStream_t stream);
183
+ bool batchSlotsLogits, float const * minPs, cudaStream_t stream);
169
184
170
185
template void invokeAddBiasSoftMax (half* logits, half** logitsPtrs, half* probs, half const * bias,
171
186
int32_t const * endIds, FinishedState const * finished, int32_t const * batchSlots, int32_t batchSize,
172
187
int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax,
173
- bool batchSlotsLogits, cudaStream_t stream);
188
+ bool batchSlotsLogits, float const * minPs, cudaStream_t stream);
174
189
175
190
template <typename T>
176
191
__global__ void scatterDecodingParamsKernel (T const * src, T* dst, int const * batchSlots, int batchSize)
0 commit comments