diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index aaa199661d..fa7e6ae10c 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -277,6 +277,17 @@ def tokenize_prompt(self, prompt): labels[:len_role] = [IGNORE_TOKEN_ID] * min( len_role, len(labels) ) + elif role == "": + turn = content + # this is only ever the first part, should include the bos token and the user query + res = self._tokenize( + turn, add_eos_token=False, strip_bos_token=False + ) + if self.train_on_inputs: + labels = copy.deepcopy(res["input_ids"]) + else: + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) else: LOG.warning(f"unhandled role: {role}") continue