diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index fdde7ea01..e6ad83da5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1625,6 +1625,8 @@ def create_chat_completion( format = llama_chat_format.get_chat_format(self.chat_format) result = format( messages=messages, + functions=functions, + function_call=function_call, ) prompt = result.prompt if result.stop is not None: diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 9a09a28ee..a3e850ef6 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,107 +1,206 @@ +""" +llama_cpp/llama_chat_format.py + +This module provides a chat formatting system that allows for custom templates and HuggingFace's jinja2-based chat templating. + +To extend or customize, simply inherit from the ChatFormatter class and override the necessary methods. Registered formatters can be accessed using the ChatFormatterFactory. + +NOTE: The system message is always assumed to be the first element in a sequence. + +NOTE: Users should avoid tampering with special tokens to prevent model issues. + +--- + +# IMPORTANT NOTES: + +- The use of the merge operator (|) for dictionaries requires Python 3.9 or higher. Keep in mind that llama-cpp-python supports Python 3.8 and later versions. If you are working with an earlier Python version, consider alternatives such as `dict.update()` or creating a custom function to merge dictionaries. For Python 3.9 or higher, the merge operator simplifies dictionary merging. +Source: https://docs.python.org/3/library/stdtypes.html?highlight=dict#dict + +- Special tokens are crucial for the model's underlying operations, impacting pre-training, fine-tuning, and low-level inference processes. Users should avoid modifying special tokens to prevent issues in the model's output during inference. These issues may manifest as token fixation, repetitive language patterns, contextual derailment, and hallucinations. Improper use of separators and templates can exacerbate these problems. + +Example using the llama-2 model and its templating schema: + +# 1 <>My name is Llama and I am a helpful assistant.<>$ +# 2 [INST] Hello Llama, my name is User. What's your name? [/INST]$ +# 3 Hello User, my name is Llama. Nice to meet you!$ +# 4 [INST] What can you do? [/INST]$ +# 5 I can assist you with various tasks, including providing structured output for certain queries.$ +# 6 [INST] How can you assist me in my programming projects? [/INST]$ +# 7 $ + +This initial example is a proper template format that the model understands. It results in proper output and does not confuse the model. + +# 1 <>My name is Llama and I am a helpful assistant.<>$ +# 2 [INST] Hello Llama, my name is User. What's your name? [/INST]$ +# 3 Hello User, my name is Llama. Nice to meet you!$ +# 4 [INST] What can you do? [/INST]$ +# 5 I can assist you with various tasks, including providing structured output for certain queries.$ +# 6 [INST] How can you assist me in my programming projects? [/INST]$ +# 7 $ + +This example includes the use of special tokens, and the model may or may not use these tokens as a result. The model is not expecting them during inference, which causes unexpected behavior. + +# 1 <>My name is Llama and I am a helpful assistant.<>$ +# 2 $ +# 3 [INST] Hello Llama, my name is User. What's your name? [/INST]$ +# 4 Hello User, my name is Llama. Nice to meet you!$ +# 5 $ +# 6 [INST] What can you do? [/INST]$ +# 7 I can assist you with various tasks, including providing structured output for certain queries.$ +# 8 $ +# 9 [INST] How can you assist me in my programming projects? [/INST]$ +# 10 $ + +This example is improperly formatted and causes the model to become confused. The model begins to fixate on tokens, uses language repetition, and eventually derails. + +--- + +# Usage example: +# Registering a custom formatter +@ChatFormatterFactory.register_predefined_model("llama-2") +class Llama2Formatter(ChatFormatter): + def __init__(self): + super().__init__(llama2_template) + +# Obtaining a registered formatter +chat_formatter_factory = ChatFormatterFactory() +llama2_formatter = chat_formatter_factory.get_formatter_by_name("llama-2") + +# Formatting messages +messages = [{"role": "user", "content": "Hello, World!"}] +response = llama2_formatter(messages) +print(response) +""" import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Union, Protocol -from . import llama_types +import os +from typing import Any, Dict, List, Optional, Protocol, Type, Union +import huggingface_hub +from transformers import AutoTokenizer -def _get_system_message( - messages: List[llama_types.ChatCompletionRequestMessage], -) -> str: - """Get the first system message.""" - for message in messages: - if message["role"] == "system": - return message["content"] or "" - return "" - - -def _map_roles( - messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str] -) -> List[Tuple[str, Optional[str]]]: - """Map the message roles.""" - output: List[Tuple[str, Optional[str]]] = [] - for message in messages: - role = message["role"] - if role in role_map: - output.append((role_map[role], message["content"])) - return output - - -def _format_llama2( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the llama2 style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += message + " " - else: - ret += role + " " - return ret - - -def _format_add_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the add-colon-single style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += role + ": " + message + sep - else: - ret += role + ":" - return ret - - -def _format_add_colon_two( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str -) -> str: - """Format the prompt with the add-colon-two style.""" - seps = [sep, sep2] - ret = system_message + seps[0] - for i, (role, message) in enumerate(messages): - if message: - ret += role + ": " + message + seps[i % 2] - else: - ret += role + ":" - return ret - - -def _format_no_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the no-colon-single style.""" - ret = system_message - for role, message in messages: - if message: - ret += role + message + sep - else: - ret += role - return ret - - -def _format_add_colon_space_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the add-colon-space-single style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += role + ": " + message + sep - else: - ret += role + ": " # must be end with a space - return ret - - -def _format_chatml( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the chatml style.""" - ret = "" if system_message == "" else system_message + sep + "\n" - for role, message in messages: - if message: - ret += role + "\n" + message + sep + "\n" - else: - ret += role + "\n" - return ret +from . import llama_types + +# Default chat formatting templates for reusability. +# These templates can be reused or modified on a model-by-model basis. + +# Template for HuggingFace-based models. +huggingface_template = { + "model": "meta-llama/Llama-2-7b-chat-hf", + "jinja": None, + "tokenize": False, +} + +# Common formatting settings applicable to all roles in chat models. +common_template: llama_types.CommonTemplate = { + "separators": { + "after_system": "\n", + "between_messages": "\n", + "end_of_response": "", + }, + "default_termination": { + "role": "assistant", # Default role for termination + "message": None, # Default termination message (None for assistant) + }, + "include_prompt": False, # Whether to include user prefix/postfix in prompts +} + +# Template for Llama-2 model. +llama2_template: llama_types.ChatMLTemplate = { + "roles": { + "system": { + "prefix": "<>", # System message prefix + "postfix": "<>", # System message postfix + "format": None, # Optionally specify a custom format + }, + "user": { + "prefix": "[INST] ", + "postfix": " [/INST]", # Model generates from here + "format": None, + }, + "assistant": { + "prefix": "", # No prefix for assistant role by default + "postfix": "", # No postfix for assistant role by default + "format": None, # Custom format for assistant (if needed) + }, + } +} +# Merge common settings into the llama2_template to reduce code duplication. +llama2_template |= common_template + +# Template for Alpaca model. +alpaca_template: llama_types.ChatMLTemplate = { + "roles": { + "system": { + "prefix": "", + "postfix": "\n", + "format": None, + }, + "user": { + "prefix": "### Instruction:\n", + "postfix": "\n", + "format": None, + }, + "input": { + "prefix": "### Input:\n", + "postfix": "\n", + "format": None, + }, + "assistant": { + "prefix": "### Response:\n", + "postfix": "", # Model generates from here + "format": None, + }, + } +} +alpaca_template |= common_template + +# Template for Vicuna model. +# NOTE: The v0 template differs from the v1.1, v1.3, and v1.5. +# This is the v1.5 Vicuna Template. +vicuna_template: llama_types.ChatMLTemplate = { + "roles": { + "system": { + "prefix": "", + "postfix": "\n", + "format": None, + }, + "user": { + "prefix": "USER: ", + "postfix": "", + "format": None, + }, + "assistant": { + "prefix": "ASSISTANT: ", # Model generates from here + "postfix": "", + "format": None, + }, + } +} +vicuna_template |= common_template + +# NOTE: Open Assistant uses multiple custom prompts. +# The oasst-llama hybrids utilize ChatML templates. +# The base template is defined here for convenience. +oasst_template: llama_types.ChatMLTemplate = { + "roles": { + "system": { + "prefix": "<|system|>", + "postfix": "<|endoftext|>", + "format": None, + }, + "user": { + "prefix": "<|prompter|>", + "postfix": "<|endoftext|>", + "format": None, + }, + "assistant": { + "prefix": "<|assistant|>", # Model generates from here + "postfix": "<|endoftext|>", + "format": None, + }, + } +} +oasst_template |= common_template @dataclasses.dataclass @@ -110,95 +209,169 @@ class ChatFormatterResponse: stop: Optional[Union[str, List[str]]] = None -class ChatFormatter(Protocol): +# Base Chat Formatter Protocol +class ChatFormatterInterface(Protocol): + def __init__(self, template: Optional[Dict[str, Any]] = None): + raise NotImplementedError + def __call__( self, - messages: List[llama_types.ChatCompletionRequestMessage], + messages: List[Dict[str, str]], + **kwargs, + ) -> ChatFormatterResponse: + raise NotImplementedError + + +# Core Chat Formatter class +# NOTE: Methods can be overridden as needed on a model-by-model basis. +class ChatFormatter(ChatFormatterInterface): + def __init__(self, template: Optional[Dict[str, Any]] = None): + self.template = template or llama2_template + + def __call__( + self, + messages: List[Dict[str, str]], **kwargs: Any, ) -> ChatFormatterResponse: - ... + formatted_messages = [ + self.format_message(msg["content"], msg["role"]) for msg in messages + ] + separator = self.format_separator("between_messages") + formatted_sequence = separator.join(formatted_messages) + # NOTE: Optionally include a prompt at the end + if self.template["include_prompt"]: + formatted_sequence += self.get_prompt() + # NOTE: `stop` is handled within completion methods + return ChatFormatterResponse(prompt=formatted_sequence) + + def format_message(self, message, role) -> str: + """Format a message based on the specified role.""" + try: + role_info = self.template["roles"][role] + except KeyError: + raise KeyError( + f"The role '{role}' is not defined in the template. Please check your template configuration." + ) + + prefix = role_info.get("prefix", "") + postfix = role_info.get("postfix", "") + formatted_message = f"{prefix}{message}{postfix}" + return formatted_message + + def format_separator(self, separator_type) -> str: + """Format separators based on the specified type.""" + return self.template["separators"].get(separator_type, "") + + def get_prompt(self) -> str: + # Implement logic to generate a prompt, if needed + return self.template["roles"]["user"]["prefix"] + + +class TokenizerCache: + _cache: Dict[str, AutoTokenizer] = {} + + @classmethod + def get_tokenizer(cls, model_name: str) -> AutoTokenizer: + if model_name not in cls._cache: + cls._cache[model_name] = AutoTokenizer.from_pretrained(model_name) + return cls._cache[model_name] + + +class AutoTokenizerFormatter(ChatFormatterInterface): + def __init__(self, template: Optional[Dict[str, str]] = None): + self.template = template or huggingface_template + self.huggingface_login() + self.tokenizer = TokenizerCache.get_tokenizer(self.template["model"]) + def __call__( + self, + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs, + ) -> ChatFormatterResponse: + formatted_content = self.format_messages(messages) + return ChatFormatterResponse( + prompt=formatted_content, stop=[self.tokenizer.eos_token] + ) -_CHAT_FORMATS: Dict[str, ChatFormatter] = {} + def huggingface_login(self) -> None: + token = os.getenv("HF_TOKEN") + if token is None: + raise AttributeError( + "Failed to login to huggingface. " + "Did you forget to set the `HF_TOKEN` environment variable with your huggingface token?" + ) + huggingface_hub.login(token) + + def format_messages( + self, messages: List[llama_types.ChatCompletionRequestMessage] + ) -> str: + # If a custom template is provided, override the tokenizer's default template + if self.template.get("jinja"): + self.tokenizer.chat_template = self.template["jinja"] + + return self.tokenizer.apply_chat_template( + messages, tokenize=self.template.get("tokenize", False) + ) -def register_chat_format(name: str): - def decorator(f: ChatFormatter): - _CHAT_FORMATS[name] = f - return f +# NOTE: Template registration is currently a WIP (work in progress). +class FormatterNotFoundException(Exception): + pass - return decorator +# External developers can now use the `@ChatFormatter.register_predefined_model` +# method to register their own custom formatters. +class ChatFormatterFactory: + _chat_formatters: Dict[str, ChatFormatterInterface] = {} -def get_chat_format(name: str): - try: - return _CHAT_FORMATS[name] - except KeyError: - raise ValueError( - f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})" - ) + @staticmethod + def register_predefined_model(name: str): + def decorator(cls: Type[ChatFormatterInterface]): + ChatFormatterFactory._chat_formatters[name] = cls() + return cls + return decorator -@register_chat_format("llama-2") -def format_llama2( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - _system_template = "[INST] <>\n{system_message}\n<>\n\n" - _roles = dict(user="[INST]", assistant="[/INST]") - _sep = "\n\n" - system_message = _get_system_message(messages) - system_message = _system_template.format(system_message=system_message) - _messages = _map_roles(messages, _roles) - _messages.append((_roles["assistant"], None)) - _prompt = _format_llama2(system_message, _messages, _sep) - return ChatFormatterResponse(prompt=_prompt) + def register_custom_model(self, name: str, formatter: ChatFormatterInterface): + self._chat_formatters[name] = formatter + def get_formatter_by_name(self, name: str) -> ChatFormatterInterface: + try: + return self._chat_formatters[name] + except KeyError: + raise FormatterNotFoundException( + f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})" + ) -@register_chat_format("alpaca") -def format_alpaca( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - _roles = dict(user="### Instruction", assistant="### Response") - _sep = "\n\n" - _sep2 = "" - system_message = _get_system_message(messages) - _messages = _map_roles(messages, _roles) - _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2) - return ChatFormatterResponse(prompt=_prompt) +# Define a chat format class and register it +@ChatFormatterFactory.register_predefined_model("llama-2") +class Llama2Formatter(ChatFormatter): + def __init__(self): + super().__init__(llama2_template) -@register_chat_format("vicuna") -def format( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - _system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." - _roles = dict(user="USER", assistant="ASSISTANT") - _sep = " " - _sep2 = "" - system_message = _system_message - _messages = _map_roles(messages, _roles) - _messages.append((_roles["assistant"], None)) - _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2) - return ChatFormatterResponse(prompt=_prompt) +# Define a chat format class and register it +@ChatFormatterFactory.register_predefined_model("alpaca") +class AlpacaFormatter(ChatFormatter): + def __init__(self): + # Define the Alpaca template + super().__init__(alpaca_template) -@register_chat_format("oasst_llama") -def format_oasst_llama( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - _system_template = "[INST] <>\n{system_message}\n<>\n\n" - _roles = dict(user="<|prompter|>", assistant="<|assistant|>") - _sep = "" - system_message = _get_system_message(messages) - system_message = _system_template.format(system_message=system_message) - _messages = _map_roles(messages, _roles) - _messages.append((_roles["assistant"], None)) - _prompt = _format_no_colon_single(system_message, _messages, _sep) - return ChatFormatterResponse(prompt=_prompt) + +@ChatFormatterFactory.register_predefined_model("vicuna") +class VicunaFormatter(ChatFormatter): + def __init__(self): + # Define the Vicuna template + super().__init__(vicuna_template) + + +# NOTE: Refer to `oasst_template` note for more information. +@ChatFormatterFactory.register_predefined_model("oasst") +class OpenAssistantFormatter(ChatFormatter): + def __init__(self): + # Define the Open Assistant template + super().__init__(oasst_template) @register_chat_format("openbuddy") @@ -320,3 +493,124 @@ def format_chatml( _messages.append((_roles["assistant"], None)) _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt) + + +@register_chat_format("functionary") +def format_functionary( + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, + **kwargs: Any, +) -> ChatFormatterResponse: + SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" + + def generate_schema_from_functions( + functions: List[llama_types.ChatCompletionFunctions], + namespace: str = "functions", + ): + """ + Convert functions schema to a schema that language models can understand. + """ + + schema = ( + "// Supported function definitions that should be called when necessary.\n" + ) + schema += f"namespace {namespace} {{\n\n" + + for function in functions: + # Convert a Function object to dict, if necessary + function_name = function["name"] + description = function.get("description", "") + schema += f"// {description}\n" + schema += f"type {function_name}" + + parameters = function.get("parameters", None) + schema += " = (_: {\n" + required_params = parameters.get("required", []) + for param_name, param in parameters.get("properties", {}).items(): + # Param Description + description = param.get("description") + if description is not None: + schema += f"// {description}\n" + + # Param Name + schema += f"{param_name}" + if param_name not in required_params: + schema += "?" + + # Param Type + param_type = param.get("type", "any") + if param_type == "integer": + param_type = "number" + if "enum" in param: + param_type = " | ".join([f'"{v}"' for v in param["enum"]]) + schema += f": {param_type},\n" + + schema += "}) => any;\n\n" + + schema += f"}} // namespace {namespace}" + + return schema + + def prepare_messages_for_inference( + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, + ): + all_messages: List[llama_types.ChatCompletionRequestMessage] = [] + if functions is not None: + all_messages.append( + llama_types.ChatCompletionRequestMessage( + role="system", content=generate_schema_from_functions(functions) + ) + ) + + all_messages.append( + llama_types.ChatCompletionRequestMessage( + role="system", content=SYSTEM_MESSAGE + ) + ) + + for message in messages: + # Function call responses + if message["role"] == "function" and "name" in message: + message["name"] = f"functions.{message['name']}" + # Function call requests by assistant + if "function_call" in message: + message["function_call"][ + "name" + ] = f"functions.{message['function_call']['name']}" + all_messages.append(message) + + all_messages.append( + llama_types.ChatCompletionRequestMessage(role="assistant", content=None) + ) + + def message_to_str(msg: llama_types.ChatCompletionRequestMessage): + if msg["role"] == "system": + return f"system:\n{msg['content']}\n" + + elif msg["role"] == "function" and "name" in msg: + return f"function name={msg['name']}:\n{msg['content']}\n" + elif msg["role"] == "user": + if msg["content"] is None: + return "user:\n" + else: + return f"user:\n{msg['content']}\n" + elif msg["role"] == "assistant": + if msg["content"] is not None and "function_call" in msg: + return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + elif "function_call" in msg: + return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + elif msg["content"] is None: + return "assistant" + else: + return f"assistant:\n{msg['content']}\n" + else: + raise ValueError(f"Unsupported role: {msg['role']}") + + return "".join([message_to_str(msg) for msg in all_messages]) + + prompt = prepare_messages_for_inference(messages, functions) + return ChatFormatterResponse( + prompt=prompt, + stop=["user:", ""], + ) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 8ff15658e..29431d957 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1,4 +1,5 @@ -"""C++ implementation of the llama grammar parser.""" +"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" + # flake8: noqa from pathlib import Path import sys @@ -1056,8 +1057,7 @@ def print_rule( # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " - + str(rule_id) + "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { @@ -1102,9 +1102,7 @@ def print_rule( for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype if case is llama_gretype.LLAMA_GRETYPE_END: - raise RuntimeError( - "unexpected end of rule: " + str(rule_id) + "," + str(i) - ) + raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: @@ -1186,3 +1184,308 @@ def print_grammar(file: TextIO, state: parse_state) -> None: f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, ) + + +"""llama.cpp gbnf rules from vendor/llama.cpp/grammars""" + +ARITHMETIC_GBNF = """\ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= ident | num | "(" ws expr ")" ws +ident ::= [a-z] [a-z0-9_]* ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* +""" + +C_GBNF = """\ +root ::= (declaration)* + +declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" + +dataType ::= "int" ws | "float" ws | "char" ws +identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* + +parameter ::= dataType identifier + +statement ::= + ( dataType identifier ws "=" ws expression ";" ) | + ( identifier ws "=" ws expression ";" ) | + ( identifier ws "(" argList? ")" ";" ) | + ( "return" ws expression ";" ) | + ( "while" "(" condition ")" "{" statement* "}" ) | + ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | + ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | + ( singleLineComment ) | + ( multiLineComment ) + +forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression +forUpdate ::= identifier ws "=" ws expression + +condition ::= expression relationOperator expression +relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") + +expression ::= term (("+" | "-") term)* +term ::= factor(("*" | "/") factor)* + +factor ::= identifier | number | unaryTerm | funcCall | parenExpression +unaryTerm ::= "-" factor +funcCall ::= identifier "(" argList? ")" +parenExpression ::= "(" ws expression ws ")" + +argList ::= expression ("," ws expression)* + +number ::= [0-9]+ + +singleLineComment ::= "//" [^\n]* "\n" +multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" + +ws ::= ([ \t\n]+) +""" + +CHESS_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + +JAPANESE_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + +JSON_ARR_GBNF = """\ +# This is the same as json.gbnf but we restrict whitespaces at the end of the root array +# Useful for generating JSON arrays + +root ::= arr +value ::= object | array | string | number | ("true" | "false" | "null") ws + +arr ::= + "[\n" ws ( + value + (",\n" ws value)* + )? "]" + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + + +JSON_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)?""" + +LIST_GBNF = """\ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" +""" + +"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py""" +import json +import re +from typing import List, Optional + +# whitespace is constrained to a single space char to prevent model "running away" in +# whitespace. Also maybe improves generation quality? +SPACE_RULE = '" "?' + +PRIMITIVE_RULES = { + "boolean": '("true" | "false") space', + "number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', + "integer": '("-"? ([0-9] | [1-9] [0-9]*)) space', + "string": r""" "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) + )* "\"" space """, + "null": '"null" space', +} + +INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} + + +class SchemaConverter: + def __init__(self, prop_order): + self._prop_order = prop_order + self._rules = {"space": SPACE_RULE} + + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) + ) + return f'"{escaped}"' + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub("-", name) + if esc_name not in self._rules or self._rules[esc_name] == rule: + key = esc_name + else: + i = 0 + while f"{esc_name}{i}" in self._rules: + i += 1 + key = f"{esc_name}{i}" + self._rules[key] = rule + return key + + def visit(self, schema, name): + schema_type = schema.get("type") + rule_name = name or "root" + + if "oneOf" in schema or "anyOf" in schema: + rule = " | ".join( + ( + self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') + for i, alt_schema in enumerate( + schema.get("oneOf") or schema["anyOf"] + ) + ) + ) + return self._add_rule(rule_name, rule) + + elif "const" in schema: + return self._add_rule(rule_name, self._format_literal(schema["const"])) + + elif "enum" in schema: + rule = " | ".join((self._format_literal(v) for v in schema["enum"])) + return self._add_rule(rule_name, rule) + + elif schema_type == "object" and "properties" in schema: + # TODO: `required` keyword + prop_order = self._prop_order + prop_pairs = sorted( + schema["properties"].items(), + # sort by position in prop_order (if specified) then by key + key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]), + ) + + rule = '"{" space' + for i, (prop_name, prop_schema) in enumerate(prop_pairs): + prop_rule_name = self.visit( + prop_schema, f'{name}{"-" if name else ""}{prop_name}' + ) + if i > 0: + rule += ' "," space' + rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' + rule += ' "}" space' + + return self._add_rule(rule_name, rule) + + elif schema_type == "array" and "items" in schema: + # TODO `prefixItems` keyword + item_rule_name = self.visit( + schema["items"], f'{name}{"-" if name else ""}item' + ) + rule = ( + f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space' + ) + return self._add_rule(rule_name, rule) + + else: + assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" + return self._add_rule( + "root" if rule_name == "root" else schema_type, + PRIMITIVE_RULES[schema_type], + ) + + def format_grammar(self): + return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) + + +def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): + prop_order = prop_order or [] + schema = json.load(schema) + prop_order = {name: idx for idx, name in enumerate(prop_order)} + converter = SchemaConverter(prop_order) + converter.visit(schema, "") + return converter.format_grammar() diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 6ee7ef914..70c79c321 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -4,8 +4,9 @@ https://github.com/openai/openai-openapi/blob/master/openapi.yaml """ -from typing import Any, List, Optional, Dict, Union -from typing_extensions import TypedDict, NotRequired, Literal +from typing import Any, Dict, List, Optional, Union + +from typing_extensions import Literal, NotRequired, TypedDict class EmbeddingUsage(TypedDict): @@ -170,3 +171,19 @@ class ChatCompletionRequestMessage(TypedDict): content: Optional[str] name: NotRequired[str] funcion_call: NotRequired[ChatCompletionFunctionCall] + + +class RoleTemplate(TypedDict, total=False): + prefix: str + postfix: str + format: Optional[str] + + +class CommonTemplate(TypedDict): + separators: Dict[str, str] + default_termination: Dict[str, Optional[str]] + include_prompt: bool + + +class ChatMLTemplate(TypedDict): + roles: Dict[str, RoleTemplate] diff --git a/tests/test_llama_chat_formatters.py b/tests/test_llama_chat_formatters.py new file mode 100644 index 000000000..30e042bfa --- /dev/null +++ b/tests/test_llama_chat_formatters.py @@ -0,0 +1,36 @@ +from typing import List + +import pytest + +from llama_cpp import ChatCompletionMessage +from llama_cpp.llama_chat_format import Llama2Formatter + + +@pytest.fixture +def sequence_of_messages() -> List[ChatCompletionMessage]: + return [ + ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), + ChatCompletionMessage( + role="user", content="Hi there! I need some help with Python." + ), + ChatCompletionMessage( + role="assistant", content="Of course! What do you need help with in Python?" + ), + ChatCompletionMessage( + role="user", + content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", + ), + ChatCompletionMessage( + role="assistant", + content="I can help with that! Would you like a recursive or iterative solution?", + ), + ChatCompletionMessage( + role="user", content="Let's go with a recursive solution." + ), + ] + + +def test_llama2_formatter(sequence_of_messages): + prompt = """<>Welcome to CodeHelp Bot!<>\n[INST] Hi there! I need some help with Python. [/INST]\nOf course! What do you need help with in Python?\n[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\nI can help with that! Would you like a recursive or iterative solution?\n[INST] Let's go with a recursive solution. [/INST]""" + llama2formatter = Llama2Formatter() + assert prompt == llama2formatter._format_messages(sequence_of_messages)