Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Nov 21, 2023
1 parent 422ebc8 commit 7a3f878
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 71 deletions.
10 changes: 6 additions & 4 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,9 +1036,9 @@ def eval(self, tokens: Sequence[int]):
offset = (
0 if self.context_params.logits_all else n_tokens - 1
) # NOTE: Only save the last token logits if logits_all is False
self.scores[n_past + offset : n_past + n_tokens, :].reshape(
-1
)[:] = self._ctx.get_logits()[offset * cols: rows * cols]
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[
:
] = self._ctx.get_logits()[offset * cols : rows * cols]
# Update n_tokens
self.n_tokens += n_tokens

Expand Down Expand Up @@ -1135,7 +1135,9 @@ def sample(
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
self._ctx.sample_typical(candidates=self._candidates, p=typical_p, min_keep=1)
self._ctx.sample_typical(
candidates=self._candidates, p=typical_p, min_keep=1
)
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
Expand Down
110 changes: 73 additions & 37 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def format_phind(
_prompt = _format_add_colon_single(_system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt)


@register_chat_format("intel")
def format_intel(
messages: List[llama_types.ChatCompletionRequestMessage],
Expand Down Expand Up @@ -588,6 +589,7 @@ def format_mistrallite(
_prompt = _format_no_colon_single(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt)


@register_chat_format("chatml")
def format_chatml(
messages: List[llama_types.ChatCompletionRequestMessage],
Expand All @@ -604,6 +606,7 @@ def format_chatml(
_prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)


@register_chat_format("openchat")
def format_openchat(
messages: List[llama_types.ChatCompletionRequestMessage],
Expand All @@ -612,7 +615,9 @@ def format_openchat(
system_template = "{system_message}<|end_of_turn|>"
system_message = _get_system_message(messages)
system_message = system_template.format(system_message=system_message)
_roles = dict(user="GPT4 Correct User: ", assistant="<|end_of_turn|>GPT4 Correct Assistant: ")
_roles = dict(
user="GPT4 Correct User: ", assistant="<|end_of_turn|>GPT4 Correct Assistant: "
)
_sep = "<|end_of_turn|>"
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
Expand Down Expand Up @@ -651,46 +656,60 @@ def functionary_chat_handler(
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""

def generate_type_definition(param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs) -> str:
indent = ' ' * indent_level
if '$ref' in param:
def generate_type_definition(
param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs
) -> str:
indent = " " * indent_level
if "$ref" in param:
# Reference to a shared definition
ref_name = param['$ref'].split('/')[-1] # Extract the type name from the reference
ref_name = param["$ref"].split("/")[
-1
] # Extract the type name from the reference
return ref_name
elif param.get('type') == 'array':
items = param.get('items', {})
elif param.get("type") == "array":
items = param.get("items", {})
item_type = generate_type_definition(items, indent_level + 1, shared_defs)
return f"Array<{item_type}>"
elif param.get('type') == 'object':
properties = param.get('properties', {})
elif param.get("type") == "object":
properties = param.get("properties", {})
nested_schema = "{\n"
for nested_param_name, nested_param in properties.items():
nested_param_type = generate_type_definition(nested_param, indent_level + 1, shared_defs)
nested_schema += f"{indent} {nested_param_name}: {nested_param_type},\n"
nested_param_type = generate_type_definition(
nested_param, indent_level + 1, shared_defs
)
nested_schema += (
f"{indent} {nested_param_name}: {nested_param_type},\n"
)
nested_schema += indent + "}"
return nested_schema
elif 'enum' in param:
elif "enum" in param:
# Enum type
return " | ".join([f'"{enum_value}"' for enum_value in param['enum']])
return " | ".join([f'"{enum_value}"' for enum_value in param["enum"]])
else:
# Simple type
return param.get('type', 'any')
return param.get("type", "any")

def generate_shared_definitions(shared_defs, indent_level: int) -> str:
indent = ' ' * indent_level
indent = " " * indent_level
shared_definitions = ""
for def_name, def_properties in shared_defs.items():
shared_definitions += f"{indent}type {def_name} = "
if def_properties.get('type') == 'object':
shared_definitions += generate_type_definition(def_properties, indent_level, shared_defs)
elif 'enum' in def_properties:
if def_properties.get("type") == "object":
shared_definitions += generate_type_definition(
def_properties, indent_level, shared_defs
)
elif "enum" in def_properties:
# Enum type
shared_definitions += " | ".join([f'"{enum_value}"' for enum_value in def_properties['enum']])
shared_definitions += " | ".join(
[f'"{enum_value}"' for enum_value in def_properties["enum"]]
)
shared_definitions += ";\n"
return shared_definitions

def generate_schema_from_functions(functions, namespace="functions") -> str:
schema = "// Supported function definitions that should be called when necessary.\n"
schema = (
"// Supported function definitions that should be called when necessary.\n"
)
schema += f"namespace {namespace} {{\n\n"

# Generate shared definitions
Expand All @@ -706,10 +725,10 @@ def generate_schema_from_functions(functions, namespace="functions") -> str:
description = function.get("description", "")
parameters = function.get("parameters", {})
required_params = parameters.get("required", [])

schema += f" // {description}\n"
schema += f" type {function_name} = (_: {{\n"

for param_name, param in parameters.get("properties", {}).items():
param_description = param.get("description", "")
param_type = generate_type_definition(param, 2, shared_definitions)
Expand All @@ -733,13 +752,18 @@ def prepare_messages_for_inference(
role="system", content=generate_schema_from_functions(functions)
)
)

if tools is not None:
all_messages.append(
llama_types.ChatCompletionRequestSystemMessage(
role="system", content=generate_schema_from_functions(
[tool["function"] for tool in tools if tool["type"] == "function"]
)
role="system",
content=generate_schema_from_functions(
[
tool["function"]
for tool in tools
if tool["type"] == "function"
]
),
)
)

Expand Down Expand Up @@ -790,7 +814,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
elif "function_call" in msg:
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
elif "tool_calls" in msg and len(msg["tool_calls"]) > 0:
for tool_call in msg["tool_calls"]: # NOTE: probably doesn't work with the functionary model
for tool_call in msg[
"tool_calls"
]: # NOTE: probably doesn't work with the functionary model
return f"assistant to={tool_call['id']}:\n{tool_call['function']['arguments']}</s>\n"
elif msg["content"] is None:
return "assistant"
Expand All @@ -800,12 +826,14 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
raise ValueError(f"Unsupported role: {msg['role']}")

return "".join([message_to_str(msg) for msg in all_messages])

if tools is not None:
functions = [tool["function"] for tool in tools if tool["type"] == "function"]

if tool_choice is not None:
function_call = tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
function_call = (
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
)

prompt = prepare_messages_for_inference(messages, functions, tools)

Expand Down Expand Up @@ -861,19 +889,27 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
if tool["type"] == "function" and tool["function"]["name"] == function_call:
function_body = tool["function"]["parameters"]
break

if function_body is not None:
try:
with suppress_stdout_stderr(disable=llama.verbose):
grammar_text = llama_grammar.json_schema_to_gbnf(json.dumps(function_body))
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.json_schema_to_gbnf(json.dumps(function_body)))
grammar_text = llama_grammar.json_schema_to_gbnf(
json.dumps(function_body)
)
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.json_schema_to_gbnf(json.dumps(function_body))
)
print(grammar_text)
except Exception as e:
if llama.verbose:
print("Failed to parse function body as JSON schema, falling back to default grammar")
print(
"Failed to parse function body as JSON schema, falling back to default grammar"
)
print(e)
with suppress_stdout_stderr(disable=llama.verbose):
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF
)
else:
with suppress_stdout_stderr(disable=llama.verbose):
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
Expand Down Expand Up @@ -929,9 +965,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
"function": {
"name": function_call,
"arguments": completion["choices"][0]["text"],
}
},
}
]
],
},
"finish_reason": "tool_calls",
}
Expand Down
55 changes: 25 additions & 30 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


