From 52e319050b917d694ae0542eb3a1ee757cd8e887 Mon Sep 17 00:00:00 2001 From: Danny Daemonic Date: Tue, 18 Apr 2023 02:55:40 -0700 Subject: [PATCH 1/5] Add author mode and other related QOL improvements --- examples/common.cpp | 153 ++++++++++++++++++++++++++++++++++++++++- examples/common.h | 19 +++-- examples/main/main.cpp | 60 ++++++---------- 3 files changed, 184 insertions(+), 48 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 97eded6eccd64..cd37f96f4f09f 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -27,7 +27,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int const wchar_t * lpWideCharStr, int cchWideChar, char * lpMultiByteStr, int cbMultiByte, const char * lpDefaultChar, bool * lpUsedDefaultChar); +#define ENABLE_LINE_INPUT 0x0002 +#define ENABLE_ECHO_INPUT 0x0004 #define CP_UTF8 65001 +#define CONSOLE_CHAR_TYPE wchar_t +#define CONSOLE_GET_CHAR() getwchar() +#define CONSOLE_EOF WEOF +#else +#include +#define CONSOLE_CHAR_TYPE char +#define CONSOLE_GET_CHAR() getchar() +#define CONSOLE_EOF EOF #endif int32_t get_num_physical_cores() { @@ -264,6 +274,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.embedding = true; } else if (arg == "--interactive-first") { params.interactive_first = true; + } else if (arg == "--author-mode") { + params.author_mode = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; } else if (arg == "--color") { @@ -356,6 +368,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); + fprintf(stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); @@ -477,7 +490,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { } /* Keep track of current color of output, and emit ANSI code if it changes. */ -void set_console_color(console_state & con_st, console_color_t color) { +void console_set_color(console_state & con_st, console_color_t color) { if (con_st.use_color && con_st.color != color) { switch(color) { case CONSOLE_COLOR_DEFAULT: @@ -494,8 +507,9 @@ void set_console_color(console_state & con_st, console_color_t color) { } } +void console_init(console_state & con_st) { #if defined (_WIN32) -void win32_console_init(bool enable_color) { + // Windows-specific console initialization unsigned long dwMode = 0; void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { @@ -506,7 +520,7 @@ void win32_console_init(bool enable_color) { } if (hConOut) { // Enable ANSI colors on Windows 10+ - if (enable_color && !(dwMode & 0x4)) { + if (con_st.use_color && !(dwMode & 0x4)) { SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) } // Set console output codepage to UTF8 @@ -516,9 +530,46 @@ void win32_console_init(bool enable_color) { if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { // Set console input codepage to UTF16 _setmode(_fileno(stdin), _O_WTEXT); + + // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + SetConsoleMode(hConIn, dwMode); + } +#else + // POSIX-specific console initialization + struct termios new_termios; + tcgetattr(STDIN_FILENO, &con_st.prev_state); + new_termios = con_st.prev_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); +#endif +} + +void console_cleanup(console_state & con_st) { +#if !defined(_WIN32) + // Restore the terminal settings on POSIX systems + tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); +#endif + + // Reset console color + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); +} + +// Helper function to remove the last UTF-8 character from a string +void remove_last_utf8_char(std::string & line) { + if (line.empty()) return; + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character } + line.erase(pos); } +#if defined (_WIN32) // Convert a wide Unicode string to an UTF8 string void win32_utf8_encode(const std::wstring & wstr, std::string & str) { int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); @@ -527,3 +578,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) { str = strTo; } #endif + +bool console_readline(console_state & con_st, std::string & line) { + line.clear(); + bool is_special_char = false; + bool end_of_stream = false; + + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + + CONSOLE_CHAR_TYPE input_char; + while (true) { + fflush(stdout); // Ensure all output is displayed before waiting for input + input_char = CONSOLE_GET_CHAR(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + putchar('\b'); + putchar(line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR(); + if (code == '[') { + // Discard the rest of the escape sequence + while ((code = CONSOLE_GET_CHAR()) != CONSOLE_EOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!line.empty()) { + fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again + remove_last_utf8_char(line); + } + } else if (input_char < 32) { + // Ignore control characters + } else { +#if defined(_WIN32) + std::string utf8_char; + win32_utf8_encode(std::wstring(1, input_char), utf8_char); + line += utf8_char; + fputs(utf8_char.c_str(), stdout); +#else + line += input_char; + putchar(input_char); +#endif + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + console_set_color(con_st, CONSOLE_COLOR_PROMPT); + putchar('\b'); + putchar(line.back()); + is_special_char = true; + } + } + + bool has_more = con_st.author_mode; + if (is_special_char) { + fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + putchar('\n'); + has_more = !has_more; + } else { + // llama doesn't seem to process a single space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + putchar('\b'); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + putchar('\n'); + } + } + + fflush(stdout); + return has_more; +} diff --git a/examples/common.h b/examples/common.h index 842e1516ffe05..cb1e384e2f39b 100644 --- a/examples/common.h +++ b/examples/common.h @@ -10,6 +10,10 @@ #include #include +#if !defined (_WIN32) +#include +#endif + // // CLI argument parsing // @@ -56,6 +60,7 @@ struct gpt_params { bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately + bool author_mode = false; // reverse the usage of `\` bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token @@ -104,13 +109,15 @@ enum console_color_t { }; struct console_state { + bool author_mode = false; bool use_color = false; console_color_t color = CONSOLE_COLOR_DEFAULT; +#if !defined (_WIN32) + termios prev_state; +#endif }; -void set_console_color(console_state & con_st, console_color_t color); - -#if defined (_WIN32) -void win32_console_init(bool enable_color); -void win32_utf8_encode(const std::wstring & wstr, std::string & str); -#endif +void console_init(console_state & con_st); +void console_cleanup(console_state & con_st); +void console_set_color(console_state & con_st, console_color_t color); +bool console_readline(console_state & con_st, std::string & line); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 43dca8eb5ea82..5124b8aa94ecb 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -35,12 +35,12 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; } else { + console_cleanup(con_st); + printf("\n"); llama_print_timings(*g_ctx); _exit(130); } @@ -59,10 +59,9 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) con_st.use_color = params.use_color; - -#if defined (_WIN32) - win32_console_init(params.use_color); -#endif + con_st.author_mode = params.author_mode; + console_init(con_st); + atexit([]() { console_cleanup(con_st); }); if (params.perplexity) { printf("\n************\n"); @@ -275,12 +274,21 @@ int main(int argc, char ** argv) { std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { + const char *control_message; + if (con_st.author_mode) { + control_message = " - To return control to LLaMa, end your input with '\\'.\n" + " - To return control without starting a new line, end your input with '/'.\n"; + } else { + control_message = " - Press Return to return control to LLaMa.\n" + " - To return control without starting a new line, end your input with '/'.\n" + " - If you want to submit another line, end your input with '\\'.\n"; + } fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif - " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n\n"); + "%s\n", control_message); + is_interacting = params.interactive_first; } @@ -299,7 +307,7 @@ int main(int argc, char ** argv) { int n_session_consumed = 0; // the first thing we will do is to output the prompt, so set color accordingly - set_console_color(con_st, CONSOLE_COLOR_PROMPT); + console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; @@ -498,7 +506,7 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (input_echo && (int)embd_inp.size() == n_consumed) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } // in interactive mode, and not currently processing queued inputs; @@ -518,17 +526,12 @@ int main(int argc, char ** argv) { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { is_interacting = true; is_antiprompt = true; - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - fflush(stdout); break; } } } if (n_past > 0 && is_interacting) { - // potentially set color to indicate we are taking user input - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - if (params.instruct) { printf("\n> "); } @@ -542,31 +545,12 @@ int main(int argc, char ** argv) { std::string line; bool another_line = true; do { -#if defined(_WIN32) - std::wstring wline; - if (!std::getline(std::wcin, wline)) { - // input stream is bad or EOF received - return 0; - } - win32_utf8_encode(wline, line); -#else - if (!std::getline(std::cin, line)) { - // input stream is bad or EOF received - return 0; - } -#endif - if (!line.empty()) { - if (line.back() == '\\') { - line.pop_back(); // Remove the continue character - } else { - another_line = false; - } - buffer += line + '\n'; // Append the line to the result - } + another_line = console_readline(con_st, line); + buffer += line; } while (another_line); // done taking input, reset color - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back @@ -622,7 +606,5 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - return 0; } From 94dd17247af7bbf28098bc29d4a71849d8554e1f Mon Sep 17 00:00:00 2001 From: Danny Daemonic Date: Sun, 23 Apr 2023 07:05:02 -0700 Subject: [PATCH 2/5] author mode -> multiline input --- examples/common.cpp | 8 ++++---- examples/common.h | 4 ++-- examples/main/main.cpp | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index cd37f96f4f09f..f7973fc5299bd 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -274,10 +274,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.embedding = true; } else if (arg == "--interactive-first") { params.interactive_first = true; - } else if (arg == "--author-mode") { - params.author_mode = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; + } else if (arg == "--multiline-input") { + params.multiline_input = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -368,7 +368,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); - fprintf(stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\'\n"); + fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); @@ -644,7 +644,7 @@ bool console_readline(console_state & con_st, std::string & line) { } } - bool has_more = con_st.author_mode; + bool has_more = con_st.multiline_input; if (is_special_char) { fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again diff --git a/examples/common.h b/examples/common.h index cb1e384e2f39b..0950fc7c334c3 100644 --- a/examples/common.h +++ b/examples/common.h @@ -60,7 +60,7 @@ struct gpt_params { bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately - bool author_mode = false; // reverse the usage of `\` + bool multiline_input = false; // reverse the usage of `\` bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token @@ -109,7 +109,7 @@ enum console_color_t { }; struct console_state { - bool author_mode = false; + bool multiline_input = false; bool use_color = false; console_color_t color = CONSOLE_COLOR_DEFAULT; #if !defined (_WIN32) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5124b8aa94ecb..10ae61422918a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -59,7 +59,7 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) con_st.use_color = params.use_color; - con_st.author_mode = params.author_mode; + con_st.multiline_input = params.multiline_input; console_init(con_st); atexit([]() { console_cleanup(con_st); }); @@ -275,7 +275,7 @@ int main(int argc, char ** argv) { if (params.interactive) { const char *control_message; - if (con_st.author_mode) { + if (con_st.multiline_input) { control_message = " - To return control to LLaMa, end your input with '\\'.\n" " - To return control without starting a new line, end your input with '/'.\n"; } else { From 8f9f962d4dcbfb0fc7f8a1c80a1de27e3a102bd8 Mon Sep 17 00:00:00 2001 From: Danny Daemonic Date: Sat, 29 Apr 2023 17:48:52 -0700 Subject: [PATCH 3/5] Signed variable to unsigned variable cast --- examples/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index f7973fc5299bd..248c37e5a6b5b 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -622,7 +622,7 @@ bool console_readline(console_state & con_st, std::string & line) { fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again remove_last_utf8_char(line); } - } else if (input_char < 32) { + } else if (static_cast(input_char) < 32) { // Ignore control characters } else { #if defined(_WIN32) From 534c89e76610ab00c52bc12a08a00aa3d2961867 Mon Sep 17 00:00:00 2001 From: Danny Daemonic Date: Fri, 5 May 2023 08:38:58 -0700 Subject: [PATCH 4/5] Track character width --- examples/common.cpp | 157 ++++++++++++++++++++++++++++++-------------- 1 file changed, 106 insertions(+), 51 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 248c37e5a6b5b..f6b2d6b1351ca 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -15,29 +15,13 @@ #endif #if defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include #include #include -#pragma comment(lib,"kernel32.lib") -extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); -extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int CodePage, unsigned long dwFlags, - const wchar_t * lpWideCharStr, int cchWideChar, - char * lpMultiByteStr, int cbMultiByte, - const char * lpDefaultChar, bool * lpUsedDefaultChar); -#define ENABLE_LINE_INPUT 0x0002 -#define ENABLE_ECHO_INPUT 0x0004 -#define CP_UTF8 65001 -#define CONSOLE_CHAR_TYPE wchar_t -#define CONSOLE_GET_CHAR() getwchar() -#define CONSOLE_EOF WEOF #else #include -#define CONSOLE_CHAR_TYPE char -#define CONSOLE_GET_CHAR() getchar() -#define CONSOLE_EOF EOF #endif int32_t get_num_physical_cores() { @@ -545,6 +529,7 @@ void console_init(console_state & con_st) { new_termios.c_cc[VTIME] = 0; tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); #endif + setlocale(LC_ALL, ""); } void console_cleanup(console_state & con_st) { @@ -557,9 +542,80 @@ void console_cleanup(console_state & con_st) { console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } +#if defined (_WIN32) +int puts_get_width(_In_z_ CONST CHAR* lpBuffer) { + DWORD nNumberOfCharsToWrite = strlen(lpBuffer); + + HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) { + // Make a guess + return 1; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + + DWORD written = 0; + WriteConsole(hConsole, lpBuffer, nNumberOfCharsToWrite, &written, nullptr); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (newBufferInfo.dwCursorPosition.Y > initialPosition.Y) { + width += (newBufferInfo.dwSize.X - initialPosition.X); + } + + return width; +} +#endif + +char32_t getchar32() { + wchar_t wc = getwchar(); + if (static_cast(wc) == WEOF) { + return WEOF; + } + +#if WCHAR_MAX == 0xFFFF + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + wchar_t low_surrogate = getwchar(); + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; + } + } + if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + return 0xFFFD; // Return the replacement character U+FFFD + } +#endif + + return static_cast(wc); +} + +void append_utf8(char32_t ch, std::string & out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } +} + // Helper function to remove the last UTF-8 character from a string -void remove_last_utf8_char(std::string & line) { - if (line.empty()) return; +void pop_back_utf8_char(std::string & line) { + if (line.empty()) { + return; + } + size_t pos = line.length() - 1; // Find the start of the last UTF-8 character (checking up to 4 bytes back) @@ -569,33 +625,24 @@ void remove_last_utf8_char(std::string & line) { line.erase(pos); } -#if defined (_WIN32) -// Convert a wide Unicode string to an UTF8 string -void win32_utf8_encode(const std::wstring & wstr, std::string & str) { - int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); - std::string strTo(size_needed, 0); - WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), &strTo[0], size_needed, NULL, NULL); - str = strTo; -} -#endif - bool console_readline(console_state & con_st, std::string & line) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + line.clear(); + std::vector widths; bool is_special_char = false; bool end_of_stream = false; - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); - - CONSOLE_CHAR_TYPE input_char; + char32_t input_char; while (true) { fflush(stdout); // Ensure all output is displayed before waiting for input - input_char = CONSOLE_GET_CHAR(); + input_char = getchar32(); if (input_char == '\r' || input_char == '\n') { break; } - if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/) { + if (input_char == WEOF || input_char == 0x04 /* Ctrl+D*/) { end_of_stream = true; break; } @@ -608,31 +655,39 @@ bool console_readline(console_state & con_st, std::string & line) { } if (input_char == '\033') { // Escape sequence - CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR(); - if (code == '[') { + char32_t code = getchar32(); + if (code == '[' || code == 0x1B) { // Discard the rest of the escape sequence - while ((code = CONSOLE_GET_CHAR()) != CONSOLE_EOF) { + while ((code = getchar32()) != WEOF) { if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { break; } } } } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace - if (!line.empty()) { - fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again - remove_last_utf8_char(line); + if (!widths.empty()) { + int count; + do { + count = widths.back(); + widths.pop_back(); + // Move cursor back, print spaces, and move cursor back again + for (int i = 0; i < count; i++) { + fputs("\b \b", stdout); + } + pop_back_utf8_char(line); + } while (count == 0 && !widths.empty()); } - } else if (static_cast(input_char) < 32) { + } else if (input_char < 32) { // Ignore control characters } else { -#if defined(_WIN32) - std::string utf8_char; - win32_utf8_encode(std::wstring(1, input_char), utf8_char); - line += utf8_char; - fputs(utf8_char.c_str(), stdout); + int offset = line.length(); + append_utf8(input_char, line); +#if defined (_WIN32) + int width = puts_get_width(line.c_str() + offset); + widths.push_back(width); #else - line += input_char; - putchar(input_char); + fputs(line.c_str() + offset, stdout); + widths.push_back(wcwidth(input_char)); #endif } @@ -655,7 +710,7 @@ bool console_readline(console_state & con_st, std::string & line) { putchar('\n'); has_more = !has_more; } else { - // llama doesn't seem to process a single space + // llama will just eat the single space if (line.length() == 1 && line.back() == ' ') { line.clear(); putchar('\b'); From 5cc9085353b788c04ffd84860dd6b93bc54393c0 Mon Sep 17 00:00:00 2001 From: Danny Daemonic Date: Sun, 7 May 2023 02:39:10 -0700 Subject: [PATCH 5/5] Works with all characters and control codes + Windows console fixes --- examples/common.cpp | 240 +++++++++++++++++++++++++++++--------------- examples/common.h | 8 +- 2 files changed, 168 insertions(+), 80 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index f6b2d6b1351ca..5eeab0cb1a02e 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -14,14 +14,16 @@ #include #endif -#if defined (_WIN32) +#if defined(_WIN32) #define WIN32_LEAN_AND_MEAN #define NOMINMAX #include #include #include #else +#include #include +#include #endif int32_t get_num_physical_cores() { @@ -473,45 +475,27 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { return lctx; } -/* Keep track of current color of output, and emit ANSI code if it changes. */ -void console_set_color(console_state & con_st, console_color_t color) { - if (con_st.use_color && con_st.color != color) { - switch(color) { - case CONSOLE_COLOR_DEFAULT: - printf(ANSI_COLOR_RESET); - break; - case CONSOLE_COLOR_PROMPT: - printf(ANSI_COLOR_YELLOW); - break; - case CONSOLE_COLOR_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); - break; - } - con_st.color = color; - } -} - void console_init(console_state & con_st) { -#if defined (_WIN32) +#if defined(_WIN32) // Windows-specific console initialization - unsigned long dwMode = 0; - void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) - if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { - hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) - if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { - hConOut = 0; + DWORD dwMode = 0; + con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) { + con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) { + con_st.hConsole = NULL; } } - if (hConOut) { + if (con_st.hConsole) { // Enable ANSI colors on Windows 10+ - if (con_st.use_color && !(dwMode & 0x4)) { - SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) + if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); } // Set console output codepage to UTF8 SetConsoleOutputCP(CP_UTF8); } - void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) - if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { // Set console input codepage to UTF16 _setmode(_fileno(stdin), _O_WTEXT); @@ -528,46 +512,49 @@ void console_init(console_state & con_st) { new_termios.c_cc[VMIN] = 1; new_termios.c_cc[VTIME] = 0; tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + con_st.tty = fopen("/dev/tty", "w+"); + if (con_st.tty != nullptr) { + con_st.out = con_st.tty; + } #endif setlocale(LC_ALL, ""); } void console_cleanup(console_state & con_st) { + // Reset console color + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + #if !defined(_WIN32) + if (con_st.tty != nullptr) { + con_st.out = stdout; + fclose(con_st.tty); + con_st.tty = nullptr; + } // Restore the terminal settings on POSIX systems tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); #endif - - // Reset console color - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } -#if defined (_WIN32) -int puts_get_width(_In_z_ CONST CHAR* lpBuffer) { - DWORD nNumberOfCharsToWrite = strlen(lpBuffer); - - HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); - CONSOLE_SCREEN_BUFFER_INFO bufferInfo; - if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) { - // Make a guess - return 1; - } - COORD initialPosition = bufferInfo.dwCursorPosition; - - DWORD written = 0; - WriteConsole(hConsole, lpBuffer, nNumberOfCharsToWrite, &written, nullptr); - - CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; - GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); - - int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; - if (newBufferInfo.dwCursorPosition.Y > initialPosition.Y) { - width += (newBufferInfo.dwSize.X - initialPosition.X); +/* Keep track of current color of output, and emit ANSI code if it changes. */ +void console_set_color(console_state & con_st, console_color_t color) { + if (con_st.use_color && con_st.color != color) { + fflush(stdout); + switch(color) { + case CONSOLE_COLOR_DEFAULT: + fprintf(con_st.out, ANSI_COLOR_RESET); + break; + case CONSOLE_COLOR_PROMPT: + fprintf(con_st.out, ANSI_COLOR_YELLOW); + break; + case CONSOLE_COLOR_USER_INPUT: + fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + } + con_st.color = color; + fflush(con_st.out); } - - return width; } -#endif char32_t getchar32() { wchar_t wc = getwchar(); @@ -590,6 +577,102 @@ char32_t getchar32() { return static_cast(wc); } +void pop_cursor(console_state & con_st) { +#if defined(_WIN32) + if (con_st.hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(con_st.hConsole, newCursorPosition); + return; + } +#endif + putc('\b', con_st.out); +} + +int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + return 1; +#else + return wcwidth(codepoint); +#endif +} + +int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // we can trust expectedWidth if we've got one + if (expectedWidth >= 0 || con_st.tty == nullptr) { + fwrite(utf8_codepoint, length, 1, con_st.out); + return expectedWidth; + } + + fputs("\033[6n", con_st.tty); // Query cursor position + int x1, x2, y1, y2; + int results = 0; + results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, con_st.tty); + + fputs("\033[6n", con_st.tty); // Query cursor position + results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif +} + +void replace_last(console_state & con_st, char ch) { +#if defined(_WIN32) + pop_cursor(con_st); + put_codepoint(con_st, &ch, 1, 1); +#else + fprintf(con_st.out, "\b%c", ch); +#endif +} + void append_utf8(char32_t ch, std::string & out) { if (ch <= 0x7F) { out.push_back(static_cast(ch)); @@ -627,6 +710,9 @@ void pop_back_utf8_char(std::string & line) { bool console_readline(console_state & con_st, std::string & line) { console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + if (con_st.out != stdout) { + fflush(stdout); + } line.clear(); std::vector widths; @@ -635,7 +721,7 @@ bool console_readline(console_state & con_st, std::string & line) { char32_t input_char; while (true) { - fflush(stdout); // Ensure all output is displayed before waiting for input + fflush(con_st.out); // Ensure all output is displayed before waiting for input input_char = getchar32(); if (input_char == '\r' || input_char == '\n') { @@ -649,8 +735,7 @@ bool console_readline(console_state & con_st, std::string & line) { if (is_special_char) { console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); - putchar('\b'); - putchar(line.back()); + replace_last(con_st, line.back()); is_special_char = false; } @@ -670,50 +755,47 @@ bool console_readline(console_state & con_st, std::string & line) { do { count = widths.back(); widths.pop_back(); - // Move cursor back, print spaces, and move cursor back again + // Move cursor back, print space, and move cursor back again for (int i = 0; i < count; i++) { - fputs("\b \b", stdout); + replace_last(con_st, ' '); + pop_cursor(con_st); } pop_back_utf8_char(line); } while (count == 0 && !widths.empty()); } - } else if (input_char < 32) { - // Ignore control characters } else { int offset = line.length(); append_utf8(input_char, line); -#if defined (_WIN32) - int width = puts_get_width(line.c_str() + offset); + int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } widths.push_back(width); -#else - fputs(line.c_str() + offset, stdout); - widths.push_back(wcwidth(input_char)); -#endif } if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { console_set_color(con_st, CONSOLE_COLOR_PROMPT); - putchar('\b'); - putchar(line.back()); + replace_last(con_st, line.back()); is_special_char = true; } } bool has_more = con_st.multiline_input; if (is_special_char) { - fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again + replace_last(con_st, ' '); + pop_cursor(con_st); char last = line.back(); line.pop_back(); if (last == '\\') { line += '\n'; - putchar('\n'); + fputc('\n', con_st.out); has_more = !has_more; } else { - // llama will just eat the single space + // llama will just eat the single space, it won't act as a space if (line.length() == 1 && line.back() == ' ') { line.clear(); - putchar('\b'); + pop_cursor(con_st); } has_more = false; } @@ -722,10 +804,10 @@ bool console_readline(console_state & con_st, std::string & line) { has_more = false; } else { line += '\n'; - putchar('\n'); + fputc('\n', con_st.out); } } - fflush(stdout); + fflush(con_st.out); return has_more; } diff --git a/examples/common.h b/examples/common.h index 0950fc7c334c3..43f1cc9ef09d5 100644 --- a/examples/common.h +++ b/examples/common.h @@ -11,6 +11,7 @@ #include #if !defined (_WIN32) +#include #include #endif @@ -112,7 +113,12 @@ struct console_state { bool multiline_input = false; bool use_color = false; console_color_t color = CONSOLE_COLOR_DEFAULT; -#if !defined (_WIN32) + + FILE* out = stdout; +#if defined (_WIN32) + void* hConsole; +#else + FILE* tty = nullptr; termios prev_state; #endif };