Skip to content

Commit

Permalink
feat: rebased and allowed more dynamic roles via config
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Feb 22, 2024
1 parent 85ddde2 commit 33bcf57
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 234 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,13 @@ datasets:
train_on_split: train # Optional[str] name of dataset split to load from

# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].

# Custom user instruction prompt
- path: repo
Expand Down
166 changes: 4 additions & 162 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import copy
import logging
from typing import Any, Dict, Optional

from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template

from axolotl.prompt_tokenizers import (
InvalidDataException,
ShareGPTPromptTokenizingStrategy,
parse_tokenized_to_result,
tokenize_prompt_default,
)
from axolotl.prompters import (
IGNORE_TOKEN_ID,
ShareGPTPrompterV2,
ShareGPTPrompterV2MultiRole,
)
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2

LOG = logging.getLogger("axolotl")

Expand All @@ -40,11 +30,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
Expand Down Expand Up @@ -90,28 +82,6 @@ def load_guanaco(tokenizer, cfg):
)


def load_multirole(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strategy = MultiRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2MultiRole(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]

return strategy


class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
Expand Down Expand Up @@ -178,131 +148,3 @@ def get_conversation_thread(self, prompt):
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns


class MultiRoleShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy for support of multi-role
"""

def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default()
conversation: Conversation = (
self.prompter._conversation.copy() # pylint: disable=protected-access
)
user, assistant = conversation.roles

input_roles = {
"human",
"funcresponse",
"funccaller",
"tool",
"tool_response",
user,
}
output_roles = {"gpt", "tool_caller", assistant}

# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
conversation.name == "vicuna_v1.1"
and "roles" in prompt
and len(prompt["roles"]) >= 2
):
role_remap = [
{"from": conversation.roles[0], "to": prompt["roles"][0]},
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]

try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if not isinstance(part, tuple):
LOG.warning(f"expected tuple, got {part}")
continue

role, content = part

# Uses "in" because role contains extra characters
input_turn = any(r in role.lower() for r in input_roles)
output_turn = any(r in role.lower() for r in output_roles)

if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else role
)
turn = role + content
# this is still the user query, we should
if not content.strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if self.train_on_inputs:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else role
)
turn = role + content
# this should be the assistant response, should end with an eos token
if not content.strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
add_eos_token = not (
conversation.name == "chatml"
and conversation.sep == self.tokenizer.eos_token
)
res = self._tokenize(
turn,
add_eos_token=add_eos_token,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
labels = copy.deepcopy(res["input_ids"])
if not self.train_on_inputs:
# mask out role tokens from the labels
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif role == "":
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
if self.train_on_inputs:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
34 changes: 26 additions & 8 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter

LOG = logging.getLogger("axolotl")

Expand All @@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):

def __init__(
self,
prompter,
prompter: Prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
Expand Down Expand Up @@ -340,6 +340,19 @@ def tokenize_prompt(self, prompt):
self.prompter._conversation.copy() # pylint: disable=protected-access
)

input_roles = {conversation.roles[0]}
output_roles = {conversation.roles[1]}

# Add roles from the config
if self.prompter.roles:
if "input" in self.prompter.roles and self.prompter.roles["input"]:
for role in self.prompter.roles["input"]:
input_roles.add(role)

if "output" in self.prompter.roles and self.prompter.roles["output"]:
for role in self.prompter.roles["output"]:
output_roles.add(role)

# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
Expand All @@ -360,11 +373,19 @@ def tokenize_prompt(self, prompt):
LOG.warning(f"expected tuple, got {part}")
continue

user, assistant = conversation.roles
role, content = part

# Uses "in" because role contains extra characters
if user in role:
input_turn = any(r in role.lower() for r in input_roles)
output_turn = any(r in role.lower() for r in output_roles)
empty_role = role.strip() == ""

if not any([input_turn, output_turn, empty_role]):
LOG.warning(f"unhandled role: {role}")
continue

# Uses "in" because role contains extra characters
if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
Expand All @@ -384,7 +405,7 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant in role:
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
Expand Down Expand Up @@ -426,9 +447,6 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
Expand Down
Loading

0 comments on commit 33bcf57

Please sign in to comment.