diff --git a/README.md b/README.md index 3c9f030007..754c017719 100644 --- a/README.md +++ b/README.md @@ -597,9 +597,13 @@ datasets: train_on_split: train # Optional[str] name of dataset split to load from # Optional[str] fastchat conversation type, only used with type: sharegpt - conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py + conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py field_human: # Optional[str]. Human key to use for conversation. field_model: # Optional[str]. Assistant key to use for conversation. + # Add additional keys from your dataset as input or output roles + roles: + input: # Optional[List[str]]. These will be masked based on train_on_input + output: # Optional[List[str]]. # Custom user instruction prompt - path: repo diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index dbd46e82bc..bc5f16acb5 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,21 +1,11 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" -import copy import logging from typing import Any, Dict, Optional from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template -from axolotl.prompt_tokenizers import ( - InvalidDataException, - ShareGPTPromptTokenizingStrategy, - parse_tokenized_to_result, - tokenize_prompt_default, -) -from axolotl.prompters import ( - IGNORE_TOKEN_ID, - ShareGPTPrompterV2, - ShareGPTPrompterV2MultiRole, -) +from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy +from axolotl.prompters import ShareGPTPrompterV2 LOG = logging.getLogger("axolotl") @@ -40,11 +30,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ) field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None + roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation=conversation, role_key_model=field_model, role_key_human=field_human, + roles=roles, ), tokenizer, cfg.train_on_inputs, @@ -90,28 +82,6 @@ def load_guanaco(tokenizer, cfg): ) -def load_multirole(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None - ) - field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None - field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None - strategy = MultiRoleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2MultiRole( - conversation=conversation, - role_key_model=field_model, - role_key_human=field_human, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - if ds_cfg and "strict" in ds_cfg: - strategy.strict = ds_cfg["strict"] - - return strategy - - class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ basic sharegpt strategy to grab conversations from the sample row @@ -178,131 +148,3 @@ def get_conversation_thread(self, prompt): {"from": role_map[t["role"]], "value": t["content"]} for t in conversations ] return turns - - -class MultiRoleShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): - """ - sharegpt strategy for support of multi-role - """ - - def tokenize_prompt(self, prompt): - # Initial values. We will append to these as we go through the conversation. - result, current_len = tokenize_prompt_default() - conversation: Conversation = ( - self.prompter._conversation.copy() # pylint: disable=protected-access - ) - user, assistant = conversation.roles - - input_roles = { - "human", - "funcresponse", - "funccaller", - "tool", - "tool_response", - user, - } - output_roles = {"gpt", "tool_caller", assistant} - - # support for custom roles from the dataset, only useful for vicuna style prompts/roles - role_remap = [] - if ( - conversation.name == "vicuna_v1.1" - and "roles" in prompt - and len(prompt["roles"]) >= 2 - ): - role_remap = [ - {"from": conversation.roles[0], "to": prompt["roles"][0]}, - {"from": conversation.roles[1], "to": prompt["roles"][1]}, - ] - - try: - for _, part in enumerate( - self.prompter.build_prompt(self.get_conversation_thread(prompt)) - ): - if not isinstance(part, tuple): - LOG.warning(f"expected tuple, got {part}") - continue - - role, content = part - - # Uses "in" because role contains extra characters - input_turn = any(r in role.lower() for r in input_roles) - output_turn = any(r in role.lower() for r in output_roles) - - if input_turn: - role = ( - role.replace(role_remap[0]["from"], role_remap[0]["to"]) - if role_remap - else role - ) - turn = role + content - # this is still the user query, we should - if not content.strip(): - LOG.warning(f"user turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=False, - strip_bos_token=True, - ) - if self.train_on_inputs: - labels = copy.deepcopy(res["input_ids"]) - else: - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif output_turn: - role = ( - role.replace(role_remap[1]["from"], role_remap[1]["to"]) - if role_remap - else role - ) - turn = role + content - # this should be the assistant response, should end with an eos token - if not content.strip(): - LOG.warning(f"assistant turn has empty text: {prompt}") - add_eos_token = not ( - conversation.name == "chatml" - and conversation.sep == self.tokenizer.eos_token - ) - res = self._tokenize( - turn, - add_eos_token=add_eos_token, - strip_bos_token=True, - ) - role_res = self._tokenize( - role.rstrip(), - add_eos_token=False, - strip_bos_token=True, - ) - labels = copy.deepcopy(res["input_ids"]) - if not self.train_on_inputs: - # mask out role tokens from the labels - len_role = len(role_res["input_ids"]) - labels[:len_role] = [IGNORE_TOKEN_ID] * min( - len_role, len(labels) - ) - elif role == "": - turn = content - # this is only ever the first part, should include the bos token and the user query - res = self._tokenize( - turn, add_eos_token=False, strip_bos_token=False - ) - if self.train_on_inputs: - labels = copy.deepcopy(res["input_ids"]) - else: - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - else: - LOG.warning(f"unhandled role: {role}") - continue - - # pylint: disable=duplicate-code - result, current_len = parse_tokenized_to_result( - result, - current_len, - res, - labels, - pad_token_id=self.tokenizer.pad_token_id, - ) - return result - except (KeyError, AssertionError, IndexError) as err: - raise InvalidDataException(str(err)) from err diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index a5c243f7e6..e2dc756297 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -11,7 +11,7 @@ from axolotl.monkeypatch.fastchat_conversation_turns import ( add_get_turns_to_conversation, ) -from axolotl.prompters import IGNORE_TOKEN_ID +from axolotl.prompters import IGNORE_TOKEN_ID, Prompter LOG = logging.getLogger("axolotl") @@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC): def __init__( self, - prompter, + prompter: Prompter, tokenizer, train_on_inputs: bool = False, sequence_len: int = 2048, @@ -340,6 +340,19 @@ def tokenize_prompt(self, prompt): self.prompter._conversation.copy() # pylint: disable=protected-access ) + input_roles = {conversation.roles[0]} + output_roles = {conversation.roles[1]} + + # Add roles from the config + if self.prompter.roles: + if "input" in self.prompter.roles and self.prompter.roles["input"]: + for role in self.prompter.roles["input"]: + input_roles.add(role) + + if "output" in self.prompter.roles and self.prompter.roles["output"]: + for role in self.prompter.roles["output"]: + output_roles.add(role) + # support for custom roles from the dataset, only useful for vicuna style prompts/roles role_remap = [] if ( @@ -360,11 +373,19 @@ def tokenize_prompt(self, prompt): LOG.warning(f"expected tuple, got {part}") continue - user, assistant = conversation.roles role, content = part # Uses "in" because role contains extra characters - if user in role: + input_turn = any(r in role.lower() for r in input_roles) + output_turn = any(r in role.lower() for r in output_roles) + empty_role = role.strip() == "" + + if not any([input_turn, output_turn, empty_role]): + LOG.warning(f"unhandled role: {role}") + continue + + # Uses "in" because role contains extra characters + if input_turn: role = ( role.replace(role_remap[0]["from"], role_remap[0]["to"]) if role_remap @@ -384,7 +405,7 @@ def tokenize_prompt(self, prompt): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif assistant in role: + elif output_turn: role = ( role.replace(role_remap[1]["from"], role_remap[1]["to"]) if role_remap @@ -426,9 +447,6 @@ def tokenize_prompt(self, prompt): else: # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - else: - LOG.warning(f"unhandled role: {role}") - continue # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 05bf53510e..cba057ca87 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -259,6 +259,11 @@ def __repr__(self) -> str: "Role did not alternate between turns (gpt and human). Please check your data." ) +CONVERSATION_ROLE_FORMAT = { + "chatml": "<|im_start|>{ROLE}", + "zephyr": "<|{ROLE}|>", +} + class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods """ @@ -274,6 +279,7 @@ def __init__( conversation: Optional[Union[str, Conversation]] = None, role_key_human: Optional[str] = None, role_key_model: Optional[str] = None, + roles: Optional[dict] = None, ): if conversation: if isinstance(conversation, Conversation): @@ -287,6 +293,8 @@ def __init__( if role_key_model: self.role_key_model = role_key_model + self.roles = roles + def _build_result(self, source): if len(source) < 2: # If there isn't a back and forth conversation, ignore it @@ -315,11 +323,23 @@ def _build_result(self, source): conv.messages = [] for _, sentence in enumerate(source): - role = roles[sentence["from"]] - if len(conv.messages) > 0 and ( - (role == conv.messages[-1][0]) or (role not in conv.roles) - ): + from_role = sentence["from"] + if from_role in roles: + role = roles[from_role] + else: + if self._conversation.name not in CONVERSATION_ROLE_FORMAT: + raise NotImplementedError( + f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet." + "Please help us by creating an Issue to add support for this role." + ) + + role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( + ROLE=from_role + ) + + if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") + conv.append_message(role, sentence["value"]) return conv.get_turns() @@ -347,73 +367,16 @@ def __init__( conversation: Optional[Union[str, Conversation]] = None, role_key_human: Optional[str] = None, role_key_model: Optional[str] = None, + roles: Optional[dict] = None, ): super().__init__( conversation=conversation, role_key_human=role_key_human, role_key_model=role_key_model, + roles=roles, ) -CONVERSATION_ROLE_FORMAT = { - "chatml": "<|im_start|>{ROLE}", - "zephyr": "<|{ROLE}|>", -} - - -class ShareGPTPrompterV2MultiRole(ShareGPTPrompterV2): - """ - An multi-role V2 prompter that generates prompts for the ShareGPT that supports multi-role - """ - - def _build_result(self, source): - if len(source) < 2: - # If there isn't a back and forth conversation, ignore it - # also happens on the data splitting leaving empty conversations - raise IndexError( - f"A conversation entry has less than 2 messages :\n{source}" - ) - - conv = self._conversation.copy() - - # Add the conversation system prompt if provided, otherwise use the default one - if source[0]["from"] == "system": - conv.set_system_message(source[0]["value"]) - source.pop(0) - - roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]} - - try: - # Apply prompt templates - if source[0]["from"] not in roles: - # Skip the first one if it is not from human - source = source[1:] - except IndexError as err: - # sometimes there is a bing or system chat - raise err - - conv.messages = [] - for _, sentence in enumerate(source): - from_role = sentence["from"] - if from_role in roles: - role = roles[from_role] - else: - if self._conversation.name not in CONVERSATION_ROLE_FORMAT: - raise NotImplementedError( - f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet." - ) - - role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( - ROLE=from_role - ) - - if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): - LOG.warning(f"Roles did not alternate: {sentence}") - conv.append_message(role, sentence["value"]) - - return conv.get_turns() - - class UnsupportedPrompter(Prompter): """ A dummy class for custom prompters