From b646ffa1b16cf5423e33d437bed32462daf2efef Mon Sep 17 00:00:00 2001 From: Johnman Date: Mon, 20 Mar 2023 16:06:58 +0100 Subject: [PATCH 1/4] Check for reverse prompt by characters instead of tokens (#292) --- main.cpp | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/main.cpp b/main.cpp index 15903337339fb..e3bbd8c12f4c3 100644 --- a/main.cpp +++ b/main.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include @@ -877,16 +878,9 @@ int main(int argc, char ** argv) { params.interactive = true; params.antiprompt.push_back("### Instruction:\n\n"); } - - // tokenize the reverse prompt - std::vector> antipromptv_inp; - for (auto antiprompt : params.antiprompt) { - antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false)); - } - // enable interactive mode if reverse prompt is specified - if (antipromptv_inp.size() != 0) { + if (params.antiprompt.size() != 0) { params.interactive = true; } @@ -910,15 +904,9 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: interactive mode on.\n", __func__); - if(antipromptv_inp.size()) { - for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) { - auto antiprompt_inp = antipromptv_inp.at(apindex); - fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str()); - fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); - for (int i = 0; i < (int) antiprompt_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); - } - fprintf(stderr, "\n"); + if(params.antiprompt.size()) { + for (auto antiprompt : params.antiprompt) { + fprintf(stderr, "Antiprompt: %s\n", antiprompt); } } } @@ -1035,12 +1023,23 @@ int main(int argc, char ** argv) { // check if we should prompt the user for more if (params.interactive && embd_inp.size() <= input_consumed) { // check for reverse prompt - for (auto antiprompt_inp : antipromptv_inp) { - if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { - // reverse prompt found + + std::stringstream last_output_ss; + for (auto id : last_n_tokens) { + last_output_ss << vocab.id_to_token[id]; + } + std::string last_output = last_output_ss.str(); + + for (std::string antiprompt : params.antiprompt) { + if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { is_interacting = true; break; } + /*if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { + // reverse prompt found + is_interacting = true; + break; + }*/ } if (is_interacting) { if (params.instruct) { From e9f77473dcffb477ba2a30c5dee4a5a863422ed8 Mon Sep 17 00:00:00 2001 From: tjohnman Date: Mon, 20 Mar 2023 16:10:26 +0100 Subject: [PATCH 2/4] Update main.cpp Wording. --- main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.cpp b/main.cpp index e3bbd8c12f4c3..2f212b6a4da69 100644 --- a/main.cpp +++ b/main.cpp @@ -906,7 +906,7 @@ int main(int argc, char ** argv) { if(params.antiprompt.size()) { for (auto antiprompt : params.antiprompt) { - fprintf(stderr, "Antiprompt: %s\n", antiprompt); + fprintf(stderr, "Reverse prompt: %s\n", antiprompt); } } } From 6242d1ccd5fd2b9c403ad526647176bd9edfedf0 Mon Sep 17 00:00:00 2001 From: Johnman Date: Mon, 20 Mar 2023 16:15:00 +0100 Subject: [PATCH 3/4] Cleanup. --- main.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/main.cpp b/main.cpp index 2f212b6a4da69..0ac3a13e48ec4 100644 --- a/main.cpp +++ b/main.cpp @@ -906,7 +906,7 @@ int main(int argc, char ** argv) { if(params.antiprompt.size()) { for (auto antiprompt : params.antiprompt) { - fprintf(stderr, "Reverse prompt: %s\n", antiprompt); + fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str()); } } } @@ -1023,23 +1023,18 @@ int main(int argc, char ** argv) { // check if we should prompt the user for more if (params.interactive && embd_inp.size() <= input_consumed) { // check for reverse prompt - std::stringstream last_output_ss; for (auto id : last_n_tokens) { last_output_ss << vocab.id_to_token[id]; } std::string last_output = last_output_ss.str(); + // Check if each of the reverse prompts appears at the end of the output. for (std::string antiprompt : params.antiprompt) { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { is_interacting = true; break; } - /*if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { - // reverse prompt found - is_interacting = true; - break; - }*/ } if (is_interacting) { if (params.instruct) { From 12280807c344f9e164be88cf5d4538c693956018 Mon Sep 17 00:00:00 2001 From: Johnman Date: Mon, 20 Mar 2023 16:31:10 +0100 Subject: [PATCH 4/4] Remove unnecessary use of std::stringstream. --- main.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/main.cpp b/main.cpp index 0ac3a13e48ec4..7295f017a0612 100644 --- a/main.cpp +++ b/main.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include @@ -1023,11 +1022,10 @@ int main(int argc, char ** argv) { // check if we should prompt the user for more if (params.interactive && embd_inp.size() <= input_consumed) { // check for reverse prompt - std::stringstream last_output_ss; + std::string last_output; for (auto id : last_n_tokens) { - last_output_ss << vocab.id_to_token[id]; + last_output += vocab.id_to_token[id]; } - std::string last_output = last_output_ss.str(); // Check if each of the reverse prompts appears at the end of the output. for (std::string antiprompt : params.antiprompt) {