Skip to content

Commit 9668aa1

Browse files
committed
llama : distinguish pieces from decoded text + fix detokenization
1 parent 5d0ffb6 commit 9668aa1

File tree

15 files changed

+93
-68
lines changed

15 files changed

+93
-68
lines changed

common/common.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -733,16 +733,37 @@ std::vector<llama_token> llama_tokenize(
733733
return result;
734734
}
735735

736-
std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
736+
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
737737
std::vector<char> result(8, 0);
738-
const int n_tokens = llama_token_to_str(ctx, token, result.data(), result.size());
738+
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
739739
if (n_tokens < 0) {
740740
result.resize(-n_tokens);
741-
int check = llama_token_to_str(ctx, token, result.data(), result.size());
741+
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
742742
GGML_ASSERT(check == -n_tokens);
743743
} else {
744744
result.resize(n_tokens);
745745
}
746746

747747
return std::string(result.data(), result.size());
748748
}
749+
750+
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens) {
751+
const llama_token bos_id = llama_token_bos(ctx);
752+
753+
std::string piece;
754+
std::string result;
755+
756+
for (size_t i = 0; i < tokens.size(); ++i) {
757+
piece = llama_token_to_piece(ctx, tokens[i]);
758+
759+
// remove the leading space of the first non-BOS token
760+
if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') {
761+
piece = piece.substr(1);
762+
}
763+
764+
result += piece;
765+
}
766+
767+
return result;
768+
}
769+

common/common.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ std::vector<llama_token> llama_tokenize(
121121
const std::string & text,
122122
bool add_bos);
123123

124-
std::string llama_token_to_str(
124+
std::string llama_token_to_piece(
125125
const struct llama_context * ctx,
126126
llama_token token);
127+
128+
// removes the leading space from the first non-BOS token
129+
std::string llama_detokenize(
130+
llama_context * ctx,
131+
const std::vector<llama_token> & tokens);

examples/beam_search/beam_search.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct ostream_beam_view {
3535
std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
3636
os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
3737
for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
38-
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
38+
os << llama_token_to_piece(obv.ctx, obv.beam_view.tokens[i]);
3939
}
4040
return os << ')';
4141
}
@@ -156,7 +156,7 @@ int main(int argc, char ** argv)
156156

157157
for( auto id : tokens_list )
158158
{
159-
std::cout << llama_token_to_str(ctx, id);
159+
std::cout << llama_token_to_piece(ctx, id);
160160
}
161161
std::cout << std::flush;
162162

@@ -175,7 +175,7 @@ int main(int argc, char ** argv)
175175

176176
std::cout << "\n\n";
177177
for (llama_token const token_id : callback_data.response) {
178-
std::cout << llama_token_to_str(ctx,token_id);
178+
std::cout << llama_token_to_piece(ctx,token_id);
179179
}
180180
std::cout << std::endl;
181181

examples/embd-input/embd-input-lib.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ const char * sampling(struct MyModel * mymodel) {
214214
if (id == llama_token_eos(ctx)) {
215215
ret = "</s>";
216216
} else {
217-
ret = llama_token_to_str(ctx, id);
217+
ret = llama_token_to_piece(ctx, id);
218218
}
219219
eval_id(mymodel, id);
220220
return ret.c_str();

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ int main(int argc, char ** argv) {
6464
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
6565
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
6666
for (int i = 0; i < (int) embd_inp.size(); i++) {
67-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
67+
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
6868
}
6969
fprintf(stderr, "\n");
7070
}

examples/main/main.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,22 +280,22 @@ int main(int argc, char ** argv) {
280280
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
281281
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
282282
for (int i = 0; i < (int) embd_inp.size(); i++) {
283-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
283+
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
284284
}
285285

286286
if (ctx_guidance) {
287287
fprintf(stderr, "\n");
288288
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
289289
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
290290
for (int i = 0; i < (int) guidance_inp.size(); i++) {
291-
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]).c_str());
291+
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
292292
}
293293
}
294294

