Skip to content

Commit

Permalink
try to fix cherry picks
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Geyn committed Jan 21, 2025
1 parent b77df58 commit 386dfed
Show file tree
Hide file tree
Showing 13 changed files with 174 additions and 17 deletions.
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/dpo/dpo_with_seq_p.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
{
"name": "chat_test",
"records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
"num_samples": 2
"sample_rate": 1
}
],
"prompt_template": {
Expand Down
10 changes: 3 additions & 7 deletions tests/sequence_parallel/test_dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def do_test_all_gather_variable(n_ranks: int = 4, group_size: int = 2):

@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 4,
reason='at least four gpus required'
reason='at least four gpus required',
)
def test_all_gather_variable():
return subprocess.check_call(create_run_preamble(4) + ['--mode', 'all_gather_variable'])
Expand All @@ -90,9 +90,7 @@ def do_test_gather_and_split(test_case):
range(len(INPUTS)),
)
def test_gather_and_split(test_case: int):
subprocess.check_call(
create_run_preamble(2) + ['--mode', 'gather_and_split', '--test-case', str(test_case)]
)
subprocess.check_call(create_run_preamble(2) + ['--mode', 'gather_and_split', '--test-case', str(test_case)])


CREATE_AND_BROADCAST_INPUTS = [
Expand Down Expand Up @@ -128,9 +126,7 @@ def do_test_create_and_broadcast(test_case: int):
range(len(CREATE_AND_BROADCAST_INPUTS)),
)
def test_create_and_broadcast(test_case: int):
subprocess.check_call(
create_run_preamble(2) + ['--mode', 'create_and_broadcast', '--test-case', str(test_case)]
)
subprocess.check_call(create_run_preamble(2) + ['--mode', 'create_and_broadcast', '--test-case', str(test_case)])


if __name__ == '__main__':
Expand Down
2 changes: 2 additions & 0 deletions tests/sequence_parallel/test_gemma_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def _test_genaration(test_case: int = 0, model_path: str = MODEL_PATH):

vanilla_model = vanilla_model.to(args.device)

run_in_order()(print)(f'{dist.get_rank()=} {parallel_states.get_data_parallel_rank()=}')

