diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 9a09a28ee..8774d05f8 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,5 +1,6 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Union, Protocol +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union + from . import llama_types @@ -144,7 +145,7 @@ def format_llama2( messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: - _system_template = "[INST] <>\n{system_message}\n<>\n\n" + _system_template = "<>\n{system_message}\n<>\n\n" _roles = dict(user="[INST]", assistant="[/INST]") _sep = "\n\n" system_message = _get_system_message(messages)