diff --git a/poetry.lock b/poetry.lock index 2cfb7dd..8f4b995 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -218,6 +218,24 @@ scipy = ">=1.10.0" hub = ["huggingface-hub"] text = ["pillow"] +[[package]] +name = "allenai-common" +version = "1.1.2" +description = "Params, FromParams, Registrable, Lazy classes from AllenNLP" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "allenai-common-1.1.2.tar.gz", hash = "sha256:ae82947201c6c120dc954843fc0dc4a1e4b03f838283d62e2f11349ff9213004"}, + {file = "allenai_common-1.1.2-py3-none-any.whl", hash = "sha256:112559f6bc5a53d65f333cb3fb9748bb053c22241572ae177fd685520cce080a"}, +] + +[package.dependencies] +boto3 = ">=1.14.0,<2.0.0" +filelock = ">=3.0.0,<4.0.0" +overrides = "3.1.0" +requests = ">=2.18.0,<3.0.0" +tqdm = ">=4.19.0,<5.0.0" + [[package]] name = "annotated-types" version = "0.6.0" @@ -378,8 +396,8 @@ files = [ lazy-object-proxy = ">=1.4.0" typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} wrapt = [ - {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, {version = ">=1.11,<2", markers = "python_version < \"3.11\""}, + {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, ] [[package]] @@ -2541,39 +2559,39 @@ jupyter-server = ">=1.1.2" [[package]] name = "jupyter-server" -version = "2.14.2" +version = "2.13.0" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_server-2.14.2-py3-none-any.whl", hash = "sha256:47ff506127c2f7851a17bf4713434208fc490955d0e8632e95014a9a9afbeefd"}, - {file = "jupyter_server-2.14.2.tar.gz", hash = "sha256:66095021aa9638ced276c248b1d81862e4c50f292d575920bbe960de1c56b12b"}, + {file = "jupyter_server-2.13.0-py3-none-any.whl", hash = "sha256:77b2b49c3831fbbfbdb5048cef4350d12946191f833a24e5f83e5f8f4803e97b"}, + {file = "jupyter_server-2.13.0.tar.gz", hash = "sha256:c80bfb049ea20053c3d9641c2add4848b38073bf79f1729cea1faed32fc1c78e"}, ] [package.dependencies] anyio = ">=3.1.0" -argon2-cffi = ">=21.1" -jinja2 = ">=3.0.3" +argon2-cffi = "*" +jinja2 = "*" jupyter-client = ">=7.4.4" jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" jupyter-events = ">=0.9.0" -jupyter-server-terminals = ">=0.4.4" +jupyter-server-terminals = "*" nbconvert = ">=6.4.4" nbformat = ">=5.3.0" -overrides = ">=5.0" -packaging = ">=22.0" -prometheus-client = ">=0.9" -pywinpty = {version = ">=2.0.1", markers = "os_name == \"nt\""} +overrides = "*" +packaging = "*" +prometheus-client = "*" +pywinpty = {version = "*", markers = "os_name == \"nt\""} pyzmq = ">=24" send2trash = ">=1.8.2" terminado = ">=0.8.3" tornado = ">=6.2.0" traitlets = ">=5.6.0" -websocket-client = ">=1.7" +websocket-client = "*" [package.extras] -docs = ["ipykernel", "jinja2", "jupyter-client", "myst-parser", "nbformat", "prometheus-client", "pydata-sphinx-theme", "send2trash", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-openapi (>=0.8.0)", "sphinxcontrib-spelling", "sphinxemoji", "tornado", "typing-extensions"] -test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0,<9)", "pytest-console-scripts", "pytest-jupyter[server] (>=0.7)", "pytest-timeout", "requests"] +docs = ["ipykernel", "jinja2", "jupyter-client", "jupyter-server", "myst-parser", "nbformat", "prometheus-client", "pydata-sphinx-theme", "send2trash", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-openapi (>=0.8.0)", "sphinxcontrib-spelling", "sphinxemoji", "tornado", "typing-extensions"] +test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", "pytest-jupyter[server] (>=0.7)", "pytest-timeout", "requests"] [[package]] name = "jupyter-server-terminals" @@ -3793,10 +3811,10 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, ] [[package]] @@ -3817,10 +3835,10 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, ] [[package]] @@ -3939,13 +3957,12 @@ test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "diff-cove [[package]] name = "overrides" -version = "7.7.0" +version = "3.1.0" description = "A decorator to automatically detect mismatch when overriding a method." optional = false -python-versions = ">=3.6" +python-versions = "*" files = [ - {file = "overrides-7.7.0-py3-none-any.whl", hash = "sha256:c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49"}, - {file = "overrides-7.7.0.tar.gz", hash = "sha256:55158fa3d93b98cc75299b1e67078ad9003ca27945c76162c1c0766d6f91820a"}, + {file = "overrides-3.1.0.tar.gz", hash = "sha256:30f761124579e59884b018758c4d7794914ef02a6c038621123fec49ea7599c6"}, ] [[package]] @@ -4012,9 +4029,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4797,8 +4814,8 @@ files = [ astroid = ">=2.15.8,<=2.17.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ - {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, {version = ">=0.2", markers = "python_version < \"3.11\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, ] isort = ">=4.2.5,<6" mccabe = ">=0.6,<0.8" @@ -5781,6 +5798,11 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -7786,4 +7808,4 @@ gpu = ["faiss-gpu", "vllm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "28fbc02cd29c1a1d09e00a4eb6b5804842e68bded1b2106a4587063d266ec070" +content-hash = "036ba7fa3c16cb3d63bcef89c8798d69d68f60028c256ed58f117166ad452c69" diff --git a/pyproject.toml b/pyproject.toml index 9c2b472..2830202 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ faiss-gpu = { version = "^1.7.2", optional = true } deepspeed = { version = "0.12", optional = true } accelerate = { version = "0.29.0", optional = true } vllm = {version = "0.5.3", optional = true} +allenai-common = "^1.1.2" [tool.poetry.dev-dependencies] diff --git a/tests/fixtures/configs/train/reinforce/base.json b/tests/fixtures/configs/train/reinforce/base.json index 520ced7..6e9861e 100644 --- a/tests/fixtures/configs/train/reinforce/base.json +++ b/tests/fixtures/configs/train/reinforce/base.json @@ -125,6 +125,13 @@ } }, "trainer_settings": { + "actor_settings": { + "actor_type": "distributed_vllm", + "vllm_num_engines": 1, + "vllm_tensor_parallel_size": 1 + }, + "critic_type": "ray_transformers", + "reward_processor_type": "rloo", "evaluation_strategy": "steps", "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index 9dd80b3..0525bed 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -135,7 +135,7 @@ def reinforce_training( reference_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=ReferenceModel) # TODO_RLOO if possible hide init inside RayGroup - ray.get(policy_models.async_init_model_from_pretrained(pretrain=experiment_settings.model_settings.model_path)) + ray.get(policy_models.async_init_model_from_pretrained()) ray.get(reward_model.async_init_model_from_pretrained(rm_model=experiment_settings.reward_model_settings.model_path)) ray.get(reference_model.async_init_model_from_pretrained(pretrain=experiment_settings.model_settings.model_path)) diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index eedd088..b813d02 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -16,7 +16,7 @@ from vllm import SamplingParams import ray -class VLLMChatGenerator(BaseGenerator[ChatDatasetRecord, ChatInferenceOutput]): +class vLLMChatGenerator(BaseGenerator[ChatDatasetRecord, ChatInferenceOutput]): def __init__( self, transformers_settings: GeneratorTransformersSettings, @@ -65,7 +65,7 @@ def generate_from_batch( original_records: list[ChatDatasetRecord] | None = None, ) -> list[ChatInferenceOutput]: - #TODO Make sure that records are already splitted between ranks + #TODO Make sure that records are already splitted between ranks(Assuming micro_rollout_batch_size equal to micro_batch_size) input_ids = [record['input_ids'].tolist() for record in records] diff --git a/turbo_alignment/settings/pipelines/train/reinforce.py b/turbo_alignment/settings/pipelines/train/reinforce.py index 580c3ea..973a8f5 100644 --- a/turbo_alignment/settings/pipelines/train/reinforce.py +++ b/turbo_alignment/settings/pipelines/train/reinforce.py @@ -8,10 +8,10 @@ from turbo_alignment.settings.tf.trainer import TrainerSettings from turbo_alignment.settings.online import ( CriticType, - vLLMActorType, - HFActorType, + ActorType, RewardProcessorType, ) +from typing import Union class REINFORCETrainerSettings(TrainerSettings): max_tokens_count: int = 1024 @@ -30,8 +30,8 @@ class REINFORCETrainerSettings(TrainerSettings): temperature: float | None = None whiten_rewards: bool = False - actor_type: (vLLMActorType| HFActorType) - critic_type: CriticType = CriticType.LOCAL_TRANSFORMERS + actor_type: ActorType = ActorType.DISTRIBUTED_VLLM + critic_type: CriticType = CriticType.RAY_TRANSFORMERS reward_processor_type: RewardProcessorType = RewardProcessorType.RLOO diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index 201bb6b..7520548 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -61,7 +61,7 @@ class REINFORCETrainingArguments(TrainingArguments): temperature: float | None = None whiten_rewards: bool = False - actor_settings: vLLMActorSettings | HFActorSettings + actor_settings: vLLMActorSettings | HFActorSettings = vLLMActorSettings critic_type: CriticType = CriticType.LOCAL_TRANSFORMERS