Skip to content

Commit 325fc88

Browse files
committed
Shift all values by the max value before applying logsoftmax
1 parent 8e66e59 commit 325fc88

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

llama.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2143,10 +2143,18 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
21432143

21442144
template<typename T, typename LogitAccessor>
21452145
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) {
2146+
T* element = std::max_element(
2147+
array, array + size,
2148+
[&logit_accessor](T& lhs, T& rhs) {
2149+
return logit_accessor(lhs) < logit_accessor(rhs);
2150+
}
2151+
);
2152+
2153+
float max_l = logit_accessor(*element);
21462154
float sum = 0.f;
21472155
for (int i = 0; i < size; ++i) {
21482156
float& logit = logit_accessor(array[i]);
2149-
float p = expf(logit);
2157+
float p = expf(logit - max_l);
21502158
sum += p;
21512159
logit = p;
21522160
}

0 commit comments

Comments
 (0)