Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/sharegpt multirole #1

Open
wants to merge 60 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
1623a50
feat(prompt): support multiple roles for sharegpt
NanoCode012 Jan 17, 2024
85ddde2
fix: add handling of empty role back
NanoCode012 Jan 18, 2024
33bcf57
feat: rebased and allowed more dynamic roles via config
NanoCode012 Feb 22, 2024
b54869c
fix: variable
NanoCode012 Feb 22, 2024
8751ffb
chore: update message
NanoCode012 Feb 22, 2024
24591e8
feat: add vicuna format
NanoCode012 Feb 23, 2024
9e171a9
fix: JSON serializable error
NanoCode012 Feb 23, 2024
2ed52bd
fix(readme): Clarify doc for tokenizer_config (#1323) [skip ci]
NanoCode012 Feb 24, 2024
5cf226e
Use yaml codeblock for config.yaml field (#1303) [skip ci]
kallewoof Feb 24, 2024
5894f0e
make mlflow optional (#1317)
winglian Feb 26, 2024
cc3cebf
Pydantic 2.x cfg (#1239)
winglian Feb 26, 2024
c6b01e0
chore: update readme to be more clear (#1326) [skip ci]
NanoCode012 Feb 26, 2024
d756534
ADD: push checkpoints to mlflow artifact registry (#1295) [skip ci]
JohanWork Feb 26, 2024
7de912e
hotfix for capabilities loading (#1331)
winglian Feb 26, 2024
cf00231
hotfix for lora rank (#1332)
winglian Feb 26, 2024
e7eed20
hotfix for missing outputs params (#1333)
winglian Feb 26, 2024
269c543
hotfix to exclude_unset from pydantic config when converting back to …
winglian Feb 26, 2024
f30d062
Add StableLM 2 Example Scripts (#1327) [skip ci]
ncoop57 Feb 26, 2024
1648279
add lion-pytorch optimizer (#1299) [skip ci]
maximegmd Feb 26, 2024
1e3d530
Support user-defined prompt processing strategies for dpo (#1248)
nopperl Feb 26, 2024
3f69571
more pydantic fixes (#1338)
winglian Feb 27, 2024
0f6af36
Mps mistral lora (#1292) [skip ci]
maximegmd Feb 27, 2024
5be8b55
fix: checkpoint saving with deepspeed (#1321)
NanoCode012 Feb 27, 2024
5265cd6
Update debugging.md (#1339) [skip ci]
hamelsmu Feb 27, 2024
2c9c88b
fix steps check for anneal on first cycle (#1316)
winglian Feb 27, 2024
2b9687f
Update fastchat_conversation_turns.py (#1294) [skip ci]
eltociear Feb 27, 2024
c1a7b3d
add gemma instruct chat template (#1341)
winglian Feb 27, 2024
0f985e1
more fixes 20240228 (#1342) [skip ci]
winglian Feb 28, 2024
6d4bbb8
deprecate py 3.9 support, set min pytorch version (#1343) [skip ci]
winglian Feb 28, 2024
3a5a2d2
Fix `use_mlflow` to be bool instead of str (#1344)
chiragjn Feb 28, 2024
6b3b271
fix for protected model_ namespace w pydantic (#1345)
winglian Feb 28, 2024
0001862
run tests again on Modal (#1289) [skip ci]
winglian Feb 29, 2024
170d4d7
chore: enable sample_packing for Gemma (#1351)
NanoCode012 Mar 2, 2024
b5b4492
Fix validation for early stopping (#1358)
chiragjn Mar 4, 2024
4d09b42
plain input/output prompt strategy w/o chat templates (#1346)
winglian Mar 4, 2024
decb66e
lora+ support (#1352)
winglian Mar 5, 2024
2598c9f
allow the sharegpt handler to also better handle datasets destined fo…
winglian Mar 5, 2024
8984bf1
Update tinyllama lora.yml to fix eval packing issue (#1362)
rasbt Mar 5, 2024
e0f1895
add starcoder2 (#1349)
ehartford Mar 6, 2024
3765747
Remove unsupported python version 3.9 from README (#1364) [skip ci]
nirogu Mar 6, 2024
0cfdb2c
support for DoRA w/ PEFT (#1363)
winglian Mar 6, 2024
ed70a08
add docs for `input_output` format (#1367) [skip ci]
hamelsmu Mar 6, 2024
58b0d4b
update flash attention for gemma support: (#1368)
winglian Mar 6, 2024
638c2da
JarvisLabs (#1372)
winglian Mar 7, 2024
9b6ee83
FDSP + QLoRA (#1378)
winglian Mar 8, 2024
3fd8093
validation for fsdp and deepspeed (#1388) [skip ci]
winglian Mar 11, 2024
7659c00
support for rslora (#1387) [skip ci]
winglian Mar 11, 2024
0bc114d
Fix pydantic configuration for the max_memory input (#1385) [skip ci]
dandm1 Mar 11, 2024
b0ee9ec
Set `gradient_clipping` to `auto` in DeepSpeed configs (#1382) [skip ci]
seungduk-yanolja Mar 11, 2024
b7d8a7d
Add Glaive conversation format support (#1365)
brianfitzgerald Mar 11, 2024
4326520
chore: lint (#1389)
winglian Mar 11, 2024
58fa1ee
Merge branch 'main' into feat/sharegpt_multirole
NanoCode012 Mar 12, 2024
2b9d66c
fix: typing
NanoCode012 Mar 12, 2024
ad70d34
fix: don't remap for unknown keys
NanoCode012 Mar 12, 2024
3031632
fix: add roles to pydantic
NanoCode012 Mar 12, 2024
c2738b3
feat: add test
NanoCode012 Mar 12, 2024
320312f
chore: remove leftover print
NanoCode012 Mar 12, 2024
4b86dd8
chore: remove leftover comment
NanoCode012 Mar 12, 2024
f00f63b
chore: remove print
NanoCode012 Mar 12, 2024
cc545db
fix: update test to use chatml
NanoCode012 Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: rebased and allowed more dynamic roles via config
NanoCode012 committed Feb 22, 2024
commit 33bcf57a5e1e0abbae2eefc6d0048471a2df6c9f
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -597,9 +597,13 @@ datasets:
train_on_split: train # Optional[str] name of dataset split to load from

# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].

# Custom user instruction prompt
- path: repo
166 changes: 4 additions & 162 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import copy
import logging
from typing import Any, Dict, Optional

from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template

from axolotl.prompt_tokenizers import (
InvalidDataException,
ShareGPTPromptTokenizingStrategy,
parse_tokenized_to_result,
tokenize_prompt_default,
)
from axolotl.prompters import (
IGNORE_TOKEN_ID,
ShareGPTPrompterV2,
ShareGPTPrompterV2MultiRole,
)
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2

LOG = logging.getLogger("axolotl")

@@ -40,11 +30,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
@@ -90,28 +82,6 @@ def load_guanaco(tokenizer, cfg):
)


def load_multirole(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strategy = MultiRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2MultiRole(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]

return strategy


class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
@@ -178,131 +148,3 @@ def get_conversation_thread(self, prompt):
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns


class MultiRoleShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy for support of multi-role
"""

def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default()
conversation: Conversation = (
self.prompter._conversation.copy() # pylint: disable=protected-access
)
user, assistant = conversation.roles

input_roles = {
"human",
"funcresponse",
"funccaller",
"tool",
"tool_response",
user,
}
output_roles = {"gpt", "tool_caller", assistant}

# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
conversation.name == "vicuna_v1.1"
and "roles" in prompt
and len(prompt["roles"]) >= 2
):
role_remap = [
{"from": conversation.roles[0], "to": prompt["roles"][0]},
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]

try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if not isinstance(part, tuple):
LOG.warning(f"expected tuple, got {part}")
continue

role, content = part

# Uses "in" because role contains extra characters
input_turn = any(r in role.lower() for r in input_roles)
output_turn = any(r in role.lower() for r in output_roles)

if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else role
)
turn = role + content
# this is still the user query, we should
if not content.strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if self.train_on_inputs:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else role
)
turn = role + content
# this should be the assistant response, should end with an eos token
if not content.strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
add_eos_token = not (
conversation.name == "chatml"
and conversation.sep == self.tokenizer.eos_token
)
res = self._tokenize(
turn,
add_eos_token=add_eos_token,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
labels = copy.deepcopy(res["input_ids"])
if not self.train_on_inputs:
# mask out role tokens from the labels
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif role == "":
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
if self.train_on_inputs:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
34 changes: 26 additions & 8 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter

LOG = logging.getLogger("axolotl")

@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):

def __init__(
self,
prompter,
prompter: Prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
@@ -340,6 +340,19 @@ def tokenize_prompt(self, prompt):
self.prompter._conversation.copy() # pylint: disable=protected-access
)

input_roles = {conversation.roles[0]}
output_roles = {conversation.roles[1]}

# Add roles from the config
if self.prompter.roles:
if "input" in self.prompter.roles and self.prompter.roles["input"]:
for role in self.prompter.roles["input"]:
input_roles.add(role)

if "output" in self.prompter.roles and self.prompter.roles["output"]:
for role in self.prompter.roles["output"]:
output_roles.add(role)

# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
@@ -360,11 +373,19 @@ def tokenize_prompt(self, prompt):
LOG.warning(f"expected tuple, got {part}")
continue

user, assistant = conversation.roles
role, content = part

# Uses "in" because role contains extra characters
if user in role:
input_turn = any(r in role.lower() for r in input_roles)
output_turn = any(r in role.lower() for r in output_roles)
empty_role = role.strip() == ""

if not any([input_turn, output_turn, empty_role]):
LOG.warning(f"unhandled role: {role}")
continue

# Uses "in" because role contains extra characters
if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
@@ -384,7 +405,7 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant in role:
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
@@ -426,9 +447,6 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
Loading