Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Nov 11, 2024
1 parent c2a553a commit c46598d
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 32 deletions.
2 changes: 2 additions & 0 deletions turbo_alignment/dataset/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def _encode(
inference=inference,
random_cut=random_cut,
)
if len(input_ids) >= 8000:
raise ValueError(f'{len(input_ids)=}, which is >=8000')

except ValueError as ex:
output.append(None)
Expand Down
2 changes: 2 additions & 0 deletions turbo_alignment/pipelines/inference/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def _get_single_inference_settings(
dtype='bfloat16',
tensor_parallel_size=model_inference_settings.tensor_parallel_size,
enable_lora=enable_lora,
gpu_memory_utilization=0.9,
disable_custom_all_reduce=True,
)

else:
Expand Down
33 changes: 17 additions & 16 deletions turbo_alignment/pipelines/train/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,23 @@ def _get_cherry_pick_callback(
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> ChatCherryPickCallback:
cherry_pick_settings = experiment_settings.cherry_pick_settings

cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
)

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
for metric in cherry_pick_settings.metric_settings
]

return ChatCherryPickCallback(
cherry_pick_settings=cherry_pick_settings,
datasets=cherry_pick_datasets,
metrics=metrics,
)
return None
# cherry_pick_settings = experiment_settings.cherry_pick_settings

# cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
# cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
# )

# metrics = [
# Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
# for metric in cherry_pick_settings.metric_settings
# ]

# return ChatCherryPickCallback(
# cherry_pick_settings=cherry_pick_settings,
# datasets=cherry_pick_datasets,
# metrics=metrics,
# )

@staticmethod
def _get_training_args(experiment_settings: DPOTrainExperimentSettings) -> DPOTrainingArguments:
Expand Down
27 changes: 14 additions & 13 deletions turbo_alignment/pipelines/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,23 @@ def _get_cherry_pick_callback(
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> ChatCherryPickCallback:
cherry_pick_settings = experiment_settings.cherry_pick_settings
return None
# cherry_pick_settings = experiment_settings.cherry_pick_settings

cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
)
# cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
# cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
# )

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
for metric in cherry_pick_settings.metric_settings
]
# metrics = [
# Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
# for metric in cherry_pick_settings.metric_settings
# ]

return ChatCherryPickCallback(
cherry_pick_settings=cherry_pick_settings,
datasets=cherry_pick_datasets,
metrics=metrics,
)
# return ChatCherryPickCallback(
# cherry_pick_settings=cherry_pick_settings,
# datasets=cherry_pick_datasets,
# metrics=metrics,
# )

@staticmethod
def _get_training_args(experiment_settings: SftTrainExperimentSettings) -> TrainingArguments:
Expand Down
7 changes: 4 additions & 3 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,9 +749,9 @@ def _compute_metrics(
metrics[f'{prefix_name}grad_term'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().mean().item()
)
metrics[f'{prefix_name}grad_term_std'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item()
)
# metrics[f'{prefix_name}grad_term_std'] = (
# (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item()
# )

return metrics

Expand Down Expand Up @@ -791,6 +791,7 @@ def compute_loss(
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
loss, metrics = self.get_batch_metrics(model, inputs, train_eval='train')

Expand Down

0 comments on commit c46598d

Please sign in to comment.