Skip to content

Commit

Permalink
sarkar/Fix TRL training issues (#1085)
Browse files Browse the repository at this point in the history
* Turn off fused rope

* disable rope fused ops in DPO/PPO/rewarding

Signed-off-by: Wang, Yi A <[email protected]>

* fix ppo issue. rms backward fail

Signed-off-by: Wang, Yi A <[email protected]>

* Update text-generation-pipeline example command

---------

Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: Wang, Yi A <[email protected]>
Co-authored-by: Yeonsil Yoon <[email protected]>
(cherry picked from commit 89b7dac)
  • Loading branch information
ssarkar2 authored and mfuntowicz committed Jun 21, 2024
1 parent df497ca commit 0996308
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 5 deletions.
4 changes: 4 additions & 0 deletions examples/text-generation/text-generation-pipeline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ python run_pipeline.py \
--use_kv_cache \
--max_new_tokens 100 \
--do_sample \
--batch_size 2 \
--prompt "Hello world" "How are you?"
```

Expand All @@ -101,6 +102,7 @@ python run_pipeline.py \
--do_sample \
--temperature 0.5 \
--top_p 0.95 \
--batch_size 2 \
--prompt "Hello world" "How are you?"
```

Expand All @@ -114,6 +116,7 @@ python ../../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--bf16 \
--use_hpu_graphs \
--use_kv_cache \
--batch_size 4 \
--prompt "Hello world" "How are you?" "Here is my prompt" "Once upon a time"
```

Expand All @@ -128,6 +131,7 @@ python ../../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--do_sample \
--temperature 0.5 \
--top_p 0.95 \
--batch_size 4 \
--prompt "Hello world" "How are you?" "Here is my prompt" "Once upon a time"
```

Expand Down
1 change: 1 addition & 0 deletions examples/trl/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False
model.config.use_fused_rope = False

if script_args.ignore_bias_buffers:
# torch distributed hack
Expand Down
4 changes: 2 additions & 2 deletions examples/trl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def collator(data):
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)

model.config.use_fused_rope = False
model.config.use_fused_rms_norm = False
optimizer = None
model = model.to(torch.bfloat16)

Expand Down Expand Up @@ -241,7 +242,6 @@ def collator(data):
reward_model_name,
num_labels=1,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)

if config.use_habana:
Expand Down
1 change: 1 addition & 0 deletions examples/trl/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class ScriptArguments:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
model.config.use_cache = not script_args.gradient_checkpointing
model.config.use_fused_rope = False
num_proc = 24 # Can adjust to be higher if you have more processors.
original_columns = train_dataset.column_names

Expand Down
2 changes: 1 addition & 1 deletion examples/trl/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def create_datasets(tokenizer, args, seed=None):
token=script_args.token,
)
base_model.config.use_cache = False
base_model.config.use_fused_rope = False

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -167,7 +168,6 @@ def create_datasets(tokenizer, args, seed=None):
gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
gaudi_config.use_fused_clip_norm = True

trainer = GaudiSFTTrainer(
model=base_model,
gaudi_config=gaudi_config,
Expand Down
6 changes: 6 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,12 @@ def forward(
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if hasattr(self.config, "use_fused_rope") and self.config.use_fused_rope is False:
global has_fused_rope
has_fused_rope = False
if hasattr(self.config, "use_fused_rms_norm") and self.config.use_fused_rms_norm is False:
global has_fused_rms_norm
has_fused_rms_norm = False

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
Expand Down
5 changes: 3 additions & 2 deletions optimum/habana/trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,11 @@ def step(
active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)

rewards, non_score_reward = self.compute_rewards(
rewards, non_score_reward, kls = self.compute_rewards(
scores, active_full_logprobs, ref_full_logprobs, masks
)
else:
rewards, non_score_reward = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
timing["time/ppo/compute_rewards"] = time.time() - t

t = time.time()
Expand Down Expand Up @@ -648,6 +648,7 @@ def step(
masks=masks,
queries=queries,
responses=responses,
kls=kls,
)
# Gather/Reduce stats from all processes
if self.is_distributed:
Expand Down

0 comments on commit 0996308

Please sign in to comment.