diff --git a/tests/cli/test_classification.py b/tests/cli/test_classification.py index 8de6b4b..115e475 100755 --- a/tests/cli/test_classification.py +++ b/tests/cli/test_classification.py @@ -1,26 +1,24 @@ from pathlib import Path +import pytest from typer.testing import CliRunner from tests.constants import FIXTURES_PATH from turbo_alignment.cli import app +from turbo_alignment.settings.pipelines.train.classification import ( + ClassificationTrainExperimentSettings, +) runner = CliRunner() -def test_classification_train(): - result = runner.invoke( - app, - [ - 'train_classification', - '--experiment_settings_path', - FIXTURES_PATH / 'configs/train/classification/base.json', - ], - catch_exceptions=False, - ) +@pytest.mark.parametrize( + 'config_path', + [ + FIXTURES_PATH / 'configs/train/classification/base.json', + ], +) +def test_classification_train(config_path: Path): + result = runner.invoke(app, ['train_classification', '--experiment_settings_path', str(config_path)]) assert result.exit_code == 0 - assert Path('test_train_classification_output').is_dir() - - -if __name__ == '__main__': - test_classification_train() + assert ClassificationTrainExperimentSettings.parse_file(config_path).log_path.is_dir() diff --git a/tests/fixtures/configs/inference/multimodal/llama_llava_clip_pickle.json b/tests/fixtures/configs/inference/multimodal/llama_llava_clip_pickle.json index d4c41b2..9ed9dfd 100644 --- a/tests/fixtures/configs/inference/multimodal/llama_llava_clip_pickle.json +++ b/tests/fixtures/configs/inference/multimodal/llama_llava_clip_pickle.json @@ -38,7 +38,7 @@ }, "generation_settings": [ { - "transformers_settings": { + "generation_config": { "num_beams": 1, "max_new_tokens": 8 }, diff --git a/tests/fixtures/configs/inference/rag/base.json b/tests/fixtures/configs/inference/rag/base.json index dff8d56..eb86c41 100755 --- a/tests/fixtures/configs/inference/rag/base.json +++ b/tests/fixtures/configs/inference/rag/base.json @@ -37,7 +37,7 @@ }, "generation_settings": [ { - "transformers_settings": { + "generation_config": { "num_beams": 1, "max_new_tokens": 10, "repetition_penalty": 1.2, diff --git a/tests/fixtures/configs/inference/sft/base.json b/tests/fixtures/configs/inference/sft/base.json index 980d32f..3428953 100755 --- a/tests/fixtures/configs/inference/sft/base.json +++ b/tests/fixtures/configs/inference/sft/base.json @@ -13,7 +13,7 @@ }, "generation_settings": [ { - "transformers_settings": { + "generation_config": { "num_beams": 3, "max_new_tokens": 8 }, diff --git a/tests/fixtures/configs/train/classification/base.json b/tests/fixtures/configs/train/classification/base.json index 127fd26..069c11b 100755 --- a/tests/fixtures/configs/train/classification/base.json +++ b/tests/fixtures/configs/train/classification/base.json @@ -79,11 +79,18 @@ "problem_type": "single_label_classification" }, "peft_settings": { - "r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": ["q_proj", "v_proj"], - "task_type": "SEQ_CLS" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } } }, "tokenizer_settings": {}, @@ -92,7 +99,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, @@ -115,6 +122,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "classification", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/ddpo/base.json b/tests/fixtures/configs/train/ddpo/base.json index 82c48da..4e53533 100755 --- a/tests/fixtures/configs/train/ddpo/base.json +++ b/tests/fixtures/configs/train/ddpo/base.json @@ -92,7 +92,7 @@ "adapter_path": "tests/fixtures/models/llama2_tiny_rm" }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "do_sample": false, "max_new_tokens": 8 @@ -132,7 +132,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, @@ -156,6 +156,7 @@ "max_grad_norm": 1.0 }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "ddpo", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/dpo/base.json b/tests/fixtures/configs/train/dpo/base.json index 3140625..dd6e36c 100755 --- a/tests/fixtures/configs/train/dpo/base.json +++ b/tests/fixtures/configs/train/dpo/base.json @@ -53,7 +53,7 @@ "is_trainable": true }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "do_sample": false, "stop_strings": "", @@ -111,7 +111,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 2, "per_device_eval_batch_size": 2, @@ -137,6 +137,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "dpo", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/dpo/simpo.json b/tests/fixtures/configs/train/dpo/simpo.json index 293093a..70fd3c0 100755 --- a/tests/fixtures/configs/train/dpo/simpo.json +++ b/tests/fixtures/configs/train/dpo/simpo.json @@ -53,7 +53,7 @@ "is_trainable": true }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "do_sample": false, "stop_strings": "", @@ -103,7 +103,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 2, "per_device_eval_batch_size": 2, @@ -131,6 +131,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "dpo", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/kto/base.json b/tests/fixtures/configs/train/kto/base.json index a06d499..cb7654f 100755 --- a/tests/fixtures/configs/train/kto/base.json +++ b/tests/fixtures/configs/train/kto/base.json @@ -51,7 +51,7 @@ "is_trainable": true }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "do_sample": false, "stop_strings": "", @@ -89,7 +89,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 4, "per_device_eval_batch_size": 4, @@ -111,6 +111,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "kto", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json index 38130b4..643c945 100644 --- a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json +++ b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json @@ -76,19 +76,18 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "k_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "bot", @@ -108,7 +107,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "epoch", "save_strategy": "epoch", "per_device_train_batch_size": 1, @@ -126,6 +125,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "multimodal", "entity": "turbo-alignment" @@ -147,7 +147,7 @@ "audio": null }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "max_new_tokens": 16, "repetition_penalty": 1.1, diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json index d95ee19..62dbaff 100644 --- a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json +++ b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json @@ -76,19 +76,18 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "k_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "bot", @@ -108,7 +107,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "epoch", "save_strategy": "epoch", "per_device_train_batch_size": 1, @@ -126,6 +125,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "multimodal", "entity": "turbo-alignment" @@ -147,7 +147,7 @@ "audio": null }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "max_new_tokens": 128, "repetition_penalty": 1.1, diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json index 71e04b0..5c807cf 100644 --- a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json +++ b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json @@ -76,19 +76,18 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "k_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "bot", @@ -108,7 +107,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "epoch", "save_strategy": "epoch", "per_device_train_batch_size": 1, @@ -126,6 +125,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "multimodal", "entity": "turbo-alignment" @@ -147,7 +147,7 @@ "audio": null }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "max_new_tokens": 4, "repetition_penalty": 1.1, diff --git a/tests/fixtures/configs/train/rag/end2end.json b/tests/fixtures/configs/train/rag/end2end.json index d54b66f..151471b 100755 --- a/tests/fixtures/configs/train/rag/end2end.json +++ b/tests/fixtures/configs/train/rag/end2end.json @@ -54,16 +54,18 @@ "": "system" }, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } } }, "question_encoder_settings": { @@ -83,7 +85,7 @@ } }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 3, "max_new_tokens": 16, "repetition_penalty": 1.1, @@ -123,7 +125,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "epoch", "save_strategy": "epoch", "per_device_train_batch_size": 1, @@ -146,6 +148,7 @@ "save_total_limit": 1 }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "rag", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/rm/base.json b/tests/fixtures/configs/train/rm/base.json index e1d5e21..10dcab0 100755 --- a/tests/fixtures/configs/train/rm/base.json +++ b/tests/fixtures/configs/train/rm/base.json @@ -81,11 +81,18 @@ "return_dict": true }, "peft_settings": { - "r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": ["q_proj", "v_proj"], - "task_type": "SEQ_CLS" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "SEQ_CLS", + "modules_to_save": ["embed_tokens", "lm_head"] + } } }, "tokenizer_settings": {}, @@ -94,7 +101,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, @@ -113,6 +120,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "rm", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/sft/base.json b/tests/fixtures/configs/train/sft/base.json index 4cb4cdb..a6fcda7 100755 --- a/tests/fixtures/configs/train/sft/base.json +++ b/tests/fixtures/configs/train/sft/base.json @@ -47,16 +47,18 @@ "transformers_settings": { }, "peft_settings": { - "r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "v_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": ["embed_tokens", "lm_head"], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "", @@ -67,7 +69,7 @@ } }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 3, "stop_strings": ["", ""], "max_new_tokens": 8 @@ -104,7 +106,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, @@ -123,6 +125,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/sft/prompt_tuning.json b/tests/fixtures/configs/train/sft/prompt_tuning.json index 4ef7df4..86c97de 100755 --- a/tests/fixtures/configs/train/sft/prompt_tuning.json +++ b/tests/fixtures/configs/train/sft/prompt_tuning.json @@ -46,9 +46,11 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "task_type": "CAUSAL_LM", "name": "PROMPT_TUNING", - "num_virtual_tokens": 32 + "config": { + "task_type": "CAUSAL_LM", + "num_virtual_tokens": 32 + } }, "embeddings_initialization_strategy": { "": "", @@ -59,7 +61,7 @@ } }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "max_new_tokens": 35, "repetition_penalty": 1.1, @@ -99,7 +101,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, @@ -118,6 +120,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/sft/resume_from_checkpoint.json b/tests/fixtures/configs/train/sft/resume_from_checkpoint.json index 28c3f03..ada6496 100755 --- a/tests/fixtures/configs/train/sft/resume_from_checkpoint.json +++ b/tests/fixtures/configs/train/sft/resume_from_checkpoint.json @@ -49,7 +49,7 @@ "is_trainable": true }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "max_new_tokens": 8 }, @@ -83,7 +83,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, @@ -102,6 +102,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/sft/sft_retrieval_utility.json b/tests/fixtures/configs/train/sft/sft_retrieval_utility.json index 1e4f5df..91e8afc 100755 --- a/tests/fixtures/configs/train/sft/sft_retrieval_utility.json +++ b/tests/fixtures/configs/train/sft/sft_retrieval_utility.json @@ -131,6 +131,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "sft_retrieval_utility", "entity": "turbo-alignment" diff --git a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json index 19bedde..8dd2e1b 100755 --- a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json +++ b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json @@ -47,16 +47,18 @@ "transformers_settings": { }, "peft_settings": { - "r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "v_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": ["embed_tokens", "lm_head"], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "", @@ -67,10 +69,11 @@ } }, "cherry_pick_settings": { - "generator_transformers_settings": { + "generation_config": { "num_beams": 1, "num_return_sequences": 2, "stop_strings": "", + "do_sample": true, "max_new_tokens": 8 }, "custom_generation_settings": { @@ -126,7 +129,7 @@ "eos_token": "", "pad_token": "" }, - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, @@ -145,6 +148,7 @@ "no_cuda": true }, "logging_settings": { + "logging_type": "wandb", "project_name": "alignment", "run_name": "sft", "entity": "turbo-alignment" diff --git a/tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/experiment_settings_config.json b/tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/experiment_settings_config.json index 3425c8f..7e74e1f 100755 --- a/tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/experiment_settings_config.json +++ b/tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/experiment_settings_config.json @@ -1,5 +1,5 @@ { - "trainer_settings": { + "training_arguments": { "evaluation_strategy": "steps", "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, diff --git a/tests/integration/test_trainers.py b/tests/integration/test_trainers.py index bdda2ae..97f8aa8 100644 --- a/tests/integration/test_trainers.py +++ b/tests/integration/test_trainers.py @@ -3,11 +3,11 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer +from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelSettings from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator -from turbo_alignment.settings.pipelines.train.dpo import ( +from turbo_alignment.settings.pipelines.train.utils import ( DPOLossesType, SigmoidLossSettings, - SyncRefModelSettings, ) from turbo_alignment.trainers.dpo import DPOTrainer, DPOTrainingArguments diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index eec477e..7bc39d2 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -22,7 +22,7 @@ def __init__( ) -> None: super().__init__(cherry_pick_settings=cherry_pick_settings, datasets=datasets, metrics=metrics) self._custom_generation_settings = cherry_pick_settings.custom_generation_settings - self._generator_transformers_settings = cherry_pick_settings.generator_transformers_settings + self._generation_config = cherry_pick_settings.generation_config def _get_dataset_metrics( self, @@ -39,13 +39,13 @@ def _get_dataset_metrics( generator = ChatGenerator( model=model, tokenizer=tokenizer, - transformers_settings=self._generator_transformers_settings, + generation_config=self._generation_config, custom_generation_settings=self._custom_generation_settings, accelerator=accelerator, return_logits=True, ) - batch_size = self._generator_transformers_settings.num_return_sequences + batch_size = self._generation_config.num_return_sequences generations = generator.generate_from_dataset(dataset) diff --git a/turbo_alignment/cherry_picks/multimodal.py b/turbo_alignment/cherry_picks/multimodal.py index 80c6022..b0cf990 100755 --- a/turbo_alignment/cherry_picks/multimodal.py +++ b/turbo_alignment/cherry_picks/multimodal.py @@ -22,7 +22,7 @@ def __init__( ) -> None: super().__init__(cherry_pick_settings=cherry_pick_settings, datasets=datasets, metrics=metrics) self._custom_generation_settings = cherry_pick_settings.custom_generation_settings - self._generator_transformers_settings = cherry_pick_settings.generator_transformers_settings + self._generation_config = cherry_pick_settings.generation_config def _get_dataset_metrics( self, @@ -34,7 +34,7 @@ def _get_dataset_metrics( generator = MultimodalGenerator( model=model, tokenizer=tokenizer, - transformers_settings=self._generator_transformers_settings, + generation_config=self._generation_config, custom_generation_settings=self._custom_generation_settings, ) diff --git a/turbo_alignment/cherry_picks/rag.py b/turbo_alignment/cherry_picks/rag.py index b1c2c4d..c44ffce 100755 --- a/turbo_alignment/cherry_picks/rag.py +++ b/turbo_alignment/cherry_picks/rag.py @@ -20,7 +20,7 @@ def _get_dataset_metrics( generator = RagGenerator( model=model, tokenizer=tokenizer, - transformers_settings=self._generator_transformers_settings, + generation_config=self._generation_config, custom_generation_settings=self._custom_generation_settings, accelerator=accelerator, ) diff --git a/turbo_alignment/common/tf/callbacks/sync_ref_model.py b/turbo_alignment/common/tf/callbacks/sync_ref_model.py index 555c9ec..f94ce81 100755 --- a/turbo_alignment/common/tf/callbacks/sync_ref_model.py +++ b/turbo_alignment/common/tf/callbacks/sync_ref_model.py @@ -7,7 +7,13 @@ TrainingArguments, ) -from turbo_alignment.settings.pipelines.train.dpo import SyncRefModelSettings +from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel + + +class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel): + sync_ref_model: bool = False + alpha: float = 1.0 + sync_steps: int = 1 class SyncRefModelCallback(TrainerCallback): diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py index aba80e9..49a5f11 100755 --- a/turbo_alignment/common/tf/loaders/model/model.py +++ b/turbo_alignment/common/tf/loaders/model/model.py @@ -4,7 +4,6 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from turbo_alignment.common.tf.loaders.model.registry import ( - PeftConfigRegistry, TransformersAutoModelRegistry, ) from turbo_alignment.modeling.liger_kernels import apply_liger_kernel_to_gemma2 @@ -18,12 +17,7 @@ def _prepare_model_for_peft(model: PreTrainedModel, peft_settings: PEFT_TYPE) -> PeftModel: - peft_params = peft_settings.dict() - peft_params.pop('name') - - peft_config = PeftConfigRegistry.by_name(peft_settings.name)(**peft_params) - - return get_peft_model(model, peft_config) + return get_peft_model(model, peft_settings.config) def _load_pretrained_adapters( diff --git a/turbo_alignment/common/tf/loaders/model/registry.py b/turbo_alignment/common/tf/loaders/model/registry.py index e2c5e08..565a61d 100755 --- a/turbo_alignment/common/tf/loaders/model/registry.py +++ b/turbo_alignment/common/tf/loaders/model/registry.py @@ -1,10 +1,3 @@ -from peft import ( - LoraConfig, - PeftType, - PrefixTuningConfig, - PromptEncoderConfig, - PromptTuningConfig, -) from transformers import ( AutoModel, AutoModelForCausalLM, @@ -19,15 +12,6 @@ class TransformersAutoModelRegistry(Registrable): ... -class PeftConfigRegistry(Registrable): - ... - - TransformersAutoModelRegistry.register(ModelType.CAUSAL)(AutoModelForCausalLM) TransformersAutoModelRegistry.register(ModelType.SEQ_CLS)(AutoModelForSequenceClassification) TransformersAutoModelRegistry.register(ModelType.ENC)(AutoModel) - -PeftConfigRegistry.register(PeftType.LORA)(LoraConfig) -PeftConfigRegistry.register(PeftType.PREFIX_TUNING)(PrefixTuningConfig) -PeftConfigRegistry.register(PeftType.PROMPT_TUNING)(PromptTuningConfig) -PeftConfigRegistry.register(PeftType.P_TUNING)(PromptEncoderConfig) diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 5415daf..a1e560b 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -10,7 +10,6 @@ from turbo_alignment.dataset.base.models import DatasetRecord from turbo_alignment.settings.generators.chat import CustomChatGenerationSettings from turbo_alignment.settings.generators.outputs.base import BaseInferenceOutput -from turbo_alignment.settings.tf.generation import GeneratorTransformersSettings InferenceOutputT = TypeVar('InferenceOutputT', bound=BaseInferenceOutput) DatasetRecordT = TypeVar('DatasetRecordT', bound=DatasetRecord) @@ -89,7 +88,7 @@ def generate_from_dataset(self, dataset: BaseDataset) -> list[InferenceOutputT]: class ChatGeneratorBase(BaseGenerator, Generic[DatasetRecordT, InferenceOutputT]): def __init__( self, - transformers_settings: GeneratorTransformersSettings, + generation_config: GenerationConfig, custom_generation_settings: CustomChatGenerationSettings, tokenizer: PreTrainedTokenizerBase, return_logits: bool = False, @@ -99,10 +98,8 @@ def __init__( self._return_logits = return_logits - self._transformers_generator_parameters = GenerationConfig( - bos_token_id=self._tokenizer.bos_token_id, - **transformers_settings.dict(), - ) + self._generation_config = generation_config + self._generation_config.bos_token_id = self._tokenizer.bos_token_id self._custom_generation_settings = custom_generation_settings @@ -128,7 +125,7 @@ def _generate_from_batch( self, records: list[dict[str, Any]], original_records: list[DatasetRecordT], dataset_name: str ) -> list[InferenceOutputT]: if self._custom_generation_settings.batch > 1: - if self._transformers_generator_parameters.num_beams != 1: + if self._generation_config.num_beams != 1: raise ValueError('You can not use batch generation with num_beams != 1') self._tokenizer.padding_side = 'left' diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index acc4337..544da30 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -38,7 +38,7 @@ def _generate_from_batch_records( output_indices = self._model.generate( inputs=batched_input_ids, attention_mask=batched_attention_mask, - generation_config=self._transformers_generator_parameters, + generation_config=self._generation_config, tokenizer=self._tokenizer, pad_token_id=self._tokenizer.pad_token_id, ) @@ -83,7 +83,7 @@ def _generate_from_single_record( output_indices = self._model.generate( inputs=input_ids, attention_mask=attention_mask, - generation_config=self._transformers_generator_parameters, + generation_config=self._generation_config, tokenizer=self._tokenizer, pad_token_id=self._tokenizer.pad_token_id, ) diff --git a/turbo_alignment/generators/multimodal.py b/turbo_alignment/generators/multimodal.py index 0c64137..925ff33 100755 --- a/turbo_alignment/generators/multimodal.py +++ b/turbo_alignment/generators/multimodal.py @@ -1,5 +1,5 @@ import torch -from transformers import PreTrainedTokenizerBase +from transformers import GenerationConfig, PreTrainedTokenizerBase from turbo_alignment.dataset.multimodal.models import MultimodalDatasetRecord from turbo_alignment.generators.base import ChatGeneratorBase @@ -8,19 +8,18 @@ from turbo_alignment.settings.generators.outputs.multimodal import ( MultimodalInferenceOutput, ) -from turbo_alignment.settings.tf.generation import GeneratorTransformersSettings class MultimodalGenerator(ChatGeneratorBase[MultimodalDatasetRecord, MultimodalInferenceOutput]): def __init__( self, - transformers_settings: GeneratorTransformersSettings, + generation_config: GenerationConfig, custom_generation_settings: CustomChatGenerationSettings, tokenizer: PreTrainedTokenizerBase, **kwargs, ) -> None: super().__init__( - transformers_settings=transformers_settings, + generation_config=generation_config, custom_generation_settings=custom_generation_settings, tokenizer=tokenizer, **kwargs, @@ -51,7 +50,7 @@ def _generate_from_single_record( inputs_embeds=inputs_embeds, attention_mask=attention_mask, tokenizer=self._tokenizer, - generation_config=self._transformers_generator_parameters, + generation_config=self._generation_config, ) answers = self._decode(token_indices=output_indices) diff --git a/turbo_alignment/generators/rag.py b/turbo_alignment/generators/rag.py index 4203de8..194b2c1 100755 --- a/turbo_alignment/generators/rag.py +++ b/turbo_alignment/generators/rag.py @@ -21,7 +21,7 @@ def _generate_from_single_record( answer_indices, document_indices, doc_scores = self._model.generate( inputs=input_ids, - generation_config=self._transformers_generator_parameters, + generation_config=self._generation_config, tokenizer=self._tokenizer.current_tokenizer, pad_token_id=self._tokenizer.pad_token_id, ) diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index 5a89bb1..2ec1341 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -1,7 +1,7 @@ from typing import Any import torch -from transformers import PreTrainedTokenizerBase +from transformers import GenerationConfig, PreTrainedTokenizerBase from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest @@ -12,13 +12,12 @@ AnswerMessage, ChatInferenceOutput, ) -from turbo_alignment.settings.tf.generation import GeneratorTransformersSettings class VLLMChatGenerator(BaseGenerator[ChatDatasetRecord, ChatInferenceOutput]): def __init__( self, - transformers_settings: GeneratorTransformersSettings, + generation_config: GenerationConfig, custom_generation_settings: CustomChatGenerationSettings, model: LLM, tokenizer: PreTrainedTokenizerBase, @@ -29,28 +28,28 @@ def __init__( model.set_tokenizer(tokenizer) super().__init__(model, tokenizer, batch=batch) - if isinstance(transformers_settings.stop_strings, list): + if isinstance(generation_config.stop_strings, list): raise ValueError('You should use only 1 eos token with VLLM') - eos_token_id: list[int] = self._tokenizer.encode(transformers_settings.stop_strings, add_special_tokens=False) + eos_token_id: list[int] = self._tokenizer.encode(generation_config.stop_strings, add_special_tokens=False) beam_search_params: dict[str, Any] = { - 'best_of': transformers_settings.num_return_sequences, + 'best_of': generation_config.num_return_sequences, 'use_beam_search': False, } - if transformers_settings.num_beams > 1: + if generation_config.num_beams > 1: beam_search_params['use_beam_search'] = True - beam_search_params['best_of'] = transformers_settings.num_beams + beam_search_params['best_of'] = generation_config.num_beams self._sampling_params = SamplingParams( - n=transformers_settings.num_return_sequences, - repetition_penalty=transformers_settings.repetition_penalty, - temperature=transformers_settings.temperature, - top_p=transformers_settings.top_p, - top_k=transformers_settings.top_k, + n=generation_config.num_return_sequences, + repetition_penalty=generation_config.repetition_penalty, + temperature=generation_config.temperature, + top_p=generation_config.top_p, + top_k=generation_config.top_k, skip_special_tokens=custom_generation_settings.skip_special_tokens, stop_token_ids=eos_token_id, - max_tokens=transformers_settings.max_new_tokens, + max_tokens=generation_config.max_new_tokens, **beam_search_params, ) self._lora_request = lora_request diff --git a/turbo_alignment/pipelines/inference/chat.py b/turbo_alignment/pipelines/inference/chat.py index 93b975c..e95ebbd 100755 --- a/turbo_alignment/pipelines/inference/chat.py +++ b/turbo_alignment/pipelines/inference/chat.py @@ -58,7 +58,7 @@ def _get_single_inference_settings( generator_kwargs = { 'model': model, 'tokenizer': tokenizer, - 'transformers_settings': generation_settings.transformers_settings, + 'generation_config': generation_settings.generation_config, 'custom_generation_settings': generation_settings.custom_settings, 'batch': model_inference_settings.batch, } diff --git a/turbo_alignment/pipelines/inference/multimodal.py b/turbo_alignment/pipelines/inference/multimodal.py index addaf11..22be3ba 100755 --- a/turbo_alignment/pipelines/inference/multimodal.py +++ b/turbo_alignment/pipelines/inference/multimodal.py @@ -57,7 +57,7 @@ def _get_single_inference_settings( for generation_settings in model_inference_settings.generation_settings: generator = MultimodalGenerator( - transformers_settings=generation_settings.transformers_settings, + generation_config=generation_settings.generation_config, custom_generation_settings=generation_settings.custom_settings, tokenizer=tokenizer, model=model, diff --git a/turbo_alignment/pipelines/inference/rag.py b/turbo_alignment/pipelines/inference/rag.py index 01ffae1..68bb349 100755 --- a/turbo_alignment/pipelines/inference/rag.py +++ b/turbo_alignment/pipelines/inference/rag.py @@ -43,7 +43,7 @@ def _get_single_inference_settings( for generation_settings in model_inference_settings.generation_settings: generator = RagGenerator( - transformers_settings=generation_settings.transformers_settings, + generation_config=generation_settings.generation_config, custom_generation_settings=generation_settings.custom_settings, tokenizer=tokenizer, model=model, diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index 0f98602..5e45325 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -84,11 +84,6 @@ def _load_model( ) -> torch.nn.Module | PreTrainedModel: return load_model(experiment_settings.model_settings, tokenizer) - @staticmethod - @abstractmethod - def _get_training_args(experiment_settings: ExperimentSettingsT) -> TrainingArguments: - ... - @staticmethod def _load_tokenizer(experiment_settings: ExperimentSettingsT) -> PreTrainedTokenizerBase: return load_tokenizer(experiment_settings.tokenizer_settings, experiment_settings.model_settings) @@ -134,8 +129,6 @@ def _add_trainer_callbacks(self, experiment_settings: ExperimentSettingsT, **kwa ) def run(self, experiment_settings: ExperimentSettingsT) -> None: - training_args = self._get_training_args(experiment_settings) - self.tokenizer = self._load_tokenizer(experiment_settings) logger.info('Tokenizer is loaded!') @@ -173,7 +166,7 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: data_collator = self._get_data_collator(experiment_settings, self.tokenizer) self.trainer = self._get_trainer( - training_args, + experiment_settings.training_arguments, experiment_settings, self.model, self.tokenizer, diff --git a/turbo_alignment/pipelines/train/classification.py b/turbo_alignment/pipelines/train/classification.py index 1bc09fe..9ca72d9 100755 --- a/turbo_alignment/pipelines/train/classification.py +++ b/turbo_alignment/pipelines/train/classification.py @@ -1,12 +1,11 @@ from typing import Callable, cast from torch.utils.data import Dataset -from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments +from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.data.data_collator import DataCollatorMixin, DataCollatorWithPadding from turbo_alignment.cherry_picks.classification import ClassificationCherryPickCallback from turbo_alignment.common.logging import get_project_logger -from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.dataset.classification.classification import ClassificationDataset from turbo_alignment.dataset.loader import DatasetLoader from turbo_alignment.metrics.metric import Metric @@ -18,6 +17,7 @@ ) from turbo_alignment.settings.pipelines.train.classification import ( ClassificationTrainExperimentSettings, + ClassificationTrainingArguments, ) from turbo_alignment.trainers.classification import ( ClassificationTrainer, @@ -62,28 +62,19 @@ def _get_cherry_pick_callback( metrics=metrics, ) - @staticmethod - def _get_training_args(experiment_settings: ClassificationTrainExperimentSettings) -> TrainingArguments: - return TrainingArguments( - output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), - label_names=['labels'], - remove_unused_columns=False, - **experiment_settings.trainer_settings.dict(exclude={'loss_settings'}), - ) - @staticmethod def _get_trainer( - training_args: TrainingArguments, + training_args: ClassificationTrainingArguments, experiment_settings: ClassificationTrainExperimentSettings, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, train_dataset: Dataset, val_dataset: Dataset, data_collator: DataCollatorMixin, - ): - if experiment_settings.trainer_settings.loss_settings.alpha == 'auto': - experiment_settings.trainer_settings.loss_settings.alpha = auto_class_weights(train_dataset) - logger.info(f'Auto computed class weights: {experiment_settings.trainer_settings.loss_settings.alpha}') + ) -> ClassificationTrainer: + if training_args.loss_settings.alpha == 'auto': + training_args.loss_settings.alpha = auto_class_weights(train_dataset) + logger.info(f'Auto computed class weights: {training_args.loss_settings.alpha}') return ClassificationTrainer( model=model, @@ -93,7 +84,6 @@ def _get_trainer( eval_dataset=val_dataset, data_collator=data_collator, callbacks=[], - loss_settings=experiment_settings.trainer_settings.loss_settings, ) def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None: diff --git a/turbo_alignment/pipelines/train/ddpo.py b/turbo_alignment/pipelines/train/ddpo.py index 2657fa1..58ecfe4 100755 --- a/turbo_alignment/pipelines/train/ddpo.py +++ b/turbo_alignment/pipelines/train/ddpo.py @@ -12,7 +12,6 @@ from turbo_alignment.common.tf.loaders.model import load_model from turbo_alignment.common.tf.loaders.tokenizer import load_tokenizer from turbo_alignment.common.tf.special_tokens_setter import SpecialTokensSetter -from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.dataset.chat.chat import InferenceChatDataset from turbo_alignment.dataset.ddpo.collators import DDPODataCollator from turbo_alignment.dataset.ddpo.ddpo import load_ddpo_datasets as load_datasets @@ -60,18 +59,6 @@ def _get_cherry_pick_callback( metrics=metrics, ) - @staticmethod - def _get_training_args(experiment_settings: DDPOTrainExperimentSettings) -> DDPOTrainingArguments: - return DDPOTrainingArguments( - output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), - label_names=[], - remove_unused_columns=False, - beta=experiment_settings.beta, - use_ref_model=experiment_settings.use_ref_model, - forward_kl=experiment_settings.forward_kl, - **experiment_settings.trainer_settings.dict(), - ) - @staticmethod def _get_data_collator( experiment_settings: DDPOTrainExperimentSettings, tokenizer: PreTrainedTokenizerBase, **kwargs @@ -94,7 +81,7 @@ def _get_trainer( data_collator: Callable, rm_model: PreTrainedModel = None, ) -> DDPOTrainer: - model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing + model.config.use_cache = not experiment_settings.training_arguments.gradient_checkpointing extra_args = {'rm': rm_model} @@ -110,8 +97,6 @@ def _get_trainer( ) def run(self, experiment_settings: DDPOTrainExperimentSettings) -> None: - training_args = self._get_training_args(experiment_settings) - self.tokenizer = load_tokenizer( experiment_settings.tokenizer_settings, experiment_settings.model_settings, @@ -170,7 +155,7 @@ def run(self, experiment_settings: DDPOTrainExperimentSettings) -> None: ) self.trainer = self._get_trainer( - training_args=training_args, + training_args=experiment_settings.training_arguments, experiment_settings=experiment_settings, model=self.model, tokenizer=self.tokenizer, diff --git a/turbo_alignment/pipelines/train/dpo.py b/turbo_alignment/pipelines/train/dpo.py index 155f6d3..ff9e9f4 100755 --- a/turbo_alignment/pipelines/train/dpo.py +++ b/turbo_alignment/pipelines/train/dpo.py @@ -7,7 +7,6 @@ from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback from turbo_alignment.common.logging import get_project_logger from turbo_alignment.common.tf.loaders.model import load_model -from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.dataset.chat.chat import InferenceChatDataset from turbo_alignment.dataset.loader import DatasetLoader from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator @@ -53,15 +52,6 @@ def _get_cherry_pick_callback( metrics=metrics, ) - @staticmethod - def _get_training_args(experiment_settings: DPOTrainExperimentSettings) -> DPOTrainingArguments: - return DPOTrainingArguments( - output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), - label_names=[], - remove_unused_columns=False, - **experiment_settings.trainer_settings.dict(), - ) - @staticmethod def _get_trainer( training_args: DPOTrainingArguments, @@ -75,14 +65,14 @@ def _get_trainer( model.config.use_cache = not training_args.gradient_checkpointing extra_args = {} - if experiment_settings.trainer_settings.use_ref_model: + if experiment_settings.training_arguments.use_ref_model: ref_model = load_model(experiment_settings.model_settings, tokenizer) for _, param in ref_model.named_parameters(): param.requires_grad = False extra_args['ref_model'] = ref_model - if experiment_settings.trainer_settings.use_sft_model: + if experiment_settings.training_arguments.use_sft_model: sft_model = load_model(experiment_settings.model_settings, tokenizer) for _, param in sft_model.named_parameters(): param.requires_grad = False diff --git a/turbo_alignment/pipelines/train/kto.py b/turbo_alignment/pipelines/train/kto.py index b34d813..e827693 100755 --- a/turbo_alignment/pipelines/train/kto.py +++ b/turbo_alignment/pipelines/train/kto.py @@ -7,7 +7,6 @@ from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback from turbo_alignment.common.logging import get_project_logger from turbo_alignment.common.tf.loaders.model.model import load_model -from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.dataset.chat.chat import InferenceChatDataset from turbo_alignment.dataset.kto.collators import KTODataCollator from turbo_alignment.dataset.loader import DatasetLoader @@ -53,15 +52,6 @@ def _get_cherry_pick_callback( metrics=metrics, ) - @staticmethod - def _get_training_args(experiment_settings: KTOTrainExperimentSettings) -> KTOTrainingArguments: - return KTOTrainingArguments( - output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), - label_names=[], - remove_unused_columns=False, - **experiment_settings.trainer_settings.dict(), - ) - @staticmethod def _get_trainer( training_args: KTOTrainingArguments, @@ -72,10 +62,10 @@ def _get_trainer( val_dataset: Dataset, data_collator: Callable, ): - model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing + model.config.use_cache = not experiment_settings.training_arguments.gradient_checkpointing extra_args = {} - if experiment_settings.trainer_settings.use_ref_model: + if experiment_settings.training_arguments.use_ref_model: ref_model = load_model(experiment_settings.model_settings, tokenizer) for _, param in ref_model.named_parameters(): param.requires_grad = False diff --git a/turbo_alignment/pipelines/train/multimodal.py b/turbo_alignment/pipelines/train/multimodal.py index cb523d0..842c226 100755 --- a/turbo_alignment/pipelines/train/multimodal.py +++ b/turbo_alignment/pipelines/train/multimodal.py @@ -8,7 +8,6 @@ from turbo_alignment.cherry_picks.multimodal import MultimodalCherryPickCallback from turbo_alignment.common.logging import get_project_logger from turbo_alignment.common.tf.loaders import load_model -from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.dataset.loader import DatasetLoader from turbo_alignment.dataset.multimodal import InferenceMultimodalDataset from turbo_alignment.dataset.multimodal.collators import DataCollatorWithModalityInputs @@ -60,13 +59,6 @@ def _get_cherry_pick_callback( metrics=metrics, ) - @staticmethod - def _get_training_args(experiment_settings: MultimodalTrainExperimentSettings) -> TrainingArguments: - return TrainingArguments( - output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), - **experiment_settings.trainer_settings.dict(), - ) - @staticmethod def _get_trainer( training_args: TrainingArguments, diff --git a/turbo_alignment/pipelines/train/rag.py b/turbo_alignment/pipelines/train/rag.py index 6c61fb7..7ba50bd 100755 --- a/turbo_alignment/pipelines/train/rag.py +++ b/turbo_alignment/pipelines/train/rag.py @@ -16,7 +16,6 @@ from turbo_alignment.cherry_picks.rag import RagCherryPickCallback from turbo_alignment.common.logging import get_project_logger from turbo_alignment.common.tf.loaders.model.model import load_model -from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.dataset.chat import InferenceChatDataset from turbo_alignment.dataset.loader import DatasetLoader from turbo_alignment.metrics.metric import Metric @@ -71,13 +70,6 @@ def _get_additional_special_tokens( embeddings_initialization_strategy = gen_settings.embeddings_initialization_strategy return list(embeddings_initialization_strategy.keys()) if embeddings_initialization_strategy else [] - @staticmethod - def _get_training_args(experiment_settings: RAGTrainExperimentSettings) -> TrainingArguments: - return TrainingArguments( - output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), - **experiment_settings.trainer_settings.dict(), - ) - @staticmethod def _get_trainer( training_args: TrainingArguments, diff --git a/turbo_alignment/pipelines/train/rm.py b/turbo_alignment/pipelines/train/rm.py index 66ecac9..0edc1d8 100755 --- a/turbo_alignment/pipelines/train/rm.py +++ b/turbo_alignment/pipelines/train/rm.py @@ -6,7 +6,6 @@ from turbo_alignment.cherry_picks.rm import RmCherryPickCallback from turbo_alignment.common.logging import get_project_logger -from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.dataset.loader import DatasetLoader from turbo_alignment.dataset.pair_preferences import ( PairPreferenceDataCollator, @@ -55,15 +54,6 @@ def _get_cherry_pick_callback( metrics=metrics, ) - @staticmethod - def _get_training_args(experiment_settings: RMTrainExperimentSettings) -> TrainingArguments: - return TrainingArguments( - output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), - label_names=[], - remove_unused_columns=False, - **experiment_settings.trainer_settings.dict(), - ) - @staticmethod def _get_trainer( training_args: TrainingArguments, diff --git a/turbo_alignment/pipelines/train/sft.py b/turbo_alignment/pipelines/train/sft.py index a1bddec..6a81980 100755 --- a/turbo_alignment/pipelines/train/sft.py +++ b/turbo_alignment/pipelines/train/sft.py @@ -9,7 +9,6 @@ from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback from turbo_alignment.common.logging import get_project_logger -from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.dataset.chat import InferenceChatDataset from turbo_alignment.dataset.loader import DatasetLoader from turbo_alignment.metrics.metric import Metric @@ -54,13 +53,6 @@ def _get_cherry_pick_callback( metrics=metrics, ) - @staticmethod - def _get_training_args(experiment_settings: SftTrainExperimentSettings) -> TrainingArguments: - return TrainingArguments( - output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), - **experiment_settings.trainer_settings.dict(), - ) - @staticmethod def _get_trainer( training_args: TrainingArguments, @@ -72,7 +64,7 @@ def _get_trainer( data_collator: DataCollatorMixin, **_kwargs, ) -> MultiGPUCherryPicksTrainer: - model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing + model.config.use_cache = not experiment_settings.training_arguments.gradient_checkpointing return MultiGPUCherryPicksTrainer( model=model, diff --git a/turbo_alignment/settings/base.py b/turbo_alignment/settings/base.py index e91eaf6..be294ad 100755 --- a/turbo_alignment/settings/base.py +++ b/turbo_alignment/settings/base.py @@ -5,3 +5,4 @@ class ExtraFieldsNotAllowedBaseModel(BaseModel): class Config: extra = Extra.forbid protected_namespaces = () + arbitrary_types_allowed = True diff --git a/turbo_alignment/settings/cherry_pick.py b/turbo_alignment/settings/cherry_pick.py index d67d61e..c889d24 100755 --- a/turbo_alignment/settings/cherry_pick.py +++ b/turbo_alignment/settings/cherry_pick.py @@ -1,3 +1,8 @@ +from typing import Any + +from pydantic import field_validator +from transformers import GenerationConfig + from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel from turbo_alignment.settings.datasets.base import MultiDatasetSettings from turbo_alignment.settings.datasets.chat import ChatMultiDatasetSettings @@ -10,7 +15,6 @@ ) from turbo_alignment.settings.generators.chat import CustomChatGenerationSettings from turbo_alignment.settings.metric import MetricSettings -from turbo_alignment.settings.tf.generation import GeneratorTransformersSettings class CherryPickSettings(ExtraFieldsNotAllowedBaseModel): @@ -27,9 +31,13 @@ class ClassificationCherryPickSettings(CherryPickSettings): class GenerationSettings(CherryPickSettings): - generator_transformers_settings: GeneratorTransformersSettings + generation_config: GenerationConfig custom_generation_settings: CustomChatGenerationSettings + @field_validator('generation_config', mode='before') + def convert_generation_config(cls, values: dict[str, Any]) -> GenerationConfig: + return GenerationConfig.from_dict(values) + class ChatCherryPickSettings(GenerationSettings): dataset_settings: ChatMultiDatasetSettings diff --git a/turbo_alignment/settings/generators/outputs/chat.py b/turbo_alignment/settings/generators/outputs/chat.py index 3d4c392..bee744c 100755 --- a/turbo_alignment/settings/generators/outputs/chat.py +++ b/turbo_alignment/settings/generators/outputs/chat.py @@ -13,9 +13,6 @@ class AnswerMessage(ExtraFieldsNotAllowedBaseModel): answer_token_ids: torch.Tensor | None = None logits: torch.Tensor | None = None - class Config: - arbitrary_types_allowed = True - class ChatInferenceOutput(BaseInferenceOutput, ChatDatasetRecord): answers: list[AnswerMessage] diff --git a/turbo_alignment/settings/logging/__init__.py b/turbo_alignment/settings/logging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/turbo_alignment/settings/logging/clearml.py b/turbo_alignment/settings/logging/clearml.py index b42f963..f602a11 100644 --- a/turbo_alignment/settings/logging/clearml.py +++ b/turbo_alignment/settings/logging/clearml.py @@ -1,7 +1,8 @@ -from pydantic_settings import BaseSettings +from turbo_alignment.settings.logging.common import LoggingSettings, LoggingType -class ClearMLSettings(BaseSettings): +class ClearMLSettings(LoggingSettings): + logging_type: LoggingType = LoggingType.CLEARML project_name: str task_name: str tags: list[str] = [] diff --git a/turbo_alignment/settings/logging/common.py b/turbo_alignment/settings/logging/common.py new file mode 100644 index 0000000..909791b --- /dev/null +++ b/turbo_alignment/settings/logging/common.py @@ -0,0 +1,12 @@ +from enum import Enum + +from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel + + +class LoggingType(str, Enum): + WANDB: str = 'wandb' + CLEARML: str = 'clearml' + + +class LoggingSettings(ExtraFieldsNotAllowedBaseModel): + logging_type: LoggingType diff --git a/turbo_alignment/settings/logging/weights_and_biases.py b/turbo_alignment/settings/logging/weights_and_biases.py index fb24801..f5aff48 100755 --- a/turbo_alignment/settings/logging/weights_and_biases.py +++ b/turbo_alignment/settings/logging/weights_and_biases.py @@ -1,6 +1,6 @@ from enum import Enum -from pydantic_settings import BaseSettings +from turbo_alignment.settings.logging.common import LoggingSettings, LoggingType class WandbMode(str, Enum): @@ -9,7 +9,8 @@ class WandbMode(str, Enum): DISABLED: str = 'disabled' -class WandbSettings(BaseSettings): +class WandbSettings(LoggingSettings): + logging_type: LoggingType = LoggingType.WANDB project_name: str run_name: str entity: str diff --git a/turbo_alignment/settings/pipelines/inference/chat.py b/turbo_alignment/settings/pipelines/inference/chat.py index 1727c79..5817385 100755 --- a/turbo_alignment/settings/pipelines/inference/chat.py +++ b/turbo_alignment/settings/pipelines/inference/chat.py @@ -1,4 +1,7 @@ -from typing import Sequence +from typing import Any, Sequence + +from pydantic import field_validator +from transformers import GenerationConfig from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel from turbo_alignment.settings.generators.chat import CustomChatGenerationSettings @@ -6,19 +9,26 @@ InferenceExperimentSettings, SingleModelInferenceSettings, ) -from turbo_alignment.settings.tf.generation import GeneratorTransformersSettings class ChatGenerationSettings(ExtraFieldsNotAllowedBaseModel): - transformers_settings: GeneratorTransformersSettings + generation_config: GenerationConfig custom_settings: CustomChatGenerationSettings + @field_validator('generation_config', mode='before') + def convert_generation_config(cls, values: dict[str, Any]) -> GenerationConfig: + return GenerationConfig.from_dict(values) + class ChatSingleModelInferenceSettings(SingleModelInferenceSettings): generation_settings: list[ChatGenerationSettings] use_vllm: bool = False tensor_parallel_size: int = 1 + @field_validator('generation_settings', mode='before') + def convert_generation_settings(cls, values: list[Any]) -> GenerationConfig: + return [ChatGenerationSettings(**value) for value in values] + class ChatInferenceExperimentSettings(InferenceExperimentSettings): inference_settings: Sequence[ChatSingleModelInferenceSettings] diff --git a/turbo_alignment/settings/pipelines/inference/rag.py b/turbo_alignment/settings/pipelines/inference/rag.py index 77e8dd0..7471f8a 100755 --- a/turbo_alignment/settings/pipelines/inference/rag.py +++ b/turbo_alignment/settings/pipelines/inference/rag.py @@ -1,5 +1,8 @@ from pathlib import Path -from typing import Sequence +from typing import Any, Sequence + +from pydantic import field_validator +from transformers import GenerationConfig from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel from turbo_alignment.settings.datasets.chat import ChatMultiDatasetSettings @@ -15,6 +18,10 @@ class RAGSingleModelInferenceSettings(ExtraFieldsNotAllowedBaseModel): generation_settings: list[ChatGenerationSettings] + @field_validator('generation_settings', mode='before') + def convert_generation_settings(cls, values: list[Any]) -> GenerationConfig: + return [ChatGenerationSettings(**value) for value in values] + class RAGInferenceExperimentSettings(ExtraFieldsNotAllowedBaseModel): inference_settings: Sequence[RAGSingleModelInferenceSettings] diff --git a/turbo_alignment/settings/pipelines/train/base.py b/turbo_alignment/settings/pipelines/train/base.py index d31242d..db62818 100755 --- a/turbo_alignment/settings/pipelines/train/base.py +++ b/turbo_alignment/settings/pipelines/train/base.py @@ -1,8 +1,12 @@ from pathlib import Path +from typing import Any +from pydantic import field_validator from pydantic_settings import BaseSettings +from transformers import TrainingArguments from turbo_alignment.common import set_random_seed +from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.settings.cherry_pick import CherryPickSettings from turbo_alignment.settings.datasets.base import MultiDatasetSettings from turbo_alignment.settings.logging.clearml import ClearMLSettings @@ -15,19 +19,13 @@ from turbo_alignment.settings.s3 import CheckpointUploaderCallbackParameters from turbo_alignment.settings.tf.special_tokens_setter import SpecialTokensSettings from turbo_alignment.settings.tf.tokenizer import TokenizerSettings -from turbo_alignment.settings.tf.trainer import TrainerSettings - - -class EarlyStoppingSettings(BaseSettings): - patience: int = 1 - threshold: float | None = 0.0 class BaseTrainExperimentSettings(BaseSettings): log_path: Path = Path('train_output') seed: int = 42 - trainer_settings: TrainerSettings + training_arguments: TrainingArguments tokenizer_settings: TokenizerSettings special_tokens_settings: SpecialTokensSettings @@ -42,8 +40,19 @@ class BaseTrainExperimentSettings(BaseSettings): checkpoint_uploader_callback_parameters: CheckpointUploaderCallbackParameters | None = None cherry_pick_settings: CherryPickSettings | None = None + @field_validator('training_arguments', mode='before') + def create_training_arguments(cls, values: dict[str, Any]) -> TrainingArguments: + return TrainingArguments( + **values, + output_dir=TRAINER_LOGS_FOLDER, + report_to=[], + remove_unused_columns=False, + label_names=[], + ) + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.log_path.mkdir(exist_ok=True) set_random_seed(self.seed) + self.training_arguments.output_dir = str(self.log_path / TRAINER_LOGS_FOLDER) diff --git a/turbo_alignment/settings/pipelines/train/classification.py b/turbo_alignment/settings/pipelines/train/classification.py index 3febb0e..78f8c4d 100755 --- a/turbo_alignment/settings/pipelines/train/classification.py +++ b/turbo_alignment/settings/pipelines/train/classification.py @@ -1,12 +1,16 @@ -from typing import Literal +from dataclasses import dataclass +from typing import Any, Literal +from pydantic import field_validator +from transformers import TrainingArguments + +from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel from turbo_alignment.settings.cherry_pick import ClassificationCherryPickSettings from turbo_alignment.settings.datasets.classification import ( ClassificationMultiDatasetSettings, ) from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings -from turbo_alignment.settings.tf.trainer import TrainerSettings class ClassificationLossSettings(ExtraFieldsNotAllowedBaseModel): @@ -14,14 +18,32 @@ class ClassificationLossSettings(ExtraFieldsNotAllowedBaseModel): gamma: float -class ClassificationTrainerSettings(TrainerSettings): - loss_settings: ClassificationLossSettings +@dataclass +class ClassificationTrainingArguments(TrainingArguments): + loss_settings: ClassificationLossSettings = ClassificationLossSettings( + alpha=[1.0], + gamma=1.0, + ) class ClassificationTrainExperimentSettings(BaseTrainExperimentSettings): - trainer_settings: ClassificationTrainerSettings + training_arguments: ClassificationTrainingArguments train_dataset_settings: ClassificationMultiDatasetSettings val_dataset_settings: ClassificationMultiDatasetSettings cherry_pick_settings: ClassificationCherryPickSettings + + @field_validator('training_arguments', mode='before') + def create_training_arguments(cls, values: dict[str, Any]) -> ClassificationTrainingArguments: + loss_settings = values.pop('loss_settings', {}) + return ClassificationTrainingArguments( + **values, + output_dir=TRAINER_LOGS_FOLDER, + report_to=[], + remove_unused_columns=False, + label_names=['labels'], + loss_settings=ClassificationLossSettings( + **loss_settings, + ), + ) diff --git a/turbo_alignment/settings/pipelines/train/ddpo.py b/turbo_alignment/settings/pipelines/train/ddpo.py index fef8bbe..1b51048 100755 --- a/turbo_alignment/settings/pipelines/train/ddpo.py +++ b/turbo_alignment/settings/pipelines/train/ddpo.py @@ -1,3 +1,8 @@ +from typing import Any + +from pydantic import field_validator + +from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings from turbo_alignment.settings.datasets.ddpo import DDPOMultiDatasetSettings from turbo_alignment.settings.model import ( @@ -6,13 +11,10 @@ ) from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings from turbo_alignment.settings.tf.tokenizer import TokenizerSettings +from turbo_alignment.trainers.ddpo import DDPOTrainingArguments class DDPOTrainExperimentSettings(BaseTrainExperimentSettings): - beta: float = 0.1 - forward_kl: bool = False - - use_ref_model: bool = True rm_settings: PreTrainedModelSettings | PreTrainedAdaptersModelSettings train_dataset_settings: DDPOMultiDatasetSettings @@ -21,3 +23,11 @@ class DDPOTrainExperimentSettings(BaseTrainExperimentSettings): cherry_pick_settings: ChatCherryPickSettings rm_tokenizer_settings: TokenizerSettings + + training_arguments: DDPOTrainingArguments + + @field_validator('training_arguments', mode='before') + def create_training_arguments(cls, values: dict[str, Any]) -> DDPOTrainingArguments: + return DDPOTrainingArguments( + **values, output_dir=TRAINER_LOGS_FOLDER, report_to=[], remove_unused_columns=False, label_names=[] + ) diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index c5312f0..795b84b 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -1,117 +1,14 @@ -from enum import Enum -from typing import Literal +from typing import Any -from pydantic import Field +from pydantic import field_validator -from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel +from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings from turbo_alignment.settings.datasets.pair_preference import ( PairPreferenceMultiDatasetSettings, ) from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings -from turbo_alignment.settings.tf.trainer import TrainerSettings - - -class DPOLossesType(str, Enum): - SIGMOID = 'sigmoid' - SIGMOID_WITH_MARGIN = 'sigmoid_with_margin' - HINGE = 'hinge' - IPO = 'ipo' - KTO = 'kto' - SLIC_HF = 'slic_hf' - CPO = 'cpo' - ORPO = 'orpo' - SIMPO = 'simpo' - APO_ZERO = 'apo_zero' - APO_DOWN = 'apo_down' - ASFT = 'asft' - - -class DPOLossSettings(ExtraFieldsNotAllowedBaseModel): - loss_type: DPOLossesType - beta: float = 0.1 - - -class KTOLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.KTO] - - -class SigmoidLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.SIGMOID] - label_smoothing: float = 0 - - -class SigmoidLossWithMarginSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.SIGMOID_WITH_MARGIN] - - -class HingeLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.HINGE] - norm: bool = True - - -class IPOLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.IPO] - - -class CPOLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.CPO] - norm: bool = True - - -class SlicHfLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.SLIC_HF] - delta: float = 1.0 - lam: float = 0.1 - norm: bool = False - - -class SimPOLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.SIMPO] - gamma: float = 0.1 - - -class ORPOLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.ORPO] - - -class ASFTLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.ASFT] - - -class APOZeroLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.APO_ZERO] - - -class APODownLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.APO_DOWN] - - -class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel): - sync_ref_model: bool = False - alpha: float = 1.0 - sync_steps: int = 1 - - -class DPOTrainerSettings(TrainerSettings): - loss_settings: ( - SigmoidLossSettings - | HingeLossSettings - | IPOLossSettings - | KTOLossSettings - | CPOLossSettings - | ORPOLossSettings - | ASFTLossSettings - | SimPOLossSettings - | SlicHfLossSettings - | SigmoidLossWithMarginSettings - | APOZeroLossSettings - | APODownLossSettings - ) - sync_ref_settings: SyncRefModelSettings - use_ref_model: bool = True - use_sft_model: bool = False - average_log_prob: bool = Field(default=False, description='Normalize log probability by length or not') +from turbo_alignment.trainers.dpo import DPOTrainingArguments class DPOTrainExperimentSettings(BaseTrainExperimentSettings): @@ -120,4 +17,8 @@ class DPOTrainExperimentSettings(BaseTrainExperimentSettings): cherry_pick_settings: ChatCherryPickSettings - trainer_settings: DPOTrainerSettings + training_arguments: DPOTrainingArguments + + @field_validator('training_arguments', mode='before') + def create_training_arguments(cls, values: dict[str, Any]) -> DPOTrainingArguments: + return DPOTrainingArguments(**values, output_dir=TRAINER_LOGS_FOLDER, report_to=[]) diff --git a/turbo_alignment/settings/pipelines/train/kto.py b/turbo_alignment/settings/pipelines/train/kto.py index fd7b6ca..6de8b2a 100755 --- a/turbo_alignment/settings/pipelines/train/kto.py +++ b/turbo_alignment/settings/pipelines/train/kto.py @@ -1,19 +1,12 @@ -from pydantic import Field +from typing import Any +from pydantic import field_validator + +from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings from turbo_alignment.settings.datasets.kto import KTOMultiDatasetSettings from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings -from turbo_alignment.settings.pipelines.train.dpo import SyncRefModelSettings -from turbo_alignment.settings.tf.trainer import TrainerSettings - - -class KTOTrainerSettings(TrainerSettings): - undesirable_weight: float = 1.0 - desirable_weight: float = 1.33 - beta: float = 0.1 - use_ref_model: bool = True - sync_ref_settings: SyncRefModelSettings = SyncRefModelSettings() - average_log_prob: bool = Field(default=False, description='Normalize log probability by length or not') +from turbo_alignment.trainers.kto import KTOTrainingArguments class KTOTrainExperimentSettings(BaseTrainExperimentSettings): @@ -22,4 +15,10 @@ class KTOTrainExperimentSettings(BaseTrainExperimentSettings): cherry_pick_settings: ChatCherryPickSettings - trainer_settings: KTOTrainerSettings + training_arguments: KTOTrainingArguments + + @field_validator('training_arguments', mode='before') + def create_training_arguments(cls, values: dict[str, Any]) -> KTOTrainingArguments: + return KTOTrainingArguments( + **values, output_dir=TRAINER_LOGS_FOLDER, report_to=[], label_names=[], remove_unused_columns=False + ) diff --git a/turbo_alignment/settings/pipelines/train/utils.py b/turbo_alignment/settings/pipelines/train/utils.py new file mode 100644 index 0000000..b74af56 --- /dev/null +++ b/turbo_alignment/settings/pipelines/train/utils.py @@ -0,0 +1,79 @@ +from enum import Enum +from typing import Literal + +from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel + + +class DPOLossesType(str, Enum): + SIGMOID = 'sigmoid' + SIGMOID_WITH_MARGIN = 'sigmoid_with_margin' + HINGE = 'hinge' + IPO = 'ipo' + KTO = 'kto' + SLIC_HF = 'slic_hf' + CPO = 'cpo' + ORPO = 'orpo' + SIMPO = 'simpo' + APO_ZERO = 'apo_zero' + APO_DOWN = 'apo_down' + ASFT = 'asft' + + +class DPOLossSettings(ExtraFieldsNotAllowedBaseModel): + loss_type: DPOLossesType + beta: float = 0.1 + + +class KTOLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.KTO] + + +class SigmoidLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.SIGMOID] + label_smoothing: float = 0 + + +class SigmoidLossWithMarginSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.SIGMOID_WITH_MARGIN] + + +class HingeLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.HINGE] + norm: bool = True + + +class IPOLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.IPO] + + +class CPOLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.CPO] + norm: bool = True + + +class SlicHfLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.SLIC_HF] + delta: float = 1.0 + lam: float = 0.1 + norm: bool = False + + +class SimPOLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.SIMPO] + gamma: float = 0.1 + + +class ORPOLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.ORPO] + + +class ASFTLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.APO_DOWN] + + +class APOZeroLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.APO_ZERO] + + +class APODownLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.APO_DOWN] diff --git a/turbo_alignment/settings/s3.py b/turbo_alignment/settings/s3.py index 2e964fc..22c63ca 100755 --- a/turbo_alignment/settings/s3.py +++ b/turbo_alignment/settings/s3.py @@ -1,6 +1,5 @@ from pathlib import Path -from pydantic import field_validator from pydantic_settings import BaseSettings @@ -14,13 +13,6 @@ class S3HandlerParameters(BaseSettings): aws_access_key_id: str aws_secret_access_key: str - @field_validator('bucket') - def bucket_name_biglm_is_not_allowed(cls, values: str) -> str: - if values == 'biglm': - raise S3HandlerParametersWrongBucketException('Usage of biglm bucket is not allowed') - - return values - class Config: env_file: str = '.env' env_prefix: str = 'S3_CHECKPOINTS_' diff --git a/turbo_alignment/settings/tf/generation.py b/turbo_alignment/settings/tf/generation.py deleted file mode 100755 index 14537d9..0000000 --- a/turbo_alignment/settings/tf/generation.py +++ /dev/null @@ -1,13 +0,0 @@ -from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel - - -class GeneratorTransformersSettings(ExtraFieldsNotAllowedBaseModel): - num_beams: int = 1 - max_new_tokens: int = 15 - repetition_penalty: float = 1.0 - num_return_sequences: int = 1 - do_sample: bool = True - top_p: float = 1.0 - top_k: int = 50 - temperature: float = 1.0 - stop_strings: str | list[str] = '' diff --git a/turbo_alignment/settings/tf/peft.py b/turbo_alignment/settings/tf/peft.py index f479a27..d8aafd1 100755 --- a/turbo_alignment/settings/tf/peft.py +++ b/turbo_alignment/settings/tf/peft.py @@ -1,41 +1,40 @@ from typing import Literal -from peft import PeftType, TaskType +from peft import ( + LoraConfig, + PeftConfig, + PeftType, + PrefixTuningConfig, + PromptEncoderConfig, + PromptTuningConfig, +) from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel class BasePeftSettings(ExtraFieldsNotAllowedBaseModel): name: PeftType - task_type: TaskType = TaskType.CAUSAL_LM + config: PeftConfig class LoraSettings(BasePeftSettings): name: Literal[PeftType.LORA] = PeftType.LORA - r: int = 16 - lora_alpha: int = 16 - lora_dropout: float = 0.05 - bias: str = 'none' - target_modules: list[str] = ['q_proj', 'v_proj'] - modules_to_save: list[str] | None = None + config: LoraConfig class PrefixTuningSettings(BasePeftSettings): name: Literal[PeftType.PREFIX_TUNING] = PeftType.PREFIX_TUNING - encoder_hidden_size: int - prefix_projection: bool + config: PrefixTuningConfig class PromptTuningSettings(BasePeftSettings): name: Literal[PeftType.PROMPT_TUNING] = PeftType.PROMPT_TUNING - num_virtual_tokens: int = 32 - prompt_tuning_init_text: str | None = None + config: PromptTuningConfig class PTuningSettings(BasePeftSettings): name: Literal[PeftType.P_TUNING] = PeftType.P_TUNING - num_virtual_tokens: int = 32 - encoder_reparameterization_type: str = 'MLP' + config: PromptEncoderConfig PEFT_TYPE = PrefixTuningSettings | LoraSettings | PromptTuningSettings | PTuningSettings diff --git a/turbo_alignment/settings/tf/trainer.py b/turbo_alignment/settings/tf/trainer.py deleted file mode 100755 index 662ee62..0000000 --- a/turbo_alignment/settings/tf/trainer.py +++ /dev/null @@ -1,49 +0,0 @@ -from pathlib import Path -from typing import Any - -from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel - - -class TrainerSettings(ExtraFieldsNotAllowedBaseModel): - evaluation_strategy: str = 'steps' - save_strategy: str = 'steps' - per_device_train_batch_size: int = 4 - per_device_eval_batch_size: int = 4 - gradient_accumulation_steps: int = 32 - eval_steps: int = 150 - save_steps: int = 150 - logging_steps: int = 5 - learning_rate: float = 0.0003 - num_train_epochs: int = 3 - max_steps: int = -1 - lr_scheduler_type: str = 'cosine' - lr_scheduler_kwargs: dict[str, Any] = {} - warmup_steps: int = 0 - warmup_ratio: float = 0.0 - fp16: bool = True - bf16: bool = False - tf32: bool = False - torch_compile: bool = False - optim: str = 'adamw_torch' - adam_beta1: float = 0.9 - adam_beta2: float = 0.999 - adam_epsilon: float = 1e-8 - weight_decay: float = 0.0 - max_grad_norm: float = 1.0 - deepspeed: Path | None = None - save_total_limit: int = 1 - save_only_model: bool = False - no_cuda: bool = False - prediction_loss_only: bool = False - load_best_model_at_end: bool = True - logging_first_step: bool = True - fsdp_config: dict[str, Any] | None = None - fsdp: str | list[str] | None = '' - dataloader_num_workers: int = 8 - dataloader_prefetch_factor: int | None = None - dataloader_persistent_workers: bool | None = False - dataloader_pin_memory: bool | None = True - gradient_checkpointing: bool = False - gradient_checkpointing_kwargs: dict[str, Any] = {} - neftune_noise_alpha: float | None = None - report_to: list[str] = [] diff --git a/turbo_alignment/trainers/__init__.py b/turbo_alignment/trainers/__init__.py index 0c170c5..527a11a 100755 --- a/turbo_alignment/trainers/__init__.py +++ b/turbo_alignment/trainers/__init__.py @@ -1,5 +1,4 @@ from .classification import ClassificationTrainer -from .custom_loss import CustomLossTrainer from .ddpo import DDPOTrainer from .dpo import DPOTrainer from .kto import KTOTrainer diff --git a/turbo_alignment/trainers/classification.py b/turbo_alignment/trainers/classification.py index b68ac53..bfc1500 100755 --- a/turbo_alignment/trainers/classification.py +++ b/turbo_alignment/trainers/classification.py @@ -1,5 +1,3 @@ -from functools import partial - import numpy as np import torch import torch.nn.functional as F @@ -15,10 +13,7 @@ from torch.utils.data import Dataset from transformers import EvalPrediction -from turbo_alignment.settings.pipelines.train.classification import ( - ClassificationLossSettings, -) -from turbo_alignment.trainers.custom_loss import CustomLossTrainer +from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer def compute_clf_metrics(eval_pred: EvalPrediction) -> dict[str, float]: @@ -41,19 +36,17 @@ def compute_clf_metrics(eval_pred: EvalPrediction) -> dict[str, float]: return metrics -def classification_loss( - logits: torch.Tensor, labels: torch.LongTensor, loss_settings: ClassificationLossSettings -) -> torch.Tensor: - if loss_settings.alpha is None: +def classification_loss(logits: torch.Tensor, labels: torch.LongTensor, alpha, gamma) -> torch.Tensor: + if alpha is None: alpha = torch.ones((logits.size(-1),), device=logits.device, dtype=logits.dtype) else: - alpha = torch.tensor(loss_settings.alpha, device=logits.device, dtype=logits.dtype) + alpha = torch.tensor(alpha, device=logits.device, dtype=logits.dtype) ce_loss = F.cross_entropy(logits, labels, weight=alpha, reduction='none') p_t = torch.exp(-ce_loss) - focal_loss = ((1 - p_t) ** loss_settings.gamma) * ce_loss + focal_loss = ((1 - p_t) ** gamma) * ce_loss return focal_loss.mean() @@ -64,10 +57,29 @@ def auto_class_weights(dataset: Dataset) -> list[float]: return class_weights.tolist() -class ClassificationTrainer(CustomLossTrainer): - def __init__(self, loss_settings: ClassificationLossSettings, **kwargs): +class ClassificationTrainer(MultiGPUCherryPicksTrainer): + def __init__(self, **kwargs) -> None: + args = kwargs.get('args') + self.loss_settings = args.loss_settings # type: ignore[union-attr] super().__init__( - custom_loss=partial(classification_loss, loss_settings=loss_settings), compute_metrics=compute_clf_metrics, **kwargs, ) + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Modified original version, without manual label smoothing + """ + if 'labels' in inputs: + labels = inputs.pop('labels') + else: + raise ValueError('No labels provided in the inputs') + + outputs = model(**inputs) + logits = outputs['logits'] if isinstance(outputs, dict) else outputs[0] + + loss = classification_loss( + logits=logits, labels=labels, alpha=self.loss_settings['alpha'], gamma=self.loss_settings['gamma'] + ) + + return (loss, outputs) if return_outputs else loss diff --git a/turbo_alignment/trainers/custom_loss.py b/turbo_alignment/trainers/custom_loss.py deleted file mode 100755 index f126bf3..0000000 --- a/turbo_alignment/trainers/custom_loss.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Callable - -import torch -from torch import nn -from torch.utils.data import Dataset -from transformers import ( - DataCollator, - PreTrainedModel, - PreTrainedTokenizerBase, - TrainerCallback, - TrainingArguments, -) - -from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer - - -class CustomLossTrainer(MultiGPUCherryPicksTrainer): - def __init__( - self, - model: PreTrainedModel | nn.Module, - args: TrainingArguments, - train_dataset: Dataset, - eval_dataset: Dataset, - custom_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], - data_collator: DataCollator, - tokenizer: PreTrainedTokenizerBase | None = None, - callbacks: list[TrainerCallback] | None = None, - model_init: Callable[[], PreTrainedModel] | None = None, - **kwargs, - ): - self.custom_loss = custom_loss - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - model_init=model_init, - callbacks=callbacks, - **kwargs, - ) - - def compute_loss(self, model, inputs, return_outputs=False): - """ - Modified original version, without manual label smoothing - """ - if 'labels' in inputs: - labels = inputs.pop('labels') - else: - raise ValueError('No labels provided in the inputs') - - outputs = model(**inputs) - logits = outputs['logits'] if isinstance(outputs, dict) else outputs[0] - - loss = self.custom_loss(logits, labels) - - return (loss, outputs) if return_outputs else loss diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 490e6d1..f6e212e 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -1,5 +1,5 @@ from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Literal import torch @@ -21,14 +21,17 @@ from turbo_alignment.common.logging import get_project_logger from turbo_alignment.common.tf.callbacks.common import MetricsCallbackHandler -from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelCallback +from turbo_alignment.common.tf.callbacks.sync_ref_model import ( + SyncRefModelCallback, + SyncRefModelSettings, +) from turbo_alignment.constants import DISABLE_LOSS_LABEL -from turbo_alignment.settings.pipelines.train.dpo import ( +from turbo_alignment.settings.pipelines.train.utils import ( APODownLossSettings, APOZeroLossSettings, + ASFTLossSettings, CPOLossSettings, DPOLossesType, - ASFTLossSettings, HingeLossSettings, IPOLossSettings, KTOLossSettings, @@ -37,7 +40,6 @@ SigmoidLossWithMarginSettings, SimPOLossSettings, SlicHfLossSettings, - SyncRefModelSettings, ) from turbo_alignment.trainers.utils import ( DPOLossRegistry, @@ -442,12 +444,8 @@ class DPOTrainingArguments(TrainingArguments): | SigmoidLossWithMarginSettings | APOZeroLossSettings | APODownLossSettings - ) = field( - default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) - ) # type: ignore[call-overload] - sync_ref_settings: SyncRefModelSettings = field( # type: ignore[call-overload] - default_factory=SyncRefModelSettings() - ) + ) = SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) + sync_ref_settings: SyncRefModelSettings = SyncRefModelSettings() use_ref_model: bool = True use_sft_model: bool = False average_log_prob: bool = False diff --git a/turbo_alignment/trainers/kto.py b/turbo_alignment/trainers/kto.py index c6df560..6a11774 100755 --- a/turbo_alignment/trainers/kto.py +++ b/turbo_alignment/trainers/kto.py @@ -1,5 +1,5 @@ from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Literal import torch @@ -20,7 +20,7 @@ from turbo_alignment.common.logging import get_project_logger from turbo_alignment.common.tf.callbacks.common import MetricsCallbackHandler -from turbo_alignment.settings.pipelines.train.dpo import SyncRefModelSettings +from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelSettings from turbo_alignment.trainers.dpo import DPOTrainer from turbo_alignment.trainers.utils import prepare_model @@ -30,9 +30,7 @@ @dataclass class KTOTrainingArguments(TrainingArguments): beta: float = 0.1 - sync_ref_settings: SyncRefModelSettings = field( - default_factory=SyncRefModelSettings() - ) # type: ignore[call-overload] + sync_ref_settings: SyncRefModelSettings = SyncRefModelSettings() use_ref_model: bool = True average_log_prob: bool = False undesirable_weight: float = 1.0 diff --git a/tutorials/multimodal/multimodal.json b/tutorials/multimodal/multimodal.json index 54b7948..7746c79 100644 --- a/tutorials/multimodal/multimodal.json +++ b/tutorials/multimodal/multimodal.json @@ -76,19 +76,18 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "k_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "bot",