trainer = Trainer(
model=model,
train_dataset=dataset,
Expand Down
19 changes: 18 additions & 1 deletion turbo_alignment/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
import os
from turbo_alignment.cli.app import app

app()

def set_prctl():
try:
import prctl

prctl.set_ptracer(prctl.SET_PTRACER_ANY)

except ImportError:
print('prctl unavailable')


if __name__ == '__main__':
set_prctl()
os.register_at_fork(after_in_child=set_prctl)
app()

# app()
20 changes: 19 additions & 1 deletion turbo_alignment/cherry_picks/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.metrics.registry import KLType
from turbo_alignment.metrics.utils import get_logits
from turbo_alignment.modeling import parallel_states
from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings
from turbo_alignment.settings.metric import ElementWiseScores, MetricResults

import torch.distributed as dist
from turbo_alignment.dist_utils.order import print_in_order, run_in_order


class ChatCherryPickCallback(CherryPickCallbackBase[InferenceChatDataset]):
def __init__(
Expand Down Expand Up @@ -53,12 +57,20 @@ def _get_dataset_metrics(

batch_size = self._generator_transformers_settings.num_return_sequences

@run_in_order()
def print_dataset(prefix, dataset):
print(f'{prefix} {dist.get_rank()=} {dataset[0]=}')

print_dataset('Before sharding:', dataset)

if accelerator is not None:
dataset = self._get_sharded_dataset(
dataset=dataset,
accelerator=accelerator,
)

print_dataset('After sharding:', dataset)

generations = generator.generate_from_dataset(dataset)

prompts = [record['prompt'] for record in dataset]
Expand Down Expand Up @@ -120,6 +132,12 @@ def _get_dataset_metrics(
@staticmethod
def _get_sharded_dataset(dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)
world_size = accelerator.num_processes
if parallel_states.sequence_parallel_is_enabled():
rank_device = parallel_states.get_data_parallel_rank()
world_size = parallel_states.get_data_parallel_world_size()

print_in_order(None)(f'{dist.get_rank()=} {rank_device=} {world_size=}')
slice_size = math.ceil(len(dataset) / world_size)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
9 changes: 7 additions & 2 deletions turbo_alignment/dataset/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ def __init__(
settings: ChatDatasetSettings,
tokenizer: PreTrainedTokenizerBase,
read: bool = True,
cut_seed: int = 42,
) -> None:
super().__init__(source=source, settings=settings, tokenizer=tokenizer)
self.settings: ChatDatasetSettings = settings
self.cut_seed = cut_seed
self.cut_generator = random.Random(cut_seed)

if read:
self._read()
Expand Down Expand Up @@ -192,7 +195,7 @@ def _truncate_and_merge(
for i, m in enumerate(conversation.messages)
if m.role == ChatMessageRole.BOT and left_bound <= i < right_bound
]
right_bound = random.choice(bot_indices) if bot_indices else right_bound
right_bound = self.cut_generator.choice(bot_indices) if bot_indices else right_bound

input_ids = np.array([])
labels = np.array([])
Expand Down Expand Up @@ -358,10 +361,11 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
read: bool = True,
random_cut: bool = False,
cut_seed: int = 42,
) -> None:
self._random_cut = random_cut

super().__init__(source=source, settings=settings, tokenizer=tokenizer, read=read)
super().__init__(source=source, settings=settings, tokenizer=tokenizer, read=read, cut_seed=cut_seed)

def convert_records(self, records: list[ChatDatasetRecord]) -> list[dict[str, Any] | None]:
return self._encode(records, inference=True, random_cut=self._random_cut)
Expand All @@ -373,6 +377,7 @@ def get_slice(self, start: int, end: int) -> Self:
tokenizer=self.tokenizer,
read=False,
random_cut=self._random_cut,
cut_seed=self.cut_seed,
)

dataset_records = [self[idx] for idx in range(len(self))]
Expand Down
4 changes: 4 additions & 0 deletions turbo_alignment/dist_utils/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ def wrapped(*args, **kwargs):
return wrapped

return inner


def print_in_order(group: dist.ProcessGroup | None = None):
return run_in_order(group)(print)
90 changes: 88 additions & 2 deletions turbo_alignment/generators/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@

from turbo_alignment.dataset.chat.models import ChatDatasetRecord
from turbo_alignment.generators.base import ChatGeneratorBase
from turbo_alignment.modeling import parallel_states
from turbo_alignment.sequence_parallel.collator import pad_for_sequence_parallel, pad_and_slice
from turbo_alignment.settings.generators.outputs.chat import (
AnswerMessage,
ChatInferenceOutput,
)

import torch.distributed as dist
from turbo_alignment.dist_utils.order import run_in_order


def slice_tensor(tensor: torch.Tensor, world_size: int, rank: int, dim: int = -1) -> torch.Tensor:
dim_size = tensor.size(dim)
chunk_size = (dim_size + world_size - 1) // world_size # round up
actual_size = min(chunk_size, dim_size - chunk_size * rank)
return tensor.narrow(dim, chunk_size * rank, actual_size)


class ChatGenerator(ChatGeneratorBase[ChatDatasetRecord, ChatInferenceOutput]):
def _generate_from_batch_records(
Expand All @@ -35,6 +47,31 @@ def _generate_from_batch_records(
batched_input_ids = batch['input_ids'].to(self.device)
batched_attention_mask = batch['attention_mask'].to(self.device)

if parallel_states.sequence_parallel_is_initialized():
batched_input_ids = pad_for_sequence_parallel(
batched_input_ids,
parallel_states.get_sequence_parallel_world_size(),
padding_side=self._tokenizer.padding_side,
padding_value=0,
)

batched_attention_mask = pad_for_sequence_parallel(
batched_attention_mask,
parallel_states.get_sequence_parallel_world_size(),
padding_side=self._tokenizer.padding_side,
padding_value=0,
)

seq_len = batched_input_ids.size(1)
chunk_size = seq_len / parallel_states.get_sequence_parallel_world_size()
rank = parallel_states.get_sequence_parallel_rank()
batched_input_ids = batched_input_ids[:, rank * chunk_size : (rank + 1) * chunk_size]

run_in_order()(print)(f'{dist.get_rank()=} {batched_input_ids.tolist()=}')

else:
print('WHAT')

output_indices = self._model.generate(
inputs=batched_input_ids,
attention_mask=batched_attention_mask,
Expand Down Expand Up @@ -80,8 +117,31 @@ def _generate_from_single_record(
input_ids = torch.unsqueeze(record['input_ids'], 0).to(self.device)
attention_mask = torch.unsqueeze(record['attention_mask'], 0).to(self.device)

actual_input_ids = input_ids

if parallel_states.sequence_parallel_is_initialized():
run_in_order()(print)(f'Before {dist.get_rank()=} {input_ids.tolist()=}')
actual_input_ids = slice_tensor(
input_ids,
parallel_states.get_sequence_parallel_world_size(),
parallel_states.get_sequence_parallel_rank(),
dim=-1,
)

# attention_mask = pad_for_sequence_parallel(
# attention_mask,
# parallel_states.get_sequence_parallel_world_size(),
# padding_side='left',
# padding_value=0,
# )

run_in_order()(print)(f'After {dist.get_rank()=} {input_ids.tolist()=}')

else:
print('WHAT')

output_indices = self._model.generate(
inputs=input_ids,
inputs=actual_input_ids,
attention_mask=attention_mask,
generation_config=self._transformers_generator_parameters,
tokenizer=self._tokenizer,
Expand All @@ -102,7 +162,33 @@ def _generate_from_single_record(

if self._return_logits:
with torch.no_grad():
logits = self._model(output_indices).logits.cpu()
actual_output_indices = output_indices
attention_mask = None
if parallel_states.sequence_parallel_is_enabled():
attention_mask = torch.full_like(output_indices, fill_value=1)
attention_mask = pad_for_sequence_parallel(
attention_mask,
parallel_states.get_sequence_parallel_world_size(),
0,
padding_side='left',
)

actual_output_indices = pad_and_slice(
output_indices,
parallel_states.get_sequence_parallel_world_size(),
parallel_states.get_sequence_parallel_rank(),
self._tokenizer.pad_token_id,
padding_side='left',
)

logits = self._model(actual_output_indices, attention_mask=attention_mask).logits.cpu()
ws = parallel_states.get_sequence_parallel_world_size_or_one()
assert logits.size(-2) == actual_output_indices.size(-1), (logits.size(), actual_output_indices.size())
if ws != 1:
remainder = output_indices.size(1) % ws
padding = 0 if remainder == 0 else (ws - remainder)
if padding != 0:
logits = logits[:, padding:]

answer_tokens_ids = postprocessed_output_indices
input_token_ids = input_ids
Expand Down
2 changes: 2 additions & 0 deletions turbo_alignment/pipelines/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None:

special_tokens_setter.setup_model_config(self.model)

set_random_seed(training_args.seed)
train_dataset: ConcatDataset = ConcatDataset(
datasets=DatasetLoader().load_datasets(
experiment_settings.train_dataset_settings,
Expand All @@ -169,6 +170,7 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None:
)
)

set_random_seed(training_args.seed)
val_dataset: ConcatDataset = ConcatDataset(
datasets=DatasetLoader().load_datasets(
experiment_settings.val_dataset_settings,
Expand Down
13 changes: 13 additions & 0 deletions turbo_alignment/pipelines/train/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,23 @@ def _get_cherry_pick_callback(

cherry_pick_settings = experiment_settings.cherry_pick_settings

from turbo_alignment.common import set_random_seed

set_random_seed(experiment_settings.seed)
cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
)

import torch.distributed as dist
from turbo_alignment.dist_utils.order import run_in_order

@run_in_order()
def print_dataset(prefix, dataset):
print(f'{prefix} {dist.get_rank()=} {dataset[0]=}')

for d in cherry_pick_datasets:
print_dataset('After load:', d)

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
for metric in cherry_pick_settings.metric_settings
Expand Down
13 changes: 13 additions & 0 deletions turbo_alignment/sequence_parallel/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def pad_for_sequence_parallel(tensor, seq_parallel_world_size, padding_value, di
return tensor


def pad_and_slice(
tensor: torch.Tensor,
seq_parallel_world_size: int,
rank: int,
padding_value: tp.Any,
dim=-1,
padding_side='right',
) -> torch.Tensor:
padded = pad_for_sequence_parallel(tensor, seq_parallel_world_size, padding_value, dim, padding_side)
chunk_size = padded.size(dim) // seq_parallel_world_size
return padded.narrow(dim, rank * chunk_size, chunk_size)


DEFAULT_PAD_VALUES = {
'attention_mask': False,
'input_ids': 0,
Expand Down
3 changes: 1 addition & 2 deletions turbo_alignment/sequence_parallel/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def prepare_inputs_for_generation(
past_key_values = None
cache_position = None

print(f'prepare {dist.get_rank()=} {input_ids=}')
# print(f'prepare {dist.get_rank()=} {input_ids=}')
input_ids = gather_and_split(
input_ids,
group=parallel_states.get_sequence_parallel_group(),
Expand All @@ -118,7 +118,6 @@ def prepare_inputs_for_generation(
padding_side='left',
)

print(f'prepare II {dist.get_rank()=} {input_ids=}')
# END OF THE PATCH

# 1. Handle BC:
Expand Down
4 changes: 3 additions & 1 deletion turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,9 @@ def get_batch_metrics(
policy_chosen_logits,
policy_rejected_logits,
precomputed_margins,
) = self.concatenated_forward(model, batch) # pylit: disable=unbalanced-tuple-unpacking
) = self.concatenated_forward(
model, batch
) # pylit: disable=unbalanced-tuple-unpacking

reference_chosen_logps, reference_rejected_logps = torch.Tensor([float('inf')]), torch.Tensor([float('inf')])

Expand Down

0 comments on commit 386dfed

Please sign in to comment.