Skip to content

Commit

Permalink
training on eos
Browse files Browse the repository at this point in the history
  • Loading branch information
d.taranets committed Nov 13, 2024
1 parent 31a0d37 commit 1ea49e4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 33 deletions.
31 changes: 7 additions & 24 deletions turbo_alignment/pipelines/train/sft_rm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Callable
import torch

import numpy as np
from torch import nn
from torch.utils.data import Dataset
from transformers import (
Expand Down Expand Up @@ -37,36 +35,21 @@ class MultiheadModel(PreTrainedModel, GenerationMixin):
def __init__(self, config, model_settings, tokenizer):
super().__init__(config)

self.decoder = load_model(model_settings=model_settings, tokenizer=tokenizer)
model = load_model(model_settings=model_settings, tokenizer=tokenizer)
self.decoder = model.model

self.lm_head = nn.Linear(self.decoder.norm.weight.shape[0], len(tokenizer), bias=False)
self.lm_head = model.lm_head
self.rm_head = nn.Linear(self.decoder.norm.weight.shape[0], 1, bias=False)

reward_token_ids = tokenizer.encode('<reward>', add_special_tokens=False)
if len(reward_token_ids) != 1:
raise ValueError('<reward> token is not found in the tokenizer')

self.reward_token_ids = reward_token_ids[0]

def forward(self, batch):
outputs_w = self.decoder(**batch['inputs_w']).last_hidden_state[0]
outputs_l = self.decoder(**batch['inputs_l']).last_hidden_state[0]

reward_token_pos_w = np.where(batch['inputs_w']['input_ids'][0].cpu() == self.reward_token_ids)[0]
reward_token_pos_l = np.where(batch['inputs_l']['input_ids'][0].cpu() == self.reward_token_ids)[0]

if len(reward_token_pos_w) != 1 or len(reward_token_pos_l) != 1:
raise ValueError('More than one <reward> token detected in replica')

outputs_w_1 = outputs_w[: reward_token_pos_w[0]]
outputs_w_2 = outputs_w[reward_token_pos_w[0] + 1 :]
outputs_w_cat = torch.cat((outputs_w_1, outputs_w_2), dim=0)

lm_logits = self.lm_head(outputs_w_cat)
rm_logits_w = self.rm_head(outputs_w[reward_token_pos_w[0]])
rm_logits_l = self.rm_head(outputs_l[reward_token_pos_l[0]])
lm_logits = self.lm_head(outputs_w)
rm_logits_w = self.rm_head(outputs_w[-1])
rm_logits_l = self.rm_head(outputs_l[-1])

return lm_logits, rm_logits_w, rm_logits_l, reward_token_pos_w
return lm_logits, rm_logits_w, rm_logits_l


class TrainMultiheadStrategy(BaseTrainStrategy[RMTrainExperimentSettings]):
Expand Down
12 changes: 3 additions & 9 deletions turbo_alignment/trainers/sft_with_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,11 @@

class SFTwithRMTrainer(MultiGPUCherryPicksTrainer):
def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor:
sft_logits, rewards_w, rewards_l, reward_token_pos_w = model.forward(inputs)

sft_logits = sft_logits.view(-1, sft_logits.size(-1))
sft_labels = inputs['inputs_w']['input_ids']

sft_labels_1 = sft_labels.view(-1)[: reward_token_pos_w[0]]
sft_labels_2 = sft_labels.view(-1)[reward_token_pos_w[0] + 1 :]
sft_labels_cat = torch.cat((sft_labels_1, sft_labels_2), dim=0)
sft_logits, rewards_w, rewards_l = model.forward(inputs)
sft_labels = inputs['inputs_w']['input_ids'][0]

loss = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() + torch.nn.functional.cross_entropy(
sft_logits, sft_labels_cat
sft_logits, sft_labels
)
if return_outputs:
return loss, {'rewards_w': rewards_w, 'rewards_l': rewards_l}
Expand Down

0 comments on commit 1ea49e4

Please sign in to comment.