diff --git a/turbo_alignment/pipelines/train/sft_rm.py b/turbo_alignment/pipelines/train/sft_rm.py index 71aa8d5..04ba7bd 100755 --- a/turbo_alignment/pipelines/train/sft_rm.py +++ b/turbo_alignment/pipelines/train/sft_rm.py @@ -1,7 +1,5 @@ from typing import Callable -import torch -import numpy as np from torch import nn from torch.utils.data import Dataset from transformers import ( @@ -37,36 +35,21 @@ class MultiheadModel(PreTrainedModel, GenerationMixin): def __init__(self, config, model_settings, tokenizer): super().__init__(config) - self.decoder = load_model(model_settings=model_settings, tokenizer=tokenizer) + model = load_model(model_settings=model_settings, tokenizer=tokenizer) + self.decoder = model.model - self.lm_head = nn.Linear(self.decoder.norm.weight.shape[0], len(tokenizer), bias=False) + self.lm_head = model.lm_head self.rm_head = nn.Linear(self.decoder.norm.weight.shape[0], 1, bias=False) - reward_token_ids = tokenizer.encode('', add_special_tokens=False) - if len(reward_token_ids) != 1: - raise ValueError(' token is not found in the tokenizer') - - self.reward_token_ids = reward_token_ids[0] - def forward(self, batch): outputs_w = self.decoder(**batch['inputs_w']).last_hidden_state[0] outputs_l = self.decoder(**batch['inputs_l']).last_hidden_state[0] - reward_token_pos_w = np.where(batch['inputs_w']['input_ids'][0].cpu() == self.reward_token_ids)[0] - reward_token_pos_l = np.where(batch['inputs_l']['input_ids'][0].cpu() == self.reward_token_ids)[0] - - if len(reward_token_pos_w) != 1 or len(reward_token_pos_l) != 1: - raise ValueError('More than one token detected in replica') - - outputs_w_1 = outputs_w[: reward_token_pos_w[0]] - outputs_w_2 = outputs_w[reward_token_pos_w[0] + 1 :] - outputs_w_cat = torch.cat((outputs_w_1, outputs_w_2), dim=0) - - lm_logits = self.lm_head(outputs_w_cat) - rm_logits_w = self.rm_head(outputs_w[reward_token_pos_w[0]]) - rm_logits_l = self.rm_head(outputs_l[reward_token_pos_l[0]]) + lm_logits = self.lm_head(outputs_w) + rm_logits_w = self.rm_head(outputs_w[-1]) + rm_logits_l = self.rm_head(outputs_l[-1]) - return lm_logits, rm_logits_w, rm_logits_l, reward_token_pos_w + return lm_logits, rm_logits_w, rm_logits_l class TrainMultiheadStrategy(BaseTrainStrategy[RMTrainExperimentSettings]): diff --git a/turbo_alignment/trainers/sft_with_rm.py b/turbo_alignment/trainers/sft_with_rm.py index 5fc6dec..a957798 100644 --- a/turbo_alignment/trainers/sft_with_rm.py +++ b/turbo_alignment/trainers/sft_with_rm.py @@ -16,17 +16,11 @@ class SFTwithRMTrainer(MultiGPUCherryPicksTrainer): def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: - sft_logits, rewards_w, rewards_l, reward_token_pos_w = model.forward(inputs) - - sft_logits = sft_logits.view(-1, sft_logits.size(-1)) - sft_labels = inputs['inputs_w']['input_ids'] - - sft_labels_1 = sft_labels.view(-1)[: reward_token_pos_w[0]] - sft_labels_2 = sft_labels.view(-1)[reward_token_pos_w[0] + 1 :] - sft_labels_cat = torch.cat((sft_labels_1, sft_labels_2), dim=0) + sft_logits, rewards_w, rewards_l = model.forward(inputs) + sft_labels = inputs['inputs_w']['input_ids'][0] loss = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() + torch.nn.functional.cross_entropy( - sft_logits, sft_labels_cat + sft_logits, sft_labels ) if return_outputs: return loss, {'rewards_w': rewards_w, 'rewards_l': rewards_l}