Skip to content

Commit 0481a36

Browse files
committed
Use first bad_words as extra parameters, and implement min-p
1 parent b777bd6 commit 0481a36

18 files changed

+86
-36
lines changed

cpp/include/tensorrt_llm/runtime/decodingInput.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class DecodingInput
3030
using TensorPtr = std::shared_ptr<ITensor const>;
3131

3232
DecodingInput(SizeType32 maxLength, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength,
33-
SizeType32 maxBatchSize, TensorPtr logits, TensorPtr endIds)
33+
SizeType32 maxBatchSize, TensorPtr logits, TensorPtr endIds, TensorPtr minP)
3434
: step{maxLength}
3535
, maxLength{maxLength}
3636
, maxAttentionWindow{maxAttentionWindow}
@@ -40,6 +40,7 @@ class DecodingInput
4040
, maxBadWordsLen{0}
4141
, logits{std::move(logits)}
4242
, endIds{std::move(endIds)}
43+
, minP{std::move(minP)}
4344
{
4445
TLLM_CHECK_WITH_INFO(static_cast<bool>(this->logits), "Invalid logits tensor");
4546
TLLM_CHECK_WITH_INFO(static_cast<bool>(this->endIds), "Invalid endIds tensor");
@@ -57,6 +58,7 @@ class DecodingInput
5758
std::optional<std::vector<TensorPtr>>
5859
logitsVec; // vector of size [batchSize] contains logits of size [beamWidth, vocabSizePadded], on gpu
5960
TensorPtr endIds; // [maxBatchSize * beamWidth], on gpu
61+
TensorPtr minP; // [maxBatchSize * beamWidth], on gpu
6062

6163
// optional parameters
6264
TensorPtr finished; // [maxBatchSize, beamWidth], finished states at current iteration.

cpp/tensorrt_llm/kernels/banBadWords.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ __global__ void ban_bad_words(T* logits, TokenIdType const** output_ids_ptr, Siz
3131
SizeType32 const* bad_words_lens, SizeType32 vocab_size_padded, SizeType32 const* sequence_lengths,
3232
SizeType32 max_seq_len)
3333
{
34-
auto const id = blockIdx.x * blockDim.x + threadIdx.x;
34+
auto const id = blockIdx.x * blockDim.x + threadIdx.x + 1;
3535
auto const batch_idx = blockIdx.y / beam_width;
3636
auto const beam_idx = blockIdx.y % beam_width;
3737
auto const batch_slot = batch_slots != nullptr ? batch_slots[batch_idx] : batch_idx;

cpp/tensorrt_llm/kernels/beamSearchKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ struct BeamHypotheses
5252
// Pointers from input
5353
int const* inputLengths{nullptr}; // [BS, BM] %% context_length
5454
int const* endIds{nullptr}; // [BS, BM] %% self.end_ids
55+
float const* minP{nullptr}; // [BS, BM] %% self.min_p
5556

5657
// Pointers for output
5758
int* outputIds{nullptr}; // [BS, BM, MSL] %% self.output_ids only used in gather_tree

cpp/tensorrt_llm/kernels/decodingCommon.cu

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ void invokeCurandBatchInitialize(curandState_t* states, int const* batchSlots, c
6666
template <typename T>
6767
__global__ void addBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, int32_t const* endIds,
6868
FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize, int32_t maxBatchSize,
69-
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits)
69+
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits,
70+
float const* minPs)
7071
{
7172
auto const batchIdx = blockIdx.x;
7273
auto const beamIdx = blockIdx.y;
@@ -114,6 +115,12 @@ __global__ void addBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bia
114115
logitsPtr[tid] = logit;
115116
}
116117

118+
float minP = 0.0f;
119+
if (minPs != nullptr)
120+
{
121+
minP = minPs[batchSlot];
122+
}
123+
117124
if (!skipSoftMax)
118125
{
119126
maxVal = blockReduceMax<float>((float) maxVal);
@@ -123,10 +130,18 @@ __global__ void addBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bia
123130
}
124131
__syncthreads();
125132

133+
// min_p : probability of token proportional to the max token
134+
// compare min_p against exp(logit - maxVal) / exp(maxVal - maxVal) = exp(logit - maxVal)
135+
126136
float sumVal = 0.0f;
127137
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
128138
{
129-
probs[offset + tid] = __expf((float) logitsPtr[tid] - sMaxVal);
139+
float rel_prob = __expf((float) logitsPtr[tid] - sMaxVal);
140+
if (rel_prob < minP) {
141+
rel_prob = 0.0;
142+
logitsPtr[tid] = -MAX_T_VAL;
143+
}
144+
probs[offset + tid] = rel_prob;
130145
sumVal += (float) probs[offset + tid];
131146
}
132147

@@ -148,7 +163,7 @@ template <typename T>
148163
void invokeAddBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, int32_t const* endIds,
149164
FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize, int32_t maxBatchSize,
150165
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits,
151-
cudaStream_t stream)
166+
float const* minPs, cudaStream_t stream)
152167
{
153168
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
154169

@@ -157,20 +172,20 @@ void invokeAddBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, in
157172
dim3 block(min(vocabRoundedToWarp, 1024));
158173
// vocabSize, e.g., 30000, 7000.... vocabSize is usually very big.
159174
addBiasSoftMax<<<grid, block, 0, stream>>>(logits, logitsPtrs, probs, bias, endIds, finished, batchSlots, batchSize,
160-
maxBatchSize, beamWidth, vocabSize, vocabSizePadded, skipSoftMax, batchSlotsLogits);
175+
maxBatchSize, beamWidth, vocabSize, vocabSizePadded, skipSoftMax, batchSlotsLogits, minPs);
161176

