Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Almaz Dautov committed Nov 6, 2024
1 parent fe7be1e commit 838b71f
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 34 deletions.
74 changes: 48 additions & 26 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ faiss-gpu = { version = "^1.7.2", optional = true }
deepspeed = { version = "0.12", optional = true }
accelerate = { version = "0.29.0", optional = true }
vllm = {version = "0.5.3", optional = true}
allenai-common = "^1.1.2"


[tool.poetry.dev-dependencies]
Expand Down
7 changes: 7 additions & 0 deletions tests/fixtures/configs/train/reinforce/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@
}
},
"trainer_settings": {
"actor_settings": {
"actor_type": "distributed_vllm",
"vllm_num_engines": 1,
"vllm_tensor_parallel_size": 1
},
"critic_type": "ray_transformers",
"reward_processor_type": "rloo",
"evaluation_strategy": "steps",
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def reinforce_training(
reference_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=ReferenceModel)

# TODO_RLOO if possible hide init inside RayGroup
ray.get(policy_models.async_init_model_from_pretrained(pretrain=experiment_settings.model_settings.model_path))
ray.get(policy_models.async_init_model_from_pretrained())
ray.get(reward_model.async_init_model_from_pretrained(rm_model=experiment_settings.reward_model_settings.model_path))
ray.get(reference_model.async_init_model_from_pretrained(pretrain=experiment_settings.model_settings.model_path))

Expand Down
4 changes: 2 additions & 2 deletions turbo_alignment/generators/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm import SamplingParams
import ray

class VLLMChatGenerator(BaseGenerator[ChatDatasetRecord, ChatInferenceOutput]):
class vLLMChatGenerator(BaseGenerator[ChatDatasetRecord, ChatInferenceOutput]):
def __init__(
self,
transformers_settings: GeneratorTransformersSettings,
Expand Down Expand Up @@ -65,7 +65,7 @@ def generate_from_batch(
original_records: list[ChatDatasetRecord] | None = None,
) -> list[ChatInferenceOutput]:

#TODO Make sure that records are already splitted between ranks
#TODO Make sure that records are already splitted between ranks(Assuming micro_rollout_batch_size equal to micro_batch_size)

input_ids = [record['input_ids'].tolist() for record in records]

Expand Down
8 changes: 4 additions & 4 deletions turbo_alignment/settings/pipelines/train/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from turbo_alignment.settings.tf.trainer import TrainerSettings
from turbo_alignment.settings.online import (
CriticType,
vLLMActorType,
HFActorType,
ActorType,
RewardProcessorType,
)
from typing import Union

class REINFORCETrainerSettings(TrainerSettings):
max_tokens_count: int = 1024
Expand All @@ -30,8 +30,8 @@ class REINFORCETrainerSettings(TrainerSettings):
temperature: float | None = None
whiten_rewards: bool = False

actor_type: (vLLMActorType| HFActorType)
critic_type: CriticType = CriticType.LOCAL_TRANSFORMERS
actor_type: ActorType = ActorType.DISTRIBUTED_VLLM
critic_type: CriticType = CriticType.RAY_TRANSFORMERS

reward_processor_type: RewardProcessorType = RewardProcessorType.RLOO

Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class REINFORCETrainingArguments(TrainingArguments):
temperature: float | None = None
whiten_rewards: bool = False

actor_settings: vLLMActorSettings | HFActorSettings
actor_settings: vLLMActorSettings | HFActorSettings = vLLMActorSettings

critic_type: CriticType = CriticType.LOCAL_TRANSFORMERS

Expand Down

0 comments on commit 838b71f

Please sign in to comment.