diff --git a/tests/fixtures/configs/train/dpo/dpo_with_seq_p.json b/tests/fixtures/configs/train/dpo/dpo_with_seq_p.json index f1895e5f..6b42d196 100644 --- a/tests/fixtures/configs/train/dpo/dpo_with_seq_p.json +++ b/tests/fixtures/configs/train/dpo/dpo_with_seq_p.json @@ -70,7 +70,7 @@ { "name": "chat_test", "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", - "num_samples": 2 + "sample_rate": 1 } ], "prompt_template": { diff --git a/tests/sequence_parallel/test_dist_utils.py b/tests/sequence_parallel/test_dist_utils.py index fc1bd45d..45a57afd 100644 --- a/tests/sequence_parallel/test_dist_utils.py +++ b/tests/sequence_parallel/test_dist_utils.py @@ -66,7 +66,7 @@ def do_test_all_gather_variable(n_ranks: int = 4, group_size: int = 2): @pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.device_count() < 4, - reason='at least four gpus required' + reason='at least four gpus required', ) def test_all_gather_variable(): return subprocess.check_call(create_run_preamble(4) + ['--mode', 'all_gather_variable']) @@ -90,9 +90,7 @@ def do_test_gather_and_split(test_case): range(len(INPUTS)), ) def test_gather_and_split(test_case: int): - subprocess.check_call( - create_run_preamble(2) + ['--mode', 'gather_and_split', '--test-case', str(test_case)] - ) + subprocess.check_call(create_run_preamble(2) + ['--mode', 'gather_and_split', '--test-case', str(test_case)]) CREATE_AND_BROADCAST_INPUTS = [ @@ -128,9 +126,7 @@ def do_test_create_and_broadcast(test_case: int): range(len(CREATE_AND_BROADCAST_INPUTS)), ) def test_create_and_broadcast(test_case: int): - subprocess.check_call( - create_run_preamble(2) + ['--mode', 'create_and_broadcast', '--test-case', str(test_case)] - ) + subprocess.check_call(create_run_preamble(2) + ['--mode', 'create_and_broadcast', '--test-case', str(test_case)]) if __name__ == '__main__': diff --git a/tests/sequence_parallel/test_gemma_model.py b/tests/sequence_parallel/test_gemma_model.py index 28946ea4..926d4963 100644 --- a/tests/sequence_parallel/test_gemma_model.py +++ b/tests/sequence_parallel/test_gemma_model.py @@ -172,6 +172,8 @@ def _test_genaration(test_case: int = 0, model_path: str = MODEL_PATH): vanilla_model = vanilla_model.to(args.device) + run_in_order()(print)(f'{dist.get_rank()=} {parallel_states.get_data_parallel_rank()=}') + trainer = Trainer( model=model, train_dataset=dataset, diff --git a/turbo_alignment/__main__.py b/turbo_alignment/__main__.py index 38e89291..a221b4f1 100755 --- a/turbo_alignment/__main__.py +++ b/turbo_alignment/__main__.py @@ -1,3 +1,20 @@ +import os from turbo_alignment.cli.app import app -app() + +def set_prctl(): + try: + import prctl + + prctl.set_ptracer(prctl.SET_PTRACER_ANY) + + except ImportError: + print('prctl unavailable') + + +if __name__ == '__main__': + set_prctl() + os.register_at_fork(after_in_child=set_prctl) + app() + +# app() diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index bdd19423..3301fe4f 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -10,9 +10,13 @@ from turbo_alignment.metrics.metric import Metric from turbo_alignment.metrics.registry import KLType from turbo_alignment.metrics.utils import get_logits +from turbo_alignment.modeling import parallel_states from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings from turbo_alignment.settings.metric import ElementWiseScores, MetricResults +import torch.distributed as dist +from turbo_alignment.dist_utils.order import print_in_order, run_in_order + class ChatCherryPickCallback(CherryPickCallbackBase[InferenceChatDataset]): def __init__( @@ -53,12 +57,20 @@ def _get_dataset_metrics( batch_size = self._generator_transformers_settings.num_return_sequences + @run_in_order() + def print_dataset(prefix, dataset): + print(f'{prefix} {dist.get_rank()=} {dataset[0]=}') + + print_dataset('Before sharding:', dataset) + if accelerator is not None: dataset = self._get_sharded_dataset( dataset=dataset, accelerator=accelerator, ) + print_dataset('After sharding:', dataset) + generations = generator.generate_from_dataset(dataset) prompts = [record['prompt'] for record in dataset] @@ -120,6 +132,12 @@ def _get_dataset_metrics( @staticmethod def _get_sharded_dataset(dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset: rank_device = accelerator.process_index - slice_size = math.ceil(len(dataset) / accelerator.num_processes) + world_size = accelerator.num_processes + if parallel_states.sequence_parallel_is_enabled(): + rank_device = parallel_states.get_data_parallel_rank() + world_size = parallel_states.get_data_parallel_world_size() + + print_in_order(None)(f'{dist.get_rank()=} {rank_device=} {world_size=}') + slice_size = math.ceil(len(dataset) / world_size) return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index ca9129d8..07456950 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -38,9 +38,12 @@ def __init__( settings: ChatDatasetSettings, tokenizer: PreTrainedTokenizerBase, read: bool = True, + cut_seed: int = 42, ) -> None: super().__init__(source=source, settings=settings, tokenizer=tokenizer) self.settings: ChatDatasetSettings = settings + self.cut_seed = cut_seed + self.cut_generator = random.Random(cut_seed) if read: self._read() @@ -192,7 +195,7 @@ def _truncate_and_merge( for i, m in enumerate(conversation.messages) if m.role == ChatMessageRole.BOT and left_bound <= i < right_bound ] - right_bound = random.choice(bot_indices) if bot_indices else right_bound + right_bound = self.cut_generator.choice(bot_indices) if bot_indices else right_bound input_ids = np.array([]) labels = np.array([]) @@ -358,10 +361,11 @@ def __init__( tokenizer: PreTrainedTokenizerBase, read: bool = True, random_cut: bool = False, + cut_seed: int = 42, ) -> None: self._random_cut = random_cut - super().__init__(source=source, settings=settings, tokenizer=tokenizer, read=read) + super().__init__(source=source, settings=settings, tokenizer=tokenizer, read=read, cut_seed=cut_seed) def convert_records(self, records: list[ChatDatasetRecord]) -> list[dict[str, Any] | None]: return self._encode(records, inference=True, random_cut=self._random_cut) @@ -373,6 +377,7 @@ def get_slice(self, start: int, end: int) -> Self: tokenizer=self.tokenizer, read=False, random_cut=self._random_cut, + cut_seed=self.cut_seed, ) dataset_records = [self[idx] for idx in range(len(self))] diff --git a/turbo_alignment/dist_utils/order.py b/turbo_alignment/dist_utils/order.py index 7fbba3af..f6fb5c6d 100644 --- a/turbo_alignment/dist_utils/order.py +++ b/turbo_alignment/dist_utils/order.py @@ -19,3 +19,7 @@ def wrapped(*args, **kwargs): return wrapped return inner + + +def print_in_order(group: dist.ProcessGroup | None = None): + return run_in_order(group)(print) diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index acc43379..5da6d64d 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -4,11 +4,23 @@ from turbo_alignment.dataset.chat.models import ChatDatasetRecord from turbo_alignment.generators.base import ChatGeneratorBase +from turbo_alignment.modeling import parallel_states +from turbo_alignment.sequence_parallel.collator import pad_for_sequence_parallel, pad_and_slice from turbo_alignment.settings.generators.outputs.chat import ( AnswerMessage, ChatInferenceOutput, ) +import torch.distributed as dist +from turbo_alignment.dist_utils.order import run_in_order + + +def slice_tensor(tensor: torch.Tensor, world_size: int, rank: int, dim: int = -1) -> torch.Tensor: + dim_size = tensor.size(dim) + chunk_size = (dim_size + world_size - 1) // world_size # round up + actual_size = min(chunk_size, dim_size - chunk_size * rank) + return tensor.narrow(dim, chunk_size * rank, actual_size) + class ChatGenerator(ChatGeneratorBase[ChatDatasetRecord, ChatInferenceOutput]): def _generate_from_batch_records( @@ -35,6 +47,31 @@ def _generate_from_batch_records( batched_input_ids = batch['input_ids'].to(self.device) batched_attention_mask = batch['attention_mask'].to(self.device) + if parallel_states.sequence_parallel_is_initialized(): + batched_input_ids = pad_for_sequence_parallel( + batched_input_ids, + parallel_states.get_sequence_parallel_world_size(), + padding_side=self._tokenizer.padding_side, + padding_value=0, + ) + + batched_attention_mask = pad_for_sequence_parallel( + batched_attention_mask, + parallel_states.get_sequence_parallel_world_size(), + padding_side=self._tokenizer.padding_side, + padding_value=0, + ) + + seq_len = batched_input_ids.size(1) + chunk_size = seq_len / parallel_states.get_sequence_parallel_world_size() + rank = parallel_states.get_sequence_parallel_rank() + batched_input_ids = batched_input_ids[:, rank * chunk_size : (rank + 1) * chunk_size] + + run_in_order()(print)(f'{dist.get_rank()=} {batched_input_ids.tolist()=}') + + else: + print('WHAT') + output_indices = self._model.generate( inputs=batched_input_ids, attention_mask=batched_attention_mask, @@ -80,8 +117,31 @@ def _generate_from_single_record( input_ids = torch.unsqueeze(record['input_ids'], 0).to(self.device) attention_mask = torch.unsqueeze(record['attention_mask'], 0).to(self.device) + actual_input_ids = input_ids + + if parallel_states.sequence_parallel_is_initialized(): + run_in_order()(print)(f'Before {dist.get_rank()=} {input_ids.tolist()=}') + actual_input_ids = slice_tensor( + input_ids, + parallel_states.get_sequence_parallel_world_size(), + parallel_states.get_sequence_parallel_rank(), + dim=-1, + ) + + # attention_mask = pad_for_sequence_parallel( + # attention_mask, + # parallel_states.get_sequence_parallel_world_size(), + # padding_side='left', + # padding_value=0, + # ) + + run_in_order()(print)(f'After {dist.get_rank()=} {input_ids.tolist()=}') + + else: + print('WHAT') + output_indices = self._model.generate( - inputs=input_ids, + inputs=actual_input_ids, attention_mask=attention_mask, generation_config=self._transformers_generator_parameters, tokenizer=self._tokenizer, @@ -102,7 +162,33 @@ def _generate_from_single_record( if self._return_logits: with torch.no_grad(): - logits = self._model(output_indices).logits.cpu() + actual_output_indices = output_indices + attention_mask = None + if parallel_states.sequence_parallel_is_enabled(): + attention_mask = torch.full_like(output_indices, fill_value=1) + attention_mask = pad_for_sequence_parallel( + attention_mask, + parallel_states.get_sequence_parallel_world_size(), + 0, + padding_side='left', + ) + + actual_output_indices = pad_and_slice( + output_indices, + parallel_states.get_sequence_parallel_world_size(), + parallel_states.get_sequence_parallel_rank(), + self._tokenizer.pad_token_id, + padding_side='left', + ) + + logits = self._model(actual_output_indices, attention_mask=attention_mask).logits.cpu() + ws = parallel_states.get_sequence_parallel_world_size_or_one() + assert logits.size(-2) == actual_output_indices.size(-1), (logits.size(), actual_output_indices.size()) + if ws != 1: + remainder = output_indices.size(1) % ws + padding = 0 if remainder == 0 else (ws - remainder) + if padding != 0: + logits = logits[:, padding:] answer_tokens_ids = postprocessed_output_indices input_token_ids = input_ids diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index de94d40d..38bd23b3 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -161,6 +161,7 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: special_tokens_setter.setup_model_config(self.model) + set_random_seed(training_args.seed) train_dataset: ConcatDataset = ConcatDataset( datasets=DatasetLoader().load_datasets( experiment_settings.train_dataset_settings, @@ -169,6 +170,7 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: ) ) + set_random_seed(training_args.seed) val_dataset: ConcatDataset = ConcatDataset( datasets=DatasetLoader().load_datasets( experiment_settings.val_dataset_settings, diff --git a/turbo_alignment/pipelines/train/dpo.py b/turbo_alignment/pipelines/train/dpo.py index 82d19d30..d94ef536 100755 --- a/turbo_alignment/pipelines/train/dpo.py +++ b/turbo_alignment/pipelines/train/dpo.py @@ -41,10 +41,23 @@ def _get_cherry_pick_callback( cherry_pick_settings = experiment_settings.cherry_pick_settings + from turbo_alignment.common import set_random_seed + + set_random_seed(experiment_settings.seed) cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE ) + import torch.distributed as dist + from turbo_alignment.dist_utils.order import run_in_order + + @run_in_order() + def print_dataset(prefix, dataset): + print(f'{prefix} {dist.get_rank()=} {dataset[0]=}') + + for d in cherry_pick_datasets: + print_dataset('After load:', d) + metrics = [ Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) for metric in cherry_pick_settings.metric_settings diff --git a/turbo_alignment/sequence_parallel/collator.py b/turbo_alignment/sequence_parallel/collator.py index 562a9b75..af01d90e 100644 --- a/turbo_alignment/sequence_parallel/collator.py +++ b/turbo_alignment/sequence_parallel/collator.py @@ -27,6 +27,19 @@ def pad_for_sequence_parallel(tensor, seq_parallel_world_size, padding_value, di return tensor +def pad_and_slice( + tensor: torch.Tensor, + seq_parallel_world_size: int, + rank: int, + padding_value: tp.Any, + dim=-1, + padding_side='right', +) -> torch.Tensor: + padded = pad_for_sequence_parallel(tensor, seq_parallel_world_size, padding_value, dim, padding_side) + chunk_size = padded.size(dim) // seq_parallel_world_size + return padded.narrow(dim, rank * chunk_size, chunk_size) + + DEFAULT_PAD_VALUES = { 'attention_mask': False, 'input_ids': 0, diff --git a/turbo_alignment/sequence_parallel/generation.py b/turbo_alignment/sequence_parallel/generation.py index 2a62993f..749d3e9b 100644 --- a/turbo_alignment/sequence_parallel/generation.py +++ b/turbo_alignment/sequence_parallel/generation.py @@ -103,7 +103,7 @@ def prepare_inputs_for_generation( past_key_values = None cache_position = None - print(f'prepare {dist.get_rank()=} {input_ids=}') + # print(f'prepare {dist.get_rank()=} {input_ids=}') input_ids = gather_and_split( input_ids, group=parallel_states.get_sequence_parallel_group(), @@ -118,7 +118,6 @@ def prepare_inputs_for_generation( padding_side='left', ) - print(f'prepare II {dist.get_rank()=} {input_ids=}') # END OF THE PATCH # 1. Handle BC: diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index ae5b8e9a..5805dcf1 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -777,7 +777,9 @@ def get_batch_metrics( policy_chosen_logits, policy_rejected_logits, precomputed_margins, - ) = self.concatenated_forward(model, batch) # pylit: disable=unbalanced-tuple-unpacking + ) = self.concatenated_forward( + model, batch + ) # pylit: disable=unbalanced-tuple-unpacking reference_chosen_logps, reference_rejected_logps = torch.Tensor([float('inf')]), torch.Tensor([float('inf')])