diff --git a/examples/common.cpp b/examples/common.cpp index 5400f6b01f9d7..9770f83fda220 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -164,6 +164,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.antiprompt.push_back(argv[i]); + } else if (arg == "--stop") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.stop_keywords.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; } else if (arg == "--ignore-eos") { @@ -209,8 +215,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); - fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); - fprintf(stderr, " specified more than once for multiple prompts).\n"); + fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT\n"); + fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n"); + fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n"); + fprintf(stderr, " (can be specified more than once for multiple keywords).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for <= 0)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); diff --git a/examples/common.h b/examples/common.h index 1505aa927eb3a..9e37ec88e959d 100644 --- a/examples/common.h +++ b/examples/common.h @@ -35,6 +35,7 @@ struct gpt_params { std::vector antiprompt; // string upon seeing which more user input is prompted + std::vector stop_keywords; // string upon seeing which the model will stop bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 67a34e6677edf..b74c2a62dacb2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -209,6 +209,13 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } } + + if (params.stop_keywords.size()) { + for (auto stop_keyword : params.stop_keywords) { + fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str()); + } + } + fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); @@ -344,13 +351,28 @@ int main(int argc, char ** argv) { // check if we should prompt the user for more if (params.interactive && (int) embd_inp.size() <= n_consumed) { - // check for reverse prompt - if (params.antiprompt.size()) { - std::string last_output; + std::string last_output; + if (params.antiprompt.size() || params.stop_keywords.size()) { for (auto id : last_n_tokens) { last_output += llama_token_to_str(ctx, id); } + } + + // Check for stop keywords, a configurable alternative to the end-of-text token + // This should stop also the interactive mode, useful to stop interactive mode without SIGTERM + bool stop = false; + for (std::string stop_keyword : params.stop_keywords) { + if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { + stop = true; + break; + } + } + if (stop) { + break; + } + // check for reverse prompt + if (params.antiprompt.size()) { is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. for (std::string & antiprompt : params.antiprompt) { @@ -430,6 +452,24 @@ int main(int argc, char ** argv) { } } + // Check for stop keywords, a configurable alternative to the end-of-text token + if (!params.interactive && params.stop_keywords.size() && !is_interacting) { + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } + bool stop = false; + for (std::string stop_keyword : params.stop_keywords) { + if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { + stop = true; + break; + } + } + if (stop) { + break; + } + } + // end of text token if (!embd.empty() && embd.back() == llama_token_eos()) { if (params.instruct) {