Skip to content

Commit

Permalink
rebase from main
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreymeetkai committed Aug 14, 2024
2 parents 1be6c31 + 2041dad commit cd2dfde
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 77 deletions.
13 changes: 13 additions & 0 deletions functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,19 @@ def update_fsm_gen_state(
"""
raise NotImplementedError

@abstractmethod
def get_options_from_gen_state(self, gen_state: Dict, tools_or_functions: List):
"""Gets a list of options for grammar sampling to generate tokens to form given the gen state
Args:
gen_state (Dict): _description_
tools_or_functions (List): _description_
Returns:
_type_: _description_
"""
return []

def get_force_text_generation_prefix(self):
"""This function will be used for force-text generation. Returns empty string by default"""
return ""
Expand Down
15 changes: 7 additions & 8 deletions functionary/prompt_template/llama31_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ def stream_delta_text(
)

responses = []
options = []
options = self.get_options_from_gen_state(
gen_state=gen_state, tools_or_functions=tools_or_functions
)

if gen_state["stage"] == "text-gen":
if gen_state["gen_empty_text"]:
Expand Down Expand Up @@ -344,14 +346,11 @@ def update_fsm_gen_state(

return gen_state

def get_chat_template_jinja(self) -> str:
if self._chat_template is None:
with open(
f"./functionary/prompt_template/jinja_templates/{self.version}.txt", "r"
) as f:
self._chat_template = f.read()
def get_options_from_gen_state(self, gen_state: Dict, tools_or_functions: List):
return []

return self._chat_template
def get_chat_template_jinja(self):
return super().get_chat_template_jinja()

def get_force_function_call_prefix(self, function_name: str):
return f"<function={function_name}>"
Expand Down
37 changes: 19 additions & 18 deletions functionary/prompt_template/llama3_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,9 @@ def grammar_sample(
grammar_sampled_token_id = model_sampled_token_id
grammar_sampled_token = tokenizer.decode([model_sampled_token_id])

options = []
# Form the pre-function options (<|reserved_token_249|> or <|eot_id|>) to update gen_state
if gen_state["stage"] == "pre-function":
options = [self.function_separator, "<|eot_id|>"]
# Form the functions options for grammar sampling
elif gen_state["stage"] == "function":
options = [tool_or_func["name"] for tool_or_func in tools_or_functions]
if gen_state["add_code_interpreter"]:
options.append("python")
options = self.get_options_from_gen_state(
gen_state=gen_state, tools_or_functions=tools_or_functions
)

# Loop through the list of token ids sorted in descending order. For "function"
# stage, form a mask made up of booleans where the index of the mask == index
Expand Down Expand Up @@ -229,15 +223,9 @@ def stream_delta_text(
responses = []

# Form the options for the following stages
options = []
# Form the pre-function options (<|reserved_token_249|> or <|eot_id|>) to update gen_state
if gen_state["stage"] == "pre-function":
options = [self.function_separator, "<|eot_id|>"]
# Form the functions options for grammar sampling
elif gen_state["stage"] == "function":
options = [tool_or_func["name"] for tool_or_func in tools_or_functions]
if gen_state["add_code_interpreter"]:
options.append("python")
options = self.get_options_from_gen_state(
gen_state=gen_state, tools_or_functions=tools_or_functions
)

if gen_state["stage"] == "text-gen":
if gen_state["gen_empty_text"]:
Expand Down Expand Up @@ -360,6 +348,19 @@ def update_fsm_gen_state(
gen_state = self._reset_fsm_curr_text_and_tokens(gen_state=gen_state)
return gen_state

def get_options_from_gen_state(self, gen_state: Dict, tools_or_functions: List):
options = []
# Form the pre-function options (<|reserved_token_249|> or <|eot_id|>) to update gen_state
if gen_state["stage"] == "pre-function":
options = [self.function_separator, "<|eot_id|>"]
# Form the functions options for grammar sampling
elif gen_state["stage"] == "function":
options = [tool_or_func["name"] for tool_or_func in tools_or_functions]
if gen_state["add_code_interpreter"]:
options.append("python")

return options

def get_force_text_generation_prefix(self):
return ""

Expand Down
45 changes: 22 additions & 23 deletions functionary/prompt_template/llama3_prompt_template_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,9 @@ def grammar_sample(
grammar_sampled_token_id, grammar_sampled_token = None, None

# Form the options for the following stages
options = []
if gen_state["stage"] == "pre-function":
options = [self.function_separator, self.eos_token]
elif gen_state["stage"] == "function":
options = [tool_or_func["name"] for tool_or_func in tools_or_functions]
if gen_state["add_all_recipient"]:
options.append("all")
if gen_state["add_code_interpreter"]:
options.append("python")
elif gen_state["stage"] == "pre-parameter":
options = [self.fn_param_sep_token]
options = self.get_options_from_gen_state(
gen_state=gen_state, tools_or_functions=tools_or_functions
)

# No grammar sampling needed if gen_state not in the following stages. Return model_sampled_token_id
if gen_state["stage"] not in ["pre-function", "function", "pre-parameter"]:
Expand Down Expand Up @@ -249,18 +241,9 @@ def stream_delta_text(

responses = []

# Form the options for the following stages
options = []
if gen_state["stage"] == "pre-function":
options = [self.function_separator, self.eos_token]
elif gen_state["stage"] == "function":
options = [(tool_or_func["name"]) for tool_or_func in tools_or_functions]
if gen_state["add_all_recipient"]:
options.append("all")
if gen_state["add_code_interpreter"]:
options.append("python")
elif gen_state["stage"] == "pre-parameter":
options = [self.fn_param_sep_token]
options = self.get_options_from_gen_state(
gen_state=gen_state, tools_or_functions=tools_or_functions
)

if gen_state["stage"] == "text-gen":
if delta_text != self.function_separator:
Expand Down Expand Up @@ -379,3 +362,19 @@ def update_fsm_gen_state(
gen_state["func_name"] = ""

return gen_state

def get_options_from_gen_state(self, gen_state: Dict, tools_or_functions: List):
# Form the options for the following stages
options = []
if gen_state["stage"] == "pre-function":
options = [self.function_separator, self.eos_token]
elif gen_state["stage"] == "function":
options = [(tool_or_func["name"]) for tool_or_func in tools_or_functions]
if gen_state["add_all_recipient"]:
options.append("all")
if gen_state["add_code_interpreter"]:
options.append("python")
elif gen_state["stage"] == "pre-parameter":
options = [self.fn_param_sep_token]

return options
52 changes: 24 additions & 28 deletions functionary/prompt_template/prompt_template_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,9 @@ def grammar_sample(
grammar_sampled_token_id, grammar_sampled_token = None, None

# Form the options for the following stages
options = []
if gen_state["stage"] == "pre-function":
options = [
f"\n{self.from_token} assistant\n{self.recipient_token}",
self.stop_token,
]
elif gen_state["stage"] == "function":
options = [tool_or_func["name"] for tool_or_func in tools_or_functions]
if gen_state["add_all_recipient"]:
options.append("all")
if gen_state["add_code_interpreter"]:
options.append("python")
elif gen_state["stage"] == "pre-parameter":
options = [self.fn_param_sep_token]
options = self.get_options_from_gen_state(
gen_state=gen_state, tools_or_functions=tools_or_functions
)

# No grammar sampling needed if gen_state not in the following stages. Return model_sampled_token_id
if gen_state["stage"] not in ["function", "pre-parameter"]:
Expand Down Expand Up @@ -279,20 +268,9 @@ def stream_delta_text(
responses = []

# Form the options for the following stages
options = []
if gen_state["stage"] == "pre-function":
options = [
f"\n{self.from_token} assistant\n{self.recipient_token}",
self.stop_token,
]
elif gen_state["stage"] == "function":
options = [tool_or_func["name"] for tool_or_func in tools_or_functions]
if gen_state["add_all_recipient"]:
options.append("all")
if gen_state["add_code_interpreter"]:
options.append("python")
elif gen_state["stage"] == "pre-parameter":
options = [self.fn_param_sep_token]
options = self.get_options_from_gen_state(
gen_state=gen_state, tools_or_functions=tools_or_functions
)

if gen_state["stage"] == "text-gen":
if delta_text == "\n":
Expand Down Expand Up @@ -428,6 +406,24 @@ def update_fsm_gen_state(

return gen_state

def get_options_from_gen_state(self, gen_state: Dict, tools_or_functions: List):
options = []
if gen_state["stage"] == "pre-function":
options = [
f"\n{self.from_token} assistant\n{self.recipient_token}",
self.stop_token,
]
elif gen_state["stage"] == "function":
options = [tool_or_func["name"] for tool_or_func in tools_or_functions]
if gen_state["add_all_recipient"]:
options.append("all")
if gen_state["add_code_interpreter"]:
options.append("python")
elif gen_state["stage"] == "pre-parameter":
options = [self.fn_param_sep_token]

return options

def get_force_text_generation_prefix(self):
return f"all{self.fn_param_sep_token}"

Expand Down

0 comments on commit cd2dfde

Please sign in to comment.