# Disable warning for model and model_alias settings
BaseSettings.model_config['protected_namespaces'] = ()
BaseSettings.model_config["protected_namespaces"] = ()


class Settings(BaseSettings):
Expand Down Expand Up @@ -68,7 +68,9 @@ class Settings(BaseSettings):
description="Use mlock.",
)
# Context Params
seed: int = Field(default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random.")
seed: int = Field(
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."
)
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval."
Expand All @@ -83,30 +85,16 @@ class Settings(BaseSettings):
ge=0,
description="The number of threads to use when batch processing.",
)
rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED
)
rope_freq_base: float = Field(
default=0.0, description="RoPE base frequency"
)
rope_scaling_type: int = Field(default=llama_cpp.LLAMA_ROPE_SCALING_UNSPECIFIED)
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor"
)
yarn_ext_factor: float = Field(
default=-1.0
)
yarn_attn_factor: float = Field(
default=1.0
)
yarn_beta_fast: float = Field(
default=32.0
)
yarn_beta_slow: float = Field(
default=1.0
)
yarn_orig_ctx: int = Field(
default=0
)
yarn_ext_factor: float = Field(default=-1.0)
yarn_attn_factor: float = Field(default=1.0)
yarn_beta_fast: float = Field(default=32.0)
yarn_beta_slow: float = Field(default=1.0)
yarn_orig_ctx: int = Field(default=0)
mul_mat_q: bool = Field(
default=True, description="if true, use experimental mul_mat_q kernels"
)
Expand All @@ -122,7 +110,7 @@ class Settings(BaseSettings):
# LoRA Params
lora_base: Optional[str] = Field(
default=None,
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model."
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.",
)
lora_path: Optional[str] = Field(
default=None,
Expand Down Expand Up @@ -384,7 +372,9 @@ def create_app(settings: Optional[Settings] = None):
chat_handler = None
if settings.chat_format == "llava-1-5":
assert settings.clip_model_path is not None
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(clip_model_path=settings.clip_model_path, verbose=settings.verbose)
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
##

llama = llama_cpp.Llama(
Expand Down Expand Up @@ -587,9 +577,10 @@ async def get_event_publisher(

grammar = Field(
default=None,
description="A CBNF grammar (as string) to be used for formatting the model's output."
description="A CBNF grammar (as string) to be used for formatting the model's output.",
)


class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]] = Field(
default="", description="The prompt to generate completions for."
Expand Down Expand Up @@ -690,7 +681,8 @@ async def create_completion(
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

iterator_or_completion: Union[
llama_cpp.CreateCompletionResponse, Iterator[llama_cpp.CreateCompletionStreamResponse]
llama_cpp.CreateCompletionResponse,
Iterator[llama_cpp.CreateCompletionStreamResponse],
] = await run_in_threadpool(llama, **kwargs)

if isinstance(iterator_or_completion, Iterator):
Expand Down Expand Up @@ -748,7 +740,9 @@ class ChatCompletionRequestMessage(BaseModel):
role: Literal["system", "user", "assistant", "function"] = Field(
default="user", description="The role of the message."
)
content: Optional[str] = Field(default="", description="The content of the message.")
content: Optional[str] = Field(
default="", description="The content of the message."
)


class CreateChatCompletionRequest(BaseModel):
Expand All @@ -770,9 +764,10 @@ class CreateChatCompletionRequest(BaseModel):
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
default=None,
description="A tool to apply to the generated completions.",
) # TODO: verify
) # TODO: verify
max_tokens: Optional[int] = Field(
default=None, description="The maximum number of tokens to generate. Defaults to inf"
default=None,
description="The maximum number of tokens to generate. Defaults to inf",
)
temperature: float = temperature_field
top_p: float = top_p_field
Expand Down

0 comments on commit 7a3f878

Please sign in to comment.