295295
if (params.n_keep > 0) {
296296
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
297297
for (int i = 0; i < params.n_keep; i++) {
298-
fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]).c_str());
298+
fprintf(stderr, "%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
299299
}
300300
fprintf(stderr, "'\n");
301301
}
@@ -451,7 +451,7 @@ int main(int argc, char ** argv) {
451451
//printf("\n---\n");
452452
//printf("resetting: '");
453453
//for (int i = 0; i < (int) embd.size(); i++) {
454-
// printf("%s", llama_token_to_str(ctx, embd[i]));
454+
// printf("%s", llama_token_to_piece(ctx, embd[i]));
455455
//}
456456
//printf("'\n");
457457
//printf("\n---\n");
@@ -504,7 +504,7 @@ int main(int argc, char ** argv) {
504504
input_size = embd_guidance.size();
505505
//fprintf(stderr, "\n---------------------\n");
506506
//for (int i = 0; i < (int) embd_guidance.size(); i++) {
507-
//fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i]));
507+
//fprintf(stderr, "%s", llama_token_to_piece(ctx, embd_guidance[i]));
508508
//}
509509
//fprintf(stderr, "\n---------------------\n");
510510
} else {
@@ -663,7 +663,7 @@ int main(int argc, char ** argv) {
663663
// display text
664664
if (input_echo) {
665665
for (auto id : embd) {
666-
printf("%s", llama_token_to_str(ctx, id).c_str());
666+
printf("%s", llama_token_to_piece(ctx, id).c_str());
667667
}
668668
fflush(stdout);
669669
}
@@ -679,7 +679,7 @@ int main(int argc, char ** argv) {
679679
if (params.antiprompt.size()) {
680680
std::string last_output;
681681
for (auto id : last_n_tokens) {
682-
last_output += llama_token_to_str(ctx, id);
682+
last_output += llama_token_to_piece(ctx, id);
683683
}
684684

685685
is_antiprompt = false;

examples/save-load-state/save-load-state.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ int main(int argc, char ** argv) {
8787
}
8888
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
8989
auto next_token = llama_sample_token(ctx, &candidates_p);
90-
auto next_token_str = llama_token_to_str(ctx, next_token);
90+
auto next_token_str = llama_token_to_piece(ctx, next_token);
9191
last_n_tokens_data.push_back(next_token);
9292

9393
printf("%s", next_token_str.c_str());
@@ -147,7 +147,7 @@ int main(int argc, char ** argv) {
147147
}
148148
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
149149
auto next_token = llama_sample_token(ctx2, &candidates_p);
150-
auto next_token_str = llama_token_to_str(ctx2, next_token);
150+
auto next_token_str = llama_token_to_piece(ctx2, next_token);
151151
last_n_tokens_data.push_back(next_token);
152152

153153
printf("%s", next_token_str.c_str());

examples/server/server.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
9494
std::string ret;
9595
for (; begin != end; ++begin)
9696
{
97-
ret += llama_token_to_str(ctx, *begin);
97+
ret += llama_token_to_piece(ctx, *begin);
9898
}
9999
return ret;
100100
}
@@ -123,7 +123,7 @@ static void server_log(const char *level, const char *function, int line,
123123
// format incomplete utf-8 multibyte character for output
124124
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
125125
{
126-
std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
126+
std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
127127
// if the size is 1 and first bit is 1, meaning it's a partial character
128128
// (size > 1 meaning it's already a known token)
129129
if (out.size() == 1 && (out[0] & 0x80) == 0x80)
@@ -566,7 +566,7 @@ struct llama_server_context
566566

567567
if (!embd.empty() && embd.back() == llama_token_eos(ctx))
568568
{
569-
// stopping_word = llama_token_to_str(ctx, embd.back());
569+
// stopping_word = llama_token_to_piece(ctx, embd.back());
570570
has_next_token = false;
571571
stopped_eos = true;
572572
LOG_VERBOSE("eos token found", {});
@@ -613,7 +613,7 @@ struct llama_server_context
613613
{
614614
const completion_token_output token_with_probs = nextToken();
615615

616-
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
616+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
617617
generated_text += token_text;
618618

619619
if (params.n_probs > 0)
@@ -1248,7 +1248,7 @@ void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
12481248

12491249
struct token_translator {
12501250
llama_context * ctx;
1251-
std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
1251+
std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
12521252
std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
12531253
};
12541254

@@ -1358,7 +1358,7 @@ int main(int argc, char **argv)
13581358

13591359
while (llama.has_next_token) {
13601360
const completion_token_output token_with_probs = llama.doCompletion();
1361-
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
1361+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok);
13621362

13631363
stop_pos = llama.findStoppingStrings(llama.generated_text,
13641364
token_text.size(), STOP_FULL);
@@ -1389,7 +1389,7 @@ int main(int argc, char **argv)
13891389
if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
13901390
continue;
13911391
}
1392-
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
1392+
const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
13931393

13941394
size_t pos = std::min(sent_count, llama.generated_text.size());
13951395

examples/simple/simple.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ int main(int argc, char ** argv) {
6363
fprintf(stderr, "\n\n");
6464

6565
for (auto id : tokens_list) {
66-
fprintf(stderr, "%s", llama_token_to_str(ctx, id).c_str());
66+
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
6767
}
6868

6969
fflush(stderr);
@@ -112,7 +112,7 @@ int main(int argc, char ** argv) {
112112
}
113113

114114
// print the new token :
115-
printf("%s", llama_token_to_str(ctx, new_token_id).c_str());
115+
printf("%s", llama_token_to_piece(ctx, new_token_id).c_str());
116116
fflush(stdout);
117117

118118
// push this new token for next evaluation

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,7 +1964,7 @@ void print_matrix(struct ggml_tensor * probs) {
19641964

19651965

19661966
void print_token(struct llama_context * ctx, llama_token token) {
1967-
printf("%s", llama_token_to_str(ctx, token).c_str());
1967+
printf("%s", llama_token_to_piece(ctx, token).c_str());
19681968
}
19691969

19701970
void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) {
@@ -2202,7 +2202,7 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
22022202
const char * in = buf.data();
22032203
const char * end = buf.data() + buf.size();
22042204
for (int i = 0; i < (int) out.size(); ++i) {
2205-
std::string s = llama_token_to_str(lctx, out[i]);
2205+
std::string s = llama_token_to_piece(lctx, out[i]);
22062206
int len = s.length();
22072207
if (in >= end) {
22082208
printf("%s: unexpected end of original text.\n", __func__);

0 commit comments

Comments
 (0)