Skip to content

Commit 9eb8aa5

Browse files
committed
Shift all values by the max value before applying logsoftmax
1 parent f70c885 commit 9eb8aa5

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

llama.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2120,10 +2120,18 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
21202120

21212121
template<typename T, typename LogitAccessor>
21222122
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) {
2123+
T* element = std::max_element(
2124+
array, array + size,
2125+
[&logit_accessor](T& lhs, T& rhs) {
2126+
return logit_accessor(lhs) < logit_accessor(rhs);
2127+
}
2128+
);
2129+
2130+
float max_l = logit_accessor(*element);
21232131
float sum = 0.f;
21242132
for (int i = 0; i < size; ++i) {
21252133
float& logit = logit_accessor(array[i]);
2126-
float p = expf(logit);
2134+
float p = expf(logit - max_l);
21272135
sum += p;
21282136
logit = p;
21292137
}

0 commit comments

Comments
 (0)