162177
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
163178
}
164179

165180
template void invokeAddBiasSoftMax(float* logits, float** logitsPtrs, float* probs, float const* bias,
166181
int32_t const* endIds, FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize,
167182
int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax,
168-
bool batchSlotsLogits, cudaStream_t stream);
183+
bool batchSlotsLogits, float const* minPs, cudaStream_t stream);
169184

170185
template void invokeAddBiasSoftMax(half* logits, half** logitsPtrs, half* probs, half const* bias,
171186
int32_t const* endIds, FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize,
172187
int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax,
173-
bool batchSlotsLogits, cudaStream_t stream);
188+
bool batchSlotsLogits, float const* minPs, cudaStream_t stream);
174189

175190
template <typename T>
176191
__global__ void scatterDecodingParamsKernel(T const* src, T* dst, int const* batchSlots, int batchSize)

cpp/tensorrt_llm/kernels/decodingCommon.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,13 @@ void invokeCurandBatchInitialize(curandState_t* states, int const* batchSlots, c
185185
//! \param vocabSizePadded padded vocab size
186186
//! \param skipSoftMax flag to skip softmax computation
187187
//! \param batchSlotsLogits flag to use batchSlot as index for logits and probs
188+
//! \param minPs input buffer [maxBatchSize]. minimum ratio of probability to maximum probability for token consideration.
188189
//! \param stream stream
189190
template <typename T>
190191
void invokeAddBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, int32_t const* endIds,
191192
FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize, int32_t maxBatchSize,
192193
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits,
193-
cudaStream_t stream);
194+
float const* minPs, cudaStream_t stream);
194195

195196
//! \brief Distributes values located in src to dst according to the indieces from batchSlots
196197
//!

cpp/tensorrt_llm/kernels/decodingKernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,11 +758,11 @@ void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs,
758758
invokeAddBiasSoftMax(draftLogits, static_cast<T**>(nullptr), draftProbs, static_cast<T*>(nullptr), nullptr,
759759
finished, batchSlots, batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded,
760760
/* skip softmax */ false,
761-
/* batchSlotLogits */ true, stream);
761+
/* batchSlotLogits */ true, (float*) (nullptr), stream);
762762
invokeAddBiasSoftMax(static_cast<T*>(nullptr), targetLogits, targetProbs, static_cast<T*>(nullptr), nullptr,
763763
finished, batchSlots, batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded,
764764
/* skip softmax */ false,
765-
/* batchSlotLogits */ true, stream);
765+
/* batchSlotLogits */ true, (float*) (nullptr), stream);
766766
}
767767
{
768768
dim3 block(1024);

cpp/tensorrt_llm/layers/beamSearchLayer.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ void BeamSearchLayer<T>::forwardAsyncSingleRequest(
137137
bh.earlyStoppings = mEarlyStoppingDevice;
138138
bh.inputLengths = ip->input_lengths->template getPtr<int const>();
139139
bh.endIds = ip->end_ids.template getPtr<int const>();
140+
bh.minP = ip->min_p.template getPtr<float const>();
140141
bh.logProbsTiled = (op->output_log_probs) ? op->output_log_probs->template getPtr<float>() : nullptr;
141142
bh.sequenceLengths = op->sequence_length->template getPtr<int>();
142143
bh.cumLogProbs = op->cum_log_probs->template getPtr<float>();
@@ -183,6 +184,7 @@ void BeamSearchLayer<T>::forwardAsync(
183184

184185
// common inputs
185186
auto const& endIds = params->end_ids;
187+
auto const& minP = params->min_p;
186188
auto const localBatchSize = static_cast<std::size_t>(params->local_batch_size);
187189

188190
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() > 1,
@@ -209,8 +211,10 @@ void BeamSearchLayer<T>::forwardAsync(
209211
= params->logits->slice({dynamic_decode_batch_size, params->logits->shape[1], params->logits->shape[2]},
210212
dynamic_decode_vocab_size_units_offset);
211213
auto const end_id_offset = endIds.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
214+
auto const min_p_offset = minP.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
212215

213-
auto forwardParams = std::make_shared<BeamSearchInputParams>(step, ite, logits_offset, end_id_offset,
216+
217+
auto forwardParams = std::make_shared<BeamSearchInputParams>(step, ite, logits_offset, end_id_offset, min_p_offset,
214218
*params->src_cache_indirection, static_cast<std::int32_t>(params->max_attention_window),
215219
static_cast<std::int32_t>(params->sink_token_length), static_cast<std::int32_t>(maxSeqLen));
216220

cpp/tensorrt_llm/layers/beamSearchLayer.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ class BeamSearchInputParams : public BaseInputParams
4848
{
4949
public:
5050
explicit BeamSearchInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, tc::Tensor logits,
51-
tc::Tensor endIds, tc::Tensor src_cache_indirection, runtime::SizeType32 max_attention_window,
52-
runtime::SizeType32 sink_token_length, runtime::SizeType32 max_seq_len)
53-
: BaseInputParams(step, ite, std::move(endIds))
51+
tc::Tensor endIds, tc::Tensor minPs, tc::Tensor src_cache_indirection, runtime::SizeType32 max_attention_window,
52+
runtime::SizeType32 sink_token_length, runtime::SizeType32 max_seq_len)
53+
: BaseInputParams(step, ite, std::move(endIds), std::move(minPs))
5454
, logits{std::move(logits)}
5555
, max_attention_window{max_attention_window}
5656
, sink_token_length{sink_token_length}

cpp/tensorrt_llm/layers/decodingLayer.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>>
206206
auto const localDecoderDomain = getLocalDecoderDomain(params);
207207
auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1];
208208
auto const& endIds = params->end_ids;
209+
auto const& minP = params->min_p;
209210

210211
std::shared_ptr<BaseOutputParams> preparedOutputs;
211212
std::shared_ptr<BaseInputParams> preparedInputs;
@@ -230,8 +231,9 @@ std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>>
230231
Tensor const logitsSlice{params->logits->slice(
231232
{localBatchSize, static_cast<size_t>(localDecoderDomain.getBeamWidth()), params->logits->shape[2]}, 0)};
232233
Tensor const endIdSlice{endIds.slice({localBatchSize}, 0)};
234+
Tensor const minPSlice{minP.slice({localBatchSize}, 0)};
233235
auto decodeInputs = std::make_shared<SamplingInputParams>(
234-
step, ite, logitsSlice, endIdSlice, static_cast<SizeType32>(maxSeqLen));
236+
step, ite, logitsSlice, endIdSlice, minPSlice, static_cast<SizeType32>(maxSeqLen));
235237

236238
decodeInputs->finished = params->finished;
237239

@@ -274,7 +276,7 @@ std::tuple<std::shared_ptr<BaseOutputParams>, std::shared_ptr<BaseInputParams>>
274276
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() == 1,
275277
"Decoding mode is Medusa, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth());
276278

277-
auto medusaInputParams = std::make_shared<MedusaInputParams>(params->logits.value(), endIds);
279+
auto medusaInputParams = std::make_shared<MedusaInputParams>(params->logits.value(), endIds, minP);
278280
medusaInputParams->finished = outputs->finished.value();
279281
medusaInputParams->batch_slots = params->batch_slots;
280282
medusaInputParams->paths = params->medusaInputs->medusaPaths;

cpp/tensorrt_llm/layers/decodingParams.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,11 @@ class DynamicDecodeSetupParams : public BaseSetupParams
164164
class BaseInputParams
165165
{
166166
public:
167-
explicit BaseInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, tc::Tensor endIds)
167+
explicit BaseInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, tc::Tensor endIds, tc::Tensor minPs)
168168
: step{step}
169169
, ite{ite}
170170
, end_ids{std::move(endIds)}
171+
, min_p{std::move(minPs)}
171172
{
172173
}
173174

@@ -177,6 +178,7 @@ class BaseInputParams
177178
runtime::SizeType32 step;
178179
runtime::SizeType32 ite;
179180
tc::Tensor end_ids; // [maxBatchSize]
181+
tc::Tensor min_p; // [maxBatchSize]
180182
std::optional<tc::Tensor> batch_slots; // [forwardBatchSize], on pinned memory
181183
std::optional<tc::Tensor> finished; // [maxBatchSize, maxBeamWidth]
182184
};
@@ -186,8 +188,8 @@ class DynamicDecodeInputParams : public BaseInputParams
186188
public:
187189
DynamicDecodeInputParams(runtime::SizeType32 step, runtime::SizeType32 ite, runtime::SizeType32 maxInputLength,
188190
runtime::SizeType32 maxAttentionWindow, runtime::SizeType32 sinkTokenLength, runtime::SizeType32 localBatchSize,
189-
tc::Tensor endIds)
190-
: BaseInputParams(step, ite, std::move(endIds))
191+
tc::Tensor endIds, tc::Tensor minPs)
192+
: BaseInputParams(step, ite, std::move(endIds), std::move(minPs))
191193
, max_input_length{maxInputLength}
192194
, max_attention_window{maxAttentionWindow}
193195
, sink_token_length{sinkTokenLength}

0 commit comments

Comments
